Add tests and fixes for BidirectionalSequenceLSTM op's aux_input.
Also add the missing generated pad op tests.
Bug: 129570867
Test: NeuralNetworksTest_static --gtest_filter=GeneratedTests.*lstm*
Change-Id: I08178d966d46dd30414955144e9db3393342514a
Merged-In: I08178d966d46dd30414955144e9db3393342514a
(cherry picked from commit d0fbc245689322136c9b2c5eb5b0838fffc9618f)
diff --git a/common/operations/LSTM.cpp b/common/operations/LSTM.cpp
index 3bc4072..b219e75 100644
--- a/common/operations/LSTM.cpp
+++ b/common/operations/LSTM.cpp
@@ -400,18 +400,27 @@
const uint32_t batchOutputSize = batchSize * outputSize;
std::vector<float> transposedInput;
+ const bool hasAuxInput = (aux_input_buffer != nullptr);
+ std::vector<float> transposedAuxInput;
std::vector<float> transposedOutput;
Shape transposedInputShape;
Shape transposedOutputShape;
if (!timeMajor) {
transposedInput.resize(maxTime * batchInputSize);
- transposedOutput.resize(maxTime * batchOutputSize);
transposeFirstTwoDimensions<float>(input_buffer, input_shape, transposedInput.data());
+ if (hasAuxInput) {
+ transposedAuxInput.resize(maxTime * batchInputSize);
+ transposeFirstTwoDimensions<float>(aux_input_buffer, input_shape,
+ transposedAuxInput.data());
+ }
transposeFirstTwoDimensions(input_shape, &transposedInputShape);
+ transposedOutput.resize(maxTime * batchOutputSize);
transposedOutputShape = transposedInputShape;
transposedOutputShape.dimensions[2] = outputSize;
}
const float* inputData = timeMajor ? input_buffer : transposedInput.data();
+ const float* auxInputData =
+ hasAuxInput ? (timeMajor ? aux_input_buffer : transposedAuxInput.data()) : nullptr;
float* outputData = timeMajor ? output_buffer : transposedOutput.data();
std::vector<float> outputStateInCurrentTimeStep(
@@ -420,6 +429,9 @@
cell_state_in_buffer + batchSize * numCells);
const float* inputCurrentTimeStep =
inputData + (forwardSequence ? 0 : batchInputSize * (maxTime - 1));
+ const float* auxInputCurrentTimeStep =
+ hasAuxInput ? (auxInputData + (forwardSequence ? 0 : batchInputSize * (maxTime - 1)))
+ : nullptr;
float* outputCurrentTimeStep =
outputData + (forwardSequence ? 0 : batchOutputSize * (maxTime - 1));
const int batchInputDelta = forwardSequence ? batchInputSize : -batchInputSize;
@@ -432,17 +444,21 @@
recurrent_to_input_weights_buffer, recurrent_to_forget_weights_buffer,
recurrent_to_cell_weights_buffer, recurrent_to_output_weights_buffer,
recurrent_to_output_weights_shape, cell_to_input_weights_buffer,
- cell_to_forget_weights_buffer, cell_to_output_weights_buffer, aux_input_buffer,
- aux_input_to_input_weights_buffer, aux_input_to_forget_weights_buffer,
- aux_input_to_cell_weights_buffer, aux_input_to_output_weights_buffer,
- input_gate_bias_buffer, forget_gate_bias_buffer, cell_bias_buffer,
- output_gate_bias_buffer, projection_weights_buffer, projection_bias_buffer,
+ cell_to_forget_weights_buffer, cell_to_output_weights_buffer,
+ auxInputCurrentTimeStep, aux_input_to_input_weights_buffer,
+ aux_input_to_forget_weights_buffer, aux_input_to_cell_weights_buffer,
+ aux_input_to_output_weights_buffer, input_gate_bias_buffer,
+ forget_gate_bias_buffer, cell_bias_buffer, output_gate_bias_buffer,
+ projection_weights_buffer, projection_bias_buffer,
outputStateInCurrentTimeStep.data(), cellStateInCurrentTimeStep.data(),
input_layer_norm_weights_buffer, forget_layer_norm_weights_buffer,
cell_layer_norm_weights_buffer, output_layer_norm_weights_buffer,
output_state_out_buffer, cell_state_out_buffer, outputCurrentTimeStep,
scratch_buffer_buffer);
inputCurrentTimeStep += batchInputDelta;
+ if (hasAuxInput) {
+ auxInputCurrentTimeStep += batchInputDelta;
+ }
outputCurrentTimeStep += batchOutputDelta;
outputStateInCurrentTimeStep.assign(output_state_out_buffer,
output_state_out_buffer + batchSize * outputSize);
@@ -619,19 +635,29 @@
convertFloat16ToFloat32(scratch_buffer_buffer, &scratch_buffer_float32);
std::vector<float> transposedInput;
+ const bool hasAuxInput = (aux_input_buffer != nullptr);
+ std::vector<float> transposedAuxInput;
std::vector<float> transposedOutput;
Shape transposedInputShape;
Shape transposedOutputShape;
if (!timeMajor) {
transposedInput.resize(maxTime * batchInputSize);
- transposedOutput.resize(maxTime * batchOutputSize);
transposeFirstTwoDimensions<float>(input_float32.data(), input_shape,
transposedInput.data());
+ if (hasAuxInput) {
+ transposedAuxInput.resize(maxTime * batchInputSize);
+ transposeFirstTwoDimensions<float>(aux_input_float32.data(), input_shape,
+ transposedAuxInput.data());
+ }
transposeFirstTwoDimensions(input_shape, &transposedInputShape);
+ transposedOutput.resize(maxTime * batchOutputSize);
transposedOutputShape = transposedInputShape;
transposedOutputShape.dimensions[2] = outputSize;
}
const float* inputData = timeMajor ? input_float32.data() : transposedInput.data();
+ const float* auxInputData =
+ hasAuxInput ? (timeMajor ? aux_input_float32.data() : transposedAuxInput.data())
+ : nullptr;
float* outputData = timeMajor ? output_float32.data() : transposedOutput.data();
std::vector<float> outputStateInCurrentTimeStep(batchSize * outputSize);
@@ -641,6 +667,9 @@
const float* inputCurrentTimeStep =
inputData + (forwardSequence ? 0 : batchInputSize * (maxTime - 1));
+ const float* auxInputCurrentTimeStep =
+ hasAuxInput ? (auxInputData + (forwardSequence ? 0 : batchInputSize * (maxTime - 1)))
+ : nullptr;
float* outputCurrentTimeStep =
outputData + (forwardSequence ? 0 : batchOutputSize * (maxTime - 1));
const int batchInputDelta = forwardSequence ? batchInputSize : -batchInputSize;
@@ -655,8 +684,7 @@
recurrent_to_cell_weights_float32.data(),
recurrent_to_output_weights_float32.data(), recurrent_to_output_weights_shape,
cell_to_input_weights_float32.data(), cell_to_forget_weights_float32.data(),
- cell_to_output_weights_float32.data(),
- aux_input_buffer != nullptr ? aux_input_float32.data() : nullptr,
+ cell_to_output_weights_float32.data(), auxInputCurrentTimeStep,
aux_input_to_input_weights_float32.data(),
aux_input_to_forget_weights_float32.data(),
aux_input_to_cell_weights_float32.data(),
@@ -670,6 +698,9 @@
cell_state_out_float32.data(), outputCurrentTimeStep,
scratch_buffer_float32.data());
inputCurrentTimeStep += batchInputDelta;
+ if (hasAuxInput) {
+ auxInputCurrentTimeStep += batchInputDelta;
+ }
outputCurrentTimeStep += batchOutputDelta;
outputStateInCurrentTimeStep = output_state_out_float32;
cellStateInCurrentTimeStep = cell_state_out_float32;