LSTM: require input layer norm weights to be omitted in case CIFG is used.

In case of CIFG LSTM, input layer norm weights are not used in the
computation.

Bug: 129126572
Test: NeuralNetworksTest_static
Change-Id: Id835e7a8f7fa80354eb280cfc0154d9e235695aa
Merged-In: Id835e7a8f7fa80354eb280cfc0154d9e235695aa
(cherry picked from commit 5a9edc3c1290494f0a339e4b2e0038bfe59df416)
diff --git a/common/operations/LSTM.cpp b/common/operations/LSTM.cpp
index 477ff54..3bc4072 100644
--- a/common/operations/LSTM.cpp
+++ b/common/operations/LSTM.cpp
@@ -22,6 +22,7 @@
 #include "OperationsUtils.h"
 
 #include "Tracing.h"
+#include "Utils.h"
 
 namespace android {
 namespace nn {
@@ -92,10 +93,14 @@
     // We check the version of LSTM by checking the number of the inputs to the
     // op. For LSTM version 1.0 there were 23 inputs and for 1.2 there are 27.
     if (operation.inputs.size() == 27) {
-        input_layer_norm_weights_ = GetInput(operation, operands, kInputLayerNormWeightsTensor);
-        forget_layer_norm_weights_ = GetInput(operation, operands, kForgetLayerNormWeightsTensor);
-        cell_layer_norm_weights_ = GetInput(operation, operands, kCellLayerNormWeightsTensor);
-        output_layer_norm_weights_ = GetInput(operation, operands, kOutputLayerNormWeightsTensor);
+        input_layer_norm_weights_ =
+                GetInput(operation, operands, kInputLayerNormWeightsTensor);  // optional
+        forget_layer_norm_weights_ =
+                GetInput(operation, operands, kForgetLayerNormWeightsTensor);  // optional
+        cell_layer_norm_weights_ =
+                GetInput(operation, operands, kCellLayerNormWeightsTensor);  // optional
+        output_layer_norm_weights_ =
+                GetInput(operation, operands, kOutputLayerNormWeightsTensor);  // optional
     } else {
         // For LSTM from HAL v1.0 assign operands with no values
         static RunTimeOperandInfo no_value;
@@ -203,7 +208,9 @@
     // Since we have already checked that weights are all there or none, we can
     // check the existence of only one to the get the condition.
     params->use_peephole = !IsNullInput(cell_to_output_weights);
-    params->use_layer_norm = !IsNullInput(input_layer_norm_weights);
+    // Checking output instead of input layer norm weights because input can be
+    // omitted ones can be omited in case CIFG LSTM is used.
+    params->use_layer_norm = !IsNullInput(output_layer_norm_weights);
 
     params->use_projection_weight = (projection_weights->lifetime != OperandLifeTime::NO_VALUE);
     params->use_projection_bias = (projection_bias->lifetime != OperandLifeTime::NO_VALUE);
@@ -262,12 +269,24 @@
         NN_CHECK_EQ(SizeOfDimension(output_layer_norm_weights, 0), n_cell);
     }
 
-    const bool layer_norm_weights_all_or_none =
-            (IsNullInput(input_layer_norm_weights) && IsNullInput(forget_layer_norm_weights) &&
-             IsNullInput(cell_layer_norm_weights) && IsNullInput(input_layer_norm_weights)) ||
-            (!IsNullInput(input_layer_norm_weights) && !IsNullInput(forget_layer_norm_weights) &&
-             !IsNullInput(cell_layer_norm_weights) && !IsNullInput(input_layer_norm_weights));
-    NN_CHECK(layer_norm_weights_all_or_none);
+    if (params->use_cifg) {
+        NN_RET_CHECK(IsNullInput(input_layer_norm_weights))
+                << "input_layer_norm_weights are provided while CIFG is used";
+        const bool layer_norm_weights_all_or_none_cifg =
+                (IsNullInput(forget_layer_norm_weights) && IsNullInput(cell_layer_norm_weights) &&
+                 IsNullInput(output_layer_norm_weights)) ||
+                (!IsNullInput(forget_layer_norm_weights) && !IsNullInput(cell_layer_norm_weights) &&
+                 !IsNullInput(output_layer_norm_weights));
+        NN_RET_CHECK(layer_norm_weights_all_or_none_cifg);
+    } else {
+        const bool layer_norm_weights_all_or_none =
+                (IsNullInput(input_layer_norm_weights) && IsNullInput(forget_layer_norm_weights) &&
+                 IsNullInput(cell_layer_norm_weights) && IsNullInput(output_layer_norm_weights)) ||
+                (!IsNullInput(input_layer_norm_weights) &&
+                 !IsNullInput(forget_layer_norm_weights) && !IsNullInput(cell_layer_norm_weights) &&
+                 !IsNullInput(output_layer_norm_weights));
+        NN_RET_CHECK(layer_norm_weights_all_or_none);
+    }
 
     return true;
 }