Lev Proleev | 471aa21 | 2019-01-22 21:00:53 +0000 | [diff] [blame] | 1 | /* |
| 2 | * Copyright (C) 2019 The Android Open Source Project |
| 3 | * |
| 4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
| 5 | * you may not use this file except in compliance with the License. |
| 6 | * You may obtain a copy of the License at |
| 7 | * |
| 8 | * http://www.apache.org/licenses/LICENSE-2.0 |
| 9 | * |
| 10 | * Unless required by applicable law or agreed to in writing, software |
| 11 | * distributed under the License is distributed on an "AS IS" BASIS, |
| 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 13 | * See the License for the specific language governing permissions and |
| 14 | * limitations under the License. |
| 15 | */ |
| 16 | |
| 17 | #define LOG_TAG "Operations" |
| 18 | |
Lev Proleev | fbc2a3d | 2020-01-14 17:35:36 +0000 | [diff] [blame] | 19 | #include <algorithm> |
| 20 | #include <utility> |
| 21 | #include <vector> |
| 22 | |
Lev Proleev | 471aa21 | 2019-01-22 21:00:53 +0000 | [diff] [blame] | 23 | #include "OperationResolver.h" |
| 24 | #include "RNN.h" |
| 25 | |
| 26 | namespace android { |
| 27 | namespace nn { |
| 28 | namespace bidirectional_sequence_rnn { |
| 29 | |
| 30 | constexpr uint32_t kNumInputs = 15; |
| 31 | constexpr uint32_t kInputTensor = 0; |
| 32 | // Forward cell tensors |
| 33 | constexpr uint32_t kFwWeightsTensor = 1; |
| 34 | constexpr uint32_t kFwRecurrentWeightsTensor = 2; |
| 35 | constexpr uint32_t kFwBiasTensor = 3; |
| 36 | constexpr uint32_t kFwHiddenStateTensor = 4; |
| 37 | // Backward cell tensors |
| 38 | constexpr uint32_t kBwWeightsTensor = 5; |
| 39 | constexpr uint32_t kBwRecurrentWeightsTensor = 6; |
| 40 | constexpr uint32_t kBwBiasTensor = 7; |
| 41 | constexpr uint32_t kBwHiddenStateTensor = 8; |
| 42 | // Auxiliary inputs |
| 43 | constexpr uint32_t kAuxInputTensor = 9; // optional |
| 44 | constexpr uint32_t kFwAuxWeightsTensor = 10; // optional |
| 45 | constexpr uint32_t kBwAuxWeightsTensor = 11; // optional |
| 46 | // Cell parameters |
| 47 | constexpr uint32_t kActivationParam = 12; |
| 48 | constexpr uint32_t kTimeMajorParam = 13; |
| 49 | constexpr uint32_t kMergeOutputsParam = 14; |
| 50 | |
Lev Proleev | fbc2a3d | 2020-01-14 17:35:36 +0000 | [diff] [blame] | 51 | constexpr uint32_t kNumOutputs = 2; |
| 52 | constexpr uint32_t kNumOutputsMerged = 1; |
| 53 | constexpr uint32_t kNumOutputsWithState = 4; |
| 54 | constexpr uint32_t kNumOutputsMergedWithState = 3; |
| 55 | |
Lev Proleev | 471aa21 | 2019-01-22 21:00:53 +0000 | [diff] [blame] | 56 | constexpr uint32_t kFwOutputTensor = 0; |
| 57 | constexpr uint32_t kBwOutputTensor = 1; // Only if mergeOutputs parameter is false |
Lev Proleev | fbc2a3d | 2020-01-14 17:35:36 +0000 | [diff] [blame] | 58 | constexpr uint32_t kFwOutputHiddenStateTensor = 2; |
| 59 | constexpr uint32_t kBwOutputHiddenStateTensor = 3; |
Lev Proleev | 471aa21 | 2019-01-22 21:00:53 +0000 | [diff] [blame] | 60 | |
Slava Shklyaev | 84829a9 | 2021-02-26 12:05:38 +0000 | [diff] [blame] | 61 | #ifdef NN_INCLUDE_CPU_IMPLEMENTATION |
Lev Proleev | 471aa21 | 2019-01-22 21:00:53 +0000 | [diff] [blame] | 62 | namespace { |
| 63 | |
| 64 | template <typename T> |
| 65 | void transposeFirstTwoDims(const T* input, const Shape& inputShape, T* output) { |
| 66 | const uint32_t firstDimSize = getSizeOfDimension(inputShape, 0); |
| 67 | const uint32_t secondDimSize = getSizeOfDimension(inputShape, 1); |
| 68 | const uint32_t inputSize = getSizeOfDimension(inputShape, 2); |
| 69 | for (int f = 0; f < firstDimSize; ++f) { |
| 70 | for (int s = 0; s < secondDimSize; ++s) { |
| 71 | for (int i = 0; i < inputSize; ++i) { |
| 72 | const uint32_t inputIndex = f * secondDimSize * inputSize + s * inputSize + i; |
| 73 | const uint32_t outputIndex = s * firstDimSize * inputSize + f * inputSize + i; |
| 74 | output[outputIndex] = input[inputIndex]; |
| 75 | } |
| 76 | } |
| 77 | } |
| 78 | } |
| 79 | |
| 80 | Shape removeFirstDim(const Shape& input) { |
| 81 | Shape output = input; |
| 82 | output.dimensions.resize(input.dimensions.size() - 1); |
| 83 | for (int i = 0; i < input.dimensions.size() - 1; ++i) { |
| 84 | output.dimensions[i] = input.dimensions[i + 1]; |
| 85 | } |
| 86 | return output; |
| 87 | } |
| 88 | |
Lev Proleev | 721ee16 | 2020-01-27 15:46:59 +0000 | [diff] [blame] | 89 | enum class LinkingMode { |
| 90 | NO_LINKING, |
| 91 | PARALLEL_LINKING, |
| 92 | CROSS_LINKING, |
| 93 | }; |
| 94 | |
| 95 | bool getLinkingMode(IOperationExecutionContext* context, LinkingMode* linkingMode) { |
| 96 | const bool hasAuxInput = !context->isOmittedInput(kAuxInputTensor); |
| 97 | const bool hasFwAuxWeights = !context->isOmittedInput(kFwAuxWeightsTensor); |
| 98 | const bool hasBwAuxWeights = !context->isOmittedInput(kBwAuxWeightsTensor); |
| 99 | |
| 100 | // Three possible configurations for three possible linking modes: |
| 101 | // 1) NO_LINKING -- no auxiliary tensors at all |
| 102 | // 2) PARALLEL_LINKING -- auxiliary input is provided and used as a regular |
| 103 | // input to the backward network, so the auxiliary weights are omitted. |
| 104 | // 3) CROSS_LINKING -- auxiliary input is provided and multiplied by |
| 105 | // auxiliary weights. |
| 106 | if (!hasAuxInput && !hasFwAuxWeights && !hasBwAuxWeights) { |
| 107 | *linkingMode = LinkingMode::NO_LINKING; |
| 108 | } else if (hasAuxInput && !hasFwAuxWeights && !hasBwAuxWeights) { |
| 109 | *linkingMode = LinkingMode::PARALLEL_LINKING; |
| 110 | } else if (hasAuxInput && hasFwAuxWeights && hasBwAuxWeights) { |
| 111 | *linkingMode = LinkingMode::CROSS_LINKING; |
| 112 | } else { |
| 113 | NN_RET_CHECK_FAIL() |
| 114 | << "Unsupported auxiliary tensors configuration for BIDIRECTIONAL_SEQUENCE_RNN."; |
| 115 | } |
| 116 | |
| 117 | return true; |
| 118 | } |
| 119 | |
Lev Proleev | 471aa21 | 2019-01-22 21:00:53 +0000 | [diff] [blame] | 120 | template <typename T> |
| 121 | bool executeTyped(IOperationExecutionContext* context) { |
| 122 | const T* input = context->getInputBuffer<T>(kInputTensor); |
| 123 | Shape inputShape = context->getInputShape(kInputTensor); |
| 124 | |
| 125 | const T* fwWeights = context->getInputBuffer<T>(kFwWeightsTensor); |
| 126 | Shape fwWeightsShape = context->getInputShape(kFwWeightsTensor); |
| 127 | const T* fwRecurrentWeights = context->getInputBuffer<T>(kFwRecurrentWeightsTensor); |
| 128 | Shape fwRecurrentWeightsShape = context->getInputShape(kFwRecurrentWeightsTensor); |
| 129 | const T* fwBias = context->getInputBuffer<T>(kFwBiasTensor); |
| 130 | const T* fwHiddenState = context->getInputBuffer<T>(kFwHiddenStateTensor); |
| 131 | |
| 132 | const T* bwWeights = context->getInputBuffer<T>(kBwWeightsTensor); |
| 133 | Shape bwWeightsShape = context->getInputShape(kBwWeightsTensor); |
| 134 | const T* bwRecurrentWeights = context->getInputBuffer<T>(kBwRecurrentWeightsTensor); |
| 135 | Shape bwRecurrentWeightsShape = context->getInputShape(kBwRecurrentWeightsTensor); |
| 136 | const T* bwBias = context->getInputBuffer<T>(kBwBiasTensor); |
| 137 | const T* bwHiddenState = context->getInputBuffer<T>(kBwHiddenStateTensor); |
| 138 | |
| 139 | const T* auxInput = nullptr; |
| 140 | const T* fwAuxWeights = nullptr; |
| 141 | const T* bwAuxWeights = nullptr; |
Lev Proleev | 721ee16 | 2020-01-27 15:46:59 +0000 | [diff] [blame] | 142 | LinkingMode linkingMode; |
| 143 | NN_RET_CHECK(getLinkingMode(context, &linkingMode)); |
| 144 | if (linkingMode == LinkingMode::CROSS_LINKING) { |
Lev Proleev | 471aa21 | 2019-01-22 21:00:53 +0000 | [diff] [blame] | 145 | auxInput = context->getInputBuffer<T>(kAuxInputTensor); |
| 146 | fwAuxWeights = context->getInputBuffer<T>(kFwAuxWeightsTensor); |
| 147 | bwAuxWeights = context->getInputBuffer<T>(kBwAuxWeightsTensor); |
Lev Proleev | 721ee16 | 2020-01-27 15:46:59 +0000 | [diff] [blame] | 148 | } else if (linkingMode == LinkingMode::PARALLEL_LINKING) { |
| 149 | auxInput = context->getInputBuffer<T>(kAuxInputTensor); |
Lev Proleev | 471aa21 | 2019-01-22 21:00:53 +0000 | [diff] [blame] | 150 | } |
Lev Proleev | 721ee16 | 2020-01-27 15:46:59 +0000 | [diff] [blame] | 151 | const bool hasAuxInput = (linkingMode == LinkingMode::CROSS_LINKING || |
| 152 | linkingMode == LinkingMode::PARALLEL_LINKING); |
| 153 | const bool hasAuxWeights = (linkingMode == LinkingMode::CROSS_LINKING); |
Lev Proleev | 471aa21 | 2019-01-22 21:00:53 +0000 | [diff] [blame] | 154 | Shape auxInputShape = context->getInputShape(kAuxInputTensor); |
| 155 | Shape fwAuxWeightsShape = context->getInputShape(kFwAuxWeightsTensor); |
| 156 | Shape bwAuxWeightsShape = context->getInputShape(kBwAuxWeightsTensor); |
| 157 | |
Lev Proleev | 721ee16 | 2020-01-27 15:46:59 +0000 | [diff] [blame] | 158 | const int32_t activation = context->getInputValue<int32_t>(kActivationParam); |
| 159 | const bool timeMajor = context->getInputValue<bool>(kTimeMajorParam); |
| 160 | const bool mergeOutputs = context->getInputValue<bool>(kMergeOutputsParam); |
Lev Proleev | 471aa21 | 2019-01-22 21:00:53 +0000 | [diff] [blame] | 161 | |
| 162 | T* fwOutput = context->getOutputBuffer<T>(kFwOutputTensor); |
| 163 | Shape fwOutputShape = context->getOutputShape(kFwOutputTensor); |
| 164 | T* bwOutput = nullptr; |
| 165 | Shape bwOutputShape; |
| 166 | if (!mergeOutputs) { |
| 167 | bwOutputShape = context->getOutputShape(kBwOutputTensor); |
| 168 | bwOutput = context->getOutputBuffer<T>(kBwOutputTensor); |
| 169 | } |
| 170 | |
| 171 | // If the input tensors are not in time major format, we transpose the first |
| 172 | // two dimensions, and set input and output pointers to temporary vectors |
| 173 | // which are transposed back after the RNN is applied. |
| 174 | std::vector<T> inputTransposed; |
| 175 | std::vector<T> auxInputTransposed; |
| 176 | std::vector<T> fwOutputTransposed; |
| 177 | std::vector<T> bwOutputTransposed; |
| 178 | if (!timeMajor) { |
| 179 | // First, resize temporary buffers to accommodate for transposed tensors. |
| 180 | inputTransposed.resize(getNumberOfElements(inputShape)); |
Lev Proleev | 721ee16 | 2020-01-27 15:46:59 +0000 | [diff] [blame] | 181 | if (hasAuxInput) { |
Lev Proleev | 471aa21 | 2019-01-22 21:00:53 +0000 | [diff] [blame] | 182 | auxInputTransposed.resize(getNumberOfElements(auxInputShape)); |
| 183 | } |
| 184 | fwOutputTransposed.resize(getNumberOfElements(fwOutputShape)); |
| 185 | if (!mergeOutputs) { |
| 186 | bwOutputTransposed.resize(getNumberOfElements(bwOutputShape)); |
| 187 | } |
| 188 | |
| 189 | // Transpose the input tensors. |
| 190 | transposeFirstTwoDims(input, inputShape, inputTransposed.data()); |
Lev Proleev | 721ee16 | 2020-01-27 15:46:59 +0000 | [diff] [blame] | 191 | if (hasAuxInput) { |
Lev Proleev | 471aa21 | 2019-01-22 21:00:53 +0000 | [diff] [blame] | 192 | transposeFirstTwoDims(auxInput, auxInputShape, auxInputTransposed.data()); |
| 193 | } |
| 194 | |
| 195 | // Change input and output pointers to the temporary buffers. |
| 196 | input = inputTransposed.data(); |
Lev Proleev | 721ee16 | 2020-01-27 15:46:59 +0000 | [diff] [blame] | 197 | if (hasAuxInput) { |
Lev Proleev | 471aa21 | 2019-01-22 21:00:53 +0000 | [diff] [blame] | 198 | auxInput = auxInputTransposed.data(); |
| 199 | } |
| 200 | fwOutput = fwOutputTransposed.data(); |
| 201 | if (!mergeOutputs) { |
| 202 | bwOutput = bwOutputTransposed.data(); |
| 203 | } |
| 204 | |
| 205 | // Swap the first two dimensions in the Shapes to reflect the |
| 206 | // transposition. |
| 207 | std::swap(inputShape.dimensions[0], inputShape.dimensions[1]); |
Lev Proleev | 721ee16 | 2020-01-27 15:46:59 +0000 | [diff] [blame] | 208 | if (hasAuxInput) { |
Lev Proleev | 471aa21 | 2019-01-22 21:00:53 +0000 | [diff] [blame] | 209 | std::swap(auxInputShape.dimensions[0], auxInputShape.dimensions[1]); |
| 210 | } |
| 211 | std::swap(fwOutputShape.dimensions[0], fwOutputShape.dimensions[1]); |
| 212 | if (!mergeOutputs) { |
| 213 | std::swap(bwOutputShape.dimensions[0], bwOutputShape.dimensions[1]); |
| 214 | } |
| 215 | } |
| 216 | |
| 217 | const uint32_t maxTime = getSizeOfDimension(inputShape, 0); |
| 218 | const uint32_t batchSize = getSizeOfDimension(inputShape, 1); |
| 219 | const uint32_t inputSize = getSizeOfDimension(inputShape, 2); |
Lev Proleev | 884ee2b | 2019-06-14 14:15:45 +0100 | [diff] [blame] | 220 | uint32_t auxInputSize = 0; |
Lev Proleev | 721ee16 | 2020-01-27 15:46:59 +0000 | [diff] [blame] | 221 | if (hasAuxInput) { |
Lev Proleev | 884ee2b | 2019-06-14 14:15:45 +0100 | [diff] [blame] | 222 | auxInputSize = getSizeOfDimension(auxInputShape, 2); |
| 223 | } |
Lev Proleev | 471aa21 | 2019-01-22 21:00:53 +0000 | [diff] [blame] | 224 | const uint32_t fwNumUnits = getSizeOfDimension(fwWeightsShape, 0); |
| 225 | const uint32_t bwNumUnits = getSizeOfDimension(bwWeightsShape, 0); |
| 226 | |
| 227 | Shape fixedTimeInputShape = removeFirstDim(inputShape); |
Lev Proleev | 884ee2b | 2019-06-14 14:15:45 +0100 | [diff] [blame] | 228 | Shape fixedTimeAuxInputShape = auxInputShape; |
Lev Proleev | 721ee16 | 2020-01-27 15:46:59 +0000 | [diff] [blame] | 229 | if (hasAuxInput) { |
Lev Proleev | 884ee2b | 2019-06-14 14:15:45 +0100 | [diff] [blame] | 230 | fixedTimeAuxInputShape = removeFirstDim(auxInputShape); |
| 231 | } |
Lev Proleev | 471aa21 | 2019-01-22 21:00:53 +0000 | [diff] [blame] | 232 | |
Lev Proleev | 721ee16 | 2020-01-27 15:46:59 +0000 | [diff] [blame] | 233 | const T* bwInput = input; |
| 234 | if (linkingMode == LinkingMode::PARALLEL_LINKING) { |
| 235 | bwInput = auxInput; |
| 236 | auxInput = nullptr; |
| 237 | } |
| 238 | |
Lev Proleev | fbc2a3d | 2020-01-14 17:35:36 +0000 | [diff] [blame] | 239 | const bool outputState = (context->getNumOutputs() == kNumOutputsWithState || |
| 240 | context->getNumOutputs() == kNumOutputsMergedWithState); |
| 241 | T* fwOutputHiddenState = nullptr; |
| 242 | T* bwOutputHiddenState = nullptr; |
Lev Proleev | 471aa21 | 2019-01-22 21:00:53 +0000 | [diff] [blame] | 243 | // Create an additional buffer to store a hidden state between steps. |
Lev Proleev | fbc2a3d | 2020-01-14 17:35:36 +0000 | [diff] [blame] | 244 | std::vector<T> tempHiddenState; |
| 245 | if (outputState) { |
| 246 | const int delta = mergeOutputs ? 1 : 0; |
| 247 | fwOutputHiddenState = context->getOutputBuffer<T>(kFwOutputHiddenStateTensor - delta); |
| 248 | bwOutputHiddenState = context->getOutputBuffer<T>(kBwOutputHiddenStateTensor - delta); |
| 249 | } else { |
| 250 | tempHiddenState.resize(std::max(batchSize * fwNumUnits, batchSize * bwNumUnits)); |
| 251 | fwOutputHiddenState = tempHiddenState.data(); |
| 252 | bwOutputHiddenState = tempHiddenState.data(); |
| 253 | } |
| 254 | |
Lev Proleev | 471aa21 | 2019-01-22 21:00:53 +0000 | [diff] [blame] | 255 | // Forward pass |
| 256 | for (int i = 0; i < maxTime; ++i) { |
| 257 | const T* inputBatchPtr = input + i * batchSize * inputSize; |
| 258 | const T* auxInputBatchPtr = nullptr; |
Lev Proleev | 721ee16 | 2020-01-27 15:46:59 +0000 | [diff] [blame] | 259 | if (hasAuxWeights) { |
Lev Proleev | 471aa21 | 2019-01-22 21:00:53 +0000 | [diff] [blame] | 260 | auxInputBatchPtr = auxInput + i * batchSize * auxInputSize; |
| 261 | } |
| 262 | const uint32_t fwOutputBatchStride = mergeOutputs ? (fwNumUnits + bwNumUnits) : fwNumUnits; |
| 263 | T* fwOutputBatchPtr = fwOutput + i * batchSize * fwOutputBatchStride; |
| 264 | |
| 265 | RNN::RNNStep<T>(inputBatchPtr, fixedTimeInputShape, auxInputBatchPtr, |
| 266 | fixedTimeAuxInputShape, fwHiddenState, fwBias, fwWeights, fwWeightsShape, |
| 267 | fwAuxWeights, fwAuxWeightsShape, fwRecurrentWeights, |
| 268 | fwRecurrentWeightsShape, activation, fwOutputBatchStride, |
Lev Proleev | fbc2a3d | 2020-01-14 17:35:36 +0000 | [diff] [blame] | 269 | /*outputBatchOffset=*/0, fwOutputBatchPtr, fwOutputHiddenState); |
Lev Proleev | 471aa21 | 2019-01-22 21:00:53 +0000 | [diff] [blame] | 270 | |
Lev Proleev | fbc2a3d | 2020-01-14 17:35:36 +0000 | [diff] [blame] | 271 | fwHiddenState = fwOutputHiddenState; |
Lev Proleev | 471aa21 | 2019-01-22 21:00:53 +0000 | [diff] [blame] | 272 | } |
| 273 | |
Lev Proleev | 471aa21 | 2019-01-22 21:00:53 +0000 | [diff] [blame] | 274 | // Backward pass |
| 275 | for (int i = maxTime - 1; i >= 0; --i) { |
Lev Proleev | 721ee16 | 2020-01-27 15:46:59 +0000 | [diff] [blame] | 276 | const T* inputBatchPtr = bwInput + i * batchSize * inputSize; |
Lev Proleev | 471aa21 | 2019-01-22 21:00:53 +0000 | [diff] [blame] | 277 | const T* auxInputBatchPtr = nullptr; |
Lev Proleev | 721ee16 | 2020-01-27 15:46:59 +0000 | [diff] [blame] | 278 | if (hasAuxWeights) { |
Lev Proleev | 471aa21 | 2019-01-22 21:00:53 +0000 | [diff] [blame] | 279 | auxInputBatchPtr = auxInput + i * batchSize * auxInputSize; |
| 280 | } |
| 281 | T* bwOutputBatchPtr; |
| 282 | uint32_t bwOutputBatchOffset = 0; |
| 283 | uint32_t bwOutputBatchStride; |
| 284 | if (mergeOutputs) { |
| 285 | bwOutputBatchStride = fwNumUnits + bwNumUnits; |
| 286 | bwOutputBatchOffset = fwNumUnits; |
| 287 | bwOutputBatchPtr = fwOutput + i * batchSize * bwOutputBatchStride; |
| 288 | } else { |
| 289 | bwOutputBatchStride = bwNumUnits; |
| 290 | bwOutputBatchPtr = bwOutput + i * batchSize * bwOutputBatchStride; |
| 291 | } |
| 292 | |
| 293 | RNN::RNNStep<T>(inputBatchPtr, fixedTimeInputShape, auxInputBatchPtr, |
| 294 | fixedTimeAuxInputShape, bwHiddenState, bwBias, bwWeights, bwWeightsShape, |
| 295 | bwAuxWeights, bwAuxWeightsShape, bwRecurrentWeights, |
| 296 | bwRecurrentWeightsShape, activation, bwOutputBatchStride, |
Lev Proleev | fbc2a3d | 2020-01-14 17:35:36 +0000 | [diff] [blame] | 297 | bwOutputBatchOffset, bwOutputBatchPtr, bwOutputHiddenState); |
Lev Proleev | 471aa21 | 2019-01-22 21:00:53 +0000 | [diff] [blame] | 298 | |
Lev Proleev | fbc2a3d | 2020-01-14 17:35:36 +0000 | [diff] [blame] | 299 | bwHiddenState = bwOutputHiddenState; |
Lev Proleev | 471aa21 | 2019-01-22 21:00:53 +0000 | [diff] [blame] | 300 | } |
| 301 | |
| 302 | // If the inputs were in batch major format, transpose data in temporary |
| 303 | // buffers and write to the output(s). |
| 304 | if (!timeMajor) { |
| 305 | transposeFirstTwoDims(fwOutputTransposed.data(), fwOutputShape, |
| 306 | context->getOutputBuffer<T>(kFwOutputTensor)); |
| 307 | if (!mergeOutputs) { |
| 308 | transposeFirstTwoDims(bwOutputTransposed.data(), bwOutputShape, |
| 309 | context->getOutputBuffer<T>(kBwOutputTensor)); |
| 310 | } |
| 311 | } |
| 312 | return true; |
| 313 | } |
| 314 | |
| 315 | } // namespace |
Slava Shklyaev | 84829a9 | 2021-02-26 12:05:38 +0000 | [diff] [blame] | 316 | #endif // NN_INCLUDE_CPU_IMPLEMENTATION |
Lev Proleev | 471aa21 | 2019-01-22 21:00:53 +0000 | [diff] [blame] | 317 | |
Michael Butler | 274ff7b | 2020-11-02 23:17:11 -0800 | [diff] [blame] | 318 | Result<Version> validate(const IOperationValidationContext* context) { |
Lev Proleev | 471aa21 | 2019-01-22 21:00:53 +0000 | [diff] [blame] | 319 | NN_RET_CHECK_EQ(context->getNumInputs(), kNumInputs); |
| 320 | // Exact number is dependent on the mergeOutputs parameter and checked |
| 321 | // during preparation. |
Lev Proleev | fbc2a3d | 2020-01-14 17:35:36 +0000 | [diff] [blame] | 322 | const uint32_t numOutputs = context->getNumOutputs(); |
| 323 | NN_RET_CHECK(numOutputs == kNumOutputs || numOutputs == kNumOutputsMerged || |
| 324 | numOutputs == kNumOutputsWithState || numOutputs == kNumOutputsMergedWithState); |
| 325 | |
Lev Proleev | 471aa21 | 2019-01-22 21:00:53 +0000 | [diff] [blame] | 326 | OperandType inputType = context->getInputType(kInputTensor); |
| 327 | if (inputType != OperandType::TENSOR_FLOAT16 && inputType != OperandType::TENSOR_FLOAT32) { |
Michael Butler | 274ff7b | 2020-11-02 23:17:11 -0800 | [diff] [blame] | 328 | return NN_ERROR() << "Unsupported input operand type for UNIDIRECTIONAL_SEQUENCE_RNN op: " |
| 329 | << inputType; |
Lev Proleev | 471aa21 | 2019-01-22 21:00:53 +0000 | [diff] [blame] | 330 | } |
| 331 | NN_RET_CHECK(validateInputTypes( |
| 332 | context, {inputType, inputType, inputType, inputType, inputType, inputType, inputType, |
| 333 | inputType, inputType, inputType, inputType, inputType, OperandType::INT32, |
| 334 | OperandType::BOOL, OperandType::BOOL})); |
Lev Proleev | fbc2a3d | 2020-01-14 17:35:36 +0000 | [diff] [blame] | 335 | |
| 336 | std::vector<OperandType> outExpectedTypes(numOutputs, inputType); |
| 337 | NN_RET_CHECK(validateOutputTypes(context, outExpectedTypes)); |
| 338 | |
Michael Butler | 25d5073 | 2020-11-01 23:47:40 -0800 | [diff] [blame] | 339 | Version minSupportedVersion = Version::ANDROID_Q; |
Lev Proleev | fbc2a3d | 2020-01-14 17:35:36 +0000 | [diff] [blame] | 340 | if (numOutputs == kNumOutputsWithState || numOutputs == kNumOutputsMergedWithState) { |
Michael Butler | 25d5073 | 2020-11-01 23:47:40 -0800 | [diff] [blame] | 341 | minSupportedVersion = Version::ANDROID_R; |
Lev Proleev | 471aa21 | 2019-01-22 21:00:53 +0000 | [diff] [blame] | 342 | } |
Michael Butler | 274ff7b | 2020-11-02 23:17:11 -0800 | [diff] [blame] | 343 | return minSupportedVersion; |
Lev Proleev | 471aa21 | 2019-01-22 21:00:53 +0000 | [diff] [blame] | 344 | } |
| 345 | |
Slava Shklyaev | 84829a9 | 2021-02-26 12:05:38 +0000 | [diff] [blame] | 346 | #ifdef NN_INCLUDE_CPU_IMPLEMENTATION |
Lev Proleev | 471aa21 | 2019-01-22 21:00:53 +0000 | [diff] [blame] | 347 | bool prepare(IOperationExecutionContext* context) { |
Lev Proleev | fbc2a3d | 2020-01-14 17:35:36 +0000 | [diff] [blame] | 348 | const bool mergeOutputs = context->getInputValue<bool>(kMergeOutputsParam); |
| 349 | const int32_t numOutputs = context->getNumOutputs(); |
Lev Proleev | 471aa21 | 2019-01-22 21:00:53 +0000 | [diff] [blame] | 350 | if (mergeOutputs) { |
Lev Proleev | fbc2a3d | 2020-01-14 17:35:36 +0000 | [diff] [blame] | 351 | NN_RET_CHECK(numOutputs == kNumOutputsMerged || numOutputs == kNumOutputsMergedWithState); |
Lev Proleev | 471aa21 | 2019-01-22 21:00:53 +0000 | [diff] [blame] | 352 | } else { |
Lev Proleev | fbc2a3d | 2020-01-14 17:35:36 +0000 | [diff] [blame] | 353 | NN_RET_CHECK(numOutputs == kNumOutputs || numOutputs == kNumOutputsWithState); |
Lev Proleev | 471aa21 | 2019-01-22 21:00:53 +0000 | [diff] [blame] | 354 | } |
| 355 | |
Lev Proleev | 1c6a4d5 | 2019-04-24 15:30:04 +0100 | [diff] [blame] | 356 | // Check that none of the required inputs are omitted. |
| 357 | const std::vector<int> requiredInputs = { |
| 358 | kInputTensor, kFwWeightsTensor, kFwRecurrentWeightsTensor, kFwBiasTensor, |
| 359 | kFwHiddenStateTensor, kBwWeightsTensor, kBwRecurrentWeightsTensor, kBwBiasTensor, |
| 360 | kBwHiddenStateTensor, kActivationParam, kTimeMajorParam, kMergeOutputsParam, |
| 361 | }; |
| 362 | for (const int requiredInput : requiredInputs) { |
| 363 | NN_RET_CHECK(!context->isOmittedInput(requiredInput)) |
| 364 | << "required input " << requiredInput << " is omitted"; |
| 365 | } |
| 366 | |
Lev Proleev | 471aa21 | 2019-01-22 21:00:53 +0000 | [diff] [blame] | 367 | Shape input = context->getInputShape(kInputTensor); |
| 368 | Shape fwWeights = context->getInputShape(kFwWeightsTensor); |
| 369 | Shape fwRecurrentWeights = context->getInputShape(kFwRecurrentWeightsTensor); |
| 370 | Shape fwBias = context->getInputShape(kFwBiasTensor); |
| 371 | Shape fwHiddenState = context->getInputShape(kFwHiddenStateTensor); |
| 372 | Shape bwWeights = context->getInputShape(kBwWeightsTensor); |
| 373 | Shape bwRecurrentWeights = context->getInputShape(kBwRecurrentWeightsTensor); |
| 374 | Shape bwBias = context->getInputShape(kBwBiasTensor); |
| 375 | Shape bwHiddenState = context->getInputShape(kBwHiddenStateTensor); |
| 376 | |
| 377 | Shape auxInput = context->getInputShape(kAuxInputTensor); |
| 378 | Shape fwAuxWeights = context->getInputShape(kFwAuxWeightsTensor); |
| 379 | Shape bwAuxWeights = context->getInputShape(kBwAuxWeightsTensor); |
| 380 | |
Lev Proleev | 721ee16 | 2020-01-27 15:46:59 +0000 | [diff] [blame] | 381 | LinkingMode linkingMode; |
| 382 | NN_RET_CHECK(getLinkingMode(context, &linkingMode)); |
Lev Proleev | 471aa21 | 2019-01-22 21:00:53 +0000 | [diff] [blame] | 383 | |
Lev Proleev | 721ee16 | 2020-01-27 15:46:59 +0000 | [diff] [blame] | 384 | bool timeMajor = context->getInputValue<bool>(kTimeMajorParam); |
Lev Proleev | 471aa21 | 2019-01-22 21:00:53 +0000 | [diff] [blame] | 385 | const uint32_t batchSize = |
| 386 | timeMajor ? getSizeOfDimension(input, 1) : getSizeOfDimension(input, 0); |
| 387 | const uint32_t maxTime = |
| 388 | timeMajor ? getSizeOfDimension(input, 0) : getSizeOfDimension(input, 1); |
| 389 | const uint32_t fwNumUnits = getSizeOfDimension(fwWeights, 0); |
| 390 | const uint32_t bwNumUnits = getSizeOfDimension(bwWeights, 0); |
| 391 | const uint32_t inputSize = getSizeOfDimension(input, 2); |
| 392 | |
| 393 | NN_RET_CHECK_EQ(getNumberOfDimensions(input), 3); |
| 394 | NN_RET_CHECK_EQ(getNumberOfDimensions(fwWeights), 2); |
| 395 | NN_RET_CHECK_EQ(getNumberOfDimensions(fwRecurrentWeights), 2); |
| 396 | NN_RET_CHECK_EQ(getNumberOfDimensions(fwBias), 1); |
| 397 | NN_RET_CHECK_EQ(getNumberOfDimensions(fwHiddenState), 2); |
| 398 | NN_RET_CHECK_EQ(getNumberOfDimensions(bwWeights), 2); |
| 399 | NN_RET_CHECK_EQ(getNumberOfDimensions(bwRecurrentWeights), 2); |
| 400 | NN_RET_CHECK_EQ(getNumberOfDimensions(bwBias), 1); |
| 401 | NN_RET_CHECK_EQ(getNumberOfDimensions(bwHiddenState), 2); |
| 402 | |
| 403 | NN_RET_CHECK_EQ(inputSize, getSizeOfDimension(fwWeights, 1)); |
| 404 | NN_RET_CHECK_EQ(fwNumUnits, getSizeOfDimension(fwBias, 0)); |
| 405 | NN_RET_CHECK_EQ(fwNumUnits, getSizeOfDimension(fwRecurrentWeights, 0)); |
| 406 | NN_RET_CHECK_EQ(fwNumUnits, getSizeOfDimension(fwRecurrentWeights, 1)); |
| 407 | NN_RET_CHECK_EQ(batchSize, getSizeOfDimension(fwHiddenState, 0)); |
| 408 | NN_RET_CHECK_EQ(fwNumUnits, getSizeOfDimension(fwHiddenState, 1)); |
| 409 | |
Lev Proleev | 721ee16 | 2020-01-27 15:46:59 +0000 | [diff] [blame] | 410 | if (linkingMode != LinkingMode::PARALLEL_LINKING) { |
| 411 | NN_RET_CHECK_EQ(inputSize, getSizeOfDimension(bwWeights, 1)); |
| 412 | } |
Lev Proleev | 471aa21 | 2019-01-22 21:00:53 +0000 | [diff] [blame] | 413 | NN_RET_CHECK_EQ(bwNumUnits, getSizeOfDimension(bwBias, 0)); |
| 414 | NN_RET_CHECK_EQ(bwNumUnits, getSizeOfDimension(bwRecurrentWeights, 0)); |
| 415 | NN_RET_CHECK_EQ(bwNumUnits, getSizeOfDimension(bwRecurrentWeights, 1)); |
| 416 | NN_RET_CHECK_EQ(batchSize, getSizeOfDimension(bwHiddenState, 0)); |
| 417 | NN_RET_CHECK_EQ(bwNumUnits, getSizeOfDimension(bwHiddenState, 1)); |
| 418 | |
Lev Proleev | 721ee16 | 2020-01-27 15:46:59 +0000 | [diff] [blame] | 419 | if (linkingMode == LinkingMode::CROSS_LINKING) { |
Lev Proleev | 471aa21 | 2019-01-22 21:00:53 +0000 | [diff] [blame] | 420 | NN_RET_CHECK_EQ(getNumberOfDimensions(auxInput), 3); |
| 421 | NN_RET_CHECK_EQ(getNumberOfDimensions(fwAuxWeights), 2); |
| 422 | NN_RET_CHECK_EQ(getNumberOfDimensions(bwAuxWeights), 2); |
| 423 | |
| 424 | NN_RET_CHECK_EQ(getSizeOfDimension(auxInput, 0), getSizeOfDimension(input, 0)); |
| 425 | NN_RET_CHECK_EQ(getSizeOfDimension(auxInput, 1), getSizeOfDimension(input, 1)); |
| 426 | NN_RET_CHECK_EQ(getSizeOfDimension(fwAuxWeights, 0), fwNumUnits); |
| 427 | NN_RET_CHECK_EQ(getSizeOfDimension(fwAuxWeights, 1), getSizeOfDimension(auxInput, 2)); |
| 428 | NN_RET_CHECK_EQ(getSizeOfDimension(bwAuxWeights, 0), bwNumUnits); |
| 429 | NN_RET_CHECK_EQ(getSizeOfDimension(bwAuxWeights, 1), getSizeOfDimension(auxInput, 2)); |
Lev Proleev | 721ee16 | 2020-01-27 15:46:59 +0000 | [diff] [blame] | 430 | } else if (linkingMode == LinkingMode::PARALLEL_LINKING) { |
| 431 | NN_RET_CHECK_EQ(getNumberOfDimensions(auxInput), 3); |
| 432 | |
| 433 | NN_RET_CHECK_EQ(getSizeOfDimension(auxInput, 0), getSizeOfDimension(input, 0)); |
| 434 | NN_RET_CHECK_EQ(getSizeOfDimension(auxInput, 1), getSizeOfDimension(input, 1)); |
| 435 | NN_RET_CHECK_EQ(getSizeOfDimension(auxInput, 2), getSizeOfDimension(bwWeights, 1)); |
Lev Proleev | 471aa21 | 2019-01-22 21:00:53 +0000 | [diff] [blame] | 436 | } |
| 437 | |
| 438 | Shape fwOutput = context->getOutputShape(kFwOutputTensor); |
| 439 | fwOutput.dimensions.resize(3); |
| 440 | fwOutput.dimensions[0] = timeMajor ? maxTime : batchSize; |
| 441 | fwOutput.dimensions[1] = timeMajor ? batchSize : maxTime; |
| 442 | fwOutput.dimensions[2] = mergeOutputs ? fwNumUnits + bwNumUnits : fwNumUnits; |
| 443 | NN_RET_CHECK(context->setOutputShape(kFwOutputTensor, fwOutput)); |
| 444 | if (!mergeOutputs) { |
| 445 | Shape bwOutput = context->getOutputShape(kBwOutputTensor); |
| 446 | bwOutput.dimensions.resize(3); |
| 447 | bwOutput.dimensions[0] = timeMajor ? maxTime : batchSize; |
| 448 | bwOutput.dimensions[1] = timeMajor ? batchSize : maxTime; |
| 449 | bwOutput.dimensions[2] = bwNumUnits; |
| 450 | NN_RET_CHECK(context->setOutputShape(kBwOutputTensor, bwOutput)); |
| 451 | } |
| 452 | |
Lev Proleev | fbc2a3d | 2020-01-14 17:35:36 +0000 | [diff] [blame] | 453 | const bool outputState = |
| 454 | (numOutputs == kNumOutputsWithState || numOutputs == kNumOutputsMergedWithState); |
| 455 | if (outputState) { |
| 456 | const int delta = mergeOutputs ? 1 : 0; |
| 457 | NN_RET_CHECK(context->setOutputShape(kFwOutputHiddenStateTensor - delta, |
| 458 | context->getInputShape(kFwHiddenStateTensor))); |
| 459 | NN_RET_CHECK(context->setOutputShape(kBwOutputHiddenStateTensor - delta, |
| 460 | context->getInputShape(kBwHiddenStateTensor))); |
| 461 | } |
| 462 | |
Lev Proleev | 471aa21 | 2019-01-22 21:00:53 +0000 | [diff] [blame] | 463 | return true; |
| 464 | } |
| 465 | |
| 466 | bool execute(IOperationExecutionContext* context) { |
| 467 | if (context->getInputType(kInputTensor) == OperandType::TENSOR_FLOAT16) { |
| 468 | executeTyped<_Float16>(context); |
| 469 | } else { |
| 470 | executeTyped<float>(context); |
| 471 | } |
| 472 | return true; |
| 473 | } |
Slava Shklyaev | 84829a9 | 2021-02-26 12:05:38 +0000 | [diff] [blame] | 474 | #endif // NN_INCLUDE_CPU_IMPLEMENTATION |
Lev Proleev | 471aa21 | 2019-01-22 21:00:53 +0000 | [diff] [blame] | 475 | |
| 476 | } // namespace bidirectional_sequence_rnn |
| 477 | |
| 478 | NN_REGISTER_OPERATION(BIDIRECTIONAL_SEQUENCE_RNN, "BIDIRECTIONAL_SEQUENCE_RNN", |
| 479 | bidirectional_sequence_rnn::validate, bidirectional_sequence_rnn::prepare, |
Xusong Wang | ef83690 | 2019-03-06 13:10:14 -0800 | [diff] [blame] | 480 | bidirectional_sequence_rnn::execute, .allowOmittedOperand = true); |
Lev Proleev | 471aa21 | 2019-01-22 21:00:53 +0000 | [diff] [blame] | 481 | |
| 482 | } // namespace nn |
| 483 | } // namespace android |