Merge "NNAPI Burst object cleanup"
diff --git a/common/ExecutionBurstController.cpp b/common/ExecutionBurstController.cpp
index 32231d3..87f5c1e 100644
--- a/common/ExecutionBurstController.cpp
+++ b/common/ExecutionBurstController.cpp
@@ -14,18 +14,26 @@
  * limitations under the License.
  */
 
+#define LOG_TAG "ExecutionBurstController"
+
 #include "ExecutionBurstController.h"
 
 #include <android-base/logging.h>
+#include <string>
+#include "Tracing.h"
 
-namespace android {
-namespace nn {
+namespace android::nn {
 namespace {
-constexpr Timing invalidTiming = {UINT64_MAX, UINT64_MAX};
+
+using FmqRequestDescriptor = MQDescriptorSync<FmqRequestDatum>;
+using FmqResultDescriptor = MQDescriptorSync<FmqResultDatum>;
+
+constexpr Timing kInvalidTiming = {UINT64_MAX, UINT64_MAX};
+
 }  // anonymous namespace
 
-Return<void> ExecutionBurstCallback::getMemories(const hidl_vec<int32_t>& slots,
-                                                 getMemories_cb cb) {
+Return<void> ExecutionBurstController::ExecutionBurstCallback::getMemories(
+        const hidl_vec<int32_t>& slots, getMemories_cb cb) {
     std::lock_guard<std::mutex> guard(mMutex);
 
     // get all memories
@@ -45,8 +53,8 @@
     return Void();
 }
 
-std::vector<int32_t> ExecutionBurstCallback::getSlots(const hidl_vec<hidl_memory>& memories,
-                                                      const std::vector<intptr_t>& keys) {
+std::vector<int32_t> ExecutionBurstController::ExecutionBurstCallback::getSlots(
+        const hidl_vec<hidl_memory>& memories, const std::vector<intptr_t>& keys) {
     std::lock_guard<std::mutex> guard(mMutex);
 
     // retrieve (or bind) all slots corresponding to memories
@@ -58,7 +66,8 @@
     return slots;
 }
 
-std::pair<bool, int32_t> ExecutionBurstCallback::freeMemory(intptr_t key) {
+std::pair<bool, int32_t> ExecutionBurstController::ExecutionBurstCallback::freeMemory(
+        intptr_t key) {
     std::lock_guard<std::mutex> guard(mMutex);
 
     auto iter = mMemoryIdToSlotCache.find(key);
@@ -72,10 +81,12 @@
     }
 }
 
-int32_t ExecutionBurstCallback::getSlotLocked(const hidl_memory& memory, intptr_t key) {
+int32_t ExecutionBurstController::ExecutionBurstCallback::getSlotLocked(const hidl_memory& memory,
+                                                                        intptr_t key) {
     auto iter = mMemoryIdToSlotCache.find(key);
     if (iter == mMemoryIdToSlotCache.end()) {
         const int32_t slot = mNextSlot;
+        // TODO: change mNextSlot to uint64_t or maintain a free list of IDs
         mNextSlot = (mNextSlot + 1) % (1 << 30);
         mMemoryIdToSlotCache[key] = slot;
         mSlotToMemoryCache[slot] = memory;
@@ -86,15 +97,72 @@
     }
 }
 
+std::unique_ptr<ExecutionBurstController> ExecutionBurstController::create(
+        const sp<IPreparedModel>& preparedModel, bool blocking) {
+    // check inputs
+    if (preparedModel == nullptr) {
+        LOG(ERROR) << "ExecutionBurstController::create passed a nullptr";
+        return nullptr;
+    }
+
+    // create callback object
+    sp<ExecutionBurstCallback> callback = new ExecutionBurstCallback();
+    if (callback == nullptr) {
+        LOG(ERROR) << "ExecutionBurstController::create failed to create callback";
+        return nullptr;
+    }
+
+    // create FMQ objects
+    std::unique_ptr<FmqRequestChannel> fmqRequestChannel{new (std::nothrow) FmqRequestChannel(
+            kExecutionBurstChannelLength, /*confEventFlag=*/blocking)};
+    std::unique_ptr<FmqResultChannel> fmqResultChannel{new (std::nothrow) FmqResultChannel(
+            kExecutionBurstChannelLength, /*confEventFlag=*/blocking)};
+
+    // check FMQ objects
+    if (!fmqRequestChannel || !fmqResultChannel || !fmqRequestChannel->isValid() ||
+        !fmqResultChannel->isValid()) {
+        LOG(ERROR) << "ExecutionBurstController::create failed to create FastMessageQueue";
+        return nullptr;
+    }
+
+    // descriptors
+    const FmqRequestDescriptor& fmqRequestDescriptor = *fmqRequestChannel->getDesc();
+    const FmqResultDescriptor& fmqResultDescriptor = *fmqResultChannel->getDesc();
+
+    // configure burst
+    ErrorStatus errorStatus;
+    sp<IBurstContext> burstContext;
+    Return<void> ret = preparedModel->configureExecutionBurst(
+            callback, fmqRequestDescriptor, fmqResultDescriptor,
+            [&errorStatus, &burstContext](ErrorStatus status, const sp<IBurstContext>& context) {
+                errorStatus = status;
+                burstContext = context;
+            });
+
+    // check burst
+    if (errorStatus != ErrorStatus::NONE) {
+        LOG(ERROR) << "IPreparedModel::configureExecutionBurst failed with "
+                   << toString(errorStatus);
+        return nullptr;
+    }
+    if (burstContext == nullptr) {
+        LOG(ERROR) << "IPreparedModel::configureExecutionBurst returned nullptr for burst";
+        return nullptr;
+    }
+
+    // make and return controller
+    return std::make_unique<ExecutionBurstController>(std::move(fmqRequestChannel),
+                                                      std::move(fmqResultChannel), burstContext,
+                                                      callback, blocking);
+}
+
 ExecutionBurstController::ExecutionBurstController(
         std::unique_ptr<FmqRequestChannel> fmqRequestChannel,
         std::unique_ptr<FmqResultChannel> fmqResultChannel, const sp<IBurstContext>& burstContext,
-        const sp<IPreparedModel>& preparedModel, const sp<ExecutionBurstCallback>& callback,
-        bool blocking)
+        const sp<ExecutionBurstCallback>& callback, bool blocking)
     : mFmqRequestChannel(std::move(fmqRequestChannel)),
       mFmqResultChannel(std::move(fmqResultChannel)),
       mBurstContext(burstContext),
-      mPreparedModel(preparedModel),
       mMemoryCache(callback),
       mUsesFutex(blocking) {}
 
@@ -110,8 +178,14 @@
     using discriminator = FmqResultDatum::hidl_discriminator;
 
     // wait for result packet and read first element of result packet
+    // TODO: have a more elegant way to wait for data, and read it all at once.
+    // For example, EventFlag can be used to directly wait on the futex, and all
+    // the data can be read at once with a non-blocking call to
+    // MessageQueue::read. For further optimization, MessageQueue::beginRead and
+    // MessageQueue::commitRead can be used to avoid an extra copy of the
+    // metadata.
     FmqResultDatum datum;
-    bool success = false;
+    bool success = true;
     if (mUsesFutex) {
         success = mFmqResultChannel->readBlocking(&datum, 1);
     } else {
@@ -133,7 +207,11 @@
 
     // retrieve remaining elements
     // NOTE: all of the data is already available at this point, so there's no
-    // need to do a blocking wait to wait for more data
+    // need to do a blocking wait to wait for more data. This is known because
+    // in FMQ, all writes are published (made available) atomically. Currently,
+    // the producer always publishes the entire packet in one function call, so
+    // if the first element of the packet is available, the remaining elements
+    // are also available.
     std::vector<FmqResultDatum> packet(count);
     packet.front() = datum;
     success = mFmqResultChannel->read(packet.data() + 1, packet.size() - 1);
@@ -238,7 +316,7 @@
     // validate packet information
     if (data[index].getDiscriminator() != discriminator::packetInformation) {
         LOG(ERROR) << "FMQ Result packet ill-formed";
-        return {ErrorStatus::GENERAL_FAILURE, {}, invalidTiming};
+        return {ErrorStatus::GENERAL_FAILURE, {}, kInvalidTiming};
     }
 
     // unpackage packet information
@@ -253,7 +331,7 @@
         // validate operand information
         if (data[index].getDiscriminator() != discriminator::operandInformation) {
             LOG(ERROR) << "FMQ Result packet ill-formed";
-            return {ErrorStatus::GENERAL_FAILURE, {}, invalidTiming};
+            return {ErrorStatus::GENERAL_FAILURE, {}, kInvalidTiming};
         }
 
         // unpackage operand information
@@ -269,7 +347,7 @@
             // validate dimension
             if (data[index].getDiscriminator() != discriminator::operandDimensionValue) {
                 LOG(ERROR) << "FMQ Result packet ill-formed";
-                return {ErrorStatus::GENERAL_FAILURE, {}, invalidTiming};
+                return {ErrorStatus::GENERAL_FAILURE, {}, kInvalidTiming};
             }
 
             // unpackage dimension
@@ -287,7 +365,7 @@
     // validate execution timing
     if (data[index].getDiscriminator() != discriminator::executionTiming) {
         LOG(ERROR) << "FMQ Result packet ill-formed";
-        return {ErrorStatus::GENERAL_FAILURE, {}, invalidTiming};
+        return {ErrorStatus::GENERAL_FAILURE, {}, kInvalidTiming};
     }
 
     // unpackage execution timing
@@ -297,7 +375,7 @@
     // validate packet information
     if (index != packetSize) {
         LOG(ERROR) << "FMQ Result packet ill-formed";
-        return {ErrorStatus::GENERAL_FAILURE, {}, invalidTiming};
+        return {ErrorStatus::GENERAL_FAILURE, {}, kInvalidTiming};
     }
 
     // return result
@@ -306,6 +384,8 @@
 
 std::tuple<ErrorStatus, std::vector<OutputShape>, Timing> ExecutionBurstController::compute(
         const Request& request, MeasureTiming measure, const std::vector<intptr_t>& memoryIds) {
+    NNTRACE_FULL(NNTRACE_LAYER_IPC, NNTRACE_PHASE_EXECUTION, "ExecutionBurstController::compute");
+
     // serialize request
     std::vector<FmqRequestDatum> requestData = serialize(request, measure, memoryIds);
 
@@ -316,14 +396,14 @@
     bool success = sendPacket(requestData);
     if (!success) {
         LOG(ERROR) << "Error sending FMQ packet";
-        return {ErrorStatus::GENERAL_FAILURE, {}, invalidTiming};
+        return {ErrorStatus::GENERAL_FAILURE, {}, kInvalidTiming};
     }
 
     // get result packet
     const std::vector<FmqResultDatum> resultData = getPacketBlocking();
     if (resultData.empty()) {
         LOG(ERROR) << "Error retrieving FMQ packet";
-        return {ErrorStatus::GENERAL_FAILURE, {}, invalidTiming};
+        return {ErrorStatus::GENERAL_FAILURE, {}, kInvalidTiming};
     }
 
     // deserialize result
@@ -339,64 +419,4 @@
     }
 }
 
-std::unique_ptr<ExecutionBurstController> createExecutionBurstController(
-        const sp<IPreparedModel>& preparedModel, bool blocking) {
-    // check inputs
-    if (preparedModel == nullptr) {
-        LOG(ERROR) << "createExecutionBurstController passed a nullptr";
-        return nullptr;
-    }
-
-    // create callback object
-    sp<ExecutionBurstCallback> callback = new ExecutionBurstCallback();
-    if (callback == nullptr) {
-        LOG(ERROR) << "createExecutionBurstController failed to create callback";
-        return nullptr;
-    }
-
-    // create FMQ objects
-    std::unique_ptr<FmqRequestChannel> fmqRequestChannel{new (std::nothrow) FmqRequestChannel(
-            kExecutionBurstChannelLength, /*confEventFlag=*/blocking)};
-    std::unique_ptr<FmqResultChannel> fmqResultChannel{new (std::nothrow) FmqResultChannel(
-            kExecutionBurstChannelLength, /*confEventFlag=*/blocking)};
-
-    // check FMQ objects
-    if (!fmqRequestChannel || !fmqResultChannel || !fmqRequestChannel->isValid() ||
-        !fmqResultChannel->isValid()) {
-        LOG(ERROR) << "createExecutionBurstController failed to create FastMessageQueue";
-        return nullptr;
-    }
-
-    // descriptors
-    const FmqRequestDescriptor& fmqRequestDescriptor = *fmqRequestChannel->getDesc();
-    const FmqResultDescriptor& fmqResultDescriptor = *fmqResultChannel->getDesc();
-
-    // configure burst
-    ErrorStatus errorStatus;
-    sp<IBurstContext> burstContext;
-    Return<void> ret = preparedModel->configureExecutionBurst(
-            callback, fmqRequestDescriptor, fmqResultDescriptor,
-            [&errorStatus, &burstContext](ErrorStatus status, const sp<IBurstContext>& context) {
-                errorStatus = status;
-                burstContext = context;
-            });
-
-    // check burst
-    if (errorStatus != ErrorStatus::NONE) {
-        LOG(ERROR) << "IPreparedModel::configureExecutionBurst failed with "
-                   << toString(errorStatus);
-        return nullptr;
-    }
-    if (burstContext == nullptr) {
-        LOG(ERROR) << "IPreparedModel::configureExecutionBurst returned nullptr for burst";
-        return nullptr;
-    }
-
-    // make and return controller
-    return std::make_unique<ExecutionBurstController>(std::move(fmqRequestChannel),
-                                                      std::move(fmqResultChannel), burstContext,
-                                                      preparedModel, callback, blocking);
-}
-
-}  // namespace nn
-}  // namespace android
+}  // namespace android::nn
diff --git a/common/ExecutionBurstServer.cpp b/common/ExecutionBurstServer.cpp
index 64a4ee2..6ede34d 100644
--- a/common/ExecutionBurstServer.cpp
+++ b/common/ExecutionBurstServer.cpp
@@ -14,44 +14,57 @@
  * limitations under the License.
  */
 
+#define LOG_TAG "ExecutionBurstServer"
+
 #include "ExecutionBurstServer.h"
 
 #include <android-base/logging.h>
+#include <set>
+#include <string>
+#include "Tracing.h"
 
-namespace android {
-namespace nn {
+namespace android::nn {
 
-BurstMemoryCache::BurstMemoryCache(const sp<IBurstCallback>& callback) : mCallback(callback) {}
+ExecutionBurstServer::BurstMemoryCache::BurstMemoryCache(const sp<IBurstCallback>& callback)
+    : mCallback(callback) {}
 
-hidl_vec<hidl_memory> BurstMemoryCache::getMemories(const std::vector<int32_t>& slots) {
+hidl_vec<hidl_memory> ExecutionBurstServer::BurstMemoryCache::getMemories(
+        const std::vector<int32_t>& slots) {
     std::lock_guard<std::mutex> guard(mMutex);
 
     // find unique unknown slots
-    std::vector<int32_t> unknownSlots = slots;
-    std::sort(unknownSlots.begin(), unknownSlots.end());
-    auto last = std::unique(unknownSlots.begin(), unknownSlots.end());
-    unknownSlots.erase(last, unknownSlots.end());
+    std::set<int32_t> setOfUnknownSlots;
+    for (int32_t slot : slots) {
+        if (mSlotToMemoryCache.find(slot) == mSlotToMemoryCache.end()) {
+            setOfUnknownSlots.insert(slot);
+        }
+    }
+    const std::vector<int32_t> vecOfUnknownSlots(setOfUnknownSlots.begin(),
+                                                 setOfUnknownSlots.end());
 
     // retrieve unknown slots
-    ErrorStatus errorStatus = ErrorStatus::GENERAL_FAILURE;
-    std::vector<hidl_memory> returnedMemories;
-    Return<void> ret = mCallback->getMemories(
-            unknownSlots, [&errorStatus, &returnedMemories](ErrorStatus status,
-                                                            const hidl_vec<hidl_memory>& memories) {
-                errorStatus = status;
-                if (status == ErrorStatus::NONE) {
-                    returnedMemories = memories;
-                }
-            });
+    if (!vecOfUnknownSlots.empty()) {
+        ErrorStatus errorStatus = ErrorStatus::GENERAL_FAILURE;
+        std::vector<hidl_memory> returnedMemories;
+        Return<void> ret = mCallback->getMemories(
+                vecOfUnknownSlots,
+                [&errorStatus, &returnedMemories](ErrorStatus status,
+                                                  const hidl_vec<hidl_memory>& memories) {
+                    errorStatus = status;
+                    if (status == ErrorStatus::NONE) {
+                        returnedMemories = memories;
+                    }
+                });
 
-    if (!ret.isOk() || errorStatus != ErrorStatus::NONE) {
-        LOG(ERROR) << "Error retrieving memories";
-        return {};
-    }
+        if (!ret.isOk() || errorStatus != ErrorStatus::NONE) {
+            LOG(ERROR) << "Error retrieving memories";
+            return {};
+        }
 
-    // add memories to unknown slots
-    for (size_t i = 0; i < unknownSlots.size(); ++i) {
-        mSlotToMemoryCache[unknownSlots[i]] = returnedMemories[i];
+        // add memories to unknown slots
+        for (size_t i = 0; i < vecOfUnknownSlots.size(); ++i) {
+            mSlotToMemoryCache[vecOfUnknownSlots[i]] = returnedMemories[i];
+        }
     }
 
     // get all slots
@@ -59,14 +72,42 @@
     for (size_t i = 0; i < slots.size(); ++i) {
         memories[i] = mSlotToMemoryCache[slots[i]];
     }
+
     return memories;
 }
 
-void BurstMemoryCache::freeMemory(int32_t slot) {
+void ExecutionBurstServer::BurstMemoryCache::freeMemory(int32_t slot) {
     std::lock_guard<std::mutex> guard(mMutex);
     mSlotToMemoryCache.erase(slot);
 }
 
+sp<ExecutionBurstServer> ExecutionBurstServer::create(
+        const sp<IBurstCallback>& callback, const MQDescriptorSync<FmqRequestDatum>& requestChannel,
+        const MQDescriptorSync<FmqResultDatum>& resultChannel, IPreparedModel* preparedModel) {
+    // check inputs
+    if (callback == nullptr || preparedModel == nullptr) {
+        LOG(ERROR) << "ExecutionBurstServer::create passed a nullptr";
+        return nullptr;
+    }
+
+    // create FMQ objects
+    std::unique_ptr<FmqRequestChannel> fmqRequestChannel{new (std::nothrow)
+                                                                 FmqRequestChannel(requestChannel)};
+    std::unique_ptr<FmqResultChannel> fmqResultChannel{new (std::nothrow)
+                                                               FmqResultChannel(resultChannel)};
+
+    // check FMQ objects
+    if (!fmqRequestChannel || !fmqResultChannel || !fmqRequestChannel->isValid() ||
+        !fmqResultChannel->isValid()) {
+        LOG(ERROR) << "ExecutionBurstServer::create failed to create FastMessageQueue";
+        return nullptr;
+    }
+
+    // make and return context
+    return new ExecutionBurstServer(callback, std::move(fmqRequestChannel),
+                                    std::move(fmqResultChannel), preparedModel);
+}
+
 ExecutionBurstServer::ExecutionBurstServer(const sp<IBurstCallback>& callback,
                                            std::unique_ptr<FmqRequestChannel> requestChannel,
                                            std::unique_ptr<FmqResultChannel> resultChannel,
@@ -85,9 +126,13 @@
     mTeardown = true;
 
     // force unblock
+    // ExecutionBurstServer is by default waiting on a request packet. If the
+    // client process destroys its burst object, the server will still be
+    // waiting on the futex (assuming mBlocking is true). This force unblock
+    // wakes up any thread waiting on the futex.
     if (mBlocking) {
-        // TODO: look for a different/better way to signal/notify the futex to wake
-        // up any thread waiting on it
+        // TODO: look for a different/better way to signal/notify the futex to
+        // wake up any thread waiting on it
         FmqRequestDatum datum;
         datum.packetInformation({/*.packetSize=*/0, /*.numberOfInputOperands=*/0,
                                  /*.numberOfOutputOperands=*/0, /*.numberOfPools=*/0});
@@ -117,7 +162,13 @@
         return {};
     }
 
-    // wait for request packet and read first element of result packet
+    // wait for request packet and read first element of request packet
+    // TODO: have a more elegant way to wait for data, and read it all at once.
+    // For example, EventFlag can be used to directly wait on the futex, and all
+    // the data can be read at once with a non-blocking call to
+    // MessageQueue::read. For further optimization, MessageQueue::beginRead and
+    // MessageQueue::commitRead can be used to avoid an extra copy of the
+    // metadata.
     FmqRequestDatum datum;
     bool success = false;
     if (mBlocking) {
@@ -139,13 +190,19 @@
         return {};
     }
 
+    NNTRACE_FULL(NNTRACE_LAYER_IPC, NNTRACE_PHASE_EXECUTION, "ExecutionBurstServer getting packet");
+
     // unpack packet information
     const auto& packetInfo = datum.packetInformation();
     const size_t count = packetInfo.packetSize;
 
     // retrieve remaining elements
     // NOTE: all of the data is already available at this point, so there's no
-    // need to do a blocking wait to wait for more data
+    // need to do a blocking wait to wait for more data. This is known because
+    // in FMQ, all writes are published (made available) atomically. Currently,
+    // the producer always publishes the entire packet in one function call, so
+    // if the first element of the packet is available, the remaining elements
+    // are also available.
     std::vector<FmqRequestDatum> packet(count);
     packet.front() = datum;
     success = mFmqRequestChannel->read(packet.data() + 1, packet.size() - 1);
@@ -365,6 +422,9 @@
             return;
         }
 
+        NNTRACE_FULL(NNTRACE_LAYER_IPC, NNTRACE_PHASE_EXECUTION,
+                     "ExecutionBurstServer processing packet and returning results");
+
         // continue processing
         Request request;
         MeasureTiming measure;
@@ -374,6 +434,10 @@
         ErrorStatus errorStatus = ErrorStatus::GENERAL_FAILURE;
         std::vector<OutputShape> outputShapes;
         Timing returnedTiming;
+        // This call to IPreparedModel::executeSynchronously occurs entirely
+        // within the same process, so ignore the Return<> errors via .isOk().
+        // TODO: verify it is safe to always call isOk() here, or if there is
+        // any benefit to checking any potential errors.
         mPreparedModel
                 ->executeSynchronously(request, measure,
                                        [&errorStatus, &outputShapes, &returnedTiming](
@@ -392,33 +456,4 @@
     }
 }
 
-sp<IBurstContext> createBurstContext(const sp<IBurstCallback>& callback,
-                                     const MQDescriptorSync<FmqRequestDatum>& requestChannel,
-                                     const MQDescriptorSync<FmqResultDatum>& resultChannel,
-                                     IPreparedModel* preparedModel) {
-    // check inputs
-    if (callback == nullptr || preparedModel == nullptr) {
-        LOG(ERROR) << "createExecutionBurstServer passed a nullptr";
-        return nullptr;
-    }
-
-    // create FMQ objects
-    std::unique_ptr<FmqRequestChannel> fmqRequestChannel{new (std::nothrow)
-                                                                 FmqRequestChannel(requestChannel)};
-    std::unique_ptr<FmqResultChannel> fmqResultChannel{new (std::nothrow)
-                                                               FmqResultChannel(resultChannel)};
-
-    // check FMQ objects
-    if (!fmqRequestChannel || !fmqResultChannel || !fmqRequestChannel->isValid() ||
-        !fmqResultChannel->isValid()) {
-        LOG(ERROR) << "createExecutionBurstServer failed to create FastMessageQueue";
-        return nullptr;
-    }
-
-    // make and return context
-    return new ExecutionBurstServer(callback, std::move(fmqRequestChannel),
-                                    std::move(fmqResultChannel), preparedModel);
-}
-
-}  // namespace nn
-}  // namespace android
+}  // namespace android::nn
diff --git a/common/include/ExecutionBurstController.h b/common/include/ExecutionBurstController.h
index bf36470..7152325 100644
--- a/common/include/ExecutionBurstController.h
+++ b/common/include/ExecutionBurstController.h
@@ -27,16 +27,13 @@
 #include <tuple>
 #include "HalInterfaces.h"
 
-namespace android {
-namespace nn {
+namespace android::nn {
 
 using ::android::hardware::kSynchronizedReadWrite;
 using ::android::hardware::MessageQueue;
 using ::android::hardware::MQDescriptorSync;
 using FmqRequestChannel = MessageQueue<FmqRequestDatum, kSynchronizedReadWrite>;
 using FmqResultChannel = MessageQueue<FmqResultDatum, kSynchronizedReadWrite>;
-using FmqRequestDescriptor = MQDescriptorSync<FmqRequestDatum>;
-using FmqResultDescriptor = MQDescriptorSync<FmqResultDatum>;
 
 /**
  * Number of elements in the FMQ.
@@ -44,62 +41,96 @@
 constexpr const size_t kExecutionBurstChannelLength = 1024;
 
 /**
- * NN runtime burst callback object and memory cache.
- *
- * ExecutionBurstCallback associates a hidl_memory object with a slot number to
- * be passed across FMQ. The ExecutionBurstServer can use this callback to
- * retrieve this hidl_memory corresponding to the slot via HIDL.
- *
- * Whenever a hidl_memory object is copied, it will duplicate the underlying
- * file descriptor. Because the NN runtime currently copies the hidl_memory on
- * each execution, it is difficult to associate hidl_memory objects with
- * previously cached hidl_memory objects. For this reason, callers of this class
- * must pair each hidl_memory object with an associated key. For efficiency, if
- * two hidl_memory objects represent the same underlying buffer, they must use
- * the same key.
- */
-class ExecutionBurstCallback : public IBurstCallback {
-    DISALLOW_COPY_AND_ASSIGN(ExecutionBurstCallback);
-
-   public:
-    ExecutionBurstCallback() = default;
-
-    Return<void> getMemories(const hidl_vec<int32_t>& slots, getMemories_cb cb) override;
-
-    std::vector<int32_t> getSlots(const hidl_vec<hidl_memory>& memories,
-                                  const std::vector<intptr_t>& keys);
-    int32_t getSlot(const hidl_memory& memory, intptr_t key);
-    std::pair<bool, int32_t> freeMemory(intptr_t key);
-
-   private:
-    int32_t getSlotLocked(const hidl_memory& memory, intptr_t key);
-
-    std::mutex mMutex;
-    int32_t mNextSlot = 0;
-    std::map<intptr_t, int32_t> mMemoryIdToSlotCache;
-    std::map<int32_t, hidl_memory> mSlotToMemoryCache;
-};
-
-/**
- * NN runtime burst object
- *
- * TODO: provide high-level description of class
+ * The ExecutionBurstController class manages both the serialization and
+ * deserialization of data across FMQ, making it appear to the runtime as a
+ * regular synchronous inference. Additionally, this class manages the burst's
+ * memory cache.
  */
 class ExecutionBurstController {
     DISALLOW_IMPLICIT_CONSTRUCTORS(ExecutionBurstController);
 
+    /**
+     * NN runtime burst callback object and memory cache.
+     *
+     * ExecutionBurstCallback associates a hidl_memory object with a slot number
+     * to be passed across FMQ. The ExecutionBurstServer can use this callback
+     * to retrieve this hidl_memory corresponding to the slot via HIDL.
+     *
+     * Whenever a hidl_memory object is copied, it will duplicate the underlying
+     * file descriptor. Because the NN runtime currently copies the hidl_memory
+     * on each execution, it is difficult to associate hidl_memory objects with
+     * previously cached hidl_memory objects. For this reason, callers of this
+     * class must pair each hidl_memory object with an associated key. For
+     * efficiency, if two hidl_memory objects represent the same underlying
+     * buffer, they must use the same key.
+     */
+    class ExecutionBurstCallback : public IBurstCallback {
+        DISALLOW_COPY_AND_ASSIGN(ExecutionBurstCallback);
+
+       public:
+        ExecutionBurstCallback() = default;
+
+        Return<void> getMemories(const hidl_vec<int32_t>& slots, getMemories_cb cb) override;
+
+        std::vector<int32_t> getSlots(const hidl_vec<hidl_memory>& memories,
+                                      const std::vector<intptr_t>& keys);
+
+        int32_t getSlot(const hidl_memory& memory, intptr_t key);
+
+        /*
+         * This function performs two different actions:
+         * 1) Removes an entry from the cache (if present), including the local
+         *    storage of the hidl_memory object. Note that this call does not
+         *    free any corresponding hidl_memory object in ExecutionBurstServer,
+         *    which is separately freed via IBurstContext::freeMemory.
+         * 2) Return whether a cache entry was removed and which slot was removed if
+         *    found. If the key did not to correspond to any entry in the cache, a
+         *    slot number of 0 is returned. The slot number and whether the entry
+         *    existed is useful so the same slot can be freed in the
+         *    ExecutionBurstServer's cache via IBurstContext::freeMemory.
+         */
+        std::pair<bool, int32_t> freeMemory(intptr_t key);
+
+       private:
+        int32_t getSlotLocked(const hidl_memory& memory, intptr_t key);
+
+        std::mutex mMutex;
+        int32_t mNextSlot = 0;
+        std::map<intptr_t, int32_t> mMemoryIdToSlotCache;
+        std::map<int32_t, hidl_memory> mSlotToMemoryCache;
+    };
+
    public:
+    /**
+     * Creates a burst controller on a prepared model.
+     *
+     * Prefer this over ExecutionBurstController's constructor.
+     *
+     * @param preparedModel Model prepared for execution to execute on.
+     * @param blocking 'true' if the FMQ should use a futex to perform blocking
+     *     until data is available in a less responsive, but more energy
+     *     efficient manner. 'false' if the FMQ should use spin-looping to
+     *     wait until data is available in a more responsive, but less energy
+     *     efficient manner.
+     * @return ExecutionBurstController Execution burst controller object.
+     */
+    static std::unique_ptr<ExecutionBurstController> create(const sp<IPreparedModel>& preparedModel,
+                                                            bool blocking);
+
     ExecutionBurstController(std::unique_ptr<FmqRequestChannel> fmqRequestChannel,
                              std::unique_ptr<FmqResultChannel> fmqResultChannel,
                              const sp<IBurstContext>& burstContext,
-                             const sp<IPreparedModel>& preparedModel,
                              const sp<ExecutionBurstCallback>& callback, bool blocking);
 
     /**
      * Execute a request on a model.
      *
      * @param request Arguments to be executed on a model.
-     * @return status and output shape of the execution.
+     * @param measure Whether to collect timing measurements, either YES or NO
+     * @param memoryIds Identifiers corresponding to each memory object in the
+     *     request's pools.
+     * @return status and output shape of the execution and any execution time
+     *     measurements.
      */
     std::tuple<ErrorStatus, std::vector<OutputShape>, Timing> compute(
             const Request& request, MeasureTiming measure, const std::vector<intptr_t>& memoryIds);
@@ -122,22 +153,10 @@
     const std::unique_ptr<FmqRequestChannel> mFmqRequestChannel;
     const std::unique_ptr<FmqResultChannel> mFmqResultChannel;
     const sp<IBurstContext> mBurstContext;
-    const sp<IPreparedModel> mPreparedModel;
     const sp<ExecutionBurstCallback> mMemoryCache;
     const bool mUsesFutex;
 };
 
-/**
- * Creates a burst controller on a prepared model.
- *
- * @param preparedModel Model prepared for execution to execute on.
- * @param blocking 'true' if the FMQ should block until data is available.
- * @return ExecutionBurstController Execution burst controller object.
- */
-std::unique_ptr<ExecutionBurstController> createExecutionBurstController(
-        const sp<IPreparedModel>& preparedModel, bool blocking);
-
-}  // namespace nn
-}  // namespace android
+}  // namespace android::nn
 
 #endif  // ANDROID_ML_NN_RUNTIME_EXECUTION_BURST_CONTROLLER_H
diff --git a/common/include/ExecutionBurstServer.h b/common/include/ExecutionBurstServer.h
index 13dfaaf..0a3222f 100644
--- a/common/include/ExecutionBurstServer.h
+++ b/common/include/ExecutionBurstServer.h
@@ -26,8 +26,7 @@
 #include <set>
 #include "HalInterfaces.h"
 
-namespace android {
-namespace nn {
+namespace android::nn {
 
 using ::android::hardware::kSynchronizedReadWrite;
 using ::android::hardware::MessageQueue;
@@ -38,29 +37,58 @@
 using FmqResultDescriptor = MQDescriptorSync<FmqResultDatum>;
 
 /**
- */
-class BurstMemoryCache {
-    DISALLOW_IMPLICIT_CONSTRUCTORS(BurstMemoryCache);
-
-   public:
-    BurstMemoryCache(const sp<IBurstCallback>& callback);
-
-    hidl_vec<hidl_memory> getMemories(const std::vector<int32_t>& slots);
-    void freeMemory(int32_t slot);
-
-   private:
-    std::mutex mMutex;
-    const sp<IBurstCallback> mCallback;
-    std::map<int32_t, hidl_memory> mSlotToMemoryCache;
-};
-
-/**
- * NN server burst object
+ * The ExecutionBurstServer class is responsible for waiting for and
+ * deserializing a request object from a FMQ, performing the inference, and
+ * serializing the result back across another FMQ.
  */
 class ExecutionBurstServer : public IBurstContext {
     DISALLOW_IMPLICIT_CONSTRUCTORS(ExecutionBurstServer);
 
+    /**
+     * BurstMemoryCache is responsible for managing the local memory cache of
+     * the burst object. If the ExecutionBurstServer requests a memory key that
+     * is unrecognized, the BurstMemoryCache object will retrieve the memory
+     * from the client, transparent from the ExecutionBurstServer object.
+     */
+    class BurstMemoryCache {
+        DISALLOW_IMPLICIT_CONSTRUCTORS(BurstMemoryCache);
+
+       public:
+        BurstMemoryCache(const sp<IBurstCallback>& callback);
+
+        hidl_vec<hidl_memory> getMemories(const std::vector<int32_t>& slots);
+        void freeMemory(int32_t slot);
+
+       private:
+        std::mutex mMutex;
+        const sp<IBurstCallback> mCallback;
+        std::map<int32_t, hidl_memory> mSlotToMemoryCache;
+    };
+
    public:
+    /**
+     * Create automated context to manage FMQ-based executions.
+     *
+     * This function is intended to be used by a service to automatically:
+     * 1) Receive data from a provided FMQ
+     * 2) Execute a model with the given information
+     * 3) Send the result to the created FMQ
+     *
+     * @param callback Callback used to retrieve memories corresponding to
+     *     unrecognized slots.
+     * @param requestChannel Input FMQ channel through which the client passes the
+     *     request to the service.
+     * @param resultChannel Output FMQ channel from which the client can retrieve
+     *     the result of the execution.
+     * @param preparedModel PreparedModel that the burst object was created from.
+     *     This will be used to synchronously perform the execution.
+     * @result IBurstContext Handle to the burst context.
+     */
+    static sp<ExecutionBurstServer> create(const sp<IBurstCallback>& callback,
+                                           const FmqRequestDescriptor& requestChannel,
+                                           const FmqResultDescriptor& resultChannel,
+                                           IPreparedModel* preparedModel);
+
     ExecutionBurstServer(const sp<IBurstCallback>& callback,
                          std::unique_ptr<FmqRequestChannel> requestChannel,
                          std::unique_ptr<FmqResultChannel> resultChannel,
@@ -87,34 +115,6 @@
     const bool mBlocking;
 };
 
-/**
- * Create automated context to manage FMQ-based executions.
- *
- * This function is intended to be used by a service to automatically:
- * 1) Receive data from a provided FMQ
- * 2) Execute a model with the given information
- * 3) Send the result to the created FMQ
- *
- * @param callback Callback used to retrieve memories corresponding to
- *                 unrecognized slots.
- * @param requestChannel Input FMQ channel through which the client passes the
- *                       request to the service.
- * @param requestChannel Output FMQ channel from which the client can retrieve
- *                       the result of the execution.
- * @param preparedModel PreparedModel that the burst object was created from.
- *                      This will be used to synchronously perform the
- *                      execution.
- * @result IBurstContext Handle to the burst context.
- */
-::android::sp<::android::hardware::neuralnetworks::V1_2::IBurstContext> createBurstContext(
-        const sp<::android::hardware::neuralnetworks::V1_2::IBurstCallback>& callback,
-        const ::android::hardware::MQDescriptorSync<
-                ::android::hardware::neuralnetworks::V1_2::FmqRequestDatum>& requestChannel,
-        const ::android::hardware::MQDescriptorSync<
-                ::android::hardware::neuralnetworks::V1_2::FmqResultDatum>& resultChannel,
-        ::android::hardware::neuralnetworks::V1_2::IPreparedModel* preparedModel);
-
-}  // namespace nn
-}  // namespace android
+}  // namespace android::nn
 
 #endif  // ANDROID_ML_NN_COMMON_EXECUTION_BURST_SERVER_H
diff --git a/driver/sample/SampleDriver.cpp b/driver/sample/SampleDriver.cpp
index c0fd37c..a2289fd 100644
--- a/driver/sample/SampleDriver.cpp
+++ b/driver/sample/SampleDriver.cpp
@@ -353,7 +353,7 @@
                  "SampleDriver::configureExecutionBurst");
 
     const sp<V1_2::IBurstContext> burst =
-            createBurstContext(callback, requestChannel, resultChannel, this);
+            ExecutionBurstServer::create(callback, requestChannel, resultChannel, this);
 
     if (burst == nullptr) {
         cb(ErrorStatus::GENERAL_FAILURE, {});
diff --git a/runtime/ExecutionBuilder.cpp b/runtime/ExecutionBuilder.cpp
index 140c0db..035e84a 100644
--- a/runtime/ExecutionBuilder.cpp
+++ b/runtime/ExecutionBuilder.cpp
@@ -468,7 +468,8 @@
 
 int ExecutionBuilder::compute(sp<ExecutionCallback>* synchronizationCallback,
                               BurstBuilder* burstBuilder) {
-    assert(synchronizationCallback == nullptr || burstBuilder == nullptr);
+    CHECK(synchronizationCallback == nullptr || burstBuilder == nullptr)
+            << "synchronizationCallback and burstBuilder cannot simultaneously be used";
 
     const bool synchronous = (synchronizationCallback == nullptr);
 
diff --git a/runtime/ExecutionBuilder.h b/runtime/ExecutionBuilder.h
index cd79f40..e2c01f7 100644
--- a/runtime/ExecutionBuilder.h
+++ b/runtime/ExecutionBuilder.h
@@ -117,7 +117,7 @@
     // If burst is provided, then the burst path will be used. If a burst is not
     // provided (i.e., is nullptr), then a synchronous execution will occur.
     //
-    // Providing both synchronizationCallbak and burstBuilder is an error.
+    // Providing both synchronizationCallback and burstBuilder is an error.
     int compute(sp<ExecutionCallback>* synchronizationCallback,
                 BurstBuilder* burstBuilder = nullptr);
 
diff --git a/runtime/VersionedInterfaces.cpp b/runtime/VersionedInterfaces.cpp
index 0dfd894..9c657c2 100644
--- a/runtime/VersionedInterfaces.cpp
+++ b/runtime/VersionedInterfaces.cpp
@@ -93,7 +93,8 @@
     // create death handler object
     sp<IPreparedModelDeathHandler> deathHandler = new (std::nothrow) IPreparedModelDeathHandler();
     if (!deathHandler) {
-        LOG(ERROR) << "VersionedIDevice::create -- Failed to create IPreparedModelDeathHandler.";
+        LOG(ERROR) << "VersionedIPreparedModel::create -- Failed to create "
+                      "IPreparedModelDeathHandler.";
         return nullptr;
     }
 
@@ -103,8 +104,8 @@
     // providing the response.
     const Return<bool> ret = preparedModel->linkToDeath(deathHandler, 0);
     if (!ret.isOk() || ret != true) {
-        LOG(ERROR) << "VersionedIDevice::create -- Failed to register a death recipient for the "
-                      "IPreparedModel object.";
+        LOG(ERROR) << "VersionedIPreparedModel::create -- Failed to register a death recipient for "
+                      "the IPreparedModel object.";
         return nullptr;
     }
 
@@ -185,7 +186,7 @@
 std::unique_ptr<ExecutionBurstController> VersionedIPreparedModel::configureExecutionBurst(
         bool blocking) const {
     if (mPreparedModelV1_2 != nullptr) {
-        return createExecutionBurstController(mPreparedModelV1_2, blocking);
+        return ExecutionBurstController::create(mPreparedModelV1_2, blocking);
     } else {
         return nullptr;
     }
diff --git a/runtime/test/TestMain.cpp b/runtime/test/TestMain.cpp
index dc32cec..5aa163d 100644
--- a/runtime/test/TestMain.cpp
+++ b/runtime/test/TestMain.cpp
@@ -40,8 +40,7 @@
 // non-public DeviceManager::setSyncExecHal(); we assume the setting is always
 // true, and if we are asked to set it to false, we return 0 ("success") without
 // running tests.
-static int test(bool useCpuOnly, bool computeUsesSynchronousAPI, bool allowSyncExecHal = true,
-                bool computeUsesBurstAPI = false) {
+static int test(bool useCpuOnly, Execution::ComputeMode computeMode, bool allowSyncExecHal = true) {
 #ifdef NNTEST_ONLY_PUBLIC_API
     if (useCpuOnly || !allowSyncExecHal) {
         return 0;
@@ -51,14 +50,24 @@
     android::nn::DeviceManager::get()->setSyncExecHal(allowSyncExecHal);
 #endif
 
-    Execution::setComputeUsesSynchronousAPI(computeUsesSynchronousAPI);
-    Execution::setComputeUsesBurstAPI(computeUsesBurstAPI);
+    Execution::setComputeMode(computeMode);
 
-    LOG(INFO) << "test(useCpuOnly = " << useCpuOnly
-              << ", computeUsesSynchronousAPI = " << computeUsesSynchronousAPI
+    auto computeModeText = [computeMode] {
+        switch (computeMode) {
+            case Execution::ComputeMode::SYNC:
+                return "ComputeMode::SYNC";
+            case Execution::ComputeMode::ASYNC:
+                return "ComputeMode::ASYNC";
+            case Execution::ComputeMode::BURST:
+                return "ComputeMode::BURST";
+        }
+        return "<unknown ComputeMode>";
+    };
+
+    LOG(INFO) << "test(useCpuOnly = " << useCpuOnly << ", computeMode = " << computeModeText()
               << ", allowSyncExecHal = " << allowSyncExecHal << ")";
     std::cout << "[**********] useCpuOnly = " << useCpuOnly
-              << ", computeUsesSynchronousAPI = " << computeUsesSynchronousAPI
+              << ", computeMode = " << computeModeText()
               << ", allowSyncExecHal = " << allowSyncExecHal << std::endl;
     return RUN_ALL_TESTS();
 }
@@ -70,22 +79,25 @@
     android::nn::initVLogMask();
 #endif
 
-    int n = test(false, false) | test(false, true) | test(true, false) | test(true, true);
+    int n = test(/*useCpuOnly=*/false, Execution::ComputeMode::ASYNC) |
+            test(/*useCpuOnly=*/false, Execution::ComputeMode::SYNC) |
+            test(/*useCpuOnly=*/true, Execution::ComputeMode::ASYNC) |
+            test(/*useCpuOnly=*/true, Execution::ComputeMode::SYNC);
 
     // Now try disabling use of synchronous execution HAL.
     //
     // Whether or not the use of synchronous execution HAL is enabled should make no
     // difference when useCpuOnly = true; we already ran test(true, *, true) above,
     // so there's no reason to run test(true, *, false) now.
-    n |= test(false, false, false) | test(false, true, false);
+    n |= test(/*useCpuOnly=*/false, Execution::ComputeMode::ASYNC, /*allowSyncExecHal=*/false) |
+         test(/*useCpuOnly=*/false, Execution::ComputeMode::SYNC, /*allowSyncExecHal=*/false);
 
     // Now try execution using a burst.
     //
     // The burst path is off by default in these tests. This is the first case
-    // where it is turned on. Both "computeUsesSynchronousAPI" and
-    // "allowSyncExecHal" are irrelevant here because the burst path is separate
-    // from both.
-    n |= test(false, false, false, true);
+    // where it is turned on. Both "useCpuOnly" and "allowSyncExecHal" are
+    // irrelevant here because the burst path is separate from both.
+    n |= test(/*useCpuOnly=*/false, Execution::ComputeMode::BURST);
 
     return n;
 }
diff --git a/runtime/test/TestNeuralNetworksWrapper.cpp b/runtime/test/TestNeuralNetworksWrapper.cpp
index 9d61f49..9daa09d 100644
--- a/runtime/test/TestNeuralNetworksWrapper.cpp
+++ b/runtime/test/TestNeuralNetworksWrapper.cpp
@@ -20,9 +20,7 @@
 namespace nn {
 namespace test_wrapper {
 
-bool Execution::mComputeUsesBurstAPI = false;
-
-bool Execution::mComputeUsesSychronousAPI = true;
+Execution::ComputeMode Execution::mComputeMode = Execution::ComputeMode::SYNC;
 
 }  // namespace test_wrapper
 }  // namespace nn
diff --git a/runtime/test/TestNeuralNetworksWrapper.h b/runtime/test/TestNeuralNetworksWrapper.h
index ee3da9a..be1d4ea 100644
--- a/runtime/test/TestNeuralNetworksWrapper.h
+++ b/runtime/test/TestNeuralNetworksWrapper.h
@@ -162,42 +162,46 @@
     }
 
     Result compute() {
-        if (mComputeUsesBurstAPI) {
-            ANeuralNetworksBurst* burst = nullptr;
-            Result result = static_cast<Result>(ANeuralNetworksBurst_create(mCompilation, &burst));
-            if (result != Result::NO_ERROR) {
+        switch (mComputeMode) {
+            case ComputeMode::SYNC: {
+                return static_cast<Result>(ANeuralNetworksExecution_compute(mExecution));
+            }
+            case ComputeMode::ASYNC: {
+                ANeuralNetworksEvent* event = nullptr;
+                Result result = static_cast<Result>(
+                        ANeuralNetworksExecution_startCompute(mExecution, &event));
+                if (result != Result::NO_ERROR) {
+                    return result;
+                }
+                // TODO how to manage the lifetime of events when multiple waiters is not
+                // clear.
+                result = static_cast<Result>(ANeuralNetworksEvent_wait(event));
+                ANeuralNetworksEvent_free(event);
+                return result;
+            }
+            case ComputeMode::BURST: {
+                ANeuralNetworksBurst* burst = nullptr;
+                Result result =
+                        static_cast<Result>(ANeuralNetworksBurst_create(mCompilation, &burst));
+                if (result != Result::NO_ERROR) {
+                    return result;
+                }
+                result = static_cast<Result>(
+                        ANeuralNetworksExecution_burstCompute(mExecution, burst));
                 ANeuralNetworksBurst_free(burst);
                 return result;
             }
-            result = static_cast<Result>(ANeuralNetworksExecution_burstCompute(mExecution, burst));
-            ANeuralNetworksBurst_free(burst);
-            return result;
         }
-
-        if (!mComputeUsesSychronousAPI) {
-            ANeuralNetworksEvent* event = nullptr;
-            Result result =
-                    static_cast<Result>(ANeuralNetworksExecution_startCompute(mExecution, &event));
-            if (result != Result::NO_ERROR) {
-                return result;
-            }
-            // TODO how to manage the lifetime of events when multiple waiters is not
-            // clear.
-            result = static_cast<Result>(ANeuralNetworksEvent_wait(event));
-            ANeuralNetworksEvent_free(event);
-            return result;
-        }
-
-        return static_cast<Result>(ANeuralNetworksExecution_compute(mExecution));
+        return Result::BAD_DATA;
     }
 
-    // By default, compute() uses the synchronous API.
-    // setComputeUsesSynchronousAPI() can be used to change the behavior of
-    // compute() to instead use the asynchronous API and then wait for
-    // computation to complete.
-    static void setComputeUsesSynchronousAPI(bool val) { mComputeUsesSychronousAPI = val; }
-
-    static void setComputeUsesBurstAPI(bool val) { mComputeUsesBurstAPI = val; }
+    // By default, compute() uses the synchronous API. setComputeMode() can be
+    // used to change the behavior of compute() to either:
+    // - use the asynchronous API and then wait for computation to complete
+    // or
+    // - use the burst API
+    enum class ComputeMode { SYNC, ASYNC, BURST };
+    static void setComputeMode(ComputeMode mode) { mComputeMode = mode; }
 
     Result getOutputOperandDimensions(uint32_t index, std::vector<uint32_t>* dimensions) {
         uint32_t rank = 0;
@@ -217,11 +221,8 @@
     ANeuralNetworksCompilation* mCompilation = nullptr;
     ANeuralNetworksExecution* mExecution = nullptr;
 
-    // Initialized to false in TestNeuralNetworksWrapper.cpp.
-    static bool mComputeUsesBurstAPI;
-
-    // Initialized to true in TestNeuralNetworksWrapper.cpp.
-    static bool mComputeUsesSychronousAPI;
+    // Initialized to ComputeMode::SYNC in TestNeuralNetworksWrapper.cpp.
+    static ComputeMode mComputeMode;
 };
 
 }  // namespace test_wrapper
diff --git a/runtime/test/TestValidation.cpp b/runtime/test/TestValidation.cpp
index 4e95fa5..c3a6f0f 100644
--- a/runtime/test/TestValidation.cpp
+++ b/runtime/test/TestValidation.cpp
@@ -1111,6 +1111,15 @@
               ANEURALNETWORKS_UNEXPECTED_NULL);
 }
 
+TEST_F(ValidationTestBurst, BurstComputeBadCompilation) {
+    ANeuralNetworksCompilation* compilation;
+    ASSERT_EQ(ANeuralNetworksCompilation_create(mModel, &compilation), ANEURALNETWORKS_NO_ERROR);
+    // NOTE: ANeuralNetworksCompilation_finish not called
+
+    ANeuralNetworksBurst* burst;
+    EXPECT_EQ(ANeuralNetworksBurst_create(compilation, &burst), ANEURALNETWORKS_BAD_STATE);
+}
+
 TEST_F(ValidationTestBurst, BurstComputeDifferentCompilations) {
     ANeuralNetworksCompilation* secondCompilation;
     ASSERT_EQ(ANeuralNetworksCompilation_create(mModel, &secondCompilation),
@@ -1160,15 +1169,20 @@
                                                  sizeof(outputB0)),
               ANEURALNETWORKS_NO_ERROR);
 
-    // execute on the same burst concurrently
+    // Execute on the same burst concurrently. At least one result must be
+    // ANEURALNETWORKS_NO_ERROR. One may return ANEURALNETWORKS_BAD_STATE if the
+    // other is already executing on the burst.
     auto first = std::async(std::launch::async, [this] {
-        const int result = ANeuralNetworksExecution_burstCompute(mExecution, mBurst);
-        EXPECT_TRUE(result == ANEURALNETWORKS_BAD_STATE || result == ANEURALNETWORKS_NO_ERROR);
+        return ANeuralNetworksExecution_burstCompute(mExecution, mBurst);
     });
     auto second = std::async(std::launch::async, [this, secondExecution] {
-        const int result = ANeuralNetworksExecution_burstCompute(secondExecution, mBurst);
-        EXPECT_TRUE(result == ANEURALNETWORKS_BAD_STATE || result == ANEURALNETWORKS_NO_ERROR);
+        return ANeuralNetworksExecution_burstCompute(secondExecution, mBurst);
     });
+    const int result1 = first.get();
+    const int result2 = second.get();
+    EXPECT_TRUE(result1 == ANEURALNETWORKS_BAD_STATE || result1 == ANEURALNETWORKS_NO_ERROR);
+    EXPECT_TRUE(result2 == ANEURALNETWORKS_BAD_STATE || result2 == ANEURALNETWORKS_NO_ERROR);
+    EXPECT_TRUE(result1 == ANEURALNETWORKS_NO_ERROR || result2 == ANEURALNETWORKS_NO_ERROR);
 }
 
 TEST(ValidationTestIntrospection, GetNumDevices) {