Fix input shape in LSTM Eval for Sequence LSTM.
Bug: 113563341
Test: Covered by existing test NeuralNetworksTest_static
--gtest_filter=GeneratedTests.lstm*
Change-Id: I4ef145531b8ee5524137922b58e689bbc579fd81
Merged-In: I4ef145531b8ee5524137922b58e689bbc579fd81
(cherry picked from commit 02c3e76fc3fb2f9af7476bc1986445efe1ce76be)
diff --git a/common/operations/LSTM.cpp b/common/operations/LSTM.cpp
index bbadfa6..ac37ad3 100644
--- a/common/operations/LSTM.cpp
+++ b/common/operations/LSTM.cpp
@@ -369,13 +369,15 @@
const uint32_t inputSize = getSizeOfDimension(input_shape, inputRank - 1);
const uint32_t outputSize = getSizeOfDimension(recurrent_to_output_weights_shape, 1);
+ Shape batchInputShape = input_shape;
+ batchInputShape.dimensions = {batchSize, inputSize};
const uint32_t batchInputSize = batchSize * inputSize;
const uint32_t batchOutputSize = batchSize * outputSize;
const float* inputCurrentTimeStep = input_buffer;
float* outputCurrentTimeStep = output_buffer;
for (int t = 0; t < maxTime; ++t) {
- LSTMStep(params, inputCurrentTimeStep, input_shape, input_to_input_weights_buffer,
+ LSTMStep(params, inputCurrentTimeStep, batchInputShape, input_to_input_weights_buffer,
input_to_forget_weights_buffer, input_to_cell_weights_buffer,
input_to_output_weights_buffer, input_to_output_weights_shape,
recurrent_to_input_weights_buffer, recurrent_to_forget_weights_buffer,
@@ -430,6 +432,8 @@
const uint32_t numCells = getSizeOfDimension(input_to_output_weights_shape, 0);
const uint32_t outputSize = getSizeOfDimension(recurrent_to_output_weights_shape, 1);
+ Shape batchInputShape = input_shape;
+ batchInputShape.dimensions = {batchSize, inputSize};
const uint32_t batchInputSize = batchSize * inputSize;
const uint32_t batchOutputSize = batchSize * outputSize;
@@ -531,10 +535,10 @@
const float* inputCurrentTimeStep = input_float32.data();
float* outputCurrentTimeStep = output_float32.data();
for (int t = 0; t < maxTime; ++t) {
- LSTMStep(params, inputCurrentTimeStep, input_shape, input_to_input_weights_float32.data(),
- input_to_forget_weights_float32.data(), input_to_cell_weights_float32.data(),
- input_to_output_weights_float32.data(), input_to_output_weights_shape,
- recurrent_to_input_weights_float32.data(),
+ LSTMStep(params, inputCurrentTimeStep, batchInputShape,
+ input_to_input_weights_float32.data(), input_to_forget_weights_float32.data(),
+ input_to_cell_weights_float32.data(), input_to_output_weights_float32.data(),
+ input_to_output_weights_shape, recurrent_to_input_weights_float32.data(),
recurrent_to_forget_weights_float32.data(),
recurrent_to_cell_weights_float32.data(),
recurrent_to_output_weights_float32.data(), recurrent_to_output_weights_shape,