Change stateful Ops to stateless ones
Bug: 63905942
Updated Ops RNN, LSTM, and SVDF.
Split outputs used for states into inputs and outputs.
Test: NeuralNetworksTest
Change-Id: Ia3d11f640cba4cab1b94d0b9746c46d347c024a4
diff --git a/common/operations/LSTM.cpp b/common/operations/LSTM.cpp
index 6a1feed..e7b0dec 100644
--- a/common/operations/LSTM.cpp
+++ b/common/operations/LSTM.cpp
@@ -171,13 +171,16 @@
projection_weights_ = GetInput(operation, operands, kProjectionWeightsTensor); // optional
projection_bias_ = GetInput(operation, operands, kProjectionBiasTensor); // optional
+ output_state_in_ = GetInput(operation, operands, kOutputStateInTensor);
+ cell_state_in_ = GetInput(operation, operands, kCellStateInTensor);
+
params_.activation_ = static_cast<ActivationFn>(getScalarData<int32_t>(
*GetInput(operation, operands, kActivationParam)));
params_.cell_clip_ = getScalarData<float>(*GetInput(operation, operands, kCellClipParam));
params_.proj_clip_ = getScalarData<float>(*GetInput(operation, operands, kProjClipParam));
- output_state_ = GetOutput(operation, operands, kOutputStateTensor);
- cell_state_ = GetOutput(operation, operands, kCellStateTensor);
+ output_state_out_ = GetOutput(operation, operands, kOutputStateOutTensor);
+ cell_state_out_ = GetOutput(operation, operands, kCellStateOutTensor);
output_ = GetOutput(operation, operands, kOutputTensor);
scratch_buffer_ = GetOutput(operation, operands, kScratchBufferTensor);
@@ -338,8 +341,8 @@
Shape *cellStateShape,
Shape *outputShape) {
// Check we have all the inputs and outputs we need.
- NN_CHECK(NumInputsWithValues(operation, operands) >= 13 &&
- NumInputsWithValues(operation, operands) <= 21);
+ NN_CHECK(NumInputsWithValues(operation, operands) >= 15 &&
+ NumInputsWithValues(operation, operands) <= 23);
NN_CHECK_EQ(NumOutputs(operation), 4);
// Inferring batch size, number of outputs and number of cells from the
@@ -463,24 +466,24 @@
if (!use_cifg) {
MatrixBatchVectorMultiplyAccumulate(
GetBuffer<float>(recurrent_to_input_weights_), n_cell, n_output,
- GetBuffer<float>(output_state_), n_batch, input_gate_scratch);
+ GetBuffer<float>(output_state_in_), n_batch, input_gate_scratch);
}
MatrixBatchVectorMultiplyAccumulate(
GetBuffer<float>(recurrent_to_forget_weights_), n_cell, n_output,
- GetBuffer<float>(output_state_), n_batch, forget_gate_scratch);
+ GetBuffer<float>(output_state_in_), n_batch, forget_gate_scratch);
MatrixBatchVectorMultiplyAccumulate(
GetBuffer<float>(recurrent_to_cell_weights_), n_cell, n_output,
- GetBuffer<float>(output_state_), n_batch, cell_scratch);
+ GetBuffer<float>(output_state_in_), n_batch, cell_scratch);
MatrixBatchVectorMultiplyAccumulate(
GetBuffer<float>(recurrent_to_output_weights_), n_cell, n_output,
- GetBuffer<float>(output_state_), n_batch, output_gate_scratch);
+ GetBuffer<float>(output_state_in_), n_batch, output_gate_scratch);
// For each batch and cell: update input gate.
if (!use_cifg) {
if (use_peephole) {
VectorBatchVectorCwiseProductAccumulate(
GetBuffer<float>(cell_to_input_weights_), n_cell,
- GetBuffer<float>(cell_state_), n_batch, input_gate_scratch);
+ GetBuffer<float>(cell_state_in_), n_batch, input_gate_scratch);
}
ApplySigmoidToVector(input_gate_scratch, n_cell * n_batch,
input_gate_scratch);
@@ -490,40 +493,40 @@
if (use_peephole) {
VectorBatchVectorCwiseProductAccumulate(
GetBuffer<float>(cell_to_forget_weights_), n_cell,
- GetBuffer<float>(cell_state_), n_batch, forget_gate_scratch);
+ GetBuffer<float>(cell_state_in_), n_batch, forget_gate_scratch);
}
ApplySigmoidToVector(forget_gate_scratch, n_cell * n_batch,
forget_gate_scratch);
// For each batch and cell: update the cell.
- VectorVectorCwiseProduct(forget_gate_scratch, GetBuffer<float>(cell_state_),
- n_batch * n_cell, GetBuffer<float>(cell_state_));
+ VectorVectorCwiseProduct(forget_gate_scratch, GetBuffer<float>(cell_state_in_),
+ n_batch * n_cell, GetBuffer<float>(cell_state_out_));
ApplyActivationToVector(cell_scratch, n_batch * n_cell, params_.activation_,
cell_scratch);
if (use_cifg) {
Sub1Vector(forget_gate_scratch, n_batch * n_cell, forget_gate_scratch);
VectorVectorCwiseProductAccumulate(cell_scratch, forget_gate_scratch,
n_batch * n_cell,
- GetBuffer<float>(cell_state_));
+ GetBuffer<float>(cell_state_out_));
} else {
VectorVectorCwiseProductAccumulate(cell_scratch, input_gate_scratch,
n_batch * n_cell,
- GetBuffer<float>(cell_state_));
+ GetBuffer<float>(cell_state_out_));
}
if (params_.cell_clip_ > 0.0) {
- ClipVector(GetBuffer<float>(cell_state_), n_batch * n_cell,
- params_.cell_clip_, GetBuffer<float>(cell_state_));
+ ClipVector(GetBuffer<float>(cell_state_out_), n_batch * n_cell,
+ params_.cell_clip_, GetBuffer<float>(cell_state_out_));
}
// For each batch and cell: update the output gate.
if (use_peephole) {
VectorBatchVectorCwiseProductAccumulate(
GetBuffer<float>(cell_to_output_weights_), n_cell,
- GetBuffer<float>(cell_state_), n_batch, output_gate_scratch);
+ GetBuffer<float>(cell_state_out_), n_batch, output_gate_scratch);
}
ApplySigmoidToVector(output_gate_scratch, n_batch * n_cell,
output_gate_scratch);
- ApplyActivationToVector(GetBuffer<float>(cell_state_), n_batch * n_cell,
+ ApplyActivationToVector(GetBuffer<float>(cell_state_out_), n_batch * n_cell,
params_.activation_, cell_scratch);
VectorVectorCwiseProduct(output_gate_scratch, cell_scratch, n_batch * n_cell,
output_gate_scratch);
@@ -551,7 +554,7 @@
GetBuffer<float>(output_));
}
CopyVector(GetBuffer<float>(output_), n_batch * n_output,
- GetBuffer<float>(output_state_));
+ GetBuffer<float>(output_state_out_));
return true;
}