Implement control flow operation validation

Bug: 136735929
Test: m
Change-Id: I51c6fe63402eed3d8795a5e68b2f8122b144f4a4
Merged-In: I51c6fe63402eed3d8795a5e68b2f8122b144f4a4
(cherry picked from commit 9790ba2a4c8085d93ecef596c0e6cfa338bfaed2)
diff --git a/common/Utils.cpp b/common/Utils.cpp
index 368ef47..1046baf 100644
--- a/common/Utils.cpp
+++ b/common/Utils.cpp
@@ -32,6 +32,7 @@
 #include <utility>
 #include <vector>
 
+#include "ControlFlow.h"
 #include "NeuralNetworks.h"
 #include "NeuralNetworksOEM.h"
 #include "OperationResolver.h"
@@ -604,10 +605,155 @@
     return ANEURALNETWORKS_NO_ERROR;
 }
 
+// Checks if two operands have the same types, shapes, and parameters.
+// Omits lifetime, numberOfConsumers, and location.
+static bool compatible(const Operand& a, const Operand& b) {
+    NN_RET_CHECK(a.type == b.type) << toString(a.type) << " != " << toString(b.type);
+    NN_RET_CHECK(a.dimensions == b.dimensions)
+            << toString(a.dimensions) << " != " << toString(b.dimensions);
+    NN_RET_CHECK_EQ(a.scale, b.scale);
+    NN_RET_CHECK_EQ(a.zeroPoint, b.zeroPoint);
+    NN_RET_CHECK(a.extraParams == b.extraParams)
+            << toString(a.extraParams) << " != " << toString(b.extraParams);
+    return true;
+}
+
+static bool validateConditionOperand(const Operand& operand) {
+    NN_RET_CHECK(operand.type == OperandType::TENSOR_BOOL8)
+            << "Unexpected condition operand type: " << toString(operand.type);
+    NN_RET_CHECK_EQ(operand.dimensions.size(), 1u) << "Condition operand must be a singleton";
+    NN_RET_CHECK_EQ(operand.dimensions[0], 1u) << "Condition operand must be a singleton";
+    return true;
+}
+
+static void checkSubgraphValidationHelper(const SubgraphValidationHelper& helper) {
+    CHECK(helper.isValidSubgraphReference != nullptr);
+    CHECK(helper.getSubgraphInputCount != nullptr);
+    CHECK(helper.getSubgraphOutputCount != nullptr);
+    CHECK(helper.getSubgraphInputOperand != nullptr);
+    CHECK(helper.getSubgraphOutputOperand != nullptr);
+}
+
+static bool validateIfOperation(uint32_t inputCount, const uint32_t* inputs, uint32_t outputCount,
+                                const uint32_t* outputs, const std::vector<Operand>& operands,
+                                const SubgraphValidationHelper& helper) {
+    namespace op = operation_if;
+    checkSubgraphValidationHelper(helper);
+    NN_RET_CHECK_GE(inputCount, 3u) << "ANEURALNETWORKS_IF must have at least 3 inputs";
+    NN_RET_CHECK_GE(outputCount, 1u) << "ANEURALNETWORKS_IF must have at least 1 output";
+    auto validateBranchOperand = [&](const Operand& branchModelOperand) -> bool {
+        NN_RET_CHECK(helper.isValidSubgraphReference(branchModelOperand))
+                << "Operand is not a valid subgraph reference";
+        const uint32_t branchModelInputCount = helper.getSubgraphInputCount(branchModelOperand);
+        const uint32_t branchModelOutputCount = helper.getSubgraphOutputCount(branchModelOperand);
+        NN_RET_CHECK_EQ(inputCount, op::kFirstInput + branchModelInputCount);
+        NN_RET_CHECK_EQ(outputCount, branchModelOutputCount);
+        for (uint32_t i = 0; i < branchModelInputCount; ++i) {
+            const Operand& innerOperand = helper.getSubgraphInputOperand(branchModelOperand, i);
+            const Operand& outerOperand = operands[inputs[op::kFirstInput + i]];
+            NN_RET_CHECK(compatible(innerOperand, outerOperand));
+        }
+        for (uint32_t i = 0; i < branchModelOutputCount; ++i) {
+            const Operand& innerOperand = helper.getSubgraphOutputOperand(branchModelOperand, i);
+            const Operand& outerOperand = operands[outputs[i]];
+            NN_RET_CHECK(compatible(innerOperand, outerOperand));
+        }
+        return true;
+    };
+    NN_RET_CHECK(validateConditionOperand(operands[inputs[op::kCondBoolOperand]]))
+            << "Validation failed for IF condition operand";
+    NN_RET_CHECK(validateBranchOperand(operands[inputs[op::kThenModelOperand]]))
+            << "Validation failed for IF then model";
+    NN_RET_CHECK(validateBranchOperand(operands[inputs[op::kElseModelOperand]]))
+            << "Validation failed for IF else model";
+    return true;
+}
+
+static bool validateWhileOperation(uint32_t inputCount, const uint32_t* inputs,
+                                   uint32_t outputCount, const uint32_t* outputs,
+                                   const std::vector<Operand>& operands,
+                                   const SubgraphValidationHelper& helper) {
+    // Let the loop have
+    // - m >= 1 input-output operands,
+    // - k >= 0 state-only operands, and
+    // - n >= 0 input-only operands.
+    // Then
+    // - the WHILE loop operation has (2 + m + k + n) inputs and m outputs.
+    // - the condition model has (m + k + n) inputs and 1 output.
+    // - the body model has (m + k + n) inputs and (m + k) outputs.
+    namespace op = operation_while;
+    checkSubgraphValidationHelper(helper);
+    NN_RET_CHECK_GE(inputCount, 3u) << "ANEURALNETWORKS_WHILE must have at least 3 inputs";
+    NN_RET_CHECK_GE(outputCount, 1u) << "ANEURALNETWORKS_WHILE must have at least 1 output";
+    auto validateCondOperand = [&](const Operand& condModelOperand) -> bool {
+        NN_RET_CHECK(helper.isValidSubgraphReference(condModelOperand))
+                << "Operand is not a valid subgraph reference";
+        const uint32_t condModelInputCount = helper.getSubgraphInputCount(condModelOperand);
+        const uint32_t condModelOutputCount = helper.getSubgraphOutputCount(condModelOperand);
+        NN_RET_CHECK_EQ(inputCount, op::kFirstInput + condModelInputCount);
+        NN_RET_CHECK_EQ(condModelOutputCount, 1u);
+        for (uint32_t i = 0; i < condModelInputCount; ++i) {
+            const Operand& innerOperand = helper.getSubgraphInputOperand(condModelOperand, i);
+            const Operand& outerOperand = operands[inputs[op::kFirstInput + i]];
+            NN_RET_CHECK(compatible(innerOperand, outerOperand));
+        }
+        NN_RET_CHECK(
+                validateConditionOperand(helper.getSubgraphOutputOperand(condModelOperand, 0)));
+        return true;
+    };
+    auto validateBodyOperand = [&](const Operand& bodyModelOperand) -> bool {
+        NN_RET_CHECK(helper.isValidSubgraphReference(bodyModelOperand))
+                << "Operand is not a valid subgraph reference";
+        const uint32_t bodyModelInputCount = helper.getSubgraphInputCount(bodyModelOperand);
+        const uint32_t bodyModelOutputCount = helper.getSubgraphOutputCount(bodyModelOperand);
+        NN_RET_CHECK_EQ(inputCount, op::kFirstInput + bodyModelInputCount);
+        NN_RET_CHECK_GE(bodyModelOutputCount, outputCount);
+        NN_RET_CHECK_GE(bodyModelInputCount, bodyModelOutputCount);
+        const uint32_t inputOutputCount = outputCount;
+        const uint32_t stateOnlyCount = bodyModelOutputCount - inputOutputCount;
+        const uint32_t inputOnlyCount = bodyModelInputCount - bodyModelOutputCount;
+        for (uint32_t i = 0, n = inputOutputCount + stateOnlyCount + inputOnlyCount; i < n; ++i) {
+            const Operand& innerOperand = helper.getSubgraphInputOperand(bodyModelOperand, i);
+            const Operand& outerOperand = operands[inputs[op::kFirstInput + i]];
+            NN_RET_CHECK(compatible(innerOperand, outerOperand));
+        }
+        for (uint32_t i = 0; i < inputOutputCount; ++i) {
+            const Operand& innerOperand = helper.getSubgraphOutputOperand(bodyModelOperand, i);
+            const Operand& outerOperand = operands[outputs[i]];
+            NN_RET_CHECK(compatible(innerOperand, outerOperand));
+        }
+        for (uint32_t i = 0, n = inputOutputCount + stateOnlyCount; i < n; ++i) {
+            const Operand& inputOperand = helper.getSubgraphInputOperand(bodyModelOperand, i);
+            const Operand& outputOperand = helper.getSubgraphOutputOperand(bodyModelOperand, i);
+            NN_RET_CHECK(compatible(inputOperand, outputOperand));
+        }
+        return true;
+    };
+    NN_RET_CHECK(validateCondOperand(operands[inputs[op::kCondModelOperand]]))
+            << "Validation failed for WHILE condition model";
+    NN_RET_CHECK(validateBodyOperand(operands[inputs[op::kBodyModelOperand]]))
+            << "Validation failed for WHILE body model";
+    return true;
+}
+
+static inline int validateOperation(ANeuralNetworksOperationType opType, uint32_t inputCount,
+                                    const uint32_t* inputIndexes, uint32_t outputCount,
+                                    const uint32_t* outputIndexes,
+                                    const std::vector<hal::Operand>& operands,
+                                    HalVersion halVersion) {
+    if (opType == ANEURALNETWORKS_IF || opType == ANEURALNETWORKS_WHILE) {
+        NN_RETURN_IF_ERROR(validateHalVersion(opType, halVersion, HalVersion::V1_3));
+        LOG(ERROR) << "This validateOperation() overload does not support control flow";
+        return ANEURALNETWORKS_BAD_DATA;
+    }
+    return validateOperation(opType, inputCount, inputIndexes, outputCount, outputIndexes, operands,
+                             halVersion, {});
+}
+
 int validateOperation(ANeuralNetworksOperationType opType, uint32_t inputCount,
                       const uint32_t* inputIndexes, uint32_t outputCount,
                       const uint32_t* outputIndexes, const std::vector<Operand>& operands,
-                      HalVersion halVersion) {
+                      HalVersion halVersion, const SubgraphValidationHelper& helper) {
     NN_RETURN_IF_ERROR(validateOperandList(inputCount, inputIndexes,
                                            static_cast<uint32_t>(operands.size()),
                                            "ANeuralNetworksModel_addOperation inputs"));
@@ -1637,6 +1783,20 @@
                                                  inExpectedTypes, outputCount, outputIndexes,
                                                  outExpectedTypes);
         }
+        case ANEURALNETWORKS_IF: {
+            NN_RETURN_IF_ERROR(validateHalVersion(opType, halVersion, HalVersion::V1_3));
+            return validateIfOperation(inputCount, inputIndexes, outputCount, outputIndexes,
+                                       operands, helper)
+                           ? ANEURALNETWORKS_NO_ERROR
+                           : ANEURALNETWORKS_BAD_DATA;
+        }
+        case ANEURALNETWORKS_WHILE: {
+            NN_RETURN_IF_ERROR(validateHalVersion(opType, halVersion, HalVersion::V1_3));
+            return validateWhileOperation(inputCount, inputIndexes, outputCount, outputIndexes,
+                                          operands, helper)
+                           ? ANEURALNETWORKS_NO_ERROR
+                           : ANEURALNETWORKS_BAD_DATA;
+        }
         default: {
             const OperationRegistration* operationRegistration =
                     BuiltinOperationResolver::get()->findOperation(
diff --git a/common/ValidateHal.cpp b/common/ValidateHal.cpp
index 74e2d7b..d3f43b8 100644
--- a/common/ValidateHal.cpp
+++ b/common/ValidateHal.cpp
@@ -101,6 +101,7 @@
         case OperandType::INT32:
         case OperandType::UINT32:
         case OperandType::BOOL:
+        case OperandType::SUBGRAPH:
         case OperandType::TENSOR_FLOAT32:
         case OperandType::TENSOR_FLOAT16:
         case OperandType::TENSOR_INT32:
@@ -164,7 +165,8 @@
 template <typename VersionedOperand>
 static bool validateOperands(const hidl_vec<VersionedOperand>& operands,
                              const hidl_vec<uint8_t>& operandValues,
-                             const hidl_vec<hidl_memory>& pools, bool allowUnspecifiedRank) {
+                             const hidl_vec<hidl_memory>& pools,
+                             const hidl_vec<Subgraph>& subgraphs, bool allowUnspecifiedRank) {
     uint32_t index = 0;
     MemoryAccessVerifier poolVerifier(pools);
     for (auto& versionedOperand : operands) {
@@ -183,6 +185,7 @@
             case OperandType::INT32:
             case OperandType::UINT32:
             case OperandType::BOOL:
+            case OperandType::SUBGRAPH:
             case OperandType::OEM: {
                 size_t count = operand.dimensions.size();
                 if (count != 0) {
@@ -232,6 +235,7 @@
             case OperandType::INT32:
             case OperandType::UINT32:
             case OperandType::BOOL:
+            case OperandType::SUBGRAPH:
             case OperandType::TENSOR_FLOAT16:
             case OperandType::TENSOR_FLOAT32:
             case OperandType::TENSOR_BOOL8:
@@ -281,6 +285,7 @@
             case OperandType::INT32:
             case OperandType::UINT32:
             case OperandType::BOOL:
+            case OperandType::SUBGRAPH:
             case OperandType::TENSOR_FLOAT16:
             case OperandType::TENSOR_FLOAT32:
             case OperandType::TENSOR_INT32:
@@ -375,12 +380,37 @@
                     return false;
                 }
                 break;
+            case OperandLifeTime::SUBGRAPH: {
+                if (location.poolIndex != 0) {
+                    LOG(ERROR) << "Operand " << index << ": SUBGRAPH with a non-zero poolIndex "
+                               << location.poolIndex;
+                    return false;
+                }
+                if (location.offset >= subgraphs.size()) {
+                    LOG(ERROR) << "Subgraph index out of range: " << location.offset
+                               << " >= " << subgraphs.size();
+                    return false;
+                }
+                if (location.length != 0) {
+                    LOG(ERROR) << "Operand " << index << ": SUBGRAPH with a non-zero length "
+                               << location.length;
+                    return false;
+                }
+            } break;
             default:
                 LOG(ERROR) << "Operand " << index << ": Invalid lifetime "
                            << toString(operand.lifetime);
                 return false;
         }
 
+        // Make sure SUBGRAPH operand type and lifetime always go together.
+        if ((operand.type == OperandType::SUBGRAPH) !=
+            (operand.lifetime == OperandLifeTime::SUBGRAPH)) {
+            LOG(ERROR) << "Operand " << index << ": Operand of type " << toString(operand.type)
+                       << " cannot have lifetime " << toString(operand.lifetime);
+            return false;
+        }
+
         // For constants, validate that the length is as expected. The other lifetimes
         // expect the length to be 0. Don't validate for OEM types.
         if (operand.lifetime == OperandLifeTime::CONSTANT_REFERENCE ||
@@ -420,7 +450,35 @@
 
 template <typename VersionedOperation>
 static bool validateOperations(const hidl_vec<VersionedOperation>& operations,
-                               const hidl_vec<Operand>& operands) {
+                               const hidl_vec<Operand>& operands,
+                               const hidl_vec<Subgraph>& subgraphs) {
+    auto isValidSubgraphReference = [&subgraphs](const Operand& modelOperand) -> bool {
+        NN_RET_CHECK(modelOperand.type == OperandType::SUBGRAPH)
+                << "Unexpected operand type: " << toString(modelOperand.type);
+        NN_RET_CHECK_LT(modelOperand.location.offset, subgraphs.size())
+                << "Invalid subgraph reference";
+        return true;
+    };
+    auto getSubgraph = [&subgraphs](const Operand& modelOperand) {
+        CHECK_LT(modelOperand.location.offset, subgraphs.size());
+        return subgraphs[modelOperand.location.offset];
+    };
+    auto getInputCount = [&getSubgraph](const Operand& modelOperand) {
+        return getSubgraph(modelOperand).inputIndexes.size();
+    };
+    auto getOutputCount = [&getSubgraph](const Operand& modelOperand) {
+        return getSubgraph(modelOperand).outputIndexes.size();
+    };
+    auto getInputOperand = [&getSubgraph](const Operand& modelOperand, uint32_t index) {
+        const Subgraph& subgraph = getSubgraph(modelOperand);
+        CHECK_LT(subgraph.inputIndexes[index], subgraph.operands.size());
+        return subgraph.operands[subgraph.inputIndexes[index]];
+    };
+    auto getOutputOperand = [&getSubgraph](const Operand& modelOperand, uint32_t index) {
+        const Subgraph& subgraph = getSubgraph(modelOperand);
+        CHECK_LT(subgraph.outputIndexes[index], subgraph.operands.size());
+        return subgraph.operands[subgraph.outputIndexes[index]];
+    };
     const size_t operandCount = operands.size();
     // This vector keeps track of whether there's an operation that writes to
     // each operand. It is used to validate that temporary variables and
@@ -432,7 +490,12 @@
         int error = validateOperation(
                 static_cast<int32_t>(op.type), op.inputs.size(),
                 op.inputs.size() > 0 ? op.inputs.data() : nullptr, op.outputs.size(),
-                op.outputs.size() > 0 ? op.outputs.data() : nullptr, operands, getHalVersion(op));
+                op.outputs.size() > 0 ? op.outputs.data() : nullptr, operands, getHalVersion(op),
+                {.isValidSubgraphReference = isValidSubgraphReference,
+                 .getSubgraphInputCount = getInputCount,
+                 .getSubgraphOutputCount = getOutputCount,
+                 .getSubgraphInputOperand = getInputOperand,
+                 .getSubgraphOutputOperand = getOutputOperand});
         if (error != ANEURALNETWORKS_NO_ERROR) {
             LOG(ERROR) << "Invalid operation " << toString(op.type);
             return false;
@@ -541,9 +604,9 @@
     // We only need versioned operands for their validation. For all the other
     // validations we can use operands upcasted to the latest version.
     const hidl_vec<Operand> latestVersionOperands = convertToV1_3(model.operands);
-    return (validateOperands(model.operands, model.operandValues, model.pools,
+    return (validateOperands(model.operands, model.operandValues, model.pools, /*subgraphs=*/{},
                              /*allowUnspecifiedRank=*/version >= HalVersion::V1_2) &&
-            validateOperations(model.operations, latestVersionOperands) &&
+            validateOperations(model.operations, latestVersionOperands, /*subgraphs=*/{}) &&
             validateModelInputOutputs(model.inputIndexes, latestVersionOperands,
                                       OperandLifeTime::SUBGRAPH_INPUT) &&
             validateModelInputOutputs(model.outputIndexes, latestVersionOperands,
@@ -564,8 +627,8 @@
     }
     auto validateSubgraph = [&model](const Subgraph& subgraph) -> bool {
         return (validateOperands(subgraph.operands, model.operandValues, model.pools,
-                                 /*allowUnspecifiedRank=*/true) &&
-                validateOperations(subgraph.operations, subgraph.operands) &&
+                                 model.referenced, /*allowUnspecifiedRank=*/true) &&
+                validateOperations(subgraph.operations, subgraph.operands, model.referenced) &&
                 validateModelInputOutputs(subgraph.inputIndexes, subgraph.operands,
                                           OperandLifeTime::SUBGRAPH_INPUT) &&
                 validateModelInputOutputs(subgraph.outputIndexes, subgraph.operands,
@@ -752,6 +815,7 @@
         case V1_3::OperandType::TENSOR_BOOL8:
         case V1_3::OperandType::TENSOR_QUANT8_SYMM_PER_CHANNEL:
         case V1_3::OperandType::TENSOR_QUANT8_ASYMM_SIGNED:
+        case V1_3::OperandType::SUBGRAPH:
         case V1_3::OperandType::OEM:
         case V1_3::OperandType::TENSOR_OEM_BYTE:
             return true;
diff --git a/common/include/Utils.h b/common/include/Utils.h
index a90db35..7797381 100644
--- a/common/include/Utils.h
+++ b/common/include/Utils.h
@@ -319,12 +319,27 @@
 int validateOperandList(uint32_t count, const uint32_t* list, uint32_t operandCount,
                         const char* tag);
 
+// A set of functions to help validate models containing IF or WHILE operations.
+struct SubgraphValidationHelper {
+    // Checks if a given operand is a SUBGRAPH operand with a valid offset.
+    std::function<bool(const hal::Operand&)> isValidSubgraphReference;
+    // Gets the input count of a subgraph referenced by a given operand.
+    std::function<uint32_t(const hal::Operand&)> getSubgraphInputCount;
+    // Gets the output count of a subgraph referenced by a given operand.
+    std::function<uint32_t(const hal::Operand&)> getSubgraphOutputCount;
+    // Gets the specified input operand of a subgraph referenced by a given operand.
+    std::function<const hal::Operand&(const hal::Operand&, uint32_t)> getSubgraphInputOperand;
+    // Gets the specified output operand of a subgraph referenced by a given operand.
+    std::function<const hal::Operand&(const hal::Operand&, uint32_t)> getSubgraphOutputOperand;
+};
+
 // Returns ANEURALNETWORKS_NO_ERROR if the corresponding operation is defined and can handle the
 // provided operand types in the given HAL version, otherwise returns ANEURALNETWORKS_BAD_DATA.
+// The last argument is only used for validating IF and WHILE operations.
 int validateOperation(ANeuralNetworksOperationType opType, uint32_t inputCount,
                       const uint32_t* inputIndexes, uint32_t outputCount,
                       const uint32_t* outputIndexes, const std::vector<hal::Operand>& operands,
-                      HalVersion halVersion);
+                      HalVersion halVersion, const SubgraphValidationHelper& helper);
 
 inline size_t getSizeFromInts(int lower, int higher) {
     return (uint32_t)(lower) + ((uint64_t)(uint32_t)(higher) << 32);