Change stateful Ops to stateless ones

Bug: 63905942

Updated Ops RNN, LSTM, and SVDF.
Split outputs used for states into inputs and outputs.

Test: NeuralNetworksTest
Change-Id: Ia3d11f640cba4cab1b94d0b9746c46d347c024a4
diff --git a/common/operations/LSTM.cpp b/common/operations/LSTM.cpp
index 6a1feed..e7b0dec 100644
--- a/common/operations/LSTM.cpp
+++ b/common/operations/LSTM.cpp
@@ -171,13 +171,16 @@
   projection_weights_ = GetInput(operation, operands, kProjectionWeightsTensor);  // optional
   projection_bias_ = GetInput(operation, operands, kProjectionBiasTensor);        // optional
 
+  output_state_in_ = GetInput(operation, operands, kOutputStateInTensor);
+  cell_state_in_ = GetInput(operation, operands, kCellStateInTensor);
+
   params_.activation_ = static_cast<ActivationFn>(getScalarData<int32_t>(
       *GetInput(operation, operands, kActivationParam)));
   params_.cell_clip_ = getScalarData<float>(*GetInput(operation, operands, kCellClipParam));
   params_.proj_clip_ = getScalarData<float>(*GetInput(operation, operands, kProjClipParam));
 
-  output_state_ = GetOutput(operation, operands, kOutputStateTensor);
-  cell_state_ = GetOutput(operation, operands, kCellStateTensor);
+  output_state_out_ = GetOutput(operation, operands, kOutputStateOutTensor);
+  cell_state_out_ = GetOutput(operation, operands, kCellStateOutTensor);
   output_ = GetOutput(operation, operands, kOutputTensor);
 
   scratch_buffer_ = GetOutput(operation, operands, kScratchBufferTensor);
@@ -338,8 +341,8 @@
                        Shape *cellStateShape,
                        Shape *outputShape) {
   // Check we have all the inputs and outputs we need.
-  NN_CHECK(NumInputsWithValues(operation, operands) >= 13 &&
-           NumInputsWithValues(operation, operands) <= 21);
+  NN_CHECK(NumInputsWithValues(operation, operands) >= 15 &&
+           NumInputsWithValues(operation, operands) <= 23);
   NN_CHECK_EQ(NumOutputs(operation), 4);
 
   // Inferring batch size, number of outputs and number of cells from the
@@ -463,24 +466,24 @@
   if (!use_cifg) {
     MatrixBatchVectorMultiplyAccumulate(
         GetBuffer<float>(recurrent_to_input_weights_), n_cell, n_output,
-        GetBuffer<float>(output_state_), n_batch, input_gate_scratch);
+        GetBuffer<float>(output_state_in_), n_batch, input_gate_scratch);
   }
   MatrixBatchVectorMultiplyAccumulate(
       GetBuffer<float>(recurrent_to_forget_weights_), n_cell, n_output,
-      GetBuffer<float>(output_state_), n_batch, forget_gate_scratch);
+      GetBuffer<float>(output_state_in_), n_batch, forget_gate_scratch);
   MatrixBatchVectorMultiplyAccumulate(
       GetBuffer<float>(recurrent_to_cell_weights_), n_cell, n_output,
-      GetBuffer<float>(output_state_), n_batch, cell_scratch);
+      GetBuffer<float>(output_state_in_), n_batch, cell_scratch);
   MatrixBatchVectorMultiplyAccumulate(
       GetBuffer<float>(recurrent_to_output_weights_), n_cell, n_output,
-      GetBuffer<float>(output_state_), n_batch, output_gate_scratch);
+      GetBuffer<float>(output_state_in_), n_batch, output_gate_scratch);
 
   // For each batch and cell: update input gate.
   if (!use_cifg) {
     if (use_peephole) {
       VectorBatchVectorCwiseProductAccumulate(
           GetBuffer<float>(cell_to_input_weights_), n_cell,
-          GetBuffer<float>(cell_state_), n_batch, input_gate_scratch);
+          GetBuffer<float>(cell_state_in_), n_batch, input_gate_scratch);
     }
     ApplySigmoidToVector(input_gate_scratch, n_cell * n_batch,
                          input_gate_scratch);
@@ -490,40 +493,40 @@
   if (use_peephole) {
     VectorBatchVectorCwiseProductAccumulate(
         GetBuffer<float>(cell_to_forget_weights_), n_cell,
-        GetBuffer<float>(cell_state_), n_batch, forget_gate_scratch);
+        GetBuffer<float>(cell_state_in_), n_batch, forget_gate_scratch);
   }
   ApplySigmoidToVector(forget_gate_scratch, n_cell * n_batch,
                        forget_gate_scratch);
 
   // For each batch and cell: update the cell.
-  VectorVectorCwiseProduct(forget_gate_scratch, GetBuffer<float>(cell_state_),
-                           n_batch * n_cell, GetBuffer<float>(cell_state_));
+  VectorVectorCwiseProduct(forget_gate_scratch, GetBuffer<float>(cell_state_in_),
+                           n_batch * n_cell, GetBuffer<float>(cell_state_out_));
   ApplyActivationToVector(cell_scratch, n_batch * n_cell, params_.activation_,
                           cell_scratch);
   if (use_cifg) {
     Sub1Vector(forget_gate_scratch, n_batch * n_cell, forget_gate_scratch);
     VectorVectorCwiseProductAccumulate(cell_scratch, forget_gate_scratch,
                                        n_batch * n_cell,
-                                       GetBuffer<float>(cell_state_));
+                                       GetBuffer<float>(cell_state_out_));
   } else {
     VectorVectorCwiseProductAccumulate(cell_scratch, input_gate_scratch,
                                        n_batch * n_cell,
-                                       GetBuffer<float>(cell_state_));
+                                       GetBuffer<float>(cell_state_out_));
   }
   if (params_.cell_clip_ > 0.0) {
-    ClipVector(GetBuffer<float>(cell_state_), n_batch * n_cell,
-               params_.cell_clip_, GetBuffer<float>(cell_state_));
+    ClipVector(GetBuffer<float>(cell_state_out_), n_batch * n_cell,
+               params_.cell_clip_, GetBuffer<float>(cell_state_out_));
   }
 
   // For each batch and cell: update the output gate.
   if (use_peephole) {
     VectorBatchVectorCwiseProductAccumulate(
         GetBuffer<float>(cell_to_output_weights_), n_cell,
-        GetBuffer<float>(cell_state_), n_batch, output_gate_scratch);
+        GetBuffer<float>(cell_state_out_), n_batch, output_gate_scratch);
   }
   ApplySigmoidToVector(output_gate_scratch, n_batch * n_cell,
                        output_gate_scratch);
-  ApplyActivationToVector(GetBuffer<float>(cell_state_), n_batch * n_cell,
+  ApplyActivationToVector(GetBuffer<float>(cell_state_out_), n_batch * n_cell,
                           params_.activation_, cell_scratch);
   VectorVectorCwiseProduct(output_gate_scratch, cell_scratch, n_batch * n_cell,
                            output_gate_scratch);
@@ -551,7 +554,7 @@
                GetBuffer<float>(output_));
   }
   CopyVector(GetBuffer<float>(output_), n_batch * n_output,
-             GetBuffer<float>(output_state_));
+             GetBuffer<float>(output_state_out_));
 
   return true;
 }