Add rank checks to validation functions
The change adds rank checks to validation of operations that only
support tensors of rank 4 or less. This requirement comes from legacy TF
Lite code and is likely to be relaxed in the future to be on par with TF
Lite.
Adding the checks to validation is benefitial for the TF Lite delegate
since in case of a validation error NNAPI node will be fully rejected by
the delegation but execution error will cause TF Lite to run NNAPI node
during every invocation only to receive an error and do the calculation
using CPU implementation.
Bug: 139957496
Test: NNTest_static
Change-Id: I5cc4c48e775826a237d5ac54c3d2078254bd17a2
Merged-In: I5cc4c48e775826a237d5ac54c3d2078254bd17a2
(cherry picked from commit 1b4e152880742eca784e9c1d11f04f14a26c4836)
diff --git a/common/operations/Reduce.cpp b/common/operations/Reduce.cpp
index b3327c9..8b21552 100644
--- a/common/operations/Reduce.cpp
+++ b/common/operations/Reduce.cpp
@@ -79,6 +79,10 @@
NN_RET_CHECK(
validateInputTypes(context, {inputType, OperandType::TENSOR_INT32, OperandType::BOOL}));
NN_RET_CHECK(validateOutputTypes(context, {inputType}));
+ const Shape& input = context->getInputShape(kInputTensor);
+ if (hasKnownRank(input)) {
+ NN_RET_CHECK_LE(getNumberOfDimensions(input), 4);
+ }
return validateHalVersion(context, HalVersion::V1_2);
}
@@ -98,6 +102,10 @@
if (inputType == OperandType::TENSOR_QUANT8_ASYMM_SIGNED) {
minHalVersion = HalVersion::V1_3;
}
+ const Shape& input = context->getInputShape(kInputTensor);
+ if (hasKnownRank(input)) {
+ NN_RET_CHECK_LE(getNumberOfDimensions(input), 4);
+ }
return validateHalVersion(context, minHalVersion);
}
@@ -110,12 +118,17 @@
NN_RET_CHECK(
validateInputTypes(context, {inputType, OperandType::TENSOR_INT32, OperandType::BOOL}));
NN_RET_CHECK(validateOutputTypes(context, {inputType}));
+ const Shape& input = context->getInputShape(kInputTensor);
+ if (hasKnownRank(input)) {
+ NN_RET_CHECK_LE(getNumberOfDimensions(input), 4);
+ }
return validateHalVersion(context, HalVersion::V1_2);
}
bool prepare(IOperationExecutionContext* context) {
Shape inputShape = context->getInputShape(kInputTensor);
const uint32_t inputRank = getNumberOfDimensions(inputShape);
+ NN_RET_CHECK_LE(inputRank, 4);
std::vector<bool> shouldReduce(inputRank);
const int32_t* axes = context->getInputBuffer<int32_t>(kInputAxes);