Merge "NNAPI Burst object cleanup"
diff --git a/common/ExecutionBurstController.cpp b/common/ExecutionBurstController.cpp
index 32231d3..87f5c1e 100644
--- a/common/ExecutionBurstController.cpp
+++ b/common/ExecutionBurstController.cpp
@@ -14,18 +14,26 @@
* limitations under the License.
*/
+#define LOG_TAG "ExecutionBurstController"
+
#include "ExecutionBurstController.h"
#include <android-base/logging.h>
+#include <string>
+#include "Tracing.h"
-namespace android {
-namespace nn {
+namespace android::nn {
namespace {
-constexpr Timing invalidTiming = {UINT64_MAX, UINT64_MAX};
+
+using FmqRequestDescriptor = MQDescriptorSync<FmqRequestDatum>;
+using FmqResultDescriptor = MQDescriptorSync<FmqResultDatum>;
+
+constexpr Timing kInvalidTiming = {UINT64_MAX, UINT64_MAX};
+
} // anonymous namespace
-Return<void> ExecutionBurstCallback::getMemories(const hidl_vec<int32_t>& slots,
- getMemories_cb cb) {
+Return<void> ExecutionBurstController::ExecutionBurstCallback::getMemories(
+ const hidl_vec<int32_t>& slots, getMemories_cb cb) {
std::lock_guard<std::mutex> guard(mMutex);
// get all memories
@@ -45,8 +53,8 @@
return Void();
}
-std::vector<int32_t> ExecutionBurstCallback::getSlots(const hidl_vec<hidl_memory>& memories,
- const std::vector<intptr_t>& keys) {
+std::vector<int32_t> ExecutionBurstController::ExecutionBurstCallback::getSlots(
+ const hidl_vec<hidl_memory>& memories, const std::vector<intptr_t>& keys) {
std::lock_guard<std::mutex> guard(mMutex);
// retrieve (or bind) all slots corresponding to memories
@@ -58,7 +66,8 @@
return slots;
}
-std::pair<bool, int32_t> ExecutionBurstCallback::freeMemory(intptr_t key) {
+std::pair<bool, int32_t> ExecutionBurstController::ExecutionBurstCallback::freeMemory(
+ intptr_t key) {
std::lock_guard<std::mutex> guard(mMutex);
auto iter = mMemoryIdToSlotCache.find(key);
@@ -72,10 +81,12 @@
}
}
-int32_t ExecutionBurstCallback::getSlotLocked(const hidl_memory& memory, intptr_t key) {
+int32_t ExecutionBurstController::ExecutionBurstCallback::getSlotLocked(const hidl_memory& memory,
+ intptr_t key) {
auto iter = mMemoryIdToSlotCache.find(key);
if (iter == mMemoryIdToSlotCache.end()) {
const int32_t slot = mNextSlot;
+ // TODO: change mNextSlot to uint64_t or maintain a free list of IDs
mNextSlot = (mNextSlot + 1) % (1 << 30);
mMemoryIdToSlotCache[key] = slot;
mSlotToMemoryCache[slot] = memory;
@@ -86,15 +97,72 @@
}
}
+std::unique_ptr<ExecutionBurstController> ExecutionBurstController::create(
+ const sp<IPreparedModel>& preparedModel, bool blocking) {
+ // check inputs
+ if (preparedModel == nullptr) {
+ LOG(ERROR) << "ExecutionBurstController::create passed a nullptr";
+ return nullptr;
+ }
+
+ // create callback object
+ sp<ExecutionBurstCallback> callback = new ExecutionBurstCallback();
+ if (callback == nullptr) {
+ LOG(ERROR) << "ExecutionBurstController::create failed to create callback";
+ return nullptr;
+ }
+
+ // create FMQ objects
+ std::unique_ptr<FmqRequestChannel> fmqRequestChannel{new (std::nothrow) FmqRequestChannel(
+ kExecutionBurstChannelLength, /*confEventFlag=*/blocking)};
+ std::unique_ptr<FmqResultChannel> fmqResultChannel{new (std::nothrow) FmqResultChannel(
+ kExecutionBurstChannelLength, /*confEventFlag=*/blocking)};
+
+ // check FMQ objects
+ if (!fmqRequestChannel || !fmqResultChannel || !fmqRequestChannel->isValid() ||
+ !fmqResultChannel->isValid()) {
+ LOG(ERROR) << "ExecutionBurstController::create failed to create FastMessageQueue";
+ return nullptr;
+ }
+
+ // descriptors
+ const FmqRequestDescriptor& fmqRequestDescriptor = *fmqRequestChannel->getDesc();
+ const FmqResultDescriptor& fmqResultDescriptor = *fmqResultChannel->getDesc();
+
+ // configure burst
+ ErrorStatus errorStatus;
+ sp<IBurstContext> burstContext;
+ Return<void> ret = preparedModel->configureExecutionBurst(
+ callback, fmqRequestDescriptor, fmqResultDescriptor,
+ [&errorStatus, &burstContext](ErrorStatus status, const sp<IBurstContext>& context) {
+ errorStatus = status;
+ burstContext = context;
+ });
+
+ // check burst
+ if (errorStatus != ErrorStatus::NONE) {
+ LOG(ERROR) << "IPreparedModel::configureExecutionBurst failed with "
+ << toString(errorStatus);
+ return nullptr;
+ }
+ if (burstContext == nullptr) {
+ LOG(ERROR) << "IPreparedModel::configureExecutionBurst returned nullptr for burst";
+ return nullptr;
+ }
+
+ // make and return controller
+ return std::make_unique<ExecutionBurstController>(std::move(fmqRequestChannel),
+ std::move(fmqResultChannel), burstContext,
+ callback, blocking);
+}
+
ExecutionBurstController::ExecutionBurstController(
std::unique_ptr<FmqRequestChannel> fmqRequestChannel,
std::unique_ptr<FmqResultChannel> fmqResultChannel, const sp<IBurstContext>& burstContext,
- const sp<IPreparedModel>& preparedModel, const sp<ExecutionBurstCallback>& callback,
- bool blocking)
+ const sp<ExecutionBurstCallback>& callback, bool blocking)
: mFmqRequestChannel(std::move(fmqRequestChannel)),
mFmqResultChannel(std::move(fmqResultChannel)),
mBurstContext(burstContext),
- mPreparedModel(preparedModel),
mMemoryCache(callback),
mUsesFutex(blocking) {}
@@ -110,8 +178,14 @@
using discriminator = FmqResultDatum::hidl_discriminator;
// wait for result packet and read first element of result packet
+ // TODO: have a more elegant way to wait for data, and read it all at once.
+ // For example, EventFlag can be used to directly wait on the futex, and all
+ // the data can be read at once with a non-blocking call to
+ // MessageQueue::read. For further optimization, MessageQueue::beginRead and
+ // MessageQueue::commitRead can be used to avoid an extra copy of the
+ // metadata.
FmqResultDatum datum;
- bool success = false;
+ bool success = true;
if (mUsesFutex) {
success = mFmqResultChannel->readBlocking(&datum, 1);
} else {
@@ -133,7 +207,11 @@
// retrieve remaining elements
// NOTE: all of the data is already available at this point, so there's no
- // need to do a blocking wait to wait for more data
+ // need to do a blocking wait to wait for more data. This is known because
+ // in FMQ, all writes are published (made available) atomically. Currently,
+ // the producer always publishes the entire packet in one function call, so
+ // if the first element of the packet is available, the remaining elements
+ // are also available.
std::vector<FmqResultDatum> packet(count);
packet.front() = datum;
success = mFmqResultChannel->read(packet.data() + 1, packet.size() - 1);
@@ -238,7 +316,7 @@
// validate packet information
if (data[index].getDiscriminator() != discriminator::packetInformation) {
LOG(ERROR) << "FMQ Result packet ill-formed";
- return {ErrorStatus::GENERAL_FAILURE, {}, invalidTiming};
+ return {ErrorStatus::GENERAL_FAILURE, {}, kInvalidTiming};
}
// unpackage packet information
@@ -253,7 +331,7 @@
// validate operand information
if (data[index].getDiscriminator() != discriminator::operandInformation) {
LOG(ERROR) << "FMQ Result packet ill-formed";
- return {ErrorStatus::GENERAL_FAILURE, {}, invalidTiming};
+ return {ErrorStatus::GENERAL_FAILURE, {}, kInvalidTiming};
}
// unpackage operand information
@@ -269,7 +347,7 @@
// validate dimension
if (data[index].getDiscriminator() != discriminator::operandDimensionValue) {
LOG(ERROR) << "FMQ Result packet ill-formed";
- return {ErrorStatus::GENERAL_FAILURE, {}, invalidTiming};
+ return {ErrorStatus::GENERAL_FAILURE, {}, kInvalidTiming};
}
// unpackage dimension
@@ -287,7 +365,7 @@
// validate execution timing
if (data[index].getDiscriminator() != discriminator::executionTiming) {
LOG(ERROR) << "FMQ Result packet ill-formed";
- return {ErrorStatus::GENERAL_FAILURE, {}, invalidTiming};
+ return {ErrorStatus::GENERAL_FAILURE, {}, kInvalidTiming};
}
// unpackage execution timing
@@ -297,7 +375,7 @@
// validate packet information
if (index != packetSize) {
LOG(ERROR) << "FMQ Result packet ill-formed";
- return {ErrorStatus::GENERAL_FAILURE, {}, invalidTiming};
+ return {ErrorStatus::GENERAL_FAILURE, {}, kInvalidTiming};
}
// return result
@@ -306,6 +384,8 @@
std::tuple<ErrorStatus, std::vector<OutputShape>, Timing> ExecutionBurstController::compute(
const Request& request, MeasureTiming measure, const std::vector<intptr_t>& memoryIds) {
+ NNTRACE_FULL(NNTRACE_LAYER_IPC, NNTRACE_PHASE_EXECUTION, "ExecutionBurstController::compute");
+
// serialize request
std::vector<FmqRequestDatum> requestData = serialize(request, measure, memoryIds);
@@ -316,14 +396,14 @@
bool success = sendPacket(requestData);
if (!success) {
LOG(ERROR) << "Error sending FMQ packet";
- return {ErrorStatus::GENERAL_FAILURE, {}, invalidTiming};
+ return {ErrorStatus::GENERAL_FAILURE, {}, kInvalidTiming};
}
// get result packet
const std::vector<FmqResultDatum> resultData = getPacketBlocking();
if (resultData.empty()) {
LOG(ERROR) << "Error retrieving FMQ packet";
- return {ErrorStatus::GENERAL_FAILURE, {}, invalidTiming};
+ return {ErrorStatus::GENERAL_FAILURE, {}, kInvalidTiming};
}
// deserialize result
@@ -339,64 +419,4 @@
}
}
-std::unique_ptr<ExecutionBurstController> createExecutionBurstController(
- const sp<IPreparedModel>& preparedModel, bool blocking) {
- // check inputs
- if (preparedModel == nullptr) {
- LOG(ERROR) << "createExecutionBurstController passed a nullptr";
- return nullptr;
- }
-
- // create callback object
- sp<ExecutionBurstCallback> callback = new ExecutionBurstCallback();
- if (callback == nullptr) {
- LOG(ERROR) << "createExecutionBurstController failed to create callback";
- return nullptr;
- }
-
- // create FMQ objects
- std::unique_ptr<FmqRequestChannel> fmqRequestChannel{new (std::nothrow) FmqRequestChannel(
- kExecutionBurstChannelLength, /*confEventFlag=*/blocking)};
- std::unique_ptr<FmqResultChannel> fmqResultChannel{new (std::nothrow) FmqResultChannel(
- kExecutionBurstChannelLength, /*confEventFlag=*/blocking)};
-
- // check FMQ objects
- if (!fmqRequestChannel || !fmqResultChannel || !fmqRequestChannel->isValid() ||
- !fmqResultChannel->isValid()) {
- LOG(ERROR) << "createExecutionBurstController failed to create FastMessageQueue";
- return nullptr;
- }
-
- // descriptors
- const FmqRequestDescriptor& fmqRequestDescriptor = *fmqRequestChannel->getDesc();
- const FmqResultDescriptor& fmqResultDescriptor = *fmqResultChannel->getDesc();
-
- // configure burst
- ErrorStatus errorStatus;
- sp<IBurstContext> burstContext;
- Return<void> ret = preparedModel->configureExecutionBurst(
- callback, fmqRequestDescriptor, fmqResultDescriptor,
- [&errorStatus, &burstContext](ErrorStatus status, const sp<IBurstContext>& context) {
- errorStatus = status;
- burstContext = context;
- });
-
- // check burst
- if (errorStatus != ErrorStatus::NONE) {
- LOG(ERROR) << "IPreparedModel::configureExecutionBurst failed with "
- << toString(errorStatus);
- return nullptr;
- }
- if (burstContext == nullptr) {
- LOG(ERROR) << "IPreparedModel::configureExecutionBurst returned nullptr for burst";
- return nullptr;
- }
-
- // make and return controller
- return std::make_unique<ExecutionBurstController>(std::move(fmqRequestChannel),
- std::move(fmqResultChannel), burstContext,
- preparedModel, callback, blocking);
-}
-
-} // namespace nn
-} // namespace android
+} // namespace android::nn
diff --git a/common/ExecutionBurstServer.cpp b/common/ExecutionBurstServer.cpp
index 64a4ee2..6ede34d 100644
--- a/common/ExecutionBurstServer.cpp
+++ b/common/ExecutionBurstServer.cpp
@@ -14,44 +14,57 @@
* limitations under the License.
*/
+#define LOG_TAG "ExecutionBurstServer"
+
#include "ExecutionBurstServer.h"
#include <android-base/logging.h>
+#include <set>
+#include <string>
+#include "Tracing.h"
-namespace android {
-namespace nn {
+namespace android::nn {
-BurstMemoryCache::BurstMemoryCache(const sp<IBurstCallback>& callback) : mCallback(callback) {}
+ExecutionBurstServer::BurstMemoryCache::BurstMemoryCache(const sp<IBurstCallback>& callback)
+ : mCallback(callback) {}
-hidl_vec<hidl_memory> BurstMemoryCache::getMemories(const std::vector<int32_t>& slots) {
+hidl_vec<hidl_memory> ExecutionBurstServer::BurstMemoryCache::getMemories(
+ const std::vector<int32_t>& slots) {
std::lock_guard<std::mutex> guard(mMutex);
// find unique unknown slots
- std::vector<int32_t> unknownSlots = slots;
- std::sort(unknownSlots.begin(), unknownSlots.end());
- auto last = std::unique(unknownSlots.begin(), unknownSlots.end());
- unknownSlots.erase(last, unknownSlots.end());
+ std::set<int32_t> setOfUnknownSlots;
+ for (int32_t slot : slots) {
+ if (mSlotToMemoryCache.find(slot) == mSlotToMemoryCache.end()) {
+ setOfUnknownSlots.insert(slot);
+ }
+ }
+ const std::vector<int32_t> vecOfUnknownSlots(setOfUnknownSlots.begin(),
+ setOfUnknownSlots.end());
// retrieve unknown slots
- ErrorStatus errorStatus = ErrorStatus::GENERAL_FAILURE;
- std::vector<hidl_memory> returnedMemories;
- Return<void> ret = mCallback->getMemories(
- unknownSlots, [&errorStatus, &returnedMemories](ErrorStatus status,
- const hidl_vec<hidl_memory>& memories) {
- errorStatus = status;
- if (status == ErrorStatus::NONE) {
- returnedMemories = memories;
- }
- });
+ if (!vecOfUnknownSlots.empty()) {
+ ErrorStatus errorStatus = ErrorStatus::GENERAL_FAILURE;
+ std::vector<hidl_memory> returnedMemories;
+ Return<void> ret = mCallback->getMemories(
+ vecOfUnknownSlots,
+ [&errorStatus, &returnedMemories](ErrorStatus status,
+ const hidl_vec<hidl_memory>& memories) {
+ errorStatus = status;
+ if (status == ErrorStatus::NONE) {
+ returnedMemories = memories;
+ }
+ });
- if (!ret.isOk() || errorStatus != ErrorStatus::NONE) {
- LOG(ERROR) << "Error retrieving memories";
- return {};
- }
+ if (!ret.isOk() || errorStatus != ErrorStatus::NONE) {
+ LOG(ERROR) << "Error retrieving memories";
+ return {};
+ }
- // add memories to unknown slots
- for (size_t i = 0; i < unknownSlots.size(); ++i) {
- mSlotToMemoryCache[unknownSlots[i]] = returnedMemories[i];
+ // add memories to unknown slots
+ for (size_t i = 0; i < vecOfUnknownSlots.size(); ++i) {
+ mSlotToMemoryCache[vecOfUnknownSlots[i]] = returnedMemories[i];
+ }
}
// get all slots
@@ -59,14 +72,42 @@
for (size_t i = 0; i < slots.size(); ++i) {
memories[i] = mSlotToMemoryCache[slots[i]];
}
+
return memories;
}
-void BurstMemoryCache::freeMemory(int32_t slot) {
+void ExecutionBurstServer::BurstMemoryCache::freeMemory(int32_t slot) {
std::lock_guard<std::mutex> guard(mMutex);
mSlotToMemoryCache.erase(slot);
}
+sp<ExecutionBurstServer> ExecutionBurstServer::create(
+ const sp<IBurstCallback>& callback, const MQDescriptorSync<FmqRequestDatum>& requestChannel,
+ const MQDescriptorSync<FmqResultDatum>& resultChannel, IPreparedModel* preparedModel) {
+ // check inputs
+ if (callback == nullptr || preparedModel == nullptr) {
+ LOG(ERROR) << "ExecutionBurstServer::create passed a nullptr";
+ return nullptr;
+ }
+
+ // create FMQ objects
+ std::unique_ptr<FmqRequestChannel> fmqRequestChannel{new (std::nothrow)
+ FmqRequestChannel(requestChannel)};
+ std::unique_ptr<FmqResultChannel> fmqResultChannel{new (std::nothrow)
+ FmqResultChannel(resultChannel)};
+
+ // check FMQ objects
+ if (!fmqRequestChannel || !fmqResultChannel || !fmqRequestChannel->isValid() ||
+ !fmqResultChannel->isValid()) {
+ LOG(ERROR) << "ExecutionBurstServer::create failed to create FastMessageQueue";
+ return nullptr;
+ }
+
+ // make and return context
+ return new ExecutionBurstServer(callback, std::move(fmqRequestChannel),
+ std::move(fmqResultChannel), preparedModel);
+}
+
ExecutionBurstServer::ExecutionBurstServer(const sp<IBurstCallback>& callback,
std::unique_ptr<FmqRequestChannel> requestChannel,
std::unique_ptr<FmqResultChannel> resultChannel,
@@ -85,9 +126,13 @@
mTeardown = true;
// force unblock
+ // ExecutionBurstServer is by default waiting on a request packet. If the
+ // client process destroys its burst object, the server will still be
+ // waiting on the futex (assuming mBlocking is true). This force unblock
+ // wakes up any thread waiting on the futex.
if (mBlocking) {
- // TODO: look for a different/better way to signal/notify the futex to wake
- // up any thread waiting on it
+ // TODO: look for a different/better way to signal/notify the futex to
+ // wake up any thread waiting on it
FmqRequestDatum datum;
datum.packetInformation({/*.packetSize=*/0, /*.numberOfInputOperands=*/0,
/*.numberOfOutputOperands=*/0, /*.numberOfPools=*/0});
@@ -117,7 +162,13 @@
return {};
}
- // wait for request packet and read first element of result packet
+ // wait for request packet and read first element of request packet
+ // TODO: have a more elegant way to wait for data, and read it all at once.
+ // For example, EventFlag can be used to directly wait on the futex, and all
+ // the data can be read at once with a non-blocking call to
+ // MessageQueue::read. For further optimization, MessageQueue::beginRead and
+ // MessageQueue::commitRead can be used to avoid an extra copy of the
+ // metadata.
FmqRequestDatum datum;
bool success = false;
if (mBlocking) {
@@ -139,13 +190,19 @@
return {};
}
+ NNTRACE_FULL(NNTRACE_LAYER_IPC, NNTRACE_PHASE_EXECUTION, "ExecutionBurstServer getting packet");
+
// unpack packet information
const auto& packetInfo = datum.packetInformation();
const size_t count = packetInfo.packetSize;
// retrieve remaining elements
// NOTE: all of the data is already available at this point, so there's no
- // need to do a blocking wait to wait for more data
+ // need to do a blocking wait to wait for more data. This is known because
+ // in FMQ, all writes are published (made available) atomically. Currently,
+ // the producer always publishes the entire packet in one function call, so
+ // if the first element of the packet is available, the remaining elements
+ // are also available.
std::vector<FmqRequestDatum> packet(count);
packet.front() = datum;
success = mFmqRequestChannel->read(packet.data() + 1, packet.size() - 1);
@@ -365,6 +422,9 @@
return;
}
+ NNTRACE_FULL(NNTRACE_LAYER_IPC, NNTRACE_PHASE_EXECUTION,
+ "ExecutionBurstServer processing packet and returning results");
+
// continue processing
Request request;
MeasureTiming measure;
@@ -374,6 +434,10 @@
ErrorStatus errorStatus = ErrorStatus::GENERAL_FAILURE;
std::vector<OutputShape> outputShapes;
Timing returnedTiming;
+ // This call to IPreparedModel::executeSynchronously occurs entirely
+ // within the same process, so ignore the Return<> errors via .isOk().
+ // TODO: verify it is safe to always call isOk() here, or if there is
+ // any benefit to checking any potential errors.
mPreparedModel
->executeSynchronously(request, measure,
[&errorStatus, &outputShapes, &returnedTiming](
@@ -392,33 +456,4 @@
}
}
-sp<IBurstContext> createBurstContext(const sp<IBurstCallback>& callback,
- const MQDescriptorSync<FmqRequestDatum>& requestChannel,
- const MQDescriptorSync<FmqResultDatum>& resultChannel,
- IPreparedModel* preparedModel) {
- // check inputs
- if (callback == nullptr || preparedModel == nullptr) {
- LOG(ERROR) << "createExecutionBurstServer passed a nullptr";
- return nullptr;
- }
-
- // create FMQ objects
- std::unique_ptr<FmqRequestChannel> fmqRequestChannel{new (std::nothrow)
- FmqRequestChannel(requestChannel)};
- std::unique_ptr<FmqResultChannel> fmqResultChannel{new (std::nothrow)
- FmqResultChannel(resultChannel)};
-
- // check FMQ objects
- if (!fmqRequestChannel || !fmqResultChannel || !fmqRequestChannel->isValid() ||
- !fmqResultChannel->isValid()) {
- LOG(ERROR) << "createExecutionBurstServer failed to create FastMessageQueue";
- return nullptr;
- }
-
- // make and return context
- return new ExecutionBurstServer(callback, std::move(fmqRequestChannel),
- std::move(fmqResultChannel), preparedModel);
-}
-
-} // namespace nn
-} // namespace android
+} // namespace android::nn
diff --git a/common/include/ExecutionBurstController.h b/common/include/ExecutionBurstController.h
index bf36470..7152325 100644
--- a/common/include/ExecutionBurstController.h
+++ b/common/include/ExecutionBurstController.h
@@ -27,16 +27,13 @@
#include <tuple>
#include "HalInterfaces.h"
-namespace android {
-namespace nn {
+namespace android::nn {
using ::android::hardware::kSynchronizedReadWrite;
using ::android::hardware::MessageQueue;
using ::android::hardware::MQDescriptorSync;
using FmqRequestChannel = MessageQueue<FmqRequestDatum, kSynchronizedReadWrite>;
using FmqResultChannel = MessageQueue<FmqResultDatum, kSynchronizedReadWrite>;
-using FmqRequestDescriptor = MQDescriptorSync<FmqRequestDatum>;
-using FmqResultDescriptor = MQDescriptorSync<FmqResultDatum>;
/**
* Number of elements in the FMQ.
@@ -44,62 +41,96 @@
constexpr const size_t kExecutionBurstChannelLength = 1024;
/**
- * NN runtime burst callback object and memory cache.
- *
- * ExecutionBurstCallback associates a hidl_memory object with a slot number to
- * be passed across FMQ. The ExecutionBurstServer can use this callback to
- * retrieve this hidl_memory corresponding to the slot via HIDL.
- *
- * Whenever a hidl_memory object is copied, it will duplicate the underlying
- * file descriptor. Because the NN runtime currently copies the hidl_memory on
- * each execution, it is difficult to associate hidl_memory objects with
- * previously cached hidl_memory objects. For this reason, callers of this class
- * must pair each hidl_memory object with an associated key. For efficiency, if
- * two hidl_memory objects represent the same underlying buffer, they must use
- * the same key.
- */
-class ExecutionBurstCallback : public IBurstCallback {
- DISALLOW_COPY_AND_ASSIGN(ExecutionBurstCallback);
-
- public:
- ExecutionBurstCallback() = default;
-
- Return<void> getMemories(const hidl_vec<int32_t>& slots, getMemories_cb cb) override;
-
- std::vector<int32_t> getSlots(const hidl_vec<hidl_memory>& memories,
- const std::vector<intptr_t>& keys);
- int32_t getSlot(const hidl_memory& memory, intptr_t key);
- std::pair<bool, int32_t> freeMemory(intptr_t key);
-
- private:
- int32_t getSlotLocked(const hidl_memory& memory, intptr_t key);
-
- std::mutex mMutex;
- int32_t mNextSlot = 0;
- std::map<intptr_t, int32_t> mMemoryIdToSlotCache;
- std::map<int32_t, hidl_memory> mSlotToMemoryCache;
-};
-
-/**
- * NN runtime burst object
- *
- * TODO: provide high-level description of class
+ * The ExecutionBurstController class manages both the serialization and
+ * deserialization of data across FMQ, making it appear to the runtime as a
+ * regular synchronous inference. Additionally, this class manages the burst's
+ * memory cache.
*/
class ExecutionBurstController {
DISALLOW_IMPLICIT_CONSTRUCTORS(ExecutionBurstController);
+ /**
+ * NN runtime burst callback object and memory cache.
+ *
+ * ExecutionBurstCallback associates a hidl_memory object with a slot number
+ * to be passed across FMQ. The ExecutionBurstServer can use this callback
+ * to retrieve this hidl_memory corresponding to the slot via HIDL.
+ *
+ * Whenever a hidl_memory object is copied, it will duplicate the underlying
+ * file descriptor. Because the NN runtime currently copies the hidl_memory
+ * on each execution, it is difficult to associate hidl_memory objects with
+ * previously cached hidl_memory objects. For this reason, callers of this
+ * class must pair each hidl_memory object with an associated key. For
+ * efficiency, if two hidl_memory objects represent the same underlying
+ * buffer, they must use the same key.
+ */
+ class ExecutionBurstCallback : public IBurstCallback {
+ DISALLOW_COPY_AND_ASSIGN(ExecutionBurstCallback);
+
+ public:
+ ExecutionBurstCallback() = default;
+
+ Return<void> getMemories(const hidl_vec<int32_t>& slots, getMemories_cb cb) override;
+
+ std::vector<int32_t> getSlots(const hidl_vec<hidl_memory>& memories,
+ const std::vector<intptr_t>& keys);
+
+ int32_t getSlot(const hidl_memory& memory, intptr_t key);
+
+ /*
+ * This function performs two different actions:
+ * 1) Removes an entry from the cache (if present), including the local
+ * storage of the hidl_memory object. Note that this call does not
+ * free any corresponding hidl_memory object in ExecutionBurstServer,
+ * which is separately freed via IBurstContext::freeMemory.
+ * 2) Return whether a cache entry was removed and which slot was removed if
+ * found. If the key did not to correspond to any entry in the cache, a
+ * slot number of 0 is returned. The slot number and whether the entry
+ * existed is useful so the same slot can be freed in the
+ * ExecutionBurstServer's cache via IBurstContext::freeMemory.
+ */
+ std::pair<bool, int32_t> freeMemory(intptr_t key);
+
+ private:
+ int32_t getSlotLocked(const hidl_memory& memory, intptr_t key);
+
+ std::mutex mMutex;
+ int32_t mNextSlot = 0;
+ std::map<intptr_t, int32_t> mMemoryIdToSlotCache;
+ std::map<int32_t, hidl_memory> mSlotToMemoryCache;
+ };
+
public:
+ /**
+ * Creates a burst controller on a prepared model.
+ *
+ * Prefer this over ExecutionBurstController's constructor.
+ *
+ * @param preparedModel Model prepared for execution to execute on.
+ * @param blocking 'true' if the FMQ should use a futex to perform blocking
+ * until data is available in a less responsive, but more energy
+ * efficient manner. 'false' if the FMQ should use spin-looping to
+ * wait until data is available in a more responsive, but less energy
+ * efficient manner.
+ * @return ExecutionBurstController Execution burst controller object.
+ */
+ static std::unique_ptr<ExecutionBurstController> create(const sp<IPreparedModel>& preparedModel,
+ bool blocking);
+
ExecutionBurstController(std::unique_ptr<FmqRequestChannel> fmqRequestChannel,
std::unique_ptr<FmqResultChannel> fmqResultChannel,
const sp<IBurstContext>& burstContext,
- const sp<IPreparedModel>& preparedModel,
const sp<ExecutionBurstCallback>& callback, bool blocking);
/**
* Execute a request on a model.
*
* @param request Arguments to be executed on a model.
- * @return status and output shape of the execution.
+ * @param measure Whether to collect timing measurements, either YES or NO
+ * @param memoryIds Identifiers corresponding to each memory object in the
+ * request's pools.
+ * @return status and output shape of the execution and any execution time
+ * measurements.
*/
std::tuple<ErrorStatus, std::vector<OutputShape>, Timing> compute(
const Request& request, MeasureTiming measure, const std::vector<intptr_t>& memoryIds);
@@ -122,22 +153,10 @@
const std::unique_ptr<FmqRequestChannel> mFmqRequestChannel;
const std::unique_ptr<FmqResultChannel> mFmqResultChannel;
const sp<IBurstContext> mBurstContext;
- const sp<IPreparedModel> mPreparedModel;
const sp<ExecutionBurstCallback> mMemoryCache;
const bool mUsesFutex;
};
-/**
- * Creates a burst controller on a prepared model.
- *
- * @param preparedModel Model prepared for execution to execute on.
- * @param blocking 'true' if the FMQ should block until data is available.
- * @return ExecutionBurstController Execution burst controller object.
- */
-std::unique_ptr<ExecutionBurstController> createExecutionBurstController(
- const sp<IPreparedModel>& preparedModel, bool blocking);
-
-} // namespace nn
-} // namespace android
+} // namespace android::nn
#endif // ANDROID_ML_NN_RUNTIME_EXECUTION_BURST_CONTROLLER_H
diff --git a/common/include/ExecutionBurstServer.h b/common/include/ExecutionBurstServer.h
index 13dfaaf..0a3222f 100644
--- a/common/include/ExecutionBurstServer.h
+++ b/common/include/ExecutionBurstServer.h
@@ -26,8 +26,7 @@
#include <set>
#include "HalInterfaces.h"
-namespace android {
-namespace nn {
+namespace android::nn {
using ::android::hardware::kSynchronizedReadWrite;
using ::android::hardware::MessageQueue;
@@ -38,29 +37,58 @@
using FmqResultDescriptor = MQDescriptorSync<FmqResultDatum>;
/**
- */
-class BurstMemoryCache {
- DISALLOW_IMPLICIT_CONSTRUCTORS(BurstMemoryCache);
-
- public:
- BurstMemoryCache(const sp<IBurstCallback>& callback);
-
- hidl_vec<hidl_memory> getMemories(const std::vector<int32_t>& slots);
- void freeMemory(int32_t slot);
-
- private:
- std::mutex mMutex;
- const sp<IBurstCallback> mCallback;
- std::map<int32_t, hidl_memory> mSlotToMemoryCache;
-};
-
-/**
- * NN server burst object
+ * The ExecutionBurstServer class is responsible for waiting for and
+ * deserializing a request object from a FMQ, performing the inference, and
+ * serializing the result back across another FMQ.
*/
class ExecutionBurstServer : public IBurstContext {
DISALLOW_IMPLICIT_CONSTRUCTORS(ExecutionBurstServer);
+ /**
+ * BurstMemoryCache is responsible for managing the local memory cache of
+ * the burst object. If the ExecutionBurstServer requests a memory key that
+ * is unrecognized, the BurstMemoryCache object will retrieve the memory
+ * from the client, transparent from the ExecutionBurstServer object.
+ */
+ class BurstMemoryCache {
+ DISALLOW_IMPLICIT_CONSTRUCTORS(BurstMemoryCache);
+
+ public:
+ BurstMemoryCache(const sp<IBurstCallback>& callback);
+
+ hidl_vec<hidl_memory> getMemories(const std::vector<int32_t>& slots);
+ void freeMemory(int32_t slot);
+
+ private:
+ std::mutex mMutex;
+ const sp<IBurstCallback> mCallback;
+ std::map<int32_t, hidl_memory> mSlotToMemoryCache;
+ };
+
public:
+ /**
+ * Create automated context to manage FMQ-based executions.
+ *
+ * This function is intended to be used by a service to automatically:
+ * 1) Receive data from a provided FMQ
+ * 2) Execute a model with the given information
+ * 3) Send the result to the created FMQ
+ *
+ * @param callback Callback used to retrieve memories corresponding to
+ * unrecognized slots.
+ * @param requestChannel Input FMQ channel through which the client passes the
+ * request to the service.
+ * @param resultChannel Output FMQ channel from which the client can retrieve
+ * the result of the execution.
+ * @param preparedModel PreparedModel that the burst object was created from.
+ * This will be used to synchronously perform the execution.
+ * @result IBurstContext Handle to the burst context.
+ */
+ static sp<ExecutionBurstServer> create(const sp<IBurstCallback>& callback,
+ const FmqRequestDescriptor& requestChannel,
+ const FmqResultDescriptor& resultChannel,
+ IPreparedModel* preparedModel);
+
ExecutionBurstServer(const sp<IBurstCallback>& callback,
std::unique_ptr<FmqRequestChannel> requestChannel,
std::unique_ptr<FmqResultChannel> resultChannel,
@@ -87,34 +115,6 @@
const bool mBlocking;
};
-/**
- * Create automated context to manage FMQ-based executions.
- *
- * This function is intended to be used by a service to automatically:
- * 1) Receive data from a provided FMQ
- * 2) Execute a model with the given information
- * 3) Send the result to the created FMQ
- *
- * @param callback Callback used to retrieve memories corresponding to
- * unrecognized slots.
- * @param requestChannel Input FMQ channel through which the client passes the
- * request to the service.
- * @param requestChannel Output FMQ channel from which the client can retrieve
- * the result of the execution.
- * @param preparedModel PreparedModel that the burst object was created from.
- * This will be used to synchronously perform the
- * execution.
- * @result IBurstContext Handle to the burst context.
- */
-::android::sp<::android::hardware::neuralnetworks::V1_2::IBurstContext> createBurstContext(
- const sp<::android::hardware::neuralnetworks::V1_2::IBurstCallback>& callback,
- const ::android::hardware::MQDescriptorSync<
- ::android::hardware::neuralnetworks::V1_2::FmqRequestDatum>& requestChannel,
- const ::android::hardware::MQDescriptorSync<
- ::android::hardware::neuralnetworks::V1_2::FmqResultDatum>& resultChannel,
- ::android::hardware::neuralnetworks::V1_2::IPreparedModel* preparedModel);
-
-} // namespace nn
-} // namespace android
+} // namespace android::nn
#endif // ANDROID_ML_NN_COMMON_EXECUTION_BURST_SERVER_H
diff --git a/driver/sample/SampleDriver.cpp b/driver/sample/SampleDriver.cpp
index c0fd37c..a2289fd 100644
--- a/driver/sample/SampleDriver.cpp
+++ b/driver/sample/SampleDriver.cpp
@@ -353,7 +353,7 @@
"SampleDriver::configureExecutionBurst");
const sp<V1_2::IBurstContext> burst =
- createBurstContext(callback, requestChannel, resultChannel, this);
+ ExecutionBurstServer::create(callback, requestChannel, resultChannel, this);
if (burst == nullptr) {
cb(ErrorStatus::GENERAL_FAILURE, {});
diff --git a/runtime/ExecutionBuilder.cpp b/runtime/ExecutionBuilder.cpp
index 140c0db..035e84a 100644
--- a/runtime/ExecutionBuilder.cpp
+++ b/runtime/ExecutionBuilder.cpp
@@ -468,7 +468,8 @@
int ExecutionBuilder::compute(sp<ExecutionCallback>* synchronizationCallback,
BurstBuilder* burstBuilder) {
- assert(synchronizationCallback == nullptr || burstBuilder == nullptr);
+ CHECK(synchronizationCallback == nullptr || burstBuilder == nullptr)
+ << "synchronizationCallback and burstBuilder cannot simultaneously be used";
const bool synchronous = (synchronizationCallback == nullptr);
diff --git a/runtime/ExecutionBuilder.h b/runtime/ExecutionBuilder.h
index cd79f40..e2c01f7 100644
--- a/runtime/ExecutionBuilder.h
+++ b/runtime/ExecutionBuilder.h
@@ -117,7 +117,7 @@
// If burst is provided, then the burst path will be used. If a burst is not
// provided (i.e., is nullptr), then a synchronous execution will occur.
//
- // Providing both synchronizationCallbak and burstBuilder is an error.
+ // Providing both synchronizationCallback and burstBuilder is an error.
int compute(sp<ExecutionCallback>* synchronizationCallback,
BurstBuilder* burstBuilder = nullptr);
diff --git a/runtime/VersionedInterfaces.cpp b/runtime/VersionedInterfaces.cpp
index 0dfd894..9c657c2 100644
--- a/runtime/VersionedInterfaces.cpp
+++ b/runtime/VersionedInterfaces.cpp
@@ -93,7 +93,8 @@
// create death handler object
sp<IPreparedModelDeathHandler> deathHandler = new (std::nothrow) IPreparedModelDeathHandler();
if (!deathHandler) {
- LOG(ERROR) << "VersionedIDevice::create -- Failed to create IPreparedModelDeathHandler.";
+ LOG(ERROR) << "VersionedIPreparedModel::create -- Failed to create "
+ "IPreparedModelDeathHandler.";
return nullptr;
}
@@ -103,8 +104,8 @@
// providing the response.
const Return<bool> ret = preparedModel->linkToDeath(deathHandler, 0);
if (!ret.isOk() || ret != true) {
- LOG(ERROR) << "VersionedIDevice::create -- Failed to register a death recipient for the "
- "IPreparedModel object.";
+ LOG(ERROR) << "VersionedIPreparedModel::create -- Failed to register a death recipient for "
+ "the IPreparedModel object.";
return nullptr;
}
@@ -185,7 +186,7 @@
std::unique_ptr<ExecutionBurstController> VersionedIPreparedModel::configureExecutionBurst(
bool blocking) const {
if (mPreparedModelV1_2 != nullptr) {
- return createExecutionBurstController(mPreparedModelV1_2, blocking);
+ return ExecutionBurstController::create(mPreparedModelV1_2, blocking);
} else {
return nullptr;
}
diff --git a/runtime/test/TestMain.cpp b/runtime/test/TestMain.cpp
index dc32cec..5aa163d 100644
--- a/runtime/test/TestMain.cpp
+++ b/runtime/test/TestMain.cpp
@@ -40,8 +40,7 @@
// non-public DeviceManager::setSyncExecHal(); we assume the setting is always
// true, and if we are asked to set it to false, we return 0 ("success") without
// running tests.
-static int test(bool useCpuOnly, bool computeUsesSynchronousAPI, bool allowSyncExecHal = true,
- bool computeUsesBurstAPI = false) {
+static int test(bool useCpuOnly, Execution::ComputeMode computeMode, bool allowSyncExecHal = true) {
#ifdef NNTEST_ONLY_PUBLIC_API
if (useCpuOnly || !allowSyncExecHal) {
return 0;
@@ -51,14 +50,24 @@
android::nn::DeviceManager::get()->setSyncExecHal(allowSyncExecHal);
#endif
- Execution::setComputeUsesSynchronousAPI(computeUsesSynchronousAPI);
- Execution::setComputeUsesBurstAPI(computeUsesBurstAPI);
+ Execution::setComputeMode(computeMode);
- LOG(INFO) << "test(useCpuOnly = " << useCpuOnly
- << ", computeUsesSynchronousAPI = " << computeUsesSynchronousAPI
+ auto computeModeText = [computeMode] {
+ switch (computeMode) {
+ case Execution::ComputeMode::SYNC:
+ return "ComputeMode::SYNC";
+ case Execution::ComputeMode::ASYNC:
+ return "ComputeMode::ASYNC";
+ case Execution::ComputeMode::BURST:
+ return "ComputeMode::BURST";
+ }
+ return "<unknown ComputeMode>";
+ };
+
+ LOG(INFO) << "test(useCpuOnly = " << useCpuOnly << ", computeMode = " << computeModeText()
<< ", allowSyncExecHal = " << allowSyncExecHal << ")";
std::cout << "[**********] useCpuOnly = " << useCpuOnly
- << ", computeUsesSynchronousAPI = " << computeUsesSynchronousAPI
+ << ", computeMode = " << computeModeText()
<< ", allowSyncExecHal = " << allowSyncExecHal << std::endl;
return RUN_ALL_TESTS();
}
@@ -70,22 +79,25 @@
android::nn::initVLogMask();
#endif
- int n = test(false, false) | test(false, true) | test(true, false) | test(true, true);
+ int n = test(/*useCpuOnly=*/false, Execution::ComputeMode::ASYNC) |
+ test(/*useCpuOnly=*/false, Execution::ComputeMode::SYNC) |
+ test(/*useCpuOnly=*/true, Execution::ComputeMode::ASYNC) |
+ test(/*useCpuOnly=*/true, Execution::ComputeMode::SYNC);
// Now try disabling use of synchronous execution HAL.
//
// Whether or not the use of synchronous execution HAL is enabled should make no
// difference when useCpuOnly = true; we already ran test(true, *, true) above,
// so there's no reason to run test(true, *, false) now.
- n |= test(false, false, false) | test(false, true, false);
+ n |= test(/*useCpuOnly=*/false, Execution::ComputeMode::ASYNC, /*allowSyncExecHal=*/false) |
+ test(/*useCpuOnly=*/false, Execution::ComputeMode::SYNC, /*allowSyncExecHal=*/false);
// Now try execution using a burst.
//
// The burst path is off by default in these tests. This is the first case
- // where it is turned on. Both "computeUsesSynchronousAPI" and
- // "allowSyncExecHal" are irrelevant here because the burst path is separate
- // from both.
- n |= test(false, false, false, true);
+ // where it is turned on. Both "useCpuOnly" and "allowSyncExecHal" are
+ // irrelevant here because the burst path is separate from both.
+ n |= test(/*useCpuOnly=*/false, Execution::ComputeMode::BURST);
return n;
}
diff --git a/runtime/test/TestNeuralNetworksWrapper.cpp b/runtime/test/TestNeuralNetworksWrapper.cpp
index 9d61f49..9daa09d 100644
--- a/runtime/test/TestNeuralNetworksWrapper.cpp
+++ b/runtime/test/TestNeuralNetworksWrapper.cpp
@@ -20,9 +20,7 @@
namespace nn {
namespace test_wrapper {
-bool Execution::mComputeUsesBurstAPI = false;
-
-bool Execution::mComputeUsesSychronousAPI = true;
+Execution::ComputeMode Execution::mComputeMode = Execution::ComputeMode::SYNC;
} // namespace test_wrapper
} // namespace nn
diff --git a/runtime/test/TestNeuralNetworksWrapper.h b/runtime/test/TestNeuralNetworksWrapper.h
index ee3da9a..be1d4ea 100644
--- a/runtime/test/TestNeuralNetworksWrapper.h
+++ b/runtime/test/TestNeuralNetworksWrapper.h
@@ -162,42 +162,46 @@
}
Result compute() {
- if (mComputeUsesBurstAPI) {
- ANeuralNetworksBurst* burst = nullptr;
- Result result = static_cast<Result>(ANeuralNetworksBurst_create(mCompilation, &burst));
- if (result != Result::NO_ERROR) {
+ switch (mComputeMode) {
+ case ComputeMode::SYNC: {
+ return static_cast<Result>(ANeuralNetworksExecution_compute(mExecution));
+ }
+ case ComputeMode::ASYNC: {
+ ANeuralNetworksEvent* event = nullptr;
+ Result result = static_cast<Result>(
+ ANeuralNetworksExecution_startCompute(mExecution, &event));
+ if (result != Result::NO_ERROR) {
+ return result;
+ }
+ // TODO how to manage the lifetime of events when multiple waiters is not
+ // clear.
+ result = static_cast<Result>(ANeuralNetworksEvent_wait(event));
+ ANeuralNetworksEvent_free(event);
+ return result;
+ }
+ case ComputeMode::BURST: {
+ ANeuralNetworksBurst* burst = nullptr;
+ Result result =
+ static_cast<Result>(ANeuralNetworksBurst_create(mCompilation, &burst));
+ if (result != Result::NO_ERROR) {
+ return result;
+ }
+ result = static_cast<Result>(
+ ANeuralNetworksExecution_burstCompute(mExecution, burst));
ANeuralNetworksBurst_free(burst);
return result;
}
- result = static_cast<Result>(ANeuralNetworksExecution_burstCompute(mExecution, burst));
- ANeuralNetworksBurst_free(burst);
- return result;
}
-
- if (!mComputeUsesSychronousAPI) {
- ANeuralNetworksEvent* event = nullptr;
- Result result =
- static_cast<Result>(ANeuralNetworksExecution_startCompute(mExecution, &event));
- if (result != Result::NO_ERROR) {
- return result;
- }
- // TODO how to manage the lifetime of events when multiple waiters is not
- // clear.
- result = static_cast<Result>(ANeuralNetworksEvent_wait(event));
- ANeuralNetworksEvent_free(event);
- return result;
- }
-
- return static_cast<Result>(ANeuralNetworksExecution_compute(mExecution));
+ return Result::BAD_DATA;
}
- // By default, compute() uses the synchronous API.
- // setComputeUsesSynchronousAPI() can be used to change the behavior of
- // compute() to instead use the asynchronous API and then wait for
- // computation to complete.
- static void setComputeUsesSynchronousAPI(bool val) { mComputeUsesSychronousAPI = val; }
-
- static void setComputeUsesBurstAPI(bool val) { mComputeUsesBurstAPI = val; }
+ // By default, compute() uses the synchronous API. setComputeMode() can be
+ // used to change the behavior of compute() to either:
+ // - use the asynchronous API and then wait for computation to complete
+ // or
+ // - use the burst API
+ enum class ComputeMode { SYNC, ASYNC, BURST };
+ static void setComputeMode(ComputeMode mode) { mComputeMode = mode; }
Result getOutputOperandDimensions(uint32_t index, std::vector<uint32_t>* dimensions) {
uint32_t rank = 0;
@@ -217,11 +221,8 @@
ANeuralNetworksCompilation* mCompilation = nullptr;
ANeuralNetworksExecution* mExecution = nullptr;
- // Initialized to false in TestNeuralNetworksWrapper.cpp.
- static bool mComputeUsesBurstAPI;
-
- // Initialized to true in TestNeuralNetworksWrapper.cpp.
- static bool mComputeUsesSychronousAPI;
+ // Initialized to ComputeMode::SYNC in TestNeuralNetworksWrapper.cpp.
+ static ComputeMode mComputeMode;
};
} // namespace test_wrapper
diff --git a/runtime/test/TestValidation.cpp b/runtime/test/TestValidation.cpp
index 4e95fa5..c3a6f0f 100644
--- a/runtime/test/TestValidation.cpp
+++ b/runtime/test/TestValidation.cpp
@@ -1111,6 +1111,15 @@
ANEURALNETWORKS_UNEXPECTED_NULL);
}
+TEST_F(ValidationTestBurst, BurstComputeBadCompilation) {
+ ANeuralNetworksCompilation* compilation;
+ ASSERT_EQ(ANeuralNetworksCompilation_create(mModel, &compilation), ANEURALNETWORKS_NO_ERROR);
+ // NOTE: ANeuralNetworksCompilation_finish not called
+
+ ANeuralNetworksBurst* burst;
+ EXPECT_EQ(ANeuralNetworksBurst_create(compilation, &burst), ANEURALNETWORKS_BAD_STATE);
+}
+
TEST_F(ValidationTestBurst, BurstComputeDifferentCompilations) {
ANeuralNetworksCompilation* secondCompilation;
ASSERT_EQ(ANeuralNetworksCompilation_create(mModel, &secondCompilation),
@@ -1160,15 +1169,20 @@
sizeof(outputB0)),
ANEURALNETWORKS_NO_ERROR);
- // execute on the same burst concurrently
+ // Execute on the same burst concurrently. At least one result must be
+ // ANEURALNETWORKS_NO_ERROR. One may return ANEURALNETWORKS_BAD_STATE if the
+ // other is already executing on the burst.
auto first = std::async(std::launch::async, [this] {
- const int result = ANeuralNetworksExecution_burstCompute(mExecution, mBurst);
- EXPECT_TRUE(result == ANEURALNETWORKS_BAD_STATE || result == ANEURALNETWORKS_NO_ERROR);
+ return ANeuralNetworksExecution_burstCompute(mExecution, mBurst);
});
auto second = std::async(std::launch::async, [this, secondExecution] {
- const int result = ANeuralNetworksExecution_burstCompute(secondExecution, mBurst);
- EXPECT_TRUE(result == ANEURALNETWORKS_BAD_STATE || result == ANEURALNETWORKS_NO_ERROR);
+ return ANeuralNetworksExecution_burstCompute(secondExecution, mBurst);
});
+ const int result1 = first.get();
+ const int result2 = second.get();
+ EXPECT_TRUE(result1 == ANEURALNETWORKS_BAD_STATE || result1 == ANEURALNETWORKS_NO_ERROR);
+ EXPECT_TRUE(result2 == ANEURALNETWORKS_BAD_STATE || result2 == ANEURALNETWORKS_NO_ERROR);
+ EXPECT_TRUE(result1 == ANEURALNETWORKS_NO_ERROR || result2 == ANEURALNETWORKS_NO_ERROR);
}
TEST(ValidationTestIntrospection, GetNumDevices) {