Handle non-POINTER memory types in getBuffer

Discussion:
http://ag/c/platform/frameworks/ml/+/9911408/42/nn/runtime/ExecutionPlan.cpp#926

Bug: 148458829
Bug: 149693818
Test: NNT_static
Change-Id: I035739db915eecef1859b58a4132fac58497fba2
Merged-In: I035739db915eecef1859b58a4132fac58497fba2
(cherry picked from commit f214bd8bc908d19e1a4669085cedffdfa47cee62)
diff --git a/runtime/ExecutionBuilder.h b/runtime/ExecutionBuilder.h
index 180444b..aa33167 100644
--- a/runtime/ExecutionBuilder.h
+++ b/runtime/ExecutionBuilder.h
@@ -25,6 +25,7 @@
 
 #include "Callbacks.h"
 #include "ControlFlow.h"
+#include "CpuExecutor.h"
 #include "HalInterfaces.h"
 #include "Memory.h"
 #include "ModelArgumentInfo.h"
@@ -113,6 +114,10 @@
     const ModelArgumentInfo& getInputInfo(uint32_t index) const { return mInputs[index]; }
     const ModelArgumentInfo& getOutputInfo(uint32_t index) const { return mOutputs[index]; }
 
+    std::optional<RunTimePoolInfo> getRunTimePoolInfo(uint32_t poolIndex) const {
+        return mMemories[poolIndex]->getRunTimePoolInfo();
+    }
+
    private:
     // If a callback is provided, then this is asynchronous. If a callback is
     // not provided (i.e., is nullptr), then this is synchronous.
diff --git a/runtime/ExecutionPlan.cpp b/runtime/ExecutionPlan.cpp
index 3f19327..f19065b 100644
--- a/runtime/ExecutionPlan.cpp
+++ b/runtime/ExecutionPlan.cpp
@@ -41,6 +41,7 @@
 #include "Callbacks.h"
 #include "CompilationBuilder.h"
 #include "ControlFlow.h"
+#include "CpuExecutor.h"
 #include "ExecutionBuilder.h"
 #include "ExecutionBurstController.h"
 #include "GraphDump.h"
@@ -981,42 +982,81 @@
     return next(controller, executor);
 }
 
-static void* getBufferFromModelArgumentInfo(const ModelArgumentInfo& info) {
-    if (info.state == ModelArgumentInfo::POINTER) {
-        return info.buffer;
-    }
-    // TODO: Handle info.state == MEMORY.
-    return nullptr;
+ExecutionPlan::Buffer::Buffer(void* pointer, uint32_t size)
+    : mInfo(RunTimePoolInfo::createFromExistingBuffer(reinterpret_cast<uint8_t*>(pointer), size)),
+      mOffset(0) {}
+
+ExecutionPlan::Buffer::Buffer(RunTimePoolInfo info, uint32_t offset)
+    : mInfo(std::move(info)), mOffset(offset) {}
+
+void* ExecutionPlan::Buffer::getPointer() const {
+    return mInfo.getBuffer() + mOffset;
 }
 
-void* ExecutionPlan::getBuffer(std::shared_ptr<Controller> controller,
-                               SourceOperandIndex operandIndex) const {
+uint32_t ExecutionPlan::Buffer::getSize() const {
+    return mInfo.getSize() - mOffset;
+}
+
+void ExecutionPlan::Buffer::flush() const {
+    mInfo.flush();
+}
+
+std::optional<ExecutionPlan::Buffer> ExecutionPlan::getBufferFromModelArgumentInfo(
+        const ModelArgumentInfo& info, const ExecutionBuilder* executionBuilder) const {
+    switch (info.state) {
+        case ModelArgumentInfo::POINTER: {
+            return Buffer(info.buffer, info.locationAndLength.length);
+        } break;
+        case ModelArgumentInfo::MEMORY: {
+            if (std::optional<RunTimePoolInfo> poolInfo =
+                        executionBuilder->getRunTimePoolInfo(info.locationAndLength.poolIndex)) {
+                return Buffer(*poolInfo, info.locationAndLength.offset);
+            } else {
+                LOG(ERROR) << "Unable to map operand memory pool";
+                return std::nullopt;
+            }
+        } break;
+        case ModelArgumentInfo::HAS_NO_VALUE: {
+            LOG(ERROR) << "Attempting to read an operand that has no value";
+            return std::nullopt;
+        } break;
+        default: {
+            LOG(ERROR) << "Unexpected operand memory state: " << static_cast<int>(info.state);
+            return std::nullopt;
+        } break;
+    }
+}
+
+std::optional<ExecutionPlan::Buffer> ExecutionPlan::getBuffer(
+        std::shared_ptr<Controller> controller, SourceOperandIndex operandIndex) const {
     const auto& sourceOperandToOffsetOfTemporary = controller->mSourceOperandToOffsetOfTemporary;
     const auto& sourceOperandToInputIndex = controller->mSourceOperandToInputIndex;
     const auto& sourceOperandToOutputIndex = controller->mSourceOperandToOutputIndex;
     if (auto it = sourceOperandToOffsetOfTemporary.find(operandIndex);
         it != sourceOperandToOffsetOfTemporary.end()) {
         const uint32_t offset = it->second;
-        uint8_t* memory = controller->mTemporaries->getPointer();
-        return memory + offset;
+        const std::unique_ptr<MemoryAshmem>& memory = controller->mTemporaries;
+        return Buffer(memory->getPointer() + offset, memory->getSize() - offset);
     } else if (auto it = sourceOperandToInputIndex.find(operandIndex);
                it != sourceOperandToInputIndex.end()) {
         const ModelArgumentInfo& info = controller->mExecutionBuilder->getInputInfo(it->second);
-        return getBufferFromModelArgumentInfo(info);
+        return getBufferFromModelArgumentInfo(info, controller->mExecutionBuilder);
     } else if (auto it = sourceOperandToOutputIndex.find(operandIndex);
                it != sourceOperandToOutputIndex.end()) {
         const ModelArgumentInfo& info = controller->mExecutionBuilder->getOutputInfo(it->second);
-        return getBufferFromModelArgumentInfo(info);
+        return getBufferFromModelArgumentInfo(info, controller->mExecutionBuilder);
     }
-    return nullptr;
+    return std::nullopt;
 }
 
 bool ExecutionPlan::readConditionValue(std::shared_ptr<Controller> controller,
                                        SourceOperandIndex operandIndex) const {
-    auto buffer = reinterpret_cast<const uint8_t*>(getBuffer(controller, operandIndex));
-    CHECK(buffer != nullptr) << "Unable to read operand " << toString(operandIndex);
-    bool value = static_cast<bool>(buffer[0]);
-    VLOG(EXECUTION) << "readConditionValue: " << value;
+    std::optional<ExecutionPlan::Buffer> buffer = getBuffer(controller, operandIndex);
+    CHECK(buffer != std::nullopt) << "Unable to read operand " << toString(operandIndex);
+    bool8 value;
+    CHECK_GE(buffer->getSize(), sizeof(value));
+    std::memcpy(&value, buffer->getPointer(), sizeof(value));
+    VLOG(EXECUTION) << "readConditionValue: " << static_cast<int>(value);
     return value;
 }
 
@@ -1296,15 +1336,22 @@
             // WHILE operation input operand otherwise.
             const SourceOperandIndex& innerOperand = step->condInputOperands[i];
             const SourceOperandIndex& outerOperand = step->outerOutputOperands[i];
-            void* outerBuffer = getBuffer(controller, outerOperand);
-            CHECK(outerBuffer != nullptr);
+            std::optional<Buffer> outerBuffer = getBuffer(controller, outerOperand);
+            if (outerBuffer == std::nullopt) {
+                return ANEURALNETWORKS_OP_FAILED;
+            }
             const Operand& sourceOperand =
                     controller->mExecutionBuilder->getSourceOperand(outerOperand);
             const uint32_t size = TypeManager::get()->getSizeOfData(sourceOperand);
             CHECK_NE(size, 0u);
-            const void* innerBuffer = getBuffer(controller, innerOperand);
-            CHECK(innerBuffer != nullptr);
-            memcpy(outerBuffer, innerBuffer, size);
+            std::optional<Buffer> innerBuffer = getBuffer(controller, innerOperand);
+            if (innerBuffer == std::nullopt) {
+                return ANEURALNETWORKS_OP_FAILED;
+            }
+            CHECK_LE(size, innerBuffer->getSize());
+            CHECK_LE(size, outerBuffer->getSize());
+            memcpy(outerBuffer->getPointer(), innerBuffer->getPointer(), size);
+            outerBuffer->flush();
         }
         state.iteration = WhileState::kOutsideLoop;
     }
diff --git a/runtime/ExecutionPlan.h b/runtime/ExecutionPlan.h
index 8156f99..b01ae54 100644
--- a/runtime/ExecutionPlan.h
+++ b/runtime/ExecutionPlan.h
@@ -35,6 +35,7 @@
 
 #include "HalInterfaces.h"
 #include "Memory.h"
+#include "ModelArgumentInfo.h"
 #include "ModelBuilder.h"
 #include "NeuralNetworks.h"
 #include "TokenHasher.h"
@@ -455,8 +456,10 @@
                            sourceOperandToConstantReference);
 
         // Sets the location of innerOperand to be the same as the location of outerOperand.
-        void setInput(const SourceOperandIndex& outerOperand, const SourceOperandIndex& innerOperand);
-        void setOutput(const SourceOperandIndex& outerOperand, const SourceOperandIndex& innerOperand);
+        void setInput(const SourceOperandIndex& outerOperand,
+                      const SourceOperandIndex& innerOperand);
+        void setOutput(const SourceOperandIndex& outerOperand,
+                       const SourceOperandIndex& innerOperand);
 
         const ExecutionPlan* mPlan;
         ExecutionBuilder* mExecutionBuilder;
@@ -573,8 +576,24 @@
     void becomeCompoundIfEmpty();
     void findTempsAsStepModelOutputs();
 
-    // Returns the buffer associated with a partition boundary operand. Returns nullptr on failure.
-    void* getBuffer(std::shared_ptr<Controller> controller, SourceOperandIndex operandIndex) const;
+    class Buffer {
+       public:
+        Buffer(void* pointer, uint32_t size);
+        Buffer(RunTimePoolInfo info, uint32_t offset);
+        void* getPointer() const;
+        uint32_t getSize() const;
+        void flush() const;
+
+       private:
+        RunTimePoolInfo mInfo;
+        uint32_t mOffset;
+    };
+
+    // Returns the buffer associated with a partition boundary operand.
+    std::optional<Buffer> getBuffer(std::shared_ptr<Controller> controller,
+                                    SourceOperandIndex operandIndex) const;
+    std::optional<Buffer> getBufferFromModelArgumentInfo(
+            const ModelArgumentInfo& info, const ExecutionBuilder* executionBuilder) const;
     // Reads the value of a partition boundary boolean condition operand.
     bool readConditionValue(std::shared_ptr<Controller> controller,
                             SourceOperandIndex operandIndex) const;
diff --git a/runtime/Memory.cpp b/runtime/Memory.cpp
index 34c8c56..5c8136f 100644
--- a/runtime/Memory.cpp
+++ b/runtime/Memory.cpp
@@ -203,6 +203,11 @@
     return pool;
 }
 
+std::optional<RunTimePoolInfo> Memory::getRunTimePoolInfo() const {
+    // TODO(b/147777318): Cache memory mapping within the memory object.
+    return RunTimePoolInfo::createFromHidlMemory(kHidlMemory);
+}
+
 intptr_t Memory::getKey() const {
     return reinterpret_cast<intptr_t>(this);
 }
diff --git a/runtime/Memory.h b/runtime/Memory.h
index 6101d07..8d03417 100644
--- a/runtime/Memory.h
+++ b/runtime/Memory.h
@@ -32,6 +32,7 @@
 #include <utility>
 #include <vector>
 
+#include "CpuExecutor.h"
 #include "HalInterfaces.h"
 #include "NeuralNetworks.h"
 #include "Utils.h"
@@ -168,6 +169,7 @@
     const hal::hidl_memory& getHidlMemory() const { return kHidlMemory; }
     const sp<hal::IBuffer>& getIBuffer() const { return kBuffer; }
     virtual uint32_t getSize() const { return getHidlMemory().size(); }
+    virtual std::optional<RunTimePoolInfo> getRunTimePoolInfo() const;
 
     MemoryValidatorBase& getValidator() const {
         CHECK(mValidator != nullptr);
@@ -264,6 +266,10 @@
     // returns non-null because it was validated during MemoryAshmem::create.
     uint8_t* getPointer() const;
 
+    std::optional<RunTimePoolInfo> getRunTimePoolInfo() const override {
+        return RunTimePoolInfo::createFromExistingBuffer(getPointer(), kHidlMemory.size());
+    }
+
     // prefer using MemoryAshmem::create
     MemoryAshmem(sp<hal::IMemory> mapped, hal::hidl_memory memory);