Change stateful Ops to stateless ones

Bug: 63905942

Updated Ops RNN, LSTM, and SVDF.
Split outputs used for states into inputs and outputs.

Test: NeuralNetworksTest
Change-Id: Ia3d11f640cba4cab1b94d0b9746c46d347c024a4
diff --git a/common/operations/LSTM.cpp b/common/operations/LSTM.cpp
index 6a1feed..e7b0dec 100644
--- a/common/operations/LSTM.cpp
+++ b/common/operations/LSTM.cpp
@@ -171,13 +171,16 @@
   projection_weights_ = GetInput(operation, operands, kProjectionWeightsTensor);  // optional
   projection_bias_ = GetInput(operation, operands, kProjectionBiasTensor);        // optional
 
+  output_state_in_ = GetInput(operation, operands, kOutputStateInTensor);
+  cell_state_in_ = GetInput(operation, operands, kCellStateInTensor);
+
   params_.activation_ = static_cast<ActivationFn>(getScalarData<int32_t>(
       *GetInput(operation, operands, kActivationParam)));
   params_.cell_clip_ = getScalarData<float>(*GetInput(operation, operands, kCellClipParam));
   params_.proj_clip_ = getScalarData<float>(*GetInput(operation, operands, kProjClipParam));
 
-  output_state_ = GetOutput(operation, operands, kOutputStateTensor);
-  cell_state_ = GetOutput(operation, operands, kCellStateTensor);
+  output_state_out_ = GetOutput(operation, operands, kOutputStateOutTensor);
+  cell_state_out_ = GetOutput(operation, operands, kCellStateOutTensor);
   output_ = GetOutput(operation, operands, kOutputTensor);
 
   scratch_buffer_ = GetOutput(operation, operands, kScratchBufferTensor);
@@ -338,8 +341,8 @@
                        Shape *cellStateShape,
                        Shape *outputShape) {
   // Check we have all the inputs and outputs we need.
-  NN_CHECK(NumInputsWithValues(operation, operands) >= 13 &&
-           NumInputsWithValues(operation, operands) <= 21);
+  NN_CHECK(NumInputsWithValues(operation, operands) >= 15 &&
+           NumInputsWithValues(operation, operands) <= 23);
   NN_CHECK_EQ(NumOutputs(operation), 4);
 
   // Inferring batch size, number of outputs and number of cells from the
@@ -463,24 +466,24 @@
   if (!use_cifg) {
     MatrixBatchVectorMultiplyAccumulate(
         GetBuffer<float>(recurrent_to_input_weights_), n_cell, n_output,
-        GetBuffer<float>(output_state_), n_batch, input_gate_scratch);
+        GetBuffer<float>(output_state_in_), n_batch, input_gate_scratch);
   }
   MatrixBatchVectorMultiplyAccumulate(
       GetBuffer<float>(recurrent_to_forget_weights_), n_cell, n_output,
-      GetBuffer<float>(output_state_), n_batch, forget_gate_scratch);
+      GetBuffer<float>(output_state_in_), n_batch, forget_gate_scratch);
   MatrixBatchVectorMultiplyAccumulate(
       GetBuffer<float>(recurrent_to_cell_weights_), n_cell, n_output,
-      GetBuffer<float>(output_state_), n_batch, cell_scratch);
+      GetBuffer<float>(output_state_in_), n_batch, cell_scratch);
   MatrixBatchVectorMultiplyAccumulate(
       GetBuffer<float>(recurrent_to_output_weights_), n_cell, n_output,
-      GetBuffer<float>(output_state_), n_batch, output_gate_scratch);
+      GetBuffer<float>(output_state_in_), n_batch, output_gate_scratch);
 
   // For each batch and cell: update input gate.
   if (!use_cifg) {
     if (use_peephole) {
       VectorBatchVectorCwiseProductAccumulate(
           GetBuffer<float>(cell_to_input_weights_), n_cell,
-          GetBuffer<float>(cell_state_), n_batch, input_gate_scratch);
+          GetBuffer<float>(cell_state_in_), n_batch, input_gate_scratch);
     }
     ApplySigmoidToVector(input_gate_scratch, n_cell * n_batch,
                          input_gate_scratch);
@@ -490,40 +493,40 @@
   if (use_peephole) {
     VectorBatchVectorCwiseProductAccumulate(
         GetBuffer<float>(cell_to_forget_weights_), n_cell,
-        GetBuffer<float>(cell_state_), n_batch, forget_gate_scratch);
+        GetBuffer<float>(cell_state_in_), n_batch, forget_gate_scratch);
   }
   ApplySigmoidToVector(forget_gate_scratch, n_cell * n_batch,
                        forget_gate_scratch);
 
   // For each batch and cell: update the cell.
-  VectorVectorCwiseProduct(forget_gate_scratch, GetBuffer<float>(cell_state_),
-                           n_batch * n_cell, GetBuffer<float>(cell_state_));
+  VectorVectorCwiseProduct(forget_gate_scratch, GetBuffer<float>(cell_state_in_),
+                           n_batch * n_cell, GetBuffer<float>(cell_state_out_));
   ApplyActivationToVector(cell_scratch, n_batch * n_cell, params_.activation_,
                           cell_scratch);
   if (use_cifg) {
     Sub1Vector(forget_gate_scratch, n_batch * n_cell, forget_gate_scratch);
     VectorVectorCwiseProductAccumulate(cell_scratch, forget_gate_scratch,
                                        n_batch * n_cell,
-                                       GetBuffer<float>(cell_state_));
+                                       GetBuffer<float>(cell_state_out_));
   } else {
     VectorVectorCwiseProductAccumulate(cell_scratch, input_gate_scratch,
                                        n_batch * n_cell,
-                                       GetBuffer<float>(cell_state_));
+                                       GetBuffer<float>(cell_state_out_));
   }
   if (params_.cell_clip_ > 0.0) {
-    ClipVector(GetBuffer<float>(cell_state_), n_batch * n_cell,
-               params_.cell_clip_, GetBuffer<float>(cell_state_));
+    ClipVector(GetBuffer<float>(cell_state_out_), n_batch * n_cell,
+               params_.cell_clip_, GetBuffer<float>(cell_state_out_));
   }
 
   // For each batch and cell: update the output gate.
   if (use_peephole) {
     VectorBatchVectorCwiseProductAccumulate(
         GetBuffer<float>(cell_to_output_weights_), n_cell,
-        GetBuffer<float>(cell_state_), n_batch, output_gate_scratch);
+        GetBuffer<float>(cell_state_out_), n_batch, output_gate_scratch);
   }
   ApplySigmoidToVector(output_gate_scratch, n_batch * n_cell,
                        output_gate_scratch);
-  ApplyActivationToVector(GetBuffer<float>(cell_state_), n_batch * n_cell,
+  ApplyActivationToVector(GetBuffer<float>(cell_state_out_), n_batch * n_cell,
                           params_.activation_, cell_scratch);
   VectorVectorCwiseProduct(output_gate_scratch, cell_scratch, n_batch * n_cell,
                            output_gate_scratch);
@@ -551,7 +554,7 @@
                GetBuffer<float>(output_));
   }
   CopyVector(GetBuffer<float>(output_), n_batch * n_output,
-             GetBuffer<float>(output_state_));
+             GetBuffer<float>(output_state_out_));
 
   return true;
 }
diff --git a/common/operations/LSTM.h b/common/operations/LSTM.h
index 430ee61..cc54612 100644
--- a/common/operations/LSTM.h
+++ b/common/operations/LSTM.h
@@ -88,15 +88,17 @@
   // Projection bias tensor of size {n_output}
   static constexpr int kProjectionBiasTensor = 17;  // Optional
 
-  static constexpr int kActivationParam = 18;
-  static constexpr int kCellClipParam = 19;
-  static constexpr int kProjClipParam = 20;
+  static constexpr int kOutputStateInTensor = 18;
+  static constexpr int kCellStateInTensor = 19;
+
+  static constexpr int kActivationParam = 20;
+  static constexpr int kCellClipParam = 21;
+  static constexpr int kProjClipParam = 22;
 
   // Output tensors.
-  // TODO: Do we have to pre-allocate scratch buffer as outputs?
   static constexpr int kScratchBufferTensor = 0;
-  static constexpr int kOutputStateTensor = 1;
-  static constexpr int kCellStateTensor = 2;
+  static constexpr int kOutputStateOutTensor = 1;
+  static constexpr int kCellStateOutTensor = 2;
   static constexpr int kOutputTensor = 3;
 
  private:
@@ -130,8 +132,11 @@
   const RunTimeOperandInfo *projection_weights_;
   const RunTimeOperandInfo *projection_bias_;
 
-  RunTimeOperandInfo *output_state_;
-  RunTimeOperandInfo *cell_state_;
+  const RunTimeOperandInfo *output_state_in_;
+  const RunTimeOperandInfo *cell_state_in_;
+
+  RunTimeOperandInfo *output_state_out_;
+  RunTimeOperandInfo *cell_state_out_;
   RunTimeOperandInfo *output_;
 
   RunTimeOperandInfo *scratch_buffer_;
diff --git a/common/operations/LSTMTest.cpp b/common/operations/LSTMTest.cpp
index a61ef07..b8a80ec 100644
--- a/common/operations/LSTMTest.cpp
+++ b/common/operations/LSTMTest.cpp
@@ -60,13 +60,15 @@
     ACTION(ForgetGateBias)                              \
     ACTION(OutputGateBias)                              \
     ACTION(ProjectionWeights)                           \
-    ACTION(ProjectionBias)
+    ACTION(ProjectionBias)                              \
+    ACTION(OutputStateIn)                               \
+    ACTION(CellStateIn)
 
 // For all output and intermediate states
 #define FOR_ALL_OUTPUT_TENSORS(ACTION)          \
     ACTION(Output)                              \
-    ACTION(OutputState)                         \
-    ACTION(CellState)                           \
+    ACTION(OutputStateOut)                      \
+    ACTION(CellStateOut)                        \
     ACTION(ScratchBuffer)
 
 class LSTMOpModel {
@@ -75,7 +77,7 @@
                 uint32_t n_cell, uint32_t n_output, bool use_cifg,
                 bool use_peephole, bool use_projection_weights,
                 bool use_projection_bias, float cell_clip, float proj_clip,
-                const std::vector<std::vector<uint32_t>>& input_shapes)
+                const std::vector<std::vector<uint32_t>>& input_shapes0)
         : n_batch_(n_batch), n_input_(n_input),
           n_cell_(n_cell), n_output_(n_output),
           use_cifg_(use_cifg), use_peephole_(use_peephole),
@@ -84,7 +86,10 @@
           activation_(ActivationFn::kActivationTanh),
           cell_clip_(cell_clip), proj_clip_(proj_clip) {
         std::vector<uint32_t> inputs;
+        std::vector<std::vector<uint32_t>> input_shapes(input_shapes0.begin(), input_shapes0.end());
         auto it = input_shapes.begin();
+        input_shapes.push_back({n_batch, n_output});
+        input_shapes.push_back({n_batch, n_cell});
 
         // Input and weights
 #define AddInput(X)                                                     \
@@ -154,11 +159,11 @@
 #undef DefineSetter
 
     void ResetOutputState() {
-        std::fill(OutputState_.begin(), OutputState_.end(), 0.f);
+        std::fill(OutputStateIn_.begin(), OutputStateIn_.end(), 0.f);
     }
 
     void ResetCellState() {
-        std::fill(CellState_.begin(), CellState_.end(), 0.f);
+        std::fill(CellStateIn_.begin(), CellStateIn_.end(), 0.f);
     }
 
     void SetInput(int offset, float *begin, float *end) {
diff --git a/common/operations/RNN.cpp b/common/operations/RNN.cpp
index 4e25a88..8a00734 100644
--- a/common/operations/RNN.cpp
+++ b/common/operations/RNN.cpp
@@ -27,12 +27,13 @@
   input_ = GetInput(operation, operands, kInputTensor);
   weights_ = GetInput(operation, operands, kWeightsTensor);
   recurrent_weights_ = GetInput(operation, operands, kRecurrentWeightsTensor);
+  hidden_state_in_ = GetInput(operation, operands, kHiddenStateInTensor);
   bias_ = GetInput(operation, operands, kBiasTensor);
 
   activation_ = static_cast<ActivationFn>(
       getScalarData<int32_t>(operands[operation.inputs[kActivationParam]]));
 
-  hidden_state_ = GetOutput(operation, operands, kHiddenStateTensor);
+  hidden_state_out_ = GetOutput(operation, operands, kHiddenStateOutTensor);
   output_ = GetOutput(operation, operands, kOutputTensor);
 }
 
@@ -42,7 +43,7 @@
                   Shape *outputShape) {
   // Check we have all the inputs and outputs we need.
   const int num_inputs = NumInputsWithValues(operation, operands);
-  NN_CHECK(num_inputs == 4 || num_inputs == 5);
+  NN_CHECK(num_inputs == 5 || num_inputs == 6);
   NN_CHECK_EQ(NumOutputs(operation), 2);
 
   const RunTimeOperandInfo *input =
@@ -91,10 +92,12 @@
     // Initialize the pointer to input, output and bias.
     const float* input_ptr_batch =
         reinterpret_cast<float*>(input_->buffer) + b * input_size;
+    const float* hidden_state_in_ptr_batch =
+        reinterpret_cast<float*>(hidden_state_in_->buffer) + b * num_units;
     float* output_ptr_batch =
         reinterpret_cast<float*>(output_->buffer) + b * num_units;
-    float* hidden_state_ptr_batch =
-        reinterpret_cast<float*>(hidden_state_->buffer) + b * num_units;
+    float* hidden_state_out_ptr_batch =
+        reinterpret_cast<float*>(hidden_state_out_->buffer) + b * num_units;
 
     // Initialize input_weights and recurrent_weights.
     const float* input_weights_ptr = reinterpret_cast<float*>(weights_->buffer);
@@ -118,7 +121,7 @@
     for (uint32_t o = 0; o < num_units; o++) {
       for (uint32_t h = 0; h < num_units; h++) {
         output_ptr_batch[o] +=
-            hidden_state_ptr_batch[h] * recurrent_weights_ptr[h];
+            hidden_state_in_ptr_batch[h] * recurrent_weights_ptr[h];
       }
       recurrent_weights_ptr += recurrent_weights_stride;
     }
@@ -127,7 +130,7 @@
     for (uint32_t o = 0; o < num_units; o++) {
       output_ptr_batch[o] =
           (ActivationFunctor(activation_))(output_ptr_batch[o]);
-      hidden_state_ptr_batch[o] = output_ptr_batch[o];
+      hidden_state_out_ptr_batch[o] = output_ptr_batch[o];
     }
   }
 
diff --git a/common/operations/RNN.h b/common/operations/RNN.h
index c6ba7dc..0c6a881 100644
--- a/common/operations/RNN.h
+++ b/common/operations/RNN.h
@@ -50,9 +50,10 @@
   static constexpr int kWeightsTensor = 1;  // Optional
   static constexpr int kRecurrentWeightsTensor = 2;
   static constexpr int kBiasTensor = 3;
-  static constexpr int kActivationParam = 4;
+  static constexpr int kHiddenStateInTensor = 4;
+  static constexpr int kActivationParam = 5;
 
-  static constexpr int kHiddenStateTensor = 0;
+  static constexpr int kHiddenStateOutTensor = 0;
   static constexpr int kOutputTensor = 1;
 
  private:
@@ -62,8 +63,9 @@
   const RunTimeOperandInfo *weights_;
   const RunTimeOperandInfo *recurrent_weights_;
   const RunTimeOperandInfo *bias_;
-  const RunTimeOperandInfo *hidden_state_;
+  const RunTimeOperandInfo *hidden_state_in_;
 
+  RunTimeOperandInfo *hidden_state_out_;
   RunTimeOperandInfo *output_;
 };
 
diff --git a/common/operations/RNNTest.cpp b/common/operations/RNNTest.cpp
index d9b2965..7561f33 100644
--- a/common/operations/RNNTest.cpp
+++ b/common/operations/RNNTest.cpp
@@ -140,10 +140,11 @@
   ACTION(Weights)                                \
   ACTION(RecurrentWeights)                       \
   ACTION(Bias)                                   \
+  ACTION(HiddenStateIn)
 
 // For all output and intermediate states
 #define FOR_ALL_OUTPUT_TENSORS(ACTION) \
-  ACTION(HiddenState)                  \
+  ACTION(HiddenStateOut)               \
   ACTION(Output)
 
 class BasicRNNOpModel {
@@ -163,18 +164,20 @@
     inputs.push_back(model_.addOperand(&RecurrentWeightTy));
     OperandType BiasTy(Type::TENSOR_FLOAT32, {units_});
     inputs.push_back(model_.addOperand(&BiasTy));
-    OperandType ActionParamTy(Type::INT32, {});
+    OperandType HiddenStateTy(Type::TENSOR_FLOAT32, {batches_, units_});
+    inputs.push_back(model_.addOperand(&HiddenStateTy));
+    OperandType ActionParamTy(Type::INT32, {1});
     inputs.push_back(model_.addOperand(&ActionParamTy));
 
     std::vector<uint32_t> outputs;
 
-    OperandType HiddenStateTy(Type::TENSOR_FLOAT32, {batches_, units_});
     outputs.push_back(model_.addOperand(&HiddenStateTy));
     OperandType OutputTy(Type::TENSOR_FLOAT32, {batches_, units_});
     outputs.push_back(model_.addOperand(&OutputTy));
 
     Input_.insert(Input_.end(), batches_ * input_size_, 0.f);
-    HiddenState_.insert(HiddenState_.end(), batches_ * units_, 0.f);
+    HiddenStateIn_.insert(HiddenStateIn_.end(), batches_ * units_, 0.f);
+    HiddenStateOut_.insert(HiddenStateOut_.end(), batches_ * units_, 0.f);
     Output_.insert(Output_.end(), batches_ * units_, 0.f);
 
     model_.addOperation(ANEURALNETWORKS_RNN, inputs, outputs);
@@ -199,7 +202,7 @@
   }
 
   void ResetHiddenState() {
-    std::fill(HiddenState_.begin(), HiddenState_.end(), 0.f);
+    std::fill(HiddenStateIn_.begin(), HiddenStateIn_.end(), 0.f);
   }
 
   const std::vector<float>& GetOutput() const { return Output_; }
diff --git a/common/operations/SVDF.cpp b/common/operations/SVDF.cpp
index 42f1a72..f600b26 100644
--- a/common/operations/SVDF.cpp
+++ b/common/operations/SVDF.cpp
@@ -26,11 +26,12 @@
 
 // TODO: Implement this using circular buffer instead.
 // This is here temporarily only to show the logic.
-void svdf_right_shift_state(float* state, int state_len, float shift_value) {
+void svdf_right_shift_state(const float* state_in, int state_len, float shift_value,
+                            float* state_out) {
   for (int i = 0; i < state_len - 1; i++) {
-    state[i] = state[i + 1];
+    state_out[i] = state_in[i + 1];
   }
-  state[state_len - 1] = shift_value;
+  state_out[state_len - 1] = shift_value;
 }
 
 int32_t getInt32ScalarData(RunTimeOperandInfo& info) {
@@ -46,12 +47,13 @@
     weights_feature_ = GetInput(operation, operands, kWeightsFeatureTensor);
     weights_time_ = GetInput(operation, operands, kWeightsTimeTensor);
     bias_ = GetInput(operation, operands, kBiasTensor);
+    state_in_ = GetInput(operation, operands, kStateInTensor);
 
     params_.rank_ = getInt32ScalarData(*GetInput(operation, operands, kRankParam));
     params_.activation_ = static_cast<ActivationFn>(getInt32ScalarData(
         *GetInput(operation, operands, kActivationParam)));
 
-    state_ = GetOutput(operation, operands, kStateTensor);
+    state_out_ = GetOutput(operation, operands, kStateOutTensor);
     output_ = GetOutput(operation, operands, kOutputTensor);
 }
 
@@ -61,7 +63,7 @@
                    Shape *outputShape) {
   // Check we have all the inputs and outputs we need.
   const int num_inputs = NumInputsWithValues(operation, operands);
-  NN_CHECK(num_inputs == 5 || num_inputs == 6);
+  NN_CHECK(num_inputs == 6 || num_inputs == 7);
   NN_CHECK_EQ(NumOutputs(operation), 2);
 
   const RunTimeOperandInfo *input =
@@ -85,11 +87,6 @@
     NN_CHECK_EQ(SizeOfDimension(bias, 0), num_units);
   }
 
-  const RunTimeOperandInfo *state =
-      GetInput(operation, operands, SVDF::kStateTensor);
-  const RunTimeOperandInfo *output =
-      GetInput(operation, operands, SVDF::kOutputTensor);
-
   // Resize state.
   const Shape &inputShape = input->shape();
   stateShape->type = inputShape.type;
@@ -123,7 +120,8 @@
         // Initialize the pointer to input, output and bias.
         const float* input_ptr_batch = reinterpret_cast<float *>(input_->buffer) + b * input_size;
         float* output_ptr_batch = reinterpret_cast<float*>(output_->buffer) + b * num_units;
-        float* state_ptr_batch = reinterpret_cast<float*>(state_->buffer) + b * (memory_size - 1) * num_units;
+        const float* state_in_ptr_batch = reinterpret_cast<const float*>(state_in_->buffer) + b * (memory_size - 1) * num_units;
+        float* state_out_ptr_batch = reinterpret_cast<float*>(state_out_->buffer) + b * (memory_size - 1) * num_units;
 
         // For each unit
         for (int c = 0; c < num_units; c++) {
@@ -135,7 +133,8 @@
             }
 
             // Initialize state pointer for unit 'c'.
-            float* state_ptr = state_ptr_batch + c * (memory_size - 1);
+            const float* state_in_ptr = state_in_ptr_batch + c * (memory_size - 1);
+            float* state_out_ptr = state_out_ptr_batch + c * (memory_size - 1);
 
             // Apply bias if bias tensor exists.
             output_ptr_batch[c] = bias_->buffer ? reinterpret_cast<float *>(bias_->buffer)[c] : 0.f;
@@ -143,7 +142,7 @@
             // output = tf.matmul(state, weights_time)
             output_ptr_batch[c] += weights_time_ptr[memory_size - 1] * activation;
             for (int j = 0; j < memory_size - 1; j++) {
-                output_ptr_batch[c] += weights_time_ptr[j] * state_ptr[j];
+                output_ptr_batch[c] += weights_time_ptr[j] * state_in_ptr[j];
             }
 
             // Apply activation.
@@ -151,7 +150,8 @@
                     (ActivationFunctor(params_.activation_))(output_ptr_batch[c]);
 
             // Right shift the state and concatenate with activation.
-            svdf_right_shift_state(state_ptr, memory_size - 1, activation);
+            svdf_right_shift_state(state_in_ptr, memory_size - 1, activation,
+                                   state_out_ptr);
 
             // Update weight pointers.
             weights_feature_ptr += weights_feature_stride;
diff --git a/common/operations/SVDF.h b/common/operations/SVDF.h
index aafef12..a219fe5 100644
--- a/common/operations/SVDF.h
+++ b/common/operations/SVDF.h
@@ -58,10 +58,11 @@
     static constexpr int kWeightsFeatureTensor = 1;
     static constexpr int kWeightsTimeTensor = 2;
     static constexpr int kBiasTensor = 3;  // Optional
-    static constexpr int kRankParam = 4;
-    static constexpr int kActivationParam = 5;
+    static constexpr int kStateInTensor = 4;
+    static constexpr int kRankParam = 5;
+    static constexpr int kActivationParam = 6;
 
-    static constexpr int kStateTensor = 0;
+    static constexpr int kStateOutTensor = 0;
     static constexpr int kOutputTensor = 1;
 
 private:
@@ -71,8 +72,9 @@
     const RunTimeOperandInfo *weights_feature_;
     const RunTimeOperandInfo *weights_time_;
     const RunTimeOperandInfo *bias_;
+    const RunTimeOperandInfo *state_in_;
 
-    RunTimeOperandInfo *state_;
+    RunTimeOperandInfo *state_out_;
     RunTimeOperandInfo *output_;
 };
 
diff --git a/common/operations/SVDFTest.cpp b/common/operations/SVDFTest.cpp
index 775d3c1..a3c32e2 100644
--- a/common/operations/SVDFTest.cpp
+++ b/common/operations/SVDFTest.cpp
@@ -108,11 +108,12 @@
   ACTION(Input)                                  \
   ACTION(WeightsFeature)                         \
   ACTION(WeightsTime)                            \
-  ACTION(Bias)
+  ACTION(Bias)                                   \
+  ACTION(StateIn)
 
 // For all output and intermediate states
 #define FOR_ALL_OUTPUT_TENSORS(ACTION) \
-  ACTION(State)                        \
+  ACTION(StateOut)                     \
   ACTION(Output)
 
 // Derived class of SingleOpModel, which is used to test SVDF TFLite op.
@@ -128,7 +129,8 @@
         {batches_, input_size_},  // Input tensor
         {units_, input_size_},    // weights_feature tensor
         {units_, memory_size_},   // weights_time tensor
-        {units_}                  // bias tensor
+        {units_},                  // bias tensor
+        {batches_, (memory_size_ - 1) * units_},  // state in
     };
     std::vector<uint32_t> inputs;
     auto it = input_shapes.begin();
@@ -233,7 +235,7 @@
   }
 
   // Resets the state of SVDF op by filling it with 0's.
-  void ResetState() { std::fill(State_.begin(), State_.end(), 0.f); }
+  void ResetState() { std::fill(StateIn_.begin(), StateIn_.end(), 0.f); }
 
   // Extracts the output tensor from the SVDF op.
   const std::vector<float>& GetOutput() const { return Output_; }