blob: db7eb801009361d22513194cc1fc15501a2a0caf [file] [log] [blame]
Yang Nid0ea9fd2017-07-28 16:23:46 -07001/*
2 * Copyright (C) 2017 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
Stefano Galarraga6c3a8cc2019-07-02 16:28:53 +010017#define LOG_TAG "Operations"
18
Yang Nid0ea9fd2017-07-28 16:23:46 -070019#include "LSTM.h"
20
Lev Proleev82112fb2021-02-25 13:36:02 +000021#include <tensorflow/lite/kernels/internal/reference/portable_tensor_utils.h>
Slava Shklyaevc958cd82020-12-10 16:55:55 +000022
Lev Proleev8bd6eb72020-02-06 14:39:41 +000023#include <vector>
24
Yang Nid0ea9fd2017-07-28 16:23:46 -070025#include "CpuExecutor.h"
Lev Proleevd1c222a2018-12-28 13:24:24 +000026#include "CpuOperationUtils.h"
Slava Shklyaevc958cd82020-12-10 16:55:55 +000027#include "LegacyUtils.h"
Viet Dang1bf001b2019-01-23 00:22:19 +000028#include "OperationsUtils.h"
Mika Raento0bb84c72018-04-23 22:06:45 +010029#include "Tracing.h"
Slava Shklyaev9f29f432020-08-13 13:16:03 +010030#include "nnapi/Types.h"
Mika Raento0bb84c72018-04-23 22:06:45 +010031
Yang Nid0ea9fd2017-07-28 16:23:46 -070032namespace android {
33namespace nn {
34
Yang Nid0ea9fd2017-07-28 16:23:46 -070035namespace {
36
Yang Nid0ea9fd2017-07-28 16:23:46 -070037template <typename T>
Lev Proleevd1c222a2018-12-28 13:24:24 +000038inline T* GetBuffer(RunTimeOperandInfo* operand) {
39 return reinterpret_cast<T*>(operand->buffer);
Yang Nid0ea9fd2017-07-28 16:23:46 -070040}
41
42template <typename T>
Lev Proleevd1c222a2018-12-28 13:24:24 +000043inline const T* GetBuffer(const RunTimeOperandInfo* operand) {
44 return reinterpret_cast<const T*>(operand->buffer);
Yang Nid0ea9fd2017-07-28 16:23:46 -070045}
46
Viet Dangc2ddad92019-01-23 00:50:00 +000047template <typename T>
48inline const T* GetOptionalBuffer(const RunTimeOperandInfo* operand) {
49 return !IsNullInput(operand) ? reinterpret_cast<const T*>(operand->buffer) : nullptr;
50}
51
Yang Nid0ea9fd2017-07-28 16:23:46 -070052} // anonymous namespace
53
Slava Shklyaev742e2872019-11-27 10:15:58 +000054LSTMCell::LSTMCell(const Operation& operation, RunTimeOperandInfo* operands) {
Lev Proleevd1c222a2018-12-28 13:24:24 +000055 input_ = GetInput(operation, operands, kInputTensor);
Yang Nid0ea9fd2017-07-28 16:23:46 -070056
Lev Proleevd1c222a2018-12-28 13:24:24 +000057 input_to_input_weights_ =
58 GetInput(operation, operands, kInputToInputWeightsTensor); // optional
59 input_to_forget_weights_ = GetInput(operation, operands, kInputToForgetWeightsTensor);
60 input_to_cell_weights_ = GetInput(operation, operands, kInputToCellWeightsTensor);
61 input_to_output_weights_ = GetInput(operation, operands, kInputToOutputWeightsTensor);
Yang Nid0ea9fd2017-07-28 16:23:46 -070062
Lev Proleevd1c222a2018-12-28 13:24:24 +000063 recurrent_to_input_weights_ =
64 GetInput(operation, operands, kRecurrentToInputWeightsTensor); // optional
65 recurrent_to_forget_weights_ = GetInput(operation, operands, kRecurrentToForgetWeightsTensor);
66 recurrent_to_cell_weights_ = GetInput(operation, operands, kRecurrentToCellWeightsTensor);
67 recurrent_to_output_weights_ = GetInput(operation, operands, kRecurrentToOutputWeightsTensor);
Yang Nid0ea9fd2017-07-28 16:23:46 -070068
Lev Proleevd1c222a2018-12-28 13:24:24 +000069 cell_to_input_weights_ = GetInput(operation, operands, kCellToInputWeightsTensor); // optional
70 cell_to_forget_weights_ =
71 GetInput(operation, operands, kCellToForgetWeightsTensor); // optional
72 cell_to_output_weights_ =
73 GetInput(operation, operands, kCellToOutputWeightsTensor); // optional
Yang Nid0ea9fd2017-07-28 16:23:46 -070074
Lev Proleevd1c222a2018-12-28 13:24:24 +000075 input_gate_bias_ = GetInput(operation, operands, kInputGateBiasTensor);
76 forget_gate_bias_ = GetInput(operation, operands, kForgetGateBiasTensor);
77 cell_bias_ = GetInput(operation, operands, kCellGateBiasTensor);
78 output_gate_bias_ = GetInput(operation, operands, kOutputGateBiasTensor);
Yang Nid0ea9fd2017-07-28 16:23:46 -070079
Lev Proleevd1c222a2018-12-28 13:24:24 +000080 projection_weights_ = GetInput(operation, operands, kProjectionWeightsTensor); // optional
81 projection_bias_ = GetInput(operation, operands, kProjectionBiasTensor); // optional
Yang Nid0ea9fd2017-07-28 16:23:46 -070082
Lev Proleevd1c222a2018-12-28 13:24:24 +000083 output_state_in_ = GetInput(operation, operands, kOutputStateInTensor);
84 cell_state_in_ = GetInput(operation, operands, kCellStateInTensor);
Yang Ni38a6ad62017-10-03 22:57:33 -070085
Michael Butlerd1cfdc92020-05-27 23:06:31 -070086 const auto& activationOperand = *GetInput(operation, operands, kActivationParam);
87 params_.activation = static_cast<TfLiteFusedActivation>(getScalarDataWithDefault<int32_t>(
88 activationOperand, TfLiteFusedActivation::kTfLiteActNone));
89
90 const auto& cellClipOperand = *GetInput(operation, operands, kCellClipParam);
91 const auto& projClipOperand = *GetInput(operation, operands, kProjClipParam);
Lev Proleevd1c222a2018-12-28 13:24:24 +000092 if (input_->type == OperandType::TENSOR_FLOAT32) {
Michael Butlerd1cfdc92020-05-27 23:06:31 -070093 params_.cell_clip = getScalarDataWithDefault<float>(cellClipOperand, 0.0f);
94 params_.proj_clip = getScalarDataWithDefault<float>(projClipOperand, 0.0f);
Lev Proleevd1c222a2018-12-28 13:24:24 +000095 } else {
Michael Butlerd1cfdc92020-05-27 23:06:31 -070096 params_.cell_clip =
97 static_cast<float>(getScalarDataWithDefault<_Float16>(cellClipOperand, 0.0f));
98 params_.proj_clip =
99 static_cast<float>(getScalarDataWithDefault<_Float16>(projClipOperand, 0.0f));
Lev Proleevd1c222a2018-12-28 13:24:24 +0000100 }
Yang Nid0ea9fd2017-07-28 16:23:46 -0700101
Lev Proleevd1c222a2018-12-28 13:24:24 +0000102 // We check the version of LSTM by checking the number of the inputs to the
103 // op. For LSTM version 1.0 there were 23 inputs and for 1.2 there are 27.
104 if (operation.inputs.size() == 27) {
Lev Proleev8ab42e92019-03-25 12:15:01 +0000105 input_layer_norm_weights_ =
106 GetInput(operation, operands, kInputLayerNormWeightsTensor); // optional
107 forget_layer_norm_weights_ =
108 GetInput(operation, operands, kForgetLayerNormWeightsTensor); // optional
109 cell_layer_norm_weights_ =
110 GetInput(operation, operands, kCellLayerNormWeightsTensor); // optional
111 output_layer_norm_weights_ =
112 GetInput(operation, operands, kOutputLayerNormWeightsTensor); // optional
Lev Proleevd1c222a2018-12-28 13:24:24 +0000113 } else {
114 // For LSTM from HAL v1.0 assign operands with no values
115 static RunTimeOperandInfo no_value;
Slava Shklyaev9f29f432020-08-13 13:16:03 +0100116 no_value.lifetime = Operand::LifeTime::NO_VALUE;
Lev Proleev7f4d4c72018-11-08 12:06:38 +0000117
Lev Proleevd1c222a2018-12-28 13:24:24 +0000118 input_layer_norm_weights_ = &no_value;
119 forget_layer_norm_weights_ = &no_value;
120 cell_layer_norm_weights_ = &no_value;
121 output_layer_norm_weights_ = &no_value;
122 }
Lev Proleev7f4d4c72018-11-08 12:06:38 +0000123
Lev Proleevd1c222a2018-12-28 13:24:24 +0000124 output_state_out_ = GetOutput(operation, operands, kOutputStateOutTensor);
125 cell_state_out_ = GetOutput(operation, operands, kCellStateOutTensor);
126 output_ = GetOutput(operation, operands, kOutputTensor);
Yang Nid0ea9fd2017-07-28 16:23:46 -0700127
Lev Proleevd1c222a2018-12-28 13:24:24 +0000128 scratch_buffer_ = GetOutput(operation, operands, kScratchBufferTensor);
Yang Nie3cc73d2017-09-27 10:26:52 -0700129}
130
Michael K. Sanders70424ae2019-01-21 11:02:44 +0000131// static
132bool LSTMCell::CheckInputTensorDimensions(
133 const RunTimeOperandInfo* input_, const RunTimeOperandInfo* input_to_input_weights,
134 const RunTimeOperandInfo* input_to_forget_weights,
135 const RunTimeOperandInfo* input_to_cell_weights,
136 const RunTimeOperandInfo* input_to_output_weights,
137 const RunTimeOperandInfo* recurrent_to_input_weights,
138 const RunTimeOperandInfo* recurrent_to_forget_weights,
139 const RunTimeOperandInfo* recurrent_to_cell_weights,
140 const RunTimeOperandInfo* recurrent_to_output_weights,
141 const RunTimeOperandInfo* cell_to_input_weights,
142 const RunTimeOperandInfo* cell_to_forget_weights,
143 const RunTimeOperandInfo* cell_to_output_weights, const RunTimeOperandInfo* input_gate_bias,
144 const RunTimeOperandInfo* forget_gate_bias, const RunTimeOperandInfo* cell_bias,
145 const RunTimeOperandInfo* output_gate_bias, const RunTimeOperandInfo* projection_weights,
146 const RunTimeOperandInfo* projection_bias,
147 const RunTimeOperandInfo* input_layer_norm_weights,
148 const RunTimeOperandInfo* forget_layer_norm_weights,
149 const RunTimeOperandInfo* cell_layer_norm_weights,
150 const RunTimeOperandInfo* output_layer_norm_weights, uint32_t n_input, uint32_t n_output,
151 uint32_t n_cell, LSTMParams* params) {
Lev Proleevd1c222a2018-12-28 13:24:24 +0000152 // Making sure clipping parameters have valid values.
153 // == 0 means no clipping
154 // > 0 means clipping
Michael K. Sanders70424ae2019-01-21 11:02:44 +0000155 NN_CHECK(params->cell_clip >= 0);
156 NN_CHECK(params->proj_clip >= 0);
Yang Nie3cc73d2017-09-27 10:26:52 -0700157
Michael K. Sanders70424ae2019-01-21 11:02:44 +0000158 if (!IsNullInput(input_to_input_weights)) {
159 NN_CHECK_EQ(NumDimensions(input_to_input_weights), 2);
160 NN_CHECK_EQ(SizeOfDimension(input_to_input_weights, 0), n_cell);
161 NN_CHECK_EQ(SizeOfDimension(input_to_input_weights, 1), n_input);
Lev Proleevd1c222a2018-12-28 13:24:24 +0000162 }
Yang Nie3cc73d2017-09-27 10:26:52 -0700163
Michael K. Sanders70424ae2019-01-21 11:02:44 +0000164 NN_CHECK_EQ(NumDimensions(input_to_forget_weights), 2);
165 NN_CHECK_EQ(SizeOfDimension(input_to_forget_weights, 0), n_cell);
166 NN_CHECK_EQ(SizeOfDimension(input_to_forget_weights, 1), n_input);
Yang Nie3cc73d2017-09-27 10:26:52 -0700167
Michael K. Sanders70424ae2019-01-21 11:02:44 +0000168 NN_CHECK_EQ(NumDimensions(input_to_cell_weights), 2);
169 NN_CHECK_EQ(SizeOfDimension(input_to_cell_weights, 0), n_cell);
170 NN_CHECK_EQ(SizeOfDimension(input_to_cell_weights, 1), n_input);
Yang Nie3cc73d2017-09-27 10:26:52 -0700171
Michael K. Sanders70424ae2019-01-21 11:02:44 +0000172 if (!IsNullInput(recurrent_to_input_weights)) {
173 NN_CHECK_EQ(NumDimensions(recurrent_to_input_weights), 2);
174 NN_CHECK_EQ(SizeOfDimension(recurrent_to_input_weights, 0), n_cell);
175 NN_CHECK_EQ(SizeOfDimension(recurrent_to_input_weights, 1), n_output);
Lev Proleevd1c222a2018-12-28 13:24:24 +0000176 }
Yang Nie3cc73d2017-09-27 10:26:52 -0700177
Michael K. Sanders70424ae2019-01-21 11:02:44 +0000178 NN_CHECK_EQ(NumDimensions(recurrent_to_forget_weights), 2);
179 NN_CHECK_EQ(SizeOfDimension(recurrent_to_forget_weights, 0), n_cell);
180 NN_CHECK_EQ(SizeOfDimension(recurrent_to_forget_weights, 1), n_output);
Yang Nie3cc73d2017-09-27 10:26:52 -0700181
Michael K. Sanders70424ae2019-01-21 11:02:44 +0000182 NN_CHECK_EQ(NumDimensions(recurrent_to_cell_weights), 2);
183 NN_CHECK_EQ(SizeOfDimension(recurrent_to_cell_weights, 0), n_cell);
184 NN_CHECK_EQ(SizeOfDimension(recurrent_to_cell_weights, 1), n_output);
Yang Nie3cc73d2017-09-27 10:26:52 -0700185
Lev Proleevd1c222a2018-12-28 13:24:24 +0000186 // We make sure the input-gate's parameters are either both present (regular
187 // LSTM) or not at all (CIFG-LSTM).
188 const bool cifg_weights_all_or_none =
Michael K. Sanders70424ae2019-01-21 11:02:44 +0000189 (!IsNullInput(input_to_input_weights) && !IsNullInput(recurrent_to_input_weights)) ||
190 (IsNullInput(input_to_input_weights) && IsNullInput(recurrent_to_input_weights));
Lev Proleevd1c222a2018-12-28 13:24:24 +0000191 NN_CHECK(cifg_weights_all_or_none);
Yang Nie3cc73d2017-09-27 10:26:52 -0700192
Michael K. Sanders70424ae2019-01-21 11:02:44 +0000193 if (!IsNullInput(cell_to_input_weights)) {
194 NN_CHECK_EQ(NumDimensions(cell_to_input_weights), 1);
195 NN_CHECK_EQ(SizeOfDimension(cell_to_input_weights, 0), n_cell);
Lev Proleevd1c222a2018-12-28 13:24:24 +0000196 }
Yang Nie3cc73d2017-09-27 10:26:52 -0700197
Michael K. Sanders70424ae2019-01-21 11:02:44 +0000198 if (!IsNullInput(cell_to_forget_weights)) {
199 NN_CHECK_EQ(NumDimensions(cell_to_forget_weights), 1);
200 NN_CHECK_EQ(SizeOfDimension(cell_to_forget_weights, 0), n_cell);
Lev Proleevd1c222a2018-12-28 13:24:24 +0000201 }
Yang Nie3cc73d2017-09-27 10:26:52 -0700202
Michael K. Sanders70424ae2019-01-21 11:02:44 +0000203 if (!IsNullInput(cell_to_output_weights)) {
204 NN_CHECK_EQ(NumDimensions(cell_to_output_weights), 1);
205 NN_CHECK_EQ(SizeOfDimension(cell_to_output_weights, 0), n_cell);
Lev Proleevd1c222a2018-12-28 13:24:24 +0000206 }
Yang Nie3cc73d2017-09-27 10:26:52 -0700207
Lev Proleevd1c222a2018-12-28 13:24:24 +0000208 // Making sure the peephole weights are there all or none.
Michael K. Sanders70424ae2019-01-21 11:02:44 +0000209 params->use_cifg = IsNullInput(input_to_input_weights);
Lev Proleevd1c222a2018-12-28 13:24:24 +0000210 const bool peephole_weights_all_or_none =
Michael K. Sanders70424ae2019-01-21 11:02:44 +0000211 ((!IsNullInput(cell_to_input_weights) || params->use_cifg) &&
212 !IsNullInput(cell_to_forget_weights) && !IsNullInput(cell_to_output_weights)) ||
213 (IsNullInput(cell_to_input_weights) && IsNullInput(cell_to_forget_weights) &&
214 IsNullInput(cell_to_output_weights));
Lev Proleevd1c222a2018-12-28 13:24:24 +0000215 NN_CHECK(peephole_weights_all_or_none);
Yang Nie3cc73d2017-09-27 10:26:52 -0700216
Michael K. Sanders70424ae2019-01-21 11:02:44 +0000217 // Since we have already checked that weights are all there or none, we can
218 // check the existence of only one to the get the condition.
219 params->use_peephole = !IsNullInput(cell_to_output_weights);
Lev Proleev8ab42e92019-03-25 12:15:01 +0000220 // Checking output instead of input layer norm weights because input can be
221 // omitted ones can be omited in case CIFG LSTM is used.
222 params->use_layer_norm = !IsNullInput(output_layer_norm_weights);
Michael K. Sanders70424ae2019-01-21 11:02:44 +0000223
Slava Shklyaev9f29f432020-08-13 13:16:03 +0100224 params->use_projection_weight = (projection_weights->lifetime != Operand::LifeTime::NO_VALUE);
225 params->use_projection_bias = (projection_bias->lifetime != Operand::LifeTime::NO_VALUE);
Michael K. Sanders70424ae2019-01-21 11:02:44 +0000226
Lev Proleevd1c222a2018-12-28 13:24:24 +0000227 // Make sure the input gate bias is present only when not a CIFG-LSTM.
Michael K. Sanders70424ae2019-01-21 11:02:44 +0000228 if (params->use_cifg) {
229 NN_CHECK(IsNullInput(input_gate_bias));
Lev Proleevd1c222a2018-12-28 13:24:24 +0000230 } else {
Michael K. Sanders70424ae2019-01-21 11:02:44 +0000231 NN_CHECK_EQ(NumDimensions(input_gate_bias), 1);
232 NN_CHECK_EQ(SizeOfDimension(input_gate_bias, 0), n_cell);
Lev Proleevd1c222a2018-12-28 13:24:24 +0000233 }
Yang Nie3cc73d2017-09-27 10:26:52 -0700234
Michael K. Sanders70424ae2019-01-21 11:02:44 +0000235 NN_CHECK_EQ(NumDimensions(forget_gate_bias), 1);
236 NN_CHECK_EQ(SizeOfDimension(forget_gate_bias, 0), n_cell);
Yang Nie3cc73d2017-09-27 10:26:52 -0700237
Michael K. Sanders70424ae2019-01-21 11:02:44 +0000238 NN_CHECK_EQ(NumDimensions(cell_bias), 1);
239 NN_CHECK_EQ(SizeOfDimension(cell_bias, 0), n_cell);
Yang Nie3cc73d2017-09-27 10:26:52 -0700240
Michael K. Sanders70424ae2019-01-21 11:02:44 +0000241 NN_CHECK_EQ(NumDimensions(output_gate_bias), 1);
242 NN_CHECK_EQ(SizeOfDimension(output_gate_bias, 0), n_cell);
Yang Nie3cc73d2017-09-27 10:26:52 -0700243
Michael K. Sanders70424ae2019-01-21 11:02:44 +0000244 if (!IsNullInput(projection_weights)) {
245 NN_CHECK_EQ(NumDimensions(projection_weights), 2);
246 NN_CHECK_EQ(SizeOfDimension(projection_weights, 0), n_output);
247 NN_CHECK_EQ(SizeOfDimension(projection_weights, 1), n_cell);
Lev Proleevd1c222a2018-12-28 13:24:24 +0000248 }
Yang Nie3cc73d2017-09-27 10:26:52 -0700249
Michael K. Sanders70424ae2019-01-21 11:02:44 +0000250 if (!IsNullInput(projection_bias)) {
251 NN_CHECK_EQ(NumDimensions(projection_bias), 1);
252 NN_CHECK_EQ(SizeOfDimension(projection_bias, 0), n_output);
Lev Proleevd1c222a2018-12-28 13:24:24 +0000253 }
Yang Nie3cc73d2017-09-27 10:26:52 -0700254
Lev Proleevd1c222a2018-12-28 13:24:24 +0000255 // Making sure the projection tensors are consistent:
256 // 1) If projection weight is not present, then projection bias should not be
257 // present.
258 // 2) If projection weight is present, then projection bias is optional.
259 // TODO: make sure this is correct.
260 const bool projecton_tensors_consistent =
Michael K. Sanders70424ae2019-01-21 11:02:44 +0000261 (!IsNullInput(projection_weights) || IsNullInput(projection_bias));
Lev Proleevd1c222a2018-12-28 13:24:24 +0000262 NN_CHECK(projecton_tensors_consistent == true);
Yang Nie3cc73d2017-09-27 10:26:52 -0700263
Michael K. Sanders70424ae2019-01-21 11:02:44 +0000264 if (!IsNullInput(input_layer_norm_weights)) {
265 NN_CHECK_EQ(NumDimensions(input_layer_norm_weights), 1);
266 NN_CHECK_EQ(SizeOfDimension(input_layer_norm_weights, 0), n_cell);
Lev Proleevd1c222a2018-12-28 13:24:24 +0000267 }
Michael K. Sanders70424ae2019-01-21 11:02:44 +0000268 if (!IsNullInput(forget_layer_norm_weights)) {
269 NN_CHECK_EQ(NumDimensions(forget_layer_norm_weights), 1);
270 NN_CHECK_EQ(SizeOfDimension(forget_layer_norm_weights, 0), n_cell);
Lev Proleevd1c222a2018-12-28 13:24:24 +0000271 }
Michael K. Sanders70424ae2019-01-21 11:02:44 +0000272 if (!IsNullInput(cell_layer_norm_weights)) {
273 NN_CHECK_EQ(NumDimensions(cell_layer_norm_weights), 1);
274 NN_CHECK_EQ(SizeOfDimension(cell_layer_norm_weights, 0), n_cell);
Lev Proleevd1c222a2018-12-28 13:24:24 +0000275 }
Michael K. Sanders70424ae2019-01-21 11:02:44 +0000276 if (!IsNullInput(output_layer_norm_weights)) {
277 NN_CHECK_EQ(NumDimensions(output_layer_norm_weights), 1);
278 NN_CHECK_EQ(SizeOfDimension(output_layer_norm_weights, 0), n_cell);
Lev Proleevd1c222a2018-12-28 13:24:24 +0000279 }
Yang Nie3cc73d2017-09-27 10:26:52 -0700280
Lev Proleev8ab42e92019-03-25 12:15:01 +0000281 if (params->use_cifg) {
282 NN_RET_CHECK(IsNullInput(input_layer_norm_weights))
283 << "input_layer_norm_weights are provided while CIFG is used";
284 const bool layer_norm_weights_all_or_none_cifg =
285 (IsNullInput(forget_layer_norm_weights) && IsNullInput(cell_layer_norm_weights) &&
286 IsNullInput(output_layer_norm_weights)) ||
287 (!IsNullInput(forget_layer_norm_weights) && !IsNullInput(cell_layer_norm_weights) &&
288 !IsNullInput(output_layer_norm_weights));
289 NN_RET_CHECK(layer_norm_weights_all_or_none_cifg);
290 } else {
291 const bool layer_norm_weights_all_or_none =
292 (IsNullInput(input_layer_norm_weights) && IsNullInput(forget_layer_norm_weights) &&
293 IsNullInput(cell_layer_norm_weights) && IsNullInput(output_layer_norm_weights)) ||
294 (!IsNullInput(input_layer_norm_weights) &&
295 !IsNullInput(forget_layer_norm_weights) && !IsNullInput(cell_layer_norm_weights) &&
296 !IsNullInput(output_layer_norm_weights));
297 NN_RET_CHECK(layer_norm_weights_all_or_none);
298 }
Lev Proleev7f4d4c72018-11-08 12:06:38 +0000299
Lev Proleevd1c222a2018-12-28 13:24:24 +0000300 return true;
Yang Nie3cc73d2017-09-27 10:26:52 -0700301}
302
Slava Shklyaev742e2872019-11-27 10:15:58 +0000303bool LSTMCell::Prepare(const Operation& operation, RunTimeOperandInfo* operands,
Lev Proleevd1c222a2018-12-28 13:24:24 +0000304 Shape* scratchShape, Shape* outputStateShape, Shape* cellStateShape,
305 Shape* outputShape) {
306 // Check we have all the inputs and outputs we need.
307 NN_CHECK(NumInputsWithValues(operation, operands) >= 15 &&
308 NumInputsWithValues(operation, operands) <= 27);
Michael Butlerd1cfdc92020-05-27 23:06:31 -0700309 constexpr int requiredInputs[] = {
310 kInputTensor,
311 kInputToForgetWeightsTensor,
312 kInputToCellWeightsTensor,
313 kInputToOutputWeightsTensor,
314 kRecurrentToForgetWeightsTensor,
315 kRecurrentToCellWeightsTensor,
316 kRecurrentToOutputWeightsTensor,
317 kForgetGateBiasTensor,
318 kCellGateBiasTensor,
319 kOutputGateBiasTensor,
320 kOutputStateInTensor,
321 kCellStateInTensor,
322 kActivationParam,
323 kCellClipParam,
324 kProjClipParam,
325 };
326 for (const int requiredInput : requiredInputs) {
327 NN_RET_CHECK(!IsNullInput(GetInput(operation, operands, requiredInput)))
328 << "required input " << requiredInput << " is omitted";
329 }
Lev Proleevd1c222a2018-12-28 13:24:24 +0000330 NN_CHECK_EQ(NumOutputs(operation), 4);
Yang Nie3cc73d2017-09-27 10:26:52 -0700331
Michael Butlerd1cfdc92020-05-27 23:06:31 -0700332 // Check that the scalar operands' buffers are large enough.
333 const auto& activationOperand = *GetInput(operation, operands, kActivationParam);
334 NN_RET_CHECK(activationOperand.length >= sizeof(int32_t));
335 const auto& cellClipOperand = *GetInput(operation, operands, kCellClipParam);
336 const auto& projClipOperand = *GetInput(operation, operands, kProjClipParam);
337 if (input_->type == OperandType::TENSOR_FLOAT32) {
338 NN_RET_CHECK(cellClipOperand.length >= sizeof(float));
339 NN_RET_CHECK(projClipOperand.length >= sizeof(float));
340 } else {
341 NN_RET_CHECK(cellClipOperand.length >= sizeof(_Float16));
342 NN_RET_CHECK(projClipOperand.length >= sizeof(_Float16));
343 }
344
Lev Proleevd1c222a2018-12-28 13:24:24 +0000345 // Inferring batch size, number of outputs and number of cells from the
346 // input tensors.
347 NN_CHECK(NumDimensions(input_) > 1);
348 const uint32_t n_batch = SizeOfDimension(input_, 0);
349 const uint32_t n_input = SizeOfDimension(input_, 1);
Yang Nie3cc73d2017-09-27 10:26:52 -0700350
Lev Proleevd1c222a2018-12-28 13:24:24 +0000351 const uint32_t n_cell = SizeOfDimension(input_to_output_weights_, 0);
352 NN_CHECK_EQ(NumDimensions(input_to_output_weights_), 2);
353 NN_CHECK_EQ(SizeOfDimension(input_to_output_weights_, 1), n_input);
Yang Nie3cc73d2017-09-27 10:26:52 -0700354
Lev Proleevd1c222a2018-12-28 13:24:24 +0000355 NN_CHECK_EQ(NumDimensions(recurrent_to_output_weights_), 2);
356 NN_CHECK_EQ(SizeOfDimension(recurrent_to_output_weights_, 0), n_cell);
357 const uint32_t n_output = SizeOfDimension(recurrent_to_output_weights_, 1);
Yang Nie3cc73d2017-09-27 10:26:52 -0700358
Lev Proleevd1c222a2018-12-28 13:24:24 +0000359 // Check that input tensor dimensions matches with each other.
Michael K. Sanders70424ae2019-01-21 11:02:44 +0000360 if (!CheckInputTensorDimensions(
361 input_, input_to_input_weights_, input_to_forget_weights_, input_to_cell_weights_,
362 input_to_output_weights_, recurrent_to_input_weights_, recurrent_to_forget_weights_,
363 recurrent_to_cell_weights_, recurrent_to_output_weights_, cell_to_input_weights_,
364 cell_to_forget_weights_, cell_to_output_weights_, input_gate_bias_,
365 forget_gate_bias_, cell_bias_, output_gate_bias_, projection_weights_,
366 projection_bias_, input_layer_norm_weights_, forget_layer_norm_weights_,
367 cell_layer_norm_weights_, output_layer_norm_weights_, n_input, n_output, n_cell,
368 &params_)) {
Lev Proleevd1c222a2018-12-28 13:24:24 +0000369 return false;
370 }
Yang Nie3cc73d2017-09-27 10:26:52 -0700371
Lev Proleevd1c222a2018-12-28 13:24:24 +0000372 // Resize the output and output_state tensors.
373 const Shape& inputShape = input_->shape();
Yang Nie3cc73d2017-09-27 10:26:52 -0700374
Lev Proleevd1c222a2018-12-28 13:24:24 +0000375 outputShape->type = inputShape.type;
376 outputShape->dimensions = {n_batch, n_output};
377 outputShape->offset = inputShape.offset;
378 outputShape->scale = inputShape.scale;
Yang Nie3cc73d2017-09-27 10:26:52 -0700379
Lev Proleevd1c222a2018-12-28 13:24:24 +0000380 outputStateShape->type = inputShape.type;
381 outputStateShape->dimensions = {n_batch, n_output};
382 outputStateShape->offset = inputShape.offset;
383 outputStateShape->scale = inputShape.scale;
Yang Nie3cc73d2017-09-27 10:26:52 -0700384
Lev Proleevd1c222a2018-12-28 13:24:24 +0000385 cellStateShape->type = inputShape.type;
386 cellStateShape->dimensions = {n_batch, n_cell};
387 cellStateShape->offset = inputShape.offset;
388 cellStateShape->scale = inputShape.scale;
Yang Nie3cc73d2017-09-27 10:26:52 -0700389
Michael K. Sanders70424ae2019-01-21 11:02:44 +0000390 if (params_.use_cifg) {
Lev Proleevd1c222a2018-12-28 13:24:24 +0000391 // Reserving space for Cell, Forget, Output gates
392 scratchShape->dimensions = {n_batch, n_cell * 3};
393 } else {
394 // Reserving space for Input, Cell, Forget, Output gates
395 scratchShape->dimensions = {n_batch, n_cell * 4};
396 }
397 scratchShape->type = inputShape.type;
398 scratchShape->offset = inputShape.offset;
399 scratchShape->scale = inputShape.scale;
Yang Nie3cc73d2017-09-27 10:26:52 -0700400
Lev Proleevd1c222a2018-12-28 13:24:24 +0000401 return true;
402}
403
Michael K. Sanders70424ae2019-01-21 11:02:44 +0000404// static
Viet Dangc2ddad92019-01-23 00:50:00 +0000405bool LSTMCell::LSTMEvalFloat32(
406 const LSTMParams& params, const float* input_buffer, const Shape& input_shape,
407 const float* input_to_input_weights_buffer, const float* input_to_forget_weights_buffer,
408 const float* input_to_cell_weights_buffer, const float* input_to_output_weights_buffer,
409 const Shape& input_to_output_weights_shape, const float* recurrent_to_input_weights_buffer,
410 const float* recurrent_to_forget_weights_buffer,
411 const float* recurrent_to_cell_weights_buffer,
412 const float* recurrent_to_output_weights_buffer,
413 const Shape& recurrent_to_output_weights_shape, const float* cell_to_input_weights_buffer,
414 const float* cell_to_forget_weights_buffer, const float* cell_to_output_weights_buffer,
Viet Dang5fe83682019-01-29 15:25:14 +0000415 const float* aux_input_buffer, const float* aux_input_to_input_weights_buffer,
Michael K. Sanders6af0ba32019-01-22 17:47:45 +0000416 const float* aux_input_to_forget_weights_buffer,
417 const float* aux_input_to_cell_weights_buffer,
418 const float* aux_input_to_output_weights_buffer, const float* input_gate_bias_buffer,
419 const float* forget_gate_bias_buffer, const float* cell_bias_buffer,
420 const float* output_gate_bias_buffer, const float* projection_weights_buffer,
421 const float* projection_bias_buffer, const float* output_state_in_buffer,
422 const float* cell_state_in_buffer, const float* input_layer_norm_weights_buffer,
423 const float* forget_layer_norm_weights_buffer, const float* cell_layer_norm_weights_buffer,
424 const float* output_layer_norm_weights_buffer, float* output_state_out_buffer,
425 float* cell_state_out_buffer, float* output_buffer, float* scratch_buffer_buffer,
Viet Dang5fe83682019-01-29 15:25:14 +0000426 bool timeMajor, bool forwardSequence) {
Viet Dangc2ddad92019-01-23 00:50:00 +0000427 NNTRACE_COMP("LSTMCell::LSTMEvalFloat32");
428
429 const uint32_t inputRank = getNumberOfDimensions(input_shape);
430 NN_CHECK(inputRank == 2 || inputRank == 3);
431
432 const uint32_t maxTime =
433 (inputRank == 3) ? getSizeOfDimension(input_shape, timeMajor ? 0 : 1) : 1;
434 const uint32_t batchSize = (inputRank == 3) ? getSizeOfDimension(input_shape, timeMajor ? 1 : 0)
435 : getSizeOfDimension(input_shape, 0);
436 const uint32_t inputSize = getSizeOfDimension(input_shape, inputRank - 1);
Viet Dang1bf001b2019-01-23 00:22:19 +0000437 const uint32_t numCells = getSizeOfDimension(input_to_output_weights_shape, 0);
Viet Dangc2ddad92019-01-23 00:50:00 +0000438 const uint32_t outputSize = getSizeOfDimension(recurrent_to_output_weights_shape, 1);
439
Viet Dang75954232019-01-24 07:03:26 +0000440 Shape batchInputShape = input_shape;
441 batchInputShape.dimensions = {batchSize, inputSize};
Viet Dangc2ddad92019-01-23 00:50:00 +0000442 const uint32_t batchInputSize = batchSize * inputSize;
443 const uint32_t batchOutputSize = batchSize * outputSize;
444
Viet Dang1bf001b2019-01-23 00:22:19 +0000445 std::vector<float> transposedInput;
Viet Dangaefafaa2019-04-29 18:43:22 +0100446 const bool hasAuxInput = (aux_input_buffer != nullptr);
447 std::vector<float> transposedAuxInput;
Viet Dang1bf001b2019-01-23 00:22:19 +0000448 std::vector<float> transposedOutput;
449 Shape transposedInputShape;
450 Shape transposedOutputShape;
451 if (!timeMajor) {
452 transposedInput.resize(maxTime * batchInputSize);
Viet Dang1bf001b2019-01-23 00:22:19 +0000453 transposeFirstTwoDimensions<float>(input_buffer, input_shape, transposedInput.data());
Viet Dangaefafaa2019-04-29 18:43:22 +0100454 if (hasAuxInput) {
455 transposedAuxInput.resize(maxTime * batchInputSize);
456 transposeFirstTwoDimensions<float>(aux_input_buffer, input_shape,
457 transposedAuxInput.data());
458 }
Viet Dang1bf001b2019-01-23 00:22:19 +0000459 transposeFirstTwoDimensions(input_shape, &transposedInputShape);
Viet Dangaefafaa2019-04-29 18:43:22 +0100460 transposedOutput.resize(maxTime * batchOutputSize);
Viet Dang1bf001b2019-01-23 00:22:19 +0000461 transposedOutputShape = transposedInputShape;
462 transposedOutputShape.dimensions[2] = outputSize;
463 }
Michael K. Sanders6af0ba32019-01-22 17:47:45 +0000464 const float* inputData = timeMajor ? input_buffer : transposedInput.data();
Viet Dangaefafaa2019-04-29 18:43:22 +0100465 const float* auxInputData =
466 hasAuxInput ? (timeMajor ? aux_input_buffer : transposedAuxInput.data()) : nullptr;
Michael K. Sanders6af0ba32019-01-22 17:47:45 +0000467 float* outputData = timeMajor ? output_buffer : transposedOutput.data();
Viet Dang1bf001b2019-01-23 00:22:19 +0000468
469 std::vector<float> outputStateInCurrentTimeStep(
470 output_state_in_buffer, output_state_in_buffer + batchSize * outputSize);
471 std::vector<float> cellStateInCurrentTimeStep(cell_state_in_buffer,
472 cell_state_in_buffer + batchSize * numCells);
Michael K. Sanders6af0ba32019-01-22 17:47:45 +0000473 const float* inputCurrentTimeStep =
474 inputData + (forwardSequence ? 0 : batchInputSize * (maxTime - 1));
Viet Dangaefafaa2019-04-29 18:43:22 +0100475 const float* auxInputCurrentTimeStep =
476 hasAuxInput ? (auxInputData + (forwardSequence ? 0 : batchInputSize * (maxTime - 1)))
477 : nullptr;
Michael K. Sanders6af0ba32019-01-22 17:47:45 +0000478 float* outputCurrentTimeStep =
479 outputData + (forwardSequence ? 0 : batchOutputSize * (maxTime - 1));
Slava Shklyaev52fc7ed2020-06-29 15:54:45 +0100480 const int batchInputDelta = (forwardSequence ? 1 : -1) * static_cast<int>(batchInputSize);
481 const int batchOutputDelta = (forwardSequence ? 1 : -1) * static_cast<int>(batchOutputSize);
Michael K. Sanders6af0ba32019-01-22 17:47:45 +0000482
Viet Dangc2ddad92019-01-23 00:50:00 +0000483 for (int t = 0; t < maxTime; ++t) {
Viet Dang75954232019-01-24 07:03:26 +0000484 LSTMStep(params, inputCurrentTimeStep, batchInputShape, input_to_input_weights_buffer,
Viet Dangc2ddad92019-01-23 00:50:00 +0000485 input_to_forget_weights_buffer, input_to_cell_weights_buffer,
486 input_to_output_weights_buffer, input_to_output_weights_shape,
487 recurrent_to_input_weights_buffer, recurrent_to_forget_weights_buffer,
488 recurrent_to_cell_weights_buffer, recurrent_to_output_weights_buffer,
489 recurrent_to_output_weights_shape, cell_to_input_weights_buffer,
Viet Dangaefafaa2019-04-29 18:43:22 +0100490 cell_to_forget_weights_buffer, cell_to_output_weights_buffer,
491 auxInputCurrentTimeStep, aux_input_to_input_weights_buffer,
492 aux_input_to_forget_weights_buffer, aux_input_to_cell_weights_buffer,
493 aux_input_to_output_weights_buffer, input_gate_bias_buffer,
494 forget_gate_bias_buffer, cell_bias_buffer, output_gate_bias_buffer,
495 projection_weights_buffer, projection_bias_buffer,
Viet Dang5fe83682019-01-29 15:25:14 +0000496 outputStateInCurrentTimeStep.data(), cellStateInCurrentTimeStep.data(),
497 input_layer_norm_weights_buffer, forget_layer_norm_weights_buffer,
498 cell_layer_norm_weights_buffer, output_layer_norm_weights_buffer,
499 output_state_out_buffer, cell_state_out_buffer, outputCurrentTimeStep,
500 scratch_buffer_buffer);
Michael K. Sanders6af0ba32019-01-22 17:47:45 +0000501 inputCurrentTimeStep += batchInputDelta;
Viet Dangaefafaa2019-04-29 18:43:22 +0100502 if (hasAuxInput) {
503 auxInputCurrentTimeStep += batchInputDelta;
504 }
Michael K. Sanders6af0ba32019-01-22 17:47:45 +0000505 outputCurrentTimeStep += batchOutputDelta;
Viet Dang1bf001b2019-01-23 00:22:19 +0000506 outputStateInCurrentTimeStep.assign(output_state_out_buffer,
507 output_state_out_buffer + batchSize * outputSize);
508 cellStateInCurrentTimeStep.assign(cell_state_out_buffer,
509 cell_state_out_buffer + batchSize * numCells);
Viet Dangc2ddad92019-01-23 00:50:00 +0000510 }
Viet Dang1bf001b2019-01-23 00:22:19 +0000511
512 if (!timeMajor) {
513 transposeFirstTwoDimensions<float>(transposedOutput.data(), transposedOutputShape,
514 output_buffer);
515 }
516
Viet Dangc2ddad92019-01-23 00:50:00 +0000517 return true;
518}
519
520// static
521bool LSTMCell::LSTMEvalFloat16(
522 const LSTMParams& params, const _Float16* input_buffer, const Shape& input_shape,
523 const _Float16* input_to_input_weights_buffer,
524 const _Float16* input_to_forget_weights_buffer,
525 const _Float16* input_to_cell_weights_buffer,
526 const _Float16* input_to_output_weights_buffer, const Shape& input_to_output_weights_shape,
527 const _Float16* recurrent_to_input_weights_buffer,
528 const _Float16* recurrent_to_forget_weights_buffer,
529 const _Float16* recurrent_to_cell_weights_buffer,
530 const _Float16* recurrent_to_output_weights_buffer,
531 const Shape& recurrent_to_output_weights_shape,
532 const _Float16* cell_to_input_weights_buffer, const _Float16* cell_to_forget_weights_buffer,
Michael K. Sanders6af0ba32019-01-22 17:47:45 +0000533 const _Float16* cell_to_output_weights_buffer, const _Float16* aux_input_buffer,
Viet Dang5fe83682019-01-29 15:25:14 +0000534 const _Float16* aux_input_to_input_weights_buffer,
Michael K. Sanders6af0ba32019-01-22 17:47:45 +0000535 const _Float16* aux_input_to_forget_weights_buffer,
536 const _Float16* aux_input_to_cell_weights_buffer,
537 const _Float16* aux_input_to_output_weights_buffer, const _Float16* input_gate_bias_buffer,
Viet Dangc2ddad92019-01-23 00:50:00 +0000538 const _Float16* forget_gate_bias_buffer, const _Float16* cell_bias_buffer,
539 const _Float16* output_gate_bias_buffer, const _Float16* projection_weights_buffer,
540 const _Float16* projection_bias_buffer, const _Float16* output_state_in_buffer,
541 const _Float16* cell_state_in_buffer, const _Float16* input_layer_norm_weights_buffer,
542 const _Float16* forget_layer_norm_weights_buffer,
543 const _Float16* cell_layer_norm_weights_buffer,
544 const _Float16* output_layer_norm_weights_buffer, _Float16* output_state_out_buffer,
545 _Float16* cell_state_out_buffer, _Float16* output_buffer, _Float16* scratch_buffer_buffer,
Viet Dang5fe83682019-01-29 15:25:14 +0000546 bool timeMajor, bool forwardSequence) {
Viet Dangc2ddad92019-01-23 00:50:00 +0000547 NNTRACE_COMP("LSTMCell::LSTMEvalFloat16");
548
549 const uint32_t inputRank = getNumberOfDimensions(input_shape);
550 NN_CHECK(inputRank == 2 || inputRank == 3);
551
552 const uint32_t maxTime =
553 (inputRank == 3) ? getSizeOfDimension(input_shape, timeMajor ? 0 : 1) : 1;
554 const uint32_t batchSize = (inputRank == 3) ? getSizeOfDimension(input_shape, timeMajor ? 1 : 0)
555 : getSizeOfDimension(input_shape, 0);
556 const uint32_t inputSize = getSizeOfDimension(input_shape, inputRank - 1);
557 const uint32_t numCells = getSizeOfDimension(input_to_output_weights_shape, 0);
558 const uint32_t outputSize = getSizeOfDimension(recurrent_to_output_weights_shape, 1);
559
Viet Dang75954232019-01-24 07:03:26 +0000560 Shape batchInputShape = input_shape;
561 batchInputShape.dimensions = {batchSize, inputSize};
Viet Dangc2ddad92019-01-23 00:50:00 +0000562 const uint32_t batchInputSize = batchSize * inputSize;
563 const uint32_t batchOutputSize = batchSize * outputSize;
564
565 std::vector<float> input_float32(maxTime * batchInputSize);
566 convertFloat16ToFloat32(input_buffer, &input_float32);
567 std::vector<float> input_to_input_weights_float32(numCells * inputSize);
568 if (input_to_input_weights_buffer != nullptr) {
569 convertFloat16ToFloat32(input_to_input_weights_buffer, &input_to_input_weights_float32);
570 }
571 std::vector<float> input_to_forget_weights_float32(numCells * inputSize);
572 convertFloat16ToFloat32(input_to_forget_weights_buffer, &input_to_forget_weights_float32);
573 std::vector<float> input_to_cell_weights_float32(numCells * inputSize);
574 convertFloat16ToFloat32(input_to_cell_weights_buffer, &input_to_cell_weights_float32);
575 std::vector<float> input_to_output_weights_float32(numCells * inputSize);
576 convertFloat16ToFloat32(input_to_output_weights_buffer, &input_to_output_weights_float32);
577
578 std::vector<float> recurrent_to_input_weights_float32(numCells * outputSize);
579 if (recurrent_to_input_weights_buffer != nullptr) {
580 convertFloat16ToFloat32(recurrent_to_input_weights_buffer,
581 &recurrent_to_input_weights_float32);
582 }
583 std::vector<float> recurrent_to_forget_weights_float32(numCells * outputSize);
584 convertFloat16ToFloat32(recurrent_to_forget_weights_buffer,
585 &recurrent_to_forget_weights_float32);
586 std::vector<float> recurrent_to_cell_weights_float32(numCells * outputSize);
587 convertFloat16ToFloat32(recurrent_to_cell_weights_buffer, &recurrent_to_cell_weights_float32);
588 std::vector<float> recurrent_to_output_weights_float32(numCells * outputSize);
589 convertFloat16ToFloat32(recurrent_to_output_weights_buffer,
590 &recurrent_to_output_weights_float32);
591
592 std::vector<float> cell_to_input_weights_float32(numCells);
593 if (cell_to_input_weights_buffer != nullptr) {
594 convertFloat16ToFloat32(cell_to_input_weights_buffer, &cell_to_input_weights_float32);
595 }
596 std::vector<float> cell_to_forget_weights_float32(numCells);
597 if (cell_to_forget_weights_buffer != nullptr) {
598 convertFloat16ToFloat32(cell_to_forget_weights_buffer, &cell_to_forget_weights_float32);
599 }
600 std::vector<float> cell_to_output_weights_float32(numCells);
601 if (cell_to_output_weights_buffer != nullptr) {
602 convertFloat16ToFloat32(cell_to_output_weights_buffer, &cell_to_output_weights_float32);
603 }
604
Michael K. Sanders6af0ba32019-01-22 17:47:45 +0000605 std::vector<float> aux_input_float32(maxTime * batchInputSize);
Viet Danga11c2cd2019-01-29 10:39:58 +0000606 if (aux_input_buffer != nullptr) {
607 convertFloat16ToFloat32(aux_input_buffer, &aux_input_float32);
608 }
Michael K. Sanders6af0ba32019-01-22 17:47:45 +0000609 std::vector<float> aux_input_to_input_weights_float32(numCells * inputSize);
610 if (aux_input_to_input_weights_buffer != nullptr) {
611 convertFloat16ToFloat32(aux_input_to_input_weights_buffer,
612 &aux_input_to_input_weights_float32);
613 }
614 std::vector<float> aux_input_to_forget_weights_float32(numCells * inputSize);
Viet Danga11c2cd2019-01-29 10:39:58 +0000615 if (aux_input_to_forget_weights_buffer != nullptr) {
616 convertFloat16ToFloat32(aux_input_to_forget_weights_buffer,
617 &aux_input_to_forget_weights_float32);
618 }
Michael K. Sanders6af0ba32019-01-22 17:47:45 +0000619 std::vector<float> aux_input_to_cell_weights_float32(numCells * inputSize);
Viet Danga11c2cd2019-01-29 10:39:58 +0000620 if (aux_input_to_cell_weights_buffer != nullptr) {
621 convertFloat16ToFloat32(aux_input_to_cell_weights_buffer,
622 &aux_input_to_cell_weights_float32);
623 }
Michael K. Sanders6af0ba32019-01-22 17:47:45 +0000624 std::vector<float> aux_input_to_output_weights_float32(numCells * inputSize);
Viet Danga11c2cd2019-01-29 10:39:58 +0000625 if (aux_input_to_output_weights_buffer != nullptr) {
626 convertFloat16ToFloat32(aux_input_to_output_weights_buffer,
627 &aux_input_to_output_weights_float32);
628 }
Michael K. Sanders6af0ba32019-01-22 17:47:45 +0000629
Viet Dangc2ddad92019-01-23 00:50:00 +0000630 std::vector<float> input_gate_bias_float32(numCells);
631 if (input_gate_bias_buffer != nullptr) {
632 convertFloat16ToFloat32(input_gate_bias_buffer, &input_gate_bias_float32);
633 }
634 std::vector<float> forget_gate_bias_float32(numCells);
635 convertFloat16ToFloat32(forget_gate_bias_buffer, &forget_gate_bias_float32);
636 std::vector<float> cell_bias_float32(numCells);
637 convertFloat16ToFloat32(cell_bias_buffer, &cell_bias_float32);
638 std::vector<float> output_gate_bias_float32(numCells);
639 convertFloat16ToFloat32(output_gate_bias_buffer, &output_gate_bias_float32);
640
641 std::vector<float> projection_weights_float32(numCells * outputSize);
642 if (projection_weights_buffer != nullptr) {
643 convertFloat16ToFloat32(projection_weights_buffer, &projection_weights_float32);
644 }
645 std::vector<float> projection_bias_float32(outputSize);
646 if (projection_bias_buffer != nullptr) {
647 convertFloat16ToFloat32(projection_bias_buffer, &projection_bias_float32);
648 }
649
Viet Dangc2ddad92019-01-23 00:50:00 +0000650 std::vector<float> input_layer_norm_weights_float32(numCells);
651 if (input_layer_norm_weights_buffer != nullptr) {
652 convertFloat16ToFloat32(input_layer_norm_weights_buffer, &input_layer_norm_weights_float32);
653 }
654 std::vector<float> forget_layer_norm_weights_float32(numCells);
655 if (forget_layer_norm_weights_buffer != nullptr) {
656 convertFloat16ToFloat32(forget_layer_norm_weights_buffer,
657 &forget_layer_norm_weights_float32);
658 }
659 std::vector<float> cell_layer_norm_weights_float32(numCells);
660 if (cell_layer_norm_weights_buffer != nullptr) {
661 convertFloat16ToFloat32(cell_layer_norm_weights_buffer, &cell_layer_norm_weights_float32);
662 }
663 std::vector<float> output_layer_norm_weights_float32(numCells);
664 if (output_layer_norm_weights_buffer != nullptr) {
665 convertFloat16ToFloat32(output_layer_norm_weights_buffer,
666 &output_layer_norm_weights_float32);
667 }
668
669 std::vector<float> output_state_out_float32(batchOutputSize);
670 convertFloat16ToFloat32(output_state_out_buffer, &output_state_out_float32);
671 std::vector<float> cell_state_out_float32(batchSize * numCells);
672 convertFloat16ToFloat32(cell_state_out_buffer, &cell_state_out_float32);
673
674 std::vector<float> output_float32(maxTime * batchOutputSize);
675 convertFloat16ToFloat32(output_buffer, &output_float32);
676 std::vector<float> scratch_buffer_float32(params.use_cifg ? 3 * batchSize * numCells
677 : 4 * batchSize * numCells);
678 convertFloat16ToFloat32(scratch_buffer_buffer, &scratch_buffer_float32);
679
Viet Dang1bf001b2019-01-23 00:22:19 +0000680 std::vector<float> transposedInput;
Viet Dangaefafaa2019-04-29 18:43:22 +0100681 const bool hasAuxInput = (aux_input_buffer != nullptr);
682 std::vector<float> transposedAuxInput;
Viet Dang1bf001b2019-01-23 00:22:19 +0000683 std::vector<float> transposedOutput;
684 Shape transposedInputShape;
685 Shape transposedOutputShape;
686 if (!timeMajor) {
687 transposedInput.resize(maxTime * batchInputSize);
Viet Dang1bf001b2019-01-23 00:22:19 +0000688 transposeFirstTwoDimensions<float>(input_float32.data(), input_shape,
689 transposedInput.data());
Viet Dangaefafaa2019-04-29 18:43:22 +0100690 if (hasAuxInput) {
691 transposedAuxInput.resize(maxTime * batchInputSize);
692 transposeFirstTwoDimensions<float>(aux_input_float32.data(), input_shape,
693 transposedAuxInput.data());
694 }
Viet Dang1bf001b2019-01-23 00:22:19 +0000695 transposeFirstTwoDimensions(input_shape, &transposedInputShape);
Viet Dangaefafaa2019-04-29 18:43:22 +0100696 transposedOutput.resize(maxTime * batchOutputSize);
Viet Dang1bf001b2019-01-23 00:22:19 +0000697 transposedOutputShape = transposedInputShape;
698 transposedOutputShape.dimensions[2] = outputSize;
699 }
Michael K. Sanders6af0ba32019-01-22 17:47:45 +0000700 const float* inputData = timeMajor ? input_float32.data() : transposedInput.data();
Viet Dangaefafaa2019-04-29 18:43:22 +0100701 const float* auxInputData =
702 hasAuxInput ? (timeMajor ? aux_input_float32.data() : transposedAuxInput.data())
703 : nullptr;
Michael K. Sanders6af0ba32019-01-22 17:47:45 +0000704 float* outputData = timeMajor ? output_float32.data() : transposedOutput.data();
Viet Dang1bf001b2019-01-23 00:22:19 +0000705
706 std::vector<float> outputStateInCurrentTimeStep(batchSize * outputSize);
707 convertFloat16ToFloat32(output_state_in_buffer, &outputStateInCurrentTimeStep);
708 std::vector<float> cellStateInCurrentTimeStep(batchSize * numCells);
709 convertFloat16ToFloat32(cell_state_in_buffer, &cellStateInCurrentTimeStep);
Michael K. Sanders6af0ba32019-01-22 17:47:45 +0000710
711 const float* inputCurrentTimeStep =
712 inputData + (forwardSequence ? 0 : batchInputSize * (maxTime - 1));
Viet Dangaefafaa2019-04-29 18:43:22 +0100713 const float* auxInputCurrentTimeStep =
714 hasAuxInput ? (auxInputData + (forwardSequence ? 0 : batchInputSize * (maxTime - 1)))
715 : nullptr;
Michael K. Sanders6af0ba32019-01-22 17:47:45 +0000716 float* outputCurrentTimeStep =
717 outputData + (forwardSequence ? 0 : batchOutputSize * (maxTime - 1));
Slava Shklyaev52fc7ed2020-06-29 15:54:45 +0100718 const int batchInputDelta = (forwardSequence ? 1 : -1) * static_cast<int>(batchInputSize);
719 const int batchOutputDelta = (forwardSequence ? 1 : -1) * static_cast<int>(batchOutputSize);
Michael K. Sanders6af0ba32019-01-22 17:47:45 +0000720
Viet Dangc2ddad92019-01-23 00:50:00 +0000721 for (int t = 0; t < maxTime; ++t) {
Viet Dang75954232019-01-24 07:03:26 +0000722 LSTMStep(params, inputCurrentTimeStep, batchInputShape,
723 input_to_input_weights_float32.data(), input_to_forget_weights_float32.data(),
724 input_to_cell_weights_float32.data(), input_to_output_weights_float32.data(),
725 input_to_output_weights_shape, recurrent_to_input_weights_float32.data(),
Viet Dangc2ddad92019-01-23 00:50:00 +0000726 recurrent_to_forget_weights_float32.data(),
727 recurrent_to_cell_weights_float32.data(),
728 recurrent_to_output_weights_float32.data(), recurrent_to_output_weights_shape,
729 cell_to_input_weights_float32.data(), cell_to_forget_weights_float32.data(),
Viet Dangaefafaa2019-04-29 18:43:22 +0100730 cell_to_output_weights_float32.data(), auxInputCurrentTimeStep,
Michael K. Sanders6af0ba32019-01-22 17:47:45 +0000731 aux_input_to_input_weights_float32.data(),
732 aux_input_to_forget_weights_float32.data(),
733 aux_input_to_cell_weights_float32.data(),
734 aux_input_to_output_weights_float32.data(), input_gate_bias_float32.data(),
Viet Dangc2ddad92019-01-23 00:50:00 +0000735 forget_gate_bias_float32.data(), cell_bias_float32.data(),
736 output_gate_bias_float32.data(), projection_weights_float32.data(),
Viet Dang1bf001b2019-01-23 00:22:19 +0000737 projection_bias_float32.data(), outputStateInCurrentTimeStep.data(),
738 cellStateInCurrentTimeStep.data(), input_layer_norm_weights_float32.data(),
Viet Dangc2ddad92019-01-23 00:50:00 +0000739 forget_layer_norm_weights_float32.data(), cell_layer_norm_weights_float32.data(),
740 output_layer_norm_weights_float32.data(), output_state_out_float32.data(),
741 cell_state_out_float32.data(), outputCurrentTimeStep,
742 scratch_buffer_float32.data());
Michael K. Sanders6af0ba32019-01-22 17:47:45 +0000743 inputCurrentTimeStep += batchInputDelta;
Viet Dangaefafaa2019-04-29 18:43:22 +0100744 if (hasAuxInput) {
745 auxInputCurrentTimeStep += batchInputDelta;
746 }
Michael K. Sanders6af0ba32019-01-22 17:47:45 +0000747 outputCurrentTimeStep += batchOutputDelta;
Viet Dang1bf001b2019-01-23 00:22:19 +0000748 outputStateInCurrentTimeStep = output_state_out_float32;
749 cellStateInCurrentTimeStep = cell_state_out_float32;
750 }
751
752 if (!timeMajor) {
753 transposeFirstTwoDimensions<float>(transposedOutput.data(), transposedOutputShape,
754 output_float32.data());
Viet Dangc2ddad92019-01-23 00:50:00 +0000755 }
756
757 convertFloat32ToFloat16(output_state_out_float32, output_state_out_buffer);
758 convertFloat32ToFloat16(cell_state_out_float32, cell_state_out_buffer);
759 convertFloat32ToFloat16(output_float32, output_buffer);
760 convertFloat32ToFloat16(scratch_buffer_float32, scratch_buffer_buffer);
761 return true;
762}
763
764// static
Michael K. Sanders70424ae2019-01-21 11:02:44 +0000765bool LSTMCell::LSTMStep(
766 const LSTMParams& params, const float* input_buffer, const Shape& input_shape,
767 const float* input_to_input_weights_buffer, const float* input_to_forget_weights_buffer,
768 const float* input_to_cell_weights_buffer, const float* input_to_output_weights_buffer,
769 const Shape& input_to_output_weights_shape, const float* recurrent_to_input_weights_buffer,
Lev Proleevd1c222a2018-12-28 13:24:24 +0000770 const float* recurrent_to_forget_weights_buffer,
771 const float* recurrent_to_cell_weights_buffer,
Michael K. Sanders70424ae2019-01-21 11:02:44 +0000772 const float* recurrent_to_output_weights_buffer,
773 const Shape& recurrent_to_output_weights_shape, const float* cell_to_input_weights_buffer,
Lev Proleevd1c222a2018-12-28 13:24:24 +0000774 const float* cell_to_forget_weights_buffer, const float* cell_to_output_weights_buffer,
Viet Dang5fe83682019-01-29 15:25:14 +0000775 const float* aux_input_buffer, const float* aux_input_to_input_weights_buffer,
Michael K. Sanders6af0ba32019-01-22 17:47:45 +0000776 const float* aux_input_to_forget_weights_buffer,
777 const float* aux_input_to_cell_weights_buffer,
778 const float* aux_input_to_output_weights_buffer, const float* input_gate_bias_buffer,
779 const float* forget_gate_bias_buffer, const float* cell_bias_buffer,
780 const float* output_gate_bias_buffer, const float* projection_weights_buffer,
781 const float* projection_bias_buffer, const float* output_state_in_buffer,
782 const float* cell_state_in_buffer, const float* input_layer_norm_weights_buffer,
783 const float* forget_layer_norm_weights_buffer, const float* cell_layer_norm_weights_buffer,
784 const float* output_layer_norm_weights_buffer, float* output_state_out_buffer,
785 float* cell_state_out_buffer, float* output_buffer, float* scratch_buffer_buffer) {
Michael K. Sanders70424ae2019-01-21 11:02:44 +0000786 NNTRACE_COMP("LSTMCell::LSTMStep");
Lev Proleevd1c222a2018-12-28 13:24:24 +0000787
Michael K. Sanders70424ae2019-01-21 11:02:44 +0000788 const uint32_t n_batch = input_shape.dimensions[0];
789 const uint32_t n_input = input_shape.dimensions[1];
Lev Proleevd1c222a2018-12-28 13:24:24 +0000790 // n_cell and n_output will be the same size when there is no projection.
Michael K. Sanders70424ae2019-01-21 11:02:44 +0000791 const uint32_t n_cell = input_to_output_weights_shape.dimensions[0];
792 const uint32_t n_output = recurrent_to_output_weights_shape.dimensions[1];
Michael K. Sanders6af0ba32019-01-22 17:47:45 +0000793 const uint32_t n_aux_input = aux_input_buffer == nullptr ? 0 : n_input;
Lev Proleevd1c222a2018-12-28 13:24:24 +0000794
795 // Index the scratch buffers pointers to the global scratch buffer.
796 float* input_gate_scratch = nullptr;
797 float* cell_scratch = nullptr;
798 float* forget_gate_scratch = nullptr;
799 float* output_gate_scratch = nullptr;
Michael K. Sanders70424ae2019-01-21 11:02:44 +0000800 if (params.use_cifg) {
Lev Proleevd1c222a2018-12-28 13:24:24 +0000801 cell_scratch = scratch_buffer_buffer;
802 forget_gate_scratch = cell_scratch + n_cell * n_batch;
803 output_gate_scratch = cell_scratch + 2 * n_cell * n_batch;
804 } else {
805 input_gate_scratch = scratch_buffer_buffer;
806 cell_scratch = input_gate_scratch + n_cell * n_batch;
807 forget_gate_scratch = input_gate_scratch + 2 * n_cell * n_batch;
808 output_gate_scratch = input_gate_scratch + 3 * n_cell * n_batch;
809 }
810
Michael K. Sanders70424ae2019-01-21 11:02:44 +0000811 if (!params.use_layer_norm) {
Lev Proleevd1c222a2018-12-28 13:24:24 +0000812 // Initialize scratch buffers with bias.
Michael K. Sanders70424ae2019-01-21 11:02:44 +0000813 if (!params.use_cifg) {
Lev Proleevd1c222a2018-12-28 13:24:24 +0000814 tflite::tensor_utils::VectorBatchVectorAssign(input_gate_bias_buffer, n_cell, n_batch,
815 input_gate_scratch);
816 }
817 tflite::tensor_utils::VectorBatchVectorAssign(forget_gate_bias_buffer, n_cell, n_batch,
818 forget_gate_scratch);
819 tflite::tensor_utils::VectorBatchVectorAssign(cell_bias_buffer, n_cell, n_batch,
820 cell_scratch);
821 tflite::tensor_utils::VectorBatchVectorAssign(output_gate_bias_buffer, n_cell, n_batch,
822 output_gate_scratch);
823 } else {
824 // Initialize scratch buffers with zeroes.
Michael K. Sanders70424ae2019-01-21 11:02:44 +0000825 if (!params.use_cifg) {
Lev Proleev8bd6eb72020-02-06 14:39:41 +0000826 std::fill_n(input_gate_scratch, n_cell * n_batch, 0.0f);
Lev Proleevd1c222a2018-12-28 13:24:24 +0000827 }
Lev Proleev8bd6eb72020-02-06 14:39:41 +0000828 std::fill_n(forget_gate_scratch, n_cell * n_batch, 0.0f);
829 std::fill_n(cell_scratch, n_cell * n_batch, 0.0f);
830 std::fill_n(output_gate_scratch, n_cell * n_batch, 0.0f);
Lev Proleevd1c222a2018-12-28 13:24:24 +0000831 }
832
833 // For each batch and cell: compute input_weight * input.
Michael K. Sanders70424ae2019-01-21 11:02:44 +0000834 if (!params.use_cifg) {
Lev Proleev82112fb2021-02-25 13:36:02 +0000835 tflite::tensor_utils::MatrixBatchVectorMultiplyAccumulate(input_to_input_weights_buffer,
836 n_cell, n_input, input_buffer,
837 n_batch, input_gate_scratch);
Lev Proleevd1c222a2018-12-28 13:24:24 +0000838 }
Lev Proleev82112fb2021-02-25 13:36:02 +0000839 tflite::tensor_utils::MatrixBatchVectorMultiplyAccumulate(input_to_forget_weights_buffer,
840 n_cell, n_input, input_buffer,
841 n_batch, forget_gate_scratch);
Lev Proleevd1c222a2018-12-28 13:24:24 +0000842 tflite::tensor_utils::MatrixBatchVectorMultiplyAccumulate(
Lev Proleev82112fb2021-02-25 13:36:02 +0000843 input_to_cell_weights_buffer, n_cell, n_input, input_buffer, n_batch, cell_scratch);
844 tflite::tensor_utils::MatrixBatchVectorMultiplyAccumulate(input_to_output_weights_buffer,
845 n_cell, n_input, input_buffer,
846 n_batch, output_gate_scratch);
Lev Proleevd1c222a2018-12-28 13:24:24 +0000847
Michael K. Sanders6af0ba32019-01-22 17:47:45 +0000848 // If auxiliary input is available then compute aux_input_weight * aux_input
849 if (aux_input_buffer != nullptr) {
850 if (!params.use_cifg) {
851 tflite::tensor_utils::MatrixBatchVectorMultiplyAccumulate(
852 aux_input_to_input_weights_buffer, n_cell, n_aux_input, aux_input_buffer,
Lev Proleev82112fb2021-02-25 13:36:02 +0000853 n_batch, input_gate_scratch);
Michael K. Sanders6af0ba32019-01-22 17:47:45 +0000854 }
855
856 tflite::tensor_utils::MatrixBatchVectorMultiplyAccumulate(
857 aux_input_to_forget_weights_buffer, n_cell, n_aux_input, aux_input_buffer, n_batch,
Lev Proleev82112fb2021-02-25 13:36:02 +0000858 forget_gate_scratch);
Michael K. Sanders6af0ba32019-01-22 17:47:45 +0000859 tflite::tensor_utils::MatrixBatchVectorMultiplyAccumulate(
860 aux_input_to_cell_weights_buffer, n_cell, n_aux_input, aux_input_buffer, n_batch,
Lev Proleev82112fb2021-02-25 13:36:02 +0000861 cell_scratch);
Michael K. Sanders6af0ba32019-01-22 17:47:45 +0000862 tflite::tensor_utils::MatrixBatchVectorMultiplyAccumulate(
863 aux_input_to_output_weights_buffer, n_cell, n_aux_input, aux_input_buffer, n_batch,
Lev Proleev82112fb2021-02-25 13:36:02 +0000864 output_gate_scratch);
Michael K. Sanders6af0ba32019-01-22 17:47:45 +0000865 }
866
Lev Proleevd1c222a2018-12-28 13:24:24 +0000867 // For each batch and cell: compute recurrent_weight * output_state.
Michael K. Sanders70424ae2019-01-21 11:02:44 +0000868 if (!params.use_cifg) {
Lev Proleevd1c222a2018-12-28 13:24:24 +0000869 tflite::tensor_utils::MatrixBatchVectorMultiplyAccumulate(
870 recurrent_to_input_weights_buffer, n_cell, n_output, output_state_in_buffer,
Lev Proleev82112fb2021-02-25 13:36:02 +0000871 n_batch, input_gate_scratch);
Lev Proleevd1c222a2018-12-28 13:24:24 +0000872 }
873 tflite::tensor_utils::MatrixBatchVectorMultiplyAccumulate(
874 recurrent_to_forget_weights_buffer, n_cell, n_output, output_state_in_buffer, n_batch,
Lev Proleev82112fb2021-02-25 13:36:02 +0000875 forget_gate_scratch);
Lev Proleevd1c222a2018-12-28 13:24:24 +0000876 tflite::tensor_utils::MatrixBatchVectorMultiplyAccumulate(
877 recurrent_to_cell_weights_buffer, n_cell, n_output, output_state_in_buffer, n_batch,
Lev Proleev82112fb2021-02-25 13:36:02 +0000878 cell_scratch);
Lev Proleevd1c222a2018-12-28 13:24:24 +0000879 tflite::tensor_utils::MatrixBatchVectorMultiplyAccumulate(
880 recurrent_to_output_weights_buffer, n_cell, n_output, output_state_in_buffer, n_batch,
Lev Proleev82112fb2021-02-25 13:36:02 +0000881 output_gate_scratch);
Lev Proleevd1c222a2018-12-28 13:24:24 +0000882
883 // For each batch and cell: update input gate.
Michael K. Sanders70424ae2019-01-21 11:02:44 +0000884 if (!params.use_cifg) {
885 if (params.use_peephole) {
Lev Proleevd1c222a2018-12-28 13:24:24 +0000886 tflite::tensor_utils::VectorBatchVectorCwiseProductAccumulate(
887 cell_to_input_weights_buffer, n_cell, cell_state_in_buffer, n_batch,
888 input_gate_scratch);
889 }
Michael K. Sanders70424ae2019-01-21 11:02:44 +0000890 if (params.use_layer_norm) {
Lev Proleevd1c222a2018-12-28 13:24:24 +0000891 tflite::tensor_utils::MeanStddevNormalization(input_gate_scratch, input_gate_scratch,
Lev Proleev8bd6eb72020-02-06 14:39:41 +0000892 n_cell, n_batch);
Lev Proleevd1c222a2018-12-28 13:24:24 +0000893 tflite::tensor_utils::VectorBatchVectorCwiseProduct(input_layer_norm_weights_buffer,
894 n_cell, input_gate_scratch, n_batch,
895 input_gate_scratch);
896 tflite::tensor_utils::VectorBatchVectorAdd(input_gate_bias_buffer, n_cell, n_batch,
897 input_gate_scratch);
898 }
899 tflite::tensor_utils::ApplySigmoidToVector(input_gate_scratch, n_cell * n_batch,
900 input_gate_scratch);
901 }
902
903 // For each batch and cell: update forget gate.
Michael K. Sanders70424ae2019-01-21 11:02:44 +0000904 if (params.use_peephole) {
Lev Proleevd1c222a2018-12-28 13:24:24 +0000905 tflite::tensor_utils::VectorBatchVectorCwiseProductAccumulate(cell_to_forget_weights_buffer,
906 n_cell, cell_state_in_buffer,
907 n_batch, forget_gate_scratch);
908 }
Michael K. Sanders70424ae2019-01-21 11:02:44 +0000909 if (params.use_layer_norm) {
Lev Proleevd1c222a2018-12-28 13:24:24 +0000910 tflite::tensor_utils::MeanStddevNormalization(forget_gate_scratch, forget_gate_scratch,
Lev Proleev8bd6eb72020-02-06 14:39:41 +0000911 n_cell, n_batch);
Lev Proleevd1c222a2018-12-28 13:24:24 +0000912 tflite::tensor_utils::VectorBatchVectorCwiseProduct(forget_layer_norm_weights_buffer,
913 n_cell, forget_gate_scratch, n_batch,
914 forget_gate_scratch);
915 tflite::tensor_utils::VectorBatchVectorAdd(forget_gate_bias_buffer, n_cell, n_batch,
916 forget_gate_scratch);
917 }
918 tflite::tensor_utils::ApplySigmoidToVector(forget_gate_scratch, n_cell * n_batch,
919 forget_gate_scratch);
920
921 // For each batch and cell: update the cell.
Michael K. Sanders70424ae2019-01-21 11:02:44 +0000922 if (params.use_layer_norm) {
Lev Proleev8bd6eb72020-02-06 14:39:41 +0000923 tflite::tensor_utils::MeanStddevNormalization(cell_scratch, cell_scratch, n_cell, n_batch);
Lev Proleevd1c222a2018-12-28 13:24:24 +0000924 tflite::tensor_utils::VectorBatchVectorCwiseProduct(cell_layer_norm_weights_buffer, n_cell,
925 cell_scratch, n_batch, cell_scratch);
926 tflite::tensor_utils::VectorBatchVectorAdd(cell_bias_buffer, n_cell, n_batch, cell_scratch);
927 }
928 tflite::tensor_utils::VectorVectorCwiseProduct(forget_gate_scratch, cell_state_in_buffer,
929 n_batch * n_cell, cell_state_out_buffer);
Michael K. Sanders70424ae2019-01-21 11:02:44 +0000930 tflite::tensor_utils::ApplyActivationToVector(cell_scratch, n_batch * n_cell, params.activation,
931 cell_scratch);
932 if (params.use_cifg) {
Lev Proleevd1c222a2018-12-28 13:24:24 +0000933 tflite::tensor_utils::Sub1Vector(forget_gate_scratch, n_batch * n_cell,
934 forget_gate_scratch);
935 tflite::tensor_utils::VectorVectorCwiseProductAccumulate(
936 cell_scratch, forget_gate_scratch, n_batch * n_cell, cell_state_out_buffer);
937 } else {
938 tflite::tensor_utils::VectorVectorCwiseProductAccumulate(
939 cell_scratch, input_gate_scratch, n_batch * n_cell, cell_state_out_buffer);
940 }
Michael K. Sanders70424ae2019-01-21 11:02:44 +0000941 if (params.cell_clip > 0.0) {
Lev Proleev82112fb2021-02-25 13:36:02 +0000942 tflite::tensor_utils::CwiseClipping(cell_state_out_buffer, n_batch * n_cell,
943 params.cell_clip);
Lev Proleevd1c222a2018-12-28 13:24:24 +0000944 }
945
946 // For each batch and cell: update the output gate.
Michael K. Sanders70424ae2019-01-21 11:02:44 +0000947 if (params.use_peephole) {
Lev Proleevd1c222a2018-12-28 13:24:24 +0000948 tflite::tensor_utils::VectorBatchVectorCwiseProductAccumulate(cell_to_output_weights_buffer,
949 n_cell, cell_state_out_buffer,
950 n_batch, output_gate_scratch);
951 }
Michael K. Sanders70424ae2019-01-21 11:02:44 +0000952 if (params.use_layer_norm) {
Lev Proleevd1c222a2018-12-28 13:24:24 +0000953 tflite::tensor_utils::MeanStddevNormalization(output_gate_scratch, output_gate_scratch,
Lev Proleev8bd6eb72020-02-06 14:39:41 +0000954 n_cell, n_batch);
Lev Proleevd1c222a2018-12-28 13:24:24 +0000955 tflite::tensor_utils::VectorBatchVectorCwiseProduct(output_layer_norm_weights_buffer,
956 n_cell, output_gate_scratch, n_batch,
957 output_gate_scratch);
958 tflite::tensor_utils::VectorBatchVectorAdd(output_gate_bias_buffer, n_cell, n_batch,
959 output_gate_scratch);
960 }
961 tflite::tensor_utils::ApplySigmoidToVector(output_gate_scratch, n_batch * n_cell,
962 output_gate_scratch);
963 tflite::tensor_utils::ApplyActivationToVector(cell_state_out_buffer, n_batch * n_cell,
Michael K. Sanders70424ae2019-01-21 11:02:44 +0000964 params.activation, cell_scratch);
Lev Proleevd1c222a2018-12-28 13:24:24 +0000965 tflite::tensor_utils::VectorVectorCwiseProduct(output_gate_scratch, cell_scratch,
966 n_batch * n_cell, output_gate_scratch);
967
968 // For each batch: update the projection and output_state.
Michael K. Sanders70424ae2019-01-21 11:02:44 +0000969 if (params.use_projection_weight) {
970 if (params.use_projection_bias) {
Lev Proleevd1c222a2018-12-28 13:24:24 +0000971 tflite::tensor_utils::VectorBatchVectorAssign(projection_bias_buffer, n_output, n_batch,
972 output_buffer);
973 } else {
Lev Proleev8bd6eb72020-02-06 14:39:41 +0000974 std::fill_n(output_buffer, n_batch * n_output, 0.0f);
Lev Proleevd1c222a2018-12-28 13:24:24 +0000975 }
976 tflite::tensor_utils::MatrixBatchVectorMultiplyAccumulate(
977 projection_weights_buffer, n_output, n_cell, output_gate_scratch, n_batch,
Lev Proleev82112fb2021-02-25 13:36:02 +0000978 output_buffer);
Michael K. Sanders70424ae2019-01-21 11:02:44 +0000979 if (params.proj_clip > 0.0) {
Lev Proleev82112fb2021-02-25 13:36:02 +0000980 tflite::tensor_utils::CwiseClipping(output_buffer, n_batch * n_output,
981 params.proj_clip);
Lev Proleevd1c222a2018-12-28 13:24:24 +0000982 }
983 } else {
Lev Proleev8bd6eb72020-02-06 14:39:41 +0000984 std::copy_n(output_gate_scratch, n_batch * n_output, output_buffer);
Lev Proleevd1c222a2018-12-28 13:24:24 +0000985 }
Lev Proleev8bd6eb72020-02-06 14:39:41 +0000986 std::copy_n(output_buffer, n_batch * n_output, output_state_out_buffer);
Lev Proleevd1c222a2018-12-28 13:24:24 +0000987 return true;
Yang Nid0ea9fd2017-07-28 16:23:46 -0700988}
989
990bool LSTMCell::Eval() {
Lev Proleevd1c222a2018-12-28 13:24:24 +0000991 switch (input_->type) {
992 case OperandType::TENSOR_FLOAT32: {
Viet Dangc2ddad92019-01-23 00:50:00 +0000993 LSTMEvalFloat32(params_, GetBuffer<const float>(input_), input_->shape(),
994 GetBuffer<const float>(input_to_input_weights_),
995 GetBuffer<const float>(input_to_forget_weights_),
996 GetBuffer<const float>(input_to_cell_weights_),
997 GetBuffer<const float>(input_to_output_weights_),
998 input_to_output_weights_->shape(),
999 GetBuffer<const float>(recurrent_to_input_weights_),
1000 GetBuffer<const float>(recurrent_to_forget_weights_),
1001 GetBuffer<const float>(recurrent_to_cell_weights_),
1002 GetBuffer<const float>(recurrent_to_output_weights_),
1003 recurrent_to_output_weights_->shape(),
1004 GetBuffer<const float>(cell_to_input_weights_),
1005 GetBuffer<const float>(cell_to_forget_weights_),
1006 GetBuffer<const float>(cell_to_output_weights_),
Viet Dang5fe83682019-01-29 15:25:14 +00001007 /*aux_input_buffer=*/nullptr,
Michael K. Sanders6af0ba32019-01-22 17:47:45 +00001008 /*aux_input_to_input_weights_buffer=*/nullptr,
1009 /*aux_input_to_forget_weights_buffer=*/nullptr,
1010 /*aux_input_to_cell_weights_buffer=*/nullptr,
1011 /*aux_input_to_output_weights_buffer=*/nullptr,
Viet Dangc2ddad92019-01-23 00:50:00 +00001012 GetBuffer<const float>(input_gate_bias_),
1013 GetBuffer<const float>(forget_gate_bias_),
1014 GetBuffer<const float>(cell_bias_),
1015 GetBuffer<const float>(output_gate_bias_),
1016 GetBuffer<const float>(projection_weights_),
1017 GetBuffer<const float>(projection_bias_),
1018 GetBuffer<const float>(output_state_in_),
1019 GetBuffer<const float>(cell_state_in_),
1020 GetBuffer<const float>(input_layer_norm_weights_),
1021 GetBuffer<const float>(forget_layer_norm_weights_),
1022 GetBuffer<const float>(cell_layer_norm_weights_),
1023 GetBuffer<const float>(output_layer_norm_weights_),
1024 GetBuffer<float>(output_state_out_), GetBuffer<float>(cell_state_out_),
1025 GetBuffer<float>(output_), GetBuffer<float>(scratch_buffer_));
Lev Proleevd1c222a2018-12-28 13:24:24 +00001026 } break;
1027 case OperandType::TENSOR_FLOAT16: {
Viet Dangc2ddad92019-01-23 00:50:00 +00001028 LSTMEvalFloat16(params_, GetBuffer<const _Float16>(input_), input_->shape(),
1029 GetOptionalBuffer<const _Float16>(input_to_input_weights_),
1030 GetBuffer<const _Float16>(input_to_forget_weights_),
1031 GetBuffer<const _Float16>(input_to_cell_weights_),
1032 GetBuffer<const _Float16>(input_to_output_weights_),
1033 input_to_output_weights_->shape(),
1034 GetOptionalBuffer<const _Float16>(recurrent_to_input_weights_),
1035 GetBuffer<const _Float16>(recurrent_to_forget_weights_),
1036 GetBuffer<const _Float16>(recurrent_to_cell_weights_),
1037 GetBuffer<const _Float16>(recurrent_to_output_weights_),
1038 recurrent_to_output_weights_->shape(),
1039 GetOptionalBuffer<const _Float16>(cell_to_input_weights_),
1040 GetOptionalBuffer<const _Float16>(cell_to_forget_weights_),
1041 GetOptionalBuffer<const _Float16>(cell_to_output_weights_),
Viet Dang5fe83682019-01-29 15:25:14 +00001042 /*aux_input_buffer=*/nullptr,
Michael K. Sanders6af0ba32019-01-22 17:47:45 +00001043 /*aux_input_to_input_weights_buffer=*/nullptr,
1044 /*aux_input_to_forget_weights_buffer=*/nullptr,
1045 /*aux_input_to_cell_weights_buffer=*/nullptr,
1046 /*aux_input_to_output_weights_buffer=*/nullptr,
Viet Dangc2ddad92019-01-23 00:50:00 +00001047 GetOptionalBuffer<const _Float16>(input_gate_bias_),
1048 GetBuffer<const _Float16>(forget_gate_bias_),
1049 GetBuffer<const _Float16>(cell_bias_),
1050 GetBuffer<const _Float16>(output_gate_bias_),
1051 GetOptionalBuffer<const _Float16>(projection_weights_),
1052 GetOptionalBuffer<const _Float16>(projection_bias_),
1053 GetBuffer<const _Float16>(output_state_in_),
1054 GetBuffer<const _Float16>(cell_state_in_),
1055 GetOptionalBuffer<const _Float16>(input_layer_norm_weights_),
1056 GetOptionalBuffer<const _Float16>(forget_layer_norm_weights_),
1057 GetOptionalBuffer<const _Float16>(cell_layer_norm_weights_),
1058 GetOptionalBuffer<const _Float16>(output_layer_norm_weights_),
1059 GetBuffer<_Float16>(output_state_out_),
1060 GetBuffer<_Float16>(cell_state_out_), GetBuffer<_Float16>(output_),
1061 GetBuffer<_Float16>(scratch_buffer_));
Lev Proleevd1c222a2018-12-28 13:24:24 +00001062 } break;
1063 default: {
1064 LOG(ERROR) << "Unsupported data type: " << static_cast<int>(input_->type);
1065 return false;
1066 }
Lev Proleev7f4d4c72018-11-08 12:06:38 +00001067 }
Lev Proleevd1c222a2018-12-28 13:24:24 +00001068 return true;
Yang Nid0ea9fd2017-07-28 16:23:46 -07001069}
1070
1071} // namespace nn
1072} // namespace android