Add layer normalization support to LSTM op

* Add new test to operations test
* Add new test to validation testing
* Make Prepare and CheckInputTensorDimensions non-static.
    This makes it possible to reuse at preparation the inputs that was
    already read in constructor. This is needed to use upcasting between
    versions that is implemented in the constructor by assigning dummy
    values to the inputs that are non-existent in older versions.

Fix: 113562577
Test: NeuralNetworksTest_static
Test: VtsHalNeuralnetworksV1_2TargetTest
Change-Id: I90676abfcdb3d9a969a1418f8474ce383bf7fb07
Merged-In: I90676abfcdb3d9a969a1418f8474ce383bf7fb07
(cherry picked from commit d24a1bb943a970fbf71b25ac64216a60a16ffc37)
diff --git a/common/operations/LSTM.cpp b/common/operations/LSTM.cpp
index 7667c66..ca9f42b 100644
--- a/common/operations/LSTM.cpp
+++ b/common/operations/LSTM.cpp
@@ -21,6 +21,8 @@
 
 #include "Tracing.h"
 
+// TODO(levp): Format the file.
+// clang-format off
 namespace android {
 namespace nn {
 
@@ -73,6 +75,24 @@
   params_.cell_clip_ = getScalarData<float>(*GetInput(operation, operands, kCellClipParam));
   params_.proj_clip_ = getScalarData<float>(*GetInput(operation, operands, kProjClipParam));
 
+  // We check the version of LSTM by checking the number of the inputs to the
+  // op. For LSTM version 1.0 there were 23 inputs and for 1.2 there are 27.
+  if (operation.inputs.size() == 27) {
+    input_layer_norm_weights_ = GetInput(operation, operands, kInputLayerNormWeightsTensor);
+    forget_layer_norm_weights_ = GetInput(operation, operands, kForgetLayerNormWeightsTensor);
+    cell_layer_norm_weights_ = GetInput(operation, operands, kCellLayerNormWeightsTensor);
+    output_layer_norm_weights_ = GetInput(operation, operands, kOutputLayerNormWeightsTensor);
+  } else {
+    // For LSTM from HAL v1.0 assign operands with no values
+    static RunTimeOperandInfo no_value;
+    no_value.lifetime = OperandLifeTime::NO_VALUE;
+
+    input_layer_norm_weights_ = &no_value;
+    forget_layer_norm_weights_ = &no_value;
+    cell_layer_norm_weights_ = &no_value;
+    output_layer_norm_weights_ = &no_value;
+  }
+
   output_state_out_ = GetOutput(operation, operands, kOutputStateOutTensor);
   cell_state_out_ = GetOutput(operation, operands, kCellStateOutTensor);
   output_ = GetOutput(operation, operands, kOutputTensor);
@@ -96,125 +116,95 @@
   NN_CHECK(params.cell_clip_ >= 0);
   NN_CHECK(params.proj_clip_ >= 0);
 
-  const RunTimeOperandInfo *input_to_input_weights =
-      GetInput(operation, operands, LSTMCell::kInputToInputWeightsTensor);
-  if (!IsNullInput(input_to_input_weights)) {
-    NN_CHECK_EQ(NumDimensions(input_to_input_weights), 2);
-    NN_CHECK_EQ(SizeOfDimension(input_to_input_weights, 0), n_cell);
-    NN_CHECK_EQ(SizeOfDimension(input_to_input_weights, 1), n_input);
+  if (!IsNullInput(input_to_input_weights_)) {
+    NN_CHECK_EQ(NumDimensions(input_to_input_weights_), 2);
+    NN_CHECK_EQ(SizeOfDimension(input_to_input_weights_, 0), n_cell);
+    NN_CHECK_EQ(SizeOfDimension(input_to_input_weights_, 1), n_input);
   }
 
-  const RunTimeOperandInfo *input_to_forget_weights =
-      GetInput(operation, operands, LSTMCell::kInputToForgetWeightsTensor);
-  NN_CHECK_EQ(NumDimensions(input_to_forget_weights), 2);
-  NN_CHECK_EQ(SizeOfDimension(input_to_forget_weights, 0), n_cell);
-  NN_CHECK_EQ(SizeOfDimension(input_to_forget_weights, 1), n_input);
+  NN_CHECK_EQ(NumDimensions(input_to_forget_weights_), 2);
+  NN_CHECK_EQ(SizeOfDimension(input_to_forget_weights_, 0), n_cell);
+  NN_CHECK_EQ(SizeOfDimension(input_to_forget_weights_, 1), n_input);
 
-  const RunTimeOperandInfo *input_to_cell_weights =
-      GetInput(operation, operands, LSTMCell::kInputToCellWeightsTensor);
-  NN_CHECK_EQ(NumDimensions(input_to_cell_weights), 2);
-  NN_CHECK_EQ(SizeOfDimension(input_to_cell_weights, 0), n_cell);
-  NN_CHECK_EQ(SizeOfDimension(input_to_cell_weights, 1), n_input);
+  NN_CHECK_EQ(NumDimensions(input_to_cell_weights_), 2);
+  NN_CHECK_EQ(SizeOfDimension(input_to_cell_weights_, 0), n_cell);
+  NN_CHECK_EQ(SizeOfDimension(input_to_cell_weights_, 1), n_input);
 
-  const RunTimeOperandInfo *recurrent_to_input_weights =
-      GetInput(operation, operands, LSTMCell::kRecurrentToInputWeightsTensor);
-  if (!IsNullInput(recurrent_to_input_weights)) {
-    NN_CHECK_EQ(NumDimensions(recurrent_to_input_weights), 2);
-    NN_CHECK_EQ(SizeOfDimension(recurrent_to_input_weights, 0), n_cell);
-    NN_CHECK_EQ(SizeOfDimension(recurrent_to_input_weights, 1), n_output);
+  if (!IsNullInput(recurrent_to_input_weights_)) {
+    NN_CHECK_EQ(NumDimensions(recurrent_to_input_weights_), 2);
+    NN_CHECK_EQ(SizeOfDimension(recurrent_to_input_weights_, 0), n_cell);
+    NN_CHECK_EQ(SizeOfDimension(recurrent_to_input_weights_, 1), n_output);
   }
 
-  const RunTimeOperandInfo *recurrent_to_forget_weights =
-      GetInput(operation, operands, LSTMCell::kRecurrentToForgetWeightsTensor);
-  NN_CHECK_EQ(NumDimensions(recurrent_to_forget_weights), 2);
-  NN_CHECK_EQ(SizeOfDimension(recurrent_to_forget_weights, 0), n_cell);
-  NN_CHECK_EQ(SizeOfDimension(recurrent_to_forget_weights, 1), n_output);
+  NN_CHECK_EQ(NumDimensions(recurrent_to_forget_weights_), 2);
+  NN_CHECK_EQ(SizeOfDimension(recurrent_to_forget_weights_, 0), n_cell);
+  NN_CHECK_EQ(SizeOfDimension(recurrent_to_forget_weights_, 1), n_output);
 
-  const RunTimeOperandInfo *recurrent_to_cell_weights =
-      GetInput(operation, operands, LSTMCell::kRecurrentToCellWeightsTensor);
-  NN_CHECK_EQ(NumDimensions(recurrent_to_cell_weights), 2);
-  NN_CHECK_EQ(SizeOfDimension(recurrent_to_cell_weights, 0), n_cell);
-  NN_CHECK_EQ(SizeOfDimension(recurrent_to_cell_weights, 1), n_output);
+  NN_CHECK_EQ(NumDimensions(recurrent_to_cell_weights_), 2);
+  NN_CHECK_EQ(SizeOfDimension(recurrent_to_cell_weights_, 0), n_cell);
+  NN_CHECK_EQ(SizeOfDimension(recurrent_to_cell_weights_, 1), n_output);
 
   // We make sure the input-gate's parameters are either both present (regular
   // LSTM) or not at all (CIFG-LSTM).
   const bool cifg_weights_all_or_none =
-      (!IsNullInput(input_to_input_weights) &&
-       !IsNullInput(recurrent_to_input_weights)) ||
-      (IsNullInput(input_to_input_weights) &&
-       IsNullInput(recurrent_to_input_weights));
+      (!IsNullInput(input_to_input_weights_) &&
+       !IsNullInput(recurrent_to_input_weights_)) ||
+      (IsNullInput(input_to_input_weights_) &&
+       IsNullInput(recurrent_to_input_weights_));
   NN_CHECK(cifg_weights_all_or_none);
 
-  const RunTimeOperandInfo *cell_to_input_weights =
-      GetInput(operation, operands, LSTMCell::kCellToInputWeightsTensor);
-  if (!IsNullInput(cell_to_input_weights)) {
-    NN_CHECK_EQ(NumDimensions(cell_to_input_weights), 1);
-    NN_CHECK_EQ(SizeOfDimension(cell_to_input_weights, 0), n_cell);
+  if (!IsNullInput(cell_to_input_weights_)) {
+    NN_CHECK_EQ(NumDimensions(cell_to_input_weights_), 1);
+    NN_CHECK_EQ(SizeOfDimension(cell_to_input_weights_, 0), n_cell);
   }
 
-  const RunTimeOperandInfo *cell_to_forget_weights =
-      GetInput(operation, operands, LSTMCell::kCellToForgetWeightsTensor);
-  if (!IsNullInput(cell_to_forget_weights)) {
-    NN_CHECK_EQ(NumDimensions(cell_to_forget_weights), 1);
-    NN_CHECK_EQ(SizeOfDimension(cell_to_forget_weights, 0), n_cell);
+  if (!IsNullInput(cell_to_forget_weights_)) {
+    NN_CHECK_EQ(NumDimensions(cell_to_forget_weights_), 1);
+    NN_CHECK_EQ(SizeOfDimension(cell_to_forget_weights_, 0), n_cell);
   }
 
-  const RunTimeOperandInfo *cell_to_output_weights =
-      GetInput(operation, operands, LSTMCell::kCellToOutputWeightsTensor);
-  if (!IsNullInput(cell_to_output_weights)) {
-    NN_CHECK_EQ(NumDimensions(cell_to_output_weights), 1);
-    NN_CHECK_EQ(SizeOfDimension(cell_to_output_weights, 0), n_cell);
+  if (!IsNullInput(cell_to_output_weights_)) {
+    NN_CHECK_EQ(NumDimensions(cell_to_output_weights_), 1);
+    NN_CHECK_EQ(SizeOfDimension(cell_to_output_weights_, 0), n_cell);
   }
 
   // Making sure the peephole weights are there all or none.
-  const bool use_cifg = IsNullInput(input_to_input_weights);
+  const bool use_cifg = IsNullInput(input_to_input_weights_);
   const bool peephole_weights_all_or_none =
-      ((!IsNullInput(cell_to_input_weights) || use_cifg) &&
-       !IsNullInput(cell_to_forget_weights) &&
-       !IsNullInput(cell_to_output_weights)) ||
-      (IsNullInput(cell_to_input_weights) &&
-       IsNullInput(cell_to_forget_weights) &&
-       IsNullInput(cell_to_output_weights));
+      ((!IsNullInput(cell_to_input_weights_) || use_cifg) &&
+       !IsNullInput(cell_to_forget_weights_) &&
+       !IsNullInput(cell_to_output_weights_)) ||
+      (IsNullInput(cell_to_input_weights_) &&
+       IsNullInput(cell_to_forget_weights_) &&
+       IsNullInput(cell_to_output_weights_));
   NN_CHECK(peephole_weights_all_or_none);
 
   // Make sure the input gate bias is present only when not a CIFG-LSTM.
-  const RunTimeOperandInfo* input_gate_bias =
-      GetInput(operation, operands, LSTMCell::kInputGateBiasTensor);
   if (use_cifg) {
-    NN_CHECK(IsNullInput(input_gate_bias));
+    NN_CHECK(IsNullInput(input_gate_bias_));
   } else {
-    NN_CHECK_EQ(NumDimensions(input_gate_bias), 1);
-    NN_CHECK_EQ(SizeOfDimension(input_gate_bias, 0), n_cell);
+    NN_CHECK_EQ(NumDimensions(input_gate_bias_), 1);
+    NN_CHECK_EQ(SizeOfDimension(input_gate_bias_, 0), n_cell);
   }
 
-  const RunTimeOperandInfo *forget_gate_bias =
-      GetInput(operation, operands, LSTMCell::kForgetGateBiasTensor);
-  NN_CHECK_EQ(NumDimensions(forget_gate_bias), 1);
-  NN_CHECK_EQ(SizeOfDimension(forget_gate_bias, 0), n_cell);
+  NN_CHECK_EQ(NumDimensions(forget_gate_bias_), 1);
+  NN_CHECK_EQ(SizeOfDimension(forget_gate_bias_, 0), n_cell);
 
-  const RunTimeOperandInfo *cell_bias =
-      GetInput(operation, operands, LSTMCell::kCellGateBiasTensor);
-  NN_CHECK_EQ(NumDimensions(cell_bias), 1);
-  NN_CHECK_EQ(SizeOfDimension(cell_bias, 0), n_cell);
+  NN_CHECK_EQ(NumDimensions(cell_bias_), 1);
+  NN_CHECK_EQ(SizeOfDimension(cell_bias_, 0), n_cell);
 
-  const RunTimeOperandInfo *output_gate_bias =
-      GetInput(operation, operands, LSTMCell::kOutputGateBiasTensor);
-  NN_CHECK_EQ(NumDimensions(output_gate_bias), 1);
-  NN_CHECK_EQ(SizeOfDimension(output_gate_bias, 0), n_cell);
+  NN_CHECK_EQ(NumDimensions(output_gate_bias_), 1);
+  NN_CHECK_EQ(SizeOfDimension(output_gate_bias_, 0), n_cell);
 
-  const RunTimeOperandInfo *projection_weights =
-      GetInput(operation, operands, LSTMCell::kProjectionWeightsTensor);
-  if (!IsNullInput(projection_weights)) {
-    NN_CHECK_EQ(NumDimensions(projection_weights), 2);
-    NN_CHECK_EQ(SizeOfDimension(projection_weights, 0), n_output);
-    NN_CHECK_EQ(SizeOfDimension(projection_weights, 1), n_cell);
+  if (!IsNullInput(projection_weights_)) {
+    NN_CHECK_EQ(NumDimensions(projection_weights_), 2);
+    NN_CHECK_EQ(SizeOfDimension(projection_weights_, 0), n_output);
+    NN_CHECK_EQ(SizeOfDimension(projection_weights_, 1), n_cell);
   }
 
-  const RunTimeOperandInfo *projection_bias =
-      GetInput(operation, operands, LSTMCell::kProjectionBiasTensor);
-  if (!IsNullInput(projection_bias)) {
-    NN_CHECK_EQ(NumDimensions(projection_bias), 1);
-    NN_CHECK_EQ(SizeOfDimension(projection_bias, 0), n_output);
+  if (!IsNullInput(projection_bias_)) {
+    NN_CHECK_EQ(NumDimensions(projection_bias_), 1);
+    NN_CHECK_EQ(SizeOfDimension(projection_bias_, 0), n_output);
   }
 
   // Making sure the projection tensors are consistent:
@@ -223,9 +213,37 @@
   // 2) If projection weight is present, then projection bias is optional.
   // TODO: make sure this is correct.
   const bool projecton_tensors_consistent =
-      (!IsNullInput(projection_weights) || IsNullInput(projection_bias));
+      (!IsNullInput(projection_weights_) || IsNullInput(projection_bias_));
   NN_CHECK(projecton_tensors_consistent == true);
 
+  if (!IsNullInput(input_layer_norm_weights_)) {
+    NN_CHECK_EQ(NumDimensions(input_layer_norm_weights_), 1);
+    NN_CHECK_EQ(SizeOfDimension(input_layer_norm_weights_, 0), n_cell);
+  }
+  if (!IsNullInput(forget_layer_norm_weights_)) {
+    NN_CHECK_EQ(NumDimensions(forget_layer_norm_weights_), 1);
+    NN_CHECK_EQ(SizeOfDimension(forget_layer_norm_weights_, 0), n_cell);
+  }
+  if (!IsNullInput(cell_layer_norm_weights_)) {
+    NN_CHECK_EQ(NumDimensions(cell_layer_norm_weights_), 1);
+    NN_CHECK_EQ(SizeOfDimension(cell_layer_norm_weights_, 0), n_cell);
+  }
+  if (!IsNullInput(output_layer_norm_weights_)) {
+    NN_CHECK_EQ(NumDimensions(output_layer_norm_weights_), 1);
+    NN_CHECK_EQ(SizeOfDimension(output_layer_norm_weights_, 0), n_cell);
+  }
+
+  const bool layer_norm_weights_all_or_none =
+      (IsNullInput(input_layer_norm_weights_) &&
+       IsNullInput(forget_layer_norm_weights_) &&
+       IsNullInput(cell_layer_norm_weights_) &&
+       IsNullInput(input_layer_norm_weights_)) ||
+      (!IsNullInput(input_layer_norm_weights_) &&
+       !IsNullInput(forget_layer_norm_weights_) &&
+       !IsNullInput(cell_layer_norm_weights_) &&
+       !IsNullInput(input_layer_norm_weights_));
+  NN_CHECK(layer_norm_weights_all_or_none);
+
   return true;
 }
 
@@ -237,29 +255,22 @@
                        Shape *outputShape) {
   // Check we have all the inputs and outputs we need.
   NN_CHECK(NumInputsWithValues(operation, operands) >= 15 &&
-           NumInputsWithValues(operation, operands) <= 23);
+           NumInputsWithValues(operation, operands) <= 27);
   NN_CHECK_EQ(NumOutputs(operation), 4);
 
   // Inferring batch size, number of outputs and number of cells from the
   // input tensors.
-  const RunTimeOperandInfo *input =
-      GetInput(operation, operands, LSTMCell::kInputTensor);
-  NN_CHECK(NumDimensions(input) > 1);
-  const uint32_t n_batch = SizeOfDimension(input, 0);
-  const uint32_t n_input = SizeOfDimension(input, 1);
+  NN_CHECK(NumDimensions(input_) > 1);
+  const uint32_t n_batch = SizeOfDimension(input_, 0);
+  const uint32_t n_input = SizeOfDimension(input_, 1);
 
-  const RunTimeOperandInfo *input_to_output_weights =
-      GetInput(operation, operands, LSTMCell::kInputToOutputWeightsTensor);
-  const uint32_t n_cell = SizeOfDimension(input_to_output_weights, 0);
-  NN_CHECK_EQ(NumDimensions(input_to_output_weights), 2);
-  NN_CHECK_EQ(SizeOfDimension(input_to_output_weights, 1), n_input);
+  const uint32_t n_cell = SizeOfDimension(input_to_output_weights_, 0);
+  NN_CHECK_EQ(NumDimensions(input_to_output_weights_), 2);
+  NN_CHECK_EQ(SizeOfDimension(input_to_output_weights_, 1), n_input);
 
-  const RunTimeOperandInfo *recurrent_to_output_weights =
-      GetInput(operation, operands, LSTMCell::kRecurrentToOutputWeightsTensor);
-  NN_CHECK_EQ(NumDimensions(recurrent_to_output_weights), 2);
-  NN_CHECK_EQ(SizeOfDimension(recurrent_to_output_weights, 0),
-                    n_cell);
-  const uint32_t n_output = SizeOfDimension(recurrent_to_output_weights, 1);
+  NN_CHECK_EQ(NumDimensions(recurrent_to_output_weights_), 2);
+  NN_CHECK_EQ(SizeOfDimension(recurrent_to_output_weights_, 0), n_cell);
+  const uint32_t n_output = SizeOfDimension(recurrent_to_output_weights_, 1);
 
   // Check that input tensor dimensions matches with each other.
   if (!CheckInputTensorDimensions(operation, operands, n_input, n_output, n_cell)) {
@@ -267,7 +278,7 @@
   }
 
   // Resize the output and output_state tensors.
-  const Shape &inputShape = input->shape();
+  const Shape &inputShape = input_->shape();
 
   outputShape->type = inputShape.type;
   outputShape->dimensions = { n_batch, n_output };
@@ -284,9 +295,7 @@
   cellStateShape->offset = inputShape.offset;
   cellStateShape->scale = inputShape.scale;
 
-  const RunTimeOperandInfo *input_to_input_weights =
-      GetInput(operation, operands, LSTMCell::kInputToInputWeightsTensor);
-  const bool use_cifg = IsNullInput(input_to_input_weights);
+  const bool use_cifg = IsNullInput(input_to_input_weights_);
   if (use_cifg) {
     // Reserving space for Cell, Forget, Output gates
     scratchShape->dimensions = { n_batch, n_cell * 3 };
@@ -312,8 +321,9 @@
 
   // Since we have already checked that weights are all there or none, we can
   // check the existence of only one to the get the condition.
-  const bool use_cifg = (input_to_input_weights_->lifetime == OperandLifeTime::NO_VALUE);
-  const bool use_peephole = (cell_to_output_weights_->lifetime != OperandLifeTime::NO_VALUE);
+  const bool use_cifg = IsNullInput(input_to_input_weights_);
+  const bool use_peephole = !IsNullInput(cell_to_output_weights_);
+  const bool use_layer_norm = !IsNullInput(input_layer_norm_weights_);
 
   // Index the scratch buffers pointers to the global scratch buffer.
   float* input_gate_scratch = nullptr;
@@ -331,17 +341,27 @@
     output_gate_scratch = input_gate_scratch + 3 * n_cell * n_batch;
   }
 
-  // Initialize scratch buffers with bias.
-  if (!use_cifg) {
-    tflite::tensor_utils::VectorBatchVectorAssign(GetBuffer<float>(input_gate_bias_),
-                                                  n_cell, n_batch, input_gate_scratch);
+  if (!use_layer_norm) {
+    // Initialize scratch buffers with bias.
+    if (!use_cifg) {
+      tflite::tensor_utils::VectorBatchVectorAssign(GetBuffer<float>(input_gate_bias_),
+                                                    n_cell, n_batch, input_gate_scratch);
+    }
+    tflite::tensor_utils::VectorBatchVectorAssign(GetBuffer<float>(forget_gate_bias_),
+                                                  n_cell, n_batch, forget_gate_scratch);
+    tflite::tensor_utils::VectorBatchVectorAssign(GetBuffer<float>(cell_bias_),
+                                                  n_cell, n_batch, cell_scratch);
+    tflite::tensor_utils::VectorBatchVectorAssign(GetBuffer<float>(output_gate_bias_),
+                                                  n_cell, n_batch, output_gate_scratch);
+  } else {
+    // Initialize scratch buffers with zeroes.
+    if (!use_cifg) {
+      tflite::tensor_utils::ZeroVector(input_gate_scratch, n_cell * n_batch);
+    }
+    tflite::tensor_utils::ZeroVector(forget_gate_scratch, n_cell * n_batch);
+    tflite::tensor_utils::ZeroVector(cell_scratch, n_cell * n_batch);
+    tflite::tensor_utils::ZeroVector(output_gate_scratch, n_cell * n_batch);
   }
-  tflite::tensor_utils::VectorBatchVectorAssign(GetBuffer<float>(forget_gate_bias_),
-                                                n_cell, n_batch, forget_gate_scratch);
-  tflite::tensor_utils::VectorBatchVectorAssign(GetBuffer<float>(cell_bias_),
-                                                n_cell, n_batch, cell_scratch);
-  tflite::tensor_utils::VectorBatchVectorAssign(GetBuffer<float>(output_gate_bias_),
-                                                n_cell, n_batch, output_gate_scratch);
 
   // For each batch and cell: compute input_weight * input.
   if (!use_cifg) {
@@ -382,6 +402,16 @@
           GetBuffer<float>(cell_to_input_weights_), n_cell,
           GetBuffer<float>(cell_state_in_), n_batch, input_gate_scratch);
     }
+    if (use_layer_norm) {
+      tflite::tensor_utils::MeanStddevNormalization(input_gate_scratch,
+                                                    input_gate_scratch, n_cell, n_batch,
+                                                    kLayerNormEpsilon);
+      tflite::tensor_utils::VectorBatchVectorCwiseProduct(GetBuffer<float>(input_layer_norm_weights_),
+                                                          n_cell, input_gate_scratch,
+                                                          n_batch, input_gate_scratch);
+      tflite::tensor_utils::VectorBatchVectorAdd(GetBuffer<float>(input_gate_bias_), n_cell, n_batch,
+                                                 input_gate_scratch);
+    }
     tflite::tensor_utils::ApplySigmoidToVector(input_gate_scratch,
                                                n_cell * n_batch,
                                                input_gate_scratch);
@@ -393,11 +423,29 @@
         GetBuffer<float>(cell_to_forget_weights_), n_cell,
         GetBuffer<float>(cell_state_in_), n_batch, forget_gate_scratch);
   }
+  if (use_layer_norm) {
+    tflite::tensor_utils::MeanStddevNormalization(forget_gate_scratch,
+                                                  forget_gate_scratch, n_cell, n_batch,
+                                                  kLayerNormEpsilon);
+    tflite::tensor_utils::VectorBatchVectorCwiseProduct(GetBuffer<float>(forget_layer_norm_weights_),
+                                                        n_cell, forget_gate_scratch,
+                                                        n_batch, forget_gate_scratch);
+    tflite::tensor_utils::VectorBatchVectorAdd(GetBuffer<float>(forget_gate_bias_), n_cell, n_batch,
+                                               forget_gate_scratch);
+  }
   tflite::tensor_utils::ApplySigmoidToVector(forget_gate_scratch,
                                              n_cell * n_batch,
                                              forget_gate_scratch);
 
   // For each batch and cell: update the cell.
+  if (use_layer_norm) {
+    tflite::tensor_utils::MeanStddevNormalization(cell_scratch, cell_scratch, n_cell,
+						  n_batch, kLayerNormEpsilon);
+    tflite::tensor_utils::VectorBatchVectorCwiseProduct(
+	GetBuffer<float>(cell_layer_norm_weights_), n_cell, cell_scratch, n_batch, cell_scratch);
+    tflite::tensor_utils::VectorBatchVectorAdd(GetBuffer<float>(cell_bias_), n_cell, n_batch,
+					       cell_scratch);
+  }
   tflite::tensor_utils::VectorVectorCwiseProduct(
       forget_gate_scratch, GetBuffer<float>(cell_state_in_), n_batch * n_cell,
       GetBuffer<float>(cell_state_out_));
@@ -426,6 +474,16 @@
         GetBuffer<float>(cell_to_output_weights_), n_cell,
         GetBuffer<float>(cell_state_out_), n_batch, output_gate_scratch);
   }
+  if (use_layer_norm) {
+    tflite::tensor_utils::MeanStddevNormalization(output_gate_scratch,
+                                                  output_gate_scratch, n_cell, n_batch,
+                                                  kLayerNormEpsilon);
+    tflite::tensor_utils::VectorBatchVectorCwiseProduct(GetBuffer<float>(output_layer_norm_weights_),
+                                                        n_cell, output_gate_scratch,
+                                                        n_batch, output_gate_scratch);
+    tflite::tensor_utils::VectorBatchVectorAdd(GetBuffer<float>(output_gate_bias_), n_cell, n_batch,
+                                               output_gate_scratch);
+  }
   tflite::tensor_utils::ApplySigmoidToVector(output_gate_scratch, n_batch * n_cell,
                                              output_gate_scratch);
   tflite::tensor_utils::ApplyActivationToVector(GetBuffer<float>(cell_state_out_),