Handle non-POINTER memory types in getBuffer
Discussion:
http://ag/c/platform/frameworks/ml/+/9911408/42/nn/runtime/ExecutionPlan.cpp#926
Bug: 148458829
Bug: 149693818
Test: NNT_static
Change-Id: I035739db915eecef1859b58a4132fac58497fba2
Merged-In: I035739db915eecef1859b58a4132fac58497fba2
(cherry picked from commit f214bd8bc908d19e1a4669085cedffdfa47cee62)
diff --git a/runtime/ExecutionBuilder.h b/runtime/ExecutionBuilder.h
index 180444b..aa33167 100644
--- a/runtime/ExecutionBuilder.h
+++ b/runtime/ExecutionBuilder.h
@@ -25,6 +25,7 @@
#include "Callbacks.h"
#include "ControlFlow.h"
+#include "CpuExecutor.h"
#include "HalInterfaces.h"
#include "Memory.h"
#include "ModelArgumentInfo.h"
@@ -113,6 +114,10 @@
const ModelArgumentInfo& getInputInfo(uint32_t index) const { return mInputs[index]; }
const ModelArgumentInfo& getOutputInfo(uint32_t index) const { return mOutputs[index]; }
+ std::optional<RunTimePoolInfo> getRunTimePoolInfo(uint32_t poolIndex) const {
+ return mMemories[poolIndex]->getRunTimePoolInfo();
+ }
+
private:
// If a callback is provided, then this is asynchronous. If a callback is
// not provided (i.e., is nullptr), then this is synchronous.
diff --git a/runtime/ExecutionPlan.cpp b/runtime/ExecutionPlan.cpp
index 3f19327..f19065b 100644
--- a/runtime/ExecutionPlan.cpp
+++ b/runtime/ExecutionPlan.cpp
@@ -41,6 +41,7 @@
#include "Callbacks.h"
#include "CompilationBuilder.h"
#include "ControlFlow.h"
+#include "CpuExecutor.h"
#include "ExecutionBuilder.h"
#include "ExecutionBurstController.h"
#include "GraphDump.h"
@@ -981,42 +982,81 @@
return next(controller, executor);
}
-static void* getBufferFromModelArgumentInfo(const ModelArgumentInfo& info) {
- if (info.state == ModelArgumentInfo::POINTER) {
- return info.buffer;
- }
- // TODO: Handle info.state == MEMORY.
- return nullptr;
+ExecutionPlan::Buffer::Buffer(void* pointer, uint32_t size)
+ : mInfo(RunTimePoolInfo::createFromExistingBuffer(reinterpret_cast<uint8_t*>(pointer), size)),
+ mOffset(0) {}
+
+ExecutionPlan::Buffer::Buffer(RunTimePoolInfo info, uint32_t offset)
+ : mInfo(std::move(info)), mOffset(offset) {}
+
+void* ExecutionPlan::Buffer::getPointer() const {
+ return mInfo.getBuffer() + mOffset;
}
-void* ExecutionPlan::getBuffer(std::shared_ptr<Controller> controller,
- SourceOperandIndex operandIndex) const {
+uint32_t ExecutionPlan::Buffer::getSize() const {
+ return mInfo.getSize() - mOffset;
+}
+
+void ExecutionPlan::Buffer::flush() const {
+ mInfo.flush();
+}
+
+std::optional<ExecutionPlan::Buffer> ExecutionPlan::getBufferFromModelArgumentInfo(
+ const ModelArgumentInfo& info, const ExecutionBuilder* executionBuilder) const {
+ switch (info.state) {
+ case ModelArgumentInfo::POINTER: {
+ return Buffer(info.buffer, info.locationAndLength.length);
+ } break;
+ case ModelArgumentInfo::MEMORY: {
+ if (std::optional<RunTimePoolInfo> poolInfo =
+ executionBuilder->getRunTimePoolInfo(info.locationAndLength.poolIndex)) {
+ return Buffer(*poolInfo, info.locationAndLength.offset);
+ } else {
+ LOG(ERROR) << "Unable to map operand memory pool";
+ return std::nullopt;
+ }
+ } break;
+ case ModelArgumentInfo::HAS_NO_VALUE: {
+ LOG(ERROR) << "Attempting to read an operand that has no value";
+ return std::nullopt;
+ } break;
+ default: {
+ LOG(ERROR) << "Unexpected operand memory state: " << static_cast<int>(info.state);
+ return std::nullopt;
+ } break;
+ }
+}
+
+std::optional<ExecutionPlan::Buffer> ExecutionPlan::getBuffer(
+ std::shared_ptr<Controller> controller, SourceOperandIndex operandIndex) const {
const auto& sourceOperandToOffsetOfTemporary = controller->mSourceOperandToOffsetOfTemporary;
const auto& sourceOperandToInputIndex = controller->mSourceOperandToInputIndex;
const auto& sourceOperandToOutputIndex = controller->mSourceOperandToOutputIndex;
if (auto it = sourceOperandToOffsetOfTemporary.find(operandIndex);
it != sourceOperandToOffsetOfTemporary.end()) {
const uint32_t offset = it->second;
- uint8_t* memory = controller->mTemporaries->getPointer();
- return memory + offset;
+ const std::unique_ptr<MemoryAshmem>& memory = controller->mTemporaries;
+ return Buffer(memory->getPointer() + offset, memory->getSize() - offset);
} else if (auto it = sourceOperandToInputIndex.find(operandIndex);
it != sourceOperandToInputIndex.end()) {
const ModelArgumentInfo& info = controller->mExecutionBuilder->getInputInfo(it->second);
- return getBufferFromModelArgumentInfo(info);
+ return getBufferFromModelArgumentInfo(info, controller->mExecutionBuilder);
} else if (auto it = sourceOperandToOutputIndex.find(operandIndex);
it != sourceOperandToOutputIndex.end()) {
const ModelArgumentInfo& info = controller->mExecutionBuilder->getOutputInfo(it->second);
- return getBufferFromModelArgumentInfo(info);
+ return getBufferFromModelArgumentInfo(info, controller->mExecutionBuilder);
}
- return nullptr;
+ return std::nullopt;
}
bool ExecutionPlan::readConditionValue(std::shared_ptr<Controller> controller,
SourceOperandIndex operandIndex) const {
- auto buffer = reinterpret_cast<const uint8_t*>(getBuffer(controller, operandIndex));
- CHECK(buffer != nullptr) << "Unable to read operand " << toString(operandIndex);
- bool value = static_cast<bool>(buffer[0]);
- VLOG(EXECUTION) << "readConditionValue: " << value;
+ std::optional<ExecutionPlan::Buffer> buffer = getBuffer(controller, operandIndex);
+ CHECK(buffer != std::nullopt) << "Unable to read operand " << toString(operandIndex);
+ bool8 value;
+ CHECK_GE(buffer->getSize(), sizeof(value));
+ std::memcpy(&value, buffer->getPointer(), sizeof(value));
+ VLOG(EXECUTION) << "readConditionValue: " << static_cast<int>(value);
return value;
}
@@ -1296,15 +1336,22 @@
// WHILE operation input operand otherwise.
const SourceOperandIndex& innerOperand = step->condInputOperands[i];
const SourceOperandIndex& outerOperand = step->outerOutputOperands[i];
- void* outerBuffer = getBuffer(controller, outerOperand);
- CHECK(outerBuffer != nullptr);
+ std::optional<Buffer> outerBuffer = getBuffer(controller, outerOperand);
+ if (outerBuffer == std::nullopt) {
+ return ANEURALNETWORKS_OP_FAILED;
+ }
const Operand& sourceOperand =
controller->mExecutionBuilder->getSourceOperand(outerOperand);
const uint32_t size = TypeManager::get()->getSizeOfData(sourceOperand);
CHECK_NE(size, 0u);
- const void* innerBuffer = getBuffer(controller, innerOperand);
- CHECK(innerBuffer != nullptr);
- memcpy(outerBuffer, innerBuffer, size);
+ std::optional<Buffer> innerBuffer = getBuffer(controller, innerOperand);
+ if (innerBuffer == std::nullopt) {
+ return ANEURALNETWORKS_OP_FAILED;
+ }
+ CHECK_LE(size, innerBuffer->getSize());
+ CHECK_LE(size, outerBuffer->getSize());
+ memcpy(outerBuffer->getPointer(), innerBuffer->getPointer(), size);
+ outerBuffer->flush();
}
state.iteration = WhileState::kOutsideLoop;
}
diff --git a/runtime/ExecutionPlan.h b/runtime/ExecutionPlan.h
index 8156f99..b01ae54 100644
--- a/runtime/ExecutionPlan.h
+++ b/runtime/ExecutionPlan.h
@@ -35,6 +35,7 @@
#include "HalInterfaces.h"
#include "Memory.h"
+#include "ModelArgumentInfo.h"
#include "ModelBuilder.h"
#include "NeuralNetworks.h"
#include "TokenHasher.h"
@@ -455,8 +456,10 @@
sourceOperandToConstantReference);
// Sets the location of innerOperand to be the same as the location of outerOperand.
- void setInput(const SourceOperandIndex& outerOperand, const SourceOperandIndex& innerOperand);
- void setOutput(const SourceOperandIndex& outerOperand, const SourceOperandIndex& innerOperand);
+ void setInput(const SourceOperandIndex& outerOperand,
+ const SourceOperandIndex& innerOperand);
+ void setOutput(const SourceOperandIndex& outerOperand,
+ const SourceOperandIndex& innerOperand);
const ExecutionPlan* mPlan;
ExecutionBuilder* mExecutionBuilder;
@@ -573,8 +576,24 @@
void becomeCompoundIfEmpty();
void findTempsAsStepModelOutputs();
- // Returns the buffer associated with a partition boundary operand. Returns nullptr on failure.
- void* getBuffer(std::shared_ptr<Controller> controller, SourceOperandIndex operandIndex) const;
+ class Buffer {
+ public:
+ Buffer(void* pointer, uint32_t size);
+ Buffer(RunTimePoolInfo info, uint32_t offset);
+ void* getPointer() const;
+ uint32_t getSize() const;
+ void flush() const;
+
+ private:
+ RunTimePoolInfo mInfo;
+ uint32_t mOffset;
+ };
+
+ // Returns the buffer associated with a partition boundary operand.
+ std::optional<Buffer> getBuffer(std::shared_ptr<Controller> controller,
+ SourceOperandIndex operandIndex) const;
+ std::optional<Buffer> getBufferFromModelArgumentInfo(
+ const ModelArgumentInfo& info, const ExecutionBuilder* executionBuilder) const;
// Reads the value of a partition boundary boolean condition operand.
bool readConditionValue(std::shared_ptr<Controller> controller,
SourceOperandIndex operandIndex) const;
diff --git a/runtime/Memory.cpp b/runtime/Memory.cpp
index 34c8c56..5c8136f 100644
--- a/runtime/Memory.cpp
+++ b/runtime/Memory.cpp
@@ -203,6 +203,11 @@
return pool;
}
+std::optional<RunTimePoolInfo> Memory::getRunTimePoolInfo() const {
+ // TODO(b/147777318): Cache memory mapping within the memory object.
+ return RunTimePoolInfo::createFromHidlMemory(kHidlMemory);
+}
+
intptr_t Memory::getKey() const {
return reinterpret_cast<intptr_t>(this);
}
diff --git a/runtime/Memory.h b/runtime/Memory.h
index 6101d07..8d03417 100644
--- a/runtime/Memory.h
+++ b/runtime/Memory.h
@@ -32,6 +32,7 @@
#include <utility>
#include <vector>
+#include "CpuExecutor.h"
#include "HalInterfaces.h"
#include "NeuralNetworks.h"
#include "Utils.h"
@@ -168,6 +169,7 @@
const hal::hidl_memory& getHidlMemory() const { return kHidlMemory; }
const sp<hal::IBuffer>& getIBuffer() const { return kBuffer; }
virtual uint32_t getSize() const { return getHidlMemory().size(); }
+ virtual std::optional<RunTimePoolInfo> getRunTimePoolInfo() const;
MemoryValidatorBase& getValidator() const {
CHECK(mValidator != nullptr);
@@ -264,6 +266,10 @@
// returns non-null because it was validated during MemoryAshmem::create.
uint8_t* getPointer() const;
+ std::optional<RunTimePoolInfo> getRunTimePoolInfo() const override {
+ return RunTimePoolInfo::createFromExistingBuffer(getPointer(), kHidlMemory.size());
+ }
+
// prefer using MemoryAshmem::create
MemoryAshmem(sp<hal::IMemory> mapped, hal::hidl_memory memory);