| /* |
| * Copyright (C) 2017 The Android Open Source Project |
| * |
| * Licensed under the Apache License, Version 2.0 (the "License"); |
| * you may not use this file except in compliance with the License. |
| * You may obtain a copy of the License at |
| * |
| * http://www.apache.org/licenses/LICENSE-2.0 |
| * |
| * Unless required by applicable law or agreed to in writing, software |
| * distributed under the License is distributed on an "AS IS" BASIS, |
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| * See the License for the specific language governing permissions and |
| * limitations under the License. |
| */ |
| |
| #define LOG_TAG "Operations" |
| |
| #include "LSTM.h" |
| |
| #pragma clang diagnostic push |
| #pragma clang diagnostic ignored "-Wunused-parameter" |
| #include <tensorflow/lite/kernels/internal/reference/portable_tensor_utils.h> |
| #pragma clang diagnostic pop |
| |
| #include <tensorflow/lite/kernels/internal/tensor_utils.h> |
| |
| #include <vector> |
| |
| #include "CpuExecutor.h" |
| #include "CpuOperationUtils.h" |
| #include "LegacyUtils.h" |
| #include "OperationsExecutionUtils.h" |
| #include "Tracing.h" |
| #include "nnapi/Types.h" |
| |
| namespace android { |
| namespace nn { |
| |
| namespace { |
| |
| template <typename T> |
| inline T* GetBuffer(RunTimeOperandInfo* operand) { |
| return reinterpret_cast<T*>(operand->buffer); |
| } |
| |
| template <typename T> |
| inline const T* GetBuffer(const RunTimeOperandInfo* operand) { |
| return reinterpret_cast<const T*>(operand->buffer); |
| } |
| |
| template <typename T> |
| inline const T* GetOptionalBuffer(const RunTimeOperandInfo* operand) { |
| return !IsNullInput(operand) ? reinterpret_cast<const T*>(operand->buffer) : nullptr; |
| } |
| |
| } // anonymous namespace |
| |
| LSTMCell::LSTMCell(const Operation& operation, RunTimeOperandInfo* operands) { |
| input_ = GetInput(operation, operands, kInputTensor); |
| |
| input_to_input_weights_ = |
| GetInput(operation, operands, kInputToInputWeightsTensor); // optional |
| input_to_forget_weights_ = GetInput(operation, operands, kInputToForgetWeightsTensor); |
| input_to_cell_weights_ = GetInput(operation, operands, kInputToCellWeightsTensor); |
| input_to_output_weights_ = GetInput(operation, operands, kInputToOutputWeightsTensor); |
| |
| recurrent_to_input_weights_ = |
| GetInput(operation, operands, kRecurrentToInputWeightsTensor); // optional |
| recurrent_to_forget_weights_ = GetInput(operation, operands, kRecurrentToForgetWeightsTensor); |
| recurrent_to_cell_weights_ = GetInput(operation, operands, kRecurrentToCellWeightsTensor); |
| recurrent_to_output_weights_ = GetInput(operation, operands, kRecurrentToOutputWeightsTensor); |
| |
| cell_to_input_weights_ = GetInput(operation, operands, kCellToInputWeightsTensor); // optional |
| cell_to_forget_weights_ = |
| GetInput(operation, operands, kCellToForgetWeightsTensor); // optional |
| cell_to_output_weights_ = |
| GetInput(operation, operands, kCellToOutputWeightsTensor); // optional |
| |
| input_gate_bias_ = GetInput(operation, operands, kInputGateBiasTensor); |
| forget_gate_bias_ = GetInput(operation, operands, kForgetGateBiasTensor); |
| cell_bias_ = GetInput(operation, operands, kCellGateBiasTensor); |
| output_gate_bias_ = GetInput(operation, operands, kOutputGateBiasTensor); |
| |
| projection_weights_ = GetInput(operation, operands, kProjectionWeightsTensor); // optional |
| projection_bias_ = GetInput(operation, operands, kProjectionBiasTensor); // optional |
| |
| output_state_in_ = GetInput(operation, operands, kOutputStateInTensor); |
| cell_state_in_ = GetInput(operation, operands, kCellStateInTensor); |
| |
| const auto& activationOperand = *GetInput(operation, operands, kActivationParam); |
| params_.activation = static_cast<ActivationFn>(getScalarDataWithDefault<int32_t>( |
| activationOperand, TfLiteFusedActivation::kTfLiteActNone)); |
| |
| const auto& cellClipOperand = *GetInput(operation, operands, kCellClipParam); |
| const auto& projClipOperand = *GetInput(operation, operands, kProjClipParam); |
| if (input_->type == OperandType::TENSOR_FLOAT32) { |
| params_.cell_clip = getScalarDataWithDefault<float>(cellClipOperand, 0.0f); |
| params_.proj_clip = getScalarDataWithDefault<float>(projClipOperand, 0.0f); |
| } else { |
| params_.cell_clip = |
| static_cast<float>(getScalarDataWithDefault<_Float16>(cellClipOperand, 0.0f)); |
| params_.proj_clip = |
| static_cast<float>(getScalarDataWithDefault<_Float16>(projClipOperand, 0.0f)); |
| } |
| |
| // We check the version of LSTM by checking the number of the inputs to the |
| // op. For LSTM version 1.0 there were 23 inputs and for 1.2 there are 27. |
| if (operation.inputs.size() == 27) { |
| input_layer_norm_weights_ = |
| GetInput(operation, operands, kInputLayerNormWeightsTensor); // optional |
| forget_layer_norm_weights_ = |
| GetInput(operation, operands, kForgetLayerNormWeightsTensor); // optional |
| cell_layer_norm_weights_ = |
| GetInput(operation, operands, kCellLayerNormWeightsTensor); // optional |
| output_layer_norm_weights_ = |
| GetInput(operation, operands, kOutputLayerNormWeightsTensor); // optional |
| } else { |
| // For LSTM from HAL v1.0 assign operands with no values |
| static RunTimeOperandInfo no_value; |
| no_value.lifetime = Operand::LifeTime::NO_VALUE; |
| |
| input_layer_norm_weights_ = &no_value; |
| forget_layer_norm_weights_ = &no_value; |
| cell_layer_norm_weights_ = &no_value; |
| output_layer_norm_weights_ = &no_value; |
| } |
| |
| output_state_out_ = GetOutput(operation, operands, kOutputStateOutTensor); |
| cell_state_out_ = GetOutput(operation, operands, kCellStateOutTensor); |
| output_ = GetOutput(operation, operands, kOutputTensor); |
| |
| scratch_buffer_ = GetOutput(operation, operands, kScratchBufferTensor); |
| } |
| |
| // static |
| bool LSTMCell::CheckInputTensorDimensions( |
| const RunTimeOperandInfo* /*input_*/, const RunTimeOperandInfo* input_to_input_weights, |
| const RunTimeOperandInfo* input_to_forget_weights, |
| const RunTimeOperandInfo* input_to_cell_weights, |
| const RunTimeOperandInfo* /*input_to_output_weights*/, |
| const RunTimeOperandInfo* recurrent_to_input_weights, |
| const RunTimeOperandInfo* recurrent_to_forget_weights, |
| const RunTimeOperandInfo* recurrent_to_cell_weights, |
| const RunTimeOperandInfo* /*recurrent_to_output_weights*/, |
| const RunTimeOperandInfo* cell_to_input_weights, |
| const RunTimeOperandInfo* cell_to_forget_weights, |
| const RunTimeOperandInfo* cell_to_output_weights, const RunTimeOperandInfo* input_gate_bias, |
| const RunTimeOperandInfo* forget_gate_bias, const RunTimeOperandInfo* cell_bias, |
| const RunTimeOperandInfo* output_gate_bias, const RunTimeOperandInfo* projection_weights, |
| const RunTimeOperandInfo* projection_bias, |
| const RunTimeOperandInfo* input_layer_norm_weights, |
| const RunTimeOperandInfo* forget_layer_norm_weights, |
| const RunTimeOperandInfo* cell_layer_norm_weights, |
| const RunTimeOperandInfo* output_layer_norm_weights, uint32_t n_input, uint32_t n_output, |
| uint32_t n_cell, LSTMParams* params) { |
| // Making sure clipping parameters have valid values. |
| // == 0 means no clipping |
| // > 0 means clipping |
| NN_CHECK(params->cell_clip >= 0); |
| NN_CHECK(params->proj_clip >= 0); |
| |
| if (!IsNullInput(input_to_input_weights)) { |
| NN_CHECK_EQ(NumDimensions(input_to_input_weights), 2u); |
| NN_CHECK_EQ(SizeOfDimension(input_to_input_weights, 0), n_cell); |
| NN_CHECK_EQ(SizeOfDimension(input_to_input_weights, 1), n_input); |
| } |
| |
| NN_CHECK_EQ(NumDimensions(input_to_forget_weights), 2u); |
| NN_CHECK_EQ(SizeOfDimension(input_to_forget_weights, 0), n_cell); |
| NN_CHECK_EQ(SizeOfDimension(input_to_forget_weights, 1), n_input); |
| |
| NN_CHECK_EQ(NumDimensions(input_to_cell_weights), 2u); |
| NN_CHECK_EQ(SizeOfDimension(input_to_cell_weights, 0), n_cell); |
| NN_CHECK_EQ(SizeOfDimension(input_to_cell_weights, 1), n_input); |
| |
| if (!IsNullInput(recurrent_to_input_weights)) { |
| NN_CHECK_EQ(NumDimensions(recurrent_to_input_weights), 2u); |
| NN_CHECK_EQ(SizeOfDimension(recurrent_to_input_weights, 0), n_cell); |
| NN_CHECK_EQ(SizeOfDimension(recurrent_to_input_weights, 1), n_output); |
| } |
| |
| NN_CHECK_EQ(NumDimensions(recurrent_to_forget_weights), 2u); |
| NN_CHECK_EQ(SizeOfDimension(recurrent_to_forget_weights, 0), n_cell); |
| NN_CHECK_EQ(SizeOfDimension(recurrent_to_forget_weights, 1), n_output); |
| |
| NN_CHECK_EQ(NumDimensions(recurrent_to_cell_weights), 2u); |
| NN_CHECK_EQ(SizeOfDimension(recurrent_to_cell_weights, 0), n_cell); |
| NN_CHECK_EQ(SizeOfDimension(recurrent_to_cell_weights, 1), n_output); |
| |
| // We make sure the input-gate's parameters are either both present (regular |
| // LSTM) or not at all (CIFG-LSTM). |
| const bool cifg_weights_all_or_none = |
| (!IsNullInput(input_to_input_weights) && !IsNullInput(recurrent_to_input_weights)) || |
| (IsNullInput(input_to_input_weights) && IsNullInput(recurrent_to_input_weights)); |
| NN_CHECK(cifg_weights_all_or_none); |
| |
| if (!IsNullInput(cell_to_input_weights)) { |
| NN_CHECK_EQ(NumDimensions(cell_to_input_weights), 1u); |
| NN_CHECK_EQ(SizeOfDimension(cell_to_input_weights, 0), n_cell); |
| } |
| |
| if (!IsNullInput(cell_to_forget_weights)) { |
| NN_CHECK_EQ(NumDimensions(cell_to_forget_weights), 1u); |
| NN_CHECK_EQ(SizeOfDimension(cell_to_forget_weights, 0), n_cell); |
| } |
| |
| if (!IsNullInput(cell_to_output_weights)) { |
| NN_CHECK_EQ(NumDimensions(cell_to_output_weights), 1u); |
| NN_CHECK_EQ(SizeOfDimension(cell_to_output_weights, 0), n_cell); |
| } |
| |
| // Making sure the peephole weights are there all or none. |
| params->use_cifg = IsNullInput(input_to_input_weights); |
| const bool peephole_weights_all_or_none = |
| ((!IsNullInput(cell_to_input_weights) || params->use_cifg) && |
| !IsNullInput(cell_to_forget_weights) && !IsNullInput(cell_to_output_weights)) || |
| (IsNullInput(cell_to_input_weights) && IsNullInput(cell_to_forget_weights) && |
| IsNullInput(cell_to_output_weights)); |
| NN_CHECK(peephole_weights_all_or_none); |
| |
| // Since we have already checked that weights are all there or none, we can |
| // check the existence of only one to the get the condition. |
| params->use_peephole = !IsNullInput(cell_to_output_weights); |
| // Checking output instead of input layer norm weights because input can be |
| // omitted ones can be omited in case CIFG LSTM is used. |
| params->use_layer_norm = !IsNullInput(output_layer_norm_weights); |
| |
| params->use_projection_weight = (projection_weights->lifetime != Operand::LifeTime::NO_VALUE); |
| params->use_projection_bias = (projection_bias->lifetime != Operand::LifeTime::NO_VALUE); |
| |
| // Make sure the input gate bias is present only when not a CIFG-LSTM. |
| if (params->use_cifg) { |
| NN_CHECK(IsNullInput(input_gate_bias)); |
| } else { |
| NN_CHECK_EQ(NumDimensions(input_gate_bias), 1u); |
| NN_CHECK_EQ(SizeOfDimension(input_gate_bias, 0), n_cell); |
| } |
| |
| NN_CHECK_EQ(NumDimensions(forget_gate_bias), 1u); |
| NN_CHECK_EQ(SizeOfDimension(forget_gate_bias, 0), n_cell); |
| |
| NN_CHECK_EQ(NumDimensions(cell_bias), 1u); |
| NN_CHECK_EQ(SizeOfDimension(cell_bias, 0), n_cell); |
| |
| NN_CHECK_EQ(NumDimensions(output_gate_bias), 1u); |
| NN_CHECK_EQ(SizeOfDimension(output_gate_bias, 0), n_cell); |
| |
| if (!IsNullInput(projection_weights)) { |
| NN_CHECK_EQ(NumDimensions(projection_weights), 2u); |
| NN_CHECK_EQ(SizeOfDimension(projection_weights, 0), n_output); |
| NN_CHECK_EQ(SizeOfDimension(projection_weights, 1), n_cell); |
| } |
| |
| if (!IsNullInput(projection_bias)) { |
| NN_CHECK_EQ(NumDimensions(projection_bias), 1u); |
| NN_CHECK_EQ(SizeOfDimension(projection_bias, 0), n_output); |
| } |
| |
| // Making sure the projection tensors are consistent: |
| // 1) If projection weight is not present, then projection bias should not be |
| // present. |
| // 2) If projection weight is present, then projection bias is optional. |
| // TODO: make sure this is correct. |
| const bool projecton_tensors_consistent = |
| (!IsNullInput(projection_weights) || IsNullInput(projection_bias)); |
| NN_CHECK(projecton_tensors_consistent == true); |
| |
| if (!IsNullInput(input_layer_norm_weights)) { |
| NN_CHECK_EQ(NumDimensions(input_layer_norm_weights), 1u); |
| NN_CHECK_EQ(SizeOfDimension(input_layer_norm_weights, 0), n_cell); |
| } |
| if (!IsNullInput(forget_layer_norm_weights)) { |
| NN_CHECK_EQ(NumDimensions(forget_layer_norm_weights), 1u); |
| NN_CHECK_EQ(SizeOfDimension(forget_layer_norm_weights, 0), n_cell); |
| } |
| if (!IsNullInput(cell_layer_norm_weights)) { |
| NN_CHECK_EQ(NumDimensions(cell_layer_norm_weights), 1u); |
| NN_CHECK_EQ(SizeOfDimension(cell_layer_norm_weights, 0), n_cell); |
| } |
| if (!IsNullInput(output_layer_norm_weights)) { |
| NN_CHECK_EQ(NumDimensions(output_layer_norm_weights), 1u); |
| NN_CHECK_EQ(SizeOfDimension(output_layer_norm_weights, 0), n_cell); |
| } |
| |
| if (params->use_cifg) { |
| NN_RET_CHECK(IsNullInput(input_layer_norm_weights)) |
| << "input_layer_norm_weights are provided while CIFG is used"; |
| const bool layer_norm_weights_all_or_none_cifg = |
| (IsNullInput(forget_layer_norm_weights) && IsNullInput(cell_layer_norm_weights) && |
| IsNullInput(output_layer_norm_weights)) || |
| (!IsNullInput(forget_layer_norm_weights) && !IsNullInput(cell_layer_norm_weights) && |
| !IsNullInput(output_layer_norm_weights)); |
| NN_RET_CHECK(layer_norm_weights_all_or_none_cifg); |
| } else { |
| const bool layer_norm_weights_all_or_none = |
| (IsNullInput(input_layer_norm_weights) && IsNullInput(forget_layer_norm_weights) && |
| IsNullInput(cell_layer_norm_weights) && IsNullInput(output_layer_norm_weights)) || |
| (!IsNullInput(input_layer_norm_weights) && |
| !IsNullInput(forget_layer_norm_weights) && !IsNullInput(cell_layer_norm_weights) && |
| !IsNullInput(output_layer_norm_weights)); |
| NN_RET_CHECK(layer_norm_weights_all_or_none); |
| } |
| |
| return true; |
| } |
| |
| bool LSTMCell::Prepare(const Operation& operation, RunTimeOperandInfo* operands, |
| Shape* scratchShape, Shape* outputStateShape, Shape* cellStateShape, |
| Shape* outputShape) { |
| // Check we have all the inputs and outputs we need. |
| NN_CHECK(NumInputsWithValues(operation, operands) >= 15 && |
| NumInputsWithValues(operation, operands) <= 27); |
| constexpr int requiredInputs[] = { |
| kInputTensor, |
| kInputToForgetWeightsTensor, |
| kInputToCellWeightsTensor, |
| kInputToOutputWeightsTensor, |
| kRecurrentToForgetWeightsTensor, |
| kRecurrentToCellWeightsTensor, |
| kRecurrentToOutputWeightsTensor, |
| kForgetGateBiasTensor, |
| kCellGateBiasTensor, |
| kOutputGateBiasTensor, |
| kOutputStateInTensor, |
| kCellStateInTensor, |
| kActivationParam, |
| kCellClipParam, |
| kProjClipParam, |
| }; |
| for (const int requiredInput : requiredInputs) { |
| NN_RET_CHECK(!IsNullInput(GetInput(operation, operands, requiredInput))) |
| << "required input " << requiredInput << " is omitted"; |
| } |
| NN_CHECK_EQ(NumOutputs(operation), 4); |
| |
| // Check that the scalar operands' buffers are large enough. |
| const auto& activationOperand = *GetInput(operation, operands, kActivationParam); |
| NN_RET_CHECK(activationOperand.length >= sizeof(int32_t)); |
| const auto& cellClipOperand = *GetInput(operation, operands, kCellClipParam); |
| const auto& projClipOperand = *GetInput(operation, operands, kProjClipParam); |
| if (input_->type == OperandType::TENSOR_FLOAT32) { |
| NN_RET_CHECK(cellClipOperand.length >= sizeof(float)); |
| NN_RET_CHECK(projClipOperand.length >= sizeof(float)); |
| } else { |
| NN_RET_CHECK(cellClipOperand.length >= sizeof(_Float16)); |
| NN_RET_CHECK(projClipOperand.length >= sizeof(_Float16)); |
| } |
| |
| // Inferring batch size, number of outputs and number of cells from the |
| // input tensors. |
| NN_CHECK(NumDimensions(input_) > 1); |
| const uint32_t n_batch = SizeOfDimension(input_, 0); |
| const uint32_t n_input = SizeOfDimension(input_, 1); |
| |
| const uint32_t n_cell = SizeOfDimension(input_to_output_weights_, 0); |
| NN_CHECK_EQ(NumDimensions(input_to_output_weights_), 2u); |
| NN_CHECK_EQ(SizeOfDimension(input_to_output_weights_, 1), n_input); |
| |
| NN_CHECK_EQ(NumDimensions(recurrent_to_output_weights_), 2u); |
| NN_CHECK_EQ(SizeOfDimension(recurrent_to_output_weights_, 0), n_cell); |
| const uint32_t n_output = SizeOfDimension(recurrent_to_output_weights_, 1); |
| |
| // Check that input tensor dimensions matches with each other. |
| if (!CheckInputTensorDimensions( |
| input_, input_to_input_weights_, input_to_forget_weights_, input_to_cell_weights_, |
| input_to_output_weights_, recurrent_to_input_weights_, recurrent_to_forget_weights_, |
| recurrent_to_cell_weights_, recurrent_to_output_weights_, cell_to_input_weights_, |
| cell_to_forget_weights_, cell_to_output_weights_, input_gate_bias_, |
| forget_gate_bias_, cell_bias_, output_gate_bias_, projection_weights_, |
| projection_bias_, input_layer_norm_weights_, forget_layer_norm_weights_, |
| cell_layer_norm_weights_, output_layer_norm_weights_, n_input, n_output, n_cell, |
| ¶ms_)) { |
| return false; |
| } |
| |
| // Resize the output and output_state tensors. |
| const Shape& inputShape = input_->shape(); |
| |
| outputShape->type = inputShape.type; |
| outputShape->dimensions = {n_batch, n_output}; |
| outputShape->offset = inputShape.offset; |
| outputShape->scale = inputShape.scale; |
| |
| outputStateShape->type = inputShape.type; |
| outputStateShape->dimensions = {n_batch, n_output}; |
| outputStateShape->offset = inputShape.offset; |
| outputStateShape->scale = inputShape.scale; |
| |
| cellStateShape->type = inputShape.type; |
| cellStateShape->dimensions = {n_batch, n_cell}; |
| cellStateShape->offset = inputShape.offset; |
| cellStateShape->scale = inputShape.scale; |
| |
| if (params_.use_cifg) { |
| // Reserving space for Cell, Forget, Output gates |
| scratchShape->dimensions = {n_batch, n_cell * 3}; |
| } else { |
| // Reserving space for Input, Cell, Forget, Output gates |
| scratchShape->dimensions = {n_batch, n_cell * 4}; |
| } |
| scratchShape->type = inputShape.type; |
| scratchShape->offset = inputShape.offset; |
| scratchShape->scale = inputShape.scale; |
| |
| return true; |
| } |
| |
| // static |
| bool LSTMCell::LSTMEvalFloat32( |
| const LSTMParams& params, const float* input_buffer, const Shape& input_shape, |
| const float* input_to_input_weights_buffer, const float* input_to_forget_weights_buffer, |
| const float* input_to_cell_weights_buffer, const float* input_to_output_weights_buffer, |
| const Shape& input_to_output_weights_shape, const float* recurrent_to_input_weights_buffer, |
| const float* recurrent_to_forget_weights_buffer, |
| const float* recurrent_to_cell_weights_buffer, |
| const float* recurrent_to_output_weights_buffer, |
| const Shape& recurrent_to_output_weights_shape, const float* cell_to_input_weights_buffer, |
| const float* cell_to_forget_weights_buffer, const float* cell_to_output_weights_buffer, |
| const float* aux_input_buffer, const float* aux_input_to_input_weights_buffer, |
| const float* aux_input_to_forget_weights_buffer, |
| const float* aux_input_to_cell_weights_buffer, |
| const float* aux_input_to_output_weights_buffer, const float* input_gate_bias_buffer, |
| const float* forget_gate_bias_buffer, const float* cell_bias_buffer, |
| const float* output_gate_bias_buffer, const float* projection_weights_buffer, |
| const float* projection_bias_buffer, const float* output_state_in_buffer, |
| const float* cell_state_in_buffer, const float* input_layer_norm_weights_buffer, |
| const float* forget_layer_norm_weights_buffer, const float* cell_layer_norm_weights_buffer, |
| const float* output_layer_norm_weights_buffer, float* output_state_out_buffer, |
| float* cell_state_out_buffer, float* output_buffer, float* scratch_buffer_buffer, |
| bool timeMajor, bool forwardSequence) { |
| NNTRACE_COMP("LSTMCell::LSTMEvalFloat32"); |
| |
| const uint32_t inputRank = getNumberOfDimensions(input_shape); |
| NN_CHECK(inputRank == 2 || inputRank == 3); |
| |
| const uint32_t maxTime = |
| (inputRank == 3) ? getSizeOfDimension(input_shape, timeMajor ? 0 : 1) : 1; |
| const uint32_t batchSize = (inputRank == 3) ? getSizeOfDimension(input_shape, timeMajor ? 1 : 0) |
| : getSizeOfDimension(input_shape, 0); |
| const uint32_t inputSize = getSizeOfDimension(input_shape, inputRank - 1); |
| const uint32_t numCells = getSizeOfDimension(input_to_output_weights_shape, 0); |
| const uint32_t outputSize = getSizeOfDimension(recurrent_to_output_weights_shape, 1); |
| |
| Shape batchInputShape = input_shape; |
| batchInputShape.dimensions = {batchSize, inputSize}; |
| const uint32_t batchInputSize = batchSize * inputSize; |
| const uint32_t batchOutputSize = batchSize * outputSize; |
| |
| std::vector<float> transposedInput; |
| const bool hasAuxInput = (aux_input_buffer != nullptr); |
| std::vector<float> transposedAuxInput; |
| std::vector<float> transposedOutput; |
| Shape transposedInputShape; |
| Shape transposedOutputShape; |
| if (!timeMajor) { |
| transposedInput.resize(maxTime * batchInputSize); |
| transposeFirstTwoDimensions<float>(input_buffer, input_shape, transposedInput.data()); |
| if (hasAuxInput) { |
| transposedAuxInput.resize(maxTime * batchInputSize); |
| transposeFirstTwoDimensions<float>(aux_input_buffer, input_shape, |
| transposedAuxInput.data()); |
| } |
| transposeFirstTwoDimensions(input_shape, &transposedInputShape); |
| transposedOutput.resize(maxTime * batchOutputSize); |
| transposedOutputShape = transposedInputShape; |
| transposedOutputShape.dimensions[2] = outputSize; |
| } |
| const float* inputData = timeMajor ? input_buffer : transposedInput.data(); |
| const float* auxInputData = |
| hasAuxInput ? (timeMajor ? aux_input_buffer : transposedAuxInput.data()) : nullptr; |
| float* outputData = timeMajor ? output_buffer : transposedOutput.data(); |
| |
| std::vector<float> outputStateInCurrentTimeStep( |
| output_state_in_buffer, output_state_in_buffer + batchSize * outputSize); |
| std::vector<float> cellStateInCurrentTimeStep(cell_state_in_buffer, |
| cell_state_in_buffer + batchSize * numCells); |
| const float* inputCurrentTimeStep = |
| inputData + (forwardSequence ? 0 : batchInputSize * (maxTime - 1)); |
| const float* auxInputCurrentTimeStep = |
| hasAuxInput ? (auxInputData + (forwardSequence ? 0 : batchInputSize * (maxTime - 1))) |
| : nullptr; |
| float* outputCurrentTimeStep = |
| outputData + (forwardSequence ? 0 : batchOutputSize * (maxTime - 1)); |
| const int batchInputDelta = (forwardSequence ? 1 : -1) * static_cast<int>(batchInputSize); |
| const int batchOutputDelta = (forwardSequence ? 1 : -1) * static_cast<int>(batchOutputSize); |
| |
| for (uint32_t t = 0; t < maxTime; ++t) { |
| LSTMStep(params, inputCurrentTimeStep, batchInputShape, input_to_input_weights_buffer, |
| input_to_forget_weights_buffer, input_to_cell_weights_buffer, |
| input_to_output_weights_buffer, input_to_output_weights_shape, |
| recurrent_to_input_weights_buffer, recurrent_to_forget_weights_buffer, |
| recurrent_to_cell_weights_buffer, recurrent_to_output_weights_buffer, |
| recurrent_to_output_weights_shape, cell_to_input_weights_buffer, |
| cell_to_forget_weights_buffer, cell_to_output_weights_buffer, |
| auxInputCurrentTimeStep, aux_input_to_input_weights_buffer, |
| aux_input_to_forget_weights_buffer, aux_input_to_cell_weights_buffer, |
| aux_input_to_output_weights_buffer, input_gate_bias_buffer, |
| forget_gate_bias_buffer, cell_bias_buffer, output_gate_bias_buffer, |
| projection_weights_buffer, projection_bias_buffer, |
| outputStateInCurrentTimeStep.data(), cellStateInCurrentTimeStep.data(), |
| input_layer_norm_weights_buffer, forget_layer_norm_weights_buffer, |
| cell_layer_norm_weights_buffer, output_layer_norm_weights_buffer, |
| output_state_out_buffer, cell_state_out_buffer, outputCurrentTimeStep, |
| scratch_buffer_buffer); |
| inputCurrentTimeStep += batchInputDelta; |
| if (hasAuxInput) { |
| auxInputCurrentTimeStep += batchInputDelta; |
| } |
| outputCurrentTimeStep += batchOutputDelta; |
| outputStateInCurrentTimeStep.assign(output_state_out_buffer, |
| output_state_out_buffer + batchSize * outputSize); |
| cellStateInCurrentTimeStep.assign(cell_state_out_buffer, |
| cell_state_out_buffer + batchSize * numCells); |
| } |
| |
| if (!timeMajor) { |
| transposeFirstTwoDimensions<float>(transposedOutput.data(), transposedOutputShape, |
| output_buffer); |
| } |
| |
| return true; |
| } |
| |
| // static |
| bool LSTMCell::LSTMEvalFloat16( |
| const LSTMParams& params, const _Float16* input_buffer, const Shape& input_shape, |
| const _Float16* input_to_input_weights_buffer, |
| const _Float16* input_to_forget_weights_buffer, |
| const _Float16* input_to_cell_weights_buffer, |
| const _Float16* input_to_output_weights_buffer, const Shape& input_to_output_weights_shape, |
| const _Float16* recurrent_to_input_weights_buffer, |
| const _Float16* recurrent_to_forget_weights_buffer, |
| const _Float16* recurrent_to_cell_weights_buffer, |
| const _Float16* recurrent_to_output_weights_buffer, |
| const Shape& recurrent_to_output_weights_shape, |
| const _Float16* cell_to_input_weights_buffer, const _Float16* cell_to_forget_weights_buffer, |
| const _Float16* cell_to_output_weights_buffer, const _Float16* aux_input_buffer, |
| const _Float16* aux_input_to_input_weights_buffer, |
| const _Float16* aux_input_to_forget_weights_buffer, |
| const _Float16* aux_input_to_cell_weights_buffer, |
| const _Float16* aux_input_to_output_weights_buffer, const _Float16* input_gate_bias_buffer, |
| const _Float16* forget_gate_bias_buffer, const _Float16* cell_bias_buffer, |
| const _Float16* output_gate_bias_buffer, const _Float16* projection_weights_buffer, |
| const _Float16* projection_bias_buffer, const _Float16* output_state_in_buffer, |
| const _Float16* cell_state_in_buffer, const _Float16* input_layer_norm_weights_buffer, |
| const _Float16* forget_layer_norm_weights_buffer, |
| const _Float16* cell_layer_norm_weights_buffer, |
| const _Float16* output_layer_norm_weights_buffer, _Float16* output_state_out_buffer, |
| _Float16* cell_state_out_buffer, _Float16* output_buffer, _Float16* scratch_buffer_buffer, |
| bool timeMajor, bool forwardSequence) { |
| NNTRACE_COMP("LSTMCell::LSTMEvalFloat16"); |
| |
| const uint32_t inputRank = getNumberOfDimensions(input_shape); |
| NN_CHECK(inputRank == 2 || inputRank == 3); |
| |
| const uint32_t maxTime = |
| (inputRank == 3) ? getSizeOfDimension(input_shape, timeMajor ? 0 : 1) : 1; |
| const uint32_t batchSize = (inputRank == 3) ? getSizeOfDimension(input_shape, timeMajor ? 1 : 0) |
| : getSizeOfDimension(input_shape, 0); |
| const uint32_t inputSize = getSizeOfDimension(input_shape, inputRank - 1); |
| const uint32_t numCells = getSizeOfDimension(input_to_output_weights_shape, 0); |
| const uint32_t outputSize = getSizeOfDimension(recurrent_to_output_weights_shape, 1); |
| |
| Shape batchInputShape = input_shape; |
| batchInputShape.dimensions = {batchSize, inputSize}; |
| const uint32_t batchInputSize = batchSize * inputSize; |
| const uint32_t batchOutputSize = batchSize * outputSize; |
| |
| std::vector<float> input_float32(maxTime * batchInputSize); |
| convertFloat16ToFloat32(input_buffer, &input_float32); |
| std::vector<float> input_to_input_weights_float32(numCells * inputSize); |
| if (input_to_input_weights_buffer != nullptr) { |
| convertFloat16ToFloat32(input_to_input_weights_buffer, &input_to_input_weights_float32); |
| } |
| std::vector<float> input_to_forget_weights_float32(numCells * inputSize); |
| convertFloat16ToFloat32(input_to_forget_weights_buffer, &input_to_forget_weights_float32); |
| std::vector<float> input_to_cell_weights_float32(numCells * inputSize); |
| convertFloat16ToFloat32(input_to_cell_weights_buffer, &input_to_cell_weights_float32); |
| std::vector<float> input_to_output_weights_float32(numCells * inputSize); |
| convertFloat16ToFloat32(input_to_output_weights_buffer, &input_to_output_weights_float32); |
| |
| std::vector<float> recurrent_to_input_weights_float32(numCells * outputSize); |
| if (recurrent_to_input_weights_buffer != nullptr) { |
| convertFloat16ToFloat32(recurrent_to_input_weights_buffer, |
| &recurrent_to_input_weights_float32); |
| } |
| std::vector<float> recurrent_to_forget_weights_float32(numCells * outputSize); |
| convertFloat16ToFloat32(recurrent_to_forget_weights_buffer, |
| &recurrent_to_forget_weights_float32); |
| std::vector<float> recurrent_to_cell_weights_float32(numCells * outputSize); |
| convertFloat16ToFloat32(recurrent_to_cell_weights_buffer, &recurrent_to_cell_weights_float32); |
| std::vector<float> recurrent_to_output_weights_float32(numCells * outputSize); |
| convertFloat16ToFloat32(recurrent_to_output_weights_buffer, |
| &recurrent_to_output_weights_float32); |
| |
| std::vector<float> cell_to_input_weights_float32(numCells); |
| if (cell_to_input_weights_buffer != nullptr) { |
| convertFloat16ToFloat32(cell_to_input_weights_buffer, &cell_to_input_weights_float32); |
| } |
| std::vector<float> cell_to_forget_weights_float32(numCells); |
| if (cell_to_forget_weights_buffer != nullptr) { |
| convertFloat16ToFloat32(cell_to_forget_weights_buffer, &cell_to_forget_weights_float32); |
| } |
| std::vector<float> cell_to_output_weights_float32(numCells); |
| if (cell_to_output_weights_buffer != nullptr) { |
| convertFloat16ToFloat32(cell_to_output_weights_buffer, &cell_to_output_weights_float32); |
| } |
| |
| std::vector<float> aux_input_float32(maxTime * batchInputSize); |
| if (aux_input_buffer != nullptr) { |
| convertFloat16ToFloat32(aux_input_buffer, &aux_input_float32); |
| } |
| std::vector<float> aux_input_to_input_weights_float32(numCells * inputSize); |
| if (aux_input_to_input_weights_buffer != nullptr) { |
| convertFloat16ToFloat32(aux_input_to_input_weights_buffer, |
| &aux_input_to_input_weights_float32); |
| } |
| std::vector<float> aux_input_to_forget_weights_float32(numCells * inputSize); |
| if (aux_input_to_forget_weights_buffer != nullptr) { |
| convertFloat16ToFloat32(aux_input_to_forget_weights_buffer, |
| &aux_input_to_forget_weights_float32); |
| } |
| std::vector<float> aux_input_to_cell_weights_float32(numCells * inputSize); |
| if (aux_input_to_cell_weights_buffer != nullptr) { |
| convertFloat16ToFloat32(aux_input_to_cell_weights_buffer, |
| &aux_input_to_cell_weights_float32); |
| } |
| std::vector<float> aux_input_to_output_weights_float32(numCells * inputSize); |
| if (aux_input_to_output_weights_buffer != nullptr) { |
| convertFloat16ToFloat32(aux_input_to_output_weights_buffer, |
| &aux_input_to_output_weights_float32); |
| } |
| |
| std::vector<float> input_gate_bias_float32(numCells); |
| if (input_gate_bias_buffer != nullptr) { |
| convertFloat16ToFloat32(input_gate_bias_buffer, &input_gate_bias_float32); |
| } |
| std::vector<float> forget_gate_bias_float32(numCells); |
| convertFloat16ToFloat32(forget_gate_bias_buffer, &forget_gate_bias_float32); |
| std::vector<float> cell_bias_float32(numCells); |
| convertFloat16ToFloat32(cell_bias_buffer, &cell_bias_float32); |
| std::vector<float> output_gate_bias_float32(numCells); |
| convertFloat16ToFloat32(output_gate_bias_buffer, &output_gate_bias_float32); |
| |
| std::vector<float> projection_weights_float32(numCells * outputSize); |
| if (projection_weights_buffer != nullptr) { |
| convertFloat16ToFloat32(projection_weights_buffer, &projection_weights_float32); |
| } |
| std::vector<float> projection_bias_float32(outputSize); |
| if (projection_bias_buffer != nullptr) { |
| convertFloat16ToFloat32(projection_bias_buffer, &projection_bias_float32); |
| } |
| |
| std::vector<float> input_layer_norm_weights_float32(numCells); |
| if (input_layer_norm_weights_buffer != nullptr) { |
| convertFloat16ToFloat32(input_layer_norm_weights_buffer, &input_layer_norm_weights_float32); |
| } |
| std::vector<float> forget_layer_norm_weights_float32(numCells); |
| if (forget_layer_norm_weights_buffer != nullptr) { |
| convertFloat16ToFloat32(forget_layer_norm_weights_buffer, |
| &forget_layer_norm_weights_float32); |
| } |
| std::vector<float> cell_layer_norm_weights_float32(numCells); |
| if (cell_layer_norm_weights_buffer != nullptr) { |
| convertFloat16ToFloat32(cell_layer_norm_weights_buffer, &cell_layer_norm_weights_float32); |
| } |
| std::vector<float> output_layer_norm_weights_float32(numCells); |
| if (output_layer_norm_weights_buffer != nullptr) { |
| convertFloat16ToFloat32(output_layer_norm_weights_buffer, |
| &output_layer_norm_weights_float32); |
| } |
| |
| std::vector<float> output_state_out_float32(batchOutputSize); |
| convertFloat16ToFloat32(output_state_out_buffer, &output_state_out_float32); |
| std::vector<float> cell_state_out_float32(batchSize * numCells); |
| convertFloat16ToFloat32(cell_state_out_buffer, &cell_state_out_float32); |
| |
| std::vector<float> output_float32(maxTime * batchOutputSize); |
| convertFloat16ToFloat32(output_buffer, &output_float32); |
| std::vector<float> scratch_buffer_float32(params.use_cifg ? 3 * batchSize * numCells |
| : 4 * batchSize * numCells); |
| convertFloat16ToFloat32(scratch_buffer_buffer, &scratch_buffer_float32); |
| |
| std::vector<float> transposedInput; |
| const bool hasAuxInput = (aux_input_buffer != nullptr); |
| std::vector<float> transposedAuxInput; |
| std::vector<float> transposedOutput; |
| Shape transposedInputShape; |
| Shape transposedOutputShape; |
| if (!timeMajor) { |
| transposedInput.resize(maxTime * batchInputSize); |
| transposeFirstTwoDimensions<float>(input_float32.data(), input_shape, |
| transposedInput.data()); |
| if (hasAuxInput) { |
| transposedAuxInput.resize(maxTime * batchInputSize); |
| transposeFirstTwoDimensions<float>(aux_input_float32.data(), input_shape, |
| transposedAuxInput.data()); |
| } |
| transposeFirstTwoDimensions(input_shape, &transposedInputShape); |
| transposedOutput.resize(maxTime * batchOutputSize); |
| transposedOutputShape = transposedInputShape; |
| transposedOutputShape.dimensions[2] = outputSize; |
| } |
| const float* inputData = timeMajor ? input_float32.data() : transposedInput.data(); |
| const float* auxInputData = |
| hasAuxInput ? (timeMajor ? aux_input_float32.data() : transposedAuxInput.data()) |
| : nullptr; |
| float* outputData = timeMajor ? output_float32.data() : transposedOutput.data(); |
| |
| std::vector<float> outputStateInCurrentTimeStep(batchSize * outputSize); |
| convertFloat16ToFloat32(output_state_in_buffer, &outputStateInCurrentTimeStep); |
| std::vector<float> cellStateInCurrentTimeStep(batchSize * numCells); |
| convertFloat16ToFloat32(cell_state_in_buffer, &cellStateInCurrentTimeStep); |
| |
| const float* inputCurrentTimeStep = |
| inputData + (forwardSequence ? 0 : batchInputSize * (maxTime - 1)); |
| const float* auxInputCurrentTimeStep = |
| hasAuxInput ? (auxInputData + (forwardSequence ? 0 : batchInputSize * (maxTime - 1))) |
| : nullptr; |
| float* outputCurrentTimeStep = |
| outputData + (forwardSequence ? 0 : batchOutputSize * (maxTime - 1)); |
| const int batchInputDelta = (forwardSequence ? 1 : -1) * static_cast<int>(batchInputSize); |
| const int batchOutputDelta = (forwardSequence ? 1 : -1) * static_cast<int>(batchOutputSize); |
| |
| for (uint32_t t = 0; t < maxTime; ++t) { |
| LSTMStep(params, inputCurrentTimeStep, batchInputShape, |
| input_to_input_weights_float32.data(), input_to_forget_weights_float32.data(), |
| input_to_cell_weights_float32.data(), input_to_output_weights_float32.data(), |
| input_to_output_weights_shape, recurrent_to_input_weights_float32.data(), |
| recurrent_to_forget_weights_float32.data(), |
| recurrent_to_cell_weights_float32.data(), |
| recurrent_to_output_weights_float32.data(), recurrent_to_output_weights_shape, |
| cell_to_input_weights_float32.data(), cell_to_forget_weights_float32.data(), |
| cell_to_output_weights_float32.data(), auxInputCurrentTimeStep, |
| aux_input_to_input_weights_float32.data(), |
| aux_input_to_forget_weights_float32.data(), |
| aux_input_to_cell_weights_float32.data(), |
| aux_input_to_output_weights_float32.data(), input_gate_bias_float32.data(), |
| forget_gate_bias_float32.data(), cell_bias_float32.data(), |
| output_gate_bias_float32.data(), projection_weights_float32.data(), |
| projection_bias_float32.data(), outputStateInCurrentTimeStep.data(), |
| cellStateInCurrentTimeStep.data(), input_layer_norm_weights_float32.data(), |
| forget_layer_norm_weights_float32.data(), cell_layer_norm_weights_float32.data(), |
| output_layer_norm_weights_float32.data(), output_state_out_float32.data(), |
| cell_state_out_float32.data(), outputCurrentTimeStep, |
| scratch_buffer_float32.data()); |
| inputCurrentTimeStep += batchInputDelta; |
| if (hasAuxInput) { |
| auxInputCurrentTimeStep += batchInputDelta; |
| } |
| outputCurrentTimeStep += batchOutputDelta; |
| outputStateInCurrentTimeStep = output_state_out_float32; |
| cellStateInCurrentTimeStep = cell_state_out_float32; |
| } |
| |
| if (!timeMajor) { |
| transposeFirstTwoDimensions<float>(transposedOutput.data(), transposedOutputShape, |
| output_float32.data()); |
| } |
| |
| convertFloat32ToFloat16(output_state_out_float32, output_state_out_buffer); |
| convertFloat32ToFloat16(cell_state_out_float32, cell_state_out_buffer); |
| convertFloat32ToFloat16(output_float32, output_buffer); |
| convertFloat32ToFloat16(scratch_buffer_float32, scratch_buffer_buffer); |
| return true; |
| } |
| |
| // static |
| bool LSTMCell::LSTMStep( |
| const LSTMParams& params, const float* input_buffer, const Shape& input_shape, |
| const float* input_to_input_weights_buffer, const float* input_to_forget_weights_buffer, |
| const float* input_to_cell_weights_buffer, const float* input_to_output_weights_buffer, |
| const Shape& input_to_output_weights_shape, const float* recurrent_to_input_weights_buffer, |
| const float* recurrent_to_forget_weights_buffer, |
| const float* recurrent_to_cell_weights_buffer, |
| const float* recurrent_to_output_weights_buffer, |
| const Shape& recurrent_to_output_weights_shape, const float* cell_to_input_weights_buffer, |
| const float* cell_to_forget_weights_buffer, const float* cell_to_output_weights_buffer, |
| const float* aux_input_buffer, const float* aux_input_to_input_weights_buffer, |
| const float* aux_input_to_forget_weights_buffer, |
| const float* aux_input_to_cell_weights_buffer, |
| const float* aux_input_to_output_weights_buffer, const float* input_gate_bias_buffer, |
| const float* forget_gate_bias_buffer, const float* cell_bias_buffer, |
| const float* output_gate_bias_buffer, const float* projection_weights_buffer, |
| const float* projection_bias_buffer, const float* output_state_in_buffer, |
| const float* cell_state_in_buffer, const float* input_layer_norm_weights_buffer, |
| const float* forget_layer_norm_weights_buffer, const float* cell_layer_norm_weights_buffer, |
| const float* output_layer_norm_weights_buffer, float* output_state_out_buffer, |
| float* cell_state_out_buffer, float* output_buffer, float* scratch_buffer_buffer) { |
| NNTRACE_COMP("LSTMCell::LSTMStep"); |
| |
| const uint32_t n_batch = input_shape.dimensions[0]; |
| const uint32_t n_input = input_shape.dimensions[1]; |
| // n_cell and n_output will be the same size when there is no projection. |
| const uint32_t n_cell = input_to_output_weights_shape.dimensions[0]; |
| const uint32_t n_output = recurrent_to_output_weights_shape.dimensions[1]; |
| const uint32_t n_aux_input = aux_input_buffer == nullptr ? 0 : n_input; |
| |
| // Index the scratch buffers pointers to the global scratch buffer. |
| float* input_gate_scratch = nullptr; |
| float* cell_scratch = nullptr; |
| float* forget_gate_scratch = nullptr; |
| float* output_gate_scratch = nullptr; |
| if (params.use_cifg) { |
| cell_scratch = scratch_buffer_buffer; |
| forget_gate_scratch = cell_scratch + n_cell * n_batch; |
| output_gate_scratch = cell_scratch + 2 * n_cell * n_batch; |
| } else { |
| input_gate_scratch = scratch_buffer_buffer; |
| cell_scratch = input_gate_scratch + n_cell * n_batch; |
| forget_gate_scratch = input_gate_scratch + 2 * n_cell * n_batch; |
| output_gate_scratch = input_gate_scratch + 3 * n_cell * n_batch; |
| } |
| |
| if (!params.use_layer_norm) { |
| // Initialize scratch buffers with bias. |
| if (!params.use_cifg) { |
| tflite::tensor_utils::VectorBatchVectorAssign(input_gate_bias_buffer, n_cell, n_batch, |
| input_gate_scratch); |
| } |
| tflite::tensor_utils::VectorBatchVectorAssign(forget_gate_bias_buffer, n_cell, n_batch, |
| forget_gate_scratch); |
| tflite::tensor_utils::VectorBatchVectorAssign(cell_bias_buffer, n_cell, n_batch, |
| cell_scratch); |
| tflite::tensor_utils::VectorBatchVectorAssign(output_gate_bias_buffer, n_cell, n_batch, |
| output_gate_scratch); |
| } else { |
| // Initialize scratch buffers with zeroes. |
| if (!params.use_cifg) { |
| std::fill_n(input_gate_scratch, n_cell * n_batch, 0.0f); |
| } |
| std::fill_n(forget_gate_scratch, n_cell * n_batch, 0.0f); |
| std::fill_n(cell_scratch, n_cell * n_batch, 0.0f); |
| std::fill_n(output_gate_scratch, n_cell * n_batch, 0.0f); |
| } |
| |
| // For each batch and cell: compute input_weight * input. |
| if (!params.use_cifg) { |
| tflite::tensor_utils::MatrixBatchVectorMultiplyAccumulate(input_to_input_weights_buffer, |
| n_cell, n_input, input_buffer, |
| n_batch, input_gate_scratch); |
| } |
| tflite::tensor_utils::MatrixBatchVectorMultiplyAccumulate(input_to_forget_weights_buffer, |
| n_cell, n_input, input_buffer, |
| n_batch, forget_gate_scratch); |
| tflite::tensor_utils::MatrixBatchVectorMultiplyAccumulate( |
| input_to_cell_weights_buffer, n_cell, n_input, input_buffer, n_batch, cell_scratch); |
| tflite::tensor_utils::MatrixBatchVectorMultiplyAccumulate(input_to_output_weights_buffer, |
| n_cell, n_input, input_buffer, |
| n_batch, output_gate_scratch); |
| |
| // If auxiliary input is available then compute aux_input_weight * aux_input |
| if (aux_input_buffer != nullptr) { |
| if (!params.use_cifg) { |
| tflite::tensor_utils::MatrixBatchVectorMultiplyAccumulate( |
| aux_input_to_input_weights_buffer, n_cell, n_aux_input, aux_input_buffer, |
| n_batch, input_gate_scratch); |
| } |
| |
| tflite::tensor_utils::MatrixBatchVectorMultiplyAccumulate( |
| aux_input_to_forget_weights_buffer, n_cell, n_aux_input, aux_input_buffer, n_batch, |
| forget_gate_scratch); |
| tflite::tensor_utils::MatrixBatchVectorMultiplyAccumulate( |
| aux_input_to_cell_weights_buffer, n_cell, n_aux_input, aux_input_buffer, n_batch, |
| cell_scratch); |
| tflite::tensor_utils::MatrixBatchVectorMultiplyAccumulate( |
| aux_input_to_output_weights_buffer, n_cell, n_aux_input, aux_input_buffer, n_batch, |
| output_gate_scratch); |
| } |
| |
| // For each batch and cell: compute recurrent_weight * output_state. |
| if (!params.use_cifg) { |
| tflite::tensor_utils::MatrixBatchVectorMultiplyAccumulate( |
| recurrent_to_input_weights_buffer, n_cell, n_output, output_state_in_buffer, |
| n_batch, input_gate_scratch); |
| } |
| tflite::tensor_utils::MatrixBatchVectorMultiplyAccumulate( |
| recurrent_to_forget_weights_buffer, n_cell, n_output, output_state_in_buffer, n_batch, |
| forget_gate_scratch); |
| tflite::tensor_utils::MatrixBatchVectorMultiplyAccumulate( |
| recurrent_to_cell_weights_buffer, n_cell, n_output, output_state_in_buffer, n_batch, |
| cell_scratch); |
| tflite::tensor_utils::MatrixBatchVectorMultiplyAccumulate( |
| recurrent_to_output_weights_buffer, n_cell, n_output, output_state_in_buffer, n_batch, |
| output_gate_scratch); |
| |
| // For each batch and cell: update input gate. |
| if (!params.use_cifg) { |
| if (params.use_peephole) { |
| tflite::tensor_utils::VectorBatchVectorCwiseProductAccumulate( |
| cell_to_input_weights_buffer, n_cell, cell_state_in_buffer, n_batch, |
| input_gate_scratch); |
| } |
| if (params.use_layer_norm) { |
| tflite::tensor_utils::MeanStddevNormalization(input_gate_scratch, input_gate_scratch, |
| n_cell, n_batch); |
| tflite::tensor_utils::VectorBatchVectorCwiseProduct(input_layer_norm_weights_buffer, |
| n_cell, input_gate_scratch, n_batch, |
| input_gate_scratch); |
| tflite::tensor_utils::VectorBatchVectorAdd(input_gate_bias_buffer, n_cell, n_batch, |
| input_gate_scratch); |
| } |
| tflite::tensor_utils::ApplySigmoidToVector(input_gate_scratch, n_cell * n_batch, |
| input_gate_scratch); |
| } |
| |
| // For each batch and cell: update forget gate. |
| if (params.use_peephole) { |
| tflite::tensor_utils::VectorBatchVectorCwiseProductAccumulate(cell_to_forget_weights_buffer, |
| n_cell, cell_state_in_buffer, |
| n_batch, forget_gate_scratch); |
| } |
| if (params.use_layer_norm) { |
| tflite::tensor_utils::MeanStddevNormalization(forget_gate_scratch, forget_gate_scratch, |
| n_cell, n_batch); |
| tflite::tensor_utils::VectorBatchVectorCwiseProduct(forget_layer_norm_weights_buffer, |
| n_cell, forget_gate_scratch, n_batch, |
| forget_gate_scratch); |
| tflite::tensor_utils::VectorBatchVectorAdd(forget_gate_bias_buffer, n_cell, n_batch, |
| forget_gate_scratch); |
| } |
| tflite::tensor_utils::ApplySigmoidToVector(forget_gate_scratch, n_cell * n_batch, |
| forget_gate_scratch); |
| |
| // For each batch and cell: update the cell. |
| if (params.use_layer_norm) { |
| tflite::tensor_utils::MeanStddevNormalization(cell_scratch, cell_scratch, n_cell, n_batch); |
| tflite::tensor_utils::VectorBatchVectorCwiseProduct(cell_layer_norm_weights_buffer, n_cell, |
| cell_scratch, n_batch, cell_scratch); |
| tflite::tensor_utils::VectorBatchVectorAdd(cell_bias_buffer, n_cell, n_batch, cell_scratch); |
| } |
| tflite::tensor_utils::VectorVectorCwiseProduct(forget_gate_scratch, cell_state_in_buffer, |
| n_batch * n_cell, cell_state_out_buffer); |
| tflite::tensor_utils::ApplyActivationToVector( |
| cell_scratch, n_batch * n_cell, static_cast<TfLiteFusedActivation>(params.activation), |
| cell_scratch); |
| if (params.use_cifg) { |
| tflite::tensor_utils::Sub1Vector(forget_gate_scratch, n_batch * n_cell, |
| forget_gate_scratch); |
| tflite::tensor_utils::VectorVectorCwiseProductAccumulate( |
| cell_scratch, forget_gate_scratch, n_batch * n_cell, cell_state_out_buffer); |
| } else { |
| tflite::tensor_utils::VectorVectorCwiseProductAccumulate( |
| cell_scratch, input_gate_scratch, n_batch * n_cell, cell_state_out_buffer); |
| } |
| if (params.cell_clip > 0.0) { |
| tflite::tensor_utils::CwiseClipping(cell_state_out_buffer, n_batch * n_cell, |
| params.cell_clip); |
| } |
| |
| // For each batch and cell: update the output gate. |
| if (params.use_peephole) { |
| tflite::tensor_utils::VectorBatchVectorCwiseProductAccumulate(cell_to_output_weights_buffer, |
| n_cell, cell_state_out_buffer, |
| n_batch, output_gate_scratch); |
| } |
| if (params.use_layer_norm) { |
| tflite::tensor_utils::MeanStddevNormalization(output_gate_scratch, output_gate_scratch, |
| n_cell, n_batch); |
| tflite::tensor_utils::VectorBatchVectorCwiseProduct(output_layer_norm_weights_buffer, |
| n_cell, output_gate_scratch, n_batch, |
| output_gate_scratch); |
| tflite::tensor_utils::VectorBatchVectorAdd(output_gate_bias_buffer, n_cell, n_batch, |
| output_gate_scratch); |
| } |
| tflite::tensor_utils::ApplySigmoidToVector(output_gate_scratch, n_batch * n_cell, |
| output_gate_scratch); |
| tflite::tensor_utils::ApplyActivationToVector( |
| cell_state_out_buffer, n_batch * n_cell, |
| static_cast<TfLiteFusedActivation>(params.activation), cell_scratch); |
| tflite::tensor_utils::VectorVectorCwiseProduct(output_gate_scratch, cell_scratch, |
| n_batch * n_cell, output_gate_scratch); |
| |
| // For each batch: update the projection and output_state. |
| if (params.use_projection_weight) { |
| if (params.use_projection_bias) { |
| tflite::tensor_utils::VectorBatchVectorAssign(projection_bias_buffer, n_output, n_batch, |
| output_buffer); |
| } else { |
| std::fill_n(output_buffer, n_batch * n_output, 0.0f); |
| } |
| tflite::tensor_utils::MatrixBatchVectorMultiplyAccumulate( |
| projection_weights_buffer, n_output, n_cell, output_gate_scratch, n_batch, |
| output_buffer); |
| if (params.proj_clip > 0.0) { |
| tflite::tensor_utils::CwiseClipping(output_buffer, n_batch * n_output, |
| params.proj_clip); |
| } |
| } else { |
| std::copy_n(output_gate_scratch, n_batch * n_output, output_buffer); |
| } |
| std::copy_n(output_buffer, n_batch * n_output, output_state_out_buffer); |
| return true; |
| } |
| |
| bool LSTMCell::Eval() { |
| switch (input_->type) { |
| case OperandType::TENSOR_FLOAT32: { |
| LSTMEvalFloat32(params_, GetBuffer<const float>(input_), input_->shape(), |
| GetBuffer<const float>(input_to_input_weights_), |
| GetBuffer<const float>(input_to_forget_weights_), |
| GetBuffer<const float>(input_to_cell_weights_), |
| GetBuffer<const float>(input_to_output_weights_), |
| input_to_output_weights_->shape(), |
| GetBuffer<const float>(recurrent_to_input_weights_), |
| GetBuffer<const float>(recurrent_to_forget_weights_), |
| GetBuffer<const float>(recurrent_to_cell_weights_), |
| GetBuffer<const float>(recurrent_to_output_weights_), |
| recurrent_to_output_weights_->shape(), |
| GetBuffer<const float>(cell_to_input_weights_), |
| GetBuffer<const float>(cell_to_forget_weights_), |
| GetBuffer<const float>(cell_to_output_weights_), |
| /*aux_input_buffer=*/nullptr, |
| /*aux_input_to_input_weights_buffer=*/nullptr, |
| /*aux_input_to_forget_weights_buffer=*/nullptr, |
| /*aux_input_to_cell_weights_buffer=*/nullptr, |
| /*aux_input_to_output_weights_buffer=*/nullptr, |
| GetBuffer<const float>(input_gate_bias_), |
| GetBuffer<const float>(forget_gate_bias_), |
| GetBuffer<const float>(cell_bias_), |
| GetBuffer<const float>(output_gate_bias_), |
| GetBuffer<const float>(projection_weights_), |
| GetBuffer<const float>(projection_bias_), |
| GetBuffer<const float>(output_state_in_), |
| GetBuffer<const float>(cell_state_in_), |
| GetBuffer<const float>(input_layer_norm_weights_), |
| GetBuffer<const float>(forget_layer_norm_weights_), |
| GetBuffer<const float>(cell_layer_norm_weights_), |
| GetBuffer<const float>(output_layer_norm_weights_), |
| GetBuffer<float>(output_state_out_), GetBuffer<float>(cell_state_out_), |
| GetBuffer<float>(output_), GetBuffer<float>(scratch_buffer_)); |
| } break; |
| case OperandType::TENSOR_FLOAT16: { |
| LSTMEvalFloat16(params_, GetBuffer<const _Float16>(input_), input_->shape(), |
| GetOptionalBuffer<const _Float16>(input_to_input_weights_), |
| GetBuffer<const _Float16>(input_to_forget_weights_), |
| GetBuffer<const _Float16>(input_to_cell_weights_), |
| GetBuffer<const _Float16>(input_to_output_weights_), |
| input_to_output_weights_->shape(), |
| GetOptionalBuffer<const _Float16>(recurrent_to_input_weights_), |
| GetBuffer<const _Float16>(recurrent_to_forget_weights_), |
| GetBuffer<const _Float16>(recurrent_to_cell_weights_), |
| GetBuffer<const _Float16>(recurrent_to_output_weights_), |
| recurrent_to_output_weights_->shape(), |
| GetOptionalBuffer<const _Float16>(cell_to_input_weights_), |
| GetOptionalBuffer<const _Float16>(cell_to_forget_weights_), |
| GetOptionalBuffer<const _Float16>(cell_to_output_weights_), |
| /*aux_input_buffer=*/nullptr, |
| /*aux_input_to_input_weights_buffer=*/nullptr, |
| /*aux_input_to_forget_weights_buffer=*/nullptr, |
| /*aux_input_to_cell_weights_buffer=*/nullptr, |
| /*aux_input_to_output_weights_buffer=*/nullptr, |
| GetOptionalBuffer<const _Float16>(input_gate_bias_), |
| GetBuffer<const _Float16>(forget_gate_bias_), |
| GetBuffer<const _Float16>(cell_bias_), |
| GetBuffer<const _Float16>(output_gate_bias_), |
| GetOptionalBuffer<const _Float16>(projection_weights_), |
| GetOptionalBuffer<const _Float16>(projection_bias_), |
| GetBuffer<const _Float16>(output_state_in_), |
| GetBuffer<const _Float16>(cell_state_in_), |
| GetOptionalBuffer<const _Float16>(input_layer_norm_weights_), |
| GetOptionalBuffer<const _Float16>(forget_layer_norm_weights_), |
| GetOptionalBuffer<const _Float16>(cell_layer_norm_weights_), |
| GetOptionalBuffer<const _Float16>(output_layer_norm_weights_), |
| GetBuffer<_Float16>(output_state_out_), |
| GetBuffer<_Float16>(cell_state_out_), GetBuffer<_Float16>(output_), |
| GetBuffer<_Float16>(scratch_buffer_)); |
| } break; |
| default: { |
| LOG(ERROR) << "Unsupported data type: " << static_cast<int>(input_->type); |
| return false; |
| } |
| } |
| return true; |
| } |
| |
| } // namespace nn |
| } // namespace android |