Implement control flow operation validation
Bug: 136735929
Test: m
Change-Id: I51c6fe63402eed3d8795a5e68b2f8122b144f4a4
Merged-In: I51c6fe63402eed3d8795a5e68b2f8122b144f4a4
(cherry picked from commit 9790ba2a4c8085d93ecef596c0e6cfa338bfaed2)
diff --git a/common/ValidateHal.cpp b/common/ValidateHal.cpp
index 74e2d7b..d3f43b8 100644
--- a/common/ValidateHal.cpp
+++ b/common/ValidateHal.cpp
@@ -101,6 +101,7 @@
case OperandType::INT32:
case OperandType::UINT32:
case OperandType::BOOL:
+ case OperandType::SUBGRAPH:
case OperandType::TENSOR_FLOAT32:
case OperandType::TENSOR_FLOAT16:
case OperandType::TENSOR_INT32:
@@ -164,7 +165,8 @@
template <typename VersionedOperand>
static bool validateOperands(const hidl_vec<VersionedOperand>& operands,
const hidl_vec<uint8_t>& operandValues,
- const hidl_vec<hidl_memory>& pools, bool allowUnspecifiedRank) {
+ const hidl_vec<hidl_memory>& pools,
+ const hidl_vec<Subgraph>& subgraphs, bool allowUnspecifiedRank) {
uint32_t index = 0;
MemoryAccessVerifier poolVerifier(pools);
for (auto& versionedOperand : operands) {
@@ -183,6 +185,7 @@
case OperandType::INT32:
case OperandType::UINT32:
case OperandType::BOOL:
+ case OperandType::SUBGRAPH:
case OperandType::OEM: {
size_t count = operand.dimensions.size();
if (count != 0) {
@@ -232,6 +235,7 @@
case OperandType::INT32:
case OperandType::UINT32:
case OperandType::BOOL:
+ case OperandType::SUBGRAPH:
case OperandType::TENSOR_FLOAT16:
case OperandType::TENSOR_FLOAT32:
case OperandType::TENSOR_BOOL8:
@@ -281,6 +285,7 @@
case OperandType::INT32:
case OperandType::UINT32:
case OperandType::BOOL:
+ case OperandType::SUBGRAPH:
case OperandType::TENSOR_FLOAT16:
case OperandType::TENSOR_FLOAT32:
case OperandType::TENSOR_INT32:
@@ -375,12 +380,37 @@
return false;
}
break;
+ case OperandLifeTime::SUBGRAPH: {
+ if (location.poolIndex != 0) {
+ LOG(ERROR) << "Operand " << index << ": SUBGRAPH with a non-zero poolIndex "
+ << location.poolIndex;
+ return false;
+ }
+ if (location.offset >= subgraphs.size()) {
+ LOG(ERROR) << "Subgraph index out of range: " << location.offset
+ << " >= " << subgraphs.size();
+ return false;
+ }
+ if (location.length != 0) {
+ LOG(ERROR) << "Operand " << index << ": SUBGRAPH with a non-zero length "
+ << location.length;
+ return false;
+ }
+ } break;
default:
LOG(ERROR) << "Operand " << index << ": Invalid lifetime "
<< toString(operand.lifetime);
return false;
}
+ // Make sure SUBGRAPH operand type and lifetime always go together.
+ if ((operand.type == OperandType::SUBGRAPH) !=
+ (operand.lifetime == OperandLifeTime::SUBGRAPH)) {
+ LOG(ERROR) << "Operand " << index << ": Operand of type " << toString(operand.type)
+ << " cannot have lifetime " << toString(operand.lifetime);
+ return false;
+ }
+
// For constants, validate that the length is as expected. The other lifetimes
// expect the length to be 0. Don't validate for OEM types.
if (operand.lifetime == OperandLifeTime::CONSTANT_REFERENCE ||
@@ -420,7 +450,35 @@
template <typename VersionedOperation>
static bool validateOperations(const hidl_vec<VersionedOperation>& operations,
- const hidl_vec<Operand>& operands) {
+ const hidl_vec<Operand>& operands,
+ const hidl_vec<Subgraph>& subgraphs) {
+ auto isValidSubgraphReference = [&subgraphs](const Operand& modelOperand) -> bool {
+ NN_RET_CHECK(modelOperand.type == OperandType::SUBGRAPH)
+ << "Unexpected operand type: " << toString(modelOperand.type);
+ NN_RET_CHECK_LT(modelOperand.location.offset, subgraphs.size())
+ << "Invalid subgraph reference";
+ return true;
+ };
+ auto getSubgraph = [&subgraphs](const Operand& modelOperand) {
+ CHECK_LT(modelOperand.location.offset, subgraphs.size());
+ return subgraphs[modelOperand.location.offset];
+ };
+ auto getInputCount = [&getSubgraph](const Operand& modelOperand) {
+ return getSubgraph(modelOperand).inputIndexes.size();
+ };
+ auto getOutputCount = [&getSubgraph](const Operand& modelOperand) {
+ return getSubgraph(modelOperand).outputIndexes.size();
+ };
+ auto getInputOperand = [&getSubgraph](const Operand& modelOperand, uint32_t index) {
+ const Subgraph& subgraph = getSubgraph(modelOperand);
+ CHECK_LT(subgraph.inputIndexes[index], subgraph.operands.size());
+ return subgraph.operands[subgraph.inputIndexes[index]];
+ };
+ auto getOutputOperand = [&getSubgraph](const Operand& modelOperand, uint32_t index) {
+ const Subgraph& subgraph = getSubgraph(modelOperand);
+ CHECK_LT(subgraph.outputIndexes[index], subgraph.operands.size());
+ return subgraph.operands[subgraph.outputIndexes[index]];
+ };
const size_t operandCount = operands.size();
// This vector keeps track of whether there's an operation that writes to
// each operand. It is used to validate that temporary variables and
@@ -432,7 +490,12 @@
int error = validateOperation(
static_cast<int32_t>(op.type), op.inputs.size(),
op.inputs.size() > 0 ? op.inputs.data() : nullptr, op.outputs.size(),
- op.outputs.size() > 0 ? op.outputs.data() : nullptr, operands, getHalVersion(op));
+ op.outputs.size() > 0 ? op.outputs.data() : nullptr, operands, getHalVersion(op),
+ {.isValidSubgraphReference = isValidSubgraphReference,
+ .getSubgraphInputCount = getInputCount,
+ .getSubgraphOutputCount = getOutputCount,
+ .getSubgraphInputOperand = getInputOperand,
+ .getSubgraphOutputOperand = getOutputOperand});
if (error != ANEURALNETWORKS_NO_ERROR) {
LOG(ERROR) << "Invalid operation " << toString(op.type);
return false;
@@ -541,9 +604,9 @@
// We only need versioned operands for their validation. For all the other
// validations we can use operands upcasted to the latest version.
const hidl_vec<Operand> latestVersionOperands = convertToV1_3(model.operands);
- return (validateOperands(model.operands, model.operandValues, model.pools,
+ return (validateOperands(model.operands, model.operandValues, model.pools, /*subgraphs=*/{},
/*allowUnspecifiedRank=*/version >= HalVersion::V1_2) &&
- validateOperations(model.operations, latestVersionOperands) &&
+ validateOperations(model.operations, latestVersionOperands, /*subgraphs=*/{}) &&
validateModelInputOutputs(model.inputIndexes, latestVersionOperands,
OperandLifeTime::SUBGRAPH_INPUT) &&
validateModelInputOutputs(model.outputIndexes, latestVersionOperands,
@@ -564,8 +627,8 @@
}
auto validateSubgraph = [&model](const Subgraph& subgraph) -> bool {
return (validateOperands(subgraph.operands, model.operandValues, model.pools,
- /*allowUnspecifiedRank=*/true) &&
- validateOperations(subgraph.operations, subgraph.operands) &&
+ model.referenced, /*allowUnspecifiedRank=*/true) &&
+ validateOperations(subgraph.operations, subgraph.operands, model.referenced) &&
validateModelInputOutputs(subgraph.inputIndexes, subgraph.operands,
OperandLifeTime::SUBGRAPH_INPUT) &&
validateModelInputOutputs(subgraph.outputIndexes, subgraph.operands,
@@ -752,6 +815,7 @@
case V1_3::OperandType::TENSOR_BOOL8:
case V1_3::OperandType::TENSOR_QUANT8_SYMM_PER_CHANNEL:
case V1_3::OperandType::TENSOR_QUANT8_ASYMM_SIGNED:
+ case V1_3::OperandType::SUBGRAPH:
case V1_3::OperandType::OEM:
case V1_3::OperandType::TENSOR_OEM_BYTE:
return true;