Add TENSOR_QUANT8_SYMM_PER_CHANNEL to operand types.

Added 'ExtraParams' union to ANeuralNetworksOperandType.
ANeuralNetworksOperandType::ExtraParams contains various supplementary
parameters for operand types added in API Level 29 or later.

To allow backwards compatibility with code compiled against earlier NDKs,
ANeuralNetworksOperandType.extraParams is ignored for all operand types
introduced before API Level 29.

TENSOR_QUANT8_SYMM_PER_CHANNEL must see a
ANeuralNetworksOperandParamsChannelQuant value in extraParams union.

Bug: 119249581
Test: NeuralNetworksTest_static
Test: VtsHalNeuralnetworksV1_0TargetTest
Test: VtsHalNeuralnetworksV1_1TargetTest
Test: VtsHalNeuralnetworksV1_2TargetTest
Change-Id: I8b22098f2c14a5cd0176429c8bb16ce52f58af99
Merged-In: I8b22098f2c14a5cd0176429c8bb16ce52f58af99
(cherry picked from commit efc0ebd1f72ef0d39aa73687ec97b97ceff41eda)
diff --git a/common/ValidateHal.cpp b/common/ValidateHal.cpp
index ba6f64d..9e23955 100644
--- a/common/ValidateHal.cpp
+++ b/common/ValidateHal.cpp
@@ -18,6 +18,7 @@
 
 #include "ValidateHal.h"
 #include "NeuralNetworks.h"
+#include "OperationsUtils.h"
 #include "Tracing.h"
 #include "Utils.h"
 
@@ -55,6 +56,59 @@
     std::vector<size_t> mPoolSizes;
 };
 
+static bool validateOperandExtraParams(const V1_2::Operand& operand, uint32_t index) {
+    switch (operand.type) {
+        case OperandType::FLOAT32:
+        case OperandType::INT32:
+        case OperandType::UINT32:
+        case OperandType::BOOL:
+        case OperandType::TENSOR_FLOAT32:
+        case OperandType::TENSOR_FLOAT16:
+        case OperandType::TENSOR_INT32:
+        case OperandType::TENSOR_QUANT8_ASYMM:
+        case OperandType::TENSOR_QUANT16_SYMM:
+        case OperandType::TENSOR_BOOL8:
+            NN_RET_CHECK(operand.extraParams.getDiscriminator() ==
+                         V1_2::Operand::ExtraParams::hidl_discriminator::none)
+                    << "Operand " << index << ": Operand of type "
+                    << getOperandTypeName(operand.type) << " with a Channel Quantization params";
+            break;
+        case OperandType::TENSOR_QUANT8_SYMM_PER_CHANNEL: {
+            NN_RET_CHECK(operand.extraParams.getDiscriminator() ==
+                         V1_2::Operand::ExtraParams::hidl_discriminator::channelQuant)
+                    << "Operand " << index << ": Operand of type "
+                    << getOperandTypeName(operand.type) << " without a Channel Quantization params";
+            auto& channelQuant = operand.extraParams.channelQuant();
+
+            size_t count = operand.dimensions.size();
+            NN_RET_CHECK_LT(channelQuant.channelDim, count)
+                    << "Operand " << index << ": Operand of type "
+                    << getOperandTypeName(operand.type)
+                    << " with an invalid channelQuant.channelDim " << channelQuant.channelDim
+                    << ", must be valid dimension index in range [0, " << count << ")";
+            uint32_t expected = operand.dimensions[channelQuant.channelDim];
+            NN_RET_CHECK_EQ(channelQuant.scales.size(), expected)
+                    << "Operand " << index << ": Operand of type "
+                    << getOperandTypeName(operand.type) << " with a wrong-sized scales, "
+                    << "expected " << expected << " was " << channelQuant.scales.size();
+            NN_RET_CHECK_NE(expected, 0)
+                    << "Operand " << index << ": Operand of type "
+                    << getOperandTypeName(operand.type) << " channel dimension "
+                    << channelQuant.channelDim << " is underspecified (can't be 0)";
+            for (uint32_t i = 0; i < expected; ++i) {
+                NN_RET_CHECK_GT(channelQuant.scales[i], .0f)
+                        << "Operand " << index << ": Operand of type "
+                        << getOperandTypeName(operand.type) << " with a negative value in scales["
+                        << i << "]=" << channelQuant.scales[i];
+            }
+        } break;
+        default:
+            // No validation for the OEM types.
+            break;
+    }
+    return true;
+}
+
 template <typename VersionedOperand>
 static bool validateOperands(const hidl_vec<VersionedOperand>& operands,
                              const hidl_vec<uint8_t>& operandValues,
@@ -92,6 +146,7 @@
             case OperandType::TENSOR_QUANT8_ASYMM:
             case OperandType::TENSOR_QUANT16_SYMM:
             case OperandType::TENSOR_BOOL8:
+            case OperandType::TENSOR_QUANT8_SYMM_PER_CHANNEL:
             case OperandType::TENSOR_OEM_BYTE: {
                 if (operand.dimensions.size() == 0) {
                     LOG(ERROR) << "Operand " << index << ": Tensor has dimensions of rank 0";
@@ -120,6 +175,7 @@
             case OperandType::TENSOR_FLOAT16:
             case OperandType::TENSOR_FLOAT32:
             case OperandType::TENSOR_BOOL8:
+            case OperandType::TENSOR_QUANT8_SYMM_PER_CHANNEL:
                 if (operand.scale != 0.f) {
                     LOG(ERROR) << "Operand " << index << ": Operand of type "
                                << getOperandTypeName(operand.type) << " with a non-zero scale ("
@@ -167,6 +223,7 @@
             case OperandType::TENSOR_FLOAT32:
             case OperandType::TENSOR_INT32:
             case OperandType::TENSOR_BOOL8:
+            case OperandType::TENSOR_QUANT8_SYMM_PER_CHANNEL:
                 if (operand.zeroPoint != 0) {
                     LOG(ERROR) << "Operand " << index << ": Operand of type "
                                << getOperandTypeName(operand.type) << " with an non-zero zeroPoint "
@@ -195,6 +252,8 @@
                 break;
         }
 
+        validateOperandExtraParams(operand, index);
+
         // Validate the lifetime and the location.
         const DataLocation& location = operand.location;
         switch (operand.lifetime) {
@@ -529,6 +588,7 @@
         case V1_2::OperandType::TENSOR_QUANT8_ASYMM:
         case V1_2::OperandType::TENSOR_QUANT16_SYMM:
         case V1_2::OperandType::TENSOR_BOOL8:
+        case V1_2::OperandType::TENSOR_QUANT8_SYMM_PER_CHANNEL:
         case V1_2::OperandType::OEM:
         case V1_2::OperandType::TENSOR_OEM_BYTE:
             return true;