Implement control flow operation validation

Bug: 136735929
Test: m
Change-Id: I51c6fe63402eed3d8795a5e68b2f8122b144f4a4
Merged-In: I51c6fe63402eed3d8795a5e68b2f8122b144f4a4
(cherry picked from commit 9790ba2a4c8085d93ecef596c0e6cfa338bfaed2)
diff --git a/common/ValidateHal.cpp b/common/ValidateHal.cpp
index 74e2d7b..d3f43b8 100644
--- a/common/ValidateHal.cpp
+++ b/common/ValidateHal.cpp
@@ -101,6 +101,7 @@
         case OperandType::INT32:
         case OperandType::UINT32:
         case OperandType::BOOL:
+        case OperandType::SUBGRAPH:
         case OperandType::TENSOR_FLOAT32:
         case OperandType::TENSOR_FLOAT16:
         case OperandType::TENSOR_INT32:
@@ -164,7 +165,8 @@
 template <typename VersionedOperand>
 static bool validateOperands(const hidl_vec<VersionedOperand>& operands,
                              const hidl_vec<uint8_t>& operandValues,
-                             const hidl_vec<hidl_memory>& pools, bool allowUnspecifiedRank) {
+                             const hidl_vec<hidl_memory>& pools,
+                             const hidl_vec<Subgraph>& subgraphs, bool allowUnspecifiedRank) {
     uint32_t index = 0;
     MemoryAccessVerifier poolVerifier(pools);
     for (auto& versionedOperand : operands) {
@@ -183,6 +185,7 @@
             case OperandType::INT32:
             case OperandType::UINT32:
             case OperandType::BOOL:
+            case OperandType::SUBGRAPH:
             case OperandType::OEM: {
                 size_t count = operand.dimensions.size();
                 if (count != 0) {
@@ -232,6 +235,7 @@
             case OperandType::INT32:
             case OperandType::UINT32:
             case OperandType::BOOL:
+            case OperandType::SUBGRAPH:
             case OperandType::TENSOR_FLOAT16:
             case OperandType::TENSOR_FLOAT32:
             case OperandType::TENSOR_BOOL8:
@@ -281,6 +285,7 @@
             case OperandType::INT32:
             case OperandType::UINT32:
             case OperandType::BOOL:
+            case OperandType::SUBGRAPH:
             case OperandType::TENSOR_FLOAT16:
             case OperandType::TENSOR_FLOAT32:
             case OperandType::TENSOR_INT32:
@@ -375,12 +380,37 @@
                     return false;
                 }
                 break;
+            case OperandLifeTime::SUBGRAPH: {
+                if (location.poolIndex != 0) {
+                    LOG(ERROR) << "Operand " << index << ": SUBGRAPH with a non-zero poolIndex "
+                               << location.poolIndex;
+                    return false;
+                }
+                if (location.offset >= subgraphs.size()) {
+                    LOG(ERROR) << "Subgraph index out of range: " << location.offset
+                               << " >= " << subgraphs.size();
+                    return false;
+                }
+                if (location.length != 0) {
+                    LOG(ERROR) << "Operand " << index << ": SUBGRAPH with a non-zero length "
+                               << location.length;
+                    return false;
+                }
+            } break;
             default:
                 LOG(ERROR) << "Operand " << index << ": Invalid lifetime "
                            << toString(operand.lifetime);
                 return false;
         }
 
+        // Make sure SUBGRAPH operand type and lifetime always go together.
+        if ((operand.type == OperandType::SUBGRAPH) !=
+            (operand.lifetime == OperandLifeTime::SUBGRAPH)) {
+            LOG(ERROR) << "Operand " << index << ": Operand of type " << toString(operand.type)
+                       << " cannot have lifetime " << toString(operand.lifetime);
+            return false;
+        }
+
         // For constants, validate that the length is as expected. The other lifetimes
         // expect the length to be 0. Don't validate for OEM types.
         if (operand.lifetime == OperandLifeTime::CONSTANT_REFERENCE ||
@@ -420,7 +450,35 @@
 
 template <typename VersionedOperation>
 static bool validateOperations(const hidl_vec<VersionedOperation>& operations,
-                               const hidl_vec<Operand>& operands) {
+                               const hidl_vec<Operand>& operands,
+                               const hidl_vec<Subgraph>& subgraphs) {
+    auto isValidSubgraphReference = [&subgraphs](const Operand& modelOperand) -> bool {
+        NN_RET_CHECK(modelOperand.type == OperandType::SUBGRAPH)
+                << "Unexpected operand type: " << toString(modelOperand.type);
+        NN_RET_CHECK_LT(modelOperand.location.offset, subgraphs.size())
+                << "Invalid subgraph reference";
+        return true;
+    };
+    auto getSubgraph = [&subgraphs](const Operand& modelOperand) {
+        CHECK_LT(modelOperand.location.offset, subgraphs.size());
+        return subgraphs[modelOperand.location.offset];
+    };
+    auto getInputCount = [&getSubgraph](const Operand& modelOperand) {
+        return getSubgraph(modelOperand).inputIndexes.size();
+    };
+    auto getOutputCount = [&getSubgraph](const Operand& modelOperand) {
+        return getSubgraph(modelOperand).outputIndexes.size();
+    };
+    auto getInputOperand = [&getSubgraph](const Operand& modelOperand, uint32_t index) {
+        const Subgraph& subgraph = getSubgraph(modelOperand);
+        CHECK_LT(subgraph.inputIndexes[index], subgraph.operands.size());
+        return subgraph.operands[subgraph.inputIndexes[index]];
+    };
+    auto getOutputOperand = [&getSubgraph](const Operand& modelOperand, uint32_t index) {
+        const Subgraph& subgraph = getSubgraph(modelOperand);
+        CHECK_LT(subgraph.outputIndexes[index], subgraph.operands.size());
+        return subgraph.operands[subgraph.outputIndexes[index]];
+    };
     const size_t operandCount = operands.size();
     // This vector keeps track of whether there's an operation that writes to
     // each operand. It is used to validate that temporary variables and
@@ -432,7 +490,12 @@
         int error = validateOperation(
                 static_cast<int32_t>(op.type), op.inputs.size(),
                 op.inputs.size() > 0 ? op.inputs.data() : nullptr, op.outputs.size(),
-                op.outputs.size() > 0 ? op.outputs.data() : nullptr, operands, getHalVersion(op));
+                op.outputs.size() > 0 ? op.outputs.data() : nullptr, operands, getHalVersion(op),
+                {.isValidSubgraphReference = isValidSubgraphReference,
+                 .getSubgraphInputCount = getInputCount,
+                 .getSubgraphOutputCount = getOutputCount,
+                 .getSubgraphInputOperand = getInputOperand,
+                 .getSubgraphOutputOperand = getOutputOperand});
         if (error != ANEURALNETWORKS_NO_ERROR) {
             LOG(ERROR) << "Invalid operation " << toString(op.type);
             return false;
@@ -541,9 +604,9 @@
     // We only need versioned operands for their validation. For all the other
     // validations we can use operands upcasted to the latest version.
     const hidl_vec<Operand> latestVersionOperands = convertToV1_3(model.operands);
-    return (validateOperands(model.operands, model.operandValues, model.pools,
+    return (validateOperands(model.operands, model.operandValues, model.pools, /*subgraphs=*/{},
                              /*allowUnspecifiedRank=*/version >= HalVersion::V1_2) &&
-            validateOperations(model.operations, latestVersionOperands) &&
+            validateOperations(model.operations, latestVersionOperands, /*subgraphs=*/{}) &&
             validateModelInputOutputs(model.inputIndexes, latestVersionOperands,
                                       OperandLifeTime::SUBGRAPH_INPUT) &&
             validateModelInputOutputs(model.outputIndexes, latestVersionOperands,
@@ -564,8 +627,8 @@
     }
     auto validateSubgraph = [&model](const Subgraph& subgraph) -> bool {
         return (validateOperands(subgraph.operands, model.operandValues, model.pools,
-                                 /*allowUnspecifiedRank=*/true) &&
-                validateOperations(subgraph.operations, subgraph.operands) &&
+                                 model.referenced, /*allowUnspecifiedRank=*/true) &&
+                validateOperations(subgraph.operations, subgraph.operands, model.referenced) &&
                 validateModelInputOutputs(subgraph.inputIndexes, subgraph.operands,
                                           OperandLifeTime::SUBGRAPH_INPUT) &&
                 validateModelInputOutputs(subgraph.outputIndexes, subgraph.operands,
@@ -752,6 +815,7 @@
         case V1_3::OperandType::TENSOR_BOOL8:
         case V1_3::OperandType::TENSOR_QUANT8_SYMM_PER_CHANNEL:
         case V1_3::OperandType::TENSOR_QUANT8_ASYMM_SIGNED:
+        case V1_3::OperandType::SUBGRAPH:
         case V1_3::OperandType::OEM:
         case V1_3::OperandType::TENSOR_OEM_BYTE:
             return true;