Implement control flow operation validation
Bug: 136735929
Test: m
Change-Id: I51c6fe63402eed3d8795a5e68b2f8122b144f4a4
Merged-In: I51c6fe63402eed3d8795a5e68b2f8122b144f4a4
(cherry picked from commit 9790ba2a4c8085d93ecef596c0e6cfa338bfaed2)
diff --git a/common/Utils.cpp b/common/Utils.cpp
index 368ef47..1046baf 100644
--- a/common/Utils.cpp
+++ b/common/Utils.cpp
@@ -32,6 +32,7 @@
#include <utility>
#include <vector>
+#include "ControlFlow.h"
#include "NeuralNetworks.h"
#include "NeuralNetworksOEM.h"
#include "OperationResolver.h"
@@ -604,10 +605,155 @@
return ANEURALNETWORKS_NO_ERROR;
}
+// Checks if two operands have the same types, shapes, and parameters.
+// Omits lifetime, numberOfConsumers, and location.
+static bool compatible(const Operand& a, const Operand& b) {
+ NN_RET_CHECK(a.type == b.type) << toString(a.type) << " != " << toString(b.type);
+ NN_RET_CHECK(a.dimensions == b.dimensions)
+ << toString(a.dimensions) << " != " << toString(b.dimensions);
+ NN_RET_CHECK_EQ(a.scale, b.scale);
+ NN_RET_CHECK_EQ(a.zeroPoint, b.zeroPoint);
+ NN_RET_CHECK(a.extraParams == b.extraParams)
+ << toString(a.extraParams) << " != " << toString(b.extraParams);
+ return true;
+}
+
+static bool validateConditionOperand(const Operand& operand) {
+ NN_RET_CHECK(operand.type == OperandType::TENSOR_BOOL8)
+ << "Unexpected condition operand type: " << toString(operand.type);
+ NN_RET_CHECK_EQ(operand.dimensions.size(), 1u) << "Condition operand must be a singleton";
+ NN_RET_CHECK_EQ(operand.dimensions[0], 1u) << "Condition operand must be a singleton";
+ return true;
+}
+
+static void checkSubgraphValidationHelper(const SubgraphValidationHelper& helper) {
+ CHECK(helper.isValidSubgraphReference != nullptr);
+ CHECK(helper.getSubgraphInputCount != nullptr);
+ CHECK(helper.getSubgraphOutputCount != nullptr);
+ CHECK(helper.getSubgraphInputOperand != nullptr);
+ CHECK(helper.getSubgraphOutputOperand != nullptr);
+}
+
+static bool validateIfOperation(uint32_t inputCount, const uint32_t* inputs, uint32_t outputCount,
+ const uint32_t* outputs, const std::vector<Operand>& operands,
+ const SubgraphValidationHelper& helper) {
+ namespace op = operation_if;
+ checkSubgraphValidationHelper(helper);
+ NN_RET_CHECK_GE(inputCount, 3u) << "ANEURALNETWORKS_IF must have at least 3 inputs";
+ NN_RET_CHECK_GE(outputCount, 1u) << "ANEURALNETWORKS_IF must have at least 1 output";
+ auto validateBranchOperand = [&](const Operand& branchModelOperand) -> bool {
+ NN_RET_CHECK(helper.isValidSubgraphReference(branchModelOperand))
+ << "Operand is not a valid subgraph reference";
+ const uint32_t branchModelInputCount = helper.getSubgraphInputCount(branchModelOperand);
+ const uint32_t branchModelOutputCount = helper.getSubgraphOutputCount(branchModelOperand);
+ NN_RET_CHECK_EQ(inputCount, op::kFirstInput + branchModelInputCount);
+ NN_RET_CHECK_EQ(outputCount, branchModelOutputCount);
+ for (uint32_t i = 0; i < branchModelInputCount; ++i) {
+ const Operand& innerOperand = helper.getSubgraphInputOperand(branchModelOperand, i);
+ const Operand& outerOperand = operands[inputs[op::kFirstInput + i]];
+ NN_RET_CHECK(compatible(innerOperand, outerOperand));
+ }
+ for (uint32_t i = 0; i < branchModelOutputCount; ++i) {
+ const Operand& innerOperand = helper.getSubgraphOutputOperand(branchModelOperand, i);
+ const Operand& outerOperand = operands[outputs[i]];
+ NN_RET_CHECK(compatible(innerOperand, outerOperand));
+ }
+ return true;
+ };
+ NN_RET_CHECK(validateConditionOperand(operands[inputs[op::kCondBoolOperand]]))
+ << "Validation failed for IF condition operand";
+ NN_RET_CHECK(validateBranchOperand(operands[inputs[op::kThenModelOperand]]))
+ << "Validation failed for IF then model";
+ NN_RET_CHECK(validateBranchOperand(operands[inputs[op::kElseModelOperand]]))
+ << "Validation failed for IF else model";
+ return true;
+}
+
+static bool validateWhileOperation(uint32_t inputCount, const uint32_t* inputs,
+ uint32_t outputCount, const uint32_t* outputs,
+ const std::vector<Operand>& operands,
+ const SubgraphValidationHelper& helper) {
+ // Let the loop have
+ // - m >= 1 input-output operands,
+ // - k >= 0 state-only operands, and
+ // - n >= 0 input-only operands.
+ // Then
+ // - the WHILE loop operation has (2 + m + k + n) inputs and m outputs.
+ // - the condition model has (m + k + n) inputs and 1 output.
+ // - the body model has (m + k + n) inputs and (m + k) outputs.
+ namespace op = operation_while;
+ checkSubgraphValidationHelper(helper);
+ NN_RET_CHECK_GE(inputCount, 3u) << "ANEURALNETWORKS_WHILE must have at least 3 inputs";
+ NN_RET_CHECK_GE(outputCount, 1u) << "ANEURALNETWORKS_WHILE must have at least 1 output";
+ auto validateCondOperand = [&](const Operand& condModelOperand) -> bool {
+ NN_RET_CHECK(helper.isValidSubgraphReference(condModelOperand))
+ << "Operand is not a valid subgraph reference";
+ const uint32_t condModelInputCount = helper.getSubgraphInputCount(condModelOperand);
+ const uint32_t condModelOutputCount = helper.getSubgraphOutputCount(condModelOperand);
+ NN_RET_CHECK_EQ(inputCount, op::kFirstInput + condModelInputCount);
+ NN_RET_CHECK_EQ(condModelOutputCount, 1u);
+ for (uint32_t i = 0; i < condModelInputCount; ++i) {
+ const Operand& innerOperand = helper.getSubgraphInputOperand(condModelOperand, i);
+ const Operand& outerOperand = operands[inputs[op::kFirstInput + i]];
+ NN_RET_CHECK(compatible(innerOperand, outerOperand));
+ }
+ NN_RET_CHECK(
+ validateConditionOperand(helper.getSubgraphOutputOperand(condModelOperand, 0)));
+ return true;
+ };
+ auto validateBodyOperand = [&](const Operand& bodyModelOperand) -> bool {
+ NN_RET_CHECK(helper.isValidSubgraphReference(bodyModelOperand))
+ << "Operand is not a valid subgraph reference";
+ const uint32_t bodyModelInputCount = helper.getSubgraphInputCount(bodyModelOperand);
+ const uint32_t bodyModelOutputCount = helper.getSubgraphOutputCount(bodyModelOperand);
+ NN_RET_CHECK_EQ(inputCount, op::kFirstInput + bodyModelInputCount);
+ NN_RET_CHECK_GE(bodyModelOutputCount, outputCount);
+ NN_RET_CHECK_GE(bodyModelInputCount, bodyModelOutputCount);
+ const uint32_t inputOutputCount = outputCount;
+ const uint32_t stateOnlyCount = bodyModelOutputCount - inputOutputCount;
+ const uint32_t inputOnlyCount = bodyModelInputCount - bodyModelOutputCount;
+ for (uint32_t i = 0, n = inputOutputCount + stateOnlyCount + inputOnlyCount; i < n; ++i) {
+ const Operand& innerOperand = helper.getSubgraphInputOperand(bodyModelOperand, i);
+ const Operand& outerOperand = operands[inputs[op::kFirstInput + i]];
+ NN_RET_CHECK(compatible(innerOperand, outerOperand));
+ }
+ for (uint32_t i = 0; i < inputOutputCount; ++i) {
+ const Operand& innerOperand = helper.getSubgraphOutputOperand(bodyModelOperand, i);
+ const Operand& outerOperand = operands[outputs[i]];
+ NN_RET_CHECK(compatible(innerOperand, outerOperand));
+ }
+ for (uint32_t i = 0, n = inputOutputCount + stateOnlyCount; i < n; ++i) {
+ const Operand& inputOperand = helper.getSubgraphInputOperand(bodyModelOperand, i);
+ const Operand& outputOperand = helper.getSubgraphOutputOperand(bodyModelOperand, i);
+ NN_RET_CHECK(compatible(inputOperand, outputOperand));
+ }
+ return true;
+ };
+ NN_RET_CHECK(validateCondOperand(operands[inputs[op::kCondModelOperand]]))
+ << "Validation failed for WHILE condition model";
+ NN_RET_CHECK(validateBodyOperand(operands[inputs[op::kBodyModelOperand]]))
+ << "Validation failed for WHILE body model";
+ return true;
+}
+
+static inline int validateOperation(ANeuralNetworksOperationType opType, uint32_t inputCount,
+ const uint32_t* inputIndexes, uint32_t outputCount,
+ const uint32_t* outputIndexes,
+ const std::vector<hal::Operand>& operands,
+ HalVersion halVersion) {
+ if (opType == ANEURALNETWORKS_IF || opType == ANEURALNETWORKS_WHILE) {
+ NN_RETURN_IF_ERROR(validateHalVersion(opType, halVersion, HalVersion::V1_3));
+ LOG(ERROR) << "This validateOperation() overload does not support control flow";
+ return ANEURALNETWORKS_BAD_DATA;
+ }
+ return validateOperation(opType, inputCount, inputIndexes, outputCount, outputIndexes, operands,
+ halVersion, {});
+}
+
int validateOperation(ANeuralNetworksOperationType opType, uint32_t inputCount,
const uint32_t* inputIndexes, uint32_t outputCount,
const uint32_t* outputIndexes, const std::vector<Operand>& operands,
- HalVersion halVersion) {
+ HalVersion halVersion, const SubgraphValidationHelper& helper) {
NN_RETURN_IF_ERROR(validateOperandList(inputCount, inputIndexes,
static_cast<uint32_t>(operands.size()),
"ANeuralNetworksModel_addOperation inputs"));
@@ -1637,6 +1783,20 @@
inExpectedTypes, outputCount, outputIndexes,
outExpectedTypes);
}
+ case ANEURALNETWORKS_IF: {
+ NN_RETURN_IF_ERROR(validateHalVersion(opType, halVersion, HalVersion::V1_3));
+ return validateIfOperation(inputCount, inputIndexes, outputCount, outputIndexes,
+ operands, helper)
+ ? ANEURALNETWORKS_NO_ERROR
+ : ANEURALNETWORKS_BAD_DATA;
+ }
+ case ANEURALNETWORKS_WHILE: {
+ NN_RETURN_IF_ERROR(validateHalVersion(opType, halVersion, HalVersion::V1_3));
+ return validateWhileOperation(inputCount, inputIndexes, outputCount, outputIndexes,
+ operands, helper)
+ ? ANEURALNETWORKS_NO_ERROR
+ : ANEURALNETWORKS_BAD_DATA;
+ }
default: {
const OperationRegistration* operationRegistration =
BuiltinOperationResolver::get()->findOperation(
diff --git a/common/ValidateHal.cpp b/common/ValidateHal.cpp
index 74e2d7b..d3f43b8 100644
--- a/common/ValidateHal.cpp
+++ b/common/ValidateHal.cpp
@@ -101,6 +101,7 @@
case OperandType::INT32:
case OperandType::UINT32:
case OperandType::BOOL:
+ case OperandType::SUBGRAPH:
case OperandType::TENSOR_FLOAT32:
case OperandType::TENSOR_FLOAT16:
case OperandType::TENSOR_INT32:
@@ -164,7 +165,8 @@
template <typename VersionedOperand>
static bool validateOperands(const hidl_vec<VersionedOperand>& operands,
const hidl_vec<uint8_t>& operandValues,
- const hidl_vec<hidl_memory>& pools, bool allowUnspecifiedRank) {
+ const hidl_vec<hidl_memory>& pools,
+ const hidl_vec<Subgraph>& subgraphs, bool allowUnspecifiedRank) {
uint32_t index = 0;
MemoryAccessVerifier poolVerifier(pools);
for (auto& versionedOperand : operands) {
@@ -183,6 +185,7 @@
case OperandType::INT32:
case OperandType::UINT32:
case OperandType::BOOL:
+ case OperandType::SUBGRAPH:
case OperandType::OEM: {
size_t count = operand.dimensions.size();
if (count != 0) {
@@ -232,6 +235,7 @@
case OperandType::INT32:
case OperandType::UINT32:
case OperandType::BOOL:
+ case OperandType::SUBGRAPH:
case OperandType::TENSOR_FLOAT16:
case OperandType::TENSOR_FLOAT32:
case OperandType::TENSOR_BOOL8:
@@ -281,6 +285,7 @@
case OperandType::INT32:
case OperandType::UINT32:
case OperandType::BOOL:
+ case OperandType::SUBGRAPH:
case OperandType::TENSOR_FLOAT16:
case OperandType::TENSOR_FLOAT32:
case OperandType::TENSOR_INT32:
@@ -375,12 +380,37 @@
return false;
}
break;
+ case OperandLifeTime::SUBGRAPH: {
+ if (location.poolIndex != 0) {
+ LOG(ERROR) << "Operand " << index << ": SUBGRAPH with a non-zero poolIndex "
+ << location.poolIndex;
+ return false;
+ }
+ if (location.offset >= subgraphs.size()) {
+ LOG(ERROR) << "Subgraph index out of range: " << location.offset
+ << " >= " << subgraphs.size();
+ return false;
+ }
+ if (location.length != 0) {
+ LOG(ERROR) << "Operand " << index << ": SUBGRAPH with a non-zero length "
+ << location.length;
+ return false;
+ }
+ } break;
default:
LOG(ERROR) << "Operand " << index << ": Invalid lifetime "
<< toString(operand.lifetime);
return false;
}
+ // Make sure SUBGRAPH operand type and lifetime always go together.
+ if ((operand.type == OperandType::SUBGRAPH) !=
+ (operand.lifetime == OperandLifeTime::SUBGRAPH)) {
+ LOG(ERROR) << "Operand " << index << ": Operand of type " << toString(operand.type)
+ << " cannot have lifetime " << toString(operand.lifetime);
+ return false;
+ }
+
// For constants, validate that the length is as expected. The other lifetimes
// expect the length to be 0. Don't validate for OEM types.
if (operand.lifetime == OperandLifeTime::CONSTANT_REFERENCE ||
@@ -420,7 +450,35 @@
template <typename VersionedOperation>
static bool validateOperations(const hidl_vec<VersionedOperation>& operations,
- const hidl_vec<Operand>& operands) {
+ const hidl_vec<Operand>& operands,
+ const hidl_vec<Subgraph>& subgraphs) {
+ auto isValidSubgraphReference = [&subgraphs](const Operand& modelOperand) -> bool {
+ NN_RET_CHECK(modelOperand.type == OperandType::SUBGRAPH)
+ << "Unexpected operand type: " << toString(modelOperand.type);
+ NN_RET_CHECK_LT(modelOperand.location.offset, subgraphs.size())
+ << "Invalid subgraph reference";
+ return true;
+ };
+ auto getSubgraph = [&subgraphs](const Operand& modelOperand) {
+ CHECK_LT(modelOperand.location.offset, subgraphs.size());
+ return subgraphs[modelOperand.location.offset];
+ };
+ auto getInputCount = [&getSubgraph](const Operand& modelOperand) {
+ return getSubgraph(modelOperand).inputIndexes.size();
+ };
+ auto getOutputCount = [&getSubgraph](const Operand& modelOperand) {
+ return getSubgraph(modelOperand).outputIndexes.size();
+ };
+ auto getInputOperand = [&getSubgraph](const Operand& modelOperand, uint32_t index) {
+ const Subgraph& subgraph = getSubgraph(modelOperand);
+ CHECK_LT(subgraph.inputIndexes[index], subgraph.operands.size());
+ return subgraph.operands[subgraph.inputIndexes[index]];
+ };
+ auto getOutputOperand = [&getSubgraph](const Operand& modelOperand, uint32_t index) {
+ const Subgraph& subgraph = getSubgraph(modelOperand);
+ CHECK_LT(subgraph.outputIndexes[index], subgraph.operands.size());
+ return subgraph.operands[subgraph.outputIndexes[index]];
+ };
const size_t operandCount = operands.size();
// This vector keeps track of whether there's an operation that writes to
// each operand. It is used to validate that temporary variables and
@@ -432,7 +490,12 @@
int error = validateOperation(
static_cast<int32_t>(op.type), op.inputs.size(),
op.inputs.size() > 0 ? op.inputs.data() : nullptr, op.outputs.size(),
- op.outputs.size() > 0 ? op.outputs.data() : nullptr, operands, getHalVersion(op));
+ op.outputs.size() > 0 ? op.outputs.data() : nullptr, operands, getHalVersion(op),
+ {.isValidSubgraphReference = isValidSubgraphReference,
+ .getSubgraphInputCount = getInputCount,
+ .getSubgraphOutputCount = getOutputCount,
+ .getSubgraphInputOperand = getInputOperand,
+ .getSubgraphOutputOperand = getOutputOperand});
if (error != ANEURALNETWORKS_NO_ERROR) {
LOG(ERROR) << "Invalid operation " << toString(op.type);
return false;
@@ -541,9 +604,9 @@
// We only need versioned operands for their validation. For all the other
// validations we can use operands upcasted to the latest version.
const hidl_vec<Operand> latestVersionOperands = convertToV1_3(model.operands);
- return (validateOperands(model.operands, model.operandValues, model.pools,
+ return (validateOperands(model.operands, model.operandValues, model.pools, /*subgraphs=*/{},
/*allowUnspecifiedRank=*/version >= HalVersion::V1_2) &&
- validateOperations(model.operations, latestVersionOperands) &&
+ validateOperations(model.operations, latestVersionOperands, /*subgraphs=*/{}) &&
validateModelInputOutputs(model.inputIndexes, latestVersionOperands,
OperandLifeTime::SUBGRAPH_INPUT) &&
validateModelInputOutputs(model.outputIndexes, latestVersionOperands,
@@ -564,8 +627,8 @@
}
auto validateSubgraph = [&model](const Subgraph& subgraph) -> bool {
return (validateOperands(subgraph.operands, model.operandValues, model.pools,
- /*allowUnspecifiedRank=*/true) &&
- validateOperations(subgraph.operations, subgraph.operands) &&
+ model.referenced, /*allowUnspecifiedRank=*/true) &&
+ validateOperations(subgraph.operations, subgraph.operands, model.referenced) &&
validateModelInputOutputs(subgraph.inputIndexes, subgraph.operands,
OperandLifeTime::SUBGRAPH_INPUT) &&
validateModelInputOutputs(subgraph.outputIndexes, subgraph.operands,
@@ -752,6 +815,7 @@
case V1_3::OperandType::TENSOR_BOOL8:
case V1_3::OperandType::TENSOR_QUANT8_SYMM_PER_CHANNEL:
case V1_3::OperandType::TENSOR_QUANT8_ASYMM_SIGNED:
+ case V1_3::OperandType::SUBGRAPH:
case V1_3::OperandType::OEM:
case V1_3::OperandType::TENSOR_OEM_BYTE:
return true;
diff --git a/common/include/Utils.h b/common/include/Utils.h
index a90db35..7797381 100644
--- a/common/include/Utils.h
+++ b/common/include/Utils.h
@@ -319,12 +319,27 @@
int validateOperandList(uint32_t count, const uint32_t* list, uint32_t operandCount,
const char* tag);
+// A set of functions to help validate models containing IF or WHILE operations.
+struct SubgraphValidationHelper {
+ // Checks if a given operand is a SUBGRAPH operand with a valid offset.
+ std::function<bool(const hal::Operand&)> isValidSubgraphReference;
+ // Gets the input count of a subgraph referenced by a given operand.
+ std::function<uint32_t(const hal::Operand&)> getSubgraphInputCount;
+ // Gets the output count of a subgraph referenced by a given operand.
+ std::function<uint32_t(const hal::Operand&)> getSubgraphOutputCount;
+ // Gets the specified input operand of a subgraph referenced by a given operand.
+ std::function<const hal::Operand&(const hal::Operand&, uint32_t)> getSubgraphInputOperand;
+ // Gets the specified output operand of a subgraph referenced by a given operand.
+ std::function<const hal::Operand&(const hal::Operand&, uint32_t)> getSubgraphOutputOperand;
+};
+
// Returns ANEURALNETWORKS_NO_ERROR if the corresponding operation is defined and can handle the
// provided operand types in the given HAL version, otherwise returns ANEURALNETWORKS_BAD_DATA.
+// The last argument is only used for validating IF and WHILE operations.
int validateOperation(ANeuralNetworksOperationType opType, uint32_t inputCount,
const uint32_t* inputIndexes, uint32_t outputCount,
const uint32_t* outputIndexes, const std::vector<hal::Operand>& operands,
- HalVersion halVersion);
+ HalVersion halVersion, const SubgraphValidationHelper& helper);
inline size_t getSizeFromInts(int lower, int higher) {
return (uint32_t)(lower) + ((uint64_t)(uint32_t)(higher) << 32);