Refactors LSTM step function for reuse.
Bug: 113559542
Test: NeuralNetworksTest_static --gtest_filter=GeneratedTests.lstm*
Change-Id: I427bc41bf81f4d0cd0021f333e79b5cab95d7105
Merged-In: I427bc41bf81f4d0cd0021f333e79b5cab95d7105
(cherry picked from commit 72fd4fe3fb705c9dbd543f1fe1c25d3d23545ad3)
diff --git a/common/operations/LSTM.cpp b/common/operations/LSTM.cpp
index f1e01e2..5159952 100644
--- a/common/operations/LSTM.cpp
+++ b/common/operations/LSTM.cpp
@@ -71,15 +71,15 @@
output_state_in_ = GetInput(operation, operands, kOutputStateInTensor);
cell_state_in_ = GetInput(operation, operands, kCellStateInTensor);
- params_.activation_ = static_cast<TfLiteFusedActivation>(
+ params_.activation = static_cast<TfLiteFusedActivation>(
getScalarData<int32_t>(*GetInput(operation, operands, kActivationParam)));
if (input_->type == OperandType::TENSOR_FLOAT32) {
- params_.cell_clip_ = getScalarData<float>(*GetInput(operation, operands, kCellClipParam));
- params_.proj_clip_ = getScalarData<float>(*GetInput(operation, operands, kProjClipParam));
+ params_.cell_clip = getScalarData<float>(*GetInput(operation, operands, kCellClipParam));
+ params_.proj_clip = getScalarData<float>(*GetInput(operation, operands, kProjClipParam));
} else {
- params_.cell_clip_ = static_cast<float>(
+ params_.cell_clip = static_cast<float>(
getScalarData<_Float16>(*GetInput(operation, operands, kCellClipParam)));
- params_.proj_clip_ = static_cast<float>(
+ params_.proj_clip = static_cast<float>(
getScalarData<_Float16>(*GetInput(operation, operands, kProjClipParam)));
}
@@ -108,100 +108,126 @@
scratch_buffer_ = GetOutput(operation, operands, kScratchBufferTensor);
}
-bool LSTMCell::CheckInputTensorDimensions(const Operation& operation,
- std::vector<RunTimeOperandInfo>& operands,
- uint32_t n_input, uint32_t n_output, uint32_t n_cell) {
+// static
+bool LSTMCell::CheckInputTensorDimensions(
+ const RunTimeOperandInfo* input_, const RunTimeOperandInfo* input_to_input_weights,
+ const RunTimeOperandInfo* input_to_forget_weights,
+ const RunTimeOperandInfo* input_to_cell_weights,
+ const RunTimeOperandInfo* input_to_output_weights,
+ const RunTimeOperandInfo* recurrent_to_input_weights,
+ const RunTimeOperandInfo* recurrent_to_forget_weights,
+ const RunTimeOperandInfo* recurrent_to_cell_weights,
+ const RunTimeOperandInfo* recurrent_to_output_weights,
+ const RunTimeOperandInfo* cell_to_input_weights,
+ const RunTimeOperandInfo* cell_to_forget_weights,
+ const RunTimeOperandInfo* cell_to_output_weights, const RunTimeOperandInfo* input_gate_bias,
+ const RunTimeOperandInfo* forget_gate_bias, const RunTimeOperandInfo* cell_bias,
+ const RunTimeOperandInfo* output_gate_bias, const RunTimeOperandInfo* projection_weights,
+ const RunTimeOperandInfo* projection_bias,
+ const RunTimeOperandInfo* input_layer_norm_weights,
+ const RunTimeOperandInfo* forget_layer_norm_weights,
+ const RunTimeOperandInfo* cell_layer_norm_weights,
+ const RunTimeOperandInfo* output_layer_norm_weights, uint32_t n_input, uint32_t n_output,
+ uint32_t n_cell, LSTMParams* params) {
// Making sure clipping parameters have valid values.
// == 0 means no clipping
// > 0 means clipping
- NN_CHECK(params_.cell_clip_ >= 0);
- NN_CHECK(params_.proj_clip_ >= 0);
+ NN_CHECK(params->cell_clip >= 0);
+ NN_CHECK(params->proj_clip >= 0);
- if (!IsNullInput(input_to_input_weights_)) {
- NN_CHECK_EQ(NumDimensions(input_to_input_weights_), 2);
- NN_CHECK_EQ(SizeOfDimension(input_to_input_weights_, 0), n_cell);
- NN_CHECK_EQ(SizeOfDimension(input_to_input_weights_, 1), n_input);
+ if (!IsNullInput(input_to_input_weights)) {
+ NN_CHECK_EQ(NumDimensions(input_to_input_weights), 2);
+ NN_CHECK_EQ(SizeOfDimension(input_to_input_weights, 0), n_cell);
+ NN_CHECK_EQ(SizeOfDimension(input_to_input_weights, 1), n_input);
}
- NN_CHECK_EQ(NumDimensions(input_to_forget_weights_), 2);
- NN_CHECK_EQ(SizeOfDimension(input_to_forget_weights_, 0), n_cell);
- NN_CHECK_EQ(SizeOfDimension(input_to_forget_weights_, 1), n_input);
+ NN_CHECK_EQ(NumDimensions(input_to_forget_weights), 2);
+ NN_CHECK_EQ(SizeOfDimension(input_to_forget_weights, 0), n_cell);
+ NN_CHECK_EQ(SizeOfDimension(input_to_forget_weights, 1), n_input);
- NN_CHECK_EQ(NumDimensions(input_to_cell_weights_), 2);
- NN_CHECK_EQ(SizeOfDimension(input_to_cell_weights_, 0), n_cell);
- NN_CHECK_EQ(SizeOfDimension(input_to_cell_weights_, 1), n_input);
+ NN_CHECK_EQ(NumDimensions(input_to_cell_weights), 2);
+ NN_CHECK_EQ(SizeOfDimension(input_to_cell_weights, 0), n_cell);
+ NN_CHECK_EQ(SizeOfDimension(input_to_cell_weights, 1), n_input);
- if (!IsNullInput(recurrent_to_input_weights_)) {
- NN_CHECK_EQ(NumDimensions(recurrent_to_input_weights_), 2);
- NN_CHECK_EQ(SizeOfDimension(recurrent_to_input_weights_, 0), n_cell);
- NN_CHECK_EQ(SizeOfDimension(recurrent_to_input_weights_, 1), n_output);
+ if (!IsNullInput(recurrent_to_input_weights)) {
+ NN_CHECK_EQ(NumDimensions(recurrent_to_input_weights), 2);
+ NN_CHECK_EQ(SizeOfDimension(recurrent_to_input_weights, 0), n_cell);
+ NN_CHECK_EQ(SizeOfDimension(recurrent_to_input_weights, 1), n_output);
}
- NN_CHECK_EQ(NumDimensions(recurrent_to_forget_weights_), 2);
- NN_CHECK_EQ(SizeOfDimension(recurrent_to_forget_weights_, 0), n_cell);
- NN_CHECK_EQ(SizeOfDimension(recurrent_to_forget_weights_, 1), n_output);
+ NN_CHECK_EQ(NumDimensions(recurrent_to_forget_weights), 2);
+ NN_CHECK_EQ(SizeOfDimension(recurrent_to_forget_weights, 0), n_cell);
+ NN_CHECK_EQ(SizeOfDimension(recurrent_to_forget_weights, 1), n_output);
- NN_CHECK_EQ(NumDimensions(recurrent_to_cell_weights_), 2);
- NN_CHECK_EQ(SizeOfDimension(recurrent_to_cell_weights_, 0), n_cell);
- NN_CHECK_EQ(SizeOfDimension(recurrent_to_cell_weights_, 1), n_output);
+ NN_CHECK_EQ(NumDimensions(recurrent_to_cell_weights), 2);
+ NN_CHECK_EQ(SizeOfDimension(recurrent_to_cell_weights, 0), n_cell);
+ NN_CHECK_EQ(SizeOfDimension(recurrent_to_cell_weights, 1), n_output);
// We make sure the input-gate's parameters are either both present (regular
// LSTM) or not at all (CIFG-LSTM).
const bool cifg_weights_all_or_none =
- (!IsNullInput(input_to_input_weights_) && !IsNullInput(recurrent_to_input_weights_)) ||
- (IsNullInput(input_to_input_weights_) && IsNullInput(recurrent_to_input_weights_));
+ (!IsNullInput(input_to_input_weights) && !IsNullInput(recurrent_to_input_weights)) ||
+ (IsNullInput(input_to_input_weights) && IsNullInput(recurrent_to_input_weights));
NN_CHECK(cifg_weights_all_or_none);
- if (!IsNullInput(cell_to_input_weights_)) {
- NN_CHECK_EQ(NumDimensions(cell_to_input_weights_), 1);
- NN_CHECK_EQ(SizeOfDimension(cell_to_input_weights_, 0), n_cell);
+ if (!IsNullInput(cell_to_input_weights)) {
+ NN_CHECK_EQ(NumDimensions(cell_to_input_weights), 1);
+ NN_CHECK_EQ(SizeOfDimension(cell_to_input_weights, 0), n_cell);
}
- if (!IsNullInput(cell_to_forget_weights_)) {
- NN_CHECK_EQ(NumDimensions(cell_to_forget_weights_), 1);
- NN_CHECK_EQ(SizeOfDimension(cell_to_forget_weights_, 0), n_cell);
+ if (!IsNullInput(cell_to_forget_weights)) {
+ NN_CHECK_EQ(NumDimensions(cell_to_forget_weights), 1);
+ NN_CHECK_EQ(SizeOfDimension(cell_to_forget_weights, 0), n_cell);
}
- if (!IsNullInput(cell_to_output_weights_)) {
- NN_CHECK_EQ(NumDimensions(cell_to_output_weights_), 1);
- NN_CHECK_EQ(SizeOfDimension(cell_to_output_weights_, 0), n_cell);
+ if (!IsNullInput(cell_to_output_weights)) {
+ NN_CHECK_EQ(NumDimensions(cell_to_output_weights), 1);
+ NN_CHECK_EQ(SizeOfDimension(cell_to_output_weights, 0), n_cell);
}
// Making sure the peephole weights are there all or none.
- const bool use_cifg = IsNullInput(input_to_input_weights_);
+ params->use_cifg = IsNullInput(input_to_input_weights);
const bool peephole_weights_all_or_none =
- ((!IsNullInput(cell_to_input_weights_) || use_cifg) &&
- !IsNullInput(cell_to_forget_weights_) && !IsNullInput(cell_to_output_weights_)) ||
- (IsNullInput(cell_to_input_weights_) && IsNullInput(cell_to_forget_weights_) &&
- IsNullInput(cell_to_output_weights_));
+ ((!IsNullInput(cell_to_input_weights) || params->use_cifg) &&
+ !IsNullInput(cell_to_forget_weights) && !IsNullInput(cell_to_output_weights)) ||
+ (IsNullInput(cell_to_input_weights) && IsNullInput(cell_to_forget_weights) &&
+ IsNullInput(cell_to_output_weights));
NN_CHECK(peephole_weights_all_or_none);
+ // Since we have already checked that weights are all there or none, we can
+ // check the existence of only one to the get the condition.
+ params->use_peephole = !IsNullInput(cell_to_output_weights);
+ params->use_layer_norm = !IsNullInput(input_layer_norm_weights);
+
+ params->use_projection_weight = (projection_weights->lifetime != OperandLifeTime::NO_VALUE);
+ params->use_projection_bias = (projection_bias->lifetime != OperandLifeTime::NO_VALUE);
+
// Make sure the input gate bias is present only when not a CIFG-LSTM.
- if (use_cifg) {
- NN_CHECK(IsNullInput(input_gate_bias_));
+ if (params->use_cifg) {
+ NN_CHECK(IsNullInput(input_gate_bias));
} else {
- NN_CHECK_EQ(NumDimensions(input_gate_bias_), 1);
- NN_CHECK_EQ(SizeOfDimension(input_gate_bias_, 0), n_cell);
+ NN_CHECK_EQ(NumDimensions(input_gate_bias), 1);
+ NN_CHECK_EQ(SizeOfDimension(input_gate_bias, 0), n_cell);
}
- NN_CHECK_EQ(NumDimensions(forget_gate_bias_), 1);
- NN_CHECK_EQ(SizeOfDimension(forget_gate_bias_, 0), n_cell);
+ NN_CHECK_EQ(NumDimensions(forget_gate_bias), 1);
+ NN_CHECK_EQ(SizeOfDimension(forget_gate_bias, 0), n_cell);
- NN_CHECK_EQ(NumDimensions(cell_bias_), 1);
- NN_CHECK_EQ(SizeOfDimension(cell_bias_, 0), n_cell);
+ NN_CHECK_EQ(NumDimensions(cell_bias), 1);
+ NN_CHECK_EQ(SizeOfDimension(cell_bias, 0), n_cell);
- NN_CHECK_EQ(NumDimensions(output_gate_bias_), 1);
- NN_CHECK_EQ(SizeOfDimension(output_gate_bias_, 0), n_cell);
+ NN_CHECK_EQ(NumDimensions(output_gate_bias), 1);
+ NN_CHECK_EQ(SizeOfDimension(output_gate_bias, 0), n_cell);
- if (!IsNullInput(projection_weights_)) {
- NN_CHECK_EQ(NumDimensions(projection_weights_), 2);
- NN_CHECK_EQ(SizeOfDimension(projection_weights_, 0), n_output);
- NN_CHECK_EQ(SizeOfDimension(projection_weights_, 1), n_cell);
+ if (!IsNullInput(projection_weights)) {
+ NN_CHECK_EQ(NumDimensions(projection_weights), 2);
+ NN_CHECK_EQ(SizeOfDimension(projection_weights, 0), n_output);
+ NN_CHECK_EQ(SizeOfDimension(projection_weights, 1), n_cell);
}
- if (!IsNullInput(projection_bias_)) {
- NN_CHECK_EQ(NumDimensions(projection_bias_), 1);
- NN_CHECK_EQ(SizeOfDimension(projection_bias_, 0), n_output);
+ if (!IsNullInput(projection_bias)) {
+ NN_CHECK_EQ(NumDimensions(projection_bias), 1);
+ NN_CHECK_EQ(SizeOfDimension(projection_bias, 0), n_output);
}
// Making sure the projection tensors are consistent:
@@ -210,31 +236,31 @@
// 2) If projection weight is present, then projection bias is optional.
// TODO: make sure this is correct.
const bool projecton_tensors_consistent =
- (!IsNullInput(projection_weights_) || IsNullInput(projection_bias_));
+ (!IsNullInput(projection_weights) || IsNullInput(projection_bias));
NN_CHECK(projecton_tensors_consistent == true);
- if (!IsNullInput(input_layer_norm_weights_)) {
- NN_CHECK_EQ(NumDimensions(input_layer_norm_weights_), 1);
- NN_CHECK_EQ(SizeOfDimension(input_layer_norm_weights_, 0), n_cell);
+ if (!IsNullInput(input_layer_norm_weights)) {
+ NN_CHECK_EQ(NumDimensions(input_layer_norm_weights), 1);
+ NN_CHECK_EQ(SizeOfDimension(input_layer_norm_weights, 0), n_cell);
}
- if (!IsNullInput(forget_layer_norm_weights_)) {
- NN_CHECK_EQ(NumDimensions(forget_layer_norm_weights_), 1);
- NN_CHECK_EQ(SizeOfDimension(forget_layer_norm_weights_, 0), n_cell);
+ if (!IsNullInput(forget_layer_norm_weights)) {
+ NN_CHECK_EQ(NumDimensions(forget_layer_norm_weights), 1);
+ NN_CHECK_EQ(SizeOfDimension(forget_layer_norm_weights, 0), n_cell);
}
- if (!IsNullInput(cell_layer_norm_weights_)) {
- NN_CHECK_EQ(NumDimensions(cell_layer_norm_weights_), 1);
- NN_CHECK_EQ(SizeOfDimension(cell_layer_norm_weights_, 0), n_cell);
+ if (!IsNullInput(cell_layer_norm_weights)) {
+ NN_CHECK_EQ(NumDimensions(cell_layer_norm_weights), 1);
+ NN_CHECK_EQ(SizeOfDimension(cell_layer_norm_weights, 0), n_cell);
}
- if (!IsNullInput(output_layer_norm_weights_)) {
- NN_CHECK_EQ(NumDimensions(output_layer_norm_weights_), 1);
- NN_CHECK_EQ(SizeOfDimension(output_layer_norm_weights_, 0), n_cell);
+ if (!IsNullInput(output_layer_norm_weights)) {
+ NN_CHECK_EQ(NumDimensions(output_layer_norm_weights), 1);
+ NN_CHECK_EQ(SizeOfDimension(output_layer_norm_weights, 0), n_cell);
}
const bool layer_norm_weights_all_or_none =
- (IsNullInput(input_layer_norm_weights_) && IsNullInput(forget_layer_norm_weights_) &&
- IsNullInput(cell_layer_norm_weights_) && IsNullInput(input_layer_norm_weights_)) ||
- (!IsNullInput(input_layer_norm_weights_) && !IsNullInput(forget_layer_norm_weights_) &&
- !IsNullInput(cell_layer_norm_weights_) && !IsNullInput(input_layer_norm_weights_));
+ (IsNullInput(input_layer_norm_weights) && IsNullInput(forget_layer_norm_weights) &&
+ IsNullInput(cell_layer_norm_weights) && IsNullInput(input_layer_norm_weights)) ||
+ (!IsNullInput(input_layer_norm_weights) && !IsNullInput(forget_layer_norm_weights) &&
+ !IsNullInput(cell_layer_norm_weights) && !IsNullInput(input_layer_norm_weights));
NN_CHECK(layer_norm_weights_all_or_none);
return true;
@@ -263,7 +289,15 @@
const uint32_t n_output = SizeOfDimension(recurrent_to_output_weights_, 1);
// Check that input tensor dimensions matches with each other.
- if (!CheckInputTensorDimensions(operation, operands, n_input, n_output, n_cell)) {
+ if (!CheckInputTensorDimensions(
+ input_, input_to_input_weights_, input_to_forget_weights_, input_to_cell_weights_,
+ input_to_output_weights_, recurrent_to_input_weights_, recurrent_to_forget_weights_,
+ recurrent_to_cell_weights_, recurrent_to_output_weights_, cell_to_input_weights_,
+ cell_to_forget_weights_, cell_to_output_weights_, input_gate_bias_,
+ forget_gate_bias_, cell_bias_, output_gate_bias_, projection_weights_,
+ projection_bias_, input_layer_norm_weights_, forget_layer_norm_weights_,
+ cell_layer_norm_weights_, output_layer_norm_weights_, n_input, n_output, n_cell,
+ ¶ms_)) {
return false;
}
@@ -285,8 +319,7 @@
cellStateShape->offset = inputShape.offset;
cellStateShape->scale = inputShape.scale;
- const bool use_cifg = IsNullInput(input_to_input_weights_);
- if (use_cifg) {
+ if (params_.use_cifg) {
// Reserving space for Cell, Forget, Output gates
scratchShape->dimensions = {n_batch, n_cell * 3};
} else {
@@ -300,13 +333,16 @@
return true;
}
-bool LSTMCell::EvalFloat32(
- const float* input_buffer, const float* input_to_input_weights_buffer,
- const float* input_to_forget_weights_buffer, const float* input_to_cell_weights_buffer,
- const float* input_to_output_weights_buffer, const float* recurrent_to_input_weights_buffer,
+// static
+bool LSTMCell::LSTMStep(
+ const LSTMParams& params, const float* input_buffer, const Shape& input_shape,
+ const float* input_to_input_weights_buffer, const float* input_to_forget_weights_buffer,
+ const float* input_to_cell_weights_buffer, const float* input_to_output_weights_buffer,
+ const Shape& input_to_output_weights_shape, const float* recurrent_to_input_weights_buffer,
const float* recurrent_to_forget_weights_buffer,
const float* recurrent_to_cell_weights_buffer,
- const float* recurrent_to_output_weights_buffer, const float* cell_to_input_weights_buffer,
+ const float* recurrent_to_output_weights_buffer,
+ const Shape& recurrent_to_output_weights_shape, const float* cell_to_input_weights_buffer,
const float* cell_to_forget_weights_buffer, const float* cell_to_output_weights_buffer,
const float* input_gate_bias_buffer, const float* forget_gate_bias_buffer,
const float* cell_bias_buffer, const float* output_gate_bias_buffer,
@@ -316,26 +352,20 @@
const float* cell_layer_norm_weights_buffer, const float* output_layer_norm_weights_buffer,
float* output_state_out_buffer, float* cell_state_out_buffer, float* output_buffer,
float* scratch_buffer_buffer) {
- NNTRACE_COMP("LSTMCell::Eval");
+ NNTRACE_COMP("LSTMCell::LSTMStep");
- const uint32_t n_batch = input_->shape().dimensions[0];
- const uint32_t n_input = input_->shape().dimensions[1];
+ const uint32_t n_batch = input_shape.dimensions[0];
+ const uint32_t n_input = input_shape.dimensions[1];
// n_cell and n_output will be the same size when there is no projection.
- const uint32_t n_cell = input_to_output_weights_->shape().dimensions[0];
- const uint32_t n_output = recurrent_to_output_weights_->shape().dimensions[1];
-
- // Since we have already checked that weights are all there or none, we can
- // check the existence of only one to the get the condition.
- const bool use_cifg = IsNullInput(input_to_input_weights_);
- const bool use_peephole = !IsNullInput(cell_to_output_weights_);
- const bool use_layer_norm = !IsNullInput(input_layer_norm_weights_);
+ const uint32_t n_cell = input_to_output_weights_shape.dimensions[0];
+ const uint32_t n_output = recurrent_to_output_weights_shape.dimensions[1];
// Index the scratch buffers pointers to the global scratch buffer.
float* input_gate_scratch = nullptr;
float* cell_scratch = nullptr;
float* forget_gate_scratch = nullptr;
float* output_gate_scratch = nullptr;
- if (use_cifg) {
+ if (params.use_cifg) {
cell_scratch = scratch_buffer_buffer;
forget_gate_scratch = cell_scratch + n_cell * n_batch;
output_gate_scratch = cell_scratch + 2 * n_cell * n_batch;
@@ -346,9 +376,9 @@
output_gate_scratch = input_gate_scratch + 3 * n_cell * n_batch;
}
- if (!use_layer_norm) {
+ if (!params.use_layer_norm) {
// Initialize scratch buffers with bias.
- if (!use_cifg) {
+ if (!params.use_cifg) {
tflite::tensor_utils::VectorBatchVectorAssign(input_gate_bias_buffer, n_cell, n_batch,
input_gate_scratch);
}
@@ -360,7 +390,7 @@
output_gate_scratch);
} else {
// Initialize scratch buffers with zeroes.
- if (!use_cifg) {
+ if (!params.use_cifg) {
tflite::tensor_utils::ZeroVector(input_gate_scratch, n_cell * n_batch);
}
tflite::tensor_utils::ZeroVector(forget_gate_scratch, n_cell * n_batch);
@@ -369,7 +399,7 @@
}
// For each batch and cell: compute input_weight * input.
- if (!use_cifg) {
+ if (!params.use_cifg) {
tflite::tensor_utils::MatrixBatchVectorMultiplyAccumulate(
input_to_input_weights_buffer, n_cell, n_input, input_buffer, n_batch,
input_gate_scratch, /*result_stride*/ 1);
@@ -385,7 +415,7 @@
output_gate_scratch, /*result_stride*/ 1);
// For each batch and cell: compute recurrent_weight * output_state.
- if (!use_cifg) {
+ if (!params.use_cifg) {
tflite::tensor_utils::MatrixBatchVectorMultiplyAccumulate(
recurrent_to_input_weights_buffer, n_cell, n_output, output_state_in_buffer,
n_batch, input_gate_scratch,
@@ -402,13 +432,13 @@
output_gate_scratch, /*result_stride*/ 1);
// For each batch and cell: update input gate.
- if (!use_cifg) {
- if (use_peephole) {
+ if (!params.use_cifg) {
+ if (params.use_peephole) {
tflite::tensor_utils::VectorBatchVectorCwiseProductAccumulate(
cell_to_input_weights_buffer, n_cell, cell_state_in_buffer, n_batch,
input_gate_scratch);
}
- if (use_layer_norm) {
+ if (params.use_layer_norm) {
tflite::tensor_utils::MeanStddevNormalization(input_gate_scratch, input_gate_scratch,
n_cell, n_batch, kLayerNormEpsilon);
tflite::tensor_utils::VectorBatchVectorCwiseProduct(input_layer_norm_weights_buffer,
@@ -422,12 +452,12 @@
}
// For each batch and cell: update forget gate.
- if (use_peephole) {
+ if (params.use_peephole) {
tflite::tensor_utils::VectorBatchVectorCwiseProductAccumulate(cell_to_forget_weights_buffer,
n_cell, cell_state_in_buffer,
n_batch, forget_gate_scratch);
}
- if (use_layer_norm) {
+ if (params.use_layer_norm) {
tflite::tensor_utils::MeanStddevNormalization(forget_gate_scratch, forget_gate_scratch,
n_cell, n_batch, kLayerNormEpsilon);
tflite::tensor_utils::VectorBatchVectorCwiseProduct(forget_layer_norm_weights_buffer,
@@ -440,7 +470,7 @@
forget_gate_scratch);
// For each batch and cell: update the cell.
- if (use_layer_norm) {
+ if (params.use_layer_norm) {
tflite::tensor_utils::MeanStddevNormalization(cell_scratch, cell_scratch, n_cell, n_batch,
kLayerNormEpsilon);
tflite::tensor_utils::VectorBatchVectorCwiseProduct(cell_layer_norm_weights_buffer, n_cell,
@@ -449,9 +479,9 @@
}
tflite::tensor_utils::VectorVectorCwiseProduct(forget_gate_scratch, cell_state_in_buffer,
n_batch * n_cell, cell_state_out_buffer);
- tflite::tensor_utils::ApplyActivationToVector(cell_scratch, n_batch * n_cell,
- params_.activation_, cell_scratch);
- if (use_cifg) {
+ tflite::tensor_utils::ApplyActivationToVector(cell_scratch, n_batch * n_cell, params.activation,
+ cell_scratch);
+ if (params.use_cifg) {
tflite::tensor_utils::Sub1Vector(forget_gate_scratch, n_batch * n_cell,
forget_gate_scratch);
tflite::tensor_utils::VectorVectorCwiseProductAccumulate(
@@ -460,18 +490,18 @@
tflite::tensor_utils::VectorVectorCwiseProductAccumulate(
cell_scratch, input_gate_scratch, n_batch * n_cell, cell_state_out_buffer);
}
- if (params_.cell_clip_ > 0.0) {
- tflite::tensor_utils::ClipVector(cell_state_out_buffer, n_batch * n_cell,
- params_.cell_clip_, cell_state_out_buffer);
+ if (params.cell_clip > 0.0) {
+ tflite::tensor_utils::ClipVector(cell_state_out_buffer, n_batch * n_cell, params.cell_clip,
+ cell_state_out_buffer);
}
// For each batch and cell: update the output gate.
- if (use_peephole) {
+ if (params.use_peephole) {
tflite::tensor_utils::VectorBatchVectorCwiseProductAccumulate(cell_to_output_weights_buffer,
n_cell, cell_state_out_buffer,
n_batch, output_gate_scratch);
}
- if (use_layer_norm) {
+ if (params.use_layer_norm) {
tflite::tensor_utils::MeanStddevNormalization(output_gate_scratch, output_gate_scratch,
n_cell, n_batch, kLayerNormEpsilon);
tflite::tensor_utils::VectorBatchVectorCwiseProduct(output_layer_norm_weights_buffer,
@@ -483,15 +513,13 @@
tflite::tensor_utils::ApplySigmoidToVector(output_gate_scratch, n_batch * n_cell,
output_gate_scratch);
tflite::tensor_utils::ApplyActivationToVector(cell_state_out_buffer, n_batch * n_cell,
- params_.activation_, cell_scratch);
+ params.activation, cell_scratch);
tflite::tensor_utils::VectorVectorCwiseProduct(output_gate_scratch, cell_scratch,
n_batch * n_cell, output_gate_scratch);
// For each batch: update the projection and output_state.
- const bool use_projection_weight = (projection_weights_->lifetime != OperandLifeTime::NO_VALUE);
- const bool use_projection_bias = (projection_bias_->lifetime != OperandLifeTime::NO_VALUE);
- if (use_projection_weight) {
- if (use_projection_bias) {
+ if (params.use_projection_weight) {
+ if (params.use_projection_bias) {
tflite::tensor_utils::VectorBatchVectorAssign(projection_bias_buffer, n_output, n_batch,
output_buffer);
} else {
@@ -501,8 +529,8 @@
projection_weights_buffer, n_output, n_cell, output_gate_scratch, n_batch,
output_buffer,
/*result_stride*/ 1);
- if (params_.proj_clip_ > 0.0) {
- tflite::tensor_utils::ClipVector(output_buffer, n_batch * n_output, params_.proj_clip_,
+ if (params.proj_clip > 0.0) {
+ tflite::tensor_utils::ClipVector(output_buffer, n_batch * n_output, params.proj_clip,
output_buffer);
}
} else {
@@ -515,37 +543,33 @@
bool LSTMCell::Eval() {
switch (input_->type) {
case OperandType::TENSOR_FLOAT32: {
- // clang-format off
- EvalFloat32(
- GetBuffer<const float>(input_),
- GetBuffer<const float>(input_to_input_weights_),
- GetBuffer<const float>(input_to_forget_weights_),
- GetBuffer<const float>(input_to_cell_weights_),
- GetBuffer<const float>(input_to_output_weights_),
- GetBuffer<const float>(recurrent_to_input_weights_),
- GetBuffer<const float>(recurrent_to_forget_weights_),
- GetBuffer<const float>(recurrent_to_cell_weights_),
- GetBuffer<const float>(recurrent_to_output_weights_),
- GetBuffer<const float>(cell_to_input_weights_),
- GetBuffer<const float>(cell_to_forget_weights_),
- GetBuffer<const float>(cell_to_output_weights_),
- GetBuffer<const float>(input_gate_bias_),
- GetBuffer<const float>(forget_gate_bias_),
- GetBuffer<const float>(cell_bias_),
- GetBuffer<const float>(output_gate_bias_),
- GetBuffer<const float>(projection_weights_),
- GetBuffer<const float>(projection_bias_),
- GetBuffer<const float>(output_state_in_),
- GetBuffer<const float>(cell_state_in_),
- GetBuffer<const float>(input_layer_norm_weights_),
- GetBuffer<const float>(forget_layer_norm_weights_),
- GetBuffer<const float>(cell_layer_norm_weights_),
- GetBuffer<const float>(output_layer_norm_weights_),
- GetBuffer<float>(output_state_out_),
- GetBuffer<float>(cell_state_out_),
- GetBuffer<float>(output_),
- GetBuffer<float>(scratch_buffer_));
- // clang-format on
+ LSTMStep(params_, GetBuffer<const float>(input_), input_->shape(),
+ GetBuffer<const float>(input_to_input_weights_),
+ GetBuffer<const float>(input_to_forget_weights_),
+ GetBuffer<const float>(input_to_cell_weights_),
+ GetBuffer<const float>(input_to_output_weights_),
+ input_to_output_weights_->shape(),
+ GetBuffer<const float>(recurrent_to_input_weights_),
+ GetBuffer<const float>(recurrent_to_forget_weights_),
+ GetBuffer<const float>(recurrent_to_cell_weights_),
+ GetBuffer<const float>(recurrent_to_output_weights_),
+ recurrent_to_output_weights_->shape(),
+ GetBuffer<const float>(cell_to_input_weights_),
+ GetBuffer<const float>(cell_to_forget_weights_),
+ GetBuffer<const float>(cell_to_output_weights_),
+ GetBuffer<const float>(input_gate_bias_),
+ GetBuffer<const float>(forget_gate_bias_), GetBuffer<const float>(cell_bias_),
+ GetBuffer<const float>(output_gate_bias_),
+ GetBuffer<const float>(projection_weights_),
+ GetBuffer<const float>(projection_bias_),
+ GetBuffer<const float>(output_state_in_),
+ GetBuffer<const float>(cell_state_in_),
+ GetBuffer<const float>(input_layer_norm_weights_),
+ GetBuffer<const float>(forget_layer_norm_weights_),
+ GetBuffer<const float>(cell_layer_norm_weights_),
+ GetBuffer<const float>(output_layer_norm_weights_),
+ GetBuffer<float>(output_state_out_), GetBuffer<float>(cell_state_out_),
+ GetBuffer<float>(output_), GetBuffer<float>(scratch_buffer_));
} break;
case OperandType::TENSOR_FLOAT16: {
std::vector<float> input_float32(getNumberOfElements(input_->shape()));
@@ -699,37 +723,22 @@
getNumberOfElements(scratch_buffer_->shape()));
convertFloat16ToFloat32(GetBuffer<_Float16>(scratch_buffer_), &scratch_buffer_float32);
- // clang-format off
- EvalFloat32(
- input_float32.data(),
- input_to_input_weights_buffer,
- input_to_forget_weights_float32.data(),
- input_to_cell_weights_float32.data(),
- input_to_output_weights_float32.data(),
- recurrent_to_input_weights_buffer,
- recurrent_to_forget_weights_float32.data(),
- recurrent_to_cell_weights_float32.data(),
- recurrent_to_output_weights_float32.data(),
- cell_to_input_weights_buffer,
- cell_to_forget_weights_buffer,
- cell_to_output_weights_buffer,
- input_gate_bias_buffer,
- forget_gate_bias_float32.data(),
- cell_bias_float32.data(),
- output_gate_bias_float32.data(),
- projection_weights_buffer,
- projection_bias_buffer,
- output_state_in_float32.data(),
- cell_state_in_float32.data(),
- input_layer_norm_weights_buffer,
- forget_layer_norm_weights_buffer,
- cell_layer_norm_weights_buffer,
- output_layer_norm_weights_buffer,
- output_state_out_float32.data(),
- cell_state_out_float32.data(),
- output_float32.data(),
- scratch_buffer_float32.data());
- // clang-format on
+ LSTMStep(params_, input_float32.data(), input_->shape(), input_to_input_weights_buffer,
+ input_to_forget_weights_float32.data(), input_to_cell_weights_float32.data(),
+ input_to_output_weights_float32.data(), input_to_output_weights_->shape(),
+ recurrent_to_input_weights_buffer, recurrent_to_forget_weights_float32.data(),
+ recurrent_to_cell_weights_float32.data(),
+ recurrent_to_output_weights_float32.data(),
+ recurrent_to_output_weights_->shape(), cell_to_input_weights_buffer,
+ cell_to_forget_weights_buffer, cell_to_output_weights_buffer,
+ input_gate_bias_buffer, forget_gate_bias_float32.data(),
+ cell_bias_float32.data(), output_gate_bias_float32.data(),
+ projection_weights_buffer, projection_bias_buffer,
+ output_state_in_float32.data(), cell_state_in_float32.data(),
+ input_layer_norm_weights_buffer, forget_layer_norm_weights_buffer,
+ cell_layer_norm_weights_buffer, output_layer_norm_weights_buffer,
+ output_state_out_float32.data(), cell_state_out_float32.data(),
+ output_float32.data(), scratch_buffer_float32.data());
convertFloat32ToFloat16(output_state_out_float32,
GetBuffer<_Float16>(output_state_out_));