Add Extensions API
Please see the commit message of change Ia9b99015eec7a48bbf969cbe503862271f09adca
for motivation.
Bug: 118604960
Bug: 118606929
Test: NeuralNetworksTest_static
Change-Id: I2703b963f040a846889554888ddd984eac6b6c08
Merged-In: I2703b963f040a846889554888ddd984eac6b6c08
(cherry picked from commit 2543307e0a5caa66abc67cea0a5fe8244e442712)
diff --git a/runtime/NeuralNetworks.cpp b/runtime/NeuralNetworks.cpp
index b277395..d244935 100644
--- a/runtime/NeuralNetworks.cpp
+++ b/runtime/NeuralNetworks.cpp
@@ -29,6 +29,7 @@
#include "Manager.h"
#include "Memory.h"
#include "ModelBuilder.h"
+#include "NeuralNetworksExtensions.h"
#include "NeuralNetworksOEM.h"
#include "Tracing.h"
#include "Utils.h"
@@ -953,3 +954,62 @@
delete e;
}
}
+
+int ANeuralNetworksDevice_getExtensionSupport(const ANeuralNetworksDevice* device,
+ const char* extensionName,
+ bool* isExtensionSupported) {
+ if (device == nullptr || extensionName == nullptr || isExtensionSupported == nullptr) {
+ LOG(ERROR) << "ANeuralNetworksDevice_getExtensionSupport passed a nullptr";
+ return ANEURALNETWORKS_UNEXPECTED_NULL;
+ }
+
+ Device* d = reinterpret_cast<Device*>(const_cast<ANeuralNetworksDevice*>(device));
+ hidl_vec<Extension> supportedExtensions = d->getSupportedExtensions();
+
+ *isExtensionSupported = false;
+ for (const Extension& supportedExtension : supportedExtensions) {
+ if (supportedExtension.name == extensionName) {
+ *isExtensionSupported = true;
+ break;
+ }
+ }
+
+ return ANEURALNETWORKS_NO_ERROR;
+}
+
+int ANeuralNetworksModel_getExtensionOperandType(ANeuralNetworksModel* model,
+ const char* extensionName,
+ uint16_t operandCodeWithinExtension,
+ int32_t* type) {
+ NNTRACE_RT(NNTRACE_PHASE_PREPARATION, "ANeuralNetworksModel_getExtensionOperandType");
+ if (!model || !extensionName || !type) {
+ LOG(ERROR) << "ANeuralNetworksModel_getExtensionOperandType passed a nullptr";
+ return ANEURALNETWORKS_UNEXPECTED_NULL;
+ }
+ ModelBuilder* m = reinterpret_cast<ModelBuilder*>(model);
+ return m->getExtensionType(extensionName, operandCodeWithinExtension, type);
+}
+
+int ANeuralNetworksModel_getExtensionOperationType(ANeuralNetworksModel* model,
+ const char* extensionName,
+ uint16_t operationCodeWithinExtension,
+ ANeuralNetworksOperationType* type) {
+ NNTRACE_RT(NNTRACE_PHASE_PREPARATION, "ANeuralNetworksModel_getExtensionOperationType");
+ if (!model || !extensionName || !type) {
+ LOG(ERROR) << "ANeuralNetworksModel_getExtensionOperationType passed a nullptr";
+ return ANEURALNETWORKS_UNEXPECTED_NULL;
+ }
+ ModelBuilder* m = reinterpret_cast<ModelBuilder*>(model);
+ return m->getExtensionType(extensionName, operationCodeWithinExtension, type);
+}
+
+int ANeuralNetworksModel_setOperandExtensionData(ANeuralNetworksModel* model, int32_t index,
+ const void* data, size_t length) {
+ NNTRACE_RT(NNTRACE_PHASE_PREPARATION, "ANeuralNetworksModel_setOperandExtensionData");
+ if (!model || (!data && length != 0)) {
+ LOG(ERROR) << "ANeuralNetworksModel_setOperandExtensionData passed a nullptr";
+ return ANEURALNETWORKS_UNEXPECTED_NULL;
+ }
+ ModelBuilder* m = reinterpret_cast<ModelBuilder*>(model);
+ return m->setOperandExtensionData(index, data, length);
+}