blob: f1e01e2332a45eeed67be1edad1c2daec118e104 [file] [log] [blame]
/*
* Copyright (C) 2017 The Android Open Source Project
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "LSTM.h"
#include "CpuExecutor.h"
#include "CpuOperationUtils.h"
#include "HalInterfaces.h"
#include "Tracing.h"
namespace android {
namespace nn {
namespace {
template <typename T>
inline T* GetBuffer(RunTimeOperandInfo* operand) {
return reinterpret_cast<T*>(operand->buffer);
}
template <typename T>
inline const T* GetBuffer(const RunTimeOperandInfo* operand) {
return reinterpret_cast<const T*>(operand->buffer);
}
} // anonymous namespace
LSTMCell::LSTMCell(const Operation& operation, std::vector<RunTimeOperandInfo>& operands) {
input_ = GetInput(operation, operands, kInputTensor);
input_to_input_weights_ =
GetInput(operation, operands, kInputToInputWeightsTensor); // optional
input_to_forget_weights_ = GetInput(operation, operands, kInputToForgetWeightsTensor);
input_to_cell_weights_ = GetInput(operation, operands, kInputToCellWeightsTensor);
input_to_output_weights_ = GetInput(operation, operands, kInputToOutputWeightsTensor);
recurrent_to_input_weights_ =
GetInput(operation, operands, kRecurrentToInputWeightsTensor); // optional
recurrent_to_forget_weights_ = GetInput(operation, operands, kRecurrentToForgetWeightsTensor);
recurrent_to_cell_weights_ = GetInput(operation, operands, kRecurrentToCellWeightsTensor);
recurrent_to_output_weights_ = GetInput(operation, operands, kRecurrentToOutputWeightsTensor);
cell_to_input_weights_ = GetInput(operation, operands, kCellToInputWeightsTensor); // optional
cell_to_forget_weights_ =
GetInput(operation, operands, kCellToForgetWeightsTensor); // optional
cell_to_output_weights_ =
GetInput(operation, operands, kCellToOutputWeightsTensor); // optional
input_gate_bias_ = GetInput(operation, operands, kInputGateBiasTensor);
forget_gate_bias_ = GetInput(operation, operands, kForgetGateBiasTensor);
cell_bias_ = GetInput(operation, operands, kCellGateBiasTensor);
output_gate_bias_ = GetInput(operation, operands, kOutputGateBiasTensor);
projection_weights_ = GetInput(operation, operands, kProjectionWeightsTensor); // optional
projection_bias_ = GetInput(operation, operands, kProjectionBiasTensor); // optional
output_state_in_ = GetInput(operation, operands, kOutputStateInTensor);
cell_state_in_ = GetInput(operation, operands, kCellStateInTensor);
params_.activation_ = static_cast<TfLiteFusedActivation>(
getScalarData<int32_t>(*GetInput(operation, operands, kActivationParam)));
if (input_->type == OperandType::TENSOR_FLOAT32) {
params_.cell_clip_ = getScalarData<float>(*GetInput(operation, operands, kCellClipParam));
params_.proj_clip_ = getScalarData<float>(*GetInput(operation, operands, kProjClipParam));
} else {
params_.cell_clip_ = static_cast<float>(
getScalarData<_Float16>(*GetInput(operation, operands, kCellClipParam)));
params_.proj_clip_ = static_cast<float>(
getScalarData<_Float16>(*GetInput(operation, operands, kProjClipParam)));
}
// We check the version of LSTM by checking the number of the inputs to the
// op. For LSTM version 1.0 there were 23 inputs and for 1.2 there are 27.
if (operation.inputs.size() == 27) {
input_layer_norm_weights_ = GetInput(operation, operands, kInputLayerNormWeightsTensor);
forget_layer_norm_weights_ = GetInput(operation, operands, kForgetLayerNormWeightsTensor);
cell_layer_norm_weights_ = GetInput(operation, operands, kCellLayerNormWeightsTensor);
output_layer_norm_weights_ = GetInput(operation, operands, kOutputLayerNormWeightsTensor);
} else {
// For LSTM from HAL v1.0 assign operands with no values
static RunTimeOperandInfo no_value;
no_value.lifetime = OperandLifeTime::NO_VALUE;
input_layer_norm_weights_ = &no_value;
forget_layer_norm_weights_ = &no_value;
cell_layer_norm_weights_ = &no_value;
output_layer_norm_weights_ = &no_value;
}
output_state_out_ = GetOutput(operation, operands, kOutputStateOutTensor);
cell_state_out_ = GetOutput(operation, operands, kCellStateOutTensor);
output_ = GetOutput(operation, operands, kOutputTensor);
scratch_buffer_ = GetOutput(operation, operands, kScratchBufferTensor);
}
bool LSTMCell::CheckInputTensorDimensions(const Operation& operation,
std::vector<RunTimeOperandInfo>& operands,
uint32_t n_input, uint32_t n_output, uint32_t n_cell) {
// Making sure clipping parameters have valid values.
// == 0 means no clipping
// > 0 means clipping
NN_CHECK(params_.cell_clip_ >= 0);
NN_CHECK(params_.proj_clip_ >= 0);
if (!IsNullInput(input_to_input_weights_)) {
NN_CHECK_EQ(NumDimensions(input_to_input_weights_), 2);
NN_CHECK_EQ(SizeOfDimension(input_to_input_weights_, 0), n_cell);
NN_CHECK_EQ(SizeOfDimension(input_to_input_weights_, 1), n_input);
}
NN_CHECK_EQ(NumDimensions(input_to_forget_weights_), 2);
NN_CHECK_EQ(SizeOfDimension(input_to_forget_weights_, 0), n_cell);
NN_CHECK_EQ(SizeOfDimension(input_to_forget_weights_, 1), n_input);
NN_CHECK_EQ(NumDimensions(input_to_cell_weights_), 2);
NN_CHECK_EQ(SizeOfDimension(input_to_cell_weights_, 0), n_cell);
NN_CHECK_EQ(SizeOfDimension(input_to_cell_weights_, 1), n_input);
if (!IsNullInput(recurrent_to_input_weights_)) {
NN_CHECK_EQ(NumDimensions(recurrent_to_input_weights_), 2);
NN_CHECK_EQ(SizeOfDimension(recurrent_to_input_weights_, 0), n_cell);
NN_CHECK_EQ(SizeOfDimension(recurrent_to_input_weights_, 1), n_output);
}
NN_CHECK_EQ(NumDimensions(recurrent_to_forget_weights_), 2);
NN_CHECK_EQ(SizeOfDimension(recurrent_to_forget_weights_, 0), n_cell);
NN_CHECK_EQ(SizeOfDimension(recurrent_to_forget_weights_, 1), n_output);
NN_CHECK_EQ(NumDimensions(recurrent_to_cell_weights_), 2);
NN_CHECK_EQ(SizeOfDimension(recurrent_to_cell_weights_, 0), n_cell);
NN_CHECK_EQ(SizeOfDimension(recurrent_to_cell_weights_, 1), n_output);
// We make sure the input-gate's parameters are either both present (regular
// LSTM) or not at all (CIFG-LSTM).
const bool cifg_weights_all_or_none =
(!IsNullInput(input_to_input_weights_) && !IsNullInput(recurrent_to_input_weights_)) ||
(IsNullInput(input_to_input_weights_) && IsNullInput(recurrent_to_input_weights_));
NN_CHECK(cifg_weights_all_or_none);
if (!IsNullInput(cell_to_input_weights_)) {
NN_CHECK_EQ(NumDimensions(cell_to_input_weights_), 1);
NN_CHECK_EQ(SizeOfDimension(cell_to_input_weights_, 0), n_cell);
}
if (!IsNullInput(cell_to_forget_weights_)) {
NN_CHECK_EQ(NumDimensions(cell_to_forget_weights_), 1);
NN_CHECK_EQ(SizeOfDimension(cell_to_forget_weights_, 0), n_cell);
}
if (!IsNullInput(cell_to_output_weights_)) {
NN_CHECK_EQ(NumDimensions(cell_to_output_weights_), 1);
NN_CHECK_EQ(SizeOfDimension(cell_to_output_weights_, 0), n_cell);
}
// Making sure the peephole weights are there all or none.
const bool use_cifg = IsNullInput(input_to_input_weights_);
const bool peephole_weights_all_or_none =
((!IsNullInput(cell_to_input_weights_) || use_cifg) &&
!IsNullInput(cell_to_forget_weights_) && !IsNullInput(cell_to_output_weights_)) ||
(IsNullInput(cell_to_input_weights_) && IsNullInput(cell_to_forget_weights_) &&
IsNullInput(cell_to_output_weights_));
NN_CHECK(peephole_weights_all_or_none);
// Make sure the input gate bias is present only when not a CIFG-LSTM.
if (use_cifg) {
NN_CHECK(IsNullInput(input_gate_bias_));
} else {
NN_CHECK_EQ(NumDimensions(input_gate_bias_), 1);
NN_CHECK_EQ(SizeOfDimension(input_gate_bias_, 0), n_cell);
}
NN_CHECK_EQ(NumDimensions(forget_gate_bias_), 1);
NN_CHECK_EQ(SizeOfDimension(forget_gate_bias_, 0), n_cell);
NN_CHECK_EQ(NumDimensions(cell_bias_), 1);
NN_CHECK_EQ(SizeOfDimension(cell_bias_, 0), n_cell);
NN_CHECK_EQ(NumDimensions(output_gate_bias_), 1);
NN_CHECK_EQ(SizeOfDimension(output_gate_bias_, 0), n_cell);
if (!IsNullInput(projection_weights_)) {
NN_CHECK_EQ(NumDimensions(projection_weights_), 2);
NN_CHECK_EQ(SizeOfDimension(projection_weights_, 0), n_output);
NN_CHECK_EQ(SizeOfDimension(projection_weights_, 1), n_cell);
}
if (!IsNullInput(projection_bias_)) {
NN_CHECK_EQ(NumDimensions(projection_bias_), 1);
NN_CHECK_EQ(SizeOfDimension(projection_bias_, 0), n_output);
}
// Making sure the projection tensors are consistent:
// 1) If projection weight is not present, then projection bias should not be
// present.
// 2) If projection weight is present, then projection bias is optional.
// TODO: make sure this is correct.
const bool projecton_tensors_consistent =
(!IsNullInput(projection_weights_) || IsNullInput(projection_bias_));
NN_CHECK(projecton_tensors_consistent == true);
if (!IsNullInput(input_layer_norm_weights_)) {
NN_CHECK_EQ(NumDimensions(input_layer_norm_weights_), 1);
NN_CHECK_EQ(SizeOfDimension(input_layer_norm_weights_, 0), n_cell);
}
if (!IsNullInput(forget_layer_norm_weights_)) {
NN_CHECK_EQ(NumDimensions(forget_layer_norm_weights_), 1);
NN_CHECK_EQ(SizeOfDimension(forget_layer_norm_weights_, 0), n_cell);
}
if (!IsNullInput(cell_layer_norm_weights_)) {
NN_CHECK_EQ(NumDimensions(cell_layer_norm_weights_), 1);
NN_CHECK_EQ(SizeOfDimension(cell_layer_norm_weights_, 0), n_cell);
}
if (!IsNullInput(output_layer_norm_weights_)) {
NN_CHECK_EQ(NumDimensions(output_layer_norm_weights_), 1);
NN_CHECK_EQ(SizeOfDimension(output_layer_norm_weights_, 0), n_cell);
}
const bool layer_norm_weights_all_or_none =
(IsNullInput(input_layer_norm_weights_) && IsNullInput(forget_layer_norm_weights_) &&
IsNullInput(cell_layer_norm_weights_) && IsNullInput(input_layer_norm_weights_)) ||
(!IsNullInput(input_layer_norm_weights_) && !IsNullInput(forget_layer_norm_weights_) &&
!IsNullInput(cell_layer_norm_weights_) && !IsNullInput(input_layer_norm_weights_));
NN_CHECK(layer_norm_weights_all_or_none);
return true;
}
bool LSTMCell::Prepare(const Operation& operation, std::vector<RunTimeOperandInfo>& operands,
Shape* scratchShape, Shape* outputStateShape, Shape* cellStateShape,
Shape* outputShape) {
// Check we have all the inputs and outputs we need.
NN_CHECK(NumInputsWithValues(operation, operands) >= 15 &&
NumInputsWithValues(operation, operands) <= 27);
NN_CHECK_EQ(NumOutputs(operation), 4);
// Inferring batch size, number of outputs and number of cells from the
// input tensors.
NN_CHECK(NumDimensions(input_) > 1);
const uint32_t n_batch = SizeOfDimension(input_, 0);
const uint32_t n_input = SizeOfDimension(input_, 1);
const uint32_t n_cell = SizeOfDimension(input_to_output_weights_, 0);
NN_CHECK_EQ(NumDimensions(input_to_output_weights_), 2);
NN_CHECK_EQ(SizeOfDimension(input_to_output_weights_, 1), n_input);
NN_CHECK_EQ(NumDimensions(recurrent_to_output_weights_), 2);
NN_CHECK_EQ(SizeOfDimension(recurrent_to_output_weights_, 0), n_cell);
const uint32_t n_output = SizeOfDimension(recurrent_to_output_weights_, 1);
// Check that input tensor dimensions matches with each other.
if (!CheckInputTensorDimensions(operation, operands, n_input, n_output, n_cell)) {
return false;
}
// Resize the output and output_state tensors.
const Shape& inputShape = input_->shape();
outputShape->type = inputShape.type;
outputShape->dimensions = {n_batch, n_output};
outputShape->offset = inputShape.offset;
outputShape->scale = inputShape.scale;
outputStateShape->type = inputShape.type;
outputStateShape->dimensions = {n_batch, n_output};
outputStateShape->offset = inputShape.offset;
outputStateShape->scale = inputShape.scale;
cellStateShape->type = inputShape.type;
cellStateShape->dimensions = {n_batch, n_cell};
cellStateShape->offset = inputShape.offset;
cellStateShape->scale = inputShape.scale;
const bool use_cifg = IsNullInput(input_to_input_weights_);
if (use_cifg) {
// Reserving space for Cell, Forget, Output gates
scratchShape->dimensions = {n_batch, n_cell * 3};
} else {
// Reserving space for Input, Cell, Forget, Output gates
scratchShape->dimensions = {n_batch, n_cell * 4};
}
scratchShape->type = inputShape.type;
scratchShape->offset = inputShape.offset;
scratchShape->scale = inputShape.scale;
return true;
}
bool LSTMCell::EvalFloat32(
const float* input_buffer, const float* input_to_input_weights_buffer,
const float* input_to_forget_weights_buffer, const float* input_to_cell_weights_buffer,
const float* input_to_output_weights_buffer, const float* recurrent_to_input_weights_buffer,
const float* recurrent_to_forget_weights_buffer,
const float* recurrent_to_cell_weights_buffer,
const float* recurrent_to_output_weights_buffer, const float* cell_to_input_weights_buffer,
const float* cell_to_forget_weights_buffer, const float* cell_to_output_weights_buffer,
const float* input_gate_bias_buffer, const float* forget_gate_bias_buffer,
const float* cell_bias_buffer, const float* output_gate_bias_buffer,
const float* projection_weights_buffer, const float* projection_bias_buffer,
const float* output_state_in_buffer, const float* cell_state_in_buffer,
const float* input_layer_norm_weights_buffer, const float* forget_layer_norm_weights_buffer,
const float* cell_layer_norm_weights_buffer, const float* output_layer_norm_weights_buffer,
float* output_state_out_buffer, float* cell_state_out_buffer, float* output_buffer,
float* scratch_buffer_buffer) {
NNTRACE_COMP("LSTMCell::Eval");
const uint32_t n_batch = input_->shape().dimensions[0];
const uint32_t n_input = input_->shape().dimensions[1];
// n_cell and n_output will be the same size when there is no projection.
const uint32_t n_cell = input_to_output_weights_->shape().dimensions[0];
const uint32_t n_output = recurrent_to_output_weights_->shape().dimensions[1];
// Since we have already checked that weights are all there or none, we can
// check the existence of only one to the get the condition.
const bool use_cifg = IsNullInput(input_to_input_weights_);
const bool use_peephole = !IsNullInput(cell_to_output_weights_);
const bool use_layer_norm = !IsNullInput(input_layer_norm_weights_);
// Index the scratch buffers pointers to the global scratch buffer.
float* input_gate_scratch = nullptr;
float* cell_scratch = nullptr;
float* forget_gate_scratch = nullptr;
float* output_gate_scratch = nullptr;
if (use_cifg) {
cell_scratch = scratch_buffer_buffer;
forget_gate_scratch = cell_scratch + n_cell * n_batch;
output_gate_scratch = cell_scratch + 2 * n_cell * n_batch;
} else {
input_gate_scratch = scratch_buffer_buffer;
cell_scratch = input_gate_scratch + n_cell * n_batch;
forget_gate_scratch = input_gate_scratch + 2 * n_cell * n_batch;
output_gate_scratch = input_gate_scratch + 3 * n_cell * n_batch;
}
if (!use_layer_norm) {
// Initialize scratch buffers with bias.
if (!use_cifg) {
tflite::tensor_utils::VectorBatchVectorAssign(input_gate_bias_buffer, n_cell, n_batch,
input_gate_scratch);
}
tflite::tensor_utils::VectorBatchVectorAssign(forget_gate_bias_buffer, n_cell, n_batch,
forget_gate_scratch);
tflite::tensor_utils::VectorBatchVectorAssign(cell_bias_buffer, n_cell, n_batch,
cell_scratch);
tflite::tensor_utils::VectorBatchVectorAssign(output_gate_bias_buffer, n_cell, n_batch,
output_gate_scratch);
} else {
// Initialize scratch buffers with zeroes.
if (!use_cifg) {
tflite::tensor_utils::ZeroVector(input_gate_scratch, n_cell * n_batch);
}
tflite::tensor_utils::ZeroVector(forget_gate_scratch, n_cell * n_batch);
tflite::tensor_utils::ZeroVector(cell_scratch, n_cell * n_batch);
tflite::tensor_utils::ZeroVector(output_gate_scratch, n_cell * n_batch);
}
// For each batch and cell: compute input_weight * input.
if (!use_cifg) {
tflite::tensor_utils::MatrixBatchVectorMultiplyAccumulate(
input_to_input_weights_buffer, n_cell, n_input, input_buffer, n_batch,
input_gate_scratch, /*result_stride*/ 1);
}
tflite::tensor_utils::MatrixBatchVectorMultiplyAccumulate(
input_to_forget_weights_buffer, n_cell, n_input, input_buffer, n_batch,
forget_gate_scratch, /*result_stride*/ 1);
tflite::tensor_utils::MatrixBatchVectorMultiplyAccumulate(input_to_cell_weights_buffer, n_cell,
n_input, input_buffer, n_batch,
cell_scratch, /*result_stride*/ 1);
tflite::tensor_utils::MatrixBatchVectorMultiplyAccumulate(
input_to_output_weights_buffer, n_cell, n_input, input_buffer, n_batch,
output_gate_scratch, /*result_stride*/ 1);
// For each batch and cell: compute recurrent_weight * output_state.
if (!use_cifg) {
tflite::tensor_utils::MatrixBatchVectorMultiplyAccumulate(
recurrent_to_input_weights_buffer, n_cell, n_output, output_state_in_buffer,
n_batch, input_gate_scratch,
/*result_stride*/ 1);
}
tflite::tensor_utils::MatrixBatchVectorMultiplyAccumulate(
recurrent_to_forget_weights_buffer, n_cell, n_output, output_state_in_buffer, n_batch,
forget_gate_scratch, /*result_stride*/ 1);
tflite::tensor_utils::MatrixBatchVectorMultiplyAccumulate(
recurrent_to_cell_weights_buffer, n_cell, n_output, output_state_in_buffer, n_batch,
cell_scratch, /*result_stride*/ 1);
tflite::tensor_utils::MatrixBatchVectorMultiplyAccumulate(
recurrent_to_output_weights_buffer, n_cell, n_output, output_state_in_buffer, n_batch,
output_gate_scratch, /*result_stride*/ 1);
// For each batch and cell: update input gate.
if (!use_cifg) {
if (use_peephole) {
tflite::tensor_utils::VectorBatchVectorCwiseProductAccumulate(
cell_to_input_weights_buffer, n_cell, cell_state_in_buffer, n_batch,
input_gate_scratch);
}
if (use_layer_norm) {
tflite::tensor_utils::MeanStddevNormalization(input_gate_scratch, input_gate_scratch,
n_cell, n_batch, kLayerNormEpsilon);
tflite::tensor_utils::VectorBatchVectorCwiseProduct(input_layer_norm_weights_buffer,
n_cell, input_gate_scratch, n_batch,
input_gate_scratch);
tflite::tensor_utils::VectorBatchVectorAdd(input_gate_bias_buffer, n_cell, n_batch,
input_gate_scratch);
}
tflite::tensor_utils::ApplySigmoidToVector(input_gate_scratch, n_cell * n_batch,
input_gate_scratch);
}
// For each batch and cell: update forget gate.
if (use_peephole) {
tflite::tensor_utils::VectorBatchVectorCwiseProductAccumulate(cell_to_forget_weights_buffer,
n_cell, cell_state_in_buffer,
n_batch, forget_gate_scratch);
}
if (use_layer_norm) {
tflite::tensor_utils::MeanStddevNormalization(forget_gate_scratch, forget_gate_scratch,
n_cell, n_batch, kLayerNormEpsilon);
tflite::tensor_utils::VectorBatchVectorCwiseProduct(forget_layer_norm_weights_buffer,
n_cell, forget_gate_scratch, n_batch,
forget_gate_scratch);
tflite::tensor_utils::VectorBatchVectorAdd(forget_gate_bias_buffer, n_cell, n_batch,
forget_gate_scratch);
}
tflite::tensor_utils::ApplySigmoidToVector(forget_gate_scratch, n_cell * n_batch,
forget_gate_scratch);
// For each batch and cell: update the cell.
if (use_layer_norm) {
tflite::tensor_utils::MeanStddevNormalization(cell_scratch, cell_scratch, n_cell, n_batch,
kLayerNormEpsilon);
tflite::tensor_utils::VectorBatchVectorCwiseProduct(cell_layer_norm_weights_buffer, n_cell,
cell_scratch, n_batch, cell_scratch);
tflite::tensor_utils::VectorBatchVectorAdd(cell_bias_buffer, n_cell, n_batch, cell_scratch);
}
tflite::tensor_utils::VectorVectorCwiseProduct(forget_gate_scratch, cell_state_in_buffer,
n_batch * n_cell, cell_state_out_buffer);
tflite::tensor_utils::ApplyActivationToVector(cell_scratch, n_batch * n_cell,
params_.activation_, cell_scratch);
if (use_cifg) {
tflite::tensor_utils::Sub1Vector(forget_gate_scratch, n_batch * n_cell,
forget_gate_scratch);
tflite::tensor_utils::VectorVectorCwiseProductAccumulate(
cell_scratch, forget_gate_scratch, n_batch * n_cell, cell_state_out_buffer);
} else {
tflite::tensor_utils::VectorVectorCwiseProductAccumulate(
cell_scratch, input_gate_scratch, n_batch * n_cell, cell_state_out_buffer);
}
if (params_.cell_clip_ > 0.0) {
tflite::tensor_utils::ClipVector(cell_state_out_buffer, n_batch * n_cell,
params_.cell_clip_, cell_state_out_buffer);
}
// For each batch and cell: update the output gate.
if (use_peephole) {
tflite::tensor_utils::VectorBatchVectorCwiseProductAccumulate(cell_to_output_weights_buffer,
n_cell, cell_state_out_buffer,
n_batch, output_gate_scratch);
}
if (use_layer_norm) {
tflite::tensor_utils::MeanStddevNormalization(output_gate_scratch, output_gate_scratch,
n_cell, n_batch, kLayerNormEpsilon);
tflite::tensor_utils::VectorBatchVectorCwiseProduct(output_layer_norm_weights_buffer,
n_cell, output_gate_scratch, n_batch,
output_gate_scratch);
tflite::tensor_utils::VectorBatchVectorAdd(output_gate_bias_buffer, n_cell, n_batch,
output_gate_scratch);
}
tflite::tensor_utils::ApplySigmoidToVector(output_gate_scratch, n_batch * n_cell,
output_gate_scratch);
tflite::tensor_utils::ApplyActivationToVector(cell_state_out_buffer, n_batch * n_cell,
params_.activation_, cell_scratch);
tflite::tensor_utils::VectorVectorCwiseProduct(output_gate_scratch, cell_scratch,
n_batch * n_cell, output_gate_scratch);
// For each batch: update the projection and output_state.
const bool use_projection_weight = (projection_weights_->lifetime != OperandLifeTime::NO_VALUE);
const bool use_projection_bias = (projection_bias_->lifetime != OperandLifeTime::NO_VALUE);
if (use_projection_weight) {
if (use_projection_bias) {
tflite::tensor_utils::VectorBatchVectorAssign(projection_bias_buffer, n_output, n_batch,
output_buffer);
} else {
tflite::tensor_utils::ZeroVector(output_buffer, n_batch * n_output);
}
tflite::tensor_utils::MatrixBatchVectorMultiplyAccumulate(
projection_weights_buffer, n_output, n_cell, output_gate_scratch, n_batch,
output_buffer,
/*result_stride*/ 1);
if (params_.proj_clip_ > 0.0) {
tflite::tensor_utils::ClipVector(output_buffer, n_batch * n_output, params_.proj_clip_,
output_buffer);
}
} else {
tflite::tensor_utils::CopyVector(output_gate_scratch, n_batch * n_output, output_buffer);
}
tflite::tensor_utils::CopyVector(output_buffer, n_batch * n_output, output_state_out_buffer);
return true;
}
bool LSTMCell::Eval() {
switch (input_->type) {
case OperandType::TENSOR_FLOAT32: {
// clang-format off
EvalFloat32(
GetBuffer<const float>(input_),
GetBuffer<const float>(input_to_input_weights_),
GetBuffer<const float>(input_to_forget_weights_),
GetBuffer<const float>(input_to_cell_weights_),
GetBuffer<const float>(input_to_output_weights_),
GetBuffer<const float>(recurrent_to_input_weights_),
GetBuffer<const float>(recurrent_to_forget_weights_),
GetBuffer<const float>(recurrent_to_cell_weights_),
GetBuffer<const float>(recurrent_to_output_weights_),
GetBuffer<const float>(cell_to_input_weights_),
GetBuffer<const float>(cell_to_forget_weights_),
GetBuffer<const float>(cell_to_output_weights_),
GetBuffer<const float>(input_gate_bias_),
GetBuffer<const float>(forget_gate_bias_),
GetBuffer<const float>(cell_bias_),
GetBuffer<const float>(output_gate_bias_),
GetBuffer<const float>(projection_weights_),
GetBuffer<const float>(projection_bias_),
GetBuffer<const float>(output_state_in_),
GetBuffer<const float>(cell_state_in_),
GetBuffer<const float>(input_layer_norm_weights_),
GetBuffer<const float>(forget_layer_norm_weights_),
GetBuffer<const float>(cell_layer_norm_weights_),
GetBuffer<const float>(output_layer_norm_weights_),
GetBuffer<float>(output_state_out_),
GetBuffer<float>(cell_state_out_),
GetBuffer<float>(output_),
GetBuffer<float>(scratch_buffer_));
// clang-format on
} break;
case OperandType::TENSOR_FLOAT16: {
std::vector<float> input_float32(getNumberOfElements(input_->shape()));
convertFloat16ToFloat32(GetBuffer<_Float16>(input_), &input_float32);
const float* input_to_input_weights_buffer = nullptr;
std::vector<float> input_to_input_weights_float32(
getNumberOfElements(input_to_input_weights_->shape()));
if (!IsNullInput(input_to_input_weights_)) {
convertFloat16ToFloat32(GetBuffer<_Float16>(input_to_input_weights_),
&input_to_input_weights_float32);
input_to_input_weights_buffer = input_to_input_weights_float32.data();
}
std::vector<float> input_to_forget_weights_float32(
getNumberOfElements(input_to_forget_weights_->shape()));
convertFloat16ToFloat32(GetBuffer<_Float16>(input_to_forget_weights_),
&input_to_forget_weights_float32);
std::vector<float> input_to_cell_weights_float32(
getNumberOfElements(input_to_cell_weights_->shape()));
convertFloat16ToFloat32(GetBuffer<_Float16>(input_to_cell_weights_),
&input_to_cell_weights_float32);
std::vector<float> input_to_output_weights_float32(
getNumberOfElements(input_to_output_weights_->shape()));
convertFloat16ToFloat32(GetBuffer<_Float16>(input_to_output_weights_),
&input_to_output_weights_float32);
const float* recurrent_to_input_weights_buffer = nullptr;
std::vector<float> recurrent_to_input_weights_float32(
getNumberOfElements(recurrent_to_input_weights_->shape()));
if (!IsNullInput(recurrent_to_input_weights_)) {
convertFloat16ToFloat32(GetBuffer<_Float16>(recurrent_to_input_weights_),
&recurrent_to_input_weights_float32);
recurrent_to_input_weights_buffer = recurrent_to_input_weights_float32.data();
}
std::vector<float> recurrent_to_forget_weights_float32(
getNumberOfElements(recurrent_to_forget_weights_->shape()));
convertFloat16ToFloat32(GetBuffer<_Float16>(recurrent_to_forget_weights_),
&recurrent_to_forget_weights_float32);
std::vector<float> recurrent_to_cell_weights_float32(
getNumberOfElements(recurrent_to_cell_weights_->shape()));
convertFloat16ToFloat32(GetBuffer<_Float16>(recurrent_to_cell_weights_),
&recurrent_to_cell_weights_float32);
std::vector<float> recurrent_to_output_weights_float32(
getNumberOfElements(recurrent_to_output_weights_->shape()));
convertFloat16ToFloat32(GetBuffer<_Float16>(recurrent_to_output_weights_),
&recurrent_to_output_weights_float32);
const float* cell_to_input_weights_buffer = nullptr;
std::vector<float> cell_to_input_weights_float32(
getNumberOfElements(cell_to_input_weights_->shape()));
if (!IsNullInput(cell_to_input_weights_)) {
convertFloat16ToFloat32(GetBuffer<_Float16>(cell_to_input_weights_),
&cell_to_input_weights_float32);
cell_to_input_weights_buffer = cell_to_input_weights_float32.data();
}
const float* cell_to_forget_weights_buffer = nullptr;
std::vector<float> cell_to_forget_weights_float32(
getNumberOfElements(cell_to_forget_weights_->shape()));
if (!IsNullInput(cell_to_forget_weights_)) {
convertFloat16ToFloat32(GetBuffer<_Float16>(cell_to_forget_weights_),
&cell_to_forget_weights_float32);
cell_to_forget_weights_buffer = cell_to_forget_weights_float32.data();
}
const float* cell_to_output_weights_buffer = nullptr;
std::vector<float> cell_to_output_weights_float32(
getNumberOfElements(cell_to_output_weights_->shape()));
if (!IsNullInput(cell_to_output_weights_)) {
convertFloat16ToFloat32(GetBuffer<_Float16>(cell_to_output_weights_),
&cell_to_output_weights_float32);
cell_to_output_weights_buffer = cell_to_output_weights_float32.data();
}
const float* input_gate_bias_buffer = nullptr;
std::vector<float> input_gate_bias_float32(
getNumberOfElements(input_gate_bias_->shape()));
if (!IsNullInput(input_gate_bias_)) {
convertFloat16ToFloat32(GetBuffer<_Float16>(input_gate_bias_),
&input_gate_bias_float32);
input_gate_bias_buffer = input_gate_bias_float32.data();
}
std::vector<float> forget_gate_bias_float32(
getNumberOfElements(forget_gate_bias_->shape()));
convertFloat16ToFloat32(GetBuffer<_Float16>(forget_gate_bias_),
&forget_gate_bias_float32);
std::vector<float> cell_bias_float32(getNumberOfElements(cell_bias_->shape()));
convertFloat16ToFloat32(GetBuffer<_Float16>(cell_bias_), &cell_bias_float32);
std::vector<float> output_gate_bias_float32(
getNumberOfElements(output_gate_bias_->shape()));
convertFloat16ToFloat32(GetBuffer<_Float16>(output_gate_bias_),
&output_gate_bias_float32);
const float* projection_weights_buffer = nullptr;
std::vector<float> projection_weights_float32(
getNumberOfElements(projection_weights_->shape()));
if (!IsNullInput(projection_weights_)) {
convertFloat16ToFloat32(GetBuffer<_Float16>(projection_weights_),
&projection_weights_float32);
projection_weights_buffer = projection_weights_float32.data();
}
const float* projection_bias_buffer = nullptr;
std::vector<float> projection_bias_float32(
getNumberOfElements(projection_bias_->shape()));
if (!IsNullInput(projection_bias_)) {
convertFloat16ToFloat32(GetBuffer<_Float16>(projection_bias_),
&projection_bias_float32);
projection_bias_buffer = projection_bias_float32.data();
}
std::vector<float> output_state_in_float32(
getNumberOfElements(output_state_in_->shape()));
convertFloat16ToFloat32(GetBuffer<_Float16>(output_state_in_),
&output_state_in_float32);
std::vector<float> cell_state_in_float32(getNumberOfElements(cell_state_in_->shape()));
convertFloat16ToFloat32(GetBuffer<_Float16>(cell_state_in_), &cell_state_in_float32);
const float* input_layer_norm_weights_buffer = nullptr;
std::vector<float> input_layer_norm_weights_float32(
getNumberOfElements(input_layer_norm_weights_->shape()));
if (!IsNullInput(input_layer_norm_weights_)) {
convertFloat16ToFloat32(GetBuffer<_Float16>(input_layer_norm_weights_),
&input_layer_norm_weights_float32);
input_layer_norm_weights_buffer = input_layer_norm_weights_float32.data();
}
const float* forget_layer_norm_weights_buffer = nullptr;
std::vector<float> forget_layer_norm_weights_float32(
getNumberOfElements(forget_layer_norm_weights_->shape()));
if (!IsNullInput(forget_layer_norm_weights_)) {
convertFloat16ToFloat32(GetBuffer<_Float16>(forget_layer_norm_weights_),
&forget_layer_norm_weights_float32);
forget_layer_norm_weights_buffer = forget_layer_norm_weights_float32.data();
}
const float* cell_layer_norm_weights_buffer = nullptr;
std::vector<float> cell_layer_norm_weights_float32(
getNumberOfElements(cell_layer_norm_weights_->shape()));
if (!IsNullInput(cell_layer_norm_weights_)) {
convertFloat16ToFloat32(GetBuffer<_Float16>(cell_layer_norm_weights_),
&cell_layer_norm_weights_float32);
cell_layer_norm_weights_buffer = cell_layer_norm_weights_float32.data();
}
const float* output_layer_norm_weights_buffer = nullptr;
std::vector<float> output_layer_norm_weights_float32(
getNumberOfElements(output_layer_norm_weights_->shape()));
if (!IsNullInput(output_layer_norm_weights_)) {
convertFloat16ToFloat32(GetBuffer<_Float16>(output_layer_norm_weights_),
&output_layer_norm_weights_float32);
output_layer_norm_weights_buffer = output_layer_norm_weights_float32.data();
}
std::vector<float> output_state_out_float32(
getNumberOfElements(output_state_out_->shape()));
convertFloat16ToFloat32(GetBuffer<_Float16>(output_state_out_),
&output_state_out_float32);
std::vector<float> cell_state_out_float32(
getNumberOfElements(cell_state_out_->shape()));
convertFloat16ToFloat32(GetBuffer<_Float16>(cell_state_out_), &cell_state_out_float32);
std::vector<float> output_float32(getNumberOfElements(output_->shape()));
convertFloat16ToFloat32(GetBuffer<_Float16>(output_), &output_float32);
std::vector<float> scratch_buffer_float32(
getNumberOfElements(scratch_buffer_->shape()));
convertFloat16ToFloat32(GetBuffer<_Float16>(scratch_buffer_), &scratch_buffer_float32);
// clang-format off
EvalFloat32(
input_float32.data(),
input_to_input_weights_buffer,
input_to_forget_weights_float32.data(),
input_to_cell_weights_float32.data(),
input_to_output_weights_float32.data(),
recurrent_to_input_weights_buffer,
recurrent_to_forget_weights_float32.data(),
recurrent_to_cell_weights_float32.data(),
recurrent_to_output_weights_float32.data(),
cell_to_input_weights_buffer,
cell_to_forget_weights_buffer,
cell_to_output_weights_buffer,
input_gate_bias_buffer,
forget_gate_bias_float32.data(),
cell_bias_float32.data(),
output_gate_bias_float32.data(),
projection_weights_buffer,
projection_bias_buffer,
output_state_in_float32.data(),
cell_state_in_float32.data(),
input_layer_norm_weights_buffer,
forget_layer_norm_weights_buffer,
cell_layer_norm_weights_buffer,
output_layer_norm_weights_buffer,
output_state_out_float32.data(),
cell_state_out_float32.data(),
output_float32.data(),
scratch_buffer_float32.data());
// clang-format on
convertFloat32ToFloat16(output_state_out_float32,
GetBuffer<_Float16>(output_state_out_));
convertFloat32ToFloat16(cell_state_out_float32, GetBuffer<_Float16>(cell_state_out_));
convertFloat32ToFloat16(output_float32, GetBuffer<_Float16>(output_));
convertFloat32ToFloat16(scratch_buffer_float32, GetBuffer<_Float16>(scratch_buffer_));
} break;
default: {
LOG(ERROR) << "Unsupported data type: " << static_cast<int>(input_->type);
return false;
}
}
return true;
}
} // namespace nn
} // namespace android