Add Extensions API

Please see the commit message of change Ia9b99015eec7a48bbf969cbe503862271f09adca
for motivation.

Bug: 118604960
Bug: 118606929
Test: NeuralNetworksTest_static
Change-Id: I2703b963f040a846889554888ddd984eac6b6c08
Merged-In: I2703b963f040a846889554888ddd984eac6b6c08
(cherry picked from commit 2543307e0a5caa66abc67cea0a5fe8244e442712)
diff --git a/runtime/CompilationBuilder.cpp b/runtime/CompilationBuilder.cpp
index 5dab730..1b2dea5 100644
--- a/runtime/CompilationBuilder.cpp
+++ b/runtime/CompilationBuilder.cpp
@@ -63,7 +63,11 @@
                     return n;
                 }
                 if (mModel->hasOEMOperation()) {
-                    LOG(ERROR) << "Because of OEM op cannot fall back to CPU";
+                    LOG(ERROR) << "Cannot fall back to CPU because of an OEM operation";
+                    return n;
+                }
+                if (mModel->hasExtensionOperation()) {
+                    LOG(ERROR) << "Cannot fall back to CPU because of an extension operation";
                     return n;
                 }
                 break;
diff --git a/runtime/ExecutionBuilder.cpp b/runtime/ExecutionBuilder.cpp
index 55572f6..449ede2 100644
--- a/runtime/ExecutionBuilder.cpp
+++ b/runtime/ExecutionBuilder.cpp
@@ -52,15 +52,14 @@
     if (data == nullptr) {
         state = ModelArgumentInfo::HAS_NO_VALUE;
     } else {
-        int n = updateDimensionInfo(operand, type);
-        if (n != ANEURALNETWORKS_NO_ERROR) {
-            return n;
-        }
-        uint32_t neededLength = sizeOfData(operand.type, dimensions);
-        if (operand.type != OperandType::OEM && neededLength != length) {
-            LOG(ERROR) << "Setting argument with invalid length: " << length
-                       << ", expected length: " << neededLength;
-            return ANEURALNETWORKS_BAD_DATA;
+        NN_RETURN_IF_ERROR(updateDimensionInfo(operand, type));
+        if (!isExtensionOperandType(operand.type) && operand.type != OperandType::OEM) {
+            uint32_t neededLength = sizeOfData(operand.type, dimensions);
+            if (neededLength != length) {
+                LOG(ERROR) << "Setting argument with invalid length: " << length
+                           << ", expected length: " << neededLength;
+                return ANEURALNETWORKS_BAD_DATA;
+            }
         }
         state = ModelArgumentInfo::POINTER;
     }
@@ -71,15 +70,14 @@
 
 int ModelArgumentInfo::setFromMemory(const Operand& operand, const ANeuralNetworksOperandType* type,
                                      uint32_t poolIndex, uint32_t offset, uint32_t length) {
-    int n = updateDimensionInfo(operand, type);
-    if (n != ANEURALNETWORKS_NO_ERROR) {
-        return n;
-    }
-    uint32_t neededLength = sizeOfData(operand.type, dimensions);
-    if (operand.type != OperandType::OEM && neededLength != length) {
-        LOG(ERROR) << "Setting argument with invalid length: " << length
-                   << ", expected length: " << neededLength;
-        return ANEURALNETWORKS_BAD_DATA;
+    NN_RETURN_IF_ERROR(updateDimensionInfo(operand, type));
+    if (!isExtensionOperandType(operand.type) && operand.type != OperandType::OEM) {
+        uint32_t neededLength = sizeOfData(operand.type, dimensions);
+        if (neededLength != length) {
+            LOG(ERROR) << "Setting argument with invalid length: " << length
+                       << ", expected length: " << neededLength;
+            return ANEURALNETWORKS_BAD_DATA;
+        }
     }
 
     state = ModelArgumentInfo::MEMORY;
@@ -89,13 +87,23 @@
 }
 
 int ModelArgumentInfo::setFromTemporaryMemory(const Operand& operand, uint32_t poolIndex,
-                                              uint32_t offset) {
-    int n = updateDimensionInfo(operand, nullptr);
-    if (n != ANEURALNETWORKS_NO_ERROR) {
-        return n;
+                                              uint32_t offset, uint32_t length) {
+    NN_RETURN_IF_ERROR(updateDimensionInfo(operand, nullptr));
+    if (!isExtensionOperandType(operand.type) && operand.type != OperandType::OEM) {
+        uint32_t neededLength = sizeOfData(operand.type, dimensions);
+        if (neededLength != length) {
+            LOG(ERROR) << "Setting argument with invalid length: " << length
+                       << ", expected length: " << neededLength;
+            return ANEURALNETWORKS_BAD_DATA;
+        }
     }
+
     state = ModelArgumentInfo::MEMORY;
-    locationAndLength = {.poolIndex = poolIndex, .offset = offset, .length = sizeOfData(operand)};
+    locationAndLength = {
+            .poolIndex = poolIndex,
+            .offset = offset,
+            .length = length,
+    };
     buffer = nullptr;
     return ANEURALNETWORKS_NO_ERROR;
 }
@@ -577,7 +585,10 @@
     //     ExecutionBuilder::setOutputFromMemory()
 
     uint32_t poolIndex = mMemories.add(memory);
-    return inputOrOutputInfo->setFromTemporaryMemory(inputOrOutputOperand, poolIndex, offset);
+    uint32_t length =
+            mDevice->getSizeOfData(inputOrOutputOperand, mModel->getExtensionNameToPrefixMap());
+    return inputOrOutputInfo->setFromTemporaryMemory(inputOrOutputOperand, poolIndex, offset,
+                                                     length);
 }
 
 static void logArguments(const char* kind, const std::vector<ModelArgumentInfo>& args) {
diff --git a/runtime/ExecutionBuilder.h b/runtime/ExecutionBuilder.h
index a7a6430..cbb3d04 100644
--- a/runtime/ExecutionBuilder.h
+++ b/runtime/ExecutionBuilder.h
@@ -64,7 +64,8 @@
                        uint32_t length);
     int setFromMemory(const Operand& operand, const ANeuralNetworksOperandType* type,
                       uint32_t poolIndex, uint32_t offset, uint32_t length);
-    int setFromTemporaryMemory(const Operand& operand, uint32_t poolIndex, uint32_t offset);
+    int setFromTemporaryMemory(const Operand& operand, uint32_t poolIndex, uint32_t offset,
+                               uint32_t length);
     int updateDimensionInfo(const Operand& operand, const ANeuralNetworksOperandType* newType);
 };
 
diff --git a/runtime/ExecutionPlan.cpp b/runtime/ExecutionPlan.cpp
index 16fac44..dfe2912 100644
--- a/runtime/ExecutionPlan.cpp
+++ b/runtime/ExecutionPlan.cpp
@@ -51,49 +51,34 @@
 
 typedef std::function<void(uint32_t)> OperationReadyCallback;
 
-bool createSymmPerChannelQuantParams(ANeuralNetworksSymmPerChannelQuantParams* outChannelQuant,
-                                     const Operand::ExtraParams& extraParams) {
-    if (extraParams.getDiscriminator() !=
-        V1_2::Operand::ExtraParams::hidl_discriminator::channelQuant) {
-        LOG(ERROR) << "Unexpected extraParams discriminator, expected channelQuant"
-                   << " received " << static_cast<int>(extraParams.getDiscriminator());
-        return false;
-    }
-    auto& fromChannelQuant = extraParams.channelQuant();
-    *outChannelQuant = {
-            .channelDim = fromChannelQuant.channelDim,
-            .scaleCount = static_cast<uint32_t>(fromChannelQuant.scales.size()),
-            .scales = fromChannelQuant.scales.data(),
-    };
-    return true;
-}
-
 int copyOperandExtraParams(ModelBuilder& model, uint32_t toOperandIndex,
                            const Operand& fromOperand) {
-    switch (fromOperand.type) {
-        case OperandType::TENSOR_QUANT8_SYMM_PER_CHANNEL: {
-            ANeuralNetworksSymmPerChannelQuantParams toChannelQuant;
-            if (!createSymmPerChannelQuantParams(&toChannelQuant, fromOperand.extraParams)) {
-                return ANEURALNETWORKS_BAD_DATA;
-            }
-            int n = model.setOperandSymmPerChannelQuantParams(toOperandIndex, toChannelQuant);
-            if (n != ANEURALNETWORKS_NO_ERROR) {
-                LOG(ERROR) << "Failed setOperandSymmPerChannelQuantParams";
-                return ANEURALNETWORKS_BAD_DATA;
-            }
-        } break;
-
-        default: {
-            if (fromOperand.extraParams.getDiscriminator() !=
-                V1_2::Operand::ExtraParams::hidl_discriminator::none) {
-                LOG(ERROR) << "Unexpected extraParams discriminator, expected none"
-                           << " received "
-                           << static_cast<int>(fromOperand.extraParams.getDiscriminator());
-                return ANEURALNETWORKS_BAD_DATA;
-            }
-        }
+    if (fromOperand.type == OperandType::TENSOR_QUANT8_SYMM_PER_CHANNEL &&
+        fromOperand.extraParams.getDiscriminator() ==
+                Operand::ExtraParams::hidl_discriminator::channelQuant) {
+        auto& fromChannelQuant = fromOperand.extraParams.channelQuant();
+        ANeuralNetworksSymmPerChannelQuantParams toChannelQuant = {
+                .channelDim = fromChannelQuant.channelDim,
+                .scaleCount = static_cast<uint32_t>(fromChannelQuant.scales.size()),
+                .scales = fromChannelQuant.scales.data(),
+        };
+        return model.setOperandSymmPerChannelQuantParams(toOperandIndex, toChannelQuant);
+    } else if (isExtensionOperandType(fromOperand.type) &&
+               fromOperand.extraParams.getDiscriminator() ==
+                       Operand::ExtraParams::hidl_discriminator::extension) {
+        hidl_vec<uint8_t> extensionData = fromOperand.extraParams.extension();
+        return model.setOperandExtensionData(toOperandIndex, extensionData.data(),
+                                             extensionData.size());
+    } else if (fromOperand.extraParams.getDiscriminator() !=
+                       Operand::ExtraParams::hidl_discriminator::none ||
+               fromOperand.type == OperandType::TENSOR_QUANT8_SYMM_PER_CHANNEL) {
+        LOG(ERROR) << "Type " << toString(fromOperand.type)
+                   << " has an unexpected extraParams discriminator: "
+                   << static_cast<int>(fromOperand.extraParams.getDiscriminator());
+        return ANEURALNETWORKS_BAD_DATA;
+    } else {
+        return ANEURALNETWORKS_NO_ERROR;
     }
-    return ANEURALNETWORKS_NO_ERROR;
 }
 
 // This class tracks whether we know the value of an operand as operations
@@ -386,6 +371,7 @@
     }
 
     mSubModel.relaxComputationFloat32toFloat16(fromModel->isComputationFloat32RelaxedToFloat16());
+    mSubModel.setExtensionNameToPrefixMap(fromModel->getExtensionNameToPrefixMap());
 
     // Input order: mModelInputs, mTempsAsSubModelInputs, mOutputsAsSubModelInputs
     // Output order: mModelOutputs, mTempsAsSubModelOutputs
@@ -596,7 +582,8 @@
                     subModelInputsAndOutputs =
                             std::make_shared<Controller::SubModelInputsAndOutputsType>();
                 }
-                const uint32_t size = sizeOfData(fromModelOperand);
+                const uint32_t size = step->getDevice()->getSizeOfData(
+                        fromModelOperand, fromModel->getExtensionNameToPrefixMap());
                 totalSizeOfTemporaries += alignBytesNeeded(totalSizeOfTemporaries, size);
                 subModelInputsAndOutputs->insert(std::make_pair(fromModelOperandIndex, totalSizeOfTemporaries));
                 totalSizeOfTemporaries += size;
@@ -849,10 +836,8 @@
     // Figure out where each operation will best execute.
     // The value of the vector is the index in the devices vector.
     std::vector<int> bestDeviceForOperation(operationCount);
-    int status = findBestDeviceForEachOperation(preference, devices, &bestDeviceForOperation);
-    if (status != ANEURALNETWORKS_NO_ERROR) {
-        return status;
-    }
+    NN_RETURN_IF_ERROR(
+            findBestDeviceForEachOperation(preference, devices, &bestDeviceForOperation));
 
     // If one device will run all the operations, we don't need to split the work.
     if (std::adjacent_find(bestDeviceForOperation.begin(), bestDeviceForOperation.end(),
@@ -958,7 +943,7 @@
         case OperandType::TENSOR_OEM_BYTE:
             return device->getQuantized8Performance();
         default:
-            nnAssert(false);
+            CHECK(isExtensionOperandType(operandType)) << "Unhandled base operand type";
             return device->getQuantized8Performance();
     }
 }
diff --git a/runtime/Manager.cpp b/runtime/Manager.cpp
index 3d73465..e3439a5 100644
--- a/runtime/Manager.cpp
+++ b/runtime/Manager.cpp
@@ -36,6 +36,40 @@
 namespace android {
 namespace nn {
 
+uint32_t Device::getSizeOfData(const Operand& operand,
+                               const std::map<std::string, uint16_t>& extensionNameToPrefix) const {
+    if (!isExtensionOperandType(operand.type)) {
+        return sizeOfData(operand);
+    }
+
+    // A slow naive implementation.
+    // TODO(b/123178734): Speed it up.
+    uint32_t operandType = static_cast<uint32_t>(operand.type);
+    uint8_t kLowBitsType = static_cast<uint8_t>(Model::ExtensionTypeEncoding::LOW_BITS_TYPE);
+    uint16_t prefix = operandType >> kLowBitsType;
+    uint16_t typeWithinExtension = operandType & ((1 << kLowBitsType) - 1);
+    for (const Extension& extension : getSupportedExtensions()) {
+        if (extensionNameToPrefix.at(extension.name) != prefix) {
+            continue;
+        }
+        for (auto& extensionOperandType : extension.operandTypes) {
+            if (extensionOperandType.type == typeWithinExtension) {
+                uint32_t numElements = 1;
+                if (extensionOperandType.isTensor) {
+                    for (auto dimension : operand.dimensions) {
+                        numElements *= dimension;
+                    }
+                }
+                return numElements * extensionOperandType.byteSize;
+            }
+        }
+    }
+
+    CHECK(false) << "Cannot determine the size of extension operand type "
+                 << toString(operand.type);
+    return 0;
+}
+
 // A Device with actual underlying driver
 class DriverDevice : public Device {
     DISALLOW_IMPLICIT_CONSTRUCTORS(DriverDevice);
@@ -51,7 +85,9 @@
     VersionedIDevice* getInterface() override { return &mInterface; }
     int64_t getFeatureLevel() override { return mInterface.getFeatureLevel(); }
     int32_t getType() const override { return mInterface.getType(); }
-    void getSupportedOperations(const Model& hidlModel, hidl_vec<bool>* supported) override;
+    hidl_vec<Extension> getSupportedExtensions() const override;
+    void getSupportedOperations(const Model& hidlModel,
+                                hidl_vec<bool>* supportedOperations) override;
     PerformanceInfo getFloat32Performance() const override { return mFloat32Performance; }
     PerformanceInfo getQuantized8Performance() const override { return mQuantized8Performance; }
     PerformanceInfo getRelaxedFloat32toFloat16Performance() const override {
@@ -68,6 +104,7 @@
     PerformanceInfo mFloat32Performance;
     PerformanceInfo mQuantized8Performance;
     PerformanceInfo mRelaxedFloat32toFloat16Performance;
+    hidl_vec<Extension> mSupportedExtensions;
 
 #ifdef NN_DEBUGGABLE
     // For debugging: behavior of IDevice::getSupportedOperations for SampleDriver.
@@ -90,12 +127,14 @@
                          : 0;
 #endif  // NN_DEBUGGABLE
 
+    bool success = true;
     ErrorStatus status = ErrorStatus::GENERAL_FAILURE;
+
     Capabilities capabilities;
     std::tie(status, capabilities) = mInterface.getCapabilities();
-
     if (status != ErrorStatus::NONE) {
         LOG(ERROR) << "IDevice::getCapabilities returned the error " << toString(status);
+        success = false;
     } else {
         VLOG(MANAGER) << "Capab " << capabilities.float32Performance.execTime;
         VLOG(MANAGER) << "Capab " << capabilities.quantized8Performance.execTime;
@@ -105,14 +144,24 @@
         mRelaxedFloat32toFloat16Performance = capabilities.relaxedFloat32toFloat16Performance;
     }
 
-    auto result = mInterface.getVersionString();
+    std::tie(status, mVersionString) = mInterface.getVersionString();
     // TODO(miaowang): add a validation test case for in case of error.
-    if (result.first != ErrorStatus::NONE) {
+    if (status != ErrorStatus::NONE) {
         LOG(ERROR) << "IDevice::getVersionString returned the error " << toString(status);
-    } else {
-        mVersionString = result.second;
+        success = false;
     }
-    return status == ErrorStatus::NONE;
+
+    std::tie(status, mSupportedExtensions) = mInterface.getSupportedExtensions();
+    if (status != ErrorStatus::NONE) {
+        LOG(ERROR) << "IDevice::getSupportedExtensions returned the error " << toString(status);
+        success = false;
+    }
+
+    return success;
+}
+
+hidl_vec<Extension> DriverDevice::getSupportedExtensions() const {
+    return mSupportedExtensions;
 }
 
 void DriverDevice::getSupportedOperations(const Model& hidlModel,
@@ -232,7 +281,9 @@
     VersionedIDevice* getInterface() override { return nullptr; }
     int64_t getFeatureLevel() override { return kFeatureLevel; }
     int32_t getType() const override { return ANEURALNETWORKS_DEVICE_CPU; }
-    void getSupportedOperations(const Model& hidlModel, hidl_vec<bool>* supported) override;
+    hidl_vec<Extension> getSupportedExtensions() const override { return {/* No extensions. */}; }
+    void getSupportedOperations(const Model& hidlModel,
+                                hidl_vec<bool>* supportedOperations) override;
     PerformanceInfo getFloat32Performance() const override { return kPerformance; }
     PerformanceInfo getQuantized8Performance() const override { return kPerformance; }
     PerformanceInfo getRelaxedFloat32toFloat16Performance() const override { return kPerformance; }
@@ -250,19 +301,16 @@
     const PerformanceInfo kPerformance = {.execTime = 1.0f, .powerUsage = 1.0f};
 };
 
-void CpuDevice::getSupportedOperations(const Model& hidlModel, hidl_vec<bool>* supported) {
+void CpuDevice::getSupportedOperations(const Model& hidlModel,
+                                       hidl_vec<bool>* supportedOperations) {
     const size_t count = hidlModel.operations.size();
-    hidl_vec<bool> supportedOperations(count);
+    hidl_vec<bool> result(count);
     for (size_t i = 0; i < count; i++) {
         // TODO(b/119870033): Decide whether and how post-P operations would be supported on CPU.
         // CPU fallback should support all the operations except for OEM_OPERATION
-        if (hidlModel.operations[i].type == OperationType::OEM_OPERATION) {
-            supportedOperations[i] = false;
-        } else {
-            supportedOperations[i] = true;
-        }
+        result[i] = hidlModel.operations[i].type != OperationType::OEM_OPERATION;
     }
-    *supported = std::move(supportedOperations);
+    *supportedOperations = std::move(result);
 }
 
 int CpuDevice::prepareModel(const Model& hidlModel, ExecutionPreference executionPreference,
diff --git a/runtime/Manager.h b/runtime/Manager.h
index 1bf0462..22ea991 100644
--- a/runtime/Manager.h
+++ b/runtime/Manager.h
@@ -42,13 +42,18 @@
     virtual const char* getVersionString() const = 0;
     virtual int64_t getFeatureLevel() = 0;
     virtual int32_t getType() const = 0;
-    virtual void getSupportedOperations(const Model& hidlModel, hidl_vec<bool>* supported) = 0;
+    virtual hidl_vec<Extension> getSupportedExtensions() const = 0;
+    virtual void getSupportedOperations(const Model& hidlModel,
+                                        hidl_vec<bool>* supportedOperations) = 0;
     virtual PerformanceInfo getFloat32Performance() const = 0;
     virtual PerformanceInfo getQuantized8Performance() const = 0;
     virtual PerformanceInfo getRelaxedFloat32toFloat16Performance() const = 0;
 
     virtual int prepareModel(const Model& hidlModel, ExecutionPreference executionPreference,
                              std::shared_ptr<VersionedIPreparedModel>* preparedModel) = 0;
+
+    uint32_t getSizeOfData(const Operand& operand,
+                           const std::map<std::string, uint16_t>& extensionNameToPrefix) const;
 };
 
 // Manages the NN HAL devices.  Only one instance of this class will exist.
diff --git a/runtime/ModelBuilder.cpp b/runtime/ModelBuilder.cpp
index ec7262b..d0145cc 100644
--- a/runtime/ModelBuilder.cpp
+++ b/runtime/ModelBuilder.cpp
@@ -32,6 +32,9 @@
 // The maximum number of operands and operations that a model may have.
 const uint32_t MAX_NUMBER_OF_OPERANDS = 0xFFFFFFFE;
 const uint32_t MAX_NUMBER_OF_OPERATIONS = 0xFFFFFFFE;
+const uint32_t MAX_NUMBER_OF_EXTENSIONS_IN_USE =
+        // -2 because prefix 0x0000 corresponds to no extension.
+        (1 << static_cast<uint8_t>(Model::ExtensionTypeEncoding::HIGH_BITS_PREFIX)) - 2;
 
 bool ModelBuilder::badState(const char* name) {
     if (mCompletedModel) {
@@ -45,6 +48,25 @@
     return false;
 }
 
+int ModelBuilder::getExtensionType(const char* extensionName, uint16_t typeWithinExtension,
+                                   int32_t* type) {
+    uint16_t prefix;
+    auto it = mExtensionNameToPrefix.find(extensionName);
+    if (it != mExtensionNameToPrefix.end()) {
+        prefix = it->second;
+    } else {
+        if (mExtensionNameToPrefix.size() == MAX_NUMBER_OF_EXTENSIONS_IN_USE) {
+            LOG(ERROR) << "Too many extension types in use";
+            return ANEURALNETWORKS_BAD_DATA;
+        }
+        prefix = mExtensionNameToPrefix.size() + 1;
+        mExtensionNameToPrefix[extensionName] = prefix;
+    }
+    *type = (prefix << static_cast<uint8_t>(Model::ExtensionTypeEncoding::LOW_BITS_TYPE)) |
+            typeWithinExtension;
+    return ANEURALNETWORKS_NO_ERROR;
+}
+
 int ModelBuilder::addOperand(const ANeuralNetworksOperandType& type) {
     if (badState("addOperand")) {
         return ANEURALNETWORKS_BAD_STATE;
@@ -55,10 +77,7 @@
         LOG(WARNING) << "OEM data type is deprecated. Use Extensions instead.";
     }
 
-    int n = validateOperandType(type, "ANeuralNetworksModel_addOperand", true);
-    if (n != ANEURALNETWORKS_NO_ERROR) {
-        return n;
-    }
+    NN_RETURN_IF_ERROR(validateOperandType(type, "ANeuralNetworksModel_addOperand", true));
     size_t idx = mOperands.size();
     if (idx >= MAX_NUMBER_OF_OPERANDS) {
         LOG(ERROR) << "ANeuralNetworksModel_addOperand exceed max operands";
@@ -107,11 +126,13 @@
             return ANEURALNETWORKS_BAD_DATA;
         }
         uint32_t valueLength = static_cast<uint32_t>(length);
-        uint32_t neededLength = sizeOfData(operand.type, operand.dimensions);
-        if (operand.type != OperandType::OEM && neededLength != valueLength) {
-            LOG(ERROR) << "ANeuralNetworksModel_setOperandValue setting " << valueLength
-                       << " bytes when needing " << neededLength;
-            return ANEURALNETWORKS_BAD_DATA;
+        if (!isExtensionOperandType(operand.type) && operand.type != OperandType::OEM) {
+            uint32_t neededLength = sizeOfData(operand.type, operand.dimensions);
+            if (neededLength != valueLength) {
+                LOG(ERROR) << "ANeuralNetworksModel_setOperandValue setting " << valueLength
+                           << " bytes when needing " << neededLength;
+                return ANEURALNETWORKS_BAD_DATA;
+            }
         }
         if (valueLength <= ANEURALNETWORKS_MAX_SIZE_OF_IMMEDIATELY_COPIED_VALUES) {
             uint32_t existingSize = static_cast<uint32_t>(mSmallOperandValues.size());
@@ -146,8 +167,9 @@
     }
 
     if (index >= operandCount()) {
-        LOG(ERROR) << "setOperandSymmPerChannelQuantParams "
-                   << "setting operand extra params " << index << " of " << operandCount();
+        LOG(ERROR) << "ANeuralNetworksModel_setOperandSymmPerChannelQuantParams "
+                   << "setting per-channel quantization parameters for operand " << index << " of "
+                   << operandCount();
         return ANEURALNETWORKS_BAD_DATA;
     }
     Operand& operand = mOperands[index];
@@ -173,6 +195,45 @@
     return ANEURALNETWORKS_NO_ERROR;
 }
 
+int ModelBuilder::setOperandExtensionData(uint32_t index, const void* data, size_t length) {
+    if (badState("setOperandExtensionData")) {
+        return ANEURALNETWORKS_BAD_STATE;
+    }
+
+    if (index >= operandCount()) {
+        LOG(ERROR) << "ANeuralNetworksModel_setOperandExtensionData "
+                   << "setting extension data for operand " << index << " of " << operandCount();
+        return ANEURALNETWORKS_BAD_DATA;
+    }
+    Operand& operand = mOperands[index];
+
+    if (data == nullptr && length != 0) {
+        LOG(ERROR) << "ANeuralNetworksModel_setOperandExtensionData data is nullptr but length is "
+                   << length;
+        return ANEURALNETWORKS_BAD_DATA;
+    }
+    if (data != nullptr && length == 0) {
+        LOG(ERROR) << "ANeuralNetworksModel_setOperandExtensionData data is not nullptr but length "
+                   << "is zero";
+        return ANEURALNETWORKS_BAD_DATA;
+    }
+    if (!isExtensionOperandType(operand.type)) {
+        LOG(ERROR) << "ANeuralNetworksModel_setOperandExtensionData "
+                   << "setting extension data for a base operand type "
+                   << static_cast<int32_t>(operand.type);
+        return ANEURALNETWORKS_BAD_DATA;
+    }
+
+    if (data == nullptr) {
+        operand.extraParams.none();
+    } else {
+        operand.extraParams.extension(
+                hidl_vec<uint8_t>(reinterpret_cast<const uint8_t*>(data),
+                                  reinterpret_cast<const uint8_t*>(data) + length));
+    }
+    return ANEURALNETWORKS_NO_ERROR;
+}
+
 int ModelBuilder::copyLargeValuesToSharedMemory() {
     VLOG(MODEL) << __func__ << " has " << mLargeOperandValues.size() << " values.";
     if (!mLargeOperandValues.empty()) {
@@ -225,24 +286,27 @@
         return ANEURALNETWORKS_BAD_DATA;
     }
     Operand& operand = mOperands[index];
-    uint32_t neededLength = sizeOfData(operand.type, operand.dimensions);
     // Only BLOB format AHardwareBuffer can be used for constant data.
     if (memory->getHidlMemory().name() == "hardware_buffer") {
         LOG(ERROR) << "ANeuralNetworksModel_setOperandValueFromMemory passed an AHardwareBuffer"
                    << " that is not in AHARDWAREBUFFER_FORMAT_BLOB format";
         return ANEURALNETWORKS_UNMAPPABLE;
     }
-    if (neededLength != length) {
-        LOG(ERROR) << "ANeuralNetworksModel_setOperandValueFromMemory setting " << length
-                   << " bytes when needing " << neededLength;
-        return ANEURALNETWORKS_BAD_DATA;
+    if (!isExtensionOperandType(operand.type)) {
+        uint32_t neededLength = sizeOfData(operand.type, operand.dimensions);
+        if (neededLength != length) {
+            LOG(ERROR) << "ANeuralNetworksModel_setOperandValueFromMemory setting " << length
+                       << " bytes when needing " << neededLength;
+            return ANEURALNETWORKS_BAD_DATA;
+        }
     }
     if (!memory->validateSize(offset, length)) {
         return ANEURALNETWORKS_BAD_DATA;
     }
     operand.lifetime = OperandLifeTime::CONSTANT_REFERENCE;
-    operand.location = {
-            .poolIndex = mMemories.add(memory), .offset = offset, .length = neededLength};
+    operand.location = {.poolIndex = mMemories.add(memory),
+                        .offset = offset,
+                        .length = static_cast<uint32_t>(length)};
     return ANEURALNETWORKS_NO_ERROR;
 }
 
@@ -258,15 +322,14 @@
         LOG(WARNING) << "OEM_OPERATION is deprecated. Use Extensions instead.";
     }
 
-    if (!validCode(kNumberOfOperationTypes, kNumberOfOperationTypesOEM, type)) {
-        LOG(ERROR) << "ANeuralNetworksModel_addOperation invalid operations type " << type;
-        return ANEURALNETWORKS_BAD_DATA;
+    if (!isExtensionOperationType(operationType)) {
+        if (!validCode(kNumberOfOperationTypes, kNumberOfOperationTypesOEM, type)) {
+            LOG(ERROR) << "ANeuralNetworksModel_addOperation invalid operation type " << type;
+            return ANEURALNETWORKS_BAD_DATA;
+        }
     }
-    int n = validateOperation(type, inputCount, inputs, outputCount, outputs, mOperands,
-                              HalVersion::LATEST);
-    if (n != ANEURALNETWORKS_NO_ERROR) {
-        return n;
-    }
+    NN_RETURN_IF_ERROR(validateOperation(type, inputCount, inputs, outputCount, outputs, mOperands,
+                                         HalVersion::LATEST));
 
     uint32_t operationIndex = operationCount();
     if (operationIndex >= MAX_NUMBER_OF_OPERATIONS) {
@@ -283,6 +346,7 @@
         mOperands[i].numberOfConsumers++;
     }
     mHasOEMOperation |= (operationType == OperationType::OEM_OPERATION);
+    mHasExtensionOperation |= isExtensionOperationType(operationType);
 
     return ANEURALNETWORKS_NO_ERROR;
 }
@@ -351,6 +415,15 @@
     return ANEURALNETWORKS_NO_ERROR;
 }
 
+void ModelBuilder::setExtensionNameToPrefixMap(
+        const std::map<std::string, uint16_t>& extensionNameToPrefix) {
+    mExtensionNameToPrefix = extensionNameToPrefix;
+}
+
+const std::map<std::string, uint16_t>& ModelBuilder::getExtensionNameToPrefixMap() const {
+    return mExtensionNameToPrefix;
+}
+
 int ModelBuilder::createCompilation(CompilationBuilder** compilation,
                                     const std::vector<std::shared_ptr<Device>>& devices) {
     if (!mCompletedModel || mInvalidModel) {
@@ -465,6 +538,15 @@
     for (uint32_t i = 0; i < count; i++) {
         model->pools[i] = mMemories[i]->getHidlMemory();
     }
+
+    std::vector<Model::ExtensionNameAndPrefix> extensionNameToPrefixVec;
+    for (auto& nameAndPrefix : mExtensionNameToPrefix) {
+        extensionNameToPrefixVec.push_back({
+                .name = nameAndPrefix.first,
+                .prefix = nameAndPrefix.second,
+        });
+    }
+    model->extensionNameToPrefix = extensionNameToPrefixVec;
 }
 
 }  // namespace nn
diff --git a/runtime/ModelBuilder.h b/runtime/ModelBuilder.h
index 61edc5a..f3a1a94 100644
--- a/runtime/ModelBuilder.h
+++ b/runtime/ModelBuilder.h
@@ -34,7 +34,9 @@
 class Memory;
 
 class ModelBuilder {
-public:
+   public:
+    // Returns an operand/operation type corresponding to a given extension operand/operation type.
+    int getExtensionType(const char* extensionName, uint16_t typeWithinExtension, int32_t* type);
     // Adds an operand to the model.
     int addOperand(const ANeuralNetworksOperandType& type);
     int setOperandValue(uint32_t index, const void* buffer, size_t length);
@@ -42,6 +44,7 @@
                                   size_t length);
     int setOperandSymmPerChannelQuantParams(
             uint32_t index, const ANeuralNetworksSymmPerChannelQuantParams& extraParams);
+    int setOperandExtensionData(uint32_t index, const void* data, size_t length);
 
     int addOperation(ANeuralNetworksOperationType type, uint32_t inputCount, const uint32_t* inputs,
                      uint32_t outputCount, const uint32_t* outputs);
@@ -50,11 +53,15 @@
     int relaxComputationFloat32toFloat16(bool allow);
     bool isComputationFloat32RelaxedToFloat16() const { return mRelaxComputationFloat32toFloat16; }
 
+    void setExtensionNameToPrefixMap(const std::map<std::string, uint16_t>&);
+    const std::map<std::string, uint16_t>& getExtensionNameToPrefixMap() const;
+
     int finish();
     bool isFinished() const { return mCompletedModel; }
     bool isValid() const { return !mInvalidModel; }
 
     bool hasOEMOperation() const { return mHasOEMOperation; }
+    bool hasExtensionOperation() const { return mHasExtensionOperation; }
 
     int createCompilation(CompilationBuilder** compilation,
                           const std::vector<std::shared_ptr<Device>>& devices);
@@ -120,6 +127,8 @@
     std::vector<uint32_t> mSortedOperationIndexMap;
     // Is at least one of those operations an OEM_OPERATION?
     bool mHasOEMOperation = false;
+    // Is at least one of those operations an extension operation?
+    bool mHasExtensionOperation = false;
     // The description of the operands of the graph.
     std::vector<Operand> mOperands;
     // Specifies where to find the list of indexes identifying
@@ -156,6 +165,12 @@
     // 'false' indicates TENSOR_FLOAT32 must be calculated using at least the
     // range and precision of the IEEE 754 32-bit floating-point format.
     bool mRelaxComputationFloat32toFloat16 = false;
+
+    // Maps extension names to numeric "prefixes" of operand and operation
+    // type values. Devices rely on these prefixes to interpret extension types.
+    // TODO(b/123523457): Have a global name-to-prefix mapping instead of
+    // storing it here.
+    std::map<std::string, uint16_t> mExtensionNameToPrefix;
 };
 
 }  // namespace nn
diff --git a/runtime/NeuralNetworks.cpp b/runtime/NeuralNetworks.cpp
index b277395..d244935 100644
--- a/runtime/NeuralNetworks.cpp
+++ b/runtime/NeuralNetworks.cpp
@@ -29,6 +29,7 @@
 #include "Manager.h"
 #include "Memory.h"
 #include "ModelBuilder.h"
+#include "NeuralNetworksExtensions.h"
 #include "NeuralNetworksOEM.h"
 #include "Tracing.h"
 #include "Utils.h"
@@ -953,3 +954,62 @@
         delete e;
     }
 }
+
+int ANeuralNetworksDevice_getExtensionSupport(const ANeuralNetworksDevice* device,
+                                              const char* extensionName,
+                                              bool* isExtensionSupported) {
+    if (device == nullptr || extensionName == nullptr || isExtensionSupported == nullptr) {
+        LOG(ERROR) << "ANeuralNetworksDevice_getExtensionSupport passed a nullptr";
+        return ANEURALNETWORKS_UNEXPECTED_NULL;
+    }
+
+    Device* d = reinterpret_cast<Device*>(const_cast<ANeuralNetworksDevice*>(device));
+    hidl_vec<Extension> supportedExtensions = d->getSupportedExtensions();
+
+    *isExtensionSupported = false;
+    for (const Extension& supportedExtension : supportedExtensions) {
+        if (supportedExtension.name == extensionName) {
+            *isExtensionSupported = true;
+            break;
+        }
+    }
+
+    return ANEURALNETWORKS_NO_ERROR;
+}
+
+int ANeuralNetworksModel_getExtensionOperandType(ANeuralNetworksModel* model,
+                                                 const char* extensionName,
+                                                 uint16_t operandCodeWithinExtension,
+                                                 int32_t* type) {
+    NNTRACE_RT(NNTRACE_PHASE_PREPARATION, "ANeuralNetworksModel_getExtensionOperandType");
+    if (!model || !extensionName || !type) {
+        LOG(ERROR) << "ANeuralNetworksModel_getExtensionOperandType passed a nullptr";
+        return ANEURALNETWORKS_UNEXPECTED_NULL;
+    }
+    ModelBuilder* m = reinterpret_cast<ModelBuilder*>(model);
+    return m->getExtensionType(extensionName, operandCodeWithinExtension, type);
+}
+
+int ANeuralNetworksModel_getExtensionOperationType(ANeuralNetworksModel* model,
+                                                   const char* extensionName,
+                                                   uint16_t operationCodeWithinExtension,
+                                                   ANeuralNetworksOperationType* type) {
+    NNTRACE_RT(NNTRACE_PHASE_PREPARATION, "ANeuralNetworksModel_getExtensionOperationType");
+    if (!model || !extensionName || !type) {
+        LOG(ERROR) << "ANeuralNetworksModel_getExtensionOperationType passed a nullptr";
+        return ANEURALNETWORKS_UNEXPECTED_NULL;
+    }
+    ModelBuilder* m = reinterpret_cast<ModelBuilder*>(model);
+    return m->getExtensionType(extensionName, operationCodeWithinExtension, type);
+}
+
+int ANeuralNetworksModel_setOperandExtensionData(ANeuralNetworksModel* model, int32_t index,
+                                                 const void* data, size_t length) {
+    NNTRACE_RT(NNTRACE_PHASE_PREPARATION, "ANeuralNetworksModel_setOperandExtensionData");
+    if (!model || (!data && length != 0)) {
+        LOG(ERROR) << "ANeuralNetworksModel_setOperandExtensionData passed a nullptr";
+        return ANEURALNETWORKS_UNEXPECTED_NULL;
+    }
+    ModelBuilder* m = reinterpret_cast<ModelBuilder*>(model);
+    return m->setOperandExtensionData(index, data, length);
+}
diff --git a/runtime/VersionedInterfaces.cpp b/runtime/VersionedInterfaces.cpp
index 62e5aab..2532f5e 100644
--- a/runtime/VersionedInterfaces.cpp
+++ b/runtime/VersionedInterfaces.cpp
@@ -136,6 +136,27 @@
     return result;
 }
 
+std::pair<ErrorStatus, hidl_vec<Extension>> VersionedIDevice::getSupportedExtensions() {
+    NNTRACE_FULL(NNTRACE_LAYER_IPC, NNTRACE_PHASE_COMPILATION, "getSupportedExtensions");
+    if (mDeviceV1_2 != nullptr) {
+        std::pair<ErrorStatus, hidl_vec<Extension>> result;
+        Return<void> ret = mDeviceV1_2->getSupportedExtensions(
+                [&result](ErrorStatus error, const hidl_vec<Extension>& extensions) {
+                    result = std::make_pair(error, extensions);
+                });
+        if (!ret.isOk()) {
+            LOG(ERROR) << "getSupportedExtensions failure: " << ret.description();
+            return {ErrorStatus::GENERAL_FAILURE, {}};
+        }
+        return result;
+    } else if (mDeviceV1_0 != nullptr) {
+        return {ErrorStatus::NONE, {/* No extensions. */}};
+    } else {
+        LOG(ERROR) << "Device not available!";
+        return {ErrorStatus::DEVICE_UNAVAILABLE, {}};
+    }
+}
+
 std::pair<ErrorStatus, hidl_vec<bool>> VersionedIDevice::getSupportedOperations(
         const Model& model) {
     std::pair<ErrorStatus, hidl_vec<bool>> result;
diff --git a/runtime/VersionedInterfaces.h b/runtime/VersionedInterfaces.h
index ea37003..147c90a 100644
--- a/runtime/VersionedInterfaces.h
+++ b/runtime/VersionedInterfaces.h
@@ -74,6 +74,23 @@
     std::pair<ErrorStatus, Capabilities> getCapabilities();
 
     /**
+     * Gets information about extensions supported by the driver implementation.
+     *
+     * Extensions of category ExtensionCategory::BASE must not appear
+     * in the list.
+     *
+     * All extension operations and operands must be fully supported for the
+     * extension to appear in the list of supported extensions.
+     *
+     * @return status Error status of the call, must be:
+     *     - NONE if successful
+     *     - DEVICE_UNAVAILABLE if driver is offline or busy
+     *     - GENERAL_FAILURE if there is an unspecified error
+     * @return extensions A list of supported extensions.
+     */
+    std::pair<ErrorStatus, hidl_vec<Extension>> getSupportedExtensions();
+
+    /**
      * Gets the supported operations in a model.
      *
      * getSupportedSubgraph indicates which operations of a model are fully
diff --git a/runtime/include/NeuralNetworks.h b/runtime/include/NeuralNetworks.h
index 75aac93..20af5a2 100644
--- a/runtime/include/NeuralNetworks.h
+++ b/runtime/include/NeuralNetworks.h
@@ -4861,6 +4861,7 @@
 
 /**
  * ANeuralNetworksOperandType describes the type of an operand.
+ *
  * This structure is used to describe both scalars and tensors.
  *
  * A tensor operand type must have a specified rank (number of
@@ -4903,14 +4904,28 @@
  * Available since API level 27.
  */
 typedef struct ANeuralNetworksOperandType {
-    /** The data type, e.g ANEURALNETWORKS_INT8. */
+    /**
+     * The data type, e.g ANEURALNETWORKS_FLOAT32.
+     */
     int32_t type;
-    /** The number of dimensions (rank). It should be 0 for scalars. */
+
+    /**
+     * The number of dimensions (rank).
+     *
+     * Must be 0 for scalars.
+     */
     uint32_t dimensionCount;
-    /** The dimensions of the tensor. It should be nullptr for scalars. */
+
+    /**
+     * The dimensions of the tensor.
+     *
+     * Must be nullptr for scalars.
+     */
     const uint32_t* dimensions;
-    /** These two fields are only used for quantized tensors.
-     * They should be zero for scalars and non-fixed point tensors.
+
+    /**
+     * These two fields are only used for quantized tensors.
+     * They must be zero for all other types.
      * The dequantized value of each entry is (value - zeroPoint) * scale.
      */
     float scale;
diff --git a/runtime/include/NeuralNetworksExtensions.h b/runtime/include/NeuralNetworksExtensions.h
new file mode 100644
index 0000000..ca2e045
--- /dev/null
+++ b/runtime/include/NeuralNetworksExtensions.h
@@ -0,0 +1,117 @@
+/*
+ * Copyright (C) 2019 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef ANDROID_ML_NN_RUNTIME_NEURAL_NETWORKS_EXTENSIONS_H
+#define ANDROID_ML_NN_RUNTIME_NEURAL_NETWORKS_EXTENSIONS_H
+
+#include "NeuralNetworks.h"
+
+/******************************************************************
+ *
+ * IMPORTANT NOTICE:
+ *
+ *   This file is not intended for use by general developers -- only
+ *   by OEM applications.
+ *
+ *   Extensions source AND binary code relies on the definitions
+ *   here to be FROZEN ON ALL UPCOMING PLATFORM RELEASES.
+ *
+ *   - DO NOT MODIFY ENUMS (EXCEPT IF YOU ADD NEW 32-BIT VALUES)
+ *   - DO NOT MODIFY CONSTANTS OR FUNCTIONAL MACROS
+ *   - DO NOT CHANGE THE SIGNATURE OF FUNCTIONS IN ANY WAY
+ *   - DO NOT CHANGE THE LAYOUT OR SIZE OF STRUCTURES
+ */
+
+__BEGIN_DECLS
+
+#if __ANDROID_API__ >= __ANDROID_API_Q__
+
+/**
+ * Queries whether an extension is supported by the driver implementation of the specified device.
+ *
+ * @param device The representation of the specified device.
+ * @param extension The extension name.
+ * @param isExtensionSupported The boolean value indicating whether the extension is supported.
+ *
+ * @return ANEURALNETWORKS_NO_ERROR if successful.
+ *
+ * Available since API level 29.
+ */
+int ANeuralNetworksDevice_getExtensionSupport(const ANeuralNetworksDevice* device,
+                                              const char* extensionName, bool* isExtensionSupported)
+        __INTRODUCED_IN(29);
+
+/**
+ * Creates an operand type from an extension name and an extension operand code.
+ *
+ * See {@link ANeuralNetworksModel} for information on multithreaded usage.
+ *
+ * Available since API level 29.
+ *
+ * @param model The model to contain the operand.
+ * @param extensionName The extension name.
+ * @param operandCodeWithinExtension The extension operand code.
+ * @param type The operand type.
+ *
+ * @return ANEURALNETWORKS_NO_ERROR if successful.
+ */
+int ANeuralNetworksModel_getExtensionOperandType(ANeuralNetworksModel* model,
+                                                 const char* extensionName,
+                                                 uint16_t operandCodeWithinExtension, int32_t* type)
+        __INTRODUCED_IN(29);
+
+/**
+ * Creates an operation type from an extension name and an extension operation code.
+ *
+ * See {@link ANeuralNetworksModel} for information on multithreaded usage.
+ *
+ * Available since API level 29.
+ *
+ * @param model The model to contain the operation.
+ * @param extensionName The extension name.
+ * @param operationCodeWithinExtension The extension operation code.
+ * @param type The operation type.
+ *
+ * @return ANEURALNETWORKS_NO_ERROR if successful.
+ */
+int ANeuralNetworksModel_getExtensionOperationType(ANeuralNetworksModel* model,
+                                                   const char* extensionName,
+                                                   uint16_t operationCodeWithinExtension,
+                                                   ANeuralNetworksOperationType* type)
+        __INTRODUCED_IN(29);
+
+/**
+ * Sets extension operand parameters.
+ *
+ * Available since API level 29.
+ *
+ * @param model The model to be modified.
+ * @param index The index of the model operand we're setting.
+ * @param data A pointer to the extension operand data.
+ *             The data does not have to outlive the call to this function.
+ * @param length The size in bytes of the data value.
+ *
+ * @return ANEURALNETWORKS_NO_ERROR if successful.
+ */
+int ANeuralNetworksModel_setOperandExtensionData(ANeuralNetworksModel* model, int32_t index,
+                                                 const void* data, size_t length)
+        __INTRODUCED_IN(29);
+
+#endif  // __ANDROID_API__ >= __ANDROID_API_Q__
+
+__END_DECLS
+
+#endif  // ANDROID_ML_NN_RUNTIME_NEURAL_NETWORKS_EXTENSIONS_H
diff --git a/runtime/include/NeuralNetworksWrapper.h b/runtime/include/NeuralNetworksWrapper.h
index 24f4986..67effda 100644
--- a/runtime/include/NeuralNetworksWrapper.h
+++ b/runtime/include/NeuralNetworksWrapper.h
@@ -20,6 +20,7 @@
 #define ANDROID_ML_NN_RUNTIME_NEURAL_NETWORKS_WRAPPER_H
 
 #include "NeuralNetworks.h"
+#include "NeuralNetworksExtensions.h"
 
 #include <math.h>
 #include <optional>
@@ -194,6 +195,26 @@
         }
     }
 
+    int32_t getExtensionOperandType(const char* extensionName, uint16_t typeWithinExtension) {
+        int32_t result;
+        if (ANeuralNetworksModel_getExtensionOperandType(mModel, extensionName, typeWithinExtension,
+                                                         &result) != ANEURALNETWORKS_NO_ERROR) {
+            mValid = false;
+        }
+        return result;
+    }
+
+    ANeuralNetworksOperationType getExtensionOperationType(const char* extensionName,
+                                                           uint16_t typeWithinExtension) {
+        ANeuralNetworksOperationType result;
+        if (ANeuralNetworksModel_getExtensionOperationType(mModel, extensionName,
+                                                           typeWithinExtension,
+                                                           &result) != ANEURALNETWORKS_NO_ERROR) {
+            mValid = false;
+        }
+        return result;
+    }
+
     uint32_t addOperand(const OperandType* type) {
         if (ANeuralNetworksModel_addOperand(mModel, &(type->operandType)) !=
             ANEURALNETWORKS_NO_ERROR) {
diff --git a/runtime/libneuralnetworks.map.txt b/runtime/libneuralnetworks.map.txt
index f425ebf..59e7262 100644
--- a/runtime/libneuralnetworks.map.txt
+++ b/runtime/libneuralnetworks.map.txt
@@ -66,3 +66,11 @@
   local:
     *;
 };
+
+LIBNEURALNETWORKS_PLATFORM {
+  global:
+    ANeuralNetworksDevice_getExtensionSupport;
+    ANeuralNetworksModel_getExtensionOperandType;
+    ANeuralNetworksModel_getExtensionOperationType;
+    ANeuralNetworksModel_setOperandExtensionData;
+} LIBNEURALNETWORKS;
diff --git a/runtime/test/Android.bp b/runtime/test/Android.bp
index 14d4a90..dee442c 100644
--- a/runtime/test/Android.bp
+++ b/runtime/test/Android.bp
@@ -100,6 +100,7 @@
         "TestPartitioning.cpp",
         "TestPartitioningRandom.cpp",
         "TestIntrospectionControl.cpp",
+        "TestExtensions.cpp",
     ],
     static_libs: [
         "libgmock",
diff --git a/runtime/test/TestExtensions.cpp b/runtime/test/TestExtensions.cpp
new file mode 100644
index 0000000..b903805
--- /dev/null
+++ b/runtime/test/TestExtensions.cpp
@@ -0,0 +1,113 @@
+/*
+ * Copyright (C) 2019 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include <gtest/gtest.h>
+
+#include "HalInterfaces.h"
+#include "Manager.h"
+#include "NeuralNetworks.h"
+#include "NeuralNetworksExtensions.h"
+#include "SampleDriver.h"
+
+namespace {
+
+using DeviceManager = ::android::nn::DeviceManager;
+using SampleDriver = ::android::nn::sample_driver::SampleDriver;
+
+const char* kTestDriverName = "extensions-test-driver";
+const char* kTestExtension1 = "vendor.test.one";
+const char* kTestExtension2 = "vendor.test.two";
+const char* kTestExtension3 = "vendor.test.three";
+
+class TestDriver : public SampleDriver {
+   public:
+    TestDriver() : SampleDriver(kTestDriverName) {}
+    ~TestDriver() override {}
+
+    Return<void> getSupportedExtensions(getSupportedExtensions_cb cb) override {
+        cb(ErrorStatus::NONE, {
+                                      {.name = kTestExtension1},
+                                      {.name = kTestExtension2},
+                                      {.name = kTestExtension3},
+                              });
+        return Void();
+    }
+
+    Return<void> getCapabilities_1_1(getCapabilities_1_1_cb cb) override {
+        cb(ErrorStatus::NONE, {/* Dummy zero-filled capabilities. */});
+        return Void();
+    }
+
+    Return<void> getSupportedOperations_1_2(const Model&, getSupportedOperations_cb) override {
+        CHECK(false) << "not implemented";
+        return Void();
+    }
+};
+
+class ExtensionsTest : public ::testing::Test {
+   protected:
+    virtual void SetUp() {
+        // This is needed before we have the CPU fallback path being treated as a Device.
+        // TODO(miaowang): remove once b/72506261 is fixed.
+        if (DeviceManager::get()->getUseCpuOnly()) {
+            GTEST_SKIP();
+        }
+
+        DeviceManager::get()->forTest_registerDevice(kTestDriverName, new TestDriver());
+        mDevice = getDeviceByName(kTestDriverName);
+        ASSERT_NE(mDevice, nullptr);
+    }
+
+    virtual void TearDown() { DeviceManager::get()->forTest_reInitializeDeviceList(); }
+
+    ANeuralNetworksDevice* getDeviceByName(const std::string& name) {
+        ANeuralNetworksDevice* result = nullptr;
+        uint32_t numDevices = 0;
+        EXPECT_EQ(ANeuralNetworks_getDeviceCount(&numDevices), ANEURALNETWORKS_NO_ERROR);
+        EXPECT_GE(numDevices, 1u);
+        for (uint32_t i = 0; i < numDevices; i++) {
+            ANeuralNetworksDevice* device = nullptr;
+            EXPECT_EQ(ANeuralNetworks_getDevice(i, &device), ANEURALNETWORKS_NO_ERROR);
+            const char* buffer = nullptr;
+            EXPECT_EQ(ANeuralNetworksDevice_getName(device, &buffer), ANEURALNETWORKS_NO_ERROR);
+            if (name.compare(buffer) == 0) {
+                EXPECT_EQ(result, nullptr) << "multiple devices named " << name;
+                result = device;
+            }
+        }
+        return result;
+    }
+
+    bool testDriverSupportsExtension(const char* extensionName) {
+        bool result;
+        EXPECT_EQ(ANeuralNetworksDevice_getExtensionSupport(mDevice, extensionName, &result),
+                  ANEURALNETWORKS_NO_ERROR);
+        return result;
+    }
+
+   private:
+    ANeuralNetworksDevice* mDevice;
+};
+
+TEST_F(ExtensionsTest, DeviceReportsSupportedExtensions) {
+    EXPECT_TRUE(testDriverSupportsExtension(kTestExtension1));
+    EXPECT_FALSE(testDriverSupportsExtension("vendor.test.unknown"));
+    EXPECT_FALSE(testDriverSupportsExtension("asdfasdfas"));
+    EXPECT_TRUE(testDriverSupportsExtension(kTestExtension2));
+    EXPECT_TRUE(testDriverSupportsExtension(kTestExtension3));
+}
+
+}  // namespace
diff --git a/runtime/test/TestValidation.cpp b/runtime/test/TestValidation.cpp
index 55e0fbb..954e58e 100644
--- a/runtime/test/TestValidation.cpp
+++ b/runtime/test/TestValidation.cpp
@@ -22,6 +22,11 @@
 #include <sys/mman.h>
 #include <string>
 
+#ifndef NNTEST_ONLY_PUBLIC_API
+#include "NeuralNetworksExtensions.h"
+const char* kTestExtensionName = "vendor.test.validation_test_extension";
+#endif
+
 // This file tests all the validations done by the Neural Networks API.
 
 namespace {
@@ -41,18 +46,39 @@
         ValidationTest::TearDown();
     }
 
-    void createModel() {
-        uint32_t dimensions[]{1};
-        ANeuralNetworksOperandType tensorType{.type = ANEURALNETWORKS_TENSOR_FLOAT32,
-                                              .dimensionCount = 1,
-                                              .dimensions = dimensions};
-        ANeuralNetworksOperandType scalarType{
-                .type = ANEURALNETWORKS_INT32, .dimensionCount = 0, .dimensions = nullptr};
+    uint32_t addScalarOperand(int32_t type = ANEURALNETWORKS_INT32) {
+        ANeuralNetworksOperandType operandType = {
+                .type = type, .dimensionCount = 0, .dimensions = nullptr};
+        EXPECT_EQ(ANeuralNetworksModel_addOperand(mModel, &operandType), ANEURALNETWORKS_NO_ERROR);
+        return mNumOperands++;
+    }
 
-        ASSERT_EQ(ANeuralNetworksModel_addOperand(mModel, &tensorType), ANEURALNETWORKS_NO_ERROR);
-        ASSERT_EQ(ANeuralNetworksModel_addOperand(mModel, &tensorType), ANEURALNETWORKS_NO_ERROR);
-        ASSERT_EQ(ANeuralNetworksModel_addOperand(mModel, &scalarType), ANEURALNETWORKS_NO_ERROR);
-        ASSERT_EQ(ANeuralNetworksModel_addOperand(mModel, &tensorType), ANEURALNETWORKS_NO_ERROR);
+#ifndef NNTEST_ONLY_PUBLIC_API
+    int32_t getExtensionOperandType(uint16_t typeWithinExtension) {
+        int32_t result;
+        EXPECT_EQ(ANeuralNetworksModel_getExtensionOperandType(mModel, kTestExtensionName,
+                                                               typeWithinExtension, &result),
+                  ANEURALNETWORKS_NO_ERROR);
+        return result;
+    }
+#endif
+
+    uint32_t addTensorOperand(int32_t type = ANEURALNETWORKS_TENSOR_FLOAT32) {
+        uint32_t dimensions[] = {2};
+        ANeuralNetworksOperandType operandType = {
+                .type = type,
+                .dimensionCount = sizeof(dimensions) / sizeof(dimensions[0]),
+                .dimensions = dimensions,
+        };
+        EXPECT_EQ(ANeuralNetworksModel_addOperand(mModel, &operandType), ANEURALNETWORKS_NO_ERROR);
+        return mNumOperands++;
+    }
+
+    void createModel() {
+        addTensorOperand();
+        addTensorOperand();
+        addScalarOperand();
+        addTensorOperand();
         uint32_t inList[3]{0, 1, 2};
         uint32_t outList[1]{3};
         ASSERT_EQ(ANeuralNetworksModel_addOperation(mModel, ANEURALNETWORKS_ADD, 3, inList, 1,
@@ -64,6 +90,7 @@
         mNumOperations = 1;
     }
 
+    uint32_t mNumOperands = 0;
     uint32_t mNumOperations = 0;
     ANeuralNetworksModel* mModel = nullptr;
 };
@@ -174,33 +201,80 @@
 }
 
 TEST_F(ValidationTestModel, SetOperandSymmPerChannelQuantParams) {
-    uint32_t dim = 2;
+    const int32_t operandIndex = addTensorOperand(ANEURALNETWORKS_TENSOR_QUANT8_SYMM_PER_CHANNEL);
 
-    ANeuralNetworksOperandType quant8SymmPerChannel{
-            .type = ANEURALNETWORKS_TENSOR_QUANT8_SYMM_PER_CHANNEL,
-            .dimensionCount = 1,
-            .dimensions = &dim,
-            .scale = 0.0f,
-            .zeroPoint = 0,
-    };
-    EXPECT_EQ(ANeuralNetworksModel_addOperand(mModel, &quant8SymmPerChannel),
-              ANEURALNETWORKS_NO_ERROR);
-
-    float scale = 1.0f;
-    ANeuralNetworksSymmPerChannelQuantParams channelQuant{
+    float scales[2] = {1.0, 2.0};
+    ANeuralNetworksSymmPerChannelQuantParams channelQuant = {
             .channelDim = 0,
-            .scaleCount = 1,
-            .scales = &scale,
+            .scaleCount = 2,
+            .scales = scales,
     };
 
-    EXPECT_EQ(ANeuralNetworksModel_setOperandSymmPerChannelQuantParams(nullptr, 0, &channelQuant),
+    EXPECT_EQ(ANeuralNetworksModel_setOperandSymmPerChannelQuantParams(nullptr, operandIndex,
+                                                                       &channelQuant),
               ANEURALNETWORKS_UNEXPECTED_NULL);
-    EXPECT_EQ(ANeuralNetworksModel_setOperandSymmPerChannelQuantParams(mModel, 0, nullptr),
-              ANEURALNETWORKS_UNEXPECTED_NULL);
-    EXPECT_EQ(ANeuralNetworksModel_setOperandSymmPerChannelQuantParams(mModel, 100, &channelQuant),
+    EXPECT_EQ(
+            ANeuralNetworksModel_setOperandSymmPerChannelQuantParams(mModel, operandIndex, nullptr),
+            ANEURALNETWORKS_UNEXPECTED_NULL);
+    EXPECT_EQ(ANeuralNetworksModel_setOperandSymmPerChannelQuantParams(mModel, operandIndex + 1,
+                                                                       &channelQuant),
+              ANEURALNETWORKS_BAD_DATA);
+    EXPECT_EQ(ANeuralNetworksModel_setOperandSymmPerChannelQuantParams(mModel, operandIndex,
+                                                                       &channelQuant),
+              ANEURALNETWORKS_NO_ERROR);
+}
+
+#ifndef NNTEST_ONLY_PUBLIC_API
+TEST_F(ValidationTestModel, SetOperandSymmPerChannelQuantParams_ExtensionOperand) {
+    const int32_t operandIndex = addTensorOperand(
+            getExtensionOperandType(ANEURALNETWORKS_TENSOR_QUANT8_SYMM_PER_CHANNEL));
+
+    float scales[2] = {1.0, 2.0};
+    ANeuralNetworksSymmPerChannelQuantParams channelQuant = {
+            .channelDim = 0,
+            .scaleCount = 2,
+            .scales = scales,
+    };
+
+    EXPECT_EQ(ANeuralNetworksModel_setOperandSymmPerChannelQuantParams(mModel, operandIndex,
+                                                                       &channelQuant),
               ANEURALNETWORKS_BAD_DATA);
 }
 
+TEST_F(ValidationTestModel, SetOperandExtensionData) {
+    const int32_t operandIndex = addTensorOperand(getExtensionOperandType(0));
+    const int32_t data = 42;
+    const size_t dataLength = sizeof(data);
+    EXPECT_EQ(
+            ANeuralNetworksModel_setOperandExtensionData(nullptr, operandIndex, &data, dataLength),
+            ANEURALNETWORKS_UNEXPECTED_NULL);
+    EXPECT_EQ(
+            ANeuralNetworksModel_setOperandExtensionData(mModel, operandIndex, nullptr, dataLength),
+            ANEURALNETWORKS_UNEXPECTED_NULL);
+    EXPECT_EQ(ANeuralNetworksModel_setOperandExtensionData(mModel, operandIndex, &data, 0),
+              ANEURALNETWORKS_BAD_DATA);
+    EXPECT_EQ(ANeuralNetworksModel_setOperandExtensionData(mModel, operandIndex + 1, &data,
+                                                           dataLength),
+              ANEURALNETWORKS_BAD_DATA);
+    EXPECT_EQ(ANeuralNetworksModel_setOperandExtensionData(mModel, operandIndex, &data, dataLength),
+              ANEURALNETWORKS_NO_ERROR);
+}
+
+TEST_F(ValidationTestModel, SetOperandExtensionData_Empty) {
+    const int32_t operandIndex = addTensorOperand(getExtensionOperandType(0));
+    EXPECT_EQ(ANeuralNetworksModel_setOperandExtensionData(mModel, operandIndex, nullptr, 0),
+              ANEURALNETWORKS_NO_ERROR);
+}
+
+TEST_F(ValidationTestModel, SetOperandExtensionData_NonExtensionOperand) {
+    const int32_t operandIndex = addTensorOperand();
+    const int32_t data = 42;
+    const size_t dataLength = sizeof(data);
+    EXPECT_EQ(ANeuralNetworksModel_setOperandExtensionData(mModel, operandIndex, &data, dataLength),
+              ANEURALNETWORKS_BAD_DATA);
+}
+#endif
+
 TEST_F(ValidationTestModel, SetOptionalOperand) {
     ANeuralNetworksOperandType floatType{
             .type = ANEURALNETWORKS_FLOAT32, .dimensionCount = 0, .dimensions = nullptr};
@@ -1196,4 +1270,27 @@
     }
 }
 
+#ifndef NNTEST_ONLY_PUBLIC_API
+TEST(ValidationTestDevice, GetExtensionSupport) {
+    bool result;
+    EXPECT_EQ(ANeuralNetworksDevice_getExtensionSupport(nullptr, kTestExtensionName, &result),
+              ANEURALNETWORKS_UNEXPECTED_NULL);
+
+    uint32_t numDevices = 0;
+    EXPECT_EQ(ANeuralNetworks_getDeviceCount(&numDevices), ANEURALNETWORKS_NO_ERROR);
+
+    for (uint32_t i = 0; i < numDevices; i++) {
+        SCOPED_TRACE(i);
+        ANeuralNetworksDevice* device;
+        EXPECT_EQ(ANeuralNetworks_getDevice(i, &device), ANEURALNETWORKS_NO_ERROR);
+        EXPECT_EQ(ANeuralNetworksDevice_getExtensionSupport(device, kTestExtensionName, nullptr),
+                  ANEURALNETWORKS_UNEXPECTED_NULL);
+        EXPECT_EQ(ANeuralNetworksDevice_getExtensionSupport(device, nullptr, &result),
+                  ANEURALNETWORKS_UNEXPECTED_NULL);
+        EXPECT_EQ(ANeuralNetworksDevice_getExtensionSupport(device, kTestExtensionName, &result),
+                  ANEURALNETWORKS_NO_ERROR);
+    }
+}
+#endif
+
 }  // namespace