Add layer normalization support to LSTM op

* Add new test to operations test
* Add new test to validation testing
* Make Prepare and CheckInputTensorDimensions non-static.
    This makes it possible to reuse at preparation the inputs that was
    already read in constructor. This is needed to use upcasting between
    versions that is implemented in the constructor by assigning dummy
    values to the inputs that are non-existent in older versions.

Fix: 113562577
Test: NeuralNetworksTest_static
Test: VtsHalNeuralnetworksV1_2TargetTest
Change-Id: I90676abfcdb3d9a969a1418f8474ce383bf7fb07
Merged-In: I90676abfcdb3d9a969a1418f8474ce383bf7fb07
(cherry picked from commit d24a1bb943a970fbf71b25ac64216a60a16ffc37)
diff --git a/common/operations/LSTM.cpp b/common/operations/LSTM.cpp
index 7667c66..ca9f42b 100644
--- a/common/operations/LSTM.cpp
+++ b/common/operations/LSTM.cpp
@@ -21,6 +21,8 @@
 
 #include "Tracing.h"
 
+// TODO(levp): Format the file.
+// clang-format off
 namespace android {
 namespace nn {
 
@@ -73,6 +75,24 @@
   params_.cell_clip_ = getScalarData<float>(*GetInput(operation, operands, kCellClipParam));
   params_.proj_clip_ = getScalarData<float>(*GetInput(operation, operands, kProjClipParam));
 
+  // We check the version of LSTM by checking the number of the inputs to the
+  // op. For LSTM version 1.0 there were 23 inputs and for 1.2 there are 27.
+  if (operation.inputs.size() == 27) {
+    input_layer_norm_weights_ = GetInput(operation, operands, kInputLayerNormWeightsTensor);
+    forget_layer_norm_weights_ = GetInput(operation, operands, kForgetLayerNormWeightsTensor);
+    cell_layer_norm_weights_ = GetInput(operation, operands, kCellLayerNormWeightsTensor);
+    output_layer_norm_weights_ = GetInput(operation, operands, kOutputLayerNormWeightsTensor);
+  } else {
+    // For LSTM from HAL v1.0 assign operands with no values
+    static RunTimeOperandInfo no_value;
+    no_value.lifetime = OperandLifeTime::NO_VALUE;
+
+    input_layer_norm_weights_ = &no_value;
+    forget_layer_norm_weights_ = &no_value;
+    cell_layer_norm_weights_ = &no_value;
+    output_layer_norm_weights_ = &no_value;
+  }
+
   output_state_out_ = GetOutput(operation, operands, kOutputStateOutTensor);
   cell_state_out_ = GetOutput(operation, operands, kCellStateOutTensor);
   output_ = GetOutput(operation, operands, kOutputTensor);
@@ -96,125 +116,95 @@
   NN_CHECK(params.cell_clip_ >= 0);
   NN_CHECK(params.proj_clip_ >= 0);
 
-  const RunTimeOperandInfo *input_to_input_weights =
-      GetInput(operation, operands, LSTMCell::kInputToInputWeightsTensor);
-  if (!IsNullInput(input_to_input_weights)) {
-    NN_CHECK_EQ(NumDimensions(input_to_input_weights), 2);
-    NN_CHECK_EQ(SizeOfDimension(input_to_input_weights, 0), n_cell);
-    NN_CHECK_EQ(SizeOfDimension(input_to_input_weights, 1), n_input);
+  if (!IsNullInput(input_to_input_weights_)) {
+    NN_CHECK_EQ(NumDimensions(input_to_input_weights_), 2);
+    NN_CHECK_EQ(SizeOfDimension(input_to_input_weights_, 0), n_cell);
+    NN_CHECK_EQ(SizeOfDimension(input_to_input_weights_, 1), n_input);
   }
 
-  const RunTimeOperandInfo *input_to_forget_weights =
-      GetInput(operation, operands, LSTMCell::kInputToForgetWeightsTensor);
-  NN_CHECK_EQ(NumDimensions(input_to_forget_weights), 2);
-  NN_CHECK_EQ(SizeOfDimension(input_to_forget_weights, 0), n_cell);
-  NN_CHECK_EQ(SizeOfDimension(input_to_forget_weights, 1), n_input);
+  NN_CHECK_EQ(NumDimensions(input_to_forget_weights_), 2);
+  NN_CHECK_EQ(SizeOfDimension(input_to_forget_weights_, 0), n_cell);
+  NN_CHECK_EQ(SizeOfDimension(input_to_forget_weights_, 1), n_input);
 
-  const RunTimeOperandInfo *input_to_cell_weights =
-      GetInput(operation, operands, LSTMCell::kInputToCellWeightsTensor);
-  NN_CHECK_EQ(NumDimensions(input_to_cell_weights), 2);
-  NN_CHECK_EQ(SizeOfDimension(input_to_cell_weights, 0), n_cell);
-  NN_CHECK_EQ(SizeOfDimension(input_to_cell_weights, 1), n_input);
+  NN_CHECK_EQ(NumDimensions(input_to_cell_weights_), 2);
+  NN_CHECK_EQ(SizeOfDimension(input_to_cell_weights_, 0), n_cell);
+  NN_CHECK_EQ(SizeOfDimension(input_to_cell_weights_, 1), n_input);
 
-  const RunTimeOperandInfo *recurrent_to_input_weights =
-      GetInput(operation, operands, LSTMCell::kRecurrentToInputWeightsTensor);
-  if (!IsNullInput(recurrent_to_input_weights)) {
-    NN_CHECK_EQ(NumDimensions(recurrent_to_input_weights), 2);
-    NN_CHECK_EQ(SizeOfDimension(recurrent_to_input_weights, 0), n_cell);
-    NN_CHECK_EQ(SizeOfDimension(recurrent_to_input_weights, 1), n_output);
+  if (!IsNullInput(recurrent_to_input_weights_)) {
+    NN_CHECK_EQ(NumDimensions(recurrent_to_input_weights_), 2);
+    NN_CHECK_EQ(SizeOfDimension(recurrent_to_input_weights_, 0), n_cell);
+    NN_CHECK_EQ(SizeOfDimension(recurrent_to_input_weights_, 1), n_output);
   }
 
-  const RunTimeOperandInfo *recurrent_to_forget_weights =
-      GetInput(operation, operands, LSTMCell::kRecurrentToForgetWeightsTensor);
-  NN_CHECK_EQ(NumDimensions(recurrent_to_forget_weights), 2);
-  NN_CHECK_EQ(SizeOfDimension(recurrent_to_forget_weights, 0), n_cell);
-  NN_CHECK_EQ(SizeOfDimension(recurrent_to_forget_weights, 1), n_output);
+  NN_CHECK_EQ(NumDimensions(recurrent_to_forget_weights_), 2);
+  NN_CHECK_EQ(SizeOfDimension(recurrent_to_forget_weights_, 0), n_cell);
+  NN_CHECK_EQ(SizeOfDimension(recurrent_to_forget_weights_, 1), n_output);
 
-  const RunTimeOperandInfo *recurrent_to_cell_weights =
-      GetInput(operation, operands, LSTMCell::kRecurrentToCellWeightsTensor);
-  NN_CHECK_EQ(NumDimensions(recurrent_to_cell_weights), 2);
-  NN_CHECK_EQ(SizeOfDimension(recurrent_to_cell_weights, 0), n_cell);
-  NN_CHECK_EQ(SizeOfDimension(recurrent_to_cell_weights, 1), n_output);
+  NN_CHECK_EQ(NumDimensions(recurrent_to_cell_weights_), 2);
+  NN_CHECK_EQ(SizeOfDimension(recurrent_to_cell_weights_, 0), n_cell);
+  NN_CHECK_EQ(SizeOfDimension(recurrent_to_cell_weights_, 1), n_output);
 
   // We make sure the input-gate's parameters are either both present (regular
   // LSTM) or not at all (CIFG-LSTM).
   const bool cifg_weights_all_or_none =
-      (!IsNullInput(input_to_input_weights) &&
-       !IsNullInput(recurrent_to_input_weights)) ||
-      (IsNullInput(input_to_input_weights) &&
-       IsNullInput(recurrent_to_input_weights));
+      (!IsNullInput(input_to_input_weights_) &&
+       !IsNullInput(recurrent_to_input_weights_)) ||
+      (IsNullInput(input_to_input_weights_) &&
+       IsNullInput(recurrent_to_input_weights_));
   NN_CHECK(cifg_weights_all_or_none);
 
-  const RunTimeOperandInfo *cell_to_input_weights =
-      GetInput(operation, operands, LSTMCell::kCellToInputWeightsTensor);
-  if (!IsNullInput(cell_to_input_weights)) {
-    NN_CHECK_EQ(NumDimensions(cell_to_input_weights), 1);
-    NN_CHECK_EQ(SizeOfDimension(cell_to_input_weights, 0), n_cell);
+  if (!IsNullInput(cell_to_input_weights_)) {
+    NN_CHECK_EQ(NumDimensions(cell_to_input_weights_), 1);
+    NN_CHECK_EQ(SizeOfDimension(cell_to_input_weights_, 0), n_cell);
   }
 
-  const RunTimeOperandInfo *cell_to_forget_weights =
-      GetInput(operation, operands, LSTMCell::kCellToForgetWeightsTensor);
-  if (!IsNullInput(cell_to_forget_weights)) {
-    NN_CHECK_EQ(NumDimensions(cell_to_forget_weights), 1);
-    NN_CHECK_EQ(SizeOfDimension(cell_to_forget_weights, 0), n_cell);
+  if (!IsNullInput(cell_to_forget_weights_)) {
+    NN_CHECK_EQ(NumDimensions(cell_to_forget_weights_), 1);
+    NN_CHECK_EQ(SizeOfDimension(cell_to_forget_weights_, 0), n_cell);
   }
 
-  const RunTimeOperandInfo *cell_to_output_weights =
-      GetInput(operation, operands, LSTMCell::kCellToOutputWeightsTensor);
-  if (!IsNullInput(cell_to_output_weights)) {
-    NN_CHECK_EQ(NumDimensions(cell_to_output_weights), 1);
-    NN_CHECK_EQ(SizeOfDimension(cell_to_output_weights, 0), n_cell);
+  if (!IsNullInput(cell_to_output_weights_)) {
+    NN_CHECK_EQ(NumDimensions(cell_to_output_weights_), 1);
+    NN_CHECK_EQ(SizeOfDimension(cell_to_output_weights_, 0), n_cell);
   }
 
   // Making sure the peephole weights are there all or none.
-  const bool use_cifg = IsNullInput(input_to_input_weights);
+  const bool use_cifg = IsNullInput(input_to_input_weights_);
   const bool peephole_weights_all_or_none =
-      ((!IsNullInput(cell_to_input_weights) || use_cifg) &&
-       !IsNullInput(cell_to_forget_weights) &&
-       !IsNullInput(cell_to_output_weights)) ||
-      (IsNullInput(cell_to_input_weights) &&
-       IsNullInput(cell_to_forget_weights) &&
-       IsNullInput(cell_to_output_weights));
+      ((!IsNullInput(cell_to_input_weights_) || use_cifg) &&
+       !IsNullInput(cell_to_forget_weights_) &&
+       !IsNullInput(cell_to_output_weights_)) ||
+      (IsNullInput(cell_to_input_weights_) &&
+       IsNullInput(cell_to_forget_weights_) &&
+       IsNullInput(cell_to_output_weights_));
   NN_CHECK(peephole_weights_all_or_none);
 
   // Make sure the input gate bias is present only when not a CIFG-LSTM.
-  const RunTimeOperandInfo* input_gate_bias =
-      GetInput(operation, operands, LSTMCell::kInputGateBiasTensor);
   if (use_cifg) {
-    NN_CHECK(IsNullInput(input_gate_bias));
+    NN_CHECK(IsNullInput(input_gate_bias_));
   } else {
-    NN_CHECK_EQ(NumDimensions(input_gate_bias), 1);
-    NN_CHECK_EQ(SizeOfDimension(input_gate_bias, 0), n_cell);
+    NN_CHECK_EQ(NumDimensions(input_gate_bias_), 1);
+    NN_CHECK_EQ(SizeOfDimension(input_gate_bias_, 0), n_cell);
   }
 
-  const RunTimeOperandInfo *forget_gate_bias =
-      GetInput(operation, operands, LSTMCell::kForgetGateBiasTensor);
-  NN_CHECK_EQ(NumDimensions(forget_gate_bias), 1);
-  NN_CHECK_EQ(SizeOfDimension(forget_gate_bias, 0), n_cell);
+  NN_CHECK_EQ(NumDimensions(forget_gate_bias_), 1);
+  NN_CHECK_EQ(SizeOfDimension(forget_gate_bias_, 0), n_cell);
 
-  const RunTimeOperandInfo *cell_bias =
-      GetInput(operation, operands, LSTMCell::kCellGateBiasTensor);
-  NN_CHECK_EQ(NumDimensions(cell_bias), 1);
-  NN_CHECK_EQ(SizeOfDimension(cell_bias, 0), n_cell);
+  NN_CHECK_EQ(NumDimensions(cell_bias_), 1);
+  NN_CHECK_EQ(SizeOfDimension(cell_bias_, 0), n_cell);
 
-  const RunTimeOperandInfo *output_gate_bias =
-      GetInput(operation, operands, LSTMCell::kOutputGateBiasTensor);
-  NN_CHECK_EQ(NumDimensions(output_gate_bias), 1);
-  NN_CHECK_EQ(SizeOfDimension(output_gate_bias, 0), n_cell);
+  NN_CHECK_EQ(NumDimensions(output_gate_bias_), 1);
+  NN_CHECK_EQ(SizeOfDimension(output_gate_bias_, 0), n_cell);
 
-  const RunTimeOperandInfo *projection_weights =
-      GetInput(operation, operands, LSTMCell::kProjectionWeightsTensor);
-  if (!IsNullInput(projection_weights)) {
-    NN_CHECK_EQ(NumDimensions(projection_weights), 2);
-    NN_CHECK_EQ(SizeOfDimension(projection_weights, 0), n_output);
-    NN_CHECK_EQ(SizeOfDimension(projection_weights, 1), n_cell);
+  if (!IsNullInput(projection_weights_)) {
+    NN_CHECK_EQ(NumDimensions(projection_weights_), 2);
+    NN_CHECK_EQ(SizeOfDimension(projection_weights_, 0), n_output);
+    NN_CHECK_EQ(SizeOfDimension(projection_weights_, 1), n_cell);
   }
 
-  const RunTimeOperandInfo *projection_bias =
-      GetInput(operation, operands, LSTMCell::kProjectionBiasTensor);
-  if (!IsNullInput(projection_bias)) {
-    NN_CHECK_EQ(NumDimensions(projection_bias), 1);
-    NN_CHECK_EQ(SizeOfDimension(projection_bias, 0), n_output);
+  if (!IsNullInput(projection_bias_)) {
+    NN_CHECK_EQ(NumDimensions(projection_bias_), 1);
+    NN_CHECK_EQ(SizeOfDimension(projection_bias_, 0), n_output);
   }
 
   // Making sure the projection tensors are consistent:
@@ -223,9 +213,37 @@
   // 2) If projection weight is present, then projection bias is optional.
   // TODO: make sure this is correct.
   const bool projecton_tensors_consistent =
-      (!IsNullInput(projection_weights) || IsNullInput(projection_bias));
+      (!IsNullInput(projection_weights_) || IsNullInput(projection_bias_));
   NN_CHECK(projecton_tensors_consistent == true);
 
+  if (!IsNullInput(input_layer_norm_weights_)) {
+    NN_CHECK_EQ(NumDimensions(input_layer_norm_weights_), 1);
+    NN_CHECK_EQ(SizeOfDimension(input_layer_norm_weights_, 0), n_cell);
+  }
+  if (!IsNullInput(forget_layer_norm_weights_)) {
+    NN_CHECK_EQ(NumDimensions(forget_layer_norm_weights_), 1);
+    NN_CHECK_EQ(SizeOfDimension(forget_layer_norm_weights_, 0), n_cell);
+  }
+  if (!IsNullInput(cell_layer_norm_weights_)) {
+    NN_CHECK_EQ(NumDimensions(cell_layer_norm_weights_), 1);
+    NN_CHECK_EQ(SizeOfDimension(cell_layer_norm_weights_, 0), n_cell);
+  }
+  if (!IsNullInput(output_layer_norm_weights_)) {
+    NN_CHECK_EQ(NumDimensions(output_layer_norm_weights_), 1);
+    NN_CHECK_EQ(SizeOfDimension(output_layer_norm_weights_, 0), n_cell);
+  }
+
+  const bool layer_norm_weights_all_or_none =
+      (IsNullInput(input_layer_norm_weights_) &&
+       IsNullInput(forget_layer_norm_weights_) &&
+       IsNullInput(cell_layer_norm_weights_) &&
+       IsNullInput(input_layer_norm_weights_)) ||
+      (!IsNullInput(input_layer_norm_weights_) &&
+       !IsNullInput(forget_layer_norm_weights_) &&
+       !IsNullInput(cell_layer_norm_weights_) &&
+       !IsNullInput(input_layer_norm_weights_));
+  NN_CHECK(layer_norm_weights_all_or_none);
+
   return true;
 }
 
@@ -237,29 +255,22 @@
                        Shape *outputShape) {
   // Check we have all the inputs and outputs we need.
   NN_CHECK(NumInputsWithValues(operation, operands) >= 15 &&
-           NumInputsWithValues(operation, operands) <= 23);
+           NumInputsWithValues(operation, operands) <= 27);
   NN_CHECK_EQ(NumOutputs(operation), 4);
 
   // Inferring batch size, number of outputs and number of cells from the
   // input tensors.
-  const RunTimeOperandInfo *input =
-      GetInput(operation, operands, LSTMCell::kInputTensor);
-  NN_CHECK(NumDimensions(input) > 1);
-  const uint32_t n_batch = SizeOfDimension(input, 0);
-  const uint32_t n_input = SizeOfDimension(input, 1);
+  NN_CHECK(NumDimensions(input_) > 1);
+  const uint32_t n_batch = SizeOfDimension(input_, 0);
+  const uint32_t n_input = SizeOfDimension(input_, 1);
 
-  const RunTimeOperandInfo *input_to_output_weights =
-      GetInput(operation, operands, LSTMCell::kInputToOutputWeightsTensor);
-  const uint32_t n_cell = SizeOfDimension(input_to_output_weights, 0);
-  NN_CHECK_EQ(NumDimensions(input_to_output_weights), 2);
-  NN_CHECK_EQ(SizeOfDimension(input_to_output_weights, 1), n_input);
+  const uint32_t n_cell = SizeOfDimension(input_to_output_weights_, 0);
+  NN_CHECK_EQ(NumDimensions(input_to_output_weights_), 2);
+  NN_CHECK_EQ(SizeOfDimension(input_to_output_weights_, 1), n_input);
 
-  const RunTimeOperandInfo *recurrent_to_output_weights =
-      GetInput(operation, operands, LSTMCell::kRecurrentToOutputWeightsTensor);
-  NN_CHECK_EQ(NumDimensions(recurrent_to_output_weights), 2);
-  NN_CHECK_EQ(SizeOfDimension(recurrent_to_output_weights, 0),
-                    n_cell);
-  const uint32_t n_output = SizeOfDimension(recurrent_to_output_weights, 1);
+  NN_CHECK_EQ(NumDimensions(recurrent_to_output_weights_), 2);
+  NN_CHECK_EQ(SizeOfDimension(recurrent_to_output_weights_, 0), n_cell);
+  const uint32_t n_output = SizeOfDimension(recurrent_to_output_weights_, 1);
 
   // Check that input tensor dimensions matches with each other.
   if (!CheckInputTensorDimensions(operation, operands, n_input, n_output, n_cell)) {
@@ -267,7 +278,7 @@
   }
 
   // Resize the output and output_state tensors.
-  const Shape &inputShape = input->shape();
+  const Shape &inputShape = input_->shape();
 
   outputShape->type = inputShape.type;
   outputShape->dimensions = { n_batch, n_output };
@@ -284,9 +295,7 @@
   cellStateShape->offset = inputShape.offset;
   cellStateShape->scale = inputShape.scale;
 
-  const RunTimeOperandInfo *input_to_input_weights =
-      GetInput(operation, operands, LSTMCell::kInputToInputWeightsTensor);
-  const bool use_cifg = IsNullInput(input_to_input_weights);
+  const bool use_cifg = IsNullInput(input_to_input_weights_);
   if (use_cifg) {
     // Reserving space for Cell, Forget, Output gates
     scratchShape->dimensions = { n_batch, n_cell * 3 };
@@ -312,8 +321,9 @@
 
   // Since we have already checked that weights are all there or none, we can
   // check the existence of only one to the get the condition.
-  const bool use_cifg = (input_to_input_weights_->lifetime == OperandLifeTime::NO_VALUE);
-  const bool use_peephole = (cell_to_output_weights_->lifetime != OperandLifeTime::NO_VALUE);
+  const bool use_cifg = IsNullInput(input_to_input_weights_);
+  const bool use_peephole = !IsNullInput(cell_to_output_weights_);
+  const bool use_layer_norm = !IsNullInput(input_layer_norm_weights_);
 
   // Index the scratch buffers pointers to the global scratch buffer.
   float* input_gate_scratch = nullptr;
@@ -331,17 +341,27 @@
     output_gate_scratch = input_gate_scratch + 3 * n_cell * n_batch;
   }
 
-  // Initialize scratch buffers with bias.
-  if (!use_cifg) {
-    tflite::tensor_utils::VectorBatchVectorAssign(GetBuffer<float>(input_gate_bias_),
-                                                  n_cell, n_batch, input_gate_scratch);
+  if (!use_layer_norm) {
+    // Initialize scratch buffers with bias.
+    if (!use_cifg) {
+      tflite::tensor_utils::VectorBatchVectorAssign(GetBuffer<float>(input_gate_bias_),
+                                                    n_cell, n_batch, input_gate_scratch);
+    }
+    tflite::tensor_utils::VectorBatchVectorAssign(GetBuffer<float>(forget_gate_bias_),
+                                                  n_cell, n_batch, forget_gate_scratch);
+    tflite::tensor_utils::VectorBatchVectorAssign(GetBuffer<float>(cell_bias_),
+                                                  n_cell, n_batch, cell_scratch);
+    tflite::tensor_utils::VectorBatchVectorAssign(GetBuffer<float>(output_gate_bias_),
+                                                  n_cell, n_batch, output_gate_scratch);
+  } else {
+    // Initialize scratch buffers with zeroes.
+    if (!use_cifg) {
+      tflite::tensor_utils::ZeroVector(input_gate_scratch, n_cell * n_batch);
+    }
+    tflite::tensor_utils::ZeroVector(forget_gate_scratch, n_cell * n_batch);
+    tflite::tensor_utils::ZeroVector(cell_scratch, n_cell * n_batch);
+    tflite::tensor_utils::ZeroVector(output_gate_scratch, n_cell * n_batch);
   }
-  tflite::tensor_utils::VectorBatchVectorAssign(GetBuffer<float>(forget_gate_bias_),
-                                                n_cell, n_batch, forget_gate_scratch);
-  tflite::tensor_utils::VectorBatchVectorAssign(GetBuffer<float>(cell_bias_),
-                                                n_cell, n_batch, cell_scratch);
-  tflite::tensor_utils::VectorBatchVectorAssign(GetBuffer<float>(output_gate_bias_),
-                                                n_cell, n_batch, output_gate_scratch);
 
   // For each batch and cell: compute input_weight * input.
   if (!use_cifg) {
@@ -382,6 +402,16 @@
           GetBuffer<float>(cell_to_input_weights_), n_cell,
           GetBuffer<float>(cell_state_in_), n_batch, input_gate_scratch);
     }
+    if (use_layer_norm) {
+      tflite::tensor_utils::MeanStddevNormalization(input_gate_scratch,
+                                                    input_gate_scratch, n_cell, n_batch,
+                                                    kLayerNormEpsilon);
+      tflite::tensor_utils::VectorBatchVectorCwiseProduct(GetBuffer<float>(input_layer_norm_weights_),
+                                                          n_cell, input_gate_scratch,
+                                                          n_batch, input_gate_scratch);
+      tflite::tensor_utils::VectorBatchVectorAdd(GetBuffer<float>(input_gate_bias_), n_cell, n_batch,
+                                                 input_gate_scratch);
+    }
     tflite::tensor_utils::ApplySigmoidToVector(input_gate_scratch,
                                                n_cell * n_batch,
                                                input_gate_scratch);
@@ -393,11 +423,29 @@
         GetBuffer<float>(cell_to_forget_weights_), n_cell,
         GetBuffer<float>(cell_state_in_), n_batch, forget_gate_scratch);
   }
+  if (use_layer_norm) {
+    tflite::tensor_utils::MeanStddevNormalization(forget_gate_scratch,
+                                                  forget_gate_scratch, n_cell, n_batch,
+                                                  kLayerNormEpsilon);
+    tflite::tensor_utils::VectorBatchVectorCwiseProduct(GetBuffer<float>(forget_layer_norm_weights_),
+                                                        n_cell, forget_gate_scratch,
+                                                        n_batch, forget_gate_scratch);
+    tflite::tensor_utils::VectorBatchVectorAdd(GetBuffer<float>(forget_gate_bias_), n_cell, n_batch,
+                                               forget_gate_scratch);
+  }
   tflite::tensor_utils::ApplySigmoidToVector(forget_gate_scratch,
                                              n_cell * n_batch,
                                              forget_gate_scratch);
 
   // For each batch and cell: update the cell.
+  if (use_layer_norm) {
+    tflite::tensor_utils::MeanStddevNormalization(cell_scratch, cell_scratch, n_cell,
+						  n_batch, kLayerNormEpsilon);
+    tflite::tensor_utils::VectorBatchVectorCwiseProduct(
+	GetBuffer<float>(cell_layer_norm_weights_), n_cell, cell_scratch, n_batch, cell_scratch);
+    tflite::tensor_utils::VectorBatchVectorAdd(GetBuffer<float>(cell_bias_), n_cell, n_batch,
+					       cell_scratch);
+  }
   tflite::tensor_utils::VectorVectorCwiseProduct(
       forget_gate_scratch, GetBuffer<float>(cell_state_in_), n_batch * n_cell,
       GetBuffer<float>(cell_state_out_));
@@ -426,6 +474,16 @@
         GetBuffer<float>(cell_to_output_weights_), n_cell,
         GetBuffer<float>(cell_state_out_), n_batch, output_gate_scratch);
   }
+  if (use_layer_norm) {
+    tflite::tensor_utils::MeanStddevNormalization(output_gate_scratch,
+                                                  output_gate_scratch, n_cell, n_batch,
+                                                  kLayerNormEpsilon);
+    tflite::tensor_utils::VectorBatchVectorCwiseProduct(GetBuffer<float>(output_layer_norm_weights_),
+                                                        n_cell, output_gate_scratch,
+                                                        n_batch, output_gate_scratch);
+    tflite::tensor_utils::VectorBatchVectorAdd(GetBuffer<float>(output_gate_bias_), n_cell, n_batch,
+                                               output_gate_scratch);
+  }
   tflite::tensor_utils::ApplySigmoidToVector(output_gate_scratch, n_batch * n_cell,
                                              output_gate_scratch);
   tflite::tensor_utils::ApplyActivationToVector(GetBuffer<float>(cell_state_out_),
diff --git a/common/operations/LSTM.h b/common/operations/LSTM.h
index 5305e2b..04051aa 100644
--- a/common/operations/LSTM.h
+++ b/common/operations/LSTM.h
@@ -24,6 +24,8 @@
 #include <algorithm>
 #include <cmath>
 
+// TODO(levp): Format the file.
+// clang-format off
 namespace android {
 namespace nn {
 
@@ -41,12 +43,12 @@
   LSTMCell(const Operation &operation,
            std::vector<RunTimeOperandInfo> &operands);
 
-  static bool Prepare(const Operation &operation,
-                      std::vector<RunTimeOperandInfo> &operands,
-                      Shape *scratchShape,
-                      Shape *outputStateShape,
-                      Shape *cellStateShape,
-                      Shape *outputShape);
+  bool Prepare(const Operation &operation,
+               std::vector<RunTimeOperandInfo> &operands,
+               Shape *scratchShape,
+               Shape *outputStateShape,
+               Shape *cellStateShape,
+               Shape *outputShape);
   bool Eval();
 
   // Input Tensors of size {n_batch, n_input}
@@ -87,17 +89,24 @@
   static constexpr int kCellClipParam = 21;
   static constexpr int kProjClipParam = 22;
 
+  // Layer norm weights tensors of size {n_cell}, representing a diagonal matrix.
+  static constexpr int kInputLayerNormWeightsTensor = 23;
+  static constexpr int kForgetLayerNormWeightsTensor = 24;
+  static constexpr int kCellLayerNormWeightsTensor = 25;
+  static constexpr int kOutputLayerNormWeightsTensor = 26;
+
   // Output tensors.
   static constexpr int kScratchBufferTensor = 0;
   static constexpr int kOutputStateOutTensor = 1;
   static constexpr int kCellStateOutTensor = 2;
   static constexpr int kOutputTensor = 3;
 
+  static constexpr float kLayerNormEpsilon = 1e-8;
+
  private:
-  static bool CheckInputTensorDimensions(
-      const Operation &operation,
-      std::vector<RunTimeOperandInfo> &operands, uint32_t n_input,
-      uint32_t n_output, uint32_t n_cell);
+  bool CheckInputTensorDimensions(const Operation& operation,
+                                  std::vector<RunTimeOperandInfo>& operands,
+                                  uint32_t n_input, uint32_t n_output, uint32_t n_cell);
   LSTMParams params_;
 
   const RunTimeOperandInfo *input_;
@@ -127,6 +136,11 @@
   const RunTimeOperandInfo *output_state_in_;
   const RunTimeOperandInfo *cell_state_in_;
 
+  const RunTimeOperandInfo *input_layer_norm_weights_;
+  const RunTimeOperandInfo *forget_layer_norm_weights_;
+  const RunTimeOperandInfo *cell_layer_norm_weights_;
+  const RunTimeOperandInfo *output_layer_norm_weights_;
+
   RunTimeOperandInfo *output_state_out_;
   RunTimeOperandInfo *cell_state_out_;
   RunTimeOperandInfo *output_;
diff --git a/common/operations/LayerNormLSTMTest.cpp b/common/operations/LayerNormLSTMTest.cpp
new file mode 100644
index 0000000..faf7fef
--- /dev/null
+++ b/common/operations/LayerNormLSTMTest.cpp
@@ -0,0 +1,428 @@
+/*
+ * Copyright (C) 2017 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+#include "LSTM.h"
+
+#include <android-base/logging.h>
+
+#include "NeuralNetworksWrapper.h"
+#include "gmock/gmock-matchers.h"
+#include "gtest/gtest.h"
+
+#include <sstream>
+#include <string>
+#include <vector>
+
+namespace android {
+namespace nn {
+namespace wrapper {
+
+using ::testing::Each;
+using ::testing::FloatNear;
+using ::testing::Matcher;
+
+namespace {
+
+std::vector<Matcher<float>> ArrayFloatNear(const std::vector<float>& values,
+                                           float max_abs_error = 1.e-6) {
+    std::vector<Matcher<float>> matchers;
+    matchers.reserve(values.size());
+    for (const float& v : values) {
+        matchers.emplace_back(FloatNear(v, max_abs_error));
+    }
+    return matchers;
+}
+
+}  // anonymous namespace
+
+#define FOR_ALL_INPUT_AND_WEIGHT_TENSORS(ACTION) \
+    ACTION(Input)                                \
+    ACTION(InputToInputWeights)                  \
+    ACTION(InputToCellWeights)                   \
+    ACTION(InputToForgetWeights)                 \
+    ACTION(InputToOutputWeights)                 \
+    ACTION(RecurrentToInputWeights)              \
+    ACTION(RecurrentToCellWeights)               \
+    ACTION(RecurrentToForgetWeights)             \
+    ACTION(RecurrentToOutputWeights)             \
+    ACTION(CellToInputWeights)                   \
+    ACTION(CellToForgetWeights)                  \
+    ACTION(CellToOutputWeights)                  \
+    ACTION(InputGateBias)                        \
+    ACTION(CellGateBias)                         \
+    ACTION(ForgetGateBias)                       \
+    ACTION(OutputGateBias)                       \
+    ACTION(ProjectionWeights)                    \
+    ACTION(ProjectionBias)                       \
+    ACTION(OutputStateIn)                        \
+    ACTION(CellStateIn)
+
+#define FOR_ALL_LAYER_NORM_WEIGHTS(ACTION) \
+    ACTION(InputLayerNormWeights)          \
+    ACTION(ForgetLayerNormWeights)         \
+    ACTION(CellLayerNormWeights)           \
+    ACTION(OutputLayerNormWeights)
+
+// For all output and intermediate states
+#define FOR_ALL_OUTPUT_TENSORS(ACTION) \
+    ACTION(ScratchBuffer)              \
+    ACTION(OutputStateOut)             \
+    ACTION(CellStateOut)               \
+    ACTION(Output)
+
+class LayerNormLSTMOpModel {
+   public:
+    LayerNormLSTMOpModel(uint32_t n_batch, uint32_t n_input, uint32_t n_cell, uint32_t n_output,
+                         bool use_cifg, bool use_peephole, bool use_projection_weights,
+                         bool use_projection_bias, float cell_clip, float proj_clip,
+                         const std::vector<std::vector<uint32_t>>& input_shapes0)
+        : n_input_(n_input),
+          n_output_(n_output),
+          use_cifg_(use_cifg),
+          use_peephole_(use_peephole),
+          use_projection_weights_(use_projection_weights),
+          use_projection_bias_(use_projection_bias),
+          activation_(ActivationFn::kActivationTanh),
+          cell_clip_(cell_clip),
+          proj_clip_(proj_clip) {
+        std::vector<uint32_t> inputs;
+        std::vector<std::vector<uint32_t>> input_shapes(input_shapes0);
+
+        auto it = input_shapes.begin();
+
+        // Input and weights
+#define AddInput(X)                                     \
+    CHECK(it != input_shapes.end());                    \
+    OperandType X##OpndTy(Type::TENSOR_FLOAT32, *it++); \
+    inputs.push_back(model_.addOperand(&X##OpndTy));
+
+        FOR_ALL_INPUT_AND_WEIGHT_TENSORS(AddInput);
+
+        // Parameters
+        OperandType ActivationOpndTy(Type::INT32, {});
+        inputs.push_back(model_.addOperand(&ActivationOpndTy));
+        OperandType CellClipOpndTy(Type::FLOAT32, {});
+        inputs.push_back(model_.addOperand(&CellClipOpndTy));
+        OperandType ProjClipOpndTy(Type::FLOAT32, {});
+        inputs.push_back(model_.addOperand(&ProjClipOpndTy));
+
+        FOR_ALL_LAYER_NORM_WEIGHTS(AddInput);
+
+#undef AddOperand
+
+        // Output and other intermediate state
+        std::vector<std::vector<uint32_t>> output_shapes{
+                {n_batch, n_cell * (use_cifg ? 3 : 4)},
+                {n_batch, n_output},
+                {n_batch, n_cell},
+                {n_batch, n_output},
+        };
+        std::vector<uint32_t> outputs;
+
+        auto it2 = output_shapes.begin();
+
+#define AddOutput(X)                                     \
+    CHECK(it2 != output_shapes.end());                   \
+    OperandType X##OpndTy(Type::TENSOR_FLOAT32, *it2++); \
+    outputs.push_back(model_.addOperand(&X##OpndTy));
+
+        FOR_ALL_OUTPUT_TENSORS(AddOutput);
+
+#undef AddOutput
+
+        model_.addOperation(ANEURALNETWORKS_LSTM, inputs, outputs);
+        model_.identifyInputsAndOutputs(inputs, outputs);
+
+        Input_.insert(Input_.end(), n_batch * n_input, 0.f);
+        OutputStateIn_.insert(OutputStateIn_.end(), n_batch * n_output, 0.f);
+        CellStateIn_.insert(CellStateIn_.end(), n_batch * n_cell, 0.f);
+
+        auto multiAll = [](const std::vector<uint32_t>& dims) -> uint32_t {
+            uint32_t sz = 1;
+            for (uint32_t d : dims) {
+                sz *= d;
+            }
+            return sz;
+        };
+
+        it2 = output_shapes.begin();
+
+#define ReserveOutput(X) X##_.insert(X##_.end(), multiAll(*it2++), 0.f);
+
+        FOR_ALL_OUTPUT_TENSORS(ReserveOutput);
+
+#undef ReserveOutput
+
+        model_.finish();
+    }
+
+#define DefineSetter(X) \
+    void Set##X(const std::vector<float>& f) { X##_.insert(X##_.end(), f.begin(), f.end()); }
+
+    FOR_ALL_INPUT_AND_WEIGHT_TENSORS(DefineSetter);
+    FOR_ALL_LAYER_NORM_WEIGHTS(DefineSetter);
+
+#undef DefineSetter
+
+    void ResetOutputState() {
+        std::fill(OutputStateIn_.begin(), OutputStateIn_.end(), 0.f);
+        std::fill(OutputStateOut_.begin(), OutputStateOut_.end(), 0.f);
+    }
+
+    void ResetCellState() {
+        std::fill(CellStateIn_.begin(), CellStateIn_.end(), 0.f);
+        std::fill(CellStateOut_.begin(), CellStateOut_.end(), 0.f);
+    }
+
+    void SetInput(int offset, const float* begin, const float* end) {
+        for (; begin != end; begin++, offset++) {
+            Input_[offset] = *begin;
+        }
+    }
+
+    uint32_t num_inputs() const { return n_input_; }
+    uint32_t num_outputs() const { return n_output_; }
+
+    const std::vector<float>& GetOutput() const { return Output_; }
+
+    void Invoke() {
+        ASSERT_TRUE(model_.isValid());
+
+        OutputStateIn_.swap(OutputStateOut_);
+        CellStateIn_.swap(CellStateOut_);
+
+        Compilation compilation(&model_);
+        compilation.finish();
+        Execution execution(&compilation);
+#define SetInputOrWeight(X)                                                                       \
+    ASSERT_EQ(                                                                                    \
+            execution.setInput(LSTMCell::k##X##Tensor, X##_.data(), sizeof(float) * X##_.size()), \
+            Result::NO_ERROR);
+
+        FOR_ALL_INPUT_AND_WEIGHT_TENSORS(SetInputOrWeight);
+        FOR_ALL_LAYER_NORM_WEIGHTS(SetInputOrWeight);
+
+#undef SetInputOrWeight
+
+#define SetOutput(X)                                                                               \
+    ASSERT_EQ(                                                                                     \
+            execution.setOutput(LSTMCell::k##X##Tensor, X##_.data(), sizeof(float) * X##_.size()), \
+            Result::NO_ERROR);
+
+        FOR_ALL_OUTPUT_TENSORS(SetOutput);
+
+#undef SetOutput
+
+        if (use_cifg_) {
+            execution.setInput(LSTMCell::kInputToInputWeightsTensor, nullptr, 0);
+            execution.setInput(LSTMCell::kRecurrentToInputWeightsTensor, nullptr, 0);
+        }
+
+        if (use_peephole_) {
+            if (use_cifg_) {
+                execution.setInput(LSTMCell::kCellToInputWeightsTensor, nullptr, 0);
+            }
+        } else {
+            execution.setInput(LSTMCell::kCellToInputWeightsTensor, nullptr, 0);
+            execution.setInput(LSTMCell::kCellToForgetWeightsTensor, nullptr, 0);
+            execution.setInput(LSTMCell::kCellToOutputWeightsTensor, nullptr, 0);
+        }
+
+        if (use_projection_weights_) {
+            if (!use_projection_bias_) {
+                execution.setInput(LSTMCell::kProjectionBiasTensor, nullptr, 0);
+            }
+        } else {
+            execution.setInput(LSTMCell::kProjectionWeightsTensor, nullptr, 0);
+            execution.setInput(LSTMCell::kProjectionBiasTensor, nullptr, 0);
+        }
+
+        ASSERT_EQ(execution.setInput(LSTMCell::kActivationParam, &activation_, sizeof(activation_)),
+                  Result::NO_ERROR);
+        ASSERT_EQ(execution.setInput(LSTMCell::kCellClipParam, &cell_clip_, sizeof(cell_clip_)),
+                  Result::NO_ERROR);
+        ASSERT_EQ(execution.setInput(LSTMCell::kProjClipParam, &proj_clip_, sizeof(proj_clip_)),
+                  Result::NO_ERROR);
+
+        ASSERT_EQ(execution.compute(), Result::NO_ERROR);
+    }
+
+   private:
+    Model model_;
+    // Execution execution_;
+    const uint32_t n_input_;
+    const uint32_t n_output_;
+
+    const bool use_cifg_;
+    const bool use_peephole_;
+    const bool use_projection_weights_;
+    const bool use_projection_bias_;
+
+    const int activation_;
+    const float cell_clip_;
+    const float proj_clip_;
+
+#define DefineTensor(X) std::vector<float> X##_;
+
+    FOR_ALL_INPUT_AND_WEIGHT_TENSORS(DefineTensor);
+    FOR_ALL_LAYER_NORM_WEIGHTS(DefineTensor);
+    FOR_ALL_OUTPUT_TENSORS(DefineTensor);
+
+#undef DefineTensor
+};
+
+TEST(LSTMOpTest, LayerNormNoCifgPeepholeProjectionNoClipping) {
+    const int n_batch = 2;
+    const int n_input = 5;
+    // n_cell and n_output have the same size when there is no projection.
+    const int n_cell = 4;
+    const int n_output = 3;
+
+    LayerNormLSTMOpModel lstm(n_batch, n_input, n_cell, n_output,
+                              /*use_cifg=*/false, /*use_peephole=*/true,
+                              /*use_projection_weights=*/true,
+                              /*use_projection_bias=*/false,
+                              /*cell_clip=*/0.0, /*proj_clip=*/0.0,
+                              {
+                                      {n_batch, n_input},  // input tensor
+
+                                      {n_cell, n_input},  // input_to_input_weight tensor
+                                      {n_cell, n_input},  // input_to_forget_weight tensor
+                                      {n_cell, n_input},  // input_to_cell_weight tensor
+                                      {n_cell, n_input},  // input_to_output_weight tensor
+
+                                      {n_cell, n_output},  // recurrent_to_input_weight tensor
+                                      {n_cell, n_output},  // recurrent_to_forget_weight tensor
+                                      {n_cell, n_output},  // recurrent_to_cell_weight tensor
+                                      {n_cell, n_output},  // recurrent_to_output_weight tensor
+
+                                      {n_cell},  // cell_to_input_weight tensor
+                                      {n_cell},  // cell_to_forget_weight tensor
+                                      {n_cell},  // cell_to_output_weight tensor
+
+                                      {n_cell},  // input_gate_bias tensor
+                                      {n_cell},  // forget_gate_bias tensor
+                                      {n_cell},  // cell_bias tensor
+                                      {n_cell},  // output_gate_bias tensor
+
+                                      {n_output, n_cell},  // projection_weight tensor
+                                      {0},                 // projection_bias tensor
+
+                                      {n_batch, n_output},  // output_state_in tensor
+                                      {n_batch, n_cell},    // cell_state_in tensor
+
+                                      {n_cell},  // input_layer_norm_weights tensor
+                                      {n_cell},  // forget_layer_norm_weights tensor
+                                      {n_cell},  // cell_layer_norm_weights tensor
+                                      {n_cell},  // output_layer_norm_weights tensor
+                              });
+
+    lstm.SetInputToInputWeights({0.5,  0.6, 0.7,  -0.8, -0.9, 0.1,  0.2,  0.3,  -0.4, 0.5,
+                                 -0.8, 0.7, -0.6, 0.5,  -0.4, -0.5, -0.4, -0.3, -0.2, -0.1});
+
+    lstm.SetInputToForgetWeights({-0.6, -0.1, 0.3,  0.2,  0.9,  -0.5, -0.2, -0.4, 0.3,  -0.8,
+                                  -0.4, 0.3,  -0.5, -0.4, -0.6, 0.3,  -0.4, -0.6, -0.5, -0.5});
+
+    lstm.SetInputToCellWeights({-0.4, -0.3, -0.2, -0.1, -0.5, 0.5, -0.2, -0.3, -0.2, -0.6,
+                                0.6,  -0.1, -0.4, -0.3, -0.7, 0.7, -0.9, -0.5, 0.8,  0.6});
+
+    lstm.SetInputToOutputWeights({-0.8, -0.4, -0.2, -0.9, -0.1, -0.7, 0.3, -0.3, -0.8, -0.2,
+                                  0.6,  -0.2, 0.4,  -0.7, -0.3, -0.5, 0.1, 0.5,  -0.6, -0.4});
+
+    lstm.SetInputGateBias({0.03, 0.15, 0.22, 0.38});
+
+    lstm.SetForgetGateBias({0.1, -0.3, -0.2, 0.1});
+
+    lstm.SetCellGateBias({-0.05, 0.72, 0.25, 0.08});
+
+    lstm.SetOutputGateBias({0.05, -0.01, 0.2, 0.1});
+
+    lstm.SetRecurrentToInputWeights(
+            {-0.2, -0.3, 0.4, 0.1, -0.5, 0.9, -0.2, -0.3, -0.7, 0.05, -0.2, -0.6});
+
+    lstm.SetRecurrentToCellWeights(
+            {-0.3, 0.2, 0.1, -0.3, 0.8, -0.08, -0.2, 0.3, 0.8, -0.6, -0.1, 0.2});
+
+    lstm.SetRecurrentToForgetWeights(
+            {-0.5, -0.3, -0.5, -0.2, 0.6, 0.4, 0.9, 0.3, -0.1, 0.2, 0.5, 0.2});
+
+    lstm.SetRecurrentToOutputWeights(
+            {0.3, -0.1, 0.1, -0.2, -0.5, -0.7, -0.2, -0.6, -0.1, -0.4, -0.7, -0.2});
+
+    lstm.SetCellToInputWeights({0.05, 0.1, 0.25, 0.15});
+    lstm.SetCellToForgetWeights({-0.02, -0.15, -0.25, -0.03});
+    lstm.SetCellToOutputWeights({0.1, -0.1, -0.5, 0.05});
+
+    lstm.SetProjectionWeights({-0.1, 0.2, 0.01, -0.2, 0.1, 0.5, 0.3, 0.08, 0.07, 0.2, -0.4, 0.2});
+
+    lstm.SetInputLayerNormWeights({0.1, 0.2, 0.3, 0.5});
+    lstm.SetForgetLayerNormWeights({0.2, 0.2, 0.4, 0.3});
+    lstm.SetCellLayerNormWeights({0.7, 0.2, 0.3, 0.8});
+    lstm.SetOutputLayerNormWeights({0.6, 0.2, 0.2, 0.5});
+
+    const std::vector<std::vector<float>> lstm_input = {
+            {                           // Batch0: 3 (input_sequence_size) * 5 (n_input)
+             0.7, 0.8, 0.1, 0.2, 0.3,   // seq 0
+             0.8, 0.1, 0.2, 0.4, 0.5,   // seq 1
+             0.2, 0.7, 0.7, 0.1, 0.7},  // seq 2
+
+            {                           // Batch1: 3 (input_sequence_size) * 5 (n_input)
+             0.3, 0.2, 0.9, 0.8, 0.1,   // seq 0
+             0.1, 0.5, 0.2, 0.4, 0.2,   // seq 1
+             0.6, 0.9, 0.2, 0.5, 0.7},  // seq 2
+    };
+
+    const std::vector<std::vector<float>> lstm_golden_output = {
+            {
+                    // Batch0: 3 (input_sequence_size) * 3 (n_output)
+                    0.0244077, 0.128027, -0.00170918,  // seq 0
+                    0.0137642, 0.140751, 0.0395835,    // seq 1
+                    -0.00459231, 0.155278, 0.0837377,  // seq 2
+            },
+            {
+                    // Batch1: 3 (input_sequence_size) * 3 (n_output)
+                    -0.00692428, 0.0848741, 0.063445,  // seq 0
+                    -0.00403912, 0.139963, 0.072681,   // seq 1
+                    0.00752706, 0.161903, 0.0561371,   // seq 2
+            }};
+
+    // Resetting cell_state and output_state
+    lstm.ResetCellState();
+    lstm.ResetOutputState();
+
+    const int input_sequence_size = lstm_input[0].size() / n_input;
+    for (int i = 0; i < input_sequence_size; i++) {
+        for (int b = 0; b < n_batch; ++b) {
+            const float* batch_start = lstm_input[b].data() + i * n_input;
+            const float* batch_end = batch_start + n_input;
+
+            lstm.SetInput(b * n_input, batch_start, batch_end);
+        }
+
+        lstm.Invoke();
+
+        std::vector<float> expected;
+        for (int b = 0; b < n_batch; ++b) {
+            const float* golden_start = lstm_golden_output[b].data() + i * n_output;
+            const float* golden_end = golden_start + n_output;
+            expected.insert(expected.end(), golden_start, golden_end);
+        }
+        EXPECT_THAT(lstm.GetOutput(), ElementsAreArray(ArrayFloatNear(expected)));
+    }
+}
+
+}  // namespace wrapper
+}  // namespace nn
+}  // namespace android