Create utility code to clean up dead operands
This change adds the utility function "removeDeadOperands" to remove
operands that are no longer referenced within a model.
This change also adds "simplifyModel" to ModelBuilder to make it so
ModelBuilder::makeModel automatically removes dead operands from the
model. This fixes a problem where ModelBuilder generated dead operands
when removing arguments from operations that matched the default value.
Finally, this change also makes the following changes to
TestCompliance.cpp:
* Because ModelBuilder now reports the proper version after removing
any dead operands, the remaining testAvailableSinceV1_* test cases are
changed to testAvailableSinceVersion.
* testAvailableSinceVersion is similarly created for Request objects,
and the existing testAvailableSinceV1_* for Requests are changed to
instead use this new testAvailableSinceVersion.
Bug: 213801779
Test: mma
Test: NeuralNetworksTest_static
Change-Id: I518a6b0fbd6382284e6e9a7267b0e5b64360eab2
diff --git a/common/Android.bp b/common/Android.bp
index 6010093..61fa140 100644
--- a/common/Android.bp
+++ b/common/Android.bp
@@ -137,6 +137,7 @@
"LegacyUtils.cpp",
"MemoryUtils.cpp",
"MetaModel.cpp",
+ "ModelUtils.cpp",
"QuantUtils.cpp",
"TokenHasher.cpp",
"ValidateHal.cpp",
@@ -278,6 +279,7 @@
"IndexedShapeWrapper.cpp",
"LegacyUtils.cpp",
"MetaModel.cpp",
+ "ModelUtils.cpp",
"TokenHasher.cpp",
],
header_libs: [
diff --git a/common/ModelUtils.cpp b/common/ModelUtils.cpp
new file mode 100644
index 0000000..b6d0c1a
--- /dev/null
+++ b/common/ModelUtils.cpp
@@ -0,0 +1,292 @@
+/*
+ * Copyright (C) 2022 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#define LOG_TAG "ModelUtils"
+
+#include "ModelUtils.h"
+
+#include <android-base/logging.h>
+
+#include <algorithm>
+#include <numeric>
+#include <unordered_set>
+#include <utility>
+#include <vector>
+
+#include "nnapi/TypeUtils.h"
+#include "nnapi/Types.h"
+#include "nnapi/Validation.h"
+
+namespace android::nn {
+namespace {
+
+// Map each `true` value in `includes` with a unique integer. `false` values are ignored. E.g.:
+// includes = {false, true, true, false, true}
+// returned = { X, 0, 1, X, 2}
+std::vector<uint32_t> getMapping(const std::vector<bool>& includes) {
+ std::vector<uint32_t> mapping;
+ mapping.reserve(includes.size());
+ std::transform_exclusive_scan(includes.begin(), includes.end(), std::back_inserter(mapping), 0u,
+ std::plus<>{}, [](bool included) { return included ? 1u : 0u; });
+ return mapping;
+}
+
+// Remap indexes in `indexes` by the mapping `mapping`.
+// Precondition: indexes != nullptr
+void remapIndexes(std::vector<uint32_t>* indexes, const std::vector<uint32_t>& mapping) {
+ CHECK(indexes != nullptr);
+ for (uint32_t& index : (*indexes)) {
+ index = mapping.at(index);
+ }
+}
+
+// Keep elements from `elements` specified by `elementsToKeep`, removing all other elements.
+// Precondition: elements != nullptr
+// Precondition: elements->size() == elementsToKeep.size()
+template <typename Type>
+void keepSelectedElements(std::vector<Type>* elements, const std::vector<bool>& elementsToKeep) {
+ CHECK(elements != nullptr);
+ CHECK_EQ(elements->size(), elementsToKeep.size());
+
+ size_t elementsCopied = 0;
+ for (size_t i = 0; i < elementsToKeep.size(); ++i) {
+ if (elementsToKeep[i]) {
+ if (elementsCopied != i) {
+ (*elements)[elementsCopied] = std::move((*elements)[i]);
+ }
+ elementsCopied++;
+ }
+ }
+ elements->resize(elementsCopied);
+}
+
+// Find which operands in model.main.operands are read or written by model.main.operations and
+// model.main.inputIndexes.
+// Postcondition: returned.size() == model.main.operands.size()
+std::vector<bool> identifyUsedOperands(const Model& model) {
+ std::vector<bool> used(model.main.operands.size(), false);
+ auto markUsed = [&used](const std::vector<uint32_t>& indexes) {
+ std::for_each(indexes.begin(), indexes.end(),
+ [&used](uint32_t index) { used.at(index) = true; });
+ };
+ for (const auto& operation : model.main.operations) {
+ markUsed(operation.inputs);
+ markUsed(operation.outputs);
+ }
+ markUsed(model.main.inputIndexes);
+ CHECK_EQ(used.size(), model.main.operands.size());
+ return used;
+}
+
+// Forward declaration.
+void identifyUsedSubgraphs(uint32_t current, const std::vector<Model::Subgraph>& subgraphs,
+ std::vector<bool>* used);
+
+// Helper function to find which subgraphs are reachable by `operands`.
+// Precondition: used != nullptr
+// Precondition: subgraphs.size() == used->size()
+void identifyUsedSubgraphs(const std::vector<Operand>& operands,
+ const std::vector<Model::Subgraph>& subgraphs, std::vector<bool>* used) {
+ for (const auto& operand : operands) {
+ if (operand.lifetime == Operand::LifeTime::SUBGRAPH) {
+ identifyUsedSubgraphs(operand.location.offset, subgraphs, used);
+ }
+ }
+}
+
+// Helper function to find which subgraphs are reachable by the subgraph at the `current` index, and
+// store when a subgraph is used in `used`. `used` also acts as a cache, ensuring each subgraph is
+// processed at most once.
+// Precondition: used != nullptr
+// Precondition: subgraphs.size() == used->size()
+// Precondition: current < subgraphs.size()
+void identifyUsedSubgraphs(uint32_t current, const std::vector<Model::Subgraph>& subgraphs,
+ std::vector<bool>* used) {
+ CHECK(used != nullptr);
+ CHECK_EQ(subgraphs.size(), used->size());
+ CHECK_LT(current, subgraphs.size());
+
+ // If a subgraph was already marked as used, quickly return to avoid redundant processing.
+ if ((*used)[current]) {
+ return;
+ }
+
+ // Mark the current subgraph as used, then process any subgraph it references recursively.
+ (*used)[current] = true;
+ identifyUsedSubgraphs(subgraphs[current].operands, subgraphs, used);
+}
+
+// Find which subgraphs are reachable by the main operands of `model`.
+// Postcondition: returned.size() == model.referenced.size()
+std::vector<bool> identifyUsedSubgraphs(const Model& model) {
+ std::vector<bool> used(model.referenced.size(), false);
+ identifyUsedSubgraphs(model.main.operands, model.referenced, &used);
+ CHECK_EQ(used.size(), model.referenced.size());
+ return used;
+}
+
+// Helper function to find which pools are used by `subgraph`, and store when a pool is used in
+// `used`.
+// Precondition: used != nullptr
+void identifyUsedPools(const Model::Subgraph& subgraph, std::vector<bool>* used) {
+ CHECK(used != nullptr);
+ for (const auto& operand : subgraph.operands) {
+ if (operand.lifetime == Operand::LifeTime::CONSTANT_REFERENCE) {
+ used->at(operand.location.poolIndex) = true;
+ }
+ }
+}
+
+// Find which pools are used by `model`.
+// Postcondition: returned.size() == model.pools.size()
+std::vector<bool> identifyUsedPools(const Model& model) {
+ std::vector<bool> used(model.pools.size(), false);
+ identifyUsedPools(model.main, &used);
+ for (const auto& subgraph : model.referenced) {
+ identifyUsedPools(subgraph, &used);
+ }
+ CHECK_EQ(used.size(), model.pools.size());
+ return used;
+}
+
+// Fix the DataLocation in `operand` by either remapping an index or by copying constant data.
+// Precondition: operand != nullptr
+// Precondition: newOperandValues != nullptr
+void fixOperandDataLocation(Operand* operand, Model::OperandValues* newOperandValues,
+ const Model::OperandValues& oldOperandValues,
+ const std::vector<uint32_t>& remappedPoolIndex,
+ const std::vector<uint32_t>& remappedSubgraphIndex) {
+ CHECK(operand != nullptr);
+ CHECK(newOperandValues != nullptr);
+
+ switch (operand->lifetime) {
+ case Operand::LifeTime::CONSTANT_COPY: {
+ const uint8_t* data = oldOperandValues.data() + operand->location.offset;
+ const uint32_t length = operand->location.length;
+ operand->location = newOperandValues->append(data, length);
+ break;
+ }
+ case Operand::LifeTime::CONSTANT_REFERENCE:
+ operand->location.poolIndex = remappedPoolIndex.at(operand->location.poolIndex);
+ break;
+ case Operand::LifeTime::SUBGRAPH: {
+ uint32_t& subgraphIndex = operand->location.offset;
+ subgraphIndex = remappedSubgraphIndex.at(subgraphIndex);
+ break;
+ }
+ case Operand::LifeTime::TEMPORARY_VARIABLE:
+ case Operand::LifeTime::SUBGRAPH_INPUT:
+ case Operand::LifeTime::SUBGRAPH_OUTPUT:
+ case Operand::LifeTime::NO_VALUE:
+ case Operand::LifeTime::POINTER:
+ break;
+ }
+}
+
+// Fix all DataLocations in `operands` by either remapping an index or by copying constant data.
+// Precondition: operands != nullptr
+// Precondition: newOperandValues != nullptr
+void fixOperandDataLocations(std::vector<Operand>* operands, Model::OperandValues* newOperandValues,
+ const Model::OperandValues& oldOperandValues,
+ const std::vector<uint32_t>& remappedPoolIndex,
+ const std::vector<uint32_t>& remappedSubgraphIndex) {
+ for (Operand& operand : (*operands)) {
+ fixOperandDataLocation(&operand, newOperandValues, oldOperandValues, remappedPoolIndex,
+ remappedSubgraphIndex);
+ }
+}
+
+// Fix all operands' DataLocations in `model` by either remapping an index or by copying constant
+// data.
+// Precondition: model != nullptr
+void fixOperandDataLocations(Model* model, const std::vector<uint32_t>& remappedPoolIndex,
+ const std::vector<uint32_t>& remappedSubgraphIndex) {
+ const auto operandValues = std::exchange(model->operandValues, Model::OperandValues{});
+ fixOperandDataLocations(&model->main.operands, &model->operandValues, operandValues,
+ remappedPoolIndex, remappedSubgraphIndex);
+ for (auto& subgraph : model->referenced) {
+ fixOperandDataLocations(&subgraph.operands, &model->operandValues, operandValues,
+ remappedPoolIndex, remappedSubgraphIndex);
+ }
+}
+
+// Find which extensions are used in `model`.
+// Postcondition: returned.size() == model.extensionNameToPrefix.size()
+std::vector<bool> identifyUsedExtensions(const Model& model) {
+ std::unordered_set<uint16_t> prefixes;
+ const auto collectPrefix = [&prefixes](const auto& operandOrOperation) {
+ const auto prefix = getExtensionPrefix(static_cast<uint32_t>(operandOrOperation.type));
+ constexpr uint16_t kStandardPrefix = 0u;
+ if (prefix != kStandardPrefix) {
+ prefixes.insert(prefix);
+ }
+ };
+ const auto collectPrefixes = [collectPrefix](const Model::Subgraph& subgraph) {
+ std::for_each(subgraph.operands.begin(), subgraph.operands.end(), collectPrefix);
+ std::for_each(subgraph.operations.begin(), subgraph.operations.end(), collectPrefix);
+ };
+
+ collectPrefixes(model.main);
+ for (const auto& subgraph : model.referenced) {
+ collectPrefixes(subgraph);
+ }
+
+ std::vector<bool> used;
+ used.reserve(model.extensionNameToPrefix.size());
+ for (const auto& extension : model.extensionNameToPrefix) {
+ used.push_back(prefixes.count(extension.prefix) > 0);
+ }
+ CHECK_EQ(used.size(), model.extensionNameToPrefix.size());
+ return used;
+}
+
+} // anonymous namespace
+
+void removeDeadOperands(Model* model) {
+ CHECK(model != nullptr);
+
+ // Keep only the operands which are used.
+ const auto operandsUsed = identifyUsedOperands(*model);
+ keepSelectedElements(&model->main.operands, operandsUsed);
+
+ // Fix operand indexes.
+ const auto mappedOperandIndices = getMapping(operandsUsed);
+ for (auto& operation : model->main.operations) {
+ remapIndexes(&operation.inputs, mappedOperandIndices);
+ remapIndexes(&operation.outputs, mappedOperandIndices);
+ }
+ remapIndexes(&model->main.inputIndexes, mappedOperandIndices);
+ remapIndexes(&model->main.outputIndexes, mappedOperandIndices);
+
+ // Keep only the subgraphs which are used.
+ const auto subgraphsUsed = identifyUsedSubgraphs(*model);
+ keepSelectedElements(&model->referenced, subgraphsUsed);
+
+ // Keep only the pools which are used.
+ const auto poolsUsed = identifyUsedPools(*model);
+ keepSelectedElements(&model->pools, poolsUsed);
+
+ // Fix operand locations.
+ const auto mappedPoolIndices = getMapping(poolsUsed);
+ const auto mappedSubgraphIndices = getMapping(subgraphsUsed);
+ fixOperandDataLocations(model, mappedPoolIndices, mappedSubgraphIndices);
+
+ // Keep only the extensionNameToPrefixes which are used.
+ const auto extensionsUsed = identifyUsedExtensions(*model);
+ keepSelectedElements(&model->extensionNameToPrefix, extensionsUsed);
+}
+
+} // namespace android::nn
diff --git a/common/include/ModelUtils.h b/common/include/ModelUtils.h
new file mode 100644
index 0000000..2a003a3
--- /dev/null
+++ b/common/include/ModelUtils.h
@@ -0,0 +1,39 @@
+/*
+ * Copyright (C) 2022 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef ANDROID_PACKAGES_MODULES_NEURALNETWORKS_COMMON_MODEL_UTILS_H
+#define ANDROID_PACKAGES_MODULES_NEURALNETWORKS_COMMON_MODEL_UTILS_H
+
+#include "nnapi/Types.h"
+
+namespace android::nn {
+
+/**
+ * @brief Removes all dead operands from the main subgraph.
+ *
+ * This function is intended as a cleanup after references to operands are removed from a valid
+ * model (e.g., after an operation is removed), possibly causing the model to be invalid. Calling
+ * removeDeadOperands will restore it as a valid model.
+ *
+ * @pre model != nullptr
+ *
+ * @param model The model to have dead operands removed.
+ */
+void removeDeadOperands(Model* model);
+
+} // namespace android::nn
+
+#endif // ANDROID_PACKAGES_MODULES_NEURALNETWORKS_COMMON_MODEL_UTILS_H
diff --git a/runtime/ModelBuilder.cpp b/runtime/ModelBuilder.cpp
index b8b4ede..2998a74 100644
--- a/runtime/ModelBuilder.cpp
+++ b/runtime/ModelBuilder.cpp
@@ -20,6 +20,7 @@
#include <GraphDump.h>
#include <LegacyUtils.h>
+#include <ModelUtils.h>
#include <android-base/logging.h>
#include <nnapi/Validation.h>
@@ -557,6 +558,7 @@
}
removeTrailingArgumentsWithDefaultValues();
+ simplifyModel();
mCompletedModel = true;
CHECK(calcModelArchHash(modelForValidation, mModelArchHash))
@@ -894,17 +896,18 @@
// A helper class to simplify state management when creating a Model.
class ModelBuilder::ModelMaker {
public:
- static Model run(const ModelBuilder* model);
+ static Model run(const ModelBuilder* model, bool simplifyModel);
private:
static Model::Subgraph makeSubgraph(const ModelBuilder* model);
- ModelMaker() {}
+ explicit ModelMaker(bool simplifyModel) : mSimplifyModel(simplifyModel) {}
Model makeModel(const ModelBuilder* mainModel);
uint32_t addSubgraph(const ModelBuilder* refModel);
void updateOperandLocations(const ModelBuilder* refModel, Model::Subgraph* subgraph);
void addExtensions(const ModelBuilder* model);
void addExtensionWithPrefix(uint16_t prefix);
+ bool mSimplifyModel;
std::vector<Model::Subgraph> mRefSubgraphs;
Model::OperandValues mOperandValues;
MemoryTracker mMemories;
@@ -912,14 +915,18 @@
std::set<uint16_t> mPrefixSet;
};
-Model ModelBuilder::makeModel() const {
- // TODO: Cache the Model to speed up subsequent calls.
- return ModelMaker::run(this);
+void ModelBuilder::simplifyModel() {
+ mSimplifyModel = true;
}
-Model ModelBuilder::ModelMaker::run(const ModelBuilder* model) {
+Model ModelBuilder::makeModel() const {
+ // TODO: Cache the Model to speed up subsequent calls.
+ return ModelMaker::run(this, mSimplifyModel);
+}
+
+Model ModelBuilder::ModelMaker::run(const ModelBuilder* model, bool simplifyModel) {
// run() ensures the state of ModelMaker is destroyed after the call.
- return ModelMaker().makeModel(model);
+ return ModelMaker(simplifyModel).makeModel(model);
}
Model ModelBuilder::ModelMaker::makeModel(const ModelBuilder* mainModel) {
@@ -934,6 +941,9 @@
[](const RuntimeMemory* m) { return m->getMemory(); });
model.relaxComputationFloat32toFloat16 = mainModel->mRelaxComputationFloat32toFloat16;
model.extensionNameToPrefix = std::move(mExtensionNameToPrefix);
+ if (mSimplifyModel) {
+ removeDeadOperands(&model);
+ }
return model;
}
diff --git a/runtime/ModelBuilder.h b/runtime/ModelBuilder.h
index 9553f95..4aa13d6 100644
--- a/runtime/ModelBuilder.h
+++ b/runtime/ModelBuilder.h
@@ -186,6 +186,11 @@
// Copies the large values to a shared memory, if we have any.
int copyLargeValuesToSharedMemory();
+ // Mark that the model should be simplified during ModelBuilder::makeModel, removing arguments
+ // from operations that already match the default values, dead operands, dead pools, dead
+ // subgraphs, and dead extensions.
+ void simplifyModel();
+
// The operations of the graph.
std::vector<Operation> mOperations;
// The mapping from sorted index to the original index of operations in mOperations.
@@ -203,6 +208,10 @@
std::vector<uint32_t> mInputIndexes;
// The indexes of output operands of the model.
std::vector<uint32_t> mOutputIndexes;
+ // Whether the model should be simplified during ModelBuilder::makeModel, removing arguments
+ // from operations that already match the default values, dead operands, dead pools, dead
+ // subgraphs, and dead extensions.
+ bool mSimplifyModel = false;
MemoryTracker mMemories;
diff --git a/runtime/test/TestCompliance.cpp b/runtime/test/TestCompliance.cpp
index 6634778..6f6407e 100644
--- a/runtime/test/TestCompliance.cpp
+++ b/runtime/test/TestCompliance.cpp
@@ -14,11 +14,10 @@
* limitations under the License.
*/
-#include <HalInterfaces.h>
-#include <MemoryUtils.h>
-#include <Utils.h>
#include <android-base/scopeguard.h>
#include <gtest/gtest.h>
+#include <nnapi/SharedMemory.h>
+#include <nnapi/Types.h>
#include <nnapi/Validation.h>
#include "GeneratedTestUtils.h"
@@ -33,7 +32,6 @@
namespace android::nn::compliance_test {
using namespace test_helper;
-using HidlModel = V1_3::Model;
using WrapperModel = test_wrapper::Model;
using WrapperOperandType = test_wrapper::OperandType;
using WrapperType = test_wrapper::Type;
@@ -49,54 +47,15 @@
EXPECT_TRUE(modelBuilder->isValid());
Model model = modelBuilder->makeModel();
const auto modelVersion = validate(model);
- ASSERT_TRUE(modelVersion.ok());
+ ASSERT_TRUE(modelVersion.ok()) << modelVersion.error();
ASSERT_EQ(testVersion, modelVersion.value());
}
-// Creates a HIDL model from a wrapper model.
-static HidlModel createHidlModel(const WrapperModel& wrapperModel) {
- auto modelBuilder = reinterpret_cast<const ModelBuilder*>(wrapperModel.getHandle());
- EXPECT_TRUE(modelBuilder->isFinished());
- EXPECT_TRUE(modelBuilder->isValid());
- return convertToV1_3(modelBuilder->makeModel());
-}
-
-static void testAvailableSinceV1_3(const WrapperModel& wrapperModel) {
- HidlModel hidlModel = createHidlModel(wrapperModel);
- ASSERT_FALSE(compliantWithV1_2(hidlModel));
- ASSERT_FALSE(compliantWithV1_1(hidlModel));
- ASSERT_FALSE(compliantWithV1_0(hidlModel));
-}
-
-static void testAvailableSinceV1_2(const WrapperModel& wrapperModel) {
- HidlModel hidlModel = createHidlModel(wrapperModel);
- ASSERT_TRUE(compliantWithV1_2(hidlModel));
- ASSERT_FALSE(compliantWithV1_1(hidlModel));
- ASSERT_FALSE(compliantWithV1_0(hidlModel));
-}
-
-static void testAvailableSinceV1_1(const WrapperModel& wrapperModel) {
- HidlModel hidlModel = createHidlModel(wrapperModel);
- ASSERT_TRUE(compliantWithV1_2(hidlModel));
- ASSERT_TRUE(compliantWithV1_1(hidlModel));
- ASSERT_FALSE(compliantWithV1_0(hidlModel));
-}
-
-static void testAvailableSinceV1_0(const WrapperModel& wrapperModel) {
- HidlModel hidlModel = createHidlModel(wrapperModel);
- ASSERT_TRUE(compliantWithV1_2(hidlModel));
- ASSERT_TRUE(compliantWithV1_1(hidlModel));
- ASSERT_TRUE(compliantWithV1_0(hidlModel));
-}
-
-[[maybe_unused]] static void testAvailableSinceV1_2(const V1_3::Request& request) {
- ASSERT_FALSE(compliantWithV1_0(request));
- ASSERT_TRUE(compliantWithV1_2(request));
-}
-
-static void testAvailableSinceV1_3(const V1_3::Request& request) {
- ASSERT_FALSE(compliantWithV1_0(request));
- ASSERT_FALSE(compliantWithV1_2(request));
+// Verifies the earliest supported version for the request.
+static void testAvailableSinceVersion(const Request& request, const Version testVersion) {
+ const auto requestVersion = validate(request);
+ ASSERT_TRUE(requestVersion.ok()) << requestVersion.error();
+ ASSERT_EQ(testVersion, requestVersion.value());
}
static const WrapperOperandType kTypeTensorFloat(WrapperType::TENSOR_FLOAT32, {1});
@@ -116,7 +75,7 @@
model.identifyInputsAndOutputs({op1, op2}, {op3});
ASSERT_TRUE(model.isValid());
model.finish();
- testAvailableSinceV1_2(model);
+ testAvailableSinceVersion(model, kVersionFeatureLevel3);
}
TEST_F(ComplianceTest, Rank0TensorModelOutput) {
@@ -130,7 +89,7 @@
model.identifyInputsAndOutputs({op1, op2}, {op3});
ASSERT_TRUE(model.isValid());
model.finish();
- testAvailableSinceV1_2(model);
+ testAvailableSinceVersion(model, kVersionFeatureLevel3);
}
TEST_F(ComplianceTest, Rank0TensorTemporaryVariable) {
@@ -147,7 +106,7 @@
model.identifyInputsAndOutputs({op1, op2, op4}, {op5});
ASSERT_TRUE(model.isValid());
model.finish();
- testAvailableSinceV1_2(model);
+ testAvailableSinceVersion(model, kVersionFeatureLevel3);
}
// Hardware buffers are an Android concept, which aren't necessarily
@@ -182,53 +141,80 @@
model.identifyInputsAndOutputs({op1}, {op3});
ASSERT_TRUE(model.isValid());
model.finish();
- testAvailableSinceV1_2(model);
+ testAvailableSinceVersion(model, kVersionFeatureLevel3);
}
TEST_F(ComplianceTest, HardwareBufferRequest) {
- const auto [n, ahwb] = MemoryRuntimeAHWB::create(1024);
+ constexpr size_t kAhwbMemorySize = 1024;
+ const auto [n, ahwb] = MemoryRuntimeAHWB::create(kAhwbMemorySize);
ASSERT_EQ(n, ANEURALNETWORKS_NO_ERROR);
- V1_3::Request::MemoryPool sharedMemoryPool,
- ahwbMemoryPool = convertToV1_3(ahwb->getMemoryPool());
- sharedMemoryPool.hidlMemory(allocateSharedMemory(1024));
- ASSERT_TRUE(sharedMemoryPool.hidlMemory().valid());
- ASSERT_TRUE(ahwbMemoryPool.hidlMemory().valid());
+ const Request::MemoryPool ahwbMemoryPool = ahwb->getMemoryPool();
+
+ constexpr size_t kSharedMemorySize = 1024;
+ auto maybeSharedMemoryPool = createSharedMemory(kSharedMemorySize);
+ ASSERT_TRUE(maybeSharedMemoryPool.ok()) << maybeSharedMemoryPool.error().message;
+ const Request::MemoryPool sharedMemoryPool = std::move(maybeSharedMemoryPool).value();
// AHardwareBuffer as input.
- testAvailableSinceV1_2(V1_3::Request{
- .inputs = {{.hasNoValue = false, .location = {.poolIndex = 0}, .dimensions = {}}},
- .outputs = {{.hasNoValue = false, .location = {.poolIndex = 1}, .dimensions = {}}},
- .pools = {ahwbMemoryPool, sharedMemoryPool},
- });
+ testAvailableSinceVersion(
+ Request{
+ .inputs = {{.lifetime = Request::Argument::LifeTime::POOL,
+ .location = {.poolIndex = 0, .length = kAhwbMemorySize},
+ .dimensions = {}}},
+ .outputs = {{.lifetime = Request::Argument::LifeTime::POOL,
+ .location = {.poolIndex = 1, .length = kSharedMemorySize},
+ .dimensions = {}}},
+ .pools = {ahwbMemoryPool, sharedMemoryPool},
+ },
+ kVersionFeatureLevel3);
// AHardwareBuffer as output.
- testAvailableSinceV1_2(V1_3::Request{
- .inputs = {{.hasNoValue = false, .location = {.poolIndex = 0}, .dimensions = {}}},
- .outputs = {{.hasNoValue = false, .location = {.poolIndex = 1}, .dimensions = {}}},
- .pools = {sharedMemoryPool, ahwbMemoryPool},
- });
+ testAvailableSinceVersion(
+ Request{
+ .inputs = {{.lifetime = Request::Argument::LifeTime::POOL,
+ .location = {.poolIndex = 0, .length = kSharedMemorySize},
+ .dimensions = {}}},
+ .outputs = {{.lifetime = Request::Argument::LifeTime::POOL,
+ .location = {.poolIndex = 1, .length = kAhwbMemorySize},
+ .dimensions = {}}},
+ .pools = {sharedMemoryPool, ahwbMemoryPool},
+ },
+ kVersionFeatureLevel3);
}
#endif
TEST_F(ComplianceTest, DeviceMemory) {
- V1_3::Request::MemoryPool sharedMemoryPool, deviceMemoryPool;
- sharedMemoryPool.hidlMemory(allocateSharedMemory(1024));
- ASSERT_TRUE(sharedMemoryPool.hidlMemory().valid());
- deviceMemoryPool.token(1);
+ constexpr size_t kSharedMemorySize = 1024;
+ auto maybeSharedMemoryPool = createSharedMemory(kSharedMemorySize);
+ ASSERT_TRUE(maybeSharedMemoryPool.ok()) << maybeSharedMemoryPool.error().message;
+ const Request::MemoryPool sharedMemoryPool = std::move(maybeSharedMemoryPool).value();
+ const Request::MemoryPool deviceMemoryPool = Request::MemoryDomainToken(1);
// Device memory as input.
- testAvailableSinceV1_3(V1_3::Request{
- .inputs = {{.hasNoValue = false, .location = {.poolIndex = 0}, .dimensions = {}}},
- .outputs = {{.hasNoValue = false, .location = {.poolIndex = 1}, .dimensions = {}}},
- .pools = {deviceMemoryPool, sharedMemoryPool},
- });
+ testAvailableSinceVersion(
+ Request{
+ .inputs = {{.lifetime = Request::Argument::LifeTime::POOL,
+ .location = {.poolIndex = 0},
+ .dimensions = {}}},
+ .outputs = {{.lifetime = Request::Argument::LifeTime::POOL,
+ .location = {.poolIndex = 1, .length = kSharedMemorySize},
+ .dimensions = {}}},
+ .pools = {deviceMemoryPool, sharedMemoryPool},
+ },
+ kVersionFeatureLevel4);
// Device memory as output.
- testAvailableSinceV1_3(V1_3::Request{
- .inputs = {{.hasNoValue = false, .location = {.poolIndex = 0}, .dimensions = {}}},
- .outputs = {{.hasNoValue = false, .location = {.poolIndex = 1}, .dimensions = {}}},
- .pools = {sharedMemoryPool, deviceMemoryPool},
- });
+ testAvailableSinceVersion(
+ Request{
+ .inputs = {{.lifetime = Request::Argument::LifeTime::POOL,
+ .location = {.poolIndex = 0, .length = kSharedMemorySize},
+ .dimensions = {}}},
+ .outputs = {{.lifetime = Request::Argument::LifeTime::POOL,
+ .location = {.poolIndex = 1},
+ .dimensions = {}}},
+ .pools = {sharedMemoryPool, deviceMemoryPool},
+ },
+ kVersionFeatureLevel4);
}
class GeneratedComplianceTest : public generated_tests::GeneratedTestBase {};
@@ -240,18 +226,17 @@
model.finish();
switch (testModel.minSupportedVersion) {
// TODO(b/209797313): Unify HalVersion and Version.
- // TODO(b/213801779): Use testAvailableSinceVersion for HIDL.
case TestHalVersion::V1_0:
- testAvailableSinceV1_0(model);
+ testAvailableSinceVersion(model, kVersionFeatureLevel1);
break;
case TestHalVersion::V1_1:
- testAvailableSinceV1_1(model);
+ testAvailableSinceVersion(model, kVersionFeatureLevel2);
break;
case TestHalVersion::V1_2:
- testAvailableSinceV1_2(model);
+ testAvailableSinceVersion(model, kVersionFeatureLevel3);
break;
case TestHalVersion::V1_3:
- testAvailableSinceV1_3(model);
+ testAvailableSinceVersion(model, kVersionFeatureLevel4);
break;
case TestHalVersion::AIDL_V1:
testAvailableSinceVersion(model, kVersionFeatureLevel5);