Support memory domain in sample driver.
Bug: 147777318
Test: NNT_static
Test: 1.3 VTS
Change-Id: I64c2d325d27de36d422e86cd34d7311cededbf48
Merged-In: I64c2d325d27de36d422e86cd34d7311cededbf48
(cherry picked from commit c0622db536ba13cfcc64f0c5e9acea15672978c7)
diff --git a/common/ValidateHal.cpp b/common/ValidateHal.cpp
index d391ed2..2e1c235 100644
--- a/common/ValidateHal.cpp
+++ b/common/ValidateHal.cpp
@@ -22,6 +22,7 @@
#include <algorithm>
#include <set>
+#include <utility>
#include <vector>
#include "NeuralNetworks.h"
@@ -781,6 +782,85 @@
validatePools(request.pools, HalVersion::V1_3));
}
+bool validateMemoryDesc(const V1_3::BufferDesc& desc,
+ const hidl_vec<sp<V1_3::IPreparedModel>>& preparedModels,
+ const hidl_vec<V1_3::BufferRole>& inputRoles,
+ const hidl_vec<V1_3::BufferRole>& outputRoles,
+ std::function<const V1_3::Model*(const sp<V1_3::IPreparedModel>&)> getModel,
+ std::set<PreparedModelRole>* preparedModelRoles,
+ V1_3::Operand* combinedOperand) {
+ NN_RET_CHECK(preparedModels.size() != 0);
+ NN_RET_CHECK(inputRoles.size() != 0 || outputRoles.size() != 0);
+
+ std::set<PreparedModelRole> roles;
+ std::vector<V1_3::Operand> operands;
+ operands.reserve(inputRoles.size() + outputRoles.size());
+ for (const auto& role : inputRoles) {
+ NN_RET_CHECK_LT(role.modelIndex, preparedModels.size());
+ const auto& preparedModel = preparedModels[role.modelIndex];
+ NN_RET_CHECK(preparedModel != nullptr);
+ const auto* model = getModel(preparedModel);
+ NN_RET_CHECK(model != nullptr);
+ const auto& inputIndexes = model->main.inputIndexes;
+ NN_RET_CHECK_LT(role.ioIndex, inputIndexes.size());
+ NN_RET_CHECK_GT(role.frequency, 0.0f);
+ NN_RET_CHECK_LE(role.frequency, 1.0f);
+ const auto [it, success] = roles.emplace(preparedModel.get(), IOType::INPUT, role.ioIndex);
+ NN_RET_CHECK(success);
+ operands.push_back(model->main.operands[inputIndexes[role.ioIndex]]);
+ }
+ for (const auto& role : outputRoles) {
+ NN_RET_CHECK_LT(role.modelIndex, preparedModels.size());
+ const auto& preparedModel = preparedModels[role.modelIndex];
+ NN_RET_CHECK(preparedModel != nullptr);
+ const auto* model = getModel(preparedModel);
+ NN_RET_CHECK(model != nullptr);
+ const auto& outputIndexes = model->main.outputIndexes;
+ NN_RET_CHECK_LT(role.ioIndex, outputIndexes.size());
+ NN_RET_CHECK_GT(role.frequency, 0.0f);
+ NN_RET_CHECK_LE(role.frequency, 1.0f);
+ const auto [it, success] = roles.emplace(preparedModel.get(), IOType::OUTPUT, role.ioIndex);
+ NN_RET_CHECK(success);
+ operands.push_back(model->main.operands[outputIndexes[role.ioIndex]]);
+ }
+
+ CHECK(!operands.empty());
+ const auto opType = operands[0].type;
+ const bool isExtension = isExtensionOperandType(opType);
+
+ std::vector<uint32_t> dimensions = desc.dimensions;
+ for (const auto& operand : operands) {
+ NN_RET_CHECK(operand.type == operands[0].type)
+ << toString(operand.type) << " vs " << toString(operands[0].type);
+ NN_RET_CHECK_EQ(operand.scale, operands[0].scale);
+ NN_RET_CHECK_EQ(operand.zeroPoint, operands[0].zeroPoint);
+ // NOTE: validateMemoryDesc cannot validate extra parameters for extension operand type.
+ if (!isExtension) {
+ NN_RET_CHECK(operand.extraParams == operands[0].extraParams)
+ << toString(operand.extraParams) << " vs " << toString(operands[0].extraParams);
+ }
+ const auto combined = combineDimensions(dimensions, operand.dimensions);
+ NN_RET_CHECK(combined.has_value());
+ dimensions = combined.value();
+ }
+
+ // NOTE: validateMemoryDesc cannot validate scalar dimensions with extension operand type.
+ if (!isExtension) {
+ NN_RET_CHECK(!nonExtensionOperandTypeIsScalar(static_cast<int>(opType)) ||
+ dimensions.empty())
+ << "invalid dimensions with scalar operand type.";
+ }
+
+ if (preparedModelRoles != nullptr) {
+ *preparedModelRoles = std::move(roles);
+ }
+ if (combinedOperand != nullptr) {
+ *combinedOperand = operands[0];
+ combinedOperand->dimensions = dimensions;
+ }
+ return true;
+}
+
bool validateExecutionPreference(ExecutionPreference preference) {
return preference == ExecutionPreference::LOW_POWER ||
preference == ExecutionPreference::FAST_SINGLE_ANSWER ||