Add a centralized runtime extension information store
Fix: 124107169
Fix: 123523457
Fix: 124285861
Fix: 124104123
Fix: 123178734
Test: NeuralNetworksTest_static
Test: NeuralNetworksTest_utils
Test: NeuralNetworksTest_FibonacciExtension (from change Ibe0fc5356baa909bce8424138bd5cfac9f74648f)
Change-Id: Id3f105476f42bd747a098f081a07b161036e4922
Merged-In: Id3f105476f42bd747a098f081a07b161036e4922
(cherry picked from commit 93c679813ab8f19a2d696a22f5fce229a1e62a73)
diff --git a/runtime/Android.bp b/runtime/Android.bp
index 0423925..ba9a769 100644
--- a/runtime/Android.bp
+++ b/runtime/Android.bp
@@ -45,6 +45,7 @@
"Memory.cpp",
"ModelBuilder.cpp",
"NeuralNetworks.cpp",
+ "TypeManager.cpp",
"VersionedInterfaces.cpp",
],
diff --git a/runtime/ExecutionBuilder.cpp b/runtime/ExecutionBuilder.cpp
index 7825eb7..870a8b2 100644
--- a/runtime/ExecutionBuilder.cpp
+++ b/runtime/ExecutionBuilder.cpp
@@ -25,6 +25,7 @@
#include "Manager.h"
#include "ModelBuilder.h"
#include "Tracing.h"
+#include "TypeManager.h"
#include "Utils.h"
#include <mutex>
@@ -43,7 +44,12 @@
static bool checkDimensionInfo(const Operand& operand, const ANeuralNetworksOperandType* newType,
const char* tag, bool allowUnspecified) {
if (newType != nullptr) {
- if (validateOperandType(*newType, tag, allowUnspecified) != ANEURALNETWORKS_NO_ERROR) {
+ const Extension::OperandTypeInformation* info = nullptr;
+ if (isExtensionOperandType(operand.type)) {
+ NN_RET_CHECK(TypeManager::get()->getExtensionOperandTypeInfo(operand.type, &info));
+ }
+ if (validateOperandType(*newType, info, tag, allowUnspecified) !=
+ ANEURALNETWORKS_NO_ERROR) {
LOG(ERROR) << tag << ": Invalid newType";
return false;
}
@@ -61,7 +67,8 @@
}
}
} else {
- if (!allowUnspecified && hasUnspecifiedDimensions(operand)) {
+ if (!allowUnspecified && TypeManager::get()->isTensorType(operand.type) &&
+ tensorHasUnspecifiedDimensions(operand)) {
LOG(ERROR) << tag << ": Setting with operand type that is not fully specified";
return false;
}
@@ -82,8 +89,8 @@
state = ModelArgumentInfo::HAS_NO_VALUE;
} else {
NN_RETURN_IF_ERROR(updateDimensionInfo(operand, type));
- if (!isExtensionOperandType(operand.type) && operand.type != OperandType::OEM) {
- uint32_t neededLength = sizeOfData(operand.type, dimensions);
+ if (operand.type != OperandType::OEM) {
+ uint32_t neededLength = TypeManager::get()->getSizeOfData(operand.type, dimensions);
if (neededLength != length && neededLength != 0) {
LOG(ERROR) << "Setting argument with invalid length: " << length
<< ", expected length: " << neededLength;
@@ -100,8 +107,8 @@
int ModelArgumentInfo::setFromMemory(const Operand& operand, const ANeuralNetworksOperandType* type,
uint32_t poolIndex, uint32_t offset, uint32_t length) {
NN_RETURN_IF_ERROR(updateDimensionInfo(operand, type));
- if (!isExtensionOperandType(operand.type) && operand.type != OperandType::OEM) {
- uint32_t neededLength = sizeOfData(operand.type, dimensions);
+ if (operand.type != OperandType::OEM) {
+ uint32_t neededLength = TypeManager::get()->getSizeOfData(operand.type, dimensions);
if (neededLength != length && neededLength != 0) {
LOG(ERROR) << "Setting argument with invalid length: " << length
<< ", expected length: " << neededLength;
@@ -118,8 +125,8 @@
int ModelArgumentInfo::setFromTemporaryMemory(const Operand& operand, uint32_t poolIndex,
uint32_t offset, uint32_t length) {
NN_RETURN_IF_ERROR(updateDimensionInfo(operand, nullptr));
- if (!isExtensionOperandType(operand.type) && operand.type != OperandType::OEM) {
- uint32_t neededLength = sizeOfData(operand.type, dimensions);
+ if (operand.type != OperandType::OEM) {
+ uint32_t neededLength = TypeManager::get()->getSizeOfData(operand.type, dimensions);
if (neededLength != length) {
LOG(ERROR) << "Setting argument with invalid length: " << length
<< ", expected length: " << neededLength;
@@ -687,8 +694,7 @@
// ExecutionBuilder::setOutputFromMemory()
uint32_t poolIndex = mMemories.add(memory);
- uint32_t length =
- mDevice->getSizeOfData(inputOrOutputOperand, mModel->getExtensionNameToPrefixMap());
+ uint32_t length = TypeManager::get()->getSizeOfData(inputOrOutputOperand);
return inputOrOutputInfo->setFromTemporaryMemory(inputOrOutputOperand, poolIndex, offset,
length);
}
diff --git a/runtime/ExecutionPlan.cpp b/runtime/ExecutionPlan.cpp
index 9f044df..f96c6d4 100644
--- a/runtime/ExecutionPlan.cpp
+++ b/runtime/ExecutionPlan.cpp
@@ -28,6 +28,7 @@
#include "OperationsUtils.h"
#include "TokenHasher.h"
#include "Tracing.h"
+#include "TypeManager.h"
#include "Utils.h"
#include <cutils/native_handle.h>
@@ -491,7 +492,6 @@
}
mSubModel.relaxComputationFloat32toFloat16(fromModel->isComputationFloat32RelaxedToFloat16());
- mSubModel.setExtensionNameToPrefixMap(fromModel->getExtensionNameToPrefixMap());
// Input order: mModelInputs, mTempsAsSubModelInputs, mOutputsAsSubModelInputs
// Output order: mModelOutputs, mTempsAsSubModelOutputs
@@ -709,8 +709,7 @@
subModelInputsAndOutputs =
std::make_shared<Controller::SubModelInputsAndOutputsType>();
}
- const uint32_t size = step->getDevice()->getSizeOfData(
- fromModelOperand, fromModel->getExtensionNameToPrefixMap());
+ const uint32_t size = TypeManager::get()->getSizeOfData(fromModelOperand);
totalSizeOfTemporaries += alignBytesNeeded(totalSizeOfTemporaries, size);
subModelInputsAndOutputs->insert(std::make_pair(fromModelOperandIndex, totalSizeOfTemporaries));
totalSizeOfTemporaries += size;
diff --git a/runtime/Manager.cpp b/runtime/Manager.cpp
index 6c15daf..26df6a1 100644
--- a/runtime/Manager.cpp
+++ b/runtime/Manager.cpp
@@ -37,40 +37,6 @@
namespace android {
namespace nn {
-uint32_t Device::getSizeOfData(const Operand& operand,
- const std::map<std::string, uint16_t>& extensionNameToPrefix) const {
- if (!isExtensionOperandType(operand.type)) {
- return sizeOfData(operand);
- }
-
- // A slow naive implementation.
- // TODO(b/123178734): Speed it up.
- uint32_t operandType = static_cast<uint32_t>(operand.type);
- uint8_t kLowBitsType = static_cast<uint8_t>(Model::ExtensionTypeEncoding::LOW_BITS_TYPE);
- uint16_t prefix = operandType >> kLowBitsType;
- uint16_t typeWithinExtension = operandType & ((1 << kLowBitsType) - 1);
- for (const Extension& extension : getSupportedExtensions()) {
- if (extensionNameToPrefix.at(extension.name) != prefix) {
- continue;
- }
- for (auto& extensionOperandType : extension.operandTypes) {
- if (extensionOperandType.type == typeWithinExtension) {
- uint32_t numElements = 1;
- if (extensionOperandType.isTensor) {
- for (auto dimension : operand.dimensions) {
- numElements *= dimension;
- }
- }
- return numElements * extensionOperandType.byteSize;
- }
- }
- }
-
- CHECK(false) << "Cannot determine the size of extension operand type "
- << toString(operand.type);
- return 0;
-}
-
// A Device with actual underlying driver
class DriverDevice : public Device {
DISALLOW_IMPLICIT_CONSTRUCTORS(DriverDevice);
diff --git a/runtime/Manager.h b/runtime/Manager.h
index 74b3245..5f70450 100644
--- a/runtime/Manager.h
+++ b/runtime/Manager.h
@@ -56,9 +56,6 @@
const hidl_handle& modelCache, const hidl_handle& dataCache,
const hidl_array<uint8_t, ANEURALNETWORKS_BYTE_SIZE_OF_CACHE_TOKEN>& token,
std::shared_ptr<VersionedIPreparedModel>* preparedModel) = 0;
-
- uint32_t getSizeOfData(const Operand& operand,
- const std::map<std::string, uint16_t>& extensionNameToPrefix) const;
};
// Manages the NN HAL devices. Only one instance of this class will exist.
diff --git a/runtime/ModelBuilder.cpp b/runtime/ModelBuilder.cpp
index 3e82fc8..dcfa60b 100644
--- a/runtime/ModelBuilder.cpp
+++ b/runtime/ModelBuilder.cpp
@@ -21,6 +21,7 @@
#include "CompilationBuilder.h"
#include "GraphDump.h"
#include "Manager.h"
+#include "TypeManager.h"
#include "Utils.h"
#include "ValidateHal.h"
@@ -46,9 +47,6 @@
// The maximum number of operands and operations that a model may have.
const uint32_t MAX_NUMBER_OF_OPERANDS = 0xFFFFFFFE;
const uint32_t MAX_NUMBER_OF_OPERATIONS = 0xFFFFFFFE;
-const uint32_t MAX_NUMBER_OF_EXTENSIONS_IN_USE =
- // -2 because prefix 0x0000 corresponds to no extension.
- (1 << static_cast<uint8_t>(Model::ExtensionTypeEncoding::HIGH_BITS_PREFIX)) - 2;
ModelBuilder::ModelBuilder() {
std::string path = ::android::procpartition::getExe(getpid());
@@ -75,21 +73,9 @@
int ModelBuilder::getExtensionType(const char* extensionName, uint16_t typeWithinExtension,
int32_t* type) {
- uint16_t prefix;
- auto it = mExtensionNameToPrefix.find(extensionName);
- if (it != mExtensionNameToPrefix.end()) {
- prefix = it->second;
- } else {
- if (mExtensionNameToPrefix.size() == MAX_NUMBER_OF_EXTENSIONS_IN_USE) {
- LOG(ERROR) << "Too many extension types in use";
- return ANEURALNETWORKS_BAD_DATA;
- }
- prefix = mExtensionNameToPrefix.size() + 1;
- mExtensionNameToPrefix[extensionName] = prefix;
- }
- *type = (prefix << static_cast<uint8_t>(Model::ExtensionTypeEncoding::LOW_BITS_TYPE)) |
- typeWithinExtension;
- return ANEURALNETWORKS_NO_ERROR;
+ return TypeManager::get()->getExtensionType(extensionName, typeWithinExtension, type)
+ ? ANEURALNETWORKS_NO_ERROR
+ : ANEURALNETWORKS_BAD_DATA;
}
int ModelBuilder::addOperand(const ANeuralNetworksOperandType& type) {
@@ -106,7 +92,13 @@
LOG(WARNING) << "OEM data type is deprecated. Use Extensions instead.";
}
- NN_RETURN_IF_ERROR(validateOperandType(type, "ANeuralNetworksModel_addOperand", true));
+ const Extension::OperandTypeInformation* info = nullptr;
+ if (isExtensionOperandType(operandType) &&
+ !TypeManager::get()->getExtensionOperandTypeInfo(operandType, &info)) {
+ LOG(ERROR) << "Extension operand type " << toString(operandType) << " is not registered";
+ return ANEURALNETWORKS_BAD_DATA;
+ }
+ NN_RETURN_IF_ERROR(validateOperandType(type, info, "ANeuralNetworksModel_addOperand", true));
size_t idx = mOperands.size();
if (idx >= MAX_NUMBER_OF_OPERANDS) {
LOG(ERROR) << "ANeuralNetworksModel_addOperand exceed max operands";
@@ -149,7 +141,8 @@
// The location is unused and is set to zeros.
operand.location = {.poolIndex = 0, .offset = 0, .length = 0};
} else {
- if (hasUnspecifiedDimensions(operand)) {
+ if (TypeManager::get()->isTensorType(operand.type) &&
+ tensorHasUnspecifiedDimensions(operand)) {
LOG(ERROR) << "ANeuralNetworksModel_setOperandValue setting operand " << index
<< " which has operand type that is not fully specified";
return ANEURALNETWORKS_BAD_DATA;
@@ -160,8 +153,8 @@
return ANEURALNETWORKS_BAD_DATA;
}
uint32_t valueLength = static_cast<uint32_t>(length);
- if (!isExtensionOperandType(operand.type) && operand.type != OperandType::OEM) {
- uint32_t neededLength = sizeOfData(operand.type, operand.dimensions);
+ if (operand.type != OperandType::OEM) {
+ uint32_t neededLength = TypeManager::get()->getSizeOfData(operand);
if (neededLength != valueLength) {
LOG(ERROR) << "ANeuralNetworksModel_setOperandValue setting " << valueLength
<< " bytes when needing " << neededLength;
@@ -320,7 +313,7 @@
return ANEURALNETWORKS_BAD_DATA;
}
Operand& operand = mOperands[index];
- if (hasUnspecifiedDimensions(operand)) {
+ if (TypeManager::get()->isTensorType(operand.type) && tensorHasUnspecifiedDimensions(operand)) {
LOG(ERROR) << "ANeuralNetworksModel_setOperandValueFromMemory setting operand " << index
<< " which has operand type that is not fully specified";
return ANEURALNETWORKS_BAD_DATA;
@@ -331,13 +324,11 @@
<< " that is not in AHARDWAREBUFFER_FORMAT_BLOB format";
return ANEURALNETWORKS_UNMAPPABLE;
}
- if (!isExtensionOperandType(operand.type)) {
- uint32_t neededLength = sizeOfData(operand.type, operand.dimensions);
- if (neededLength != length) {
- LOG(ERROR) << "ANeuralNetworksModel_setOperandValueFromMemory setting " << length
- << " bytes when needing " << neededLength;
- return ANEURALNETWORKS_BAD_DATA;
- }
+ uint32_t neededLength = TypeManager::get()->getSizeOfData(operand);
+ if (neededLength != length) {
+ LOG(ERROR) << "ANeuralNetworksModel_setOperandValueFromMemory setting " << length
+ << " bytes when needing " << neededLength;
+ return ANEURALNETWORKS_BAD_DATA;
}
if (!memory->validateSize(offset, length)) {
return ANEURALNETWORKS_BAD_DATA;
@@ -458,15 +449,6 @@
return ANEURALNETWORKS_NO_ERROR;
}
-void ModelBuilder::setExtensionNameToPrefixMap(
- const std::map<std::string, uint16_t>& extensionNameToPrefix) {
- mExtensionNameToPrefix = extensionNameToPrefix;
-}
-
-const std::map<std::string, uint16_t>& ModelBuilder::getExtensionNameToPrefixMap() const {
- return mExtensionNameToPrefix;
-}
-
int ModelBuilder::createCompilation(CompilationBuilder** compilation,
const std::vector<std::shared_ptr<Device>>& devices,
bool forceNoFallback) {
@@ -579,21 +561,44 @@
model->outputIndexes = mOutputIndexes;
model->operandValues = mSmallOperandValues;
model->relaxComputationFloat32toFloat16 = mRelaxComputationFloat32toFloat16;
+ model->extensionNameToPrefix = getExtensionNameToPrefixMap();
uint32_t count = mMemories.size();
model->pools.resize(count);
for (uint32_t i = 0; i < count; i++) {
model->pools[i] = mMemories[i]->getHidlMemory();
}
+}
- std::vector<Model::ExtensionNameAndPrefix> extensionNameToPrefixVec;
- for (auto& nameAndPrefix : mExtensionNameToPrefix) {
- extensionNameToPrefixVec.push_back({
- .name = nameAndPrefix.first,
- .prefix = nameAndPrefix.second,
+std::vector<Model::ExtensionNameAndPrefix> ModelBuilder::getExtensionNameToPrefixMap() const {
+ std::vector<Model::ExtensionNameAndPrefix> extensionNameToPrefix;
+ std::set<uint16_t> prefixSet;
+
+ auto addExtensionWithPrefix = [&extensionNameToPrefix, &prefixSet](uint16_t prefix) {
+ if (!prefixSet.insert(prefix).second) {
+ return;
+ }
+ const Extension* extension;
+ CHECK(TypeManager::get()->getExtensionInfo(prefix, &extension));
+ extensionNameToPrefix.push_back({
+ .name = extension->name,
+ .prefix = prefix,
});
+ };
+
+ constexpr uint8_t kLowBitsType =
+ static_cast<uint8_t>(Model::ExtensionTypeEncoding::LOW_BITS_TYPE);
+ for (const auto& operand : mOperands) {
+ if (isExtensionOperandType(operand.type)) {
+ addExtensionWithPrefix(static_cast<uint32_t>(operand.type) >> kLowBitsType);
+ }
}
- model->extensionNameToPrefix = extensionNameToPrefixVec;
+ for (const auto& operation : mOperations) {
+ if (isExtensionOperationType(operation.type)) {
+ addExtensionWithPrefix(static_cast<uint32_t>(operation.type) >> kLowBitsType);
+ }
+ }
+ return extensionNameToPrefix;
}
} // namespace nn
diff --git a/runtime/ModelBuilder.h b/runtime/ModelBuilder.h
index e5cd468..1ef3ef0 100644
--- a/runtime/ModelBuilder.h
+++ b/runtime/ModelBuilder.h
@@ -54,9 +54,6 @@
int relaxComputationFloat32toFloat16(bool allow);
bool isComputationFloat32RelaxedToFloat16() const { return mRelaxComputationFloat32toFloat16; }
- void setExtensionNameToPrefixMap(const std::map<std::string, uint16_t>&);
- const std::map<std::string, uint16_t>& getExtensionNameToPrefixMap() const;
-
int finish();
bool isFinished() const { return mCompletedModel; }
bool isValid() const { return !mInvalidModel; }
@@ -120,6 +117,12 @@
// Copies the large values to a shared memory, if we have any.
int copyLargeValuesToSharedMemory();
+ // Returns the list of extension names and corresponding numeric "prefixes"
+ // of operand and operation type values used in the model.
+ //
+ // Devices rely on this mapping to interpret extension types.
+ std::vector<Model::ExtensionNameAndPrefix> getExtensionNameToPrefixMap() const;
+
// The operations of the graph.
std::vector<Operation> mOperations;
// The mapping from sorted index to the original index of operations in mOperations.
@@ -168,12 +171,6 @@
// 'false' indicates TENSOR_FLOAT32 must be calculated using at least the
// range and precision of the IEEE 754 32-bit floating-point format.
bool mRelaxComputationFloat32toFloat16 = false;
-
- // Maps extension names to numeric "prefixes" of operand and operation
- // type values. Devices rely on these prefixes to interpret extension types.
- // TODO(b/123523457): Have a global name-to-prefix mapping instead of
- // storing it here.
- std::map<std::string, uint16_t> mExtensionNameToPrefix;
};
} // namespace nn
diff --git a/runtime/TypeManager.cpp b/runtime/TypeManager.cpp
new file mode 100644
index 0000000..00219a2
--- /dev/null
+++ b/runtime/TypeManager.cpp
@@ -0,0 +1,169 @@
+/*
+ * Copyright (C) 2019 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#define LOG_TAG "TypeManager"
+
+#include "TypeManager.h"
+
+#include "Utils.h"
+
+#include <algorithm>
+
+namespace android {
+namespace nn {
+namespace {
+
+const uint8_t kLowBitsType = static_cast<uint8_t>(Model::ExtensionTypeEncoding::LOW_BITS_TYPE);
+const uint32_t kMaxPrefix =
+ (1 << static_cast<uint8_t>(Model::ExtensionTypeEncoding::HIGH_BITS_PREFIX)) - 1;
+
+// Checks if the two structures contain the same information. The order of
+// operand types within the structures does not matter.
+bool equal(const Extension& a, const Extension& b) {
+ NN_RET_CHECK_EQ(a.name, b.name);
+ // Relies on the fact that TypeManager sorts operandTypes.
+ NN_RET_CHECK(a.operandTypes == b.operandTypes);
+ return true;
+}
+
+} // namespace
+
+TypeManager::TypeManager() {
+ VLOG(MANAGER) << "TypeManager::TypeManager";
+ findAvailableExtensions();
+}
+
+void TypeManager::findAvailableExtensions() {
+ for (const std::shared_ptr<Device>& device : mDeviceManager->getDrivers()) {
+ for (const Extension extension : device->getSupportedExtensions()) {
+ registerExtension(extension, device->getName());
+ }
+ }
+}
+
+bool TypeManager::registerExtension(Extension extension, const std::string& deviceName) {
+ if (mDisabledExtensions.find(extension.name) != mDisabledExtensions.end()) {
+ LOG(ERROR) << "Extension " << extension.name << " is disabled";
+ return false;
+ }
+
+ std::sort(extension.operandTypes.begin(), extension.operandTypes.end(),
+ [](const Extension::OperandTypeInformation& a,
+ const Extension::OperandTypeInformation& b) {
+ return static_cast<uint16_t>(a.type) < static_cast<uint16_t>(b.type);
+ });
+
+ std::map<std::string, Extension>::iterator it;
+ bool isNew;
+ std::tie(it, isNew) = mExtensionNameToExtension.emplace(extension.name, extension);
+ if (isNew) {
+ VLOG(MANAGER) << "Registered extension " << extension.name;
+ mExtensionNameToFirstDevice.emplace(extension.name, deviceName);
+ } else if (!equal(extension, it->second)) {
+ LOG(ERROR) << "Devices " << mExtensionNameToFirstDevice[extension.name] << " and "
+ << deviceName << " provide inconsistent information for extension "
+ << extension.name << ", which is therefore disabled";
+ mExtensionNameToExtension.erase(it);
+ mDisabledExtensions.insert(extension.name);
+ return false;
+ }
+ return true;
+}
+
+bool TypeManager::getExtensionPrefix(const std::string& extensionName, uint16_t* prefix) {
+ auto it = mExtensionNameToPrefix.find(extensionName);
+ if (it != mExtensionNameToPrefix.end()) {
+ *prefix = it->second;
+ } else {
+ NN_RET_CHECK_LE(mPrefixToExtension.size(), kMaxPrefix) << "Too many extensions in use";
+ *prefix = mPrefixToExtension.size();
+ mExtensionNameToPrefix[extensionName] = *prefix;
+ mPrefixToExtension.push_back(&mExtensionNameToExtension[extensionName]);
+ }
+ return true;
+}
+
+bool TypeManager::getExtensionType(const char* extensionName, uint16_t typeWithinExtension,
+ int32_t* type) {
+ uint16_t prefix;
+ NN_RET_CHECK(getExtensionPrefix(extensionName, &prefix));
+ *type = (prefix << kLowBitsType) | typeWithinExtension;
+ return true;
+}
+
+bool TypeManager::getExtensionInfo(uint16_t prefix, const Extension** extension) const {
+ NN_RET_CHECK_NE(prefix, 0u) << "prefix=0 does not correspond to an extension";
+ NN_RET_CHECK_LT(prefix, mPrefixToExtension.size()) << "Unknown extension prefix";
+ *extension = mPrefixToExtension[prefix];
+ return true;
+}
+
+bool TypeManager::getExtensionOperandTypeInfo(
+ OperandType type, const Extension::OperandTypeInformation** info) const {
+ uint32_t operandType = static_cast<uint32_t>(type);
+ uint16_t prefix = operandType >> kLowBitsType;
+ uint16_t typeWithinExtension = operandType & ((1 << kLowBitsType) - 1);
+ const Extension* extension;
+ NN_RET_CHECK(getExtensionInfo(prefix, &extension))
+ << "Cannot find extension corresponding to prefix " << prefix;
+ auto it = std::lower_bound(
+ extension->operandTypes.begin(), extension->operandTypes.end(), typeWithinExtension,
+ [](const Extension::OperandTypeInformation& info, uint32_t typeSought) {
+ return static_cast<uint16_t>(info.type) < typeSought;
+ });
+ NN_RET_CHECK(it != extension->operandTypes.end() &&
+ static_cast<uint16_t>(it->type) == typeWithinExtension)
+ << "Cannot find operand type " << typeWithinExtension << " in extension "
+ << extension->name;
+ *info = &*it;
+ return true;
+}
+
+bool TypeManager::isTensorType(OperandType type) const {
+ if (!isExtensionOperandType(type)) {
+ return !nonExtensionOperandTypeIsScalar(static_cast<int>(type));
+ }
+ const Extension::OperandTypeInformation* info;
+ CHECK(getExtensionOperandTypeInfo(type, &info));
+ return info->isTensor;
+}
+
+uint32_t TypeManager::getSizeOfData(OperandType type,
+ const std::vector<uint32_t>& dimensions) const {
+ if (!isExtensionOperandType(type)) {
+ return nonExtensionOperandSizeOfData(type, dimensions);
+ }
+
+ const Extension::OperandTypeInformation* info;
+ CHECK(getExtensionOperandTypeInfo(type, &info));
+
+ if (!info->isTensor) {
+ return info->byteSize;
+ }
+
+ if (dimensions.empty()) {
+ return 0;
+ }
+
+ uint32_t size = info->byteSize;
+ for (auto dimension : dimensions) {
+ size *= dimension;
+ }
+ return size;
+}
+
+} // namespace nn
+} // namespace android
diff --git a/runtime/TypeManager.h b/runtime/TypeManager.h
new file mode 100644
index 0000000..b1cdcc0
--- /dev/null
+++ b/runtime/TypeManager.h
@@ -0,0 +1,133 @@
+/*
+ * Copyright (C) 2019 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef ANDROID_ML_NN_RUNTIME_EXTENSION_MANAGER_H
+#define ANDROID_ML_NN_RUNTIME_EXTENSION_MANAGER_H
+
+#include "HalInterfaces.h"
+#include "Manager.h"
+
+#include <map>
+#include <set>
+#include <string>
+
+namespace android {
+namespace nn {
+
+// Manages runtime operand and operation type information.
+//
+// This class gathers information about extension types from all devices
+// and provides a unified way to access information about any known type.
+class TypeManager {
+ public:
+ static TypeManager* get() {
+ static TypeManager manager;
+ return &manager;
+ }
+
+ // Creates an operand/operation type corresponding to a given extension
+ // name and type within extension.
+ //
+ // Returns false if the extension is unknown.
+ bool getExtensionType(const char* extensionName, uint16_t typeWithinExtension, int32_t* type);
+
+ // Looks up information about the extension corresponding to the given prefix
+ //
+ // Returns false if no extension corresponds to the given prefix.
+ bool getExtensionInfo(uint16_t prefix, const Extension** extension) const;
+
+ // Looks up information about an extension operand type
+ //
+ // Returns false if the extension or type is unknown.
+ bool getExtensionOperandTypeInfo(OperandType type,
+ const Extension::OperandTypeInformation** info) const;
+
+ // Returns true if an operand type is a tensor type.
+ //
+ // Aborts if the type is an unknown extension type.
+ bool isTensorType(OperandType type) const;
+
+ // Returns the amount of space needed to store a value of the dimensions and
+ // type of this operand. For a tensor with unspecified rank or at least one
+ // unspecified dimension, returns zero.
+ //
+ // Aborts if the type is an unknown extension type.
+ uint32_t getSizeOfData(const Operand& operand) const {
+ return getSizeOfData(operand.type, operand.dimensions);
+ }
+
+ // Returns the amount of space needed to store a value of the specified
+ // dimensions and type. For a tensor with unspecified rank or at least one
+ // unspecified dimension, returns zero.
+ //
+ // Aborts if the type is an unknown extension type.
+ uint32_t getSizeOfData(OperandType type, const std::vector<uint32_t>& dimensions) const;
+
+ // This method is intended for use only by internal unit tests.
+ //
+ // Registers an extension.
+ //
+ // Returns true if the registration was successful.
+ bool forTest_registerExtension(const Extension& extension) {
+ return registerExtension(extension, "INTERNAL TEST");
+ }
+
+ // This method is intended for use only by internal unit tests.
+ //
+ // Resets the internal state.
+ //
+ // After calling forTest_registerExtension() any number of times, call
+ // forTest_reset() to return to the state as if forTest_registerExtension()
+ // had never been called. Note that forTest_reset() resets all internal
+ // state (including assigned prefixes) and re-discovers extensions from
+ // available devices.
+ void forTest_reset() { *this = TypeManager(); }
+
+ private:
+ TypeManager();
+ void findAvailableExtensions();
+ bool registerExtension(Extension extension, const std::string& deviceName);
+
+ // Returns the numeric "prefix" value corresponding to an extension.
+ //
+ // Returns false when assigning a new prefix would overflow uint16_t.
+ bool getExtensionPrefix(const std::string& extensionName, uint16_t* prefix);
+
+ const DeviceManager* mDeviceManager = DeviceManager::get();
+
+ // Contains all registered extensions.
+ std::map<std::string, Extension> mExtensionNameToExtension;
+
+ // Contains the name of the first discovered device that supports an
+ // extension. Used for error reporting.
+ std::map<std::string, std::string> mExtensionNameToFirstDevice;
+
+ // When multiple devices report conflicting information about an extension,
+ // the extension is disabled.
+ std::set<std::string> mDisabledExtensions;
+
+ // The fields below are used to support efficient extension name to
+ // prefix mapping. New prefixes are created by getExtensionPrefix.
+ std::map<std::string, uint16_t> mExtensionNameToPrefix;
+ // Entries of mPrefixToExtension point into mExtensionNameToExtension.
+ // prefix=0 corresponds to no extension and should never be looked up.
+ std::vector<Extension*> mPrefixToExtension = {nullptr};
+};
+
+} // namespace nn
+} // namespace android
+
+#endif // ANDROID_ML_NN_RUNTIME_EXTENSION_MANAGER_H
diff --git a/runtime/test/TestValidation.cpp b/runtime/test/TestValidation.cpp
index 2cfc008..36efeb6 100644
--- a/runtime/test/TestValidation.cpp
+++ b/runtime/test/TestValidation.cpp
@@ -25,12 +25,13 @@
#ifndef NNTEST_ONLY_PUBLIC_API
#include "NeuralNetworksExtensions.h"
-const char* kTestExtensionName = "vendor.test.validation_test_extension";
+#include "TypeManager.h"
#endif
// This file tests all the validations done by the Neural Networks API.
namespace {
+
class ValidationTest : public ::testing::Test {
protected:
virtual void SetUp() {}
@@ -54,15 +55,10 @@
return mNumOperands++;
}
-#ifndef NNTEST_ONLY_PUBLIC_API
- int32_t getExtensionOperandType(uint16_t typeWithinExtension) {
- int32_t result;
- EXPECT_EQ(ANeuralNetworksModel_getExtensionOperandType(mModel, kTestExtensionName,
- typeWithinExtension, &result),
- ANEURALNETWORKS_NO_ERROR);
- return result;
+ uint32_t addOperand(const ANeuralNetworksOperandType& operandType) {
+ EXPECT_EQ(ANeuralNetworksModel_addOperand(mModel, &operandType), ANEURALNETWORKS_NO_ERROR);
+ return mNumOperands++;
}
-#endif
uint32_t addTensorOperand(int32_t type = ANEURALNETWORKS_TENSOR_FLOAT32) {
uint32_t dimensions[] = {2};
@@ -71,8 +67,7 @@
.dimensionCount = sizeof(dimensions) / sizeof(dimensions[0]),
.dimensions = dimensions,
};
- EXPECT_EQ(ANeuralNetworksModel_addOperand(mModel, &operandType), ANEURALNETWORKS_NO_ERROR);
- return mNumOperands++;
+ return addOperand(operandType);
}
void createModel() {
@@ -110,6 +105,42 @@
};
};
+#ifndef NNTEST_ONLY_PUBLIC_API
+constexpr const char* kTestExtensionName = "com.android.test_extension";
+constexpr int32_t kTestExtensionTensorType = ANEURALNETWORKS_TENSOR_QUANT8_SYMM_PER_CHANNEL;
+
+class ValidationTestModelExtensions : public ValidationTestModel {
+ protected:
+ virtual void SetUp() {
+ ValidationTestModel::SetUp();
+ EXPECT_TRUE(::android::nn::TypeManager::get()->forTest_registerExtension({
+ .name = kTestExtensionName,
+ .operandTypes =
+ {
+ {
+ .type = kTestExtensionTensorType,
+ .isTensor = true,
+ .byteSize = 1,
+ },
+ },
+ }));
+ }
+
+ virtual void TearDown() {
+ ::android::nn::TypeManager::get()->forTest_reset();
+ ValidationTestModel::TearDown();
+ }
+
+ int32_t getExtensionOperandType(uint16_t typeWithinExtension) {
+ int32_t result;
+ EXPECT_EQ(ANeuralNetworksModel_getExtensionOperandType(mModel, kTestExtensionName,
+ typeWithinExtension, &result),
+ ANEURALNETWORKS_NO_ERROR);
+ return result;
+ }
+};
+#endif
+
class ValidationTestIdentify : public ValidationTestModel {
virtual void SetUp() {
ValidationTestModel::SetUp();
@@ -259,9 +290,14 @@
}
#ifndef NNTEST_ONLY_PUBLIC_API
-TEST_F(ValidationTestModel, SetOperandSymmPerChannelQuantParams_ExtensionOperand) {
- const int32_t operandIndex = addTensorOperand(
- getExtensionOperandType(ANEURALNETWORKS_TENSOR_QUANT8_SYMM_PER_CHANNEL));
+TEST_F(ValidationTestModelExtensions, AddOperand_UnknownPrefix) {
+ ANeuralNetworksOperandType type = {.type = -1};
+ ASSERT_EQ(ANeuralNetworksModel_addOperand(mModel, &type), ANEURALNETWORKS_BAD_DATA);
+}
+
+TEST_F(ValidationTestModelExtensions, SetOperandSymmPerChannelQuantParams_ExtensionOperand) {
+ const int32_t operandIndex =
+ addTensorOperand(getExtensionOperandType(kTestExtensionTensorType));
float scales[2] = {1.0, 2.0};
ANeuralNetworksSymmPerChannelQuantParams channelQuant = {
@@ -275,8 +311,9 @@
ANEURALNETWORKS_BAD_DATA);
}
-TEST_F(ValidationTestModel, SetOperandExtensionData) {
- const int32_t operandIndex = addTensorOperand(getExtensionOperandType(0));
+TEST_F(ValidationTestModelExtensions, SetOperandExtensionData) {
+ const int32_t operandIndex =
+ addTensorOperand(getExtensionOperandType(kTestExtensionTensorType));
const int32_t data = 42;
const size_t dataLength = sizeof(data);
EXPECT_EQ(
@@ -294,19 +331,45 @@
ANEURALNETWORKS_NO_ERROR);
}
-TEST_F(ValidationTestModel, SetOperandExtensionData_Empty) {
- const int32_t operandIndex = addTensorOperand(getExtensionOperandType(0));
+TEST_F(ValidationTestModelExtensions, SetOperandExtensionData_Empty) {
+ const int32_t operandIndex =
+ addTensorOperand(getExtensionOperandType(kTestExtensionTensorType));
EXPECT_EQ(ANeuralNetworksModel_setOperandExtensionData(mModel, operandIndex, nullptr, 0),
ANEURALNETWORKS_NO_ERROR);
}
-TEST_F(ValidationTestModel, SetOperandExtensionData_NonExtensionOperand) {
+TEST_F(ValidationTestModelExtensions, SetOperandExtensionData_NonExtensionOperand) {
const int32_t operandIndex = addTensorOperand();
const int32_t data = 42;
const size_t dataLength = sizeof(data);
EXPECT_EQ(ANeuralNetworksModel_setOperandExtensionData(mModel, operandIndex, &data, dataLength),
ANEURALNETWORKS_BAD_DATA);
}
+
+TEST_F(ValidationTestModelExtensions, SetOperandValue_UnspecifiedDimension) {
+ const uint32_t dimensions[2] = {3, 0};
+ ANeuralNetworksOperandType type = {
+ .type = getExtensionOperandType(kTestExtensionTensorType),
+ .dimensionCount = 2,
+ .dimensions = dimensions,
+ };
+ const int32_t operandIndex = addOperand(type);
+ char buffer[20];
+ EXPECT_EQ(ANeuralNetworksModel_setOperandValue(mModel, operandIndex, buffer, sizeof(buffer)),
+ ANEURALNETWORKS_BAD_DATA);
+}
+
+TEST_F(ValidationTestModelExtensions, SetOperandValue_UnspecifiedRank) {
+ ANeuralNetworksOperandType type = {
+ .type = getExtensionOperandType(kTestExtensionTensorType),
+ .dimensionCount = 0,
+ .dimensions = nullptr,
+ };
+ const int32_t operandIndex = addOperand(type);
+ char buffer[20];
+ EXPECT_EQ(ANeuralNetworksModel_setOperandValue(mModel, operandIndex, buffer, sizeof(buffer)),
+ ANEURALNETWORKS_BAD_DATA);
+}
#endif
TEST_F(ValidationTestModel, SetOptionalOperand) {