Refactors LSTM step function for reuse.

Bug: 113559542
Test: NeuralNetworksTest_static --gtest_filter=GeneratedTests.lstm*
Change-Id: I427bc41bf81f4d0cd0021f333e79b5cab95d7105
Merged-In: I427bc41bf81f4d0cd0021f333e79b5cab95d7105
(cherry picked from commit 72fd4fe3fb705c9dbd543f1fe1c25d3d23545ad3)
diff --git a/common/operations/LSTM.cpp b/common/operations/LSTM.cpp
index f1e01e2..5159952 100644
--- a/common/operations/LSTM.cpp
+++ b/common/operations/LSTM.cpp
@@ -71,15 +71,15 @@
     output_state_in_ = GetInput(operation, operands, kOutputStateInTensor);
     cell_state_in_ = GetInput(operation, operands, kCellStateInTensor);
 
-    params_.activation_ = static_cast<TfLiteFusedActivation>(
+    params_.activation = static_cast<TfLiteFusedActivation>(
             getScalarData<int32_t>(*GetInput(operation, operands, kActivationParam)));
     if (input_->type == OperandType::TENSOR_FLOAT32) {
-        params_.cell_clip_ = getScalarData<float>(*GetInput(operation, operands, kCellClipParam));
-        params_.proj_clip_ = getScalarData<float>(*GetInput(operation, operands, kProjClipParam));
+        params_.cell_clip = getScalarData<float>(*GetInput(operation, operands, kCellClipParam));
+        params_.proj_clip = getScalarData<float>(*GetInput(operation, operands, kProjClipParam));
     } else {
-        params_.cell_clip_ = static_cast<float>(
+        params_.cell_clip = static_cast<float>(
                 getScalarData<_Float16>(*GetInput(operation, operands, kCellClipParam)));
-        params_.proj_clip_ = static_cast<float>(
+        params_.proj_clip = static_cast<float>(
                 getScalarData<_Float16>(*GetInput(operation, operands, kProjClipParam)));
     }
 
@@ -108,100 +108,126 @@
     scratch_buffer_ = GetOutput(operation, operands, kScratchBufferTensor);
 }
 
-bool LSTMCell::CheckInputTensorDimensions(const Operation& operation,
-                                          std::vector<RunTimeOperandInfo>& operands,
-                                          uint32_t n_input, uint32_t n_output, uint32_t n_cell) {
+// static
+bool LSTMCell::CheckInputTensorDimensions(
+        const RunTimeOperandInfo* input_, const RunTimeOperandInfo* input_to_input_weights,
+        const RunTimeOperandInfo* input_to_forget_weights,
+        const RunTimeOperandInfo* input_to_cell_weights,
+        const RunTimeOperandInfo* input_to_output_weights,
+        const RunTimeOperandInfo* recurrent_to_input_weights,
+        const RunTimeOperandInfo* recurrent_to_forget_weights,
+        const RunTimeOperandInfo* recurrent_to_cell_weights,
+        const RunTimeOperandInfo* recurrent_to_output_weights,
+        const RunTimeOperandInfo* cell_to_input_weights,
+        const RunTimeOperandInfo* cell_to_forget_weights,
+        const RunTimeOperandInfo* cell_to_output_weights, const RunTimeOperandInfo* input_gate_bias,
+        const RunTimeOperandInfo* forget_gate_bias, const RunTimeOperandInfo* cell_bias,
+        const RunTimeOperandInfo* output_gate_bias, const RunTimeOperandInfo* projection_weights,
+        const RunTimeOperandInfo* projection_bias,
+        const RunTimeOperandInfo* input_layer_norm_weights,
+        const RunTimeOperandInfo* forget_layer_norm_weights,
+        const RunTimeOperandInfo* cell_layer_norm_weights,
+        const RunTimeOperandInfo* output_layer_norm_weights, uint32_t n_input, uint32_t n_output,
+        uint32_t n_cell, LSTMParams* params) {
     // Making sure clipping parameters have valid values.
     // == 0 means no clipping
     //  > 0 means clipping
-    NN_CHECK(params_.cell_clip_ >= 0);
-    NN_CHECK(params_.proj_clip_ >= 0);
+    NN_CHECK(params->cell_clip >= 0);
+    NN_CHECK(params->proj_clip >= 0);
 
-    if (!IsNullInput(input_to_input_weights_)) {
-        NN_CHECK_EQ(NumDimensions(input_to_input_weights_), 2);
-        NN_CHECK_EQ(SizeOfDimension(input_to_input_weights_, 0), n_cell);
-        NN_CHECK_EQ(SizeOfDimension(input_to_input_weights_, 1), n_input);
+    if (!IsNullInput(input_to_input_weights)) {
+        NN_CHECK_EQ(NumDimensions(input_to_input_weights), 2);
+        NN_CHECK_EQ(SizeOfDimension(input_to_input_weights, 0), n_cell);
+        NN_CHECK_EQ(SizeOfDimension(input_to_input_weights, 1), n_input);
     }
 
-    NN_CHECK_EQ(NumDimensions(input_to_forget_weights_), 2);
-    NN_CHECK_EQ(SizeOfDimension(input_to_forget_weights_, 0), n_cell);
-    NN_CHECK_EQ(SizeOfDimension(input_to_forget_weights_, 1), n_input);
+    NN_CHECK_EQ(NumDimensions(input_to_forget_weights), 2);
+    NN_CHECK_EQ(SizeOfDimension(input_to_forget_weights, 0), n_cell);
+    NN_CHECK_EQ(SizeOfDimension(input_to_forget_weights, 1), n_input);
 
-    NN_CHECK_EQ(NumDimensions(input_to_cell_weights_), 2);
-    NN_CHECK_EQ(SizeOfDimension(input_to_cell_weights_, 0), n_cell);
-    NN_CHECK_EQ(SizeOfDimension(input_to_cell_weights_, 1), n_input);
+    NN_CHECK_EQ(NumDimensions(input_to_cell_weights), 2);
+    NN_CHECK_EQ(SizeOfDimension(input_to_cell_weights, 0), n_cell);
+    NN_CHECK_EQ(SizeOfDimension(input_to_cell_weights, 1), n_input);
 
-    if (!IsNullInput(recurrent_to_input_weights_)) {
-        NN_CHECK_EQ(NumDimensions(recurrent_to_input_weights_), 2);
-        NN_CHECK_EQ(SizeOfDimension(recurrent_to_input_weights_, 0), n_cell);
-        NN_CHECK_EQ(SizeOfDimension(recurrent_to_input_weights_, 1), n_output);
+    if (!IsNullInput(recurrent_to_input_weights)) {
+        NN_CHECK_EQ(NumDimensions(recurrent_to_input_weights), 2);
+        NN_CHECK_EQ(SizeOfDimension(recurrent_to_input_weights, 0), n_cell);
+        NN_CHECK_EQ(SizeOfDimension(recurrent_to_input_weights, 1), n_output);
     }
 
-    NN_CHECK_EQ(NumDimensions(recurrent_to_forget_weights_), 2);
-    NN_CHECK_EQ(SizeOfDimension(recurrent_to_forget_weights_, 0), n_cell);
-    NN_CHECK_EQ(SizeOfDimension(recurrent_to_forget_weights_, 1), n_output);
+    NN_CHECK_EQ(NumDimensions(recurrent_to_forget_weights), 2);
+    NN_CHECK_EQ(SizeOfDimension(recurrent_to_forget_weights, 0), n_cell);
+    NN_CHECK_EQ(SizeOfDimension(recurrent_to_forget_weights, 1), n_output);
 
-    NN_CHECK_EQ(NumDimensions(recurrent_to_cell_weights_), 2);
-    NN_CHECK_EQ(SizeOfDimension(recurrent_to_cell_weights_, 0), n_cell);
-    NN_CHECK_EQ(SizeOfDimension(recurrent_to_cell_weights_, 1), n_output);
+    NN_CHECK_EQ(NumDimensions(recurrent_to_cell_weights), 2);
+    NN_CHECK_EQ(SizeOfDimension(recurrent_to_cell_weights, 0), n_cell);
+    NN_CHECK_EQ(SizeOfDimension(recurrent_to_cell_weights, 1), n_output);
 
     // We make sure the input-gate's parameters are either both present (regular
     // LSTM) or not at all (CIFG-LSTM).
     const bool cifg_weights_all_or_none =
-            (!IsNullInput(input_to_input_weights_) && !IsNullInput(recurrent_to_input_weights_)) ||
-            (IsNullInput(input_to_input_weights_) && IsNullInput(recurrent_to_input_weights_));
+            (!IsNullInput(input_to_input_weights) && !IsNullInput(recurrent_to_input_weights)) ||
+            (IsNullInput(input_to_input_weights) && IsNullInput(recurrent_to_input_weights));
     NN_CHECK(cifg_weights_all_or_none);
 
-    if (!IsNullInput(cell_to_input_weights_)) {
-        NN_CHECK_EQ(NumDimensions(cell_to_input_weights_), 1);
-        NN_CHECK_EQ(SizeOfDimension(cell_to_input_weights_, 0), n_cell);
+    if (!IsNullInput(cell_to_input_weights)) {
+        NN_CHECK_EQ(NumDimensions(cell_to_input_weights), 1);
+        NN_CHECK_EQ(SizeOfDimension(cell_to_input_weights, 0), n_cell);
     }
 
-    if (!IsNullInput(cell_to_forget_weights_)) {
-        NN_CHECK_EQ(NumDimensions(cell_to_forget_weights_), 1);
-        NN_CHECK_EQ(SizeOfDimension(cell_to_forget_weights_, 0), n_cell);
+    if (!IsNullInput(cell_to_forget_weights)) {
+        NN_CHECK_EQ(NumDimensions(cell_to_forget_weights), 1);
+        NN_CHECK_EQ(SizeOfDimension(cell_to_forget_weights, 0), n_cell);
     }
 
-    if (!IsNullInput(cell_to_output_weights_)) {
-        NN_CHECK_EQ(NumDimensions(cell_to_output_weights_), 1);
-        NN_CHECK_EQ(SizeOfDimension(cell_to_output_weights_, 0), n_cell);
+    if (!IsNullInput(cell_to_output_weights)) {
+        NN_CHECK_EQ(NumDimensions(cell_to_output_weights), 1);
+        NN_CHECK_EQ(SizeOfDimension(cell_to_output_weights, 0), n_cell);
     }
 
     // Making sure the peephole weights are there all or none.
-    const bool use_cifg = IsNullInput(input_to_input_weights_);
+    params->use_cifg = IsNullInput(input_to_input_weights);
     const bool peephole_weights_all_or_none =
-            ((!IsNullInput(cell_to_input_weights_) || use_cifg) &&
-             !IsNullInput(cell_to_forget_weights_) && !IsNullInput(cell_to_output_weights_)) ||
-            (IsNullInput(cell_to_input_weights_) && IsNullInput(cell_to_forget_weights_) &&
-             IsNullInput(cell_to_output_weights_));
+            ((!IsNullInput(cell_to_input_weights) || params->use_cifg) &&
+             !IsNullInput(cell_to_forget_weights) && !IsNullInput(cell_to_output_weights)) ||
+            (IsNullInput(cell_to_input_weights) && IsNullInput(cell_to_forget_weights) &&
+             IsNullInput(cell_to_output_weights));
     NN_CHECK(peephole_weights_all_or_none);
 
+    // Since we have already checked that weights are all there or none, we can
+    // check the existence of only one to the get the condition.
+    params->use_peephole = !IsNullInput(cell_to_output_weights);
+    params->use_layer_norm = !IsNullInput(input_layer_norm_weights);
+
+    params->use_projection_weight = (projection_weights->lifetime != OperandLifeTime::NO_VALUE);
+    params->use_projection_bias = (projection_bias->lifetime != OperandLifeTime::NO_VALUE);
+
     // Make sure the input gate bias is present only when not a CIFG-LSTM.
-    if (use_cifg) {
-        NN_CHECK(IsNullInput(input_gate_bias_));
+    if (params->use_cifg) {
+        NN_CHECK(IsNullInput(input_gate_bias));
     } else {
-        NN_CHECK_EQ(NumDimensions(input_gate_bias_), 1);
-        NN_CHECK_EQ(SizeOfDimension(input_gate_bias_, 0), n_cell);
+        NN_CHECK_EQ(NumDimensions(input_gate_bias), 1);
+        NN_CHECK_EQ(SizeOfDimension(input_gate_bias, 0), n_cell);
     }
 
-    NN_CHECK_EQ(NumDimensions(forget_gate_bias_), 1);
-    NN_CHECK_EQ(SizeOfDimension(forget_gate_bias_, 0), n_cell);
+    NN_CHECK_EQ(NumDimensions(forget_gate_bias), 1);
+    NN_CHECK_EQ(SizeOfDimension(forget_gate_bias, 0), n_cell);
 
-    NN_CHECK_EQ(NumDimensions(cell_bias_), 1);
-    NN_CHECK_EQ(SizeOfDimension(cell_bias_, 0), n_cell);
+    NN_CHECK_EQ(NumDimensions(cell_bias), 1);
+    NN_CHECK_EQ(SizeOfDimension(cell_bias, 0), n_cell);
 
-    NN_CHECK_EQ(NumDimensions(output_gate_bias_), 1);
-    NN_CHECK_EQ(SizeOfDimension(output_gate_bias_, 0), n_cell);
+    NN_CHECK_EQ(NumDimensions(output_gate_bias), 1);
+    NN_CHECK_EQ(SizeOfDimension(output_gate_bias, 0), n_cell);
 
-    if (!IsNullInput(projection_weights_)) {
-        NN_CHECK_EQ(NumDimensions(projection_weights_), 2);
-        NN_CHECK_EQ(SizeOfDimension(projection_weights_, 0), n_output);
-        NN_CHECK_EQ(SizeOfDimension(projection_weights_, 1), n_cell);
+    if (!IsNullInput(projection_weights)) {
+        NN_CHECK_EQ(NumDimensions(projection_weights), 2);
+        NN_CHECK_EQ(SizeOfDimension(projection_weights, 0), n_output);
+        NN_CHECK_EQ(SizeOfDimension(projection_weights, 1), n_cell);
     }
 
-    if (!IsNullInput(projection_bias_)) {
-        NN_CHECK_EQ(NumDimensions(projection_bias_), 1);
-        NN_CHECK_EQ(SizeOfDimension(projection_bias_, 0), n_output);
+    if (!IsNullInput(projection_bias)) {
+        NN_CHECK_EQ(NumDimensions(projection_bias), 1);
+        NN_CHECK_EQ(SizeOfDimension(projection_bias, 0), n_output);
     }
 
     // Making sure the projection tensors are consistent:
@@ -210,31 +236,31 @@
     // 2) If projection weight is present, then projection bias is optional.
     // TODO: make sure this is correct.
     const bool projecton_tensors_consistent =
-            (!IsNullInput(projection_weights_) || IsNullInput(projection_bias_));
+            (!IsNullInput(projection_weights) || IsNullInput(projection_bias));
     NN_CHECK(projecton_tensors_consistent == true);
 
-    if (!IsNullInput(input_layer_norm_weights_)) {
-        NN_CHECK_EQ(NumDimensions(input_layer_norm_weights_), 1);
-        NN_CHECK_EQ(SizeOfDimension(input_layer_norm_weights_, 0), n_cell);
+    if (!IsNullInput(input_layer_norm_weights)) {
+        NN_CHECK_EQ(NumDimensions(input_layer_norm_weights), 1);
+        NN_CHECK_EQ(SizeOfDimension(input_layer_norm_weights, 0), n_cell);
     }
-    if (!IsNullInput(forget_layer_norm_weights_)) {
-        NN_CHECK_EQ(NumDimensions(forget_layer_norm_weights_), 1);
-        NN_CHECK_EQ(SizeOfDimension(forget_layer_norm_weights_, 0), n_cell);
+    if (!IsNullInput(forget_layer_norm_weights)) {
+        NN_CHECK_EQ(NumDimensions(forget_layer_norm_weights), 1);
+        NN_CHECK_EQ(SizeOfDimension(forget_layer_norm_weights, 0), n_cell);
     }
-    if (!IsNullInput(cell_layer_norm_weights_)) {
-        NN_CHECK_EQ(NumDimensions(cell_layer_norm_weights_), 1);
-        NN_CHECK_EQ(SizeOfDimension(cell_layer_norm_weights_, 0), n_cell);
+    if (!IsNullInput(cell_layer_norm_weights)) {
+        NN_CHECK_EQ(NumDimensions(cell_layer_norm_weights), 1);
+        NN_CHECK_EQ(SizeOfDimension(cell_layer_norm_weights, 0), n_cell);
     }
-    if (!IsNullInput(output_layer_norm_weights_)) {
-        NN_CHECK_EQ(NumDimensions(output_layer_norm_weights_), 1);
-        NN_CHECK_EQ(SizeOfDimension(output_layer_norm_weights_, 0), n_cell);
+    if (!IsNullInput(output_layer_norm_weights)) {
+        NN_CHECK_EQ(NumDimensions(output_layer_norm_weights), 1);
+        NN_CHECK_EQ(SizeOfDimension(output_layer_norm_weights, 0), n_cell);
     }
 
     const bool layer_norm_weights_all_or_none =
-            (IsNullInput(input_layer_norm_weights_) && IsNullInput(forget_layer_norm_weights_) &&
-             IsNullInput(cell_layer_norm_weights_) && IsNullInput(input_layer_norm_weights_)) ||
-            (!IsNullInput(input_layer_norm_weights_) && !IsNullInput(forget_layer_norm_weights_) &&
-             !IsNullInput(cell_layer_norm_weights_) && !IsNullInput(input_layer_norm_weights_));
+            (IsNullInput(input_layer_norm_weights) && IsNullInput(forget_layer_norm_weights) &&
+             IsNullInput(cell_layer_norm_weights) && IsNullInput(input_layer_norm_weights)) ||
+            (!IsNullInput(input_layer_norm_weights) && !IsNullInput(forget_layer_norm_weights) &&
+             !IsNullInput(cell_layer_norm_weights) && !IsNullInput(input_layer_norm_weights));
     NN_CHECK(layer_norm_weights_all_or_none);
 
     return true;
@@ -263,7 +289,15 @@
     const uint32_t n_output = SizeOfDimension(recurrent_to_output_weights_, 1);
 
     // Check that input tensor dimensions matches with each other.
-    if (!CheckInputTensorDimensions(operation, operands, n_input, n_output, n_cell)) {
+    if (!CheckInputTensorDimensions(
+                input_, input_to_input_weights_, input_to_forget_weights_, input_to_cell_weights_,
+                input_to_output_weights_, recurrent_to_input_weights_, recurrent_to_forget_weights_,
+                recurrent_to_cell_weights_, recurrent_to_output_weights_, cell_to_input_weights_,
+                cell_to_forget_weights_, cell_to_output_weights_, input_gate_bias_,
+                forget_gate_bias_, cell_bias_, output_gate_bias_, projection_weights_,
+                projection_bias_, input_layer_norm_weights_, forget_layer_norm_weights_,
+                cell_layer_norm_weights_, output_layer_norm_weights_, n_input, n_output, n_cell,
+                &params_)) {
         return false;
     }
 
@@ -285,8 +319,7 @@
     cellStateShape->offset = inputShape.offset;
     cellStateShape->scale = inputShape.scale;
 
-    const bool use_cifg = IsNullInput(input_to_input_weights_);
-    if (use_cifg) {
+    if (params_.use_cifg) {
         // Reserving space for Cell, Forget, Output gates
         scratchShape->dimensions = {n_batch, n_cell * 3};
     } else {
@@ -300,13 +333,16 @@
     return true;
 }
 
-bool LSTMCell::EvalFloat32(
-        const float* input_buffer, const float* input_to_input_weights_buffer,
-        const float* input_to_forget_weights_buffer, const float* input_to_cell_weights_buffer,
-        const float* input_to_output_weights_buffer, const float* recurrent_to_input_weights_buffer,
+// static
+bool LSTMCell::LSTMStep(
+        const LSTMParams& params, const float* input_buffer, const Shape& input_shape,
+        const float* input_to_input_weights_buffer, const float* input_to_forget_weights_buffer,
+        const float* input_to_cell_weights_buffer, const float* input_to_output_weights_buffer,
+        const Shape& input_to_output_weights_shape, const float* recurrent_to_input_weights_buffer,
         const float* recurrent_to_forget_weights_buffer,
         const float* recurrent_to_cell_weights_buffer,
-        const float* recurrent_to_output_weights_buffer, const float* cell_to_input_weights_buffer,
+        const float* recurrent_to_output_weights_buffer,
+        const Shape& recurrent_to_output_weights_shape, const float* cell_to_input_weights_buffer,
         const float* cell_to_forget_weights_buffer, const float* cell_to_output_weights_buffer,
         const float* input_gate_bias_buffer, const float* forget_gate_bias_buffer,
         const float* cell_bias_buffer, const float* output_gate_bias_buffer,
@@ -316,26 +352,20 @@
         const float* cell_layer_norm_weights_buffer, const float* output_layer_norm_weights_buffer,
         float* output_state_out_buffer, float* cell_state_out_buffer, float* output_buffer,
         float* scratch_buffer_buffer) {
-    NNTRACE_COMP("LSTMCell::Eval");
+    NNTRACE_COMP("LSTMCell::LSTMStep");
 
-    const uint32_t n_batch = input_->shape().dimensions[0];
-    const uint32_t n_input = input_->shape().dimensions[1];
+    const uint32_t n_batch = input_shape.dimensions[0];
+    const uint32_t n_input = input_shape.dimensions[1];
     // n_cell and n_output will be the same size when there is no projection.
-    const uint32_t n_cell = input_to_output_weights_->shape().dimensions[0];
-    const uint32_t n_output = recurrent_to_output_weights_->shape().dimensions[1];
-
-    // Since we have already checked that weights are all there or none, we can
-    // check the existence of only one to the get the condition.
-    const bool use_cifg = IsNullInput(input_to_input_weights_);
-    const bool use_peephole = !IsNullInput(cell_to_output_weights_);
-    const bool use_layer_norm = !IsNullInput(input_layer_norm_weights_);
+    const uint32_t n_cell = input_to_output_weights_shape.dimensions[0];
+    const uint32_t n_output = recurrent_to_output_weights_shape.dimensions[1];
 
     // Index the scratch buffers pointers to the global scratch buffer.
     float* input_gate_scratch = nullptr;
     float* cell_scratch = nullptr;
     float* forget_gate_scratch = nullptr;
     float* output_gate_scratch = nullptr;
-    if (use_cifg) {
+    if (params.use_cifg) {
         cell_scratch = scratch_buffer_buffer;
         forget_gate_scratch = cell_scratch + n_cell * n_batch;
         output_gate_scratch = cell_scratch + 2 * n_cell * n_batch;
@@ -346,9 +376,9 @@
         output_gate_scratch = input_gate_scratch + 3 * n_cell * n_batch;
     }
 
-    if (!use_layer_norm) {
+    if (!params.use_layer_norm) {
         // Initialize scratch buffers with bias.
-        if (!use_cifg) {
+        if (!params.use_cifg) {
             tflite::tensor_utils::VectorBatchVectorAssign(input_gate_bias_buffer, n_cell, n_batch,
                                                           input_gate_scratch);
         }
@@ -360,7 +390,7 @@
                                                       output_gate_scratch);
     } else {
         // Initialize scratch buffers with zeroes.
-        if (!use_cifg) {
+        if (!params.use_cifg) {
             tflite::tensor_utils::ZeroVector(input_gate_scratch, n_cell * n_batch);
         }
         tflite::tensor_utils::ZeroVector(forget_gate_scratch, n_cell * n_batch);
@@ -369,7 +399,7 @@
     }
 
     // For each batch and cell: compute input_weight * input.
-    if (!use_cifg) {
+    if (!params.use_cifg) {
         tflite::tensor_utils::MatrixBatchVectorMultiplyAccumulate(
                 input_to_input_weights_buffer, n_cell, n_input, input_buffer, n_batch,
                 input_gate_scratch, /*result_stride*/ 1);
@@ -385,7 +415,7 @@
             output_gate_scratch, /*result_stride*/ 1);
 
     // For each batch and cell: compute recurrent_weight * output_state.
-    if (!use_cifg) {
+    if (!params.use_cifg) {
         tflite::tensor_utils::MatrixBatchVectorMultiplyAccumulate(
                 recurrent_to_input_weights_buffer, n_cell, n_output, output_state_in_buffer,
                 n_batch, input_gate_scratch,
@@ -402,13 +432,13 @@
             output_gate_scratch, /*result_stride*/ 1);
 
     // For each batch and cell: update input gate.
-    if (!use_cifg) {
-        if (use_peephole) {
+    if (!params.use_cifg) {
+        if (params.use_peephole) {
             tflite::tensor_utils::VectorBatchVectorCwiseProductAccumulate(
                     cell_to_input_weights_buffer, n_cell, cell_state_in_buffer, n_batch,
                     input_gate_scratch);
         }
-        if (use_layer_norm) {
+        if (params.use_layer_norm) {
             tflite::tensor_utils::MeanStddevNormalization(input_gate_scratch, input_gate_scratch,
                                                           n_cell, n_batch, kLayerNormEpsilon);
             tflite::tensor_utils::VectorBatchVectorCwiseProduct(input_layer_norm_weights_buffer,
@@ -422,12 +452,12 @@
     }
 
     // For each batch and cell: update forget gate.
-    if (use_peephole) {
+    if (params.use_peephole) {
         tflite::tensor_utils::VectorBatchVectorCwiseProductAccumulate(cell_to_forget_weights_buffer,
                                                                       n_cell, cell_state_in_buffer,
                                                                       n_batch, forget_gate_scratch);
     }
-    if (use_layer_norm) {
+    if (params.use_layer_norm) {
         tflite::tensor_utils::MeanStddevNormalization(forget_gate_scratch, forget_gate_scratch,
                                                       n_cell, n_batch, kLayerNormEpsilon);
         tflite::tensor_utils::VectorBatchVectorCwiseProduct(forget_layer_norm_weights_buffer,
@@ -440,7 +470,7 @@
                                                forget_gate_scratch);
 
     // For each batch and cell: update the cell.
-    if (use_layer_norm) {
+    if (params.use_layer_norm) {
         tflite::tensor_utils::MeanStddevNormalization(cell_scratch, cell_scratch, n_cell, n_batch,
                                                       kLayerNormEpsilon);
         tflite::tensor_utils::VectorBatchVectorCwiseProduct(cell_layer_norm_weights_buffer, n_cell,
@@ -449,9 +479,9 @@
     }
     tflite::tensor_utils::VectorVectorCwiseProduct(forget_gate_scratch, cell_state_in_buffer,
                                                    n_batch * n_cell, cell_state_out_buffer);
-    tflite::tensor_utils::ApplyActivationToVector(cell_scratch, n_batch * n_cell,
-                                                  params_.activation_, cell_scratch);
-    if (use_cifg) {
+    tflite::tensor_utils::ApplyActivationToVector(cell_scratch, n_batch * n_cell, params.activation,
+                                                  cell_scratch);
+    if (params.use_cifg) {
         tflite::tensor_utils::Sub1Vector(forget_gate_scratch, n_batch * n_cell,
                                          forget_gate_scratch);
         tflite::tensor_utils::VectorVectorCwiseProductAccumulate(
@@ -460,18 +490,18 @@
         tflite::tensor_utils::VectorVectorCwiseProductAccumulate(
                 cell_scratch, input_gate_scratch, n_batch * n_cell, cell_state_out_buffer);
     }
-    if (params_.cell_clip_ > 0.0) {
-        tflite::tensor_utils::ClipVector(cell_state_out_buffer, n_batch * n_cell,
-                                         params_.cell_clip_, cell_state_out_buffer);
+    if (params.cell_clip > 0.0) {
+        tflite::tensor_utils::ClipVector(cell_state_out_buffer, n_batch * n_cell, params.cell_clip,
+                                         cell_state_out_buffer);
     }
 
     // For each batch and cell: update the output gate.
-    if (use_peephole) {
+    if (params.use_peephole) {
         tflite::tensor_utils::VectorBatchVectorCwiseProductAccumulate(cell_to_output_weights_buffer,
                                                                       n_cell, cell_state_out_buffer,
                                                                       n_batch, output_gate_scratch);
     }
-    if (use_layer_norm) {
+    if (params.use_layer_norm) {
         tflite::tensor_utils::MeanStddevNormalization(output_gate_scratch, output_gate_scratch,
                                                       n_cell, n_batch, kLayerNormEpsilon);
         tflite::tensor_utils::VectorBatchVectorCwiseProduct(output_layer_norm_weights_buffer,
@@ -483,15 +513,13 @@
     tflite::tensor_utils::ApplySigmoidToVector(output_gate_scratch, n_batch * n_cell,
                                                output_gate_scratch);
     tflite::tensor_utils::ApplyActivationToVector(cell_state_out_buffer, n_batch * n_cell,
-                                                  params_.activation_, cell_scratch);
+                                                  params.activation, cell_scratch);
     tflite::tensor_utils::VectorVectorCwiseProduct(output_gate_scratch, cell_scratch,
                                                    n_batch * n_cell, output_gate_scratch);
 
     // For each batch: update the projection and output_state.
-    const bool use_projection_weight = (projection_weights_->lifetime != OperandLifeTime::NO_VALUE);
-    const bool use_projection_bias = (projection_bias_->lifetime != OperandLifeTime::NO_VALUE);
-    if (use_projection_weight) {
-        if (use_projection_bias) {
+    if (params.use_projection_weight) {
+        if (params.use_projection_bias) {
             tflite::tensor_utils::VectorBatchVectorAssign(projection_bias_buffer, n_output, n_batch,
                                                           output_buffer);
         } else {
@@ -501,8 +529,8 @@
                 projection_weights_buffer, n_output, n_cell, output_gate_scratch, n_batch,
                 output_buffer,
                 /*result_stride*/ 1);
-        if (params_.proj_clip_ > 0.0) {
-            tflite::tensor_utils::ClipVector(output_buffer, n_batch * n_output, params_.proj_clip_,
+        if (params.proj_clip > 0.0) {
+            tflite::tensor_utils::ClipVector(output_buffer, n_batch * n_output, params.proj_clip,
                                              output_buffer);
         }
     } else {
@@ -515,37 +543,33 @@
 bool LSTMCell::Eval() {
     switch (input_->type) {
         case OperandType::TENSOR_FLOAT32: {
-            // clang-format off
-            EvalFloat32(
-                    GetBuffer<const float>(input_),
-                    GetBuffer<const float>(input_to_input_weights_),
-                    GetBuffer<const float>(input_to_forget_weights_),
-                    GetBuffer<const float>(input_to_cell_weights_),
-                    GetBuffer<const float>(input_to_output_weights_),
-                    GetBuffer<const float>(recurrent_to_input_weights_),
-                    GetBuffer<const float>(recurrent_to_forget_weights_),
-                    GetBuffer<const float>(recurrent_to_cell_weights_),
-                    GetBuffer<const float>(recurrent_to_output_weights_),
-                    GetBuffer<const float>(cell_to_input_weights_),
-                    GetBuffer<const float>(cell_to_forget_weights_),
-                    GetBuffer<const float>(cell_to_output_weights_),
-                    GetBuffer<const float>(input_gate_bias_),
-                    GetBuffer<const float>(forget_gate_bias_),
-                    GetBuffer<const float>(cell_bias_),
-                    GetBuffer<const float>(output_gate_bias_),
-                    GetBuffer<const float>(projection_weights_),
-                    GetBuffer<const float>(projection_bias_),
-                    GetBuffer<const float>(output_state_in_),
-                    GetBuffer<const float>(cell_state_in_),
-                    GetBuffer<const float>(input_layer_norm_weights_),
-                    GetBuffer<const float>(forget_layer_norm_weights_),
-                    GetBuffer<const float>(cell_layer_norm_weights_),
-                    GetBuffer<const float>(output_layer_norm_weights_),
-                    GetBuffer<float>(output_state_out_),
-                    GetBuffer<float>(cell_state_out_),
-                    GetBuffer<float>(output_),
-                    GetBuffer<float>(scratch_buffer_));
-            // clang-format on
+            LSTMStep(params_, GetBuffer<const float>(input_), input_->shape(),
+                     GetBuffer<const float>(input_to_input_weights_),
+                     GetBuffer<const float>(input_to_forget_weights_),
+                     GetBuffer<const float>(input_to_cell_weights_),
+                     GetBuffer<const float>(input_to_output_weights_),
+                     input_to_output_weights_->shape(),
+                     GetBuffer<const float>(recurrent_to_input_weights_),
+                     GetBuffer<const float>(recurrent_to_forget_weights_),
+                     GetBuffer<const float>(recurrent_to_cell_weights_),
+                     GetBuffer<const float>(recurrent_to_output_weights_),
+                     recurrent_to_output_weights_->shape(),
+                     GetBuffer<const float>(cell_to_input_weights_),
+                     GetBuffer<const float>(cell_to_forget_weights_),
+                     GetBuffer<const float>(cell_to_output_weights_),
+                     GetBuffer<const float>(input_gate_bias_),
+                     GetBuffer<const float>(forget_gate_bias_), GetBuffer<const float>(cell_bias_),
+                     GetBuffer<const float>(output_gate_bias_),
+                     GetBuffer<const float>(projection_weights_),
+                     GetBuffer<const float>(projection_bias_),
+                     GetBuffer<const float>(output_state_in_),
+                     GetBuffer<const float>(cell_state_in_),
+                     GetBuffer<const float>(input_layer_norm_weights_),
+                     GetBuffer<const float>(forget_layer_norm_weights_),
+                     GetBuffer<const float>(cell_layer_norm_weights_),
+                     GetBuffer<const float>(output_layer_norm_weights_),
+                     GetBuffer<float>(output_state_out_), GetBuffer<float>(cell_state_out_),
+                     GetBuffer<float>(output_), GetBuffer<float>(scratch_buffer_));
         } break;
         case OperandType::TENSOR_FLOAT16: {
             std::vector<float> input_float32(getNumberOfElements(input_->shape()));
@@ -699,37 +723,22 @@
                     getNumberOfElements(scratch_buffer_->shape()));
             convertFloat16ToFloat32(GetBuffer<_Float16>(scratch_buffer_), &scratch_buffer_float32);
 
-            // clang-format off
-            EvalFloat32(
-                    input_float32.data(),
-                    input_to_input_weights_buffer,
-                    input_to_forget_weights_float32.data(),
-                    input_to_cell_weights_float32.data(),
-                    input_to_output_weights_float32.data(),
-                    recurrent_to_input_weights_buffer,
-                    recurrent_to_forget_weights_float32.data(),
-                    recurrent_to_cell_weights_float32.data(),
-                    recurrent_to_output_weights_float32.data(),
-                    cell_to_input_weights_buffer,
-                    cell_to_forget_weights_buffer,
-                    cell_to_output_weights_buffer,
-                    input_gate_bias_buffer,
-                    forget_gate_bias_float32.data(),
-                    cell_bias_float32.data(),
-                    output_gate_bias_float32.data(),
-                    projection_weights_buffer,
-                    projection_bias_buffer,
-                    output_state_in_float32.data(),
-                    cell_state_in_float32.data(),
-                    input_layer_norm_weights_buffer,
-                    forget_layer_norm_weights_buffer,
-                    cell_layer_norm_weights_buffer,
-                    output_layer_norm_weights_buffer,
-                    output_state_out_float32.data(),
-                    cell_state_out_float32.data(),
-                    output_float32.data(),
-                    scratch_buffer_float32.data());
-            // clang-format on
+            LSTMStep(params_, input_float32.data(), input_->shape(), input_to_input_weights_buffer,
+                     input_to_forget_weights_float32.data(), input_to_cell_weights_float32.data(),
+                     input_to_output_weights_float32.data(), input_to_output_weights_->shape(),
+                     recurrent_to_input_weights_buffer, recurrent_to_forget_weights_float32.data(),
+                     recurrent_to_cell_weights_float32.data(),
+                     recurrent_to_output_weights_float32.data(),
+                     recurrent_to_output_weights_->shape(), cell_to_input_weights_buffer,
+                     cell_to_forget_weights_buffer, cell_to_output_weights_buffer,
+                     input_gate_bias_buffer, forget_gate_bias_float32.data(),
+                     cell_bias_float32.data(), output_gate_bias_float32.data(),
+                     projection_weights_buffer, projection_bias_buffer,
+                     output_state_in_float32.data(), cell_state_in_float32.data(),
+                     input_layer_norm_weights_buffer, forget_layer_norm_weights_buffer,
+                     cell_layer_norm_weights_buffer, output_layer_norm_weights_buffer,
+                     output_state_out_float32.data(), cell_state_out_float32.data(),
+                     output_float32.data(), scratch_buffer_float32.data());
 
             convertFloat32ToFloat16(output_state_out_float32,
                                     GetBuffer<_Float16>(output_state_out_));