Improve validation of the HAL structures. Add a lot of validation of the structures passed in the HAL. Particularly important are those that operands & arguments don't try to reach out of their memory blocks. Also grabs a few missing generated tests. Bug: 67828197 Test: System tests & VTS tests. Change-Id: I2edf6219fc660fab7c5b6a73e7a9cb8a358fb29b
diff --git a/common/ValidateHal.cpp b/common/ValidateHal.cpp new file mode 100644 index 0000000..011bc3c --- /dev/null +++ b/common/ValidateHal.cpp
@@ -0,0 +1,397 @@ +/* + * Copyright (C) 2017 The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#define LOG_TAG "ValidateHal" + +#include "ValidateHal.h" +#include "NeuralNetworks.h" +#include "Utils.h" + +#include <android-base/logging.h> + +namespace android { +namespace nn { + +class MemoryAccessVerifier { +public: + MemoryAccessVerifier(const hidl_vec<hidl_memory>& pools) + : mPoolCount(pools.size()), mPoolSizes(mPoolCount) { + for (size_t i = 0; i < mPoolCount; i++) { + mPoolSizes[i] = pools[i].size(); + } + } + bool validate(const DataLocation& location) { + if (location.poolIndex >= mPoolCount) { + LOG(ERROR) << "Invalid poolIndex " << location.poolIndex << "/" << mPoolCount; + return false; + } + const size_t size = mPoolSizes[location.poolIndex]; + // Do the addition using size_t to avoid potential wrap-around problems. + if (static_cast<size_t>(location.offset) + location.length > size) { + LOG(ERROR) << "Reference to pool " << location.poolIndex << " with offset " + << location.offset << " and length " << location.length + << " exceeds pool size of " << size; + return false; + } + return true; + } + +private: + size_t mPoolCount; + std::vector<size_t> mPoolSizes; +}; + +static bool validateOperands(const hidl_vec<Operand>& operands, + const hidl_vec<uint8_t>& operandValues, + const hidl_vec<hidl_memory>& pools) { + uint32_t index = 0; + MemoryAccessVerifier poolVerifier(pools); + for (auto& operand : operands) { + // Validate type and dimensions. + switch (operand.type) { + case OperandType::FLOAT32: + case OperandType::INT32: + case OperandType::UINT32: + case OperandType::OEM: { + size_t count = operand.dimensions.size(); + if (count != 0) { + LOG(ERROR) << "Operand " << index << ": Scalar data has dimensions of rank " + << count; + return false; + } + break; + } + case OperandType::TENSOR_FLOAT32: + case OperandType::TENSOR_INT32: + case OperandType::TENSOR_QUANT8_ASYMM: + case OperandType::TENSOR_OEM_BYTE: { + if (operand.dimensions.size() == 0) { + LOG(ERROR) << "Operand " << index << ": Tensor has dimensions of rank 0"; + return false; + } + break; + } + default: + LOG(ERROR) << "Operand " << index << ": Invalid operand type " + << toString(operand.type); + return false; + } + + // TODO Validate the numberOfConsumers. + // TODO Since we have to validate it, there was no point in including it. For the next + // release, consider removing unless we have an additional process in system space + // that creates this value. In that case, it would not have to be validated. + + // Validate the scale. + switch (operand.type) { + case OperandType::FLOAT32: + case OperandType::INT32: + case OperandType::UINT32: + case OperandType::TENSOR_FLOAT32: + if (operand.scale != 0.f) { + LOG(ERROR) << "Operand " << index << ": Operand of type " + << getOperandTypeName(operand.type) << " with a non-zero scale (" + << operand.scale; + return false; + } + break; + case OperandType::TENSOR_QUANT8_ASYMM: + if (operand.scale == 0.f) { + LOG(ERROR) << "Operand " << index << ": Operand of type " + << getOperandTypeName(operand.type) << " with a zero scale"; + return false; + } + break; + default: + // No validation for the OEM types. No validation also for TENSOR_INT32, + // as tensors of this type may be used with or without scale, depending on + // the operation. + // TODO We should have had a separate type for TENSOR_INT32 that a scale + // and those who don't. Document now and fix in the next release. + break; + } + + // Validate the zeroPoint. + switch (operand.type) { + case OperandType::FLOAT32: + case OperandType::INT32: + case OperandType::UINT32: + case OperandType::TENSOR_FLOAT32: + case OperandType::TENSOR_INT32: + if (operand.zeroPoint != 0) { + LOG(ERROR) << "Operand " << index << ": Operand of type " + << getOperandTypeName(operand.type) << " with an non-zero zeroPoint " + << operand.zeroPoint; + return false; + } + break; + default: + // No validation for the OEM types. + break; + } + + // Validate the lifetime and the location. + const DataLocation& location = operand.location; + switch (operand.lifetime) { + case OperandLifeTime::CONSTANT_COPY: + if (location.poolIndex != 0) { + LOG(ERROR) << "Operand " << index + << ": CONSTANT_COPY with a non-zero poolIndex " + << location.poolIndex; + return false; + } + // Do the addition using size_t to avoid potential wrap-around problems. + if (static_cast<size_t>(location.offset) + location.length > operandValues.size()) { + LOG(ERROR) << "Operand " << index + << ": OperandValue location out of range. Starts at " + << location.offset << ", length " << location.length << ", max " + << operandValues.size(); + return false; + } + break; + case OperandLifeTime::CONSTANT_REFERENCE: + if (!poolVerifier.validate(location)) { + return false; + } + break; + case OperandLifeTime::TEMPORARY_VARIABLE: + case OperandLifeTime::MODEL_INPUT: + case OperandLifeTime::MODEL_OUTPUT: + case OperandLifeTime::NO_VALUE: + if (location.poolIndex != 0 || location.offset != 0 || location.length != 0) { + LOG(ERROR) << "Operand " << index << ": Unexpected poolIndex " + << location.poolIndex << ", offset " << location.offset + << ", or length " << location.length << " for operand of lifetime " + << toString(operand.lifetime); + return false; + } + break; + default: + LOG(ERROR) << "Operand " << index << ": Invalid lifetime " + << toString(operand.lifetime); + return false; + } + + // For constants, validate that the length is as expected. The other lifetimes + // expect the length to be 0. Don't validate for OEM types. + if (operand.lifetime == OperandLifeTime::CONSTANT_REFERENCE || + operand.lifetime == OperandLifeTime::CONSTANT_COPY) { + if (operand.type != OperandType::OEM && + operand.type != OperandType::TENSOR_OEM_BYTE) { + uint32_t expectedLength = sizeOfData(operand.type, operand.dimensions); + if (location.length != expectedLength) { + LOG(ERROR) << "Operand " << index << ": For operand " << toString(operand) + << " expected a size of " << expectedLength << " but got " + << location.length; + return false; + } + } + } + + index++; + } + return true; +} + +static bool validateOperations(const hidl_vec<Operation>& operations, + const hidl_vec<Operand>& operands) { + const size_t operandCount = operands.size(); + // This vector keeps track of whether there's an operation that writes to + // each operand. It is used to validate that temporary variables and + // model outputs will be written to. + std::vector<bool> writtenTo(operandCount, false); + for (auto& op : operations) { + if (!validCode(kNumberOfOperationTypes, kNumberOfOperationTypesOEM, + static_cast<uint32_t>(op.type))) { + LOG(ERROR) << "Invalid operation type " << toString(op.type); + return false; + } + // TODO Validate that the number of inputs and outputs, and their types, is correct + // for the operation. This is currently done in CpuExecutor but should be done + // here for all drivers. + for (uint32_t i : op.inputs) { + if (i >= operandCount) { + LOG(ERROR) << "Operation input index out of range " << i << "/" << operandCount; + return false; + } + } + for (uint32_t i : op.outputs) { + if (i >= operandCount) { + LOG(ERROR) << "Operation output index out of range " << i << "/" << operandCount; + return false; + } + const Operand& operand = operands[i]; + if (operand.lifetime != OperandLifeTime::TEMPORARY_VARIABLE && + operand.lifetime != OperandLifeTime::MODEL_OUTPUT) { + LOG(ERROR) << "Writing to an operand with incompatible lifetime " + << toString(operand.lifetime); + return false; + } + + // Check that we only write once to an operand. + if (writtenTo[i]) { + LOG(ERROR) << "Operand " << i << " written a second time"; + return false; + } + writtenTo[i] = true; + } + } + for (size_t i = 0; i < operandCount; i++) { + if (!writtenTo[i]) { + const Operand& operand = operands[i]; + if (operand.lifetime == OperandLifeTime::TEMPORARY_VARIABLE || + operand.lifetime == OperandLifeTime::MODEL_OUTPUT) { + LOG(ERROR) << "Operand " << i << " with lifetime " << toString(operand.lifetime) + << " is not being written to."; + return false; + } + } + } + // TODO More whole graph verifications are possible, for example that an + // operand is not use as input & output for the same op, and more + // generally that it is acyclic. + return true; +} + +static bool validatePools(const hidl_vec<hidl_memory>& pools) { + for (const hidl_memory& memory : pools) { + const auto name = memory.name(); + if (name != "ashmem" && name != "mmap_fd") { + LOG(ERROR) << "Unsupported memory type " << name; + return false; + } + if (memory.handle() == nullptr) { + LOG(ERROR) << "Memory of type " << name << " is null"; + return false; + } + } + return true; +} + +static bool validateModelInputOutputs(const hidl_vec<uint32_t> indexes, + const hidl_vec<Operand>& operands, OperandLifeTime lifetime) { + const size_t operandCount = operands.size(); + for (uint32_t i : indexes) { + if (i >= operandCount) { + LOG(ERROR) << "Model input or output index out of range " << i << "/" << operandCount; + return false; + } + const Operand& operand = operands[i]; + if (operand.lifetime != lifetime) { + LOG(ERROR) << "Model input or output has lifetime of " << toString(operand.lifetime) + << " instead of the expected " << toString(lifetime); + return false; + } + } + return true; +} + +bool validateModel(const Model& model) { + return (validateOperands(model.operands, model.operandValues, model.pools) && + validateOperations(model.operations, model.operands) && + validateModelInputOutputs(model.inputIndexes, model.operands, + OperandLifeTime::MODEL_INPUT) && + validateModelInputOutputs(model.outputIndexes, model.operands, + OperandLifeTime::MODEL_OUTPUT) && + validatePools(model.pools)); +} + +// Validates the arguments of a request. type is either "input" or "output" and is used +// for printing error messages. The operandIndexes is the appropriate array of input +// or output operand indexes that was passed to the ANeuralNetworksModel_identifyInputsAndOutputs. +static bool validateRequestArguments(const hidl_vec<RequestArgument>& requestArguments, + const hidl_vec<uint32_t>& operandIndexes, + const hidl_vec<Operand>& operands, + const hidl_vec<hidl_memory>& pools, const char* type) { + MemoryAccessVerifier poolVerifier(pools); + // The request should specify as many arguments as were described in the model. + const size_t requestArgumentCount = requestArguments.size(); + if (requestArgumentCount != operandIndexes.size()) { + LOG(ERROR) << "Request specifies " << requestArgumentCount << " " << type + << "s but the model has " << operandIndexes.size(); + return false; + } + for (size_t requestArgumentIndex = 0; requestArgumentIndex < requestArgumentCount; + requestArgumentIndex++) { + const RequestArgument& requestArgument = requestArguments[requestArgumentIndex]; + const DataLocation& location = requestArgument.location; + // Get the operand index for this argument. We extract it from the list + // that was provided in the call to ANeuralNetworksModel_identifyInputsAndOutputs. + // We assume in this function that the model has been validated already. + const uint32_t operandIndex = operandIndexes[requestArgumentIndex]; + const Operand& operand = operands[operandIndex]; + if (requestArgument.hasNoValue) { + if (location.poolIndex != 0 || location.offset != 0 || location.length != 0 || + requestArgument.dimensions.size() != 0) { + LOG(ERROR) << "Request " << type << " " << requestArgumentIndex + << " has no value yet has details."; + return false; + } + } else { + // Validate the location. + if (!poolVerifier.validate(location)) { + return false; + } + // If the argument specified a dimension, validate it. + uint32_t rank = requestArgument.dimensions.size(); + if (rank == 0) { + // Validate that all the dimensions are specified in the model. + for (size_t i = 0; i < operand.dimensions.size(); i++) { + if (operand.dimensions[i] == 0) { + LOG(ERROR) << "Model has dimension " << i + << " set to 0 but the request does specify the dimension."; + return false; + } + } + } else { + if (rank != operand.dimensions.size()) { + LOG(ERROR) << "Request " << type << " " << requestArgumentIndex + << " has number of dimensions (" << rank + << ") different than the model's (" << operand.dimensions.size() + << ")"; + return false; + } + for (size_t i = 0; i < rank; i++) { + if (requestArgument.dimensions[i] != operand.dimensions[i] && + operand.dimensions[i] != 0) { + LOG(ERROR) << "Request " << type << " " << requestArgumentIndex + << " has dimension " << i << " of " + << requestArgument.dimensions[i] + << " different than the model's " << operand.dimensions[i]; + return false; + } + if (requestArgument.dimensions[i] == 0) { + LOG(ERROR) << "Request " << type << " " << requestArgumentIndex + << " has dimension " << i << " of zero"; + return false; + } + } + } + } + } + return true; +} + +bool validateRequest(const Request& request, const Model& model) { + return (validateRequestArguments(request.inputs, model.inputIndexes, model.operands, + request.pools, "input") && + validateRequestArguments(request.outputs, model.outputIndexes, model.operands, + request.pools, "output") && + validatePools(request.pools)); +} + +} // namespace nn +} // namespace android