More discipline for models and requests. - Must explicitly call new finish() API on model before compiling or freeing. - Must not modify a request once start() has been called on it. Bug: 63905942 Test: nn/runtime/tests, nn/common/operations tests Change-Id: Ifc6e614bda647d729e8702023a02613e629ca6a0
diff --git a/runtime/NeuralNetworks.cpp b/runtime/NeuralNetworks.cpp index 3393292..3ac8476 100644 --- a/runtime/NeuralNetworks.cpp +++ b/runtime/NeuralNetworks.cpp
@@ -112,6 +112,7 @@ "ANEURALNETWORKS_UNEXPECTED_NULL may have changed"); static_assert(ANEURALNETWORKS_BAD_DATA == 4, "ANEURALNETWORKS_BAD_DATA may have changed"); static_assert(ANEURALNETWORKS_OP_FAILED == 5, "ANEURALNETWORKS_OP_FAILED may have changed"); +static_assert(ANEURALNETWORKS_BAD_STATE == 6, "ANEURALNETWORKS_BAD_STATE may have changed"); // Make sure that the constants are compatible with the values defined in // hardware/interfaces/neuralnetworks/1.0/types.hal. @@ -350,6 +351,15 @@ delete m; } +int ANeuralNetworksModel_finish(ANeuralNetworksModel* model) { + if (!model) { + LOG(ERROR) << "ANeuralNetworksModel_finish passed a nullptr"; + return ANEURALNETWORKS_UNEXPECTED_NULL; + } + ModelBuilder* m = reinterpret_cast<ModelBuilder*>(model); + return m->finish(); +} + int ANeuralNetworksModel_addOperand(ANeuralNetworksModel* model, const ANeuralNetworksOperandType* type) { if (!model || !type) { @@ -443,13 +453,10 @@ } ModelBuilder* m = reinterpret_cast<ModelBuilder*>(model); - CompilationBuilder* c = m->createCompilation(); - if (c == nullptr) { - *compilation = nullptr; - return ANEURALNETWORKS_OUT_OF_MEMORY; - } + CompilationBuilder* c = nullptr; + int result = m->createCompilation(&c); *compilation = reinterpret_cast<ANeuralNetworksCompilation*>(c); - return ANEURALNETWORKS_NO_ERROR; + return result; } void ANeuralNetworksCompilation_free(ANeuralNetworksCompilation* compilation) { @@ -503,13 +510,10 @@ } CompilationBuilder* c = reinterpret_cast<CompilationBuilder*>(compilation); - RequestBuilder* r = c->createRequest(); - if (r == nullptr) { - *request = nullptr; - return ANEURALNETWORKS_OUT_OF_MEMORY; - } + RequestBuilder* r = nullptr; + int result = c->createRequest(&r); *request = reinterpret_cast<ANeuralNetworksRequest*>(r); - return ANEURALNETWORKS_NO_ERROR; + return result; } void ANeuralNetworksRequest_free(ANeuralNetworksRequest* request) {