Change stateful Ops to stateless ones
Bug: 63905942
Updated Ops RNN, LSTM, and SVDF.
Split outputs used for states into inputs and outputs.
Test: NeuralNetworksTest
Change-Id: Ia3d11f640cba4cab1b94d0b9746c46d347c024a4
diff --git a/common/operations/LSTM.cpp b/common/operations/LSTM.cpp
index 6a1feed..e7b0dec 100644
--- a/common/operations/LSTM.cpp
+++ b/common/operations/LSTM.cpp
@@ -171,13 +171,16 @@
projection_weights_ = GetInput(operation, operands, kProjectionWeightsTensor); // optional
projection_bias_ = GetInput(operation, operands, kProjectionBiasTensor); // optional
+ output_state_in_ = GetInput(operation, operands, kOutputStateInTensor);
+ cell_state_in_ = GetInput(operation, operands, kCellStateInTensor);
+
params_.activation_ = static_cast<ActivationFn>(getScalarData<int32_t>(
*GetInput(operation, operands, kActivationParam)));
params_.cell_clip_ = getScalarData<float>(*GetInput(operation, operands, kCellClipParam));
params_.proj_clip_ = getScalarData<float>(*GetInput(operation, operands, kProjClipParam));
- output_state_ = GetOutput(operation, operands, kOutputStateTensor);
- cell_state_ = GetOutput(operation, operands, kCellStateTensor);
+ output_state_out_ = GetOutput(operation, operands, kOutputStateOutTensor);
+ cell_state_out_ = GetOutput(operation, operands, kCellStateOutTensor);
output_ = GetOutput(operation, operands, kOutputTensor);
scratch_buffer_ = GetOutput(operation, operands, kScratchBufferTensor);
@@ -338,8 +341,8 @@
Shape *cellStateShape,
Shape *outputShape) {
// Check we have all the inputs and outputs we need.
- NN_CHECK(NumInputsWithValues(operation, operands) >= 13 &&
- NumInputsWithValues(operation, operands) <= 21);
+ NN_CHECK(NumInputsWithValues(operation, operands) >= 15 &&
+ NumInputsWithValues(operation, operands) <= 23);
NN_CHECK_EQ(NumOutputs(operation), 4);
// Inferring batch size, number of outputs and number of cells from the
@@ -463,24 +466,24 @@
if (!use_cifg) {
MatrixBatchVectorMultiplyAccumulate(
GetBuffer<float>(recurrent_to_input_weights_), n_cell, n_output,
- GetBuffer<float>(output_state_), n_batch, input_gate_scratch);
+ GetBuffer<float>(output_state_in_), n_batch, input_gate_scratch);
}
MatrixBatchVectorMultiplyAccumulate(
GetBuffer<float>(recurrent_to_forget_weights_), n_cell, n_output,
- GetBuffer<float>(output_state_), n_batch, forget_gate_scratch);
+ GetBuffer<float>(output_state_in_), n_batch, forget_gate_scratch);
MatrixBatchVectorMultiplyAccumulate(
GetBuffer<float>(recurrent_to_cell_weights_), n_cell, n_output,
- GetBuffer<float>(output_state_), n_batch, cell_scratch);
+ GetBuffer<float>(output_state_in_), n_batch, cell_scratch);
MatrixBatchVectorMultiplyAccumulate(
GetBuffer<float>(recurrent_to_output_weights_), n_cell, n_output,
- GetBuffer<float>(output_state_), n_batch, output_gate_scratch);
+ GetBuffer<float>(output_state_in_), n_batch, output_gate_scratch);
// For each batch and cell: update input gate.
if (!use_cifg) {
if (use_peephole) {
VectorBatchVectorCwiseProductAccumulate(
GetBuffer<float>(cell_to_input_weights_), n_cell,
- GetBuffer<float>(cell_state_), n_batch, input_gate_scratch);
+ GetBuffer<float>(cell_state_in_), n_batch, input_gate_scratch);
}
ApplySigmoidToVector(input_gate_scratch, n_cell * n_batch,
input_gate_scratch);
@@ -490,40 +493,40 @@
if (use_peephole) {
VectorBatchVectorCwiseProductAccumulate(
GetBuffer<float>(cell_to_forget_weights_), n_cell,
- GetBuffer<float>(cell_state_), n_batch, forget_gate_scratch);
+ GetBuffer<float>(cell_state_in_), n_batch, forget_gate_scratch);
}
ApplySigmoidToVector(forget_gate_scratch, n_cell * n_batch,
forget_gate_scratch);
// For each batch and cell: update the cell.
- VectorVectorCwiseProduct(forget_gate_scratch, GetBuffer<float>(cell_state_),
- n_batch * n_cell, GetBuffer<float>(cell_state_));
+ VectorVectorCwiseProduct(forget_gate_scratch, GetBuffer<float>(cell_state_in_),
+ n_batch * n_cell, GetBuffer<float>(cell_state_out_));
ApplyActivationToVector(cell_scratch, n_batch * n_cell, params_.activation_,
cell_scratch);
if (use_cifg) {
Sub1Vector(forget_gate_scratch, n_batch * n_cell, forget_gate_scratch);
VectorVectorCwiseProductAccumulate(cell_scratch, forget_gate_scratch,
n_batch * n_cell,
- GetBuffer<float>(cell_state_));
+ GetBuffer<float>(cell_state_out_));
} else {
VectorVectorCwiseProductAccumulate(cell_scratch, input_gate_scratch,
n_batch * n_cell,
- GetBuffer<float>(cell_state_));
+ GetBuffer<float>(cell_state_out_));
}
if (params_.cell_clip_ > 0.0) {
- ClipVector(GetBuffer<float>(cell_state_), n_batch * n_cell,
- params_.cell_clip_, GetBuffer<float>(cell_state_));
+ ClipVector(GetBuffer<float>(cell_state_out_), n_batch * n_cell,
+ params_.cell_clip_, GetBuffer<float>(cell_state_out_));
}
// For each batch and cell: update the output gate.
if (use_peephole) {
VectorBatchVectorCwiseProductAccumulate(
GetBuffer<float>(cell_to_output_weights_), n_cell,
- GetBuffer<float>(cell_state_), n_batch, output_gate_scratch);
+ GetBuffer<float>(cell_state_out_), n_batch, output_gate_scratch);
}
ApplySigmoidToVector(output_gate_scratch, n_batch * n_cell,
output_gate_scratch);
- ApplyActivationToVector(GetBuffer<float>(cell_state_), n_batch * n_cell,
+ ApplyActivationToVector(GetBuffer<float>(cell_state_out_), n_batch * n_cell,
params_.activation_, cell_scratch);
VectorVectorCwiseProduct(output_gate_scratch, cell_scratch, n_batch * n_cell,
output_gate_scratch);
@@ -551,7 +554,7 @@
GetBuffer<float>(output_));
}
CopyVector(GetBuffer<float>(output_), n_batch * n_output,
- GetBuffer<float>(output_state_));
+ GetBuffer<float>(output_state_out_));
return true;
}
diff --git a/common/operations/LSTM.h b/common/operations/LSTM.h
index 430ee61..cc54612 100644
--- a/common/operations/LSTM.h
+++ b/common/operations/LSTM.h
@@ -88,15 +88,17 @@
// Projection bias tensor of size {n_output}
static constexpr int kProjectionBiasTensor = 17; // Optional
- static constexpr int kActivationParam = 18;
- static constexpr int kCellClipParam = 19;
- static constexpr int kProjClipParam = 20;
+ static constexpr int kOutputStateInTensor = 18;
+ static constexpr int kCellStateInTensor = 19;
+
+ static constexpr int kActivationParam = 20;
+ static constexpr int kCellClipParam = 21;
+ static constexpr int kProjClipParam = 22;
// Output tensors.
- // TODO: Do we have to pre-allocate scratch buffer as outputs?
static constexpr int kScratchBufferTensor = 0;
- static constexpr int kOutputStateTensor = 1;
- static constexpr int kCellStateTensor = 2;
+ static constexpr int kOutputStateOutTensor = 1;
+ static constexpr int kCellStateOutTensor = 2;
static constexpr int kOutputTensor = 3;
private:
@@ -130,8 +132,11 @@
const RunTimeOperandInfo *projection_weights_;
const RunTimeOperandInfo *projection_bias_;
- RunTimeOperandInfo *output_state_;
- RunTimeOperandInfo *cell_state_;
+ const RunTimeOperandInfo *output_state_in_;
+ const RunTimeOperandInfo *cell_state_in_;
+
+ RunTimeOperandInfo *output_state_out_;
+ RunTimeOperandInfo *cell_state_out_;
RunTimeOperandInfo *output_;
RunTimeOperandInfo *scratch_buffer_;
diff --git a/common/operations/LSTMTest.cpp b/common/operations/LSTMTest.cpp
index a61ef07..b8a80ec 100644
--- a/common/operations/LSTMTest.cpp
+++ b/common/operations/LSTMTest.cpp
@@ -60,13 +60,15 @@
ACTION(ForgetGateBias) \
ACTION(OutputGateBias) \
ACTION(ProjectionWeights) \
- ACTION(ProjectionBias)
+ ACTION(ProjectionBias) \
+ ACTION(OutputStateIn) \
+ ACTION(CellStateIn)
// For all output and intermediate states
#define FOR_ALL_OUTPUT_TENSORS(ACTION) \
ACTION(Output) \
- ACTION(OutputState) \
- ACTION(CellState) \
+ ACTION(OutputStateOut) \
+ ACTION(CellStateOut) \
ACTION(ScratchBuffer)
class LSTMOpModel {
@@ -75,7 +77,7 @@
uint32_t n_cell, uint32_t n_output, bool use_cifg,
bool use_peephole, bool use_projection_weights,
bool use_projection_bias, float cell_clip, float proj_clip,
- const std::vector<std::vector<uint32_t>>& input_shapes)
+ const std::vector<std::vector<uint32_t>>& input_shapes0)
: n_batch_(n_batch), n_input_(n_input),
n_cell_(n_cell), n_output_(n_output),
use_cifg_(use_cifg), use_peephole_(use_peephole),
@@ -84,7 +86,10 @@
activation_(ActivationFn::kActivationTanh),
cell_clip_(cell_clip), proj_clip_(proj_clip) {
std::vector<uint32_t> inputs;
+ std::vector<std::vector<uint32_t>> input_shapes(input_shapes0.begin(), input_shapes0.end());
auto it = input_shapes.begin();
+ input_shapes.push_back({n_batch, n_output});
+ input_shapes.push_back({n_batch, n_cell});
// Input and weights
#define AddInput(X) \
@@ -154,11 +159,11 @@
#undef DefineSetter
void ResetOutputState() {
- std::fill(OutputState_.begin(), OutputState_.end(), 0.f);
+ std::fill(OutputStateIn_.begin(), OutputStateIn_.end(), 0.f);
}
void ResetCellState() {
- std::fill(CellState_.begin(), CellState_.end(), 0.f);
+ std::fill(CellStateIn_.begin(), CellStateIn_.end(), 0.f);
}
void SetInput(int offset, float *begin, float *end) {
diff --git a/common/operations/RNN.cpp b/common/operations/RNN.cpp
index 4e25a88..8a00734 100644
--- a/common/operations/RNN.cpp
+++ b/common/operations/RNN.cpp
@@ -27,12 +27,13 @@
input_ = GetInput(operation, operands, kInputTensor);
weights_ = GetInput(operation, operands, kWeightsTensor);
recurrent_weights_ = GetInput(operation, operands, kRecurrentWeightsTensor);
+ hidden_state_in_ = GetInput(operation, operands, kHiddenStateInTensor);
bias_ = GetInput(operation, operands, kBiasTensor);
activation_ = static_cast<ActivationFn>(
getScalarData<int32_t>(operands[operation.inputs[kActivationParam]]));
- hidden_state_ = GetOutput(operation, operands, kHiddenStateTensor);
+ hidden_state_out_ = GetOutput(operation, operands, kHiddenStateOutTensor);
output_ = GetOutput(operation, operands, kOutputTensor);
}
@@ -42,7 +43,7 @@
Shape *outputShape) {
// Check we have all the inputs and outputs we need.
const int num_inputs = NumInputsWithValues(operation, operands);
- NN_CHECK(num_inputs == 4 || num_inputs == 5);
+ NN_CHECK(num_inputs == 5 || num_inputs == 6);
NN_CHECK_EQ(NumOutputs(operation), 2);
const RunTimeOperandInfo *input =
@@ -91,10 +92,12 @@
// Initialize the pointer to input, output and bias.
const float* input_ptr_batch =
reinterpret_cast<float*>(input_->buffer) + b * input_size;
+ const float* hidden_state_in_ptr_batch =
+ reinterpret_cast<float*>(hidden_state_in_->buffer) + b * num_units;
float* output_ptr_batch =
reinterpret_cast<float*>(output_->buffer) + b * num_units;
- float* hidden_state_ptr_batch =
- reinterpret_cast<float*>(hidden_state_->buffer) + b * num_units;
+ float* hidden_state_out_ptr_batch =
+ reinterpret_cast<float*>(hidden_state_out_->buffer) + b * num_units;
// Initialize input_weights and recurrent_weights.
const float* input_weights_ptr = reinterpret_cast<float*>(weights_->buffer);
@@ -118,7 +121,7 @@
for (uint32_t o = 0; o < num_units; o++) {
for (uint32_t h = 0; h < num_units; h++) {
output_ptr_batch[o] +=
- hidden_state_ptr_batch[h] * recurrent_weights_ptr[h];
+ hidden_state_in_ptr_batch[h] * recurrent_weights_ptr[h];
}
recurrent_weights_ptr += recurrent_weights_stride;
}
@@ -127,7 +130,7 @@
for (uint32_t o = 0; o < num_units; o++) {
output_ptr_batch[o] =
(ActivationFunctor(activation_))(output_ptr_batch[o]);
- hidden_state_ptr_batch[o] = output_ptr_batch[o];
+ hidden_state_out_ptr_batch[o] = output_ptr_batch[o];
}
}
diff --git a/common/operations/RNN.h b/common/operations/RNN.h
index c6ba7dc..0c6a881 100644
--- a/common/operations/RNN.h
+++ b/common/operations/RNN.h
@@ -50,9 +50,10 @@
static constexpr int kWeightsTensor = 1; // Optional
static constexpr int kRecurrentWeightsTensor = 2;
static constexpr int kBiasTensor = 3;
- static constexpr int kActivationParam = 4;
+ static constexpr int kHiddenStateInTensor = 4;
+ static constexpr int kActivationParam = 5;
- static constexpr int kHiddenStateTensor = 0;
+ static constexpr int kHiddenStateOutTensor = 0;
static constexpr int kOutputTensor = 1;
private:
@@ -62,8 +63,9 @@
const RunTimeOperandInfo *weights_;
const RunTimeOperandInfo *recurrent_weights_;
const RunTimeOperandInfo *bias_;
- const RunTimeOperandInfo *hidden_state_;
+ const RunTimeOperandInfo *hidden_state_in_;
+ RunTimeOperandInfo *hidden_state_out_;
RunTimeOperandInfo *output_;
};
diff --git a/common/operations/RNNTest.cpp b/common/operations/RNNTest.cpp
index d9b2965..7561f33 100644
--- a/common/operations/RNNTest.cpp
+++ b/common/operations/RNNTest.cpp
@@ -140,10 +140,11 @@
ACTION(Weights) \
ACTION(RecurrentWeights) \
ACTION(Bias) \
+ ACTION(HiddenStateIn)
// For all output and intermediate states
#define FOR_ALL_OUTPUT_TENSORS(ACTION) \
- ACTION(HiddenState) \
+ ACTION(HiddenStateOut) \
ACTION(Output)
class BasicRNNOpModel {
@@ -163,18 +164,20 @@
inputs.push_back(model_.addOperand(&RecurrentWeightTy));
OperandType BiasTy(Type::TENSOR_FLOAT32, {units_});
inputs.push_back(model_.addOperand(&BiasTy));
- OperandType ActionParamTy(Type::INT32, {});
+ OperandType HiddenStateTy(Type::TENSOR_FLOAT32, {batches_, units_});
+ inputs.push_back(model_.addOperand(&HiddenStateTy));
+ OperandType ActionParamTy(Type::INT32, {1});
inputs.push_back(model_.addOperand(&ActionParamTy));
std::vector<uint32_t> outputs;
- OperandType HiddenStateTy(Type::TENSOR_FLOAT32, {batches_, units_});
outputs.push_back(model_.addOperand(&HiddenStateTy));
OperandType OutputTy(Type::TENSOR_FLOAT32, {batches_, units_});
outputs.push_back(model_.addOperand(&OutputTy));
Input_.insert(Input_.end(), batches_ * input_size_, 0.f);
- HiddenState_.insert(HiddenState_.end(), batches_ * units_, 0.f);
+ HiddenStateIn_.insert(HiddenStateIn_.end(), batches_ * units_, 0.f);
+ HiddenStateOut_.insert(HiddenStateOut_.end(), batches_ * units_, 0.f);
Output_.insert(Output_.end(), batches_ * units_, 0.f);
model_.addOperation(ANEURALNETWORKS_RNN, inputs, outputs);
@@ -199,7 +202,7 @@
}
void ResetHiddenState() {
- std::fill(HiddenState_.begin(), HiddenState_.end(), 0.f);
+ std::fill(HiddenStateIn_.begin(), HiddenStateIn_.end(), 0.f);
}
const std::vector<float>& GetOutput() const { return Output_; }
diff --git a/common/operations/SVDF.cpp b/common/operations/SVDF.cpp
index 42f1a72..f600b26 100644
--- a/common/operations/SVDF.cpp
+++ b/common/operations/SVDF.cpp
@@ -26,11 +26,12 @@
// TODO: Implement this using circular buffer instead.
// This is here temporarily only to show the logic.
-void svdf_right_shift_state(float* state, int state_len, float shift_value) {
+void svdf_right_shift_state(const float* state_in, int state_len, float shift_value,
+ float* state_out) {
for (int i = 0; i < state_len - 1; i++) {
- state[i] = state[i + 1];
+ state_out[i] = state_in[i + 1];
}
- state[state_len - 1] = shift_value;
+ state_out[state_len - 1] = shift_value;
}
int32_t getInt32ScalarData(RunTimeOperandInfo& info) {
@@ -46,12 +47,13 @@
weights_feature_ = GetInput(operation, operands, kWeightsFeatureTensor);
weights_time_ = GetInput(operation, operands, kWeightsTimeTensor);
bias_ = GetInput(operation, operands, kBiasTensor);
+ state_in_ = GetInput(operation, operands, kStateInTensor);
params_.rank_ = getInt32ScalarData(*GetInput(operation, operands, kRankParam));
params_.activation_ = static_cast<ActivationFn>(getInt32ScalarData(
*GetInput(operation, operands, kActivationParam)));
- state_ = GetOutput(operation, operands, kStateTensor);
+ state_out_ = GetOutput(operation, operands, kStateOutTensor);
output_ = GetOutput(operation, operands, kOutputTensor);
}
@@ -61,7 +63,7 @@
Shape *outputShape) {
// Check we have all the inputs and outputs we need.
const int num_inputs = NumInputsWithValues(operation, operands);
- NN_CHECK(num_inputs == 5 || num_inputs == 6);
+ NN_CHECK(num_inputs == 6 || num_inputs == 7);
NN_CHECK_EQ(NumOutputs(operation), 2);
const RunTimeOperandInfo *input =
@@ -85,11 +87,6 @@
NN_CHECK_EQ(SizeOfDimension(bias, 0), num_units);
}
- const RunTimeOperandInfo *state =
- GetInput(operation, operands, SVDF::kStateTensor);
- const RunTimeOperandInfo *output =
- GetInput(operation, operands, SVDF::kOutputTensor);
-
// Resize state.
const Shape &inputShape = input->shape();
stateShape->type = inputShape.type;
@@ -123,7 +120,8 @@
// Initialize the pointer to input, output and bias.
const float* input_ptr_batch = reinterpret_cast<float *>(input_->buffer) + b * input_size;
float* output_ptr_batch = reinterpret_cast<float*>(output_->buffer) + b * num_units;
- float* state_ptr_batch = reinterpret_cast<float*>(state_->buffer) + b * (memory_size - 1) * num_units;
+ const float* state_in_ptr_batch = reinterpret_cast<const float*>(state_in_->buffer) + b * (memory_size - 1) * num_units;
+ float* state_out_ptr_batch = reinterpret_cast<float*>(state_out_->buffer) + b * (memory_size - 1) * num_units;
// For each unit
for (int c = 0; c < num_units; c++) {
@@ -135,7 +133,8 @@
}
// Initialize state pointer for unit 'c'.
- float* state_ptr = state_ptr_batch + c * (memory_size - 1);
+ const float* state_in_ptr = state_in_ptr_batch + c * (memory_size - 1);
+ float* state_out_ptr = state_out_ptr_batch + c * (memory_size - 1);
// Apply bias if bias tensor exists.
output_ptr_batch[c] = bias_->buffer ? reinterpret_cast<float *>(bias_->buffer)[c] : 0.f;
@@ -143,7 +142,7 @@
// output = tf.matmul(state, weights_time)
output_ptr_batch[c] += weights_time_ptr[memory_size - 1] * activation;
for (int j = 0; j < memory_size - 1; j++) {
- output_ptr_batch[c] += weights_time_ptr[j] * state_ptr[j];
+ output_ptr_batch[c] += weights_time_ptr[j] * state_in_ptr[j];
}
// Apply activation.
@@ -151,7 +150,8 @@
(ActivationFunctor(params_.activation_))(output_ptr_batch[c]);
// Right shift the state and concatenate with activation.
- svdf_right_shift_state(state_ptr, memory_size - 1, activation);
+ svdf_right_shift_state(state_in_ptr, memory_size - 1, activation,
+ state_out_ptr);
// Update weight pointers.
weights_feature_ptr += weights_feature_stride;
diff --git a/common/operations/SVDF.h b/common/operations/SVDF.h
index aafef12..a219fe5 100644
--- a/common/operations/SVDF.h
+++ b/common/operations/SVDF.h
@@ -58,10 +58,11 @@
static constexpr int kWeightsFeatureTensor = 1;
static constexpr int kWeightsTimeTensor = 2;
static constexpr int kBiasTensor = 3; // Optional
- static constexpr int kRankParam = 4;
- static constexpr int kActivationParam = 5;
+ static constexpr int kStateInTensor = 4;
+ static constexpr int kRankParam = 5;
+ static constexpr int kActivationParam = 6;
- static constexpr int kStateTensor = 0;
+ static constexpr int kStateOutTensor = 0;
static constexpr int kOutputTensor = 1;
private:
@@ -71,8 +72,9 @@
const RunTimeOperandInfo *weights_feature_;
const RunTimeOperandInfo *weights_time_;
const RunTimeOperandInfo *bias_;
+ const RunTimeOperandInfo *state_in_;
- RunTimeOperandInfo *state_;
+ RunTimeOperandInfo *state_out_;
RunTimeOperandInfo *output_;
};
diff --git a/common/operations/SVDFTest.cpp b/common/operations/SVDFTest.cpp
index 775d3c1..a3c32e2 100644
--- a/common/operations/SVDFTest.cpp
+++ b/common/operations/SVDFTest.cpp
@@ -108,11 +108,12 @@
ACTION(Input) \
ACTION(WeightsFeature) \
ACTION(WeightsTime) \
- ACTION(Bias)
+ ACTION(Bias) \
+ ACTION(StateIn)
// For all output and intermediate states
#define FOR_ALL_OUTPUT_TENSORS(ACTION) \
- ACTION(State) \
+ ACTION(StateOut) \
ACTION(Output)
// Derived class of SingleOpModel, which is used to test SVDF TFLite op.
@@ -128,7 +129,8 @@
{batches_, input_size_}, // Input tensor
{units_, input_size_}, // weights_feature tensor
{units_, memory_size_}, // weights_time tensor
- {units_} // bias tensor
+ {units_}, // bias tensor
+ {batches_, (memory_size_ - 1) * units_}, // state in
};
std::vector<uint32_t> inputs;
auto it = input_shapes.begin();
@@ -233,7 +235,7 @@
}
// Resets the state of SVDF op by filling it with 0's.
- void ResetState() { std::fill(State_.begin(), State_.end(), 0.f); }
+ void ResetState() { std::fill(StateIn_.begin(), StateIn_.end(), 0.f); }
// Extracts the output tensor from the SVDF op.
const std::vector<float>& GetOutput() const { return Output_; }