Add Extensions API Please see the commit message of change Ia9b99015eec7a48bbf969cbe503862271f09adca for motivation. Bug: 118604960 Bug: 118606929 Test: NeuralNetworksTest_static Change-Id: I2703b963f040a846889554888ddd984eac6b6c08 Merged-In: I2703b963f040a846889554888ddd984eac6b6c08 (cherry picked from commit 2543307e0a5caa66abc67cea0a5fe8244e442712)
diff --git a/runtime/NeuralNetworks.cpp b/runtime/NeuralNetworks.cpp index b277395..d244935 100644 --- a/runtime/NeuralNetworks.cpp +++ b/runtime/NeuralNetworks.cpp
@@ -29,6 +29,7 @@ #include "Manager.h" #include "Memory.h" #include "ModelBuilder.h" +#include "NeuralNetworksExtensions.h" #include "NeuralNetworksOEM.h" #include "Tracing.h" #include "Utils.h" @@ -953,3 +954,62 @@ delete e; } } + +int ANeuralNetworksDevice_getExtensionSupport(const ANeuralNetworksDevice* device, + const char* extensionName, + bool* isExtensionSupported) { + if (device == nullptr || extensionName == nullptr || isExtensionSupported == nullptr) { + LOG(ERROR) << "ANeuralNetworksDevice_getExtensionSupport passed a nullptr"; + return ANEURALNETWORKS_UNEXPECTED_NULL; + } + + Device* d = reinterpret_cast<Device*>(const_cast<ANeuralNetworksDevice*>(device)); + hidl_vec<Extension> supportedExtensions = d->getSupportedExtensions(); + + *isExtensionSupported = false; + for (const Extension& supportedExtension : supportedExtensions) { + if (supportedExtension.name == extensionName) { + *isExtensionSupported = true; + break; + } + } + + return ANEURALNETWORKS_NO_ERROR; +} + +int ANeuralNetworksModel_getExtensionOperandType(ANeuralNetworksModel* model, + const char* extensionName, + uint16_t operandCodeWithinExtension, + int32_t* type) { + NNTRACE_RT(NNTRACE_PHASE_PREPARATION, "ANeuralNetworksModel_getExtensionOperandType"); + if (!model || !extensionName || !type) { + LOG(ERROR) << "ANeuralNetworksModel_getExtensionOperandType passed a nullptr"; + return ANEURALNETWORKS_UNEXPECTED_NULL; + } + ModelBuilder* m = reinterpret_cast<ModelBuilder*>(model); + return m->getExtensionType(extensionName, operandCodeWithinExtension, type); +} + +int ANeuralNetworksModel_getExtensionOperationType(ANeuralNetworksModel* model, + const char* extensionName, + uint16_t operationCodeWithinExtension, + ANeuralNetworksOperationType* type) { + NNTRACE_RT(NNTRACE_PHASE_PREPARATION, "ANeuralNetworksModel_getExtensionOperationType"); + if (!model || !extensionName || !type) { + LOG(ERROR) << "ANeuralNetworksModel_getExtensionOperationType passed a nullptr"; + return ANEURALNETWORKS_UNEXPECTED_NULL; + } + ModelBuilder* m = reinterpret_cast<ModelBuilder*>(model); + return m->getExtensionType(extensionName, operationCodeWithinExtension, type); +} + +int ANeuralNetworksModel_setOperandExtensionData(ANeuralNetworksModel* model, int32_t index, + const void* data, size_t length) { + NNTRACE_RT(NNTRACE_PHASE_PREPARATION, "ANeuralNetworksModel_setOperandExtensionData"); + if (!model || (!data && length != 0)) { + LOG(ERROR) << "ANeuralNetworksModel_setOperandExtensionData passed a nullptr"; + return ANEURALNETWORKS_UNEXPECTED_NULL; + } + ModelBuilder* m = reinterpret_cast<ModelBuilder*>(model); + return m->setOperandExtensionData(index, data, length); +}