Improve validation of the HAL structures.

Add a lot of validation of the structures passed in the HAL.
Particularly important are those that operands & arguments
don't try to reach out of their memory blocks.

Also grabs a few missing generated tests.

Bug: 67828197
Test: System tests & VTS tests.

Change-Id: I2edf6219fc660fab7c5b6a73e7a9cb8a358fb29b
diff --git a/common/ValidateHal.cpp b/common/ValidateHal.cpp
new file mode 100644
index 0000000..011bc3c
--- /dev/null
+++ b/common/ValidateHal.cpp
@@ -0,0 +1,397 @@
+/*
+ * Copyright (C) 2017 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#define LOG_TAG "ValidateHal"
+
+#include "ValidateHal.h"
+#include "NeuralNetworks.h"
+#include "Utils.h"
+
+#include <android-base/logging.h>
+
+namespace android {
+namespace nn {
+
+class MemoryAccessVerifier {
+public:
+    MemoryAccessVerifier(const hidl_vec<hidl_memory>& pools)
+        : mPoolCount(pools.size()), mPoolSizes(mPoolCount) {
+        for (size_t i = 0; i < mPoolCount; i++) {
+            mPoolSizes[i] = pools[i].size();
+        }
+    }
+    bool validate(const DataLocation& location) {
+        if (location.poolIndex >= mPoolCount) {
+            LOG(ERROR) << "Invalid poolIndex " << location.poolIndex << "/" << mPoolCount;
+            return false;
+        }
+        const size_t size = mPoolSizes[location.poolIndex];
+        // Do the addition using size_t to avoid potential wrap-around problems.
+        if (static_cast<size_t>(location.offset) + location.length > size) {
+            LOG(ERROR) << "Reference to pool " << location.poolIndex << " with offset "
+                       << location.offset << " and length " << location.length
+                       << " exceeds pool size of " << size;
+            return false;
+        }
+        return true;
+    }
+
+private:
+    size_t mPoolCount;
+    std::vector<size_t> mPoolSizes;
+};
+
+static bool validateOperands(const hidl_vec<Operand>& operands,
+                             const hidl_vec<uint8_t>& operandValues,
+                             const hidl_vec<hidl_memory>& pools) {
+    uint32_t index = 0;
+    MemoryAccessVerifier poolVerifier(pools);
+    for (auto& operand : operands) {
+        // Validate type and dimensions.
+        switch (operand.type) {
+            case OperandType::FLOAT32:
+            case OperandType::INT32:
+            case OperandType::UINT32:
+            case OperandType::OEM: {
+                size_t count = operand.dimensions.size();
+                if (count != 0) {
+                    LOG(ERROR) << "Operand " << index << ": Scalar data has dimensions of rank "
+                               << count;
+                    return false;
+                }
+                break;
+            }
+            case OperandType::TENSOR_FLOAT32:
+            case OperandType::TENSOR_INT32:
+            case OperandType::TENSOR_QUANT8_ASYMM:
+            case OperandType::TENSOR_OEM_BYTE: {
+                if (operand.dimensions.size() == 0) {
+                    LOG(ERROR) << "Operand " << index << ": Tensor has dimensions of rank 0";
+                    return false;
+                }
+                break;
+            }
+            default:
+                LOG(ERROR) << "Operand " << index << ": Invalid operand type "
+                           << toString(operand.type);
+                return false;
+        }
+
+        // TODO Validate the numberOfConsumers.
+        // TODO Since we have to validate it, there was no point in including it. For the next
+        // release, consider removing unless we have an additional process in system space
+        // that creates this value. In that case, it would not have to be validated.
+
+        // Validate the scale.
+        switch (operand.type) {
+            case OperandType::FLOAT32:
+            case OperandType::INT32:
+            case OperandType::UINT32:
+            case OperandType::TENSOR_FLOAT32:
+                if (operand.scale != 0.f) {
+                    LOG(ERROR) << "Operand " << index << ": Operand of type "
+                               << getOperandTypeName(operand.type) << " with a non-zero scale ("
+                               << operand.scale;
+                    return false;
+                }
+                break;
+            case OperandType::TENSOR_QUANT8_ASYMM:
+                if (operand.scale == 0.f) {
+                    LOG(ERROR) << "Operand " << index << ": Operand of type "
+                               << getOperandTypeName(operand.type) << " with a zero scale";
+                    return false;
+                }
+                break;
+            default:
+                // No validation for the OEM types. No validation also for TENSOR_INT32,
+                // as tensors of this type may be used with or without scale, depending on
+                // the operation.
+                // TODO We should have had a separate type for TENSOR_INT32 that a scale
+                // and those who don't.  Document now and fix in the next release.
+                break;
+        }
+
+        // Validate the zeroPoint.
+        switch (operand.type) {
+            case OperandType::FLOAT32:
+            case OperandType::INT32:
+            case OperandType::UINT32:
+            case OperandType::TENSOR_FLOAT32:
+            case OperandType::TENSOR_INT32:
+                if (operand.zeroPoint != 0) {
+                    LOG(ERROR) << "Operand " << index << ": Operand of type "
+                               << getOperandTypeName(operand.type) << " with an non-zero zeroPoint "
+                               << operand.zeroPoint;
+                    return false;
+                }
+                break;
+            default:
+                // No validation for the OEM types.
+                break;
+        }
+
+        // Validate the lifetime and the location.
+        const DataLocation& location = operand.location;
+        switch (operand.lifetime) {
+            case OperandLifeTime::CONSTANT_COPY:
+                if (location.poolIndex != 0) {
+                    LOG(ERROR) << "Operand " << index
+                               << ": CONSTANT_COPY with a non-zero poolIndex "
+                               << location.poolIndex;
+                    return false;
+                }
+                // Do the addition using size_t to avoid potential wrap-around problems.
+                if (static_cast<size_t>(location.offset) + location.length > operandValues.size()) {
+                    LOG(ERROR) << "Operand " << index
+                               << ": OperandValue location out of range.  Starts at "
+                               << location.offset << ", length " << location.length << ", max "
+                               << operandValues.size();
+                    return false;
+                }
+                break;
+            case OperandLifeTime::CONSTANT_REFERENCE:
+                if (!poolVerifier.validate(location)) {
+                    return false;
+                }
+                break;
+            case OperandLifeTime::TEMPORARY_VARIABLE:
+            case OperandLifeTime::MODEL_INPUT:
+            case OperandLifeTime::MODEL_OUTPUT:
+            case OperandLifeTime::NO_VALUE:
+                if (location.poolIndex != 0 || location.offset != 0 || location.length != 0) {
+                    LOG(ERROR) << "Operand " << index << ": Unexpected poolIndex "
+                               << location.poolIndex << ", offset " << location.offset
+                               << ", or length " << location.length << " for operand of lifetime "
+                               << toString(operand.lifetime);
+                    return false;
+                }
+                break;
+            default:
+                LOG(ERROR) << "Operand " << index << ": Invalid lifetime "
+                           << toString(operand.lifetime);
+                return false;
+        }
+
+        // For constants, validate that the length is as expected. The other lifetimes
+        // expect the length to be 0. Don't validate for OEM types.
+        if (operand.lifetime == OperandLifeTime::CONSTANT_REFERENCE ||
+            operand.lifetime == OperandLifeTime::CONSTANT_COPY) {
+            if (operand.type != OperandType::OEM &&
+                operand.type != OperandType::TENSOR_OEM_BYTE) {
+                uint32_t expectedLength = sizeOfData(operand.type, operand.dimensions);
+                if (location.length != expectedLength) {
+                    LOG(ERROR) << "Operand " << index << ": For operand " << toString(operand)
+                               << " expected a size of " << expectedLength << " but got "
+                               << location.length;
+                    return false;
+                }
+            }
+        }
+
+        index++;
+    }
+    return true;
+}
+
+static bool validateOperations(const hidl_vec<Operation>& operations,
+                               const hidl_vec<Operand>& operands) {
+    const size_t operandCount = operands.size();
+    // This vector keeps track of whether there's an operation that writes to
+    // each operand. It is used to validate that temporary variables and
+    // model outputs will be written to.
+    std::vector<bool> writtenTo(operandCount, false);
+    for (auto& op : operations) {
+        if (!validCode(kNumberOfOperationTypes, kNumberOfOperationTypesOEM,
+                       static_cast<uint32_t>(op.type))) {
+            LOG(ERROR) << "Invalid operation type " << toString(op.type);
+            return false;
+        }
+        // TODO Validate that the number of inputs and outputs, and their types, is correct
+        // for the operation. This is currently done in CpuExecutor but should be done
+        // here for all drivers.
+        for (uint32_t i : op.inputs) {
+            if (i >= operandCount) {
+                LOG(ERROR) << "Operation input index out of range " << i << "/" << operandCount;
+                return false;
+            }
+        }
+        for (uint32_t i : op.outputs) {
+            if (i >= operandCount) {
+                LOG(ERROR) << "Operation output index out of range " << i << "/" << operandCount;
+                return false;
+            }
+            const Operand& operand = operands[i];
+            if (operand.lifetime != OperandLifeTime::TEMPORARY_VARIABLE &&
+                operand.lifetime != OperandLifeTime::MODEL_OUTPUT) {
+                LOG(ERROR) << "Writing to an operand with incompatible lifetime "
+                           << toString(operand.lifetime);
+                return false;
+            }
+
+            // Check that we only write once to an operand.
+            if (writtenTo[i]) {
+                LOG(ERROR) << "Operand " << i << " written a second time";
+                return false;
+            }
+            writtenTo[i] = true;
+        }
+    }
+    for (size_t i = 0; i < operandCount; i++) {
+        if (!writtenTo[i]) {
+            const Operand& operand = operands[i];
+            if (operand.lifetime == OperandLifeTime::TEMPORARY_VARIABLE ||
+                operand.lifetime == OperandLifeTime::MODEL_OUTPUT) {
+                LOG(ERROR) << "Operand " << i << " with lifetime " << toString(operand.lifetime)
+                           << " is not being written to.";
+                return false;
+            }
+        }
+    }
+    // TODO More whole graph verifications are possible, for example that an
+    // operand is not use as input & output for the same op, and more
+    // generally that it is acyclic.
+    return true;
+}
+
+static bool validatePools(const hidl_vec<hidl_memory>& pools) {
+    for (const hidl_memory& memory : pools) {
+        const auto name = memory.name();
+        if (name != "ashmem" && name != "mmap_fd") {
+            LOG(ERROR) << "Unsupported memory type " << name;
+            return false;
+        }
+        if (memory.handle() == nullptr) {
+            LOG(ERROR) << "Memory of type " << name << " is null";
+            return false;
+        }
+    }
+    return true;
+}
+
+static bool validateModelInputOutputs(const hidl_vec<uint32_t> indexes,
+                                      const hidl_vec<Operand>& operands, OperandLifeTime lifetime) {
+    const size_t operandCount = operands.size();
+    for (uint32_t i : indexes) {
+        if (i >= operandCount) {
+            LOG(ERROR) << "Model input or output index out of range " << i << "/" << operandCount;
+            return false;
+        }
+        const Operand& operand = operands[i];
+        if (operand.lifetime != lifetime) {
+            LOG(ERROR) << "Model input or output has lifetime of " << toString(operand.lifetime)
+                       << " instead of the expected " << toString(lifetime);
+            return false;
+        }
+    }
+    return true;
+}
+
+bool validateModel(const Model& model) {
+    return (validateOperands(model.operands, model.operandValues, model.pools) &&
+            validateOperations(model.operations, model.operands) &&
+            validateModelInputOutputs(model.inputIndexes, model.operands,
+                                      OperandLifeTime::MODEL_INPUT) &&
+            validateModelInputOutputs(model.outputIndexes, model.operands,
+                                      OperandLifeTime::MODEL_OUTPUT) &&
+            validatePools(model.pools));
+}
+
+// Validates the arguments of a request. type is either "input" or "output" and is used
+// for printing error messages. The operandIndexes is the appropriate array of input
+// or output operand indexes that was passed to the ANeuralNetworksModel_identifyInputsAndOutputs.
+static bool validateRequestArguments(const hidl_vec<RequestArgument>& requestArguments,
+                                     const hidl_vec<uint32_t>& operandIndexes,
+                                     const hidl_vec<Operand>& operands,
+                                     const hidl_vec<hidl_memory>& pools, const char* type) {
+    MemoryAccessVerifier poolVerifier(pools);
+    // The request should specify as many arguments as were described in the model.
+    const size_t requestArgumentCount = requestArguments.size();
+    if (requestArgumentCount != operandIndexes.size()) {
+        LOG(ERROR) << "Request specifies " << requestArgumentCount << " " << type
+                   << "s but the model has " << operandIndexes.size();
+        return false;
+    }
+    for (size_t requestArgumentIndex = 0; requestArgumentIndex < requestArgumentCount;
+         requestArgumentIndex++) {
+        const RequestArgument& requestArgument = requestArguments[requestArgumentIndex];
+        const DataLocation& location = requestArgument.location;
+        // Get the operand index for this argument. We extract it from the list
+        // that was provided in the call to ANeuralNetworksModel_identifyInputsAndOutputs.
+        // We assume in this function that the model has been validated already.
+        const uint32_t operandIndex = operandIndexes[requestArgumentIndex];
+        const Operand& operand = operands[operandIndex];
+        if (requestArgument.hasNoValue) {
+            if (location.poolIndex != 0 || location.offset != 0 || location.length != 0 ||
+                requestArgument.dimensions.size() != 0) {
+                LOG(ERROR) << "Request " << type << " " << requestArgumentIndex
+                           << " has no value yet has details.";
+                return false;
+            }
+        } else {
+            // Validate the location.
+            if (!poolVerifier.validate(location)) {
+                return false;
+            }
+            // If the argument specified a dimension, validate it.
+            uint32_t rank = requestArgument.dimensions.size();
+            if (rank == 0) {
+                // Validate that all the dimensions are specified in the model.
+                for (size_t i = 0; i < operand.dimensions.size(); i++) {
+                    if (operand.dimensions[i] == 0) {
+                        LOG(ERROR) << "Model has dimension " << i
+                                   << " set to 0 but the request does specify the dimension.";
+                        return false;
+                    }
+                }
+            } else {
+                if (rank != operand.dimensions.size()) {
+                    LOG(ERROR) << "Request " << type << " " << requestArgumentIndex
+                               << " has number of dimensions (" << rank
+                               << ") different than the model's (" << operand.dimensions.size()
+                               << ")";
+                    return false;
+                }
+                for (size_t i = 0; i < rank; i++) {
+                    if (requestArgument.dimensions[i] != operand.dimensions[i] &&
+                        operand.dimensions[i] != 0) {
+                        LOG(ERROR) << "Request " << type << " " << requestArgumentIndex
+                                   << " has dimension " << i << " of "
+                                   << requestArgument.dimensions[i]
+                                   << " different than the model's " << operand.dimensions[i];
+                        return false;
+                    }
+                    if (requestArgument.dimensions[i] == 0) {
+                        LOG(ERROR) << "Request " << type << " " << requestArgumentIndex
+                                   << " has dimension " << i << " of zero";
+                        return false;
+                    }
+                }
+            }
+        }
+    }
+    return true;
+}
+
+bool validateRequest(const Request& request, const Model& model) {
+    return (validateRequestArguments(request.inputs, model.inputIndexes, model.operands,
+                                     request.pools, "input") &&
+            validateRequestArguments(request.outputs, model.outputIndexes, model.operands,
+                                     request.pools, "output") &&
+            validatePools(request.pools));
+}
+
+}  // namespace nn
+}  // namespace android