blob: 0598536e26a7ce1fa307789c61d3da47630dc2e5 [file] [log] [blame]
Lev Proleev471aa212019-01-22 21:00:53 +00001/*
2 * Copyright (C) 2019 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17#define LOG_TAG "Operations"
18
Lev Proleevfbc2a3d2020-01-14 17:35:36 +000019#include <algorithm>
20#include <utility>
21#include <vector>
22
Lev Proleev471aa212019-01-22 21:00:53 +000023#include "OperationResolver.h"
24#include "RNN.h"
25
26namespace android {
27namespace nn {
28namespace bidirectional_sequence_rnn {
29
30constexpr uint32_t kNumInputs = 15;
31constexpr uint32_t kInputTensor = 0;
32// Forward cell tensors
33constexpr uint32_t kFwWeightsTensor = 1;
34constexpr uint32_t kFwRecurrentWeightsTensor = 2;
35constexpr uint32_t kFwBiasTensor = 3;
36constexpr uint32_t kFwHiddenStateTensor = 4;
37// Backward cell tensors
38constexpr uint32_t kBwWeightsTensor = 5;
39constexpr uint32_t kBwRecurrentWeightsTensor = 6;
40constexpr uint32_t kBwBiasTensor = 7;
41constexpr uint32_t kBwHiddenStateTensor = 8;
42// Auxiliary inputs
43constexpr uint32_t kAuxInputTensor = 9; // optional
44constexpr uint32_t kFwAuxWeightsTensor = 10; // optional
45constexpr uint32_t kBwAuxWeightsTensor = 11; // optional
46// Cell parameters
47constexpr uint32_t kActivationParam = 12;
48constexpr uint32_t kTimeMajorParam = 13;
49constexpr uint32_t kMergeOutputsParam = 14;
50
Lev Proleevfbc2a3d2020-01-14 17:35:36 +000051constexpr uint32_t kNumOutputs = 2;
52constexpr uint32_t kNumOutputsMerged = 1;
53constexpr uint32_t kNumOutputsWithState = 4;
54constexpr uint32_t kNumOutputsMergedWithState = 3;
55
Lev Proleev471aa212019-01-22 21:00:53 +000056constexpr uint32_t kFwOutputTensor = 0;
57constexpr uint32_t kBwOutputTensor = 1; // Only if mergeOutputs parameter is false
Lev Proleevfbc2a3d2020-01-14 17:35:36 +000058constexpr uint32_t kFwOutputHiddenStateTensor = 2;
59constexpr uint32_t kBwOutputHiddenStateTensor = 3;
Lev Proleev471aa212019-01-22 21:00:53 +000060
Slava Shklyaev84829a92021-02-26 12:05:38 +000061#ifdef NN_INCLUDE_CPU_IMPLEMENTATION
Lev Proleev471aa212019-01-22 21:00:53 +000062namespace {
63
64template <typename T>
65void transposeFirstTwoDims(const T* input, const Shape& inputShape, T* output) {
66 const uint32_t firstDimSize = getSizeOfDimension(inputShape, 0);
67 const uint32_t secondDimSize = getSizeOfDimension(inputShape, 1);
68 const uint32_t inputSize = getSizeOfDimension(inputShape, 2);
69 for (int f = 0; f < firstDimSize; ++f) {
70 for (int s = 0; s < secondDimSize; ++s) {
71 for (int i = 0; i < inputSize; ++i) {
72 const uint32_t inputIndex = f * secondDimSize * inputSize + s * inputSize + i;
73 const uint32_t outputIndex = s * firstDimSize * inputSize + f * inputSize + i;
74 output[outputIndex] = input[inputIndex];
75 }
76 }
77 }
78}
79
80Shape removeFirstDim(const Shape& input) {
81 Shape output = input;
82 output.dimensions.resize(input.dimensions.size() - 1);
83 for (int i = 0; i < input.dimensions.size() - 1; ++i) {
84 output.dimensions[i] = input.dimensions[i + 1];
85 }
86 return output;
87}
88
Lev Proleev721ee162020-01-27 15:46:59 +000089enum class LinkingMode {
90 NO_LINKING,
91 PARALLEL_LINKING,
92 CROSS_LINKING,
93};
94
95bool getLinkingMode(IOperationExecutionContext* context, LinkingMode* linkingMode) {
96 const bool hasAuxInput = !context->isOmittedInput(kAuxInputTensor);
97 const bool hasFwAuxWeights = !context->isOmittedInput(kFwAuxWeightsTensor);
98 const bool hasBwAuxWeights = !context->isOmittedInput(kBwAuxWeightsTensor);
99
100 // Three possible configurations for three possible linking modes:
101 // 1) NO_LINKING -- no auxiliary tensors at all
102 // 2) PARALLEL_LINKING -- auxiliary input is provided and used as a regular
103 // input to the backward network, so the auxiliary weights are omitted.
104 // 3) CROSS_LINKING -- auxiliary input is provided and multiplied by
105 // auxiliary weights.
106 if (!hasAuxInput && !hasFwAuxWeights && !hasBwAuxWeights) {
107 *linkingMode = LinkingMode::NO_LINKING;
108 } else if (hasAuxInput && !hasFwAuxWeights && !hasBwAuxWeights) {
109 *linkingMode = LinkingMode::PARALLEL_LINKING;
110 } else if (hasAuxInput && hasFwAuxWeights && hasBwAuxWeights) {
111 *linkingMode = LinkingMode::CROSS_LINKING;
112 } else {
113 NN_RET_CHECK_FAIL()
114 << "Unsupported auxiliary tensors configuration for BIDIRECTIONAL_SEQUENCE_RNN.";
115 }
116
117 return true;
118}
119
Lev Proleev471aa212019-01-22 21:00:53 +0000120template <typename T>
121bool executeTyped(IOperationExecutionContext* context) {
122 const T* input = context->getInputBuffer<T>(kInputTensor);
123 Shape inputShape = context->getInputShape(kInputTensor);
124
125 const T* fwWeights = context->getInputBuffer<T>(kFwWeightsTensor);
126 Shape fwWeightsShape = context->getInputShape(kFwWeightsTensor);
127 const T* fwRecurrentWeights = context->getInputBuffer<T>(kFwRecurrentWeightsTensor);
128 Shape fwRecurrentWeightsShape = context->getInputShape(kFwRecurrentWeightsTensor);
129 const T* fwBias = context->getInputBuffer<T>(kFwBiasTensor);
130 const T* fwHiddenState = context->getInputBuffer<T>(kFwHiddenStateTensor);
131
132 const T* bwWeights = context->getInputBuffer<T>(kBwWeightsTensor);
133 Shape bwWeightsShape = context->getInputShape(kBwWeightsTensor);
134 const T* bwRecurrentWeights = context->getInputBuffer<T>(kBwRecurrentWeightsTensor);
135 Shape bwRecurrentWeightsShape = context->getInputShape(kBwRecurrentWeightsTensor);
136 const T* bwBias = context->getInputBuffer<T>(kBwBiasTensor);
137 const T* bwHiddenState = context->getInputBuffer<T>(kBwHiddenStateTensor);
138
139 const T* auxInput = nullptr;
140 const T* fwAuxWeights = nullptr;
141 const T* bwAuxWeights = nullptr;
Lev Proleev721ee162020-01-27 15:46:59 +0000142 LinkingMode linkingMode;
143 NN_RET_CHECK(getLinkingMode(context, &linkingMode));
144 if (linkingMode == LinkingMode::CROSS_LINKING) {
Lev Proleev471aa212019-01-22 21:00:53 +0000145 auxInput = context->getInputBuffer<T>(kAuxInputTensor);
146 fwAuxWeights = context->getInputBuffer<T>(kFwAuxWeightsTensor);
147 bwAuxWeights = context->getInputBuffer<T>(kBwAuxWeightsTensor);
Lev Proleev721ee162020-01-27 15:46:59 +0000148 } else if (linkingMode == LinkingMode::PARALLEL_LINKING) {
149 auxInput = context->getInputBuffer<T>(kAuxInputTensor);
Lev Proleev471aa212019-01-22 21:00:53 +0000150 }
Lev Proleev721ee162020-01-27 15:46:59 +0000151 const bool hasAuxInput = (linkingMode == LinkingMode::CROSS_LINKING ||
152 linkingMode == LinkingMode::PARALLEL_LINKING);
153 const bool hasAuxWeights = (linkingMode == LinkingMode::CROSS_LINKING);
Lev Proleev471aa212019-01-22 21:00:53 +0000154 Shape auxInputShape = context->getInputShape(kAuxInputTensor);
155 Shape fwAuxWeightsShape = context->getInputShape(kFwAuxWeightsTensor);
156 Shape bwAuxWeightsShape = context->getInputShape(kBwAuxWeightsTensor);
157
Lev Proleev721ee162020-01-27 15:46:59 +0000158 const int32_t activation = context->getInputValue<int32_t>(kActivationParam);
159 const bool timeMajor = context->getInputValue<bool>(kTimeMajorParam);
160 const bool mergeOutputs = context->getInputValue<bool>(kMergeOutputsParam);
Lev Proleev471aa212019-01-22 21:00:53 +0000161
162 T* fwOutput = context->getOutputBuffer<T>(kFwOutputTensor);
163 Shape fwOutputShape = context->getOutputShape(kFwOutputTensor);
164 T* bwOutput = nullptr;
165 Shape bwOutputShape;
166 if (!mergeOutputs) {
167 bwOutputShape = context->getOutputShape(kBwOutputTensor);
168 bwOutput = context->getOutputBuffer<T>(kBwOutputTensor);
169 }
170
171 // If the input tensors are not in time major format, we transpose the first
172 // two dimensions, and set input and output pointers to temporary vectors
173 // which are transposed back after the RNN is applied.
174 std::vector<T> inputTransposed;
175 std::vector<T> auxInputTransposed;
176 std::vector<T> fwOutputTransposed;
177 std::vector<T> bwOutputTransposed;
178 if (!timeMajor) {
179 // First, resize temporary buffers to accommodate for transposed tensors.
180 inputTransposed.resize(getNumberOfElements(inputShape));
Lev Proleev721ee162020-01-27 15:46:59 +0000181 if (hasAuxInput) {
Lev Proleev471aa212019-01-22 21:00:53 +0000182 auxInputTransposed.resize(getNumberOfElements(auxInputShape));
183 }
184 fwOutputTransposed.resize(getNumberOfElements(fwOutputShape));
185 if (!mergeOutputs) {
186 bwOutputTransposed.resize(getNumberOfElements(bwOutputShape));
187 }
188
189 // Transpose the input tensors.
190 transposeFirstTwoDims(input, inputShape, inputTransposed.data());
Lev Proleev721ee162020-01-27 15:46:59 +0000191 if (hasAuxInput) {
Lev Proleev471aa212019-01-22 21:00:53 +0000192 transposeFirstTwoDims(auxInput, auxInputShape, auxInputTransposed.data());
193 }
194
195 // Change input and output pointers to the temporary buffers.
196 input = inputTransposed.data();
Lev Proleev721ee162020-01-27 15:46:59 +0000197 if (hasAuxInput) {
Lev Proleev471aa212019-01-22 21:00:53 +0000198 auxInput = auxInputTransposed.data();
199 }
200 fwOutput = fwOutputTransposed.data();
201 if (!mergeOutputs) {
202 bwOutput = bwOutputTransposed.data();
203 }
204
205 // Swap the first two dimensions in the Shapes to reflect the
206 // transposition.
207 std::swap(inputShape.dimensions[0], inputShape.dimensions[1]);
Lev Proleev721ee162020-01-27 15:46:59 +0000208 if (hasAuxInput) {
Lev Proleev471aa212019-01-22 21:00:53 +0000209 std::swap(auxInputShape.dimensions[0], auxInputShape.dimensions[1]);
210 }
211 std::swap(fwOutputShape.dimensions[0], fwOutputShape.dimensions[1]);
212 if (!mergeOutputs) {
213 std::swap(bwOutputShape.dimensions[0], bwOutputShape.dimensions[1]);
214 }
215 }
216
217 const uint32_t maxTime = getSizeOfDimension(inputShape, 0);
218 const uint32_t batchSize = getSizeOfDimension(inputShape, 1);
219 const uint32_t inputSize = getSizeOfDimension(inputShape, 2);
Lev Proleev884ee2b2019-06-14 14:15:45 +0100220 uint32_t auxInputSize = 0;
Lev Proleev721ee162020-01-27 15:46:59 +0000221 if (hasAuxInput) {
Lev Proleev884ee2b2019-06-14 14:15:45 +0100222 auxInputSize = getSizeOfDimension(auxInputShape, 2);
223 }
Lev Proleev471aa212019-01-22 21:00:53 +0000224 const uint32_t fwNumUnits = getSizeOfDimension(fwWeightsShape, 0);
225 const uint32_t bwNumUnits = getSizeOfDimension(bwWeightsShape, 0);
226
227 Shape fixedTimeInputShape = removeFirstDim(inputShape);
Lev Proleev884ee2b2019-06-14 14:15:45 +0100228 Shape fixedTimeAuxInputShape = auxInputShape;
Lev Proleev721ee162020-01-27 15:46:59 +0000229 if (hasAuxInput) {
Lev Proleev884ee2b2019-06-14 14:15:45 +0100230 fixedTimeAuxInputShape = removeFirstDim(auxInputShape);
231 }
Lev Proleev471aa212019-01-22 21:00:53 +0000232
Lev Proleev721ee162020-01-27 15:46:59 +0000233 const T* bwInput = input;
234 if (linkingMode == LinkingMode::PARALLEL_LINKING) {
235 bwInput = auxInput;
236 auxInput = nullptr;
237 }
238
Lev Proleevfbc2a3d2020-01-14 17:35:36 +0000239 const bool outputState = (context->getNumOutputs() == kNumOutputsWithState ||
240 context->getNumOutputs() == kNumOutputsMergedWithState);
241 T* fwOutputHiddenState = nullptr;
242 T* bwOutputHiddenState = nullptr;
Lev Proleev471aa212019-01-22 21:00:53 +0000243 // Create an additional buffer to store a hidden state between steps.
Lev Proleevfbc2a3d2020-01-14 17:35:36 +0000244 std::vector<T> tempHiddenState;
245 if (outputState) {
246 const int delta = mergeOutputs ? 1 : 0;
247 fwOutputHiddenState = context->getOutputBuffer<T>(kFwOutputHiddenStateTensor - delta);
248 bwOutputHiddenState = context->getOutputBuffer<T>(kBwOutputHiddenStateTensor - delta);
249 } else {
250 tempHiddenState.resize(std::max(batchSize * fwNumUnits, batchSize * bwNumUnits));
251 fwOutputHiddenState = tempHiddenState.data();
252 bwOutputHiddenState = tempHiddenState.data();
253 }
254
Lev Proleev471aa212019-01-22 21:00:53 +0000255 // Forward pass
256 for (int i = 0; i < maxTime; ++i) {
257 const T* inputBatchPtr = input + i * batchSize * inputSize;
258 const T* auxInputBatchPtr = nullptr;
Lev Proleev721ee162020-01-27 15:46:59 +0000259 if (hasAuxWeights) {
Lev Proleev471aa212019-01-22 21:00:53 +0000260 auxInputBatchPtr = auxInput + i * batchSize * auxInputSize;
261 }
262 const uint32_t fwOutputBatchStride = mergeOutputs ? (fwNumUnits + bwNumUnits) : fwNumUnits;
263 T* fwOutputBatchPtr = fwOutput + i * batchSize * fwOutputBatchStride;
264
265 RNN::RNNStep<T>(inputBatchPtr, fixedTimeInputShape, auxInputBatchPtr,
266 fixedTimeAuxInputShape, fwHiddenState, fwBias, fwWeights, fwWeightsShape,
267 fwAuxWeights, fwAuxWeightsShape, fwRecurrentWeights,
268 fwRecurrentWeightsShape, activation, fwOutputBatchStride,
Lev Proleevfbc2a3d2020-01-14 17:35:36 +0000269 /*outputBatchOffset=*/0, fwOutputBatchPtr, fwOutputHiddenState);
Lev Proleev471aa212019-01-22 21:00:53 +0000270
Lev Proleevfbc2a3d2020-01-14 17:35:36 +0000271 fwHiddenState = fwOutputHiddenState;
Lev Proleev471aa212019-01-22 21:00:53 +0000272 }
273
Lev Proleev471aa212019-01-22 21:00:53 +0000274 // Backward pass
275 for (int i = maxTime - 1; i >= 0; --i) {
Lev Proleev721ee162020-01-27 15:46:59 +0000276 const T* inputBatchPtr = bwInput + i * batchSize * inputSize;
Lev Proleev471aa212019-01-22 21:00:53 +0000277 const T* auxInputBatchPtr = nullptr;
Lev Proleev721ee162020-01-27 15:46:59 +0000278 if (hasAuxWeights) {
Lev Proleev471aa212019-01-22 21:00:53 +0000279 auxInputBatchPtr = auxInput + i * batchSize * auxInputSize;
280 }
281 T* bwOutputBatchPtr;
282 uint32_t bwOutputBatchOffset = 0;
283 uint32_t bwOutputBatchStride;
284 if (mergeOutputs) {
285 bwOutputBatchStride = fwNumUnits + bwNumUnits;
286 bwOutputBatchOffset = fwNumUnits;
287 bwOutputBatchPtr = fwOutput + i * batchSize * bwOutputBatchStride;
288 } else {
289 bwOutputBatchStride = bwNumUnits;
290 bwOutputBatchPtr = bwOutput + i * batchSize * bwOutputBatchStride;
291 }
292
293 RNN::RNNStep<T>(inputBatchPtr, fixedTimeInputShape, auxInputBatchPtr,
294 fixedTimeAuxInputShape, bwHiddenState, bwBias, bwWeights, bwWeightsShape,
295 bwAuxWeights, bwAuxWeightsShape, bwRecurrentWeights,
296 bwRecurrentWeightsShape, activation, bwOutputBatchStride,
Lev Proleevfbc2a3d2020-01-14 17:35:36 +0000297 bwOutputBatchOffset, bwOutputBatchPtr, bwOutputHiddenState);
Lev Proleev471aa212019-01-22 21:00:53 +0000298
Lev Proleevfbc2a3d2020-01-14 17:35:36 +0000299 bwHiddenState = bwOutputHiddenState;
Lev Proleev471aa212019-01-22 21:00:53 +0000300 }
301
302 // If the inputs were in batch major format, transpose data in temporary
303 // buffers and write to the output(s).
304 if (!timeMajor) {
305 transposeFirstTwoDims(fwOutputTransposed.data(), fwOutputShape,
306 context->getOutputBuffer<T>(kFwOutputTensor));
307 if (!mergeOutputs) {
308 transposeFirstTwoDims(bwOutputTransposed.data(), bwOutputShape,
309 context->getOutputBuffer<T>(kBwOutputTensor));
310 }
311 }
312 return true;
313}
314
315} // namespace
Slava Shklyaev84829a92021-02-26 12:05:38 +0000316#endif // NN_INCLUDE_CPU_IMPLEMENTATION
Lev Proleev471aa212019-01-22 21:00:53 +0000317
Michael Butler274ff7b2020-11-02 23:17:11 -0800318Result<Version> validate(const IOperationValidationContext* context) {
Lev Proleev471aa212019-01-22 21:00:53 +0000319 NN_RET_CHECK_EQ(context->getNumInputs(), kNumInputs);
320 // Exact number is dependent on the mergeOutputs parameter and checked
321 // during preparation.
Lev Proleevfbc2a3d2020-01-14 17:35:36 +0000322 const uint32_t numOutputs = context->getNumOutputs();
323 NN_RET_CHECK(numOutputs == kNumOutputs || numOutputs == kNumOutputsMerged ||
324 numOutputs == kNumOutputsWithState || numOutputs == kNumOutputsMergedWithState);
325
Lev Proleev471aa212019-01-22 21:00:53 +0000326 OperandType inputType = context->getInputType(kInputTensor);
327 if (inputType != OperandType::TENSOR_FLOAT16 && inputType != OperandType::TENSOR_FLOAT32) {
Michael Butler274ff7b2020-11-02 23:17:11 -0800328 return NN_ERROR() << "Unsupported input operand type for UNIDIRECTIONAL_SEQUENCE_RNN op: "
329 << inputType;
Lev Proleev471aa212019-01-22 21:00:53 +0000330 }
331 NN_RET_CHECK(validateInputTypes(
332 context, {inputType, inputType, inputType, inputType, inputType, inputType, inputType,
333 inputType, inputType, inputType, inputType, inputType, OperandType::INT32,
334 OperandType::BOOL, OperandType::BOOL}));
Lev Proleevfbc2a3d2020-01-14 17:35:36 +0000335
336 std::vector<OperandType> outExpectedTypes(numOutputs, inputType);
337 NN_RET_CHECK(validateOutputTypes(context, outExpectedTypes));
338
Michael Butler25d50732020-11-01 23:47:40 -0800339 Version minSupportedVersion = Version::ANDROID_Q;
Lev Proleevfbc2a3d2020-01-14 17:35:36 +0000340 if (numOutputs == kNumOutputsWithState || numOutputs == kNumOutputsMergedWithState) {
Michael Butler25d50732020-11-01 23:47:40 -0800341 minSupportedVersion = Version::ANDROID_R;
Lev Proleev471aa212019-01-22 21:00:53 +0000342 }
Michael Butler274ff7b2020-11-02 23:17:11 -0800343 return minSupportedVersion;
Lev Proleev471aa212019-01-22 21:00:53 +0000344}
345
Slava Shklyaev84829a92021-02-26 12:05:38 +0000346#ifdef NN_INCLUDE_CPU_IMPLEMENTATION
Lev Proleev471aa212019-01-22 21:00:53 +0000347bool prepare(IOperationExecutionContext* context) {
Lev Proleevfbc2a3d2020-01-14 17:35:36 +0000348 const bool mergeOutputs = context->getInputValue<bool>(kMergeOutputsParam);
349 const int32_t numOutputs = context->getNumOutputs();
Lev Proleev471aa212019-01-22 21:00:53 +0000350 if (mergeOutputs) {
Lev Proleevfbc2a3d2020-01-14 17:35:36 +0000351 NN_RET_CHECK(numOutputs == kNumOutputsMerged || numOutputs == kNumOutputsMergedWithState);
Lev Proleev471aa212019-01-22 21:00:53 +0000352 } else {
Lev Proleevfbc2a3d2020-01-14 17:35:36 +0000353 NN_RET_CHECK(numOutputs == kNumOutputs || numOutputs == kNumOutputsWithState);
Lev Proleev471aa212019-01-22 21:00:53 +0000354 }
355
Lev Proleev1c6a4d52019-04-24 15:30:04 +0100356 // Check that none of the required inputs are omitted.
357 const std::vector<int> requiredInputs = {
358 kInputTensor, kFwWeightsTensor, kFwRecurrentWeightsTensor, kFwBiasTensor,
359 kFwHiddenStateTensor, kBwWeightsTensor, kBwRecurrentWeightsTensor, kBwBiasTensor,
360 kBwHiddenStateTensor, kActivationParam, kTimeMajorParam, kMergeOutputsParam,
361 };
362 for (const int requiredInput : requiredInputs) {
363 NN_RET_CHECK(!context->isOmittedInput(requiredInput))
364 << "required input " << requiredInput << " is omitted";
365 }
366
Lev Proleev471aa212019-01-22 21:00:53 +0000367 Shape input = context->getInputShape(kInputTensor);
368 Shape fwWeights = context->getInputShape(kFwWeightsTensor);
369 Shape fwRecurrentWeights = context->getInputShape(kFwRecurrentWeightsTensor);
370 Shape fwBias = context->getInputShape(kFwBiasTensor);
371 Shape fwHiddenState = context->getInputShape(kFwHiddenStateTensor);
372 Shape bwWeights = context->getInputShape(kBwWeightsTensor);
373 Shape bwRecurrentWeights = context->getInputShape(kBwRecurrentWeightsTensor);
374 Shape bwBias = context->getInputShape(kBwBiasTensor);
375 Shape bwHiddenState = context->getInputShape(kBwHiddenStateTensor);
376
377 Shape auxInput = context->getInputShape(kAuxInputTensor);
378 Shape fwAuxWeights = context->getInputShape(kFwAuxWeightsTensor);
379 Shape bwAuxWeights = context->getInputShape(kBwAuxWeightsTensor);
380
Lev Proleev721ee162020-01-27 15:46:59 +0000381 LinkingMode linkingMode;
382 NN_RET_CHECK(getLinkingMode(context, &linkingMode));
Lev Proleev471aa212019-01-22 21:00:53 +0000383
Lev Proleev721ee162020-01-27 15:46:59 +0000384 bool timeMajor = context->getInputValue<bool>(kTimeMajorParam);
Lev Proleev471aa212019-01-22 21:00:53 +0000385 const uint32_t batchSize =
386 timeMajor ? getSizeOfDimension(input, 1) : getSizeOfDimension(input, 0);
387 const uint32_t maxTime =
388 timeMajor ? getSizeOfDimension(input, 0) : getSizeOfDimension(input, 1);
389 const uint32_t fwNumUnits = getSizeOfDimension(fwWeights, 0);
390 const uint32_t bwNumUnits = getSizeOfDimension(bwWeights, 0);
391 const uint32_t inputSize = getSizeOfDimension(input, 2);
392
393 NN_RET_CHECK_EQ(getNumberOfDimensions(input), 3);
394 NN_RET_CHECK_EQ(getNumberOfDimensions(fwWeights), 2);
395 NN_RET_CHECK_EQ(getNumberOfDimensions(fwRecurrentWeights), 2);
396 NN_RET_CHECK_EQ(getNumberOfDimensions(fwBias), 1);
397 NN_RET_CHECK_EQ(getNumberOfDimensions(fwHiddenState), 2);
398 NN_RET_CHECK_EQ(getNumberOfDimensions(bwWeights), 2);
399 NN_RET_CHECK_EQ(getNumberOfDimensions(bwRecurrentWeights), 2);
400 NN_RET_CHECK_EQ(getNumberOfDimensions(bwBias), 1);
401 NN_RET_CHECK_EQ(getNumberOfDimensions(bwHiddenState), 2);
402
403 NN_RET_CHECK_EQ(inputSize, getSizeOfDimension(fwWeights, 1));
404 NN_RET_CHECK_EQ(fwNumUnits, getSizeOfDimension(fwBias, 0));
405 NN_RET_CHECK_EQ(fwNumUnits, getSizeOfDimension(fwRecurrentWeights, 0));
406 NN_RET_CHECK_EQ(fwNumUnits, getSizeOfDimension(fwRecurrentWeights, 1));
407 NN_RET_CHECK_EQ(batchSize, getSizeOfDimension(fwHiddenState, 0));
408 NN_RET_CHECK_EQ(fwNumUnits, getSizeOfDimension(fwHiddenState, 1));
409
Lev Proleev721ee162020-01-27 15:46:59 +0000410 if (linkingMode != LinkingMode::PARALLEL_LINKING) {
411 NN_RET_CHECK_EQ(inputSize, getSizeOfDimension(bwWeights, 1));
412 }
Lev Proleev471aa212019-01-22 21:00:53 +0000413 NN_RET_CHECK_EQ(bwNumUnits, getSizeOfDimension(bwBias, 0));
414 NN_RET_CHECK_EQ(bwNumUnits, getSizeOfDimension(bwRecurrentWeights, 0));
415 NN_RET_CHECK_EQ(bwNumUnits, getSizeOfDimension(bwRecurrentWeights, 1));
416 NN_RET_CHECK_EQ(batchSize, getSizeOfDimension(bwHiddenState, 0));
417 NN_RET_CHECK_EQ(bwNumUnits, getSizeOfDimension(bwHiddenState, 1));
418
Lev Proleev721ee162020-01-27 15:46:59 +0000419 if (linkingMode == LinkingMode::CROSS_LINKING) {
Lev Proleev471aa212019-01-22 21:00:53 +0000420 NN_RET_CHECK_EQ(getNumberOfDimensions(auxInput), 3);
421 NN_RET_CHECK_EQ(getNumberOfDimensions(fwAuxWeights), 2);
422 NN_RET_CHECK_EQ(getNumberOfDimensions(bwAuxWeights), 2);
423
424 NN_RET_CHECK_EQ(getSizeOfDimension(auxInput, 0), getSizeOfDimension(input, 0));
425 NN_RET_CHECK_EQ(getSizeOfDimension(auxInput, 1), getSizeOfDimension(input, 1));
426 NN_RET_CHECK_EQ(getSizeOfDimension(fwAuxWeights, 0), fwNumUnits);
427 NN_RET_CHECK_EQ(getSizeOfDimension(fwAuxWeights, 1), getSizeOfDimension(auxInput, 2));
428 NN_RET_CHECK_EQ(getSizeOfDimension(bwAuxWeights, 0), bwNumUnits);
429 NN_RET_CHECK_EQ(getSizeOfDimension(bwAuxWeights, 1), getSizeOfDimension(auxInput, 2));
Lev Proleev721ee162020-01-27 15:46:59 +0000430 } else if (linkingMode == LinkingMode::PARALLEL_LINKING) {
431 NN_RET_CHECK_EQ(getNumberOfDimensions(auxInput), 3);
432
433 NN_RET_CHECK_EQ(getSizeOfDimension(auxInput, 0), getSizeOfDimension(input, 0));
434 NN_RET_CHECK_EQ(getSizeOfDimension(auxInput, 1), getSizeOfDimension(input, 1));
435 NN_RET_CHECK_EQ(getSizeOfDimension(auxInput, 2), getSizeOfDimension(bwWeights, 1));
Lev Proleev471aa212019-01-22 21:00:53 +0000436 }
437
438 Shape fwOutput = context->getOutputShape(kFwOutputTensor);
439 fwOutput.dimensions.resize(3);
440 fwOutput.dimensions[0] = timeMajor ? maxTime : batchSize;
441 fwOutput.dimensions[1] = timeMajor ? batchSize : maxTime;
442 fwOutput.dimensions[2] = mergeOutputs ? fwNumUnits + bwNumUnits : fwNumUnits;
443 NN_RET_CHECK(context->setOutputShape(kFwOutputTensor, fwOutput));
444 if (!mergeOutputs) {
445 Shape bwOutput = context->getOutputShape(kBwOutputTensor);
446 bwOutput.dimensions.resize(3);
447 bwOutput.dimensions[0] = timeMajor ? maxTime : batchSize;
448 bwOutput.dimensions[1] = timeMajor ? batchSize : maxTime;
449 bwOutput.dimensions[2] = bwNumUnits;
450 NN_RET_CHECK(context->setOutputShape(kBwOutputTensor, bwOutput));
451 }
452
Lev Proleevfbc2a3d2020-01-14 17:35:36 +0000453 const bool outputState =
454 (numOutputs == kNumOutputsWithState || numOutputs == kNumOutputsMergedWithState);
455 if (outputState) {
456 const int delta = mergeOutputs ? 1 : 0;
457 NN_RET_CHECK(context->setOutputShape(kFwOutputHiddenStateTensor - delta,
458 context->getInputShape(kFwHiddenStateTensor)));
459 NN_RET_CHECK(context->setOutputShape(kBwOutputHiddenStateTensor - delta,
460 context->getInputShape(kBwHiddenStateTensor)));
461 }
462
Lev Proleev471aa212019-01-22 21:00:53 +0000463 return true;
464}
465
466bool execute(IOperationExecutionContext* context) {
467 if (context->getInputType(kInputTensor) == OperandType::TENSOR_FLOAT16) {
468 executeTyped<_Float16>(context);
469 } else {
470 executeTyped<float>(context);
471 }
472 return true;
473}
Slava Shklyaev84829a92021-02-26 12:05:38 +0000474#endif // NN_INCLUDE_CPU_IMPLEMENTATION
Lev Proleev471aa212019-01-22 21:00:53 +0000475
476} // namespace bidirectional_sequence_rnn
477
478NN_REGISTER_OPERATION(BIDIRECTIONAL_SEQUENCE_RNN, "BIDIRECTIONAL_SEQUENCE_RNN",
479 bidirectional_sequence_rnn::validate, bidirectional_sequence_rnn::prepare,
Xusong Wangef836902019-03-06 13:10:14 -0800480 bidirectional_sequence_rnn::execute, .allowOmittedOperand = true);
Lev Proleev471aa212019-01-22 21:00:53 +0000481
482} // namespace nn
483} // namespace android