Add a centralized runtime extension information store

Fix: 124107169
Fix: 123523457
Fix: 124285861
Fix: 124104123
Fix: 123178734
Test: NeuralNetworksTest_static
Test: NeuralNetworksTest_utils
Test: NeuralNetworksTest_FibonacciExtension (from change Ibe0fc5356baa909bce8424138bd5cfac9f74648f)
Change-Id: Id3f105476f42bd747a098f081a07b161036e4922
Merged-In: Id3f105476f42bd747a098f081a07b161036e4922
(cherry picked from commit 93c679813ab8f19a2d696a22f5fce229a1e62a73)
diff --git a/runtime/Android.bp b/runtime/Android.bp
index 0423925..ba9a769 100644
--- a/runtime/Android.bp
+++ b/runtime/Android.bp
@@ -45,6 +45,7 @@
         "Memory.cpp",
         "ModelBuilder.cpp",
         "NeuralNetworks.cpp",
+        "TypeManager.cpp",
         "VersionedInterfaces.cpp",
     ],
 
diff --git a/runtime/ExecutionBuilder.cpp b/runtime/ExecutionBuilder.cpp
index 7825eb7..870a8b2 100644
--- a/runtime/ExecutionBuilder.cpp
+++ b/runtime/ExecutionBuilder.cpp
@@ -25,6 +25,7 @@
 #include "Manager.h"
 #include "ModelBuilder.h"
 #include "Tracing.h"
+#include "TypeManager.h"
 #include "Utils.h"
 
 #include <mutex>
@@ -43,7 +44,12 @@
 static bool checkDimensionInfo(const Operand& operand, const ANeuralNetworksOperandType* newType,
                                const char* tag, bool allowUnspecified) {
     if (newType != nullptr) {
-        if (validateOperandType(*newType, tag, allowUnspecified) != ANEURALNETWORKS_NO_ERROR) {
+        const Extension::OperandTypeInformation* info = nullptr;
+        if (isExtensionOperandType(operand.type)) {
+            NN_RET_CHECK(TypeManager::get()->getExtensionOperandTypeInfo(operand.type, &info));
+        }
+        if (validateOperandType(*newType, info, tag, allowUnspecified) !=
+            ANEURALNETWORKS_NO_ERROR) {
             LOG(ERROR) << tag << ": Invalid newType";
             return false;
         }
@@ -61,7 +67,8 @@
             }
         }
     } else {
-        if (!allowUnspecified && hasUnspecifiedDimensions(operand)) {
+        if (!allowUnspecified && TypeManager::get()->isTensorType(operand.type) &&
+            tensorHasUnspecifiedDimensions(operand)) {
             LOG(ERROR) << tag << ": Setting with operand type that is not fully specified";
             return false;
         }
@@ -82,8 +89,8 @@
         state = ModelArgumentInfo::HAS_NO_VALUE;
     } else {
         NN_RETURN_IF_ERROR(updateDimensionInfo(operand, type));
-        if (!isExtensionOperandType(operand.type) && operand.type != OperandType::OEM) {
-            uint32_t neededLength = sizeOfData(operand.type, dimensions);
+        if (operand.type != OperandType::OEM) {
+            uint32_t neededLength = TypeManager::get()->getSizeOfData(operand.type, dimensions);
             if (neededLength != length && neededLength != 0) {
                 LOG(ERROR) << "Setting argument with invalid length: " << length
                            << ", expected length: " << neededLength;
@@ -100,8 +107,8 @@
 int ModelArgumentInfo::setFromMemory(const Operand& operand, const ANeuralNetworksOperandType* type,
                                      uint32_t poolIndex, uint32_t offset, uint32_t length) {
     NN_RETURN_IF_ERROR(updateDimensionInfo(operand, type));
-    if (!isExtensionOperandType(operand.type) && operand.type != OperandType::OEM) {
-        uint32_t neededLength = sizeOfData(operand.type, dimensions);
+    if (operand.type != OperandType::OEM) {
+        uint32_t neededLength = TypeManager::get()->getSizeOfData(operand.type, dimensions);
         if (neededLength != length && neededLength != 0) {
             LOG(ERROR) << "Setting argument with invalid length: " << length
                        << ", expected length: " << neededLength;
@@ -118,8 +125,8 @@
 int ModelArgumentInfo::setFromTemporaryMemory(const Operand& operand, uint32_t poolIndex,
                                               uint32_t offset, uint32_t length) {
     NN_RETURN_IF_ERROR(updateDimensionInfo(operand, nullptr));
-    if (!isExtensionOperandType(operand.type) && operand.type != OperandType::OEM) {
-        uint32_t neededLength = sizeOfData(operand.type, dimensions);
+    if (operand.type != OperandType::OEM) {
+        uint32_t neededLength = TypeManager::get()->getSizeOfData(operand.type, dimensions);
         if (neededLength != length) {
             LOG(ERROR) << "Setting argument with invalid length: " << length
                        << ", expected length: " << neededLength;
@@ -687,8 +694,7 @@
     //     ExecutionBuilder::setOutputFromMemory()
 
     uint32_t poolIndex = mMemories.add(memory);
-    uint32_t length =
-            mDevice->getSizeOfData(inputOrOutputOperand, mModel->getExtensionNameToPrefixMap());
+    uint32_t length = TypeManager::get()->getSizeOfData(inputOrOutputOperand);
     return inputOrOutputInfo->setFromTemporaryMemory(inputOrOutputOperand, poolIndex, offset,
                                                      length);
 }
diff --git a/runtime/ExecutionPlan.cpp b/runtime/ExecutionPlan.cpp
index 9f044df..f96c6d4 100644
--- a/runtime/ExecutionPlan.cpp
+++ b/runtime/ExecutionPlan.cpp
@@ -28,6 +28,7 @@
 #include "OperationsUtils.h"
 #include "TokenHasher.h"
 #include "Tracing.h"
+#include "TypeManager.h"
 #include "Utils.h"
 
 #include <cutils/native_handle.h>
@@ -491,7 +492,6 @@
     }
 
     mSubModel.relaxComputationFloat32toFloat16(fromModel->isComputationFloat32RelaxedToFloat16());
-    mSubModel.setExtensionNameToPrefixMap(fromModel->getExtensionNameToPrefixMap());
 
     // Input order: mModelInputs, mTempsAsSubModelInputs, mOutputsAsSubModelInputs
     // Output order: mModelOutputs, mTempsAsSubModelOutputs
@@ -709,8 +709,7 @@
                     subModelInputsAndOutputs =
                             std::make_shared<Controller::SubModelInputsAndOutputsType>();
                 }
-                const uint32_t size = step->getDevice()->getSizeOfData(
-                        fromModelOperand, fromModel->getExtensionNameToPrefixMap());
+                const uint32_t size = TypeManager::get()->getSizeOfData(fromModelOperand);
                 totalSizeOfTemporaries += alignBytesNeeded(totalSizeOfTemporaries, size);
                 subModelInputsAndOutputs->insert(std::make_pair(fromModelOperandIndex, totalSizeOfTemporaries));
                 totalSizeOfTemporaries += size;
diff --git a/runtime/Manager.cpp b/runtime/Manager.cpp
index 6c15daf..26df6a1 100644
--- a/runtime/Manager.cpp
+++ b/runtime/Manager.cpp
@@ -37,40 +37,6 @@
 namespace android {
 namespace nn {
 
-uint32_t Device::getSizeOfData(const Operand& operand,
-                               const std::map<std::string, uint16_t>& extensionNameToPrefix) const {
-    if (!isExtensionOperandType(operand.type)) {
-        return sizeOfData(operand);
-    }
-
-    // A slow naive implementation.
-    // TODO(b/123178734): Speed it up.
-    uint32_t operandType = static_cast<uint32_t>(operand.type);
-    uint8_t kLowBitsType = static_cast<uint8_t>(Model::ExtensionTypeEncoding::LOW_BITS_TYPE);
-    uint16_t prefix = operandType >> kLowBitsType;
-    uint16_t typeWithinExtension = operandType & ((1 << kLowBitsType) - 1);
-    for (const Extension& extension : getSupportedExtensions()) {
-        if (extensionNameToPrefix.at(extension.name) != prefix) {
-            continue;
-        }
-        for (auto& extensionOperandType : extension.operandTypes) {
-            if (extensionOperandType.type == typeWithinExtension) {
-                uint32_t numElements = 1;
-                if (extensionOperandType.isTensor) {
-                    for (auto dimension : operand.dimensions) {
-                        numElements *= dimension;
-                    }
-                }
-                return numElements * extensionOperandType.byteSize;
-            }
-        }
-    }
-
-    CHECK(false) << "Cannot determine the size of extension operand type "
-                 << toString(operand.type);
-    return 0;
-}
-
 // A Device with actual underlying driver
 class DriverDevice : public Device {
     DISALLOW_IMPLICIT_CONSTRUCTORS(DriverDevice);
diff --git a/runtime/Manager.h b/runtime/Manager.h
index 74b3245..5f70450 100644
--- a/runtime/Manager.h
+++ b/runtime/Manager.h
@@ -56,9 +56,6 @@
             const hidl_handle& modelCache, const hidl_handle& dataCache,
             const hidl_array<uint8_t, ANEURALNETWORKS_BYTE_SIZE_OF_CACHE_TOKEN>& token,
             std::shared_ptr<VersionedIPreparedModel>* preparedModel) = 0;
-
-    uint32_t getSizeOfData(const Operand& operand,
-                           const std::map<std::string, uint16_t>& extensionNameToPrefix) const;
 };
 
 // Manages the NN HAL devices.  Only one instance of this class will exist.
diff --git a/runtime/ModelBuilder.cpp b/runtime/ModelBuilder.cpp
index 3e82fc8..dcfa60b 100644
--- a/runtime/ModelBuilder.cpp
+++ b/runtime/ModelBuilder.cpp
@@ -21,6 +21,7 @@
 #include "CompilationBuilder.h"
 #include "GraphDump.h"
 #include "Manager.h"
+#include "TypeManager.h"
 #include "Utils.h"
 #include "ValidateHal.h"
 
@@ -46,9 +47,6 @@
 // The maximum number of operands and operations that a model may have.
 const uint32_t MAX_NUMBER_OF_OPERANDS = 0xFFFFFFFE;
 const uint32_t MAX_NUMBER_OF_OPERATIONS = 0xFFFFFFFE;
-const uint32_t MAX_NUMBER_OF_EXTENSIONS_IN_USE =
-        // -2 because prefix 0x0000 corresponds to no extension.
-        (1 << static_cast<uint8_t>(Model::ExtensionTypeEncoding::HIGH_BITS_PREFIX)) - 2;
 
 ModelBuilder::ModelBuilder() {
     std::string path = ::android::procpartition::getExe(getpid());
@@ -75,21 +73,9 @@
 
 int ModelBuilder::getExtensionType(const char* extensionName, uint16_t typeWithinExtension,
                                    int32_t* type) {
-    uint16_t prefix;
-    auto it = mExtensionNameToPrefix.find(extensionName);
-    if (it != mExtensionNameToPrefix.end()) {
-        prefix = it->second;
-    } else {
-        if (mExtensionNameToPrefix.size() == MAX_NUMBER_OF_EXTENSIONS_IN_USE) {
-            LOG(ERROR) << "Too many extension types in use";
-            return ANEURALNETWORKS_BAD_DATA;
-        }
-        prefix = mExtensionNameToPrefix.size() + 1;
-        mExtensionNameToPrefix[extensionName] = prefix;
-    }
-    *type = (prefix << static_cast<uint8_t>(Model::ExtensionTypeEncoding::LOW_BITS_TYPE)) |
-            typeWithinExtension;
-    return ANEURALNETWORKS_NO_ERROR;
+    return TypeManager::get()->getExtensionType(extensionName, typeWithinExtension, type)
+                   ? ANEURALNETWORKS_NO_ERROR
+                   : ANEURALNETWORKS_BAD_DATA;
 }
 
 int ModelBuilder::addOperand(const ANeuralNetworksOperandType& type) {
@@ -106,7 +92,13 @@
         LOG(WARNING) << "OEM data type is deprecated. Use Extensions instead.";
     }
 
-    NN_RETURN_IF_ERROR(validateOperandType(type, "ANeuralNetworksModel_addOperand", true));
+    const Extension::OperandTypeInformation* info = nullptr;
+    if (isExtensionOperandType(operandType) &&
+        !TypeManager::get()->getExtensionOperandTypeInfo(operandType, &info)) {
+        LOG(ERROR) << "Extension operand type " << toString(operandType) << " is not registered";
+        return ANEURALNETWORKS_BAD_DATA;
+    }
+    NN_RETURN_IF_ERROR(validateOperandType(type, info, "ANeuralNetworksModel_addOperand", true));
     size_t idx = mOperands.size();
     if (idx >= MAX_NUMBER_OF_OPERANDS) {
         LOG(ERROR) << "ANeuralNetworksModel_addOperand exceed max operands";
@@ -149,7 +141,8 @@
         // The location is unused and is set to zeros.
         operand.location = {.poolIndex = 0, .offset = 0, .length = 0};
     } else {
-        if (hasUnspecifiedDimensions(operand)) {
+        if (TypeManager::get()->isTensorType(operand.type) &&
+            tensorHasUnspecifiedDimensions(operand)) {
             LOG(ERROR) << "ANeuralNetworksModel_setOperandValue setting operand " << index
                        << " which has operand type that is not fully specified";
             return ANEURALNETWORKS_BAD_DATA;
@@ -160,8 +153,8 @@
             return ANEURALNETWORKS_BAD_DATA;
         }
         uint32_t valueLength = static_cast<uint32_t>(length);
-        if (!isExtensionOperandType(operand.type) && operand.type != OperandType::OEM) {
-            uint32_t neededLength = sizeOfData(operand.type, operand.dimensions);
+        if (operand.type != OperandType::OEM) {
+            uint32_t neededLength = TypeManager::get()->getSizeOfData(operand);
             if (neededLength != valueLength) {
                 LOG(ERROR) << "ANeuralNetworksModel_setOperandValue setting " << valueLength
                            << " bytes when needing " << neededLength;
@@ -320,7 +313,7 @@
         return ANEURALNETWORKS_BAD_DATA;
     }
     Operand& operand = mOperands[index];
-    if (hasUnspecifiedDimensions(operand)) {
+    if (TypeManager::get()->isTensorType(operand.type) && tensorHasUnspecifiedDimensions(operand)) {
         LOG(ERROR) << "ANeuralNetworksModel_setOperandValueFromMemory setting operand " << index
                    << " which has operand type that is not fully specified";
         return ANEURALNETWORKS_BAD_DATA;
@@ -331,13 +324,11 @@
                    << " that is not in AHARDWAREBUFFER_FORMAT_BLOB format";
         return ANEURALNETWORKS_UNMAPPABLE;
     }
-    if (!isExtensionOperandType(operand.type)) {
-        uint32_t neededLength = sizeOfData(operand.type, operand.dimensions);
-        if (neededLength != length) {
-            LOG(ERROR) << "ANeuralNetworksModel_setOperandValueFromMemory setting " << length
-                       << " bytes when needing " << neededLength;
-            return ANEURALNETWORKS_BAD_DATA;
-        }
+    uint32_t neededLength = TypeManager::get()->getSizeOfData(operand);
+    if (neededLength != length) {
+        LOG(ERROR) << "ANeuralNetworksModel_setOperandValueFromMemory setting " << length
+                   << " bytes when needing " << neededLength;
+        return ANEURALNETWORKS_BAD_DATA;
     }
     if (!memory->validateSize(offset, length)) {
         return ANEURALNETWORKS_BAD_DATA;
@@ -458,15 +449,6 @@
     return ANEURALNETWORKS_NO_ERROR;
 }
 
-void ModelBuilder::setExtensionNameToPrefixMap(
-        const std::map<std::string, uint16_t>& extensionNameToPrefix) {
-    mExtensionNameToPrefix = extensionNameToPrefix;
-}
-
-const std::map<std::string, uint16_t>& ModelBuilder::getExtensionNameToPrefixMap() const {
-    return mExtensionNameToPrefix;
-}
-
 int ModelBuilder::createCompilation(CompilationBuilder** compilation,
                                     const std::vector<std::shared_ptr<Device>>& devices,
                                     bool forceNoFallback) {
@@ -579,21 +561,44 @@
     model->outputIndexes = mOutputIndexes;
     model->operandValues = mSmallOperandValues;
     model->relaxComputationFloat32toFloat16 = mRelaxComputationFloat32toFloat16;
+    model->extensionNameToPrefix = getExtensionNameToPrefixMap();
 
     uint32_t count = mMemories.size();
     model->pools.resize(count);
     for (uint32_t i = 0; i < count; i++) {
         model->pools[i] = mMemories[i]->getHidlMemory();
     }
+}
 
-    std::vector<Model::ExtensionNameAndPrefix> extensionNameToPrefixVec;
-    for (auto& nameAndPrefix : mExtensionNameToPrefix) {
-        extensionNameToPrefixVec.push_back({
-                .name = nameAndPrefix.first,
-                .prefix = nameAndPrefix.second,
+std::vector<Model::ExtensionNameAndPrefix> ModelBuilder::getExtensionNameToPrefixMap() const {
+    std::vector<Model::ExtensionNameAndPrefix> extensionNameToPrefix;
+    std::set<uint16_t> prefixSet;
+
+    auto addExtensionWithPrefix = [&extensionNameToPrefix, &prefixSet](uint16_t prefix) {
+        if (!prefixSet.insert(prefix).second) {
+            return;
+        }
+        const Extension* extension;
+        CHECK(TypeManager::get()->getExtensionInfo(prefix, &extension));
+        extensionNameToPrefix.push_back({
+                .name = extension->name,
+                .prefix = prefix,
         });
+    };
+
+    constexpr uint8_t kLowBitsType =
+            static_cast<uint8_t>(Model::ExtensionTypeEncoding::LOW_BITS_TYPE);
+    for (const auto& operand : mOperands) {
+        if (isExtensionOperandType(operand.type)) {
+            addExtensionWithPrefix(static_cast<uint32_t>(operand.type) >> kLowBitsType);
+        }
     }
-    model->extensionNameToPrefix = extensionNameToPrefixVec;
+    for (const auto& operation : mOperations) {
+        if (isExtensionOperationType(operation.type)) {
+            addExtensionWithPrefix(static_cast<uint32_t>(operation.type) >> kLowBitsType);
+        }
+    }
+    return extensionNameToPrefix;
 }
 
 }  // namespace nn
diff --git a/runtime/ModelBuilder.h b/runtime/ModelBuilder.h
index e5cd468..1ef3ef0 100644
--- a/runtime/ModelBuilder.h
+++ b/runtime/ModelBuilder.h
@@ -54,9 +54,6 @@
     int relaxComputationFloat32toFloat16(bool allow);
     bool isComputationFloat32RelaxedToFloat16() const { return mRelaxComputationFloat32toFloat16; }
 
-    void setExtensionNameToPrefixMap(const std::map<std::string, uint16_t>&);
-    const std::map<std::string, uint16_t>& getExtensionNameToPrefixMap() const;
-
     int finish();
     bool isFinished() const { return mCompletedModel; }
     bool isValid() const { return !mInvalidModel; }
@@ -120,6 +117,12 @@
     // Copies the large values to a shared memory, if we have any.
     int copyLargeValuesToSharedMemory();
 
+    // Returns the list of extension names and corresponding numeric "prefixes"
+    // of operand and operation type values used in the model.
+    //
+    // Devices rely on this mapping to interpret extension types.
+    std::vector<Model::ExtensionNameAndPrefix> getExtensionNameToPrefixMap() const;
+
     // The operations of the graph.
     std::vector<Operation> mOperations;
     // The mapping from sorted index to the original index of operations in mOperations.
@@ -168,12 +171,6 @@
     // 'false' indicates TENSOR_FLOAT32 must be calculated using at least the
     // range and precision of the IEEE 754 32-bit floating-point format.
     bool mRelaxComputationFloat32toFloat16 = false;
-
-    // Maps extension names to numeric "prefixes" of operand and operation
-    // type values. Devices rely on these prefixes to interpret extension types.
-    // TODO(b/123523457): Have a global name-to-prefix mapping instead of
-    // storing it here.
-    std::map<std::string, uint16_t> mExtensionNameToPrefix;
 };
 
 }  // namespace nn
diff --git a/runtime/TypeManager.cpp b/runtime/TypeManager.cpp
new file mode 100644
index 0000000..00219a2
--- /dev/null
+++ b/runtime/TypeManager.cpp
@@ -0,0 +1,169 @@
+/*
+ * Copyright (C) 2019 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#define LOG_TAG "TypeManager"
+
+#include "TypeManager.h"
+
+#include "Utils.h"
+
+#include <algorithm>
+
+namespace android {
+namespace nn {
+namespace {
+
+const uint8_t kLowBitsType = static_cast<uint8_t>(Model::ExtensionTypeEncoding::LOW_BITS_TYPE);
+const uint32_t kMaxPrefix =
+        (1 << static_cast<uint8_t>(Model::ExtensionTypeEncoding::HIGH_BITS_PREFIX)) - 1;
+
+// Checks if the two structures contain the same information. The order of
+// operand types within the structures does not matter.
+bool equal(const Extension& a, const Extension& b) {
+    NN_RET_CHECK_EQ(a.name, b.name);
+    // Relies on the fact that TypeManager sorts operandTypes.
+    NN_RET_CHECK(a.operandTypes == b.operandTypes);
+    return true;
+}
+
+}  // namespace
+
+TypeManager::TypeManager() {
+    VLOG(MANAGER) << "TypeManager::TypeManager";
+    findAvailableExtensions();
+}
+
+void TypeManager::findAvailableExtensions() {
+    for (const std::shared_ptr<Device>& device : mDeviceManager->getDrivers()) {
+        for (const Extension extension : device->getSupportedExtensions()) {
+            registerExtension(extension, device->getName());
+        }
+    }
+}
+
+bool TypeManager::registerExtension(Extension extension, const std::string& deviceName) {
+    if (mDisabledExtensions.find(extension.name) != mDisabledExtensions.end()) {
+        LOG(ERROR) << "Extension " << extension.name << " is disabled";
+        return false;
+    }
+
+    std::sort(extension.operandTypes.begin(), extension.operandTypes.end(),
+              [](const Extension::OperandTypeInformation& a,
+                 const Extension::OperandTypeInformation& b) {
+                  return static_cast<uint16_t>(a.type) < static_cast<uint16_t>(b.type);
+              });
+
+    std::map<std::string, Extension>::iterator it;
+    bool isNew;
+    std::tie(it, isNew) = mExtensionNameToExtension.emplace(extension.name, extension);
+    if (isNew) {
+        VLOG(MANAGER) << "Registered extension " << extension.name;
+        mExtensionNameToFirstDevice.emplace(extension.name, deviceName);
+    } else if (!equal(extension, it->second)) {
+        LOG(ERROR) << "Devices " << mExtensionNameToFirstDevice[extension.name] << " and "
+                   << deviceName << " provide inconsistent information for extension "
+                   << extension.name << ", which is therefore disabled";
+        mExtensionNameToExtension.erase(it);
+        mDisabledExtensions.insert(extension.name);
+        return false;
+    }
+    return true;
+}
+
+bool TypeManager::getExtensionPrefix(const std::string& extensionName, uint16_t* prefix) {
+    auto it = mExtensionNameToPrefix.find(extensionName);
+    if (it != mExtensionNameToPrefix.end()) {
+        *prefix = it->second;
+    } else {
+        NN_RET_CHECK_LE(mPrefixToExtension.size(), kMaxPrefix) << "Too many extensions in use";
+        *prefix = mPrefixToExtension.size();
+        mExtensionNameToPrefix[extensionName] = *prefix;
+        mPrefixToExtension.push_back(&mExtensionNameToExtension[extensionName]);
+    }
+    return true;
+}
+
+bool TypeManager::getExtensionType(const char* extensionName, uint16_t typeWithinExtension,
+                                   int32_t* type) {
+    uint16_t prefix;
+    NN_RET_CHECK(getExtensionPrefix(extensionName, &prefix));
+    *type = (prefix << kLowBitsType) | typeWithinExtension;
+    return true;
+}
+
+bool TypeManager::getExtensionInfo(uint16_t prefix, const Extension** extension) const {
+    NN_RET_CHECK_NE(prefix, 0u) << "prefix=0 does not correspond to an extension";
+    NN_RET_CHECK_LT(prefix, mPrefixToExtension.size()) << "Unknown extension prefix";
+    *extension = mPrefixToExtension[prefix];
+    return true;
+}
+
+bool TypeManager::getExtensionOperandTypeInfo(
+        OperandType type, const Extension::OperandTypeInformation** info) const {
+    uint32_t operandType = static_cast<uint32_t>(type);
+    uint16_t prefix = operandType >> kLowBitsType;
+    uint16_t typeWithinExtension = operandType & ((1 << kLowBitsType) - 1);
+    const Extension* extension;
+    NN_RET_CHECK(getExtensionInfo(prefix, &extension))
+            << "Cannot find extension corresponding to prefix " << prefix;
+    auto it = std::lower_bound(
+            extension->operandTypes.begin(), extension->operandTypes.end(), typeWithinExtension,
+            [](const Extension::OperandTypeInformation& info, uint32_t typeSought) {
+                return static_cast<uint16_t>(info.type) < typeSought;
+            });
+    NN_RET_CHECK(it != extension->operandTypes.end() &&
+                 static_cast<uint16_t>(it->type) == typeWithinExtension)
+            << "Cannot find operand type " << typeWithinExtension << " in extension "
+            << extension->name;
+    *info = &*it;
+    return true;
+}
+
+bool TypeManager::isTensorType(OperandType type) const {
+    if (!isExtensionOperandType(type)) {
+        return !nonExtensionOperandTypeIsScalar(static_cast<int>(type));
+    }
+    const Extension::OperandTypeInformation* info;
+    CHECK(getExtensionOperandTypeInfo(type, &info));
+    return info->isTensor;
+}
+
+uint32_t TypeManager::getSizeOfData(OperandType type,
+                                    const std::vector<uint32_t>& dimensions) const {
+    if (!isExtensionOperandType(type)) {
+        return nonExtensionOperandSizeOfData(type, dimensions);
+    }
+
+    const Extension::OperandTypeInformation* info;
+    CHECK(getExtensionOperandTypeInfo(type, &info));
+
+    if (!info->isTensor) {
+        return info->byteSize;
+    }
+
+    if (dimensions.empty()) {
+        return 0;
+    }
+
+    uint32_t size = info->byteSize;
+    for (auto dimension : dimensions) {
+        size *= dimension;
+    }
+    return size;
+}
+
+}  // namespace nn
+}  // namespace android
diff --git a/runtime/TypeManager.h b/runtime/TypeManager.h
new file mode 100644
index 0000000..b1cdcc0
--- /dev/null
+++ b/runtime/TypeManager.h
@@ -0,0 +1,133 @@
+/*
+ * Copyright (C) 2019 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef ANDROID_ML_NN_RUNTIME_EXTENSION_MANAGER_H
+#define ANDROID_ML_NN_RUNTIME_EXTENSION_MANAGER_H
+
+#include "HalInterfaces.h"
+#include "Manager.h"
+
+#include <map>
+#include <set>
+#include <string>
+
+namespace android {
+namespace nn {
+
+// Manages runtime operand and operation type information.
+//
+// This class gathers information about extension types from all devices
+// and provides a unified way to access information about any known type.
+class TypeManager {
+   public:
+    static TypeManager* get() {
+        static TypeManager manager;
+        return &manager;
+    }
+
+    // Creates an operand/operation type corresponding to a given extension
+    // name and type within extension.
+    //
+    // Returns false if the extension is unknown.
+    bool getExtensionType(const char* extensionName, uint16_t typeWithinExtension, int32_t* type);
+
+    // Looks up information about the extension corresponding to the given prefix
+    //
+    // Returns false if no extension corresponds to the given prefix.
+    bool getExtensionInfo(uint16_t prefix, const Extension** extension) const;
+
+    // Looks up information about an extension operand type
+    //
+    // Returns false if the extension or type is unknown.
+    bool getExtensionOperandTypeInfo(OperandType type,
+                                     const Extension::OperandTypeInformation** info) const;
+
+    // Returns true if an operand type is a tensor type.
+    //
+    // Aborts if the type is an unknown extension type.
+    bool isTensorType(OperandType type) const;
+
+    // Returns the amount of space needed to store a value of the dimensions and
+    // type of this operand. For a tensor with unspecified rank or at least one
+    // unspecified dimension, returns zero.
+    //
+    // Aborts if the type is an unknown extension type.
+    uint32_t getSizeOfData(const Operand& operand) const {
+        return getSizeOfData(operand.type, operand.dimensions);
+    }
+
+    // Returns the amount of space needed to store a value of the specified
+    // dimensions and type. For a tensor with unspecified rank or at least one
+    // unspecified dimension, returns zero.
+    //
+    // Aborts if the type is an unknown extension type.
+    uint32_t getSizeOfData(OperandType type, const std::vector<uint32_t>& dimensions) const;
+
+    // This method is intended for use only by internal unit tests.
+    //
+    // Registers an extension.
+    //
+    // Returns true if the registration was successful.
+    bool forTest_registerExtension(const Extension& extension) {
+        return registerExtension(extension, "INTERNAL TEST");
+    }
+
+    // This method is intended for use only by internal unit tests.
+    //
+    // Resets the internal state.
+    //
+    // After calling forTest_registerExtension() any number of times, call
+    // forTest_reset() to return to the state as if forTest_registerExtension()
+    // had never been called. Note that forTest_reset() resets all internal
+    // state (including assigned prefixes) and re-discovers extensions from
+    // available devices.
+    void forTest_reset() { *this = TypeManager(); }
+
+   private:
+    TypeManager();
+    void findAvailableExtensions();
+    bool registerExtension(Extension extension, const std::string& deviceName);
+
+    // Returns the numeric "prefix" value corresponding to an extension.
+    //
+    // Returns false when assigning a new prefix would overflow uint16_t.
+    bool getExtensionPrefix(const std::string& extensionName, uint16_t* prefix);
+
+    const DeviceManager* mDeviceManager = DeviceManager::get();
+
+    // Contains all registered extensions.
+    std::map<std::string, Extension> mExtensionNameToExtension;
+
+    // Contains the name of the first discovered device that supports an
+    // extension. Used for error reporting.
+    std::map<std::string, std::string> mExtensionNameToFirstDevice;
+
+    // When multiple devices report conflicting information about an extension,
+    // the extension is disabled.
+    std::set<std::string> mDisabledExtensions;
+
+    // The fields below are used to support efficient extension name to
+    // prefix mapping. New prefixes are created by getExtensionPrefix.
+    std::map<std::string, uint16_t> mExtensionNameToPrefix;
+    // Entries of mPrefixToExtension point into mExtensionNameToExtension.
+    // prefix=0 corresponds to no extension and should never be looked up.
+    std::vector<Extension*> mPrefixToExtension = {nullptr};
+};
+
+}  // namespace nn
+}  // namespace android
+
+#endif  // ANDROID_ML_NN_RUNTIME_EXTENSION_MANAGER_H
diff --git a/runtime/test/TestValidation.cpp b/runtime/test/TestValidation.cpp
index 2cfc008..36efeb6 100644
--- a/runtime/test/TestValidation.cpp
+++ b/runtime/test/TestValidation.cpp
@@ -25,12 +25,13 @@
 
 #ifndef NNTEST_ONLY_PUBLIC_API
 #include "NeuralNetworksExtensions.h"
-const char* kTestExtensionName = "vendor.test.validation_test_extension";
+#include "TypeManager.h"
 #endif
 
 // This file tests all the validations done by the Neural Networks API.
 
 namespace {
+
 class ValidationTest : public ::testing::Test {
    protected:
     virtual void SetUp() {}
@@ -54,15 +55,10 @@
         return mNumOperands++;
     }
 
-#ifndef NNTEST_ONLY_PUBLIC_API
-    int32_t getExtensionOperandType(uint16_t typeWithinExtension) {
-        int32_t result;
-        EXPECT_EQ(ANeuralNetworksModel_getExtensionOperandType(mModel, kTestExtensionName,
-                                                               typeWithinExtension, &result),
-                  ANEURALNETWORKS_NO_ERROR);
-        return result;
+    uint32_t addOperand(const ANeuralNetworksOperandType& operandType) {
+        EXPECT_EQ(ANeuralNetworksModel_addOperand(mModel, &operandType), ANEURALNETWORKS_NO_ERROR);
+        return mNumOperands++;
     }
-#endif
 
     uint32_t addTensorOperand(int32_t type = ANEURALNETWORKS_TENSOR_FLOAT32) {
         uint32_t dimensions[] = {2};
@@ -71,8 +67,7 @@
                 .dimensionCount = sizeof(dimensions) / sizeof(dimensions[0]),
                 .dimensions = dimensions,
         };
-        EXPECT_EQ(ANeuralNetworksModel_addOperand(mModel, &operandType), ANEURALNETWORKS_NO_ERROR);
-        return mNumOperands++;
+        return addOperand(operandType);
     }
 
     void createModel() {
@@ -110,6 +105,42 @@
     };
 };
 
+#ifndef NNTEST_ONLY_PUBLIC_API
+constexpr const char* kTestExtensionName = "com.android.test_extension";
+constexpr int32_t kTestExtensionTensorType = ANEURALNETWORKS_TENSOR_QUANT8_SYMM_PER_CHANNEL;
+
+class ValidationTestModelExtensions : public ValidationTestModel {
+   protected:
+    virtual void SetUp() {
+        ValidationTestModel::SetUp();
+        EXPECT_TRUE(::android::nn::TypeManager::get()->forTest_registerExtension({
+                .name = kTestExtensionName,
+                .operandTypes =
+                        {
+                                {
+                                        .type = kTestExtensionTensorType,
+                                        .isTensor = true,
+                                        .byteSize = 1,
+                                },
+                        },
+        }));
+    }
+
+    virtual void TearDown() {
+        ::android::nn::TypeManager::get()->forTest_reset();
+        ValidationTestModel::TearDown();
+    }
+
+    int32_t getExtensionOperandType(uint16_t typeWithinExtension) {
+        int32_t result;
+        EXPECT_EQ(ANeuralNetworksModel_getExtensionOperandType(mModel, kTestExtensionName,
+                                                               typeWithinExtension, &result),
+                  ANEURALNETWORKS_NO_ERROR);
+        return result;
+    }
+};
+#endif
+
 class ValidationTestIdentify : public ValidationTestModel {
     virtual void SetUp() {
         ValidationTestModel::SetUp();
@@ -259,9 +290,14 @@
 }
 
 #ifndef NNTEST_ONLY_PUBLIC_API
-TEST_F(ValidationTestModel, SetOperandSymmPerChannelQuantParams_ExtensionOperand) {
-    const int32_t operandIndex = addTensorOperand(
-            getExtensionOperandType(ANEURALNETWORKS_TENSOR_QUANT8_SYMM_PER_CHANNEL));
+TEST_F(ValidationTestModelExtensions, AddOperand_UnknownPrefix) {
+    ANeuralNetworksOperandType type = {.type = -1};
+    ASSERT_EQ(ANeuralNetworksModel_addOperand(mModel, &type), ANEURALNETWORKS_BAD_DATA);
+}
+
+TEST_F(ValidationTestModelExtensions, SetOperandSymmPerChannelQuantParams_ExtensionOperand) {
+    const int32_t operandIndex =
+            addTensorOperand(getExtensionOperandType(kTestExtensionTensorType));
 
     float scales[2] = {1.0, 2.0};
     ANeuralNetworksSymmPerChannelQuantParams channelQuant = {
@@ -275,8 +311,9 @@
               ANEURALNETWORKS_BAD_DATA);
 }
 
-TEST_F(ValidationTestModel, SetOperandExtensionData) {
-    const int32_t operandIndex = addTensorOperand(getExtensionOperandType(0));
+TEST_F(ValidationTestModelExtensions, SetOperandExtensionData) {
+    const int32_t operandIndex =
+            addTensorOperand(getExtensionOperandType(kTestExtensionTensorType));
     const int32_t data = 42;
     const size_t dataLength = sizeof(data);
     EXPECT_EQ(
@@ -294,19 +331,45 @@
               ANEURALNETWORKS_NO_ERROR);
 }
 
-TEST_F(ValidationTestModel, SetOperandExtensionData_Empty) {
-    const int32_t operandIndex = addTensorOperand(getExtensionOperandType(0));
+TEST_F(ValidationTestModelExtensions, SetOperandExtensionData_Empty) {
+    const int32_t operandIndex =
+            addTensorOperand(getExtensionOperandType(kTestExtensionTensorType));
     EXPECT_EQ(ANeuralNetworksModel_setOperandExtensionData(mModel, operandIndex, nullptr, 0),
               ANEURALNETWORKS_NO_ERROR);
 }
 
-TEST_F(ValidationTestModel, SetOperandExtensionData_NonExtensionOperand) {
+TEST_F(ValidationTestModelExtensions, SetOperandExtensionData_NonExtensionOperand) {
     const int32_t operandIndex = addTensorOperand();
     const int32_t data = 42;
     const size_t dataLength = sizeof(data);
     EXPECT_EQ(ANeuralNetworksModel_setOperandExtensionData(mModel, operandIndex, &data, dataLength),
               ANEURALNETWORKS_BAD_DATA);
 }
+
+TEST_F(ValidationTestModelExtensions, SetOperandValue_UnspecifiedDimension) {
+    const uint32_t dimensions[2] = {3, 0};
+    ANeuralNetworksOperandType type = {
+            .type = getExtensionOperandType(kTestExtensionTensorType),
+            .dimensionCount = 2,
+            .dimensions = dimensions,
+    };
+    const int32_t operandIndex = addOperand(type);
+    char buffer[20];
+    EXPECT_EQ(ANeuralNetworksModel_setOperandValue(mModel, operandIndex, buffer, sizeof(buffer)),
+              ANEURALNETWORKS_BAD_DATA);
+}
+
+TEST_F(ValidationTestModelExtensions, SetOperandValue_UnspecifiedRank) {
+    ANeuralNetworksOperandType type = {
+            .type = getExtensionOperandType(kTestExtensionTensorType),
+            .dimensionCount = 0,
+            .dimensions = nullptr,
+    };
+    const int32_t operandIndex = addOperand(type);
+    char buffer[20];
+    EXPECT_EQ(ANeuralNetworksModel_setOperandValue(mModel, operandIndex, buffer, sizeof(buffer)),
+              ANEURALNETWORKS_BAD_DATA);
+}
 #endif
 
 TEST_F(ValidationTestModel, SetOptionalOperand) {