Add rank checks to validation functions

The change adds rank checks to validation of operations that only
support tensors of rank 4 or less. This requirement comes from legacy TF
Lite code and is likely to be relaxed in the future to be on par with TF
Lite.
Adding the checks to validation is benefitial for the TF Lite delegate
since in case of a validation error NNAPI node will be fully rejected by
the delegation but execution error will cause TF Lite to run NNAPI node
during every invocation only to receive an error and do the calculation
using CPU implementation.

Bug: 139957496
Test: NNTest_static
Change-Id: I5cc4c48e775826a237d5ac54c3d2078254bd17a2
Merged-In: I5cc4c48e775826a237d5ac54c3d2078254bd17a2
(cherry picked from commit 1b4e152880742eca784e9c1d11f04f14a26c4836)
diff --git a/common/operations/Activation.cpp b/common/operations/Activation.cpp
index a6a3e82..f12ed7a 100644
--- a/common/operations/Activation.cpp
+++ b/common/operations/Activation.cpp
@@ -375,6 +375,10 @@
     } else {
         NN_RET_CHECK_FAIL() << "Unsupported tensor type for operation " << getOperationName(opType);
     }
+    const Shape& input = context->getInputShape(kInputTensor);
+    if (hasKnownRank(input)) {
+        NN_RET_CHECK_LE(getNumberOfDimensions(input), 4);
+    }
     return validateInputTypes(context, {inputType}) && validateOutputTypes(context, {inputType});
 }
 
diff --git a/common/operations/Broadcast.cpp b/common/operations/Broadcast.cpp
index e19ce74..17094af 100644
--- a/common/operations/Broadcast.cpp
+++ b/common/operations/Broadcast.cpp
@@ -466,6 +466,12 @@
     } else {
         NN_RET_CHECK_FAIL() << "Unsupported tensor type for operation " << getOperationName(opType);
     }
+    const Shape& input1 = context->getInputShape(kInputTensor1);
+    const Shape& input2 = context->getInputShape(kInputTensor2);
+    if (hasKnownRank(input1) && hasKnownRank(input2)) {
+        NN_RET_CHECK_LE(getNumberOfDimensions(input1), 4);
+        NN_RET_CHECK_LE(getNumberOfDimensions(input2), 4);
+    }
     return validateInputTypes(context, {inputType, inputType, OperandType::INT32}) &&
            validateOutputTypes(context, {inputType});
 }
diff --git a/common/operations/ChannelShuffle.cpp b/common/operations/ChannelShuffle.cpp
index c78e496..7abf224 100644
--- a/common/operations/ChannelShuffle.cpp
+++ b/common/operations/ChannelShuffle.cpp
@@ -69,6 +69,10 @@
                  inputType == OperandType::TENSOR_QUANT8_ASYMM ||
                  inputType == OperandType::TENSOR_QUANT8_ASYMM_SIGNED)
             << "Unsupported tensor type for operation " << kOperationName;
+    const Shape& inputShape = context->getInputShape(kInputTensor);
+    if (hasKnownRank(inputShape)) {
+        NN_RET_CHECK_LE(getNumberOfDimensions(inputShape), 4);
+    }
     NN_RET_CHECK(validateInputTypes(context, {inputType, OperandType::INT32, OperandType::INT32}));
     NN_RET_CHECK(validateOutputTypes(context, {inputType}));
     if (inputType == OperandType::TENSOR_QUANT8_ASYMM_SIGNED) {
diff --git a/common/operations/Dequantize.cpp b/common/operations/Dequantize.cpp
index 3505540..2fb2d5c 100644
--- a/common/operations/Dequantize.cpp
+++ b/common/operations/Dequantize.cpp
@@ -83,6 +83,11 @@
     const OperandType inputType = context->getInputType(kInputTensor);
     const OperandType outputType = context->getOutputType(kOutputTensor);
 
+    const Shape& input = context->getInputShape(kInputTensor);
+    if (hasKnownRank(input)) {
+        NN_RET_CHECK_LE(getNumberOfDimensions(input), 4);
+    }
+
     if (inputType == OperandType::TENSOR_QUANT8_ASYMM &&
         outputType == OperandType::TENSOR_FLOAT32) {
         return validateHalVersion(context, HalVersion::V1_0);
@@ -101,6 +106,7 @@
 
 bool prepare(IOperationExecutionContext* context) {
     const Shape& input = context->getInputShape(kInputTensor);
+    NN_RET_CHECK_LE(getNumberOfDimensions(input), 4);
     Shape output = context->getOutputShape(kOutputTensor);
     output.dimensions = input.dimensions;
     return context->setOutputShape(kOutputTensor, output);
diff --git a/common/operations/FullyConnected.cpp b/common/operations/FullyConnected.cpp
index 29fbec7..2afbee0 100644
--- a/common/operations/FullyConnected.cpp
+++ b/common/operations/FullyConnected.cpp
@@ -240,6 +240,13 @@
     }
     NN_RET_CHECK(validateInputTypes(context, inExpectedTypes));
     NN_RET_CHECK(validateOutputTypes(context, {inputType}));
+
+    const Shape& input = context->getInputShape(kInputTensor);
+    if (hasKnownRank(input)) {
+        NN_RET_CHECK_GE(getNumberOfDimensions(input), 2);
+        NN_RET_CHECK_LE(getNumberOfDimensions(input), 4);
+    }
+
     return true;
 }
 
@@ -260,6 +267,7 @@
     // The Tensorflow fully connected layer specification says that input should
     // be of at least rank 2, so we check. Tflite doesn't check.
     NN_RET_CHECK_GE(getNumberOfDimensions(input), 2);
+    NN_RET_CHECK_LE(getNumberOfDimensions(input), 4);
     NN_RET_CHECK_EQ(getNumberOfDimensions(weights), 2);
     uint32_t input_n_elements = getNumberOfElements(input);
     uint32_t num_units = getSizeOfDimension(weights, 0);
diff --git a/common/operations/L2Normalization.cpp b/common/operations/L2Normalization.cpp
index 1925d54..1f0c9d0 100644
--- a/common/operations/L2Normalization.cpp
+++ b/common/operations/L2Normalization.cpp
@@ -221,6 +221,10 @@
     } else if (context->getInputShape(kInputTensor).dimensions.size() != 4) {
         NN_RET_CHECK(validateHalVersion(context, HalVersion::V1_2));
     }
+    const Shape& input = context->getInputShape(kInputTensor);
+    if (hasKnownRank(input)) {
+        NN_RET_CHECK_LE(getNumberOfDimensions(input), 4);
+    }
     return validateInputTypes(context, inExpectedTypes) &&
            validateOutputTypes(context, {inputType});
 }
@@ -231,6 +235,7 @@
     int32_t axis = context->getNumInputs() == kNumInputs
                            ? context->getInputValue<int32_t>(kAxisScalar)
                            : -1;
+    NN_RET_CHECK_LE(numDimensions, 4);
     NN_RET_CHECK_GE(axis, -numDimensions);
     NN_RET_CHECK_LT(axis, numDimensions);
     Shape output = context->getOutputShape(kOutputTensor);
diff --git a/common/operations/Reduce.cpp b/common/operations/Reduce.cpp
index b3327c9..8b21552 100644
--- a/common/operations/Reduce.cpp
+++ b/common/operations/Reduce.cpp
@@ -79,6 +79,10 @@
     NN_RET_CHECK(
             validateInputTypes(context, {inputType, OperandType::TENSOR_INT32, OperandType::BOOL}));
     NN_RET_CHECK(validateOutputTypes(context, {inputType}));
+    const Shape& input = context->getInputShape(kInputTensor);
+    if (hasKnownRank(input)) {
+        NN_RET_CHECK_LE(getNumberOfDimensions(input), 4);
+    }
     return validateHalVersion(context, HalVersion::V1_2);
 }
 
@@ -98,6 +102,10 @@
     if (inputType == OperandType::TENSOR_QUANT8_ASYMM_SIGNED) {
         minHalVersion = HalVersion::V1_3;
     }
+    const Shape& input = context->getInputShape(kInputTensor);
+    if (hasKnownRank(input)) {
+        NN_RET_CHECK_LE(getNumberOfDimensions(input), 4);
+    }
     return validateHalVersion(context, minHalVersion);
 }
 
@@ -110,12 +118,17 @@
     NN_RET_CHECK(
             validateInputTypes(context, {inputType, OperandType::TENSOR_INT32, OperandType::BOOL}));
     NN_RET_CHECK(validateOutputTypes(context, {inputType}));
+    const Shape& input = context->getInputShape(kInputTensor);
+    if (hasKnownRank(input)) {
+        NN_RET_CHECK_LE(getNumberOfDimensions(input), 4);
+    }
     return validateHalVersion(context, HalVersion::V1_2);
 }
 
 bool prepare(IOperationExecutionContext* context) {
     Shape inputShape = context->getInputShape(kInputTensor);
     const uint32_t inputRank = getNumberOfDimensions(inputShape);
+    NN_RET_CHECK_LE(inputRank, 4);
 
     std::vector<bool> shouldReduce(inputRank);
     const int32_t* axes = context->getInputBuffer<int32_t>(kInputAxes);
diff --git a/common/operations/Softmax.cpp b/common/operations/Softmax.cpp
index f9b8ed2..8c05628 100644
--- a/common/operations/Softmax.cpp
+++ b/common/operations/Softmax.cpp
@@ -246,12 +246,15 @@
     } else {
         NN_RET_CHECK_FAIL() << "Unsupported tensor type for operation " << kOperationName;
     }
+    const auto inputRank = getNumberOfDimensions(context->getInputShape(kInputTensor));
+    if (inputRank != 0) {
+        NN_RET_CHECK_LE(inputRank, 4);
+    }
     if (context->getNumInputs() == kNumInputs) {
         NN_RET_CHECK(validateHalVersion(context, HalVersion::V1_2));
         inExpectedTypes.push_back(OperandType::INT32);
     } else {
-        const size_t ndim = context->getInputShape(kInputTensor).dimensions.size();
-        if (ndim != 2 && ndim != 4 && ndim != 0) {
+        if (inputRank != 2 && inputRank != 4 && inputRank != 0) {
             NN_RET_CHECK(validateHalVersion(context, HalVersion::V1_2));
         }
     }
diff --git a/common/operations/Squeeze.cpp b/common/operations/Squeeze.cpp
index ca09703..977856d 100644
--- a/common/operations/Squeeze.cpp
+++ b/common/operations/Squeeze.cpp
@@ -62,6 +62,10 @@
                                                      OperandType::TENSOR_INT32,
                                              }));
     NN_RET_CHECK(validateOutputTypes(context, {inputType}));
+    const Shape& input = context->getInputShape(kInputTensor);
+    if (hasKnownRank(input)) {
+        NN_RET_CHECK_LE(getNumberOfDimensions(input), 4);
+    }
     return validateHalVersion(context, minSupportedHalVersion);
 }
 
@@ -75,6 +79,8 @@
     const Shape squeezeDimsShape = context->getInputShape(kSqueezeDims);
     int32_t numInputDims = static_cast<int32_t>(getNumberOfDimensions(inputShape));
 
+    NN_RET_CHECK_LE(getNumberOfDimensions(inputShape), 4);
+
     // squeezeDims need to be provided as a 1-D int32 tensor.
     NN_OPS_CHECK(squeezeDimsShape.type == OperandType::TENSOR_INT32);
     NN_OPS_CHECK(getNumberOfDimensions(squeezeDimsShape) == 1);
diff --git a/common/operations/StridedSlice.cpp b/common/operations/StridedSlice.cpp
index bcc95f6..8899383 100644
--- a/common/operations/StridedSlice.cpp
+++ b/common/operations/StridedSlice.cpp
@@ -128,6 +128,10 @@
                                                      OperandType::INT32,
                                              }));
     NN_RET_CHECK(validateOutputTypes(context, {inputType}));
+    const Shape& input = context->getInputShape(kInputTensor);
+    if (hasKnownRank(input)) {
+        NN_RET_CHECK_LE(getNumberOfDimensions(input), 4);
+    }
     return validateHalVersion(context, minSupportedHalVersion);
 }
 
diff --git a/common/operations/Transpose.cpp b/common/operations/Transpose.cpp
index e0320c6..ff70f9e 100644
--- a/common/operations/Transpose.cpp
+++ b/common/operations/Transpose.cpp
@@ -87,6 +87,10 @@
     } else {
         NN_RET_CHECK_FAIL() << "Unsupported tensor type for operation " << kOperationName;
     }
+    const Shape& input = context->getInputShape(kInputTensor);
+    if (hasKnownRank(input)) {
+        NN_RET_CHECK_LE(getNumberOfDimensions(input), 4);
+    }
     return validateInputTypes(context, {inputType, OperandType::TENSOR_INT32}) &&
            validateOutputTypes(context, {inputType});
 }