Implements BIDIRECTIONAL_SEQUENCE_LSTM operation. Test: NeuralNetworksTest_static --gtest_filter=GeneratedTests.*lstm Bug: 113559542 Change-Id: If30e31c851bfbd97445710d8e1998306a551ac08 Merged-In: If30e31c851bfbd97445710d8e1998306a551ac08 (cherry picked from commit be339f503051e72ed63e88a5645af15b02a44478)
diff --git a/common/operations/LSTM.cpp b/common/operations/LSTM.cpp index 293b23e..1b36574 100644 --- a/common/operations/LSTM.cpp +++ b/common/operations/LSTM.cpp
@@ -350,14 +350,19 @@ const float* recurrent_to_output_weights_buffer, const Shape& recurrent_to_output_weights_shape, const float* cell_to_input_weights_buffer, const float* cell_to_forget_weights_buffer, const float* cell_to_output_weights_buffer, - const float* input_gate_bias_buffer, const float* forget_gate_bias_buffer, - const float* cell_bias_buffer, const float* output_gate_bias_buffer, - const float* projection_weights_buffer, const float* projection_bias_buffer, - const float* output_state_in_buffer, const float* cell_state_in_buffer, - const float* input_layer_norm_weights_buffer, const float* forget_layer_norm_weights_buffer, - const float* cell_layer_norm_weights_buffer, const float* output_layer_norm_weights_buffer, - float* output_state_out_buffer, float* cell_state_out_buffer, float* output_buffer, - float* scratch_buffer_buffer, bool timeMajor) { + const float* aux_input_buffer, const Shape& aux_input_shape, + const float* aux_input_to_input_weights_buffer, + const float* aux_input_to_forget_weights_buffer, + const float* aux_input_to_cell_weights_buffer, + const float* aux_input_to_output_weights_buffer, const float* input_gate_bias_buffer, + const float* forget_gate_bias_buffer, const float* cell_bias_buffer, + const float* output_gate_bias_buffer, const float* projection_weights_buffer, + const float* projection_bias_buffer, const float* output_state_in_buffer, + const float* cell_state_in_buffer, const float* input_layer_norm_weights_buffer, + const float* forget_layer_norm_weights_buffer, const float* cell_layer_norm_weights_buffer, + const float* output_layer_norm_weights_buffer, float* output_state_out_buffer, + float* cell_state_out_buffer, float* output_buffer, float* scratch_buffer_buffer, + bool forwardSequence, bool timeMajor) { NNTRACE_COMP("LSTMCell::LSTMEvalFloat32"); const uint32_t inputRank = getNumberOfDimensions(input_shape); @@ -388,13 +393,20 @@ transposedOutputShape = transposedInputShape; transposedOutputShape.dimensions[2] = outputSize; } - const float* inputCurrentTimeStep = timeMajor ? input_buffer : transposedInput.data(); - float* outputCurrentTimeStep = timeMajor ? output_buffer : transposedOutput.data(); + const float* inputData = timeMajor ? input_buffer : transposedInput.data(); + float* outputData = timeMajor ? output_buffer : transposedOutput.data(); std::vector<float> outputStateInCurrentTimeStep( output_state_in_buffer, output_state_in_buffer + batchSize * outputSize); std::vector<float> cellStateInCurrentTimeStep(cell_state_in_buffer, cell_state_in_buffer + batchSize * numCells); + const float* inputCurrentTimeStep = + inputData + (forwardSequence ? 0 : batchInputSize * (maxTime - 1)); + float* outputCurrentTimeStep = + outputData + (forwardSequence ? 0 : batchOutputSize * (maxTime - 1)); + const int batchInputDelta = forwardSequence ? batchInputSize : -batchInputSize; + const int batchOutputDelta = forwardSequence ? batchOutputSize : -batchOutputSize; + for (int t = 0; t < maxTime; ++t) { LSTMStep(params, inputCurrentTimeStep, batchInputShape, input_to_input_weights_buffer, input_to_forget_weights_buffer, input_to_cell_weights_buffer, @@ -402,16 +414,18 @@ recurrent_to_input_weights_buffer, recurrent_to_forget_weights_buffer, recurrent_to_cell_weights_buffer, recurrent_to_output_weights_buffer, recurrent_to_output_weights_shape, cell_to_input_weights_buffer, - cell_to_forget_weights_buffer, cell_to_output_weights_buffer, - input_gate_bias_buffer, forget_gate_bias_buffer, cell_bias_buffer, - output_gate_bias_buffer, projection_weights_buffer, projection_bias_buffer, - outputStateInCurrentTimeStep.data(), cellStateInCurrentTimeStep.data(), - input_layer_norm_weights_buffer, forget_layer_norm_weights_buffer, - cell_layer_norm_weights_buffer, output_layer_norm_weights_buffer, - output_state_out_buffer, cell_state_out_buffer, outputCurrentTimeStep, - scratch_buffer_buffer); - inputCurrentTimeStep += batchInputSize; - outputCurrentTimeStep += batchOutputSize; + cell_to_forget_weights_buffer, cell_to_output_weights_buffer, aux_input_buffer, + aux_input_shape, aux_input_to_input_weights_buffer, + aux_input_to_forget_weights_buffer, aux_input_to_cell_weights_buffer, + aux_input_to_output_weights_buffer, input_gate_bias_buffer, + forget_gate_bias_buffer, cell_bias_buffer, output_gate_bias_buffer, + projection_weights_buffer, projection_bias_buffer, output_state_in_buffer, + cell_state_in_buffer, input_layer_norm_weights_buffer, + forget_layer_norm_weights_buffer, cell_layer_norm_weights_buffer, + output_layer_norm_weights_buffer, output_state_out_buffer, cell_state_out_buffer, + outputCurrentTimeStep, scratch_buffer_buffer); + inputCurrentTimeStep += batchInputDelta; + outputCurrentTimeStep += batchOutputDelta; outputStateInCurrentTimeStep.assign(output_state_out_buffer, output_state_out_buffer + batchSize * outputSize); cellStateInCurrentTimeStep.assign(cell_state_out_buffer, @@ -439,7 +453,11 @@ const _Float16* recurrent_to_output_weights_buffer, const Shape& recurrent_to_output_weights_shape, const _Float16* cell_to_input_weights_buffer, const _Float16* cell_to_forget_weights_buffer, - const _Float16* cell_to_output_weights_buffer, const _Float16* input_gate_bias_buffer, + const _Float16* cell_to_output_weights_buffer, const _Float16* aux_input_buffer, + const Shape& aux_input_shape, const _Float16* aux_input_to_input_weights_buffer, + const _Float16* aux_input_to_forget_weights_buffer, + const _Float16* aux_input_to_cell_weights_buffer, + const _Float16* aux_input_to_output_weights_buffer, const _Float16* input_gate_bias_buffer, const _Float16* forget_gate_bias_buffer, const _Float16* cell_bias_buffer, const _Float16* output_gate_bias_buffer, const _Float16* projection_weights_buffer, const _Float16* projection_bias_buffer, const _Float16* output_state_in_buffer, @@ -448,7 +466,7 @@ const _Float16* cell_layer_norm_weights_buffer, const _Float16* output_layer_norm_weights_buffer, _Float16* output_state_out_buffer, _Float16* cell_state_out_buffer, _Float16* output_buffer, _Float16* scratch_buffer_buffer, - bool timeMajor) { + bool forwardSequence, bool timeMajor) { NNTRACE_COMP("LSTMCell::LSTMEvalFloat16"); const uint32_t inputRank = getNumberOfDimensions(input_shape); @@ -507,6 +525,22 @@ convertFloat16ToFloat32(cell_to_output_weights_buffer, &cell_to_output_weights_float32); } + std::vector<float> aux_input_float32(maxTime * batchInputSize); + convertFloat16ToFloat32(aux_input_buffer, &aux_input_float32); + std::vector<float> aux_input_to_input_weights_float32(numCells * inputSize); + if (aux_input_to_input_weights_buffer != nullptr) { + convertFloat16ToFloat32(aux_input_to_input_weights_buffer, + &aux_input_to_input_weights_float32); + } + std::vector<float> aux_input_to_forget_weights_float32(numCells * inputSize); + convertFloat16ToFloat32(aux_input_to_forget_weights_buffer, + &aux_input_to_forget_weights_float32); + std::vector<float> aux_input_to_cell_weights_float32(numCells * inputSize); + convertFloat16ToFloat32(aux_input_to_cell_weights_buffer, &aux_input_to_cell_weights_float32); + std::vector<float> aux_input_to_output_weights_float32(numCells * inputSize); + convertFloat16ToFloat32(aux_input_to_output_weights_buffer, + &aux_input_to_output_weights_float32); + std::vector<float> input_gate_bias_float32(numCells); if (input_gate_bias_buffer != nullptr) { convertFloat16ToFloat32(input_gate_bias_buffer, &input_gate_bias_float32); @@ -570,13 +604,21 @@ transposedOutputShape = transposedInputShape; transposedOutputShape.dimensions[2] = outputSize; } - const float* inputCurrentTimeStep = timeMajor ? input_float32.data() : transposedInput.data(); - float* outputCurrentTimeStep = timeMajor ? output_float32.data() : transposedOutput.data(); + const float* inputData = timeMajor ? input_float32.data() : transposedInput.data(); + float* outputData = timeMajor ? output_float32.data() : transposedOutput.data(); std::vector<float> outputStateInCurrentTimeStep(batchSize * outputSize); convertFloat16ToFloat32(output_state_in_buffer, &outputStateInCurrentTimeStep); std::vector<float> cellStateInCurrentTimeStep(batchSize * numCells); convertFloat16ToFloat32(cell_state_in_buffer, &cellStateInCurrentTimeStep); + + const float* inputCurrentTimeStep = + inputData + (forwardSequence ? 0 : batchInputSize * (maxTime - 1)); + float* outputCurrentTimeStep = + outputData + (forwardSequence ? 0 : batchOutputSize * (maxTime - 1)); + const int batchInputDelta = forwardSequence ? batchInputSize : -batchInputSize; + const int batchOutputDelta = forwardSequence ? batchOutputSize : -batchOutputSize; + for (int t = 0; t < maxTime; ++t) { LSTMStep(params, inputCurrentTimeStep, batchInputShape, input_to_input_weights_float32.data(), input_to_forget_weights_float32.data(), @@ -586,7 +628,11 @@ recurrent_to_cell_weights_float32.data(), recurrent_to_output_weights_float32.data(), recurrent_to_output_weights_shape, cell_to_input_weights_float32.data(), cell_to_forget_weights_float32.data(), - cell_to_output_weights_float32.data(), input_gate_bias_float32.data(), + cell_to_output_weights_float32.data(), aux_input_float32.data(), aux_input_shape, + aux_input_to_input_weights_float32.data(), + aux_input_to_forget_weights_float32.data(), + aux_input_to_cell_weights_float32.data(), + aux_input_to_output_weights_float32.data(), input_gate_bias_float32.data(), forget_gate_bias_float32.data(), cell_bias_float32.data(), output_gate_bias_float32.data(), projection_weights_float32.data(), projection_bias_float32.data(), outputStateInCurrentTimeStep.data(), @@ -595,8 +641,8 @@ output_layer_norm_weights_float32.data(), output_state_out_float32.data(), cell_state_out_float32.data(), outputCurrentTimeStep, scratch_buffer_float32.data()); - inputCurrentTimeStep += batchInputSize; - outputCurrentTimeStep += batchOutputSize; + inputCurrentTimeStep += batchInputDelta; + outputCurrentTimeStep += batchOutputDelta; outputStateInCurrentTimeStep = output_state_out_float32; cellStateInCurrentTimeStep = cell_state_out_float32; } @@ -624,14 +670,18 @@ const float* recurrent_to_output_weights_buffer, const Shape& recurrent_to_output_weights_shape, const float* cell_to_input_weights_buffer, const float* cell_to_forget_weights_buffer, const float* cell_to_output_weights_buffer, - const float* input_gate_bias_buffer, const float* forget_gate_bias_buffer, - const float* cell_bias_buffer, const float* output_gate_bias_buffer, - const float* projection_weights_buffer, const float* projection_bias_buffer, - const float* output_state_in_buffer, const float* cell_state_in_buffer, - const float* input_layer_norm_weights_buffer, const float* forget_layer_norm_weights_buffer, - const float* cell_layer_norm_weights_buffer, const float* output_layer_norm_weights_buffer, - float* output_state_out_buffer, float* cell_state_out_buffer, float* output_buffer, - float* scratch_buffer_buffer) { + const float* aux_input_buffer, const Shape& aux_input_shape, + const float* aux_input_to_input_weights_buffer, + const float* aux_input_to_forget_weights_buffer, + const float* aux_input_to_cell_weights_buffer, + const float* aux_input_to_output_weights_buffer, const float* input_gate_bias_buffer, + const float* forget_gate_bias_buffer, const float* cell_bias_buffer, + const float* output_gate_bias_buffer, const float* projection_weights_buffer, + const float* projection_bias_buffer, const float* output_state_in_buffer, + const float* cell_state_in_buffer, const float* input_layer_norm_weights_buffer, + const float* forget_layer_norm_weights_buffer, const float* cell_layer_norm_weights_buffer, + const float* output_layer_norm_weights_buffer, float* output_state_out_buffer, + float* cell_state_out_buffer, float* output_buffer, float* scratch_buffer_buffer) { NNTRACE_COMP("LSTMCell::LSTMStep"); const uint32_t n_batch = input_shape.dimensions[0]; @@ -639,6 +689,7 @@ // n_cell and n_output will be the same size when there is no projection. const uint32_t n_cell = input_to_output_weights_shape.dimensions[0]; const uint32_t n_output = recurrent_to_output_weights_shape.dimensions[1]; + const uint32_t n_aux_input = aux_input_buffer == nullptr ? 0 : n_input; // Index the scratch buffers pointers to the global scratch buffer. float* input_gate_scratch = nullptr; @@ -694,6 +745,26 @@ input_to_output_weights_buffer, n_cell, n_input, input_buffer, n_batch, output_gate_scratch, /*result_stride*/ 1); + // If auxiliary input is available then compute aux_input_weight * aux_input + if (aux_input_buffer != nullptr) { + if (!params.use_cifg) { + tflite::tensor_utils::MatrixBatchVectorMultiplyAccumulate( + aux_input_to_input_weights_buffer, n_cell, n_aux_input, aux_input_buffer, + n_batch, input_gate_scratch, + /*result_stride=*/1); + } + + tflite::tensor_utils::MatrixBatchVectorMultiplyAccumulate( + aux_input_to_forget_weights_buffer, n_cell, n_aux_input, aux_input_buffer, n_batch, + forget_gate_scratch, /*result_stride=*/1); + tflite::tensor_utils::MatrixBatchVectorMultiplyAccumulate( + aux_input_to_cell_weights_buffer, n_cell, n_aux_input, aux_input_buffer, n_batch, + cell_scratch, /*result_stride=*/1); + tflite::tensor_utils::MatrixBatchVectorMultiplyAccumulate( + aux_input_to_output_weights_buffer, n_cell, n_aux_input, aux_input_buffer, n_batch, + output_gate_scratch, /*result_stride=*/1); + } + // For each batch and cell: compute recurrent_weight * output_state. if (!params.use_cifg) { tflite::tensor_utils::MatrixBatchVectorMultiplyAccumulate( @@ -837,6 +908,11 @@ GetBuffer<const float>(cell_to_input_weights_), GetBuffer<const float>(cell_to_forget_weights_), GetBuffer<const float>(cell_to_output_weights_), + /*aux_input_buffer=*/nullptr, input_->shape(), + /*aux_input_to_input_weights_buffer=*/nullptr, + /*aux_input_to_forget_weights_buffer=*/nullptr, + /*aux_input_to_cell_weights_buffer=*/nullptr, + /*aux_input_to_output_weights_buffer=*/nullptr, GetBuffer<const float>(input_gate_bias_), GetBuffer<const float>(forget_gate_bias_), GetBuffer<const float>(cell_bias_), @@ -867,6 +943,11 @@ GetOptionalBuffer<const _Float16>(cell_to_input_weights_), GetOptionalBuffer<const _Float16>(cell_to_forget_weights_), GetOptionalBuffer<const _Float16>(cell_to_output_weights_), + /*aux_input_buffer=*/nullptr, input_->shape(), + /*aux_input_to_input_weights_buffer=*/nullptr, + /*aux_input_to_forget_weights_buffer=*/nullptr, + /*aux_input_to_cell_weights_buffer=*/nullptr, + /*aux_input_to_output_weights_buffer=*/nullptr, GetOptionalBuffer<const _Float16>(input_gate_bias_), GetBuffer<const _Float16>(forget_gate_bias_), GetBuffer<const _Float16>(cell_bias_),