Propagate ANNMemory_free to IBurstContext::freeMemory

This CL extends nn::Memory to include a reference to all burst objects
which use its memory. When the nn:Memory object is destroyed (via
ANNMemory_free), it signals all these burst objects so that they can
properly clean their memory caches (via IBurstContext::freeMemory).

This CL also provides a more intelligent memory slot allocator within
ExecutionBurstController to reuse slots after they are freed.

This CL includes some additional miscellaneous code cleanup of the
neighboring test cases, e.g., closing file descriptors when they are no
longer needed.

Bug: 128319484
Test: mma
Test: atest NeuralNetworksTest_static
Change-Id: Ibc19059af5194cd3dd58c9a9d8baa54fa6b26de5
Merged-In: Ibc19059af5194cd3dd58c9a9d8baa54fa6b26de5
(cherry picked from commit be300c175e726e5b41e50af2b814cfd24f3360a1)
diff --git a/common/ExecutionBurstController.cpp b/common/ExecutionBurstController.cpp
index 87f5c1e..286c62a 100644
--- a/common/ExecutionBurstController.cpp
+++ b/common/ExecutionBurstController.cpp
@@ -38,14 +38,15 @@
 
     // get all memories
     hidl_vec<hidl_memory> memories(slots.size());
-    for (size_t i = 0; i < slots.size(); ++i) {
-        // if memory is available, return it; otherwise return error
-        auto iter = mSlotToMemoryCache.find(slots[i]);
-        if (iter == mSlotToMemoryCache.end()) {
-            cb(ErrorStatus::INVALID_ARGUMENT, {});
-            return Void();
-        }
-        memories[i] = iter->second;
+    std::transform(slots.begin(), slots.end(), memories.begin(), [this](int32_t slot) {
+        return slot < mMemoryCache.size() ? mMemoryCache[slot] : hidl_memory{};
+    });
+
+    // ensure all memories are valid
+    if (!std::all_of(memories.begin(), memories.end(),
+                     [](const hidl_memory& memory) { return memory.valid(); })) {
+        cb(ErrorStatus::INVALID_ARGUMENT, {});
+        return Void();
     }
 
     // return successful
@@ -70,26 +71,24 @@
         intptr_t key) {
     std::lock_guard<std::mutex> guard(mMutex);
 
-    auto iter = mMemoryIdToSlotCache.find(key);
-    if (iter != mMemoryIdToSlotCache.end()) {
-        const int32_t slot = iter->second;
-        mMemoryIdToSlotCache.erase(key);
-        mSlotToMemoryCache.erase(slot);
-        return {true, slot};
-    } else {
+    auto iter = mMemoryIdToSlot.find(key);
+    if (iter == mMemoryIdToSlot.end()) {
         return {false, 0};
     }
+    const int32_t slot = iter->second;
+    mMemoryIdToSlot.erase(key);
+    mMemoryCache[slot] = {};
+    mFreeSlots.push(slot);
+    return {true, slot};
 }
 
 int32_t ExecutionBurstController::ExecutionBurstCallback::getSlotLocked(const hidl_memory& memory,
                                                                         intptr_t key) {
-    auto iter = mMemoryIdToSlotCache.find(key);
-    if (iter == mMemoryIdToSlotCache.end()) {
-        const int32_t slot = mNextSlot;
-        // TODO: change mNextSlot to uint64_t or maintain a free list of IDs
-        mNextSlot = (mNextSlot + 1) % (1 << 30);
-        mMemoryIdToSlotCache[key] = slot;
-        mSlotToMemoryCache[slot] = memory;
+    auto iter = mMemoryIdToSlot.find(key);
+    if (iter == mMemoryIdToSlot.end()) {
+        const int32_t slot = allocateSlotLocked();
+        mMemoryIdToSlot[key] = slot;
+        mMemoryCache[slot] = memory;
         return slot;
     } else {
         const int32_t slot = iter->second;
@@ -97,6 +96,24 @@
     }
 }
 
+int32_t ExecutionBurstController::ExecutionBurstCallback::allocateSlotLocked() {
+    constexpr size_t kMaxNumberOfSlots = std::numeric_limits<int32_t>::max();
+
+    // if there is a free slot, use it
+    if (mFreeSlots.size() > 0) {
+        const int32_t slot = mFreeSlots.top();
+        mFreeSlots.pop();
+        return slot;
+    }
+
+    // otherwise use a slot for the first time
+    CHECK(mMemoryCache.size() < kMaxNumberOfSlots) << "Exceeded maximum number of slots!";
+    const int32_t slot = static_cast<int32_t>(mMemoryCache.size());
+    mMemoryCache.emplace_back();
+
+    return slot;
+}
+
 std::unique_ptr<ExecutionBurstController> ExecutionBurstController::create(
         const sp<IPreparedModel>& preparedModel, bool blocking) {
     // check inputs
@@ -386,6 +403,8 @@
         const Request& request, MeasureTiming measure, const std::vector<intptr_t>& memoryIds) {
     NNTRACE_FULL(NNTRACE_LAYER_IPC, NNTRACE_PHASE_EXECUTION, "ExecutionBurstController::compute");
 
+    std::lock_guard<std::mutex> guard(mMutex);
+
     // serialize request
     std::vector<FmqRequestDatum> requestData = serialize(request, measure, memoryIds);
 
@@ -411,6 +430,8 @@
 }
 
 void ExecutionBurstController::freeMemory(intptr_t key) {
+    std::lock_guard<std::mutex> guard(mMutex);
+
     bool valid;
     int32_t slot;
     std::tie(valid, slot) = mMemoryCache->freeMemory(key);
diff --git a/common/ExecutionBurstServer.cpp b/common/ExecutionBurstServer.cpp
index 6ede34d..84bb424 100644
--- a/common/ExecutionBurstServer.cpp
+++ b/common/ExecutionBurstServer.cpp
@@ -32,45 +32,64 @@
         const std::vector<int32_t>& slots) {
     std::lock_guard<std::mutex> guard(mMutex);
 
+    const auto slotIsKnown = [this](int32_t slot) {
+        return slot < mMemoryCache.size() && mMemoryCache[slot].valid();
+    };
+
     // find unique unknown slots
-    std::set<int32_t> setOfUnknownSlots;
-    for (int32_t slot : slots) {
-        if (mSlotToMemoryCache.find(slot) == mSlotToMemoryCache.end()) {
-            setOfUnknownSlots.insert(slot);
-        }
-    }
-    const std::vector<int32_t> vecOfUnknownSlots(setOfUnknownSlots.begin(),
-                                                 setOfUnknownSlots.end());
+    std::vector<int32_t> unknownSlots = slots;
+    auto unknownSlotsEnd = unknownSlots.end();
+    std::sort(unknownSlots.begin(), unknownSlotsEnd);
+    unknownSlotsEnd = std::unique(unknownSlots.begin(), unknownSlotsEnd);
+    unknownSlotsEnd = std::remove_if(unknownSlots.begin(), unknownSlotsEnd, slotIsKnown);
+    unknownSlots.erase(unknownSlotsEnd, unknownSlots.end());
 
     // retrieve unknown slots
-    if (!vecOfUnknownSlots.empty()) {
+    if (!unknownSlots.empty()) {
         ErrorStatus errorStatus = ErrorStatus::GENERAL_FAILURE;
         std::vector<hidl_memory> returnedMemories;
-        Return<void> ret = mCallback->getMemories(
-                vecOfUnknownSlots,
-                [&errorStatus, &returnedMemories](ErrorStatus status,
-                                                  const hidl_vec<hidl_memory>& memories) {
-                    errorStatus = status;
-                    if (status == ErrorStatus::NONE) {
-                        returnedMemories = memories;
-                    }
-                });
+        auto cb = [&errorStatus, &returnedMemories](ErrorStatus status,
+                                                    const hidl_vec<hidl_memory>& memories) {
+            errorStatus = status;
+            returnedMemories = memories;
+        };
 
-        if (!ret.isOk() || errorStatus != ErrorStatus::NONE) {
+        Return<void> ret = mCallback->getMemories(unknownSlots, cb);
+
+        // Ensure that the memories were successfully returned.
+        // IBurstCallback.hal specifies the that the number of memories returned
+        // must match the number of slots requested:
+        //     "slots.size() == buffers.size()"
+        if (!ret.isOk() || errorStatus != ErrorStatus::NONE ||
+            returnedMemories.size() != unknownSlots.size()) {
             LOG(ERROR) << "Error retrieving memories";
             return {};
         }
 
+        // resize cache to fit new slots if necessary
+        const int32_t maxUnknownSlot = unknownSlots.back();
+        if (maxUnknownSlot >= mMemoryCache.size()) {
+            mMemoryCache.resize(maxUnknownSlot + 1);
+        }
+
         // add memories to unknown slots
-        for (size_t i = 0; i < vecOfUnknownSlots.size(); ++i) {
-            mSlotToMemoryCache[vecOfUnknownSlots[i]] = returnedMemories[i];
+        for (size_t i = 0; i < unknownSlots.size(); ++i) {
+            mMemoryCache[unknownSlots[i]] = returnedMemories[i];
         }
     }
 
     // get all slots
     hidl_vec<hidl_memory> memories(slots.size());
-    for (size_t i = 0; i < slots.size(); ++i) {
-        memories[i] = mSlotToMemoryCache[slots[i]];
+    std::transform(slots.begin(), slots.end(), memories.begin(),
+                   [this](int32_t slot) { return mMemoryCache[slot]; });
+
+    // Ensure all slots are valid. Although this case is never expected to
+    // occur, theoretically IBurstCallback::getMemories could return invalid
+    // hidl_memory objects that must be protected against.
+    if (!std::all_of(memories.begin(), memories.end(),
+                     [](const hidl_memory& memory) { return memory.valid(); })) {
+        LOG(ERROR) << "Error, not all slots are valid!";
+        return {};
     }
 
     return memories;
@@ -78,7 +97,9 @@
 
 void ExecutionBurstServer::BurstMemoryCache::freeMemory(int32_t slot) {
     std::lock_guard<std::mutex> guard(mMutex);
-    mSlotToMemoryCache.erase(slot);
+    if (slot < mMemoryCache.size()) {
+        mMemoryCache[slot] = {};
+    }
 }
 
 sp<ExecutionBurstServer> ExecutionBurstServer::create(
diff --git a/common/include/ExecutionBurstController.h b/common/include/ExecutionBurstController.h
index 7152325..33564f4 100644
--- a/common/include/ExecutionBurstController.h
+++ b/common/include/ExecutionBurstController.h
@@ -24,6 +24,7 @@
 #include <map>
 #include <memory>
 #include <mutex>
+#include <stack>
 #include <tuple>
 #include "HalInterfaces.h"
 
@@ -93,11 +94,12 @@
 
        private:
         int32_t getSlotLocked(const hidl_memory& memory, intptr_t key);
+        int32_t allocateSlotLocked();
 
         std::mutex mMutex;
-        int32_t mNextSlot = 0;
-        std::map<intptr_t, int32_t> mMemoryIdToSlotCache;
-        std::map<int32_t, hidl_memory> mSlotToMemoryCache;
+        std::stack<int32_t, std::vector<int32_t>> mFreeSlots;
+        std::map<intptr_t, int32_t> mMemoryIdToSlot;
+        std::vector<hidl_memory> mMemoryCache;
     };
 
    public:
@@ -150,6 +152,7 @@
     std::tuple<ErrorStatus, std::vector<OutputShape>, Timing> deserialize(
             const std::vector<FmqResultDatum>& data);
 
+    std::mutex mMutex;
     const std::unique_ptr<FmqRequestChannel> mFmqRequestChannel;
     const std::unique_ptr<FmqResultChannel> mFmqResultChannel;
     const sp<IBurstContext> mBurstContext;
diff --git a/common/include/ExecutionBurstServer.h b/common/include/ExecutionBurstServer.h
index 0a3222f..af9b076 100644
--- a/common/include/ExecutionBurstServer.h
+++ b/common/include/ExecutionBurstServer.h
@@ -22,8 +22,7 @@
 #include <hidl/MQDescriptor.h>
 #include <atomic>
 #include <future>
-#include <map>
-#include <set>
+#include <vector>
 #include "HalInterfaces.h"
 
 namespace android::nn {
@@ -62,7 +61,7 @@
        private:
         std::mutex mMutex;
         const sp<IBurstCallback> mCallback;
-        std::map<int32_t, hidl_memory> mSlotToMemoryCache;
+        std::vector<hidl_memory> mMemoryCache;
     };
 
    public:
diff --git a/runtime/BurstBuilder.cpp b/runtime/BurstBuilder.cpp
index ee4b371..f8aa6be 100644
--- a/runtime/BurstBuilder.cpp
+++ b/runtime/BurstBuilder.cpp
@@ -25,7 +25,7 @@
 namespace nn {
 
 BurstBuilder::BurstBuilder(const CompilationBuilder* compilation,
-                           std::vector<std::unique_ptr<ExecutionBurstController>> burstControllers)
+                           std::vector<std::shared_ptr<ExecutionBurstController>> burstControllers)
     : mCompilation(compilation), mBurstControllers(std::move(burstControllers)) {}
 
 bool BurstBuilder::tryLock() {
@@ -41,8 +41,8 @@
     return mCompilation;
 }
 
-ExecutionBurstController* BurstBuilder::getControllerAt(size_t index) const {
-    return index < mBurstControllers.size() ? mBurstControllers[index].get() : nullptr;
+std::shared_ptr<ExecutionBurstController> BurstBuilder::getControllerAt(size_t index) const {
+    return index < mBurstControllers.size() ? mBurstControllers[index] : nullptr;
 }
 
 }  // namespace nn
diff --git a/runtime/BurstBuilder.h b/runtime/BurstBuilder.h
index 288cf84..8f2982d 100644
--- a/runtime/BurstBuilder.h
+++ b/runtime/BurstBuilder.h
@@ -40,18 +40,18 @@
 class BurstBuilder {
    public:
     BurstBuilder(const CompilationBuilder* compilation,
-                 std::vector<std::unique_ptr<ExecutionBurstController>> burstControllers);
+                 std::vector<std::shared_ptr<ExecutionBurstController>> burstControllers);
 
     bool tryLock();
     void unlock();
 
     const CompilationBuilder* getCompilation() const;
-    ExecutionBurstController* getControllerAt(size_t index) const;
+    std::shared_ptr<ExecutionBurstController> getControllerAt(size_t index) const;
 
    private:
     std::atomic_flag mCurrentlyRunning = ATOMIC_FLAG_INIT;
     const CompilationBuilder* mCompilation;
-    std::vector<std::unique_ptr<ExecutionBurstController>> mBurstControllers;
+    std::vector<std::shared_ptr<ExecutionBurstController>> mBurstControllers;
 };
 
 }  // namespace nn
diff --git a/runtime/CompilationBuilder.cpp b/runtime/CompilationBuilder.cpp
index a578671..33e09fd 100644
--- a/runtime/CompilationBuilder.cpp
+++ b/runtime/CompilationBuilder.cpp
@@ -151,7 +151,7 @@
         *burst = nullptr;
         return ANEURALNETWORKS_BAD_STATE;
     }
-    std::vector<std::unique_ptr<ExecutionBurstController>> burstControllers = mPlan.makeBursts();
+    std::vector<std::shared_ptr<ExecutionBurstController>> burstControllers = mPlan.makeBursts();
     *burst = new (std::nothrow) BurstBuilder(this, std::move(burstControllers));
     return (*burst ? ANEURALNETWORKS_NO_ERROR : ANEURALNETWORKS_OUT_OF_MEMORY);
 }
diff --git a/runtime/ExecutionBuilder.cpp b/runtime/ExecutionBuilder.cpp
index 035e84a..725f2cc 100644
--- a/runtime/ExecutionBuilder.cpp
+++ b/runtime/ExecutionBuilder.cpp
@@ -404,7 +404,7 @@
     while (true) {
         std::shared_ptr<StepExecutor> executor;
         VLOG(EXECUTION) << "looking for next StepExecutor";
-        ExecutionBurstController* burstController = nullptr;
+        std::shared_ptr<ExecutionBurstController> burstController = nullptr;
         int n = plan->next(controller, &executor, &burstController);
         if (n != ANEURALNETWORKS_NO_ERROR) {
             if (allowFallback) {
@@ -733,7 +733,7 @@
 }
 
 int StepExecutor::startCompute(sp<ExecutionCallback>* synchronizationCallback,
-                               ExecutionBurstController* burstController) {
+                               const std::shared_ptr<ExecutionBurstController>& burstController) {
     if (VLOG_IS_ON(EXECUTION)) {
         logArguments("input", mInputs);
         logArguments("output", mOutputs);
@@ -745,8 +745,9 @@
     }
 }
 
-int StepExecutor::startComputeOnDevice(sp<ExecutionCallback>* synchronizationCallback,
-                                       ExecutionBurstController* burstController) {
+int StepExecutor::startComputeOnDevice(
+        sp<ExecutionCallback>* synchronizationCallback,
+        const std::shared_ptr<ExecutionBurstController>& burstController) {
     CHECK(!isCpu());
 
     *synchronizationCallback = nullptr;
@@ -844,9 +845,11 @@
     sp<ExecutionCallback> executionCallback = new ExecutionCallback();
 
     if (burstController != nullptr) {
-        std::vector<intptr_t> memoryIds(mMemories.size());
-        for (size_t i = 0; i < mMemories.size(); ++i) {
-            memoryIds[i] = reinterpret_cast<intptr_t>(mMemories[i]);
+        std::vector<intptr_t> memoryIds;
+        memoryIds.reserve(mMemories.size());
+        for (const Memory* memory : mMemories) {
+            memory->usedBy(burstController);
+            memoryIds.push_back(memory->getKey());
         }
 
         VLOG(EXECUTION) << "Before ExecutionBurstController->compute() "
diff --git a/runtime/ExecutionBuilder.h b/runtime/ExecutionBuilder.h
index e2c01f7..3d471a3 100644
--- a/runtime/ExecutionBuilder.h
+++ b/runtime/ExecutionBuilder.h
@@ -213,7 +213,7 @@
 
     // Executes using the (driver, preparedModel) specified at construction time.
     int startCompute(sp<ExecutionCallback>* synchronizationCallback,
-                     ExecutionBurstController* burstController = nullptr);
+                     const std::shared_ptr<ExecutionBurstController>& burstController = nullptr);
 
     // Executes using the CPU, regardless of the (driver,
     // preparedModel) specified at construction time.
@@ -229,7 +229,7 @@
    private:
     int allocatePointerArgumentsToPool(std::vector<ModelArgumentInfo>* args, Memory* memory);
     int startComputeOnDevice(sp<ExecutionCallback>* synchronizationCallback,
-                             ExecutionBurstController* burstController = nullptr);
+                             const std::shared_ptr<ExecutionBurstController>& burstController);
 
     void mapInputOrOutput(const ModelArgumentInfo& builderInputOrOutput,
                           ModelArgumentInfo* executorInputOrOutput);
diff --git a/runtime/ExecutionPlan.cpp b/runtime/ExecutionPlan.cpp
index 265929d..4f483ec 100644
--- a/runtime/ExecutionPlan.cpp
+++ b/runtime/ExecutionPlan.cpp
@@ -643,11 +643,11 @@
 // indicate the regular execution path should be used. This can occur either
 // because PreparedModel was nullptr (cpu was best choice), or because the
 // IPreparedModel was of insufficient version or failed to configure the burst.
-std::vector<std::unique_ptr<ExecutionBurstController>> ExecutionPlan::makeBursts() const {
+std::vector<std::shared_ptr<ExecutionBurstController>> ExecutionPlan::makeBursts() const {
     switch (mState) {
         // burst object for each partition in the compound case
         case COMPOUND: {
-            std::vector<std::unique_ptr<ExecutionBurstController>> bursts;
+            std::vector<std::shared_ptr<ExecutionBurstController>> bursts;
             bursts.reserve(compound()->mSteps.size());
             for (const auto& step : compound()->mSteps) {
                 if (const auto preparedModel = step->getPreparedSubModel()) {
@@ -660,7 +660,7 @@
         }
         // single burst object for the simple case
         case SIMPLE: {
-            std::vector<std::unique_ptr<ExecutionBurstController>> burst;
+            std::vector<std::shared_ptr<ExecutionBurstController>> burst;
             auto simpleBody = static_cast<const SimpleBody*>(mBody);
             if (const auto preparedModel = simpleBody->mPreparedModel) {
                 burst.push_back(preparedModel->configureExecutionBurst(/*blocking=*/true));
@@ -756,7 +756,7 @@
 
 int ExecutionPlan::next(std::shared_ptr<Controller> controller,
                         std::shared_ptr<StepExecutor>* executor,
-                        ExecutionBurstController** burstController) const {
+                        std::shared_ptr<ExecutionBurstController>* burstController) const {
     *executor = nullptr;
     if (burstController != nullptr) {
         *burstController = nullptr;
diff --git a/runtime/ExecutionPlan.h b/runtime/ExecutionPlan.h
index 56d1dfa..fc5efe9 100644
--- a/runtime/ExecutionPlan.h
+++ b/runtime/ExecutionPlan.h
@@ -210,13 +210,13 @@
         size_t mNextStepIndex;
     };
 
-    std::vector<std::unique_ptr<ExecutionBurstController>> makeBursts() const;
+    std::vector<std::shared_ptr<ExecutionBurstController>> makeBursts() const;
 
     std::shared_ptr<Controller> makeController(ExecutionBuilder* executionBuilder,
                                                const BurstBuilder* burstBuilder) const;
 
     int next(std::shared_ptr<Controller> controller, std::shared_ptr<StepExecutor>* executor,
-             ExecutionBurstController** burstController = nullptr) const;
+             std::shared_ptr<ExecutionBurstController>* burstController = nullptr) const;
 
     // Create the same executor as the last one created by next().
     int fallback(std::shared_ptr<Controller> controller, std::shared_ptr<StepExecutor>* executor) const;
diff --git a/runtime/Memory.cpp b/runtime/Memory.cpp
index 7519bc9..097b140 100644
--- a/runtime/Memory.cpp
+++ b/runtime/Memory.cpp
@@ -18,12 +18,21 @@
 
 #include "Memory.h"
 
+#include "ExecutionBurstController.h"
 #include "HalInterfaces.h"
 #include "Utils.h"
 
 namespace android {
 namespace nn {
 
+Memory::~Memory() {
+    for (const auto [ptr, weakBurst] : mUsedBy) {
+        if (const std::shared_ptr<ExecutionBurstController> burst = weakBurst.lock()) {
+            burst->freeMemory(getKey());
+        }
+    }
+}
+
 int Memory::create(uint32_t size) {
     mHidlMemory = allocateSharedMemory(size);
     mMemory = mapMemory(mHidlMemory);
@@ -43,6 +52,15 @@
     }
 }
 
+intptr_t Memory::getKey() const {
+    return reinterpret_cast<intptr_t>(this);
+}
+
+void Memory::usedBy(const std::shared_ptr<ExecutionBurstController>& burst) const {
+    std::lock_guard<std::mutex> guard(mMutex);
+    mUsedBy.emplace(burst.get(), burst);
+}
+
 MemoryFd::~MemoryFd() {
     // Unmap the memory.
     if (mMapping) {
diff --git a/runtime/Memory.h b/runtime/Memory.h
index 37bb84a..4ebfabf 100644
--- a/runtime/Memory.h
+++ b/runtime/Memory.h
@@ -22,19 +22,21 @@
 
 #include <cutils/native_handle.h>
 #include <sys/mman.h>
+#include <mutex>
 #include <unordered_map>
 #include "vndk/hardware_buffer.h"
 
 namespace android {
 namespace nn {
 
+class ExecutionBurstController;
 class ModelBuilder;
 
 // Represents a memory region.
 class Memory {
    public:
     Memory() {}
-    virtual ~Memory() {}
+    virtual ~Memory();
 
     // Disallow copy semantics to ensure the runtime object can only be freed
     // once. Copy semantics could be enabled if some sort of reference counting
@@ -60,11 +62,30 @@
 
     virtual bool validateSize(uint32_t offset, uint32_t length) const;
 
+    // Unique key representing this memory object.
+    intptr_t getKey() const;
+
+    // Marks a burst object as currently using this memory. When this
+    // memory object is destroyed, it will automatically free this memory from
+    // the bursts' memory cache.
+    void usedBy(const std::shared_ptr<ExecutionBurstController>& burst) const;
+
    protected:
     // The hidl_memory handle for this shared memory.  We will pass this value when
     // communicating with the drivers.
     hardware::hidl_memory mHidlMemory;
     sp<IMemory> mMemory;
+
+    mutable std::mutex mMutex;
+    // mUsedBy is essentially a set of burst objects which use this Memory
+    // object. However, std::weak_ptr does not have comparison operations nor a
+    // std::hash implementation. This is because it is either a valid pointer
+    // (non-null) if the shared object is still alive, or it is null if the
+    // object has been freed. To circumvent this, mUsedBy is a map with the raw
+    // pointer as the key and the weak_ptr as the value.
+    mutable std::unordered_map<const ExecutionBurstController*,
+                               std::weak_ptr<ExecutionBurstController>>
+            mUsedBy;
 };
 
 class MemoryFd : public Memory {
diff --git a/runtime/VersionedInterfaces.cpp b/runtime/VersionedInterfaces.cpp
index 9c657c2..d552ab2 100644
--- a/runtime/VersionedInterfaces.cpp
+++ b/runtime/VersionedInterfaces.cpp
@@ -183,7 +183,7 @@
     }
 }
 
-std::unique_ptr<ExecutionBurstController> VersionedIPreparedModel::configureExecutionBurst(
+std::shared_ptr<ExecutionBurstController> VersionedIPreparedModel::configureExecutionBurst(
         bool blocking) const {
     if (mPreparedModelV1_2 != nullptr) {
         return ExecutionBurstController::create(mPreparedModelV1_2, blocking);
diff --git a/runtime/VersionedInterfaces.h b/runtime/VersionedInterfaces.h
index 529061c..210532d 100644
--- a/runtime/VersionedInterfaces.h
+++ b/runtime/VersionedInterfaces.h
@@ -636,7 +636,7 @@
      *                                  nullptr is returned if the burst cannot
      *                                  be configured for any reason.
      */
-    std::unique_ptr<ExecutionBurstController> configureExecutionBurst(bool blocking) const;
+    std::shared_ptr<ExecutionBurstController> configureExecutionBurst(bool blocking) const;
 
     /**
      * Returns whether this handle to an IPreparedModel object is valid or not.
diff --git a/runtime/test/TestValidation.cpp b/runtime/test/TestValidation.cpp
index c3a6f0f..c7d8b12 100644
--- a/runtime/test/TestValidation.cpp
+++ b/runtime/test/TestValidation.cpp
@@ -456,6 +456,9 @@
     // This should fail, as the model is already finished.
     EXPECT_EQ(ANeuralNetworksModel_setOperandValueFromMemory(mModel, 0, memory, 0, sizeof(float)),
               ANEURALNETWORKS_BAD_STATE);
+
+    // close memory
+    close(memoryFd);
 }
 
 TEST_F(ValidationTestModel, SetOperandValueFromAHardwareBuffer) {
@@ -888,6 +891,9 @@
     EXPECT_EQ(ANeuralNetworksExecution_setInputFromMemory(mExecution, 0, &kInvalidTensorType2,
                                                           memory, 0, sizeof(float)),
               ANEURALNETWORKS_BAD_DATA);
+
+    // close memory
+    close(memoryFd);
 }
 
 TEST_F(ValidationTestExecution, SetInputFromAHardwareBufferBlob) {
@@ -985,6 +991,9 @@
     EXPECT_EQ(ANeuralNetworksExecution_setOutputFromMemory(execution, 0, &kInvalidTensorType2,
                                                            memory, 0, sizeof(float)),
               ANEURALNETWORKS_BAD_DATA);
+
+    // close memory
+    close(memoryFd);
 }
 
 TEST_F(ValidationTestExecution, SetOutputFromAHardwareBufferBlob) {
@@ -1055,14 +1064,14 @@
     EXPECT_EQ(ANeuralNetworksExecution_create(mCompilation, &execution), ANEURALNETWORKS_NO_ERROR);
 
     float input0[] = {1.0f, 1.0f}, input1[] = {2.0f, 2.0f}, output0[2];
-    int32_t input2 = 0;
-    EXPECT_EQ(ANeuralNetworksExecution_setInput(execution, 0, nullptr, &input0, sizeof(input0)),
+    int32_t input2[] = {0};
+    EXPECT_EQ(ANeuralNetworksExecution_setInput(execution, 0, nullptr, input0, sizeof(input0)),
               ANEURALNETWORKS_NO_ERROR);
-    EXPECT_EQ(ANeuralNetworksExecution_setInput(execution, 1, nullptr, &input1, sizeof(input1)),
+    EXPECT_EQ(ANeuralNetworksExecution_setInput(execution, 1, nullptr, input1, sizeof(input1)),
               ANEURALNETWORKS_NO_ERROR);
-    EXPECT_EQ(ANeuralNetworksExecution_setInput(execution, 2, nullptr, &input2, sizeof(int32_t)),
+    EXPECT_EQ(ANeuralNetworksExecution_setInput(execution, 2, nullptr, input2, sizeof(input2)),
               ANEURALNETWORKS_NO_ERROR);
-    EXPECT_EQ(ANeuralNetworksExecution_setOutput(execution, 0, nullptr, &output0, sizeof(output0)),
+    EXPECT_EQ(ANeuralNetworksExecution_setOutput(execution, 0, nullptr, output0, sizeof(output0)),
               ANEURALNETWORKS_NO_ERROR);
 
     uint32_t rank, dims[4], expectedRank = 1, expectedDims = 2;
@@ -1126,12 +1135,13 @@
               ANEURALNETWORKS_NO_ERROR);
     ASSERT_EQ(ANeuralNetworksCompilation_finish(secondCompilation), ANEURALNETWORKS_NO_ERROR);
 
-    ANeuralNetworksBurst* burst;
-    EXPECT_EQ(ANeuralNetworksBurst_create(secondCompilation, &burst), ANEURALNETWORKS_NO_ERROR);
+    ANeuralNetworksExecution* execution;
+    EXPECT_EQ(ANeuralNetworksExecution_create(secondCompilation, &execution),
+              ANEURALNETWORKS_NO_ERROR);
 
-    EXPECT_EQ(ANeuralNetworksExecution_burstCompute(mExecution, burst), ANEURALNETWORKS_BAD_DATA);
+    EXPECT_EQ(ANeuralNetworksExecution_burstCompute(execution, mBurst), ANEURALNETWORKS_BAD_DATA);
 
-    ANeuralNetworksBurst_free(burst);
+    ANeuralNetworksExecution_free(execution);
     ANeuralNetworksCompilation_free(secondCompilation);
 }
 
@@ -1142,30 +1152,30 @@
 
     // set inputs of first execution
     float inputA0[] = {1.0f, 1.0f}, inputA1[] = {2.0f, 2.0f}, outputA0[2];
-    int32_t inputA2 = 0;
-    EXPECT_EQ(ANeuralNetworksExecution_setInput(mExecution, 0, nullptr, &inputA0, sizeof(inputA0)),
+    int32_t inputA2[] = {0};
+    EXPECT_EQ(ANeuralNetworksExecution_setInput(mExecution, 0, nullptr, inputA0, sizeof(inputA0)),
               ANEURALNETWORKS_NO_ERROR);
-    EXPECT_EQ(ANeuralNetworksExecution_setInput(mExecution, 1, nullptr, &inputA1, sizeof(inputA1)),
+    EXPECT_EQ(ANeuralNetworksExecution_setInput(mExecution, 1, nullptr, inputA1, sizeof(inputA1)),
               ANEURALNETWORKS_NO_ERROR);
-    EXPECT_EQ(ANeuralNetworksExecution_setInput(mExecution, 2, nullptr, &inputA2, sizeof(int32_t)),
+    EXPECT_EQ(ANeuralNetworksExecution_setInput(mExecution, 2, nullptr, inputA2, sizeof(inputA2)),
               ANEURALNETWORKS_NO_ERROR);
     EXPECT_EQ(
-            ANeuralNetworksExecution_setOutput(mExecution, 0, nullptr, &outputA0, sizeof(outputA0)),
+            ANeuralNetworksExecution_setOutput(mExecution, 0, nullptr, outputA0, sizeof(outputA0)),
             ANEURALNETWORKS_NO_ERROR);
 
     // set inputs of second execution
     float inputB0[] = {1.0f, 1.0f}, inputB1[] = {2.0f, 2.0f}, outputB0[2];
-    int32_t inputB2 = 0;
-    EXPECT_EQ(ANeuralNetworksExecution_setInput(secondExecution, 0, nullptr, &inputB0,
+    int32_t inputB2[] = {0};
+    EXPECT_EQ(ANeuralNetworksExecution_setInput(secondExecution, 0, nullptr, inputB0,
                                                 sizeof(inputB0)),
               ANEURALNETWORKS_NO_ERROR);
-    EXPECT_EQ(ANeuralNetworksExecution_setInput(secondExecution, 1, nullptr, &inputB1,
+    EXPECT_EQ(ANeuralNetworksExecution_setInput(secondExecution, 1, nullptr, inputB1,
                                                 sizeof(inputB1)),
               ANEURALNETWORKS_NO_ERROR);
-    EXPECT_EQ(ANeuralNetworksExecution_setInput(secondExecution, 2, nullptr, &inputB2,
-                                                sizeof(int32_t)),
+    EXPECT_EQ(ANeuralNetworksExecution_setInput(secondExecution, 2, nullptr, inputB2,
+                                                sizeof(inputB2)),
               ANEURALNETWORKS_NO_ERROR);
-    EXPECT_EQ(ANeuralNetworksExecution_setOutput(secondExecution, 0, nullptr, &outputB0,
+    EXPECT_EQ(ANeuralNetworksExecution_setOutput(secondExecution, 0, nullptr, outputB0,
                                                  sizeof(outputB0)),
               ANEURALNETWORKS_NO_ERROR);
 
@@ -1178,11 +1188,110 @@
     auto second = std::async(std::launch::async, [this, secondExecution] {
         return ANeuralNetworksExecution_burstCompute(secondExecution, mBurst);
     });
+
     const int result1 = first.get();
     const int result2 = second.get();
     EXPECT_TRUE(result1 == ANEURALNETWORKS_BAD_STATE || result1 == ANEURALNETWORKS_NO_ERROR);
     EXPECT_TRUE(result2 == ANEURALNETWORKS_BAD_STATE || result2 == ANEURALNETWORKS_NO_ERROR);
     EXPECT_TRUE(result1 == ANEURALNETWORKS_NO_ERROR || result2 == ANEURALNETWORKS_NO_ERROR);
+
+    ANeuralNetworksExecution_free(secondExecution);
+}
+
+// The burst object maintains a local cache of memory objects. Because the burst
+// is intended to live for multiple executions, and because memory might be
+// created and freed for each execution, burst includes internal mechanisms to
+// purge memory objects from its cache that have been freed by the NNAPI client.
+// The following two test cases (FreeMemoryBeforeBurst and
+// FreeBurstBeforeMemory) ensure that this internal cleanup is tested in both
+// freeing orders.
+//
+// These two test cases explicitly create a new burst object and a new execution
+// object so that the order of freeing can be specified. If these tests instead
+// relied on the provided mExecution and mBurst, mBurst would always be freed
+// before mExecution.
+
+TEST_F(ValidationTestBurst, FreeMemoryBeforeBurst) {
+    ANeuralNetworksBurst* burst;
+    EXPECT_EQ(ANeuralNetworksBurst_create(mCompilation, &burst), ANEURALNETWORKS_NO_ERROR);
+
+    // prepare data for execution
+    float input0[] = {1.0f, 1.0f}, input1[] = {2.0f, 2.0f}, output0[2];
+    int32_t input2[] = {0};
+
+    const size_t memorySize = sizeof(output0);
+    int memoryFd = ASharedMemory_create("nnMemory", memorySize);
+    ASSERT_GT(memoryFd, 0);
+
+    ANeuralNetworksMemory* memory;
+    EXPECT_EQ(ANeuralNetworksMemory_createFromFd(memorySize, PROT_READ | PROT_WRITE, memoryFd, 0,
+                                                 &memory),
+              ANEURALNETWORKS_NO_ERROR);
+
+    // create and configure execution
+    ANeuralNetworksExecution* execution;
+    EXPECT_EQ(ANeuralNetworksExecution_create(mCompilation, &execution), ANEURALNETWORKS_NO_ERROR);
+    EXPECT_EQ(ANeuralNetworksExecution_setInput(execution, 0, nullptr, input0, sizeof(input0)),
+              ANEURALNETWORKS_NO_ERROR);
+    EXPECT_EQ(ANeuralNetworksExecution_setInput(execution, 1, nullptr, input1, sizeof(input1)),
+              ANEURALNETWORKS_NO_ERROR);
+    EXPECT_EQ(ANeuralNetworksExecution_setInput(execution, 2, nullptr, input2, sizeof(input2)),
+              ANEURALNETWORKS_NO_ERROR);
+    EXPECT_EQ(ANeuralNetworksExecution_setOutputFromMemory(execution, 0, nullptr, memory, 0,
+                                                           sizeof(output0)),
+              ANEURALNETWORKS_NO_ERROR);
+
+    // preform execution to cache memory into burst
+    EXPECT_EQ(ANeuralNetworksExecution_burstCompute(execution, burst), ANEURALNETWORKS_NO_ERROR);
+    ANeuralNetworksExecution_free(execution);
+
+    // free memory before burst
+    ANeuralNetworksMemory_free(memory);
+    ANeuralNetworksBurst_free(burst);
+
+    // close memory
+    close(memoryFd);
+}
+
+TEST_F(ValidationTestBurst, FreeBurstBeforeMemory) {
+    ANeuralNetworksBurst* burst;
+    EXPECT_EQ(ANeuralNetworksBurst_create(mCompilation, &burst), ANEURALNETWORKS_NO_ERROR);
+
+    // prepare data for execution
+    float input0[] = {1.0f, 1.0f}, input1[] = {2.0f, 2.0f}, output0[2];
+    int32_t input2[] = {0};
+    const size_t memorySize = sizeof(output0);
+    int memoryFd = ASharedMemory_create("nnMemory", memorySize);
+    ASSERT_GT(memoryFd, 0);
+
+    ANeuralNetworksMemory* memory;
+    EXPECT_EQ(ANeuralNetworksMemory_createFromFd(memorySize, PROT_READ | PROT_WRITE, memoryFd, 0,
+                                                 &memory),
+              ANEURALNETWORKS_NO_ERROR);
+
+    // create and configure execution
+    ANeuralNetworksExecution* execution;
+    EXPECT_EQ(ANeuralNetworksExecution_create(mCompilation, &execution), ANEURALNETWORKS_NO_ERROR);
+    EXPECT_EQ(ANeuralNetworksExecution_setInput(execution, 0, nullptr, input0, sizeof(input0)),
+              ANEURALNETWORKS_NO_ERROR);
+    EXPECT_EQ(ANeuralNetworksExecution_setInput(execution, 1, nullptr, input1, sizeof(input1)),
+              ANEURALNETWORKS_NO_ERROR);
+    EXPECT_EQ(ANeuralNetworksExecution_setInput(execution, 2, nullptr, input2, sizeof(input2)),
+              ANEURALNETWORKS_NO_ERROR);
+    EXPECT_EQ(ANeuralNetworksExecution_setOutputFromMemory(execution, 0, nullptr, memory, 0,
+                                                           sizeof(output0)),
+              ANEURALNETWORKS_NO_ERROR);
+
+    // preform execution to cache memory into burst
+    EXPECT_EQ(ANeuralNetworksExecution_burstCompute(execution, burst), ANEURALNETWORKS_NO_ERROR);
+    ANeuralNetworksExecution_free(execution);
+
+    // free burst before memory
+    ANeuralNetworksBurst_free(burst);
+    ANeuralNetworksMemory_free(memory);
+
+    // close memory
+    close(memoryFd);
 }
 
 TEST(ValidationTestIntrospection, GetNumDevices) {