Implements BIDIRECTIONAL_SEQUENCE_LSTM operation.

Test: NeuralNetworksTest_static --gtest_filter=GeneratedTests.*lstm
Bug: 113559542
Change-Id: If30e31c851bfbd97445710d8e1998306a551ac08
Merged-In: If30e31c851bfbd97445710d8e1998306a551ac08
(cherry picked from commit be339f503051e72ed63e88a5645af15b02a44478)
diff --git a/common/operations/LSTM.cpp b/common/operations/LSTM.cpp
index 293b23e..1b36574 100644
--- a/common/operations/LSTM.cpp
+++ b/common/operations/LSTM.cpp
@@ -350,14 +350,19 @@
         const float* recurrent_to_output_weights_buffer,
         const Shape& recurrent_to_output_weights_shape, const float* cell_to_input_weights_buffer,
         const float* cell_to_forget_weights_buffer, const float* cell_to_output_weights_buffer,
-        const float* input_gate_bias_buffer, const float* forget_gate_bias_buffer,
-        const float* cell_bias_buffer, const float* output_gate_bias_buffer,
-        const float* projection_weights_buffer, const float* projection_bias_buffer,
-        const float* output_state_in_buffer, const float* cell_state_in_buffer,
-        const float* input_layer_norm_weights_buffer, const float* forget_layer_norm_weights_buffer,
-        const float* cell_layer_norm_weights_buffer, const float* output_layer_norm_weights_buffer,
-        float* output_state_out_buffer, float* cell_state_out_buffer, float* output_buffer,
-        float* scratch_buffer_buffer, bool timeMajor) {
+        const float* aux_input_buffer, const Shape& aux_input_shape,
+        const float* aux_input_to_input_weights_buffer,
+        const float* aux_input_to_forget_weights_buffer,
+        const float* aux_input_to_cell_weights_buffer,
+        const float* aux_input_to_output_weights_buffer, const float* input_gate_bias_buffer,
+        const float* forget_gate_bias_buffer, const float* cell_bias_buffer,
+        const float* output_gate_bias_buffer, const float* projection_weights_buffer,
+        const float* projection_bias_buffer, const float* output_state_in_buffer,
+        const float* cell_state_in_buffer, const float* input_layer_norm_weights_buffer,
+        const float* forget_layer_norm_weights_buffer, const float* cell_layer_norm_weights_buffer,
+        const float* output_layer_norm_weights_buffer, float* output_state_out_buffer,
+        float* cell_state_out_buffer, float* output_buffer, float* scratch_buffer_buffer,
+        bool forwardSequence, bool timeMajor) {
     NNTRACE_COMP("LSTMCell::LSTMEvalFloat32");
 
     const uint32_t inputRank = getNumberOfDimensions(input_shape);
@@ -388,13 +393,20 @@
         transposedOutputShape = transposedInputShape;
         transposedOutputShape.dimensions[2] = outputSize;
     }
-    const float* inputCurrentTimeStep = timeMajor ? input_buffer : transposedInput.data();
-    float* outputCurrentTimeStep = timeMajor ? output_buffer : transposedOutput.data();
+    const float* inputData = timeMajor ? input_buffer : transposedInput.data();
+    float* outputData = timeMajor ? output_buffer : transposedOutput.data();
 
     std::vector<float> outputStateInCurrentTimeStep(
             output_state_in_buffer, output_state_in_buffer + batchSize * outputSize);
     std::vector<float> cellStateInCurrentTimeStep(cell_state_in_buffer,
                                                   cell_state_in_buffer + batchSize * numCells);
+    const float* inputCurrentTimeStep =
+            inputData + (forwardSequence ? 0 : batchInputSize * (maxTime - 1));
+    float* outputCurrentTimeStep =
+            outputData + (forwardSequence ? 0 : batchOutputSize * (maxTime - 1));
+    const int batchInputDelta = forwardSequence ? batchInputSize : -batchInputSize;
+    const int batchOutputDelta = forwardSequence ? batchOutputSize : -batchOutputSize;
+
     for (int t = 0; t < maxTime; ++t) {
         LSTMStep(params, inputCurrentTimeStep, batchInputShape, input_to_input_weights_buffer,
                  input_to_forget_weights_buffer, input_to_cell_weights_buffer,
@@ -402,16 +414,18 @@
                  recurrent_to_input_weights_buffer, recurrent_to_forget_weights_buffer,
                  recurrent_to_cell_weights_buffer, recurrent_to_output_weights_buffer,
                  recurrent_to_output_weights_shape, cell_to_input_weights_buffer,
-                 cell_to_forget_weights_buffer, cell_to_output_weights_buffer,
-                 input_gate_bias_buffer, forget_gate_bias_buffer, cell_bias_buffer,
-                 output_gate_bias_buffer, projection_weights_buffer, projection_bias_buffer,
-                 outputStateInCurrentTimeStep.data(), cellStateInCurrentTimeStep.data(),
-                 input_layer_norm_weights_buffer, forget_layer_norm_weights_buffer,
-                 cell_layer_norm_weights_buffer, output_layer_norm_weights_buffer,
-                 output_state_out_buffer, cell_state_out_buffer, outputCurrentTimeStep,
-                 scratch_buffer_buffer);
-        inputCurrentTimeStep += batchInputSize;
-        outputCurrentTimeStep += batchOutputSize;
+                 cell_to_forget_weights_buffer, cell_to_output_weights_buffer, aux_input_buffer,
+                 aux_input_shape, aux_input_to_input_weights_buffer,
+                 aux_input_to_forget_weights_buffer, aux_input_to_cell_weights_buffer,
+                 aux_input_to_output_weights_buffer, input_gate_bias_buffer,
+                 forget_gate_bias_buffer, cell_bias_buffer, output_gate_bias_buffer,
+                 projection_weights_buffer, projection_bias_buffer, output_state_in_buffer,
+                 cell_state_in_buffer, input_layer_norm_weights_buffer,
+                 forget_layer_norm_weights_buffer, cell_layer_norm_weights_buffer,
+                 output_layer_norm_weights_buffer, output_state_out_buffer, cell_state_out_buffer,
+                 outputCurrentTimeStep, scratch_buffer_buffer);
+        inputCurrentTimeStep += batchInputDelta;
+        outputCurrentTimeStep += batchOutputDelta;
         outputStateInCurrentTimeStep.assign(output_state_out_buffer,
                                             output_state_out_buffer + batchSize * outputSize);
         cellStateInCurrentTimeStep.assign(cell_state_out_buffer,
@@ -439,7 +453,11 @@
         const _Float16* recurrent_to_output_weights_buffer,
         const Shape& recurrent_to_output_weights_shape,
         const _Float16* cell_to_input_weights_buffer, const _Float16* cell_to_forget_weights_buffer,
-        const _Float16* cell_to_output_weights_buffer, const _Float16* input_gate_bias_buffer,
+        const _Float16* cell_to_output_weights_buffer, const _Float16* aux_input_buffer,
+        const Shape& aux_input_shape, const _Float16* aux_input_to_input_weights_buffer,
+        const _Float16* aux_input_to_forget_weights_buffer,
+        const _Float16* aux_input_to_cell_weights_buffer,
+        const _Float16* aux_input_to_output_weights_buffer, const _Float16* input_gate_bias_buffer,
         const _Float16* forget_gate_bias_buffer, const _Float16* cell_bias_buffer,
         const _Float16* output_gate_bias_buffer, const _Float16* projection_weights_buffer,
         const _Float16* projection_bias_buffer, const _Float16* output_state_in_buffer,
@@ -448,7 +466,7 @@
         const _Float16* cell_layer_norm_weights_buffer,
         const _Float16* output_layer_norm_weights_buffer, _Float16* output_state_out_buffer,
         _Float16* cell_state_out_buffer, _Float16* output_buffer, _Float16* scratch_buffer_buffer,
-        bool timeMajor) {
+        bool forwardSequence, bool timeMajor) {
     NNTRACE_COMP("LSTMCell::LSTMEvalFloat16");
 
     const uint32_t inputRank = getNumberOfDimensions(input_shape);
@@ -507,6 +525,22 @@
         convertFloat16ToFloat32(cell_to_output_weights_buffer, &cell_to_output_weights_float32);
     }
 
+    std::vector<float> aux_input_float32(maxTime * batchInputSize);
+    convertFloat16ToFloat32(aux_input_buffer, &aux_input_float32);
+    std::vector<float> aux_input_to_input_weights_float32(numCells * inputSize);
+    if (aux_input_to_input_weights_buffer != nullptr) {
+        convertFloat16ToFloat32(aux_input_to_input_weights_buffer,
+                                &aux_input_to_input_weights_float32);
+    }
+    std::vector<float> aux_input_to_forget_weights_float32(numCells * inputSize);
+    convertFloat16ToFloat32(aux_input_to_forget_weights_buffer,
+                            &aux_input_to_forget_weights_float32);
+    std::vector<float> aux_input_to_cell_weights_float32(numCells * inputSize);
+    convertFloat16ToFloat32(aux_input_to_cell_weights_buffer, &aux_input_to_cell_weights_float32);
+    std::vector<float> aux_input_to_output_weights_float32(numCells * inputSize);
+    convertFloat16ToFloat32(aux_input_to_output_weights_buffer,
+                            &aux_input_to_output_weights_float32);
+
     std::vector<float> input_gate_bias_float32(numCells);
     if (input_gate_bias_buffer != nullptr) {
         convertFloat16ToFloat32(input_gate_bias_buffer, &input_gate_bias_float32);
@@ -570,13 +604,21 @@
         transposedOutputShape = transposedInputShape;
         transposedOutputShape.dimensions[2] = outputSize;
     }
-    const float* inputCurrentTimeStep = timeMajor ? input_float32.data() : transposedInput.data();
-    float* outputCurrentTimeStep = timeMajor ? output_float32.data() : transposedOutput.data();
+    const float* inputData = timeMajor ? input_float32.data() : transposedInput.data();
+    float* outputData = timeMajor ? output_float32.data() : transposedOutput.data();
 
     std::vector<float> outputStateInCurrentTimeStep(batchSize * outputSize);
     convertFloat16ToFloat32(output_state_in_buffer, &outputStateInCurrentTimeStep);
     std::vector<float> cellStateInCurrentTimeStep(batchSize * numCells);
     convertFloat16ToFloat32(cell_state_in_buffer, &cellStateInCurrentTimeStep);
+
+    const float* inputCurrentTimeStep =
+            inputData + (forwardSequence ? 0 : batchInputSize * (maxTime - 1));
+    float* outputCurrentTimeStep =
+            outputData + (forwardSequence ? 0 : batchOutputSize * (maxTime - 1));
+    const int batchInputDelta = forwardSequence ? batchInputSize : -batchInputSize;
+    const int batchOutputDelta = forwardSequence ? batchOutputSize : -batchOutputSize;
+
     for (int t = 0; t < maxTime; ++t) {
         LSTMStep(params, inputCurrentTimeStep, batchInputShape,
                  input_to_input_weights_float32.data(), input_to_forget_weights_float32.data(),
@@ -586,7 +628,11 @@
                  recurrent_to_cell_weights_float32.data(),
                  recurrent_to_output_weights_float32.data(), recurrent_to_output_weights_shape,
                  cell_to_input_weights_float32.data(), cell_to_forget_weights_float32.data(),
-                 cell_to_output_weights_float32.data(), input_gate_bias_float32.data(),
+                 cell_to_output_weights_float32.data(), aux_input_float32.data(), aux_input_shape,
+                 aux_input_to_input_weights_float32.data(),
+                 aux_input_to_forget_weights_float32.data(),
+                 aux_input_to_cell_weights_float32.data(),
+                 aux_input_to_output_weights_float32.data(), input_gate_bias_float32.data(),
                  forget_gate_bias_float32.data(), cell_bias_float32.data(),
                  output_gate_bias_float32.data(), projection_weights_float32.data(),
                  projection_bias_float32.data(), outputStateInCurrentTimeStep.data(),
@@ -595,8 +641,8 @@
                  output_layer_norm_weights_float32.data(), output_state_out_float32.data(),
                  cell_state_out_float32.data(), outputCurrentTimeStep,
                  scratch_buffer_float32.data());
-        inputCurrentTimeStep += batchInputSize;
-        outputCurrentTimeStep += batchOutputSize;
+        inputCurrentTimeStep += batchInputDelta;
+        outputCurrentTimeStep += batchOutputDelta;
         outputStateInCurrentTimeStep = output_state_out_float32;
         cellStateInCurrentTimeStep = cell_state_out_float32;
     }
@@ -624,14 +670,18 @@
         const float* recurrent_to_output_weights_buffer,
         const Shape& recurrent_to_output_weights_shape, const float* cell_to_input_weights_buffer,
         const float* cell_to_forget_weights_buffer, const float* cell_to_output_weights_buffer,
-        const float* input_gate_bias_buffer, const float* forget_gate_bias_buffer,
-        const float* cell_bias_buffer, const float* output_gate_bias_buffer,
-        const float* projection_weights_buffer, const float* projection_bias_buffer,
-        const float* output_state_in_buffer, const float* cell_state_in_buffer,
-        const float* input_layer_norm_weights_buffer, const float* forget_layer_norm_weights_buffer,
-        const float* cell_layer_norm_weights_buffer, const float* output_layer_norm_weights_buffer,
-        float* output_state_out_buffer, float* cell_state_out_buffer, float* output_buffer,
-        float* scratch_buffer_buffer) {
+        const float* aux_input_buffer, const Shape& aux_input_shape,
+        const float* aux_input_to_input_weights_buffer,
+        const float* aux_input_to_forget_weights_buffer,
+        const float* aux_input_to_cell_weights_buffer,
+        const float* aux_input_to_output_weights_buffer, const float* input_gate_bias_buffer,
+        const float* forget_gate_bias_buffer, const float* cell_bias_buffer,
+        const float* output_gate_bias_buffer, const float* projection_weights_buffer,
+        const float* projection_bias_buffer, const float* output_state_in_buffer,
+        const float* cell_state_in_buffer, const float* input_layer_norm_weights_buffer,
+        const float* forget_layer_norm_weights_buffer, const float* cell_layer_norm_weights_buffer,
+        const float* output_layer_norm_weights_buffer, float* output_state_out_buffer,
+        float* cell_state_out_buffer, float* output_buffer, float* scratch_buffer_buffer) {
     NNTRACE_COMP("LSTMCell::LSTMStep");
 
     const uint32_t n_batch = input_shape.dimensions[0];
@@ -639,6 +689,7 @@
     // n_cell and n_output will be the same size when there is no projection.
     const uint32_t n_cell = input_to_output_weights_shape.dimensions[0];
     const uint32_t n_output = recurrent_to_output_weights_shape.dimensions[1];
+    const uint32_t n_aux_input = aux_input_buffer == nullptr ? 0 : n_input;
 
     // Index the scratch buffers pointers to the global scratch buffer.
     float* input_gate_scratch = nullptr;
@@ -694,6 +745,26 @@
             input_to_output_weights_buffer, n_cell, n_input, input_buffer, n_batch,
             output_gate_scratch, /*result_stride*/ 1);
 
+    // If auxiliary input is available then compute aux_input_weight * aux_input
+    if (aux_input_buffer != nullptr) {
+        if (!params.use_cifg) {
+            tflite::tensor_utils::MatrixBatchVectorMultiplyAccumulate(
+                    aux_input_to_input_weights_buffer, n_cell, n_aux_input, aux_input_buffer,
+                    n_batch, input_gate_scratch,
+                    /*result_stride=*/1);
+        }
+
+        tflite::tensor_utils::MatrixBatchVectorMultiplyAccumulate(
+                aux_input_to_forget_weights_buffer, n_cell, n_aux_input, aux_input_buffer, n_batch,
+                forget_gate_scratch, /*result_stride=*/1);
+        tflite::tensor_utils::MatrixBatchVectorMultiplyAccumulate(
+                aux_input_to_cell_weights_buffer, n_cell, n_aux_input, aux_input_buffer, n_batch,
+                cell_scratch, /*result_stride=*/1);
+        tflite::tensor_utils::MatrixBatchVectorMultiplyAccumulate(
+                aux_input_to_output_weights_buffer, n_cell, n_aux_input, aux_input_buffer, n_batch,
+                output_gate_scratch, /*result_stride=*/1);
+    }
+
     // For each batch and cell: compute recurrent_weight * output_state.
     if (!params.use_cifg) {
         tflite::tensor_utils::MatrixBatchVectorMultiplyAccumulate(
@@ -837,6 +908,11 @@
                             GetBuffer<const float>(cell_to_input_weights_),
                             GetBuffer<const float>(cell_to_forget_weights_),
                             GetBuffer<const float>(cell_to_output_weights_),
+                            /*aux_input_buffer=*/nullptr, input_->shape(),
+                            /*aux_input_to_input_weights_buffer=*/nullptr,
+                            /*aux_input_to_forget_weights_buffer=*/nullptr,
+                            /*aux_input_to_cell_weights_buffer=*/nullptr,
+                            /*aux_input_to_output_weights_buffer=*/nullptr,
                             GetBuffer<const float>(input_gate_bias_),
                             GetBuffer<const float>(forget_gate_bias_),
                             GetBuffer<const float>(cell_bias_),
@@ -867,6 +943,11 @@
                             GetOptionalBuffer<const _Float16>(cell_to_input_weights_),
                             GetOptionalBuffer<const _Float16>(cell_to_forget_weights_),
                             GetOptionalBuffer<const _Float16>(cell_to_output_weights_),
+                            /*aux_input_buffer=*/nullptr, input_->shape(),
+                            /*aux_input_to_input_weights_buffer=*/nullptr,
+                            /*aux_input_to_forget_weights_buffer=*/nullptr,
+                            /*aux_input_to_cell_weights_buffer=*/nullptr,
+                            /*aux_input_to_output_weights_buffer=*/nullptr,
                             GetOptionalBuffer<const _Float16>(input_gate_bias_),
                             GetBuffer<const _Float16>(forget_gate_bias_),
                             GetBuffer<const _Float16>(cell_bias_),