| /* |
| * Copyright (C) 2017 The Android Open Source Project |
| * |
| * Licensed under the Apache License, Version 2.0 (the "License"); |
| * you may not use this file except in compliance with the License. |
| * You may obtain a copy of the License at |
| * |
| * http://www.apache.org/licenses/LICENSE-2.0 |
| * |
| * Unless required by applicable law or agreed to in writing, software |
| * distributed under the License is distributed on an "AS IS" BASIS, |
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| * See the License for the specific language governing permissions and |
| * limitations under the License. |
| */ |
| |
| // Provides C++ classes to more easily use the Neural Networks API. |
| // TODO(b/117845862): this should be auto generated from NeuralNetworksWrapper.h. |
| |
| #ifndef ANDROID_ML_NN_RUNTIME_TEST_TEST_NEURAL_NETWORKS_WRAPPER_H |
| #define ANDROID_ML_NN_RUNTIME_TEST_TEST_NEURAL_NETWORKS_WRAPPER_H |
| |
| #include "NeuralNetworks.h" |
| #include "NeuralNetworksWrapper.h" |
| #include "NeuralNetworksWrapperExtensions.h" |
| |
| #include <math.h> |
| #include <optional> |
| #include <string> |
| #include <vector> |
| |
| namespace android { |
| namespace nn { |
| namespace test_wrapper { |
| |
| using wrapper::Event; |
| using wrapper::ExecutePreference; |
| using wrapper::ExtensionModel; |
| using wrapper::ExtensionOperandParams; |
| using wrapper::ExtensionOperandType; |
| using wrapper::Memory; |
| using wrapper::Model; |
| using wrapper::OperandType; |
| using wrapper::Result; |
| using wrapper::SymmPerChannelQuantParams; |
| using wrapper::Type; |
| |
| class Compilation { |
| public: |
| Compilation(const Model* model) { |
| int result = ANeuralNetworksCompilation_create(model->getHandle(), &mCompilation); |
| if (result != 0) { |
| // TODO Handle the error |
| } |
| } |
| |
| Compilation() {} |
| |
| ~Compilation() { ANeuralNetworksCompilation_free(mCompilation); } |
| |
| // Disallow copy semantics to ensure the runtime object can only be freed |
| // once. Copy semantics could be enabled if some sort of reference counting |
| // or deep-copy system for runtime objects is added later. |
| Compilation(const Compilation&) = delete; |
| Compilation& operator=(const Compilation&) = delete; |
| |
| // Move semantics to remove access to the runtime object from the wrapper |
| // object that is being moved. This ensures the runtime object will be |
| // freed only once. |
| Compilation(Compilation&& other) { *this = std::move(other); } |
| Compilation& operator=(Compilation&& other) { |
| if (this != &other) { |
| ANeuralNetworksCompilation_free(mCompilation); |
| mCompilation = other.mCompilation; |
| other.mCompilation = nullptr; |
| } |
| return *this; |
| } |
| |
| Result setPreference(ExecutePreference preference) { |
| return static_cast<Result>(ANeuralNetworksCompilation_setPreference( |
| mCompilation, static_cast<int32_t>(preference))); |
| } |
| |
| Result setCaching(const std::string& cacheDir, const std::vector<uint8_t>& token) { |
| if (token.size() != ANEURALNETWORKS_BYTE_SIZE_OF_CACHE_TOKEN) { |
| return Result::BAD_DATA; |
| } |
| return static_cast<Result>(ANeuralNetworksCompilation_setCaching( |
| mCompilation, cacheDir.c_str(), token.data())); |
| } |
| |
| Result finish() { return static_cast<Result>(ANeuralNetworksCompilation_finish(mCompilation)); } |
| |
| ANeuralNetworksCompilation* getHandle() const { return mCompilation; } |
| |
| protected: |
| ANeuralNetworksCompilation* mCompilation = nullptr; |
| }; |
| |
| class Execution { |
| public: |
| Execution(const Compilation* compilation) : mCompilation(compilation->getHandle()) { |
| int result = ANeuralNetworksExecution_create(compilation->getHandle(), &mExecution); |
| if (result != 0) { |
| // TODO Handle the error |
| } |
| } |
| |
| ~Execution() { ANeuralNetworksExecution_free(mExecution); } |
| |
| // Disallow copy semantics to ensure the runtime object can only be freed |
| // once. Copy semantics could be enabled if some sort of reference counting |
| // or deep-copy system for runtime objects is added later. |
| Execution(const Execution&) = delete; |
| Execution& operator=(const Execution&) = delete; |
| |
| // Move semantics to remove access to the runtime object from the wrapper |
| // object that is being moved. This ensures the runtime object will be |
| // freed only once. |
| Execution(Execution&& other) { *this = std::move(other); } |
| Execution& operator=(Execution&& other) { |
| if (this != &other) { |
| ANeuralNetworksExecution_free(mExecution); |
| mCompilation = other.mCompilation; |
| other.mCompilation = nullptr; |
| mExecution = other.mExecution; |
| other.mExecution = nullptr; |
| } |
| return *this; |
| } |
| |
| Result setInput(uint32_t index, const void* buffer, size_t length, |
| const ANeuralNetworksOperandType* type = nullptr) { |
| return static_cast<Result>( |
| ANeuralNetworksExecution_setInput(mExecution, index, type, buffer, length)); |
| } |
| |
| Result setInputFromMemory(uint32_t index, const Memory* memory, uint32_t offset, |
| uint32_t length, const ANeuralNetworksOperandType* type = nullptr) { |
| return static_cast<Result>(ANeuralNetworksExecution_setInputFromMemory( |
| mExecution, index, type, memory->get(), offset, length)); |
| } |
| |
| Result setOutput(uint32_t index, void* buffer, size_t length, |
| const ANeuralNetworksOperandType* type = nullptr) { |
| return static_cast<Result>( |
| ANeuralNetworksExecution_setOutput(mExecution, index, type, buffer, length)); |
| } |
| |
| Result setOutputFromMemory(uint32_t index, const Memory* memory, uint32_t offset, |
| uint32_t length, const ANeuralNetworksOperandType* type = nullptr) { |
| return static_cast<Result>(ANeuralNetworksExecution_setOutputFromMemory( |
| mExecution, index, type, memory->get(), offset, length)); |
| } |
| |
| Result startCompute(Event* event) { |
| ANeuralNetworksEvent* ev = nullptr; |
| Result result = static_cast<Result>(ANeuralNetworksExecution_startCompute(mExecution, &ev)); |
| event->set(ev); |
| return result; |
| } |
| |
| Result compute() { |
| switch (mComputeMode) { |
| case ComputeMode::SYNC: { |
| return static_cast<Result>(ANeuralNetworksExecution_compute(mExecution)); |
| } |
| case ComputeMode::ASYNC: { |
| ANeuralNetworksEvent* event = nullptr; |
| Result result = static_cast<Result>( |
| ANeuralNetworksExecution_startCompute(mExecution, &event)); |
| if (result != Result::NO_ERROR) { |
| return result; |
| } |
| // TODO how to manage the lifetime of events when multiple waiters is not |
| // clear. |
| result = static_cast<Result>(ANeuralNetworksEvent_wait(event)); |
| ANeuralNetworksEvent_free(event); |
| return result; |
| } |
| case ComputeMode::BURST: { |
| ANeuralNetworksBurst* burst = nullptr; |
| Result result = |
| static_cast<Result>(ANeuralNetworksBurst_create(mCompilation, &burst)); |
| if (result != Result::NO_ERROR) { |
| return result; |
| } |
| result = static_cast<Result>( |
| ANeuralNetworksExecution_burstCompute(mExecution, burst)); |
| ANeuralNetworksBurst_free(burst); |
| return result; |
| } |
| } |
| return Result::BAD_DATA; |
| } |
| |
| // By default, compute() uses the synchronous API. setComputeMode() can be |
| // used to change the behavior of compute() to either: |
| // - use the asynchronous API and then wait for computation to complete |
| // or |
| // - use the burst API |
| enum class ComputeMode { SYNC, ASYNC, BURST }; |
| static void setComputeMode(ComputeMode mode) { mComputeMode = mode; } |
| |
| Result getOutputOperandDimensions(uint32_t index, std::vector<uint32_t>* dimensions) { |
| uint32_t rank = 0; |
| Result result = static_cast<Result>( |
| ANeuralNetworksExecution_getOutputOperandRank(mExecution, index, &rank)); |
| dimensions->resize(rank); |
| if ((result != Result::NO_ERROR && result != Result::OUTPUT_INSUFFICIENT_SIZE) || |
| rank == 0) { |
| return result; |
| } |
| result = static_cast<Result>(ANeuralNetworksExecution_getOutputOperandDimensions( |
| mExecution, index, dimensions->data())); |
| return result; |
| } |
| |
| private: |
| ANeuralNetworksCompilation* mCompilation = nullptr; |
| ANeuralNetworksExecution* mExecution = nullptr; |
| |
| // Initialized to ComputeMode::SYNC in TestNeuralNetworksWrapper.cpp. |
| static ComputeMode mComputeMode; |
| }; |
| |
| } // namespace test_wrapper |
| } // namespace nn |
| } // namespace android |
| |
| #endif // ANDROID_ML_NN_RUNTIME_TEST_TEST_NEURAL_NETWORKS_WRAPPER_H |