| /* |
| * Copyright (C) 2019 The Android Open Source Project |
| * |
| * Licensed under the Apache License, Version 2.0 (the "License"); |
| * you may not use this file except in compliance with the License. |
| * You may obtain a copy of the License at |
| * |
| * http://www.apache.org/licenses/LICENSE-2.0 |
| * |
| * Unless required by applicable law or agreed to in writing, software |
| * distributed under the License is distributed on an "AS IS" BASIS, |
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| * See the License for the specific language governing permissions and |
| * limitations under the License. |
| */ |
| |
| #define LOG_TAG "Operations" |
| |
| #include "BidirectionalSequenceRNN.h" |
| |
| #include <algorithm> |
| #include <utility> |
| #include <vector> |
| |
| #include "OperationResolver.h" |
| #include "RNN.h" |
| |
| namespace android { |
| namespace nn { |
| namespace bidirectional_sequence_rnn { |
| |
| #ifdef NN_INCLUDE_CPU_IMPLEMENTATION |
| namespace { |
| |
| template <typename T> |
| void transposeFirstTwoDims(const T* input, const Shape& inputShape, T* output) { |
| const uint32_t firstDimSize = getSizeOfDimension(inputShape, 0); |
| const uint32_t secondDimSize = getSizeOfDimension(inputShape, 1); |
| const uint32_t inputSize = getSizeOfDimension(inputShape, 2); |
| for (uint32_t f = 0; f < firstDimSize; ++f) { |
| for (uint32_t s = 0; s < secondDimSize; ++s) { |
| for (uint32_t i = 0; i < inputSize; ++i) { |
| const uint32_t inputIndex = f * secondDimSize * inputSize + s * inputSize + i; |
| const uint32_t outputIndex = s * firstDimSize * inputSize + f * inputSize + i; |
| output[outputIndex] = input[inputIndex]; |
| } |
| } |
| } |
| } |
| |
| Shape removeFirstDim(const Shape& input) { |
| Shape output = input; |
| output.dimensions.resize(input.dimensions.size() - 1); |
| for (size_t i = 0; i < input.dimensions.size() - 1; ++i) { |
| output.dimensions[i] = input.dimensions[i + 1]; |
| } |
| return output; |
| } |
| |
| enum class LinkingMode { |
| NO_LINKING, |
| PARALLEL_LINKING, |
| CROSS_LINKING, |
| }; |
| |
| bool getLinkingMode(IOperationExecutionContext* context, LinkingMode* linkingMode) { |
| const bool hasAuxInput = !context->isOmittedInput(kAuxInputTensor); |
| const bool hasFwAuxWeights = !context->isOmittedInput(kFwAuxWeightsTensor); |
| const bool hasBwAuxWeights = !context->isOmittedInput(kBwAuxWeightsTensor); |
| |
| // Three possible configurations for three possible linking modes: |
| // 1) NO_LINKING -- no auxiliary tensors at all |
| // 2) PARALLEL_LINKING -- auxiliary input is provided and used as a regular |
| // input to the backward network, so the auxiliary weights are omitted. |
| // 3) CROSS_LINKING -- auxiliary input is provided and multiplied by |
| // auxiliary weights. |
| if (!hasAuxInput && !hasFwAuxWeights && !hasBwAuxWeights) { |
| *linkingMode = LinkingMode::NO_LINKING; |
| } else if (hasAuxInput && !hasFwAuxWeights && !hasBwAuxWeights) { |
| *linkingMode = LinkingMode::PARALLEL_LINKING; |
| } else if (hasAuxInput && hasFwAuxWeights && hasBwAuxWeights) { |
| *linkingMode = LinkingMode::CROSS_LINKING; |
| } else { |
| NN_RET_CHECK_FAIL() |
| << "Unsupported auxiliary tensors configuration for BIDIRECTIONAL_SEQUENCE_RNN."; |
| } |
| |
| return true; |
| } |
| |
| template <typename T> |
| bool executeTyped(IOperationExecutionContext* context) { |
| const T* input = context->getInputBuffer<T>(kInputTensor); |
| Shape inputShape = context->getInputShape(kInputTensor); |
| |
| const T* fwWeights = context->getInputBuffer<T>(kFwWeightsTensor); |
| Shape fwWeightsShape = context->getInputShape(kFwWeightsTensor); |
| const T* fwRecurrentWeights = context->getInputBuffer<T>(kFwRecurrentWeightsTensor); |
| Shape fwRecurrentWeightsShape = context->getInputShape(kFwRecurrentWeightsTensor); |
| const T* fwBias = context->getInputBuffer<T>(kFwBiasTensor); |
| const T* fwHiddenState = context->getInputBuffer<T>(kFwHiddenStateTensor); |
| |
| const T* bwWeights = context->getInputBuffer<T>(kBwWeightsTensor); |
| Shape bwWeightsShape = context->getInputShape(kBwWeightsTensor); |
| const T* bwRecurrentWeights = context->getInputBuffer<T>(kBwRecurrentWeightsTensor); |
| Shape bwRecurrentWeightsShape = context->getInputShape(kBwRecurrentWeightsTensor); |
| const T* bwBias = context->getInputBuffer<T>(kBwBiasTensor); |
| const T* bwHiddenState = context->getInputBuffer<T>(kBwHiddenStateTensor); |
| |
| const T* auxInput = nullptr; |
| const T* fwAuxWeights = nullptr; |
| const T* bwAuxWeights = nullptr; |
| LinkingMode linkingMode; |
| NN_RET_CHECK(getLinkingMode(context, &linkingMode)); |
| if (linkingMode == LinkingMode::CROSS_LINKING) { |
| auxInput = context->getInputBuffer<T>(kAuxInputTensor); |
| fwAuxWeights = context->getInputBuffer<T>(kFwAuxWeightsTensor); |
| bwAuxWeights = context->getInputBuffer<T>(kBwAuxWeightsTensor); |
| } else if (linkingMode == LinkingMode::PARALLEL_LINKING) { |
| auxInput = context->getInputBuffer<T>(kAuxInputTensor); |
| } |
| const bool hasAuxInput = (linkingMode == LinkingMode::CROSS_LINKING || |
| linkingMode == LinkingMode::PARALLEL_LINKING); |
| const bool hasAuxWeights = (linkingMode == LinkingMode::CROSS_LINKING); |
| Shape auxInputShape = context->getInputShape(kAuxInputTensor); |
| Shape fwAuxWeightsShape = context->getInputShape(kFwAuxWeightsTensor); |
| Shape bwAuxWeightsShape = context->getInputShape(kBwAuxWeightsTensor); |
| |
| const int32_t activation = context->getInputValue<int32_t>(kActivationParam); |
| const bool timeMajor = context->getInputValue<bool>(kTimeMajorParam); |
| const bool mergeOutputs = context->getInputValue<bool>(kMergeOutputsParam); |
| |
| T* fwOutput = context->getOutputBuffer<T>(kFwOutputTensor); |
| Shape fwOutputShape = context->getOutputShape(kFwOutputTensor); |
| T* bwOutput = nullptr; |
| Shape bwOutputShape; |
| if (!mergeOutputs) { |
| bwOutputShape = context->getOutputShape(kBwOutputTensor); |
| bwOutput = context->getOutputBuffer<T>(kBwOutputTensor); |
| } |
| |
| // If the input tensors are not in time major format, we transpose the first |
| // two dimensions, and set input and output pointers to temporary vectors |
| // which are transposed back after the RNN is applied. |
| std::vector<T> inputTransposed; |
| std::vector<T> auxInputTransposed; |
| std::vector<T> fwOutputTransposed; |
| std::vector<T> bwOutputTransposed; |
| if (!timeMajor) { |
| // First, resize temporary buffers to accommodate for transposed tensors. |
| inputTransposed.resize(getNumberOfElements(inputShape)); |
| if (hasAuxInput) { |
| auxInputTransposed.resize(getNumberOfElements(auxInputShape)); |
| } |
| fwOutputTransposed.resize(getNumberOfElements(fwOutputShape)); |
| if (!mergeOutputs) { |
| bwOutputTransposed.resize(getNumberOfElements(bwOutputShape)); |
| } |
| |
| // Transpose the input tensors. |
| transposeFirstTwoDims(input, inputShape, inputTransposed.data()); |
| if (hasAuxInput) { |
| transposeFirstTwoDims(auxInput, auxInputShape, auxInputTransposed.data()); |
| } |
| |
| // Change input and output pointers to the temporary buffers. |
| input = inputTransposed.data(); |
| if (hasAuxInput) { |
| auxInput = auxInputTransposed.data(); |
| } |
| fwOutput = fwOutputTransposed.data(); |
| if (!mergeOutputs) { |
| bwOutput = bwOutputTransposed.data(); |
| } |
| |
| // Swap the first two dimensions in the Shapes to reflect the |
| // transposition. |
| std::swap(inputShape.dimensions[0], inputShape.dimensions[1]); |
| if (hasAuxInput) { |
| std::swap(auxInputShape.dimensions[0], auxInputShape.dimensions[1]); |
| } |
| std::swap(fwOutputShape.dimensions[0], fwOutputShape.dimensions[1]); |
| if (!mergeOutputs) { |
| std::swap(bwOutputShape.dimensions[0], bwOutputShape.dimensions[1]); |
| } |
| } |
| |
| const uint32_t maxTime = getSizeOfDimension(inputShape, 0); |
| const uint32_t batchSize = getSizeOfDimension(inputShape, 1); |
| const uint32_t inputSize = getSizeOfDimension(inputShape, 2); |
| uint32_t auxInputSize = 0; |
| if (hasAuxInput) { |
| auxInputSize = getSizeOfDimension(auxInputShape, 2); |
| } |
| const uint32_t fwNumUnits = getSizeOfDimension(fwWeightsShape, 0); |
| const uint32_t bwNumUnits = getSizeOfDimension(bwWeightsShape, 0); |
| |
| Shape fixedTimeInputShape = removeFirstDim(inputShape); |
| Shape fixedTimeAuxInputShape = auxInputShape; |
| if (hasAuxInput) { |
| fixedTimeAuxInputShape = removeFirstDim(auxInputShape); |
| } |
| |
| const T* bwInput = input; |
| if (linkingMode == LinkingMode::PARALLEL_LINKING) { |
| bwInput = auxInput; |
| auxInput = nullptr; |
| } |
| |
| const bool outputState = (context->getNumOutputs() == kNumOutputsWithState || |
| context->getNumOutputs() == kNumOutputsMergedWithState); |
| T* fwOutputHiddenState = nullptr; |
| T* bwOutputHiddenState = nullptr; |
| // Create an additional buffer to store a hidden state between steps. |
| std::vector<T> tempHiddenState; |
| if (outputState) { |
| const int delta = mergeOutputs ? 1 : 0; |
| fwOutputHiddenState = context->getOutputBuffer<T>(kFwOutputHiddenStateTensor - delta); |
| bwOutputHiddenState = context->getOutputBuffer<T>(kBwOutputHiddenStateTensor - delta); |
| } else { |
| tempHiddenState.resize(std::max(batchSize * fwNumUnits, batchSize * bwNumUnits)); |
| fwOutputHiddenState = tempHiddenState.data(); |
| bwOutputHiddenState = tempHiddenState.data(); |
| } |
| |
| // Forward pass |
| for (uint32_t i = 0; i < maxTime; ++i) { |
| const T* inputBatchPtr = input + i * batchSize * inputSize; |
| const T* auxInputBatchPtr = nullptr; |
| if (hasAuxWeights) { |
| auxInputBatchPtr = auxInput + i * batchSize * auxInputSize; |
| } |
| const uint32_t fwOutputBatchStride = mergeOutputs ? (fwNumUnits + bwNumUnits) : fwNumUnits; |
| T* fwOutputBatchPtr = fwOutput + i * batchSize * fwOutputBatchStride; |
| |
| RNN::RNNStep<T>(inputBatchPtr, fixedTimeInputShape, auxInputBatchPtr, |
| fixedTimeAuxInputShape, fwHiddenState, fwBias, fwWeights, fwWeightsShape, |
| fwAuxWeights, fwAuxWeightsShape, fwRecurrentWeights, |
| fwRecurrentWeightsShape, activation, fwOutputBatchStride, |
| /*outputBatchOffset=*/0, fwOutputBatchPtr, fwOutputHiddenState); |
| |
| fwHiddenState = fwOutputHiddenState; |
| } |
| |
| // Backward pass |
| for (int i = maxTime - 1; i >= 0; --i) { |
| const T* inputBatchPtr = bwInput + i * batchSize * inputSize; |
| const T* auxInputBatchPtr = nullptr; |
| if (hasAuxWeights) { |
| auxInputBatchPtr = auxInput + i * batchSize * auxInputSize; |
| } |
| T* bwOutputBatchPtr; |
| uint32_t bwOutputBatchOffset = 0; |
| uint32_t bwOutputBatchStride; |
| if (mergeOutputs) { |
| bwOutputBatchStride = fwNumUnits + bwNumUnits; |
| bwOutputBatchOffset = fwNumUnits; |
| bwOutputBatchPtr = fwOutput + i * batchSize * bwOutputBatchStride; |
| } else { |
| bwOutputBatchStride = bwNumUnits; |
| bwOutputBatchPtr = bwOutput + i * batchSize * bwOutputBatchStride; |
| } |
| |
| RNN::RNNStep<T>(inputBatchPtr, fixedTimeInputShape, auxInputBatchPtr, |
| fixedTimeAuxInputShape, bwHiddenState, bwBias, bwWeights, bwWeightsShape, |
| bwAuxWeights, bwAuxWeightsShape, bwRecurrentWeights, |
| bwRecurrentWeightsShape, activation, bwOutputBatchStride, |
| bwOutputBatchOffset, bwOutputBatchPtr, bwOutputHiddenState); |
| |
| bwHiddenState = bwOutputHiddenState; |
| } |
| |
| // If the inputs were in batch major format, transpose data in temporary |
| // buffers and write to the output(s). |
| if (!timeMajor) { |
| transposeFirstTwoDims(fwOutputTransposed.data(), fwOutputShape, |
| context->getOutputBuffer<T>(kFwOutputTensor)); |
| if (!mergeOutputs) { |
| transposeFirstTwoDims(bwOutputTransposed.data(), bwOutputShape, |
| context->getOutputBuffer<T>(kBwOutputTensor)); |
| } |
| } |
| return true; |
| } |
| |
| } // namespace |
| |
| bool prepare(IOperationExecutionContext* context) { |
| const bool mergeOutputs = context->getInputValue<bool>(kMergeOutputsParam); |
| const int32_t numOutputs = context->getNumOutputs(); |
| if (mergeOutputs) { |
| NN_RET_CHECK(numOutputs == kNumOutputsMerged || numOutputs == kNumOutputsMergedWithState); |
| } else { |
| NN_RET_CHECK(numOutputs == kNumOutputs || numOutputs == kNumOutputsWithState); |
| } |
| |
| // Check that none of the required inputs are omitted. |
| const std::vector<int> requiredInputs = { |
| kInputTensor, kFwWeightsTensor, kFwRecurrentWeightsTensor, kFwBiasTensor, |
| kFwHiddenStateTensor, kBwWeightsTensor, kBwRecurrentWeightsTensor, kBwBiasTensor, |
| kBwHiddenStateTensor, kActivationParam, kTimeMajorParam, kMergeOutputsParam, |
| }; |
| for (const int requiredInput : requiredInputs) { |
| NN_RET_CHECK(!context->isOmittedInput(requiredInput)) |
| << "required input " << requiredInput << " is omitted"; |
| } |
| |
| Shape input = context->getInputShape(kInputTensor); |
| Shape fwWeights = context->getInputShape(kFwWeightsTensor); |
| Shape fwRecurrentWeights = context->getInputShape(kFwRecurrentWeightsTensor); |
| Shape fwBias = context->getInputShape(kFwBiasTensor); |
| Shape fwHiddenState = context->getInputShape(kFwHiddenStateTensor); |
| Shape bwWeights = context->getInputShape(kBwWeightsTensor); |
| Shape bwRecurrentWeights = context->getInputShape(kBwRecurrentWeightsTensor); |
| Shape bwBias = context->getInputShape(kBwBiasTensor); |
| Shape bwHiddenState = context->getInputShape(kBwHiddenStateTensor); |
| |
| Shape auxInput = context->getInputShape(kAuxInputTensor); |
| Shape fwAuxWeights = context->getInputShape(kFwAuxWeightsTensor); |
| Shape bwAuxWeights = context->getInputShape(kBwAuxWeightsTensor); |
| |
| LinkingMode linkingMode; |
| NN_RET_CHECK(getLinkingMode(context, &linkingMode)); |
| |
| bool timeMajor = context->getInputValue<bool>(kTimeMajorParam); |
| const uint32_t batchSize = |
| timeMajor ? getSizeOfDimension(input, 1) : getSizeOfDimension(input, 0); |
| const uint32_t maxTime = |
| timeMajor ? getSizeOfDimension(input, 0) : getSizeOfDimension(input, 1); |
| const uint32_t fwNumUnits = getSizeOfDimension(fwWeights, 0); |
| const uint32_t bwNumUnits = getSizeOfDimension(bwWeights, 0); |
| const uint32_t inputSize = getSizeOfDimension(input, 2); |
| |
| NN_RET_CHECK_EQ(getNumberOfDimensions(input), 3u); |
| NN_RET_CHECK_EQ(getNumberOfDimensions(fwWeights), 2u); |
| NN_RET_CHECK_EQ(getNumberOfDimensions(fwRecurrentWeights), 2u); |
| NN_RET_CHECK_EQ(getNumberOfDimensions(fwBias), 1u); |
| NN_RET_CHECK_EQ(getNumberOfDimensions(fwHiddenState), 2u); |
| NN_RET_CHECK_EQ(getNumberOfDimensions(bwWeights), 2u); |
| NN_RET_CHECK_EQ(getNumberOfDimensions(bwRecurrentWeights), 2u); |
| NN_RET_CHECK_EQ(getNumberOfDimensions(bwBias), 1u); |
| NN_RET_CHECK_EQ(getNumberOfDimensions(bwHiddenState), 2u); |
| |
| NN_RET_CHECK_EQ(inputSize, getSizeOfDimension(fwWeights, 1)); |
| NN_RET_CHECK_EQ(fwNumUnits, getSizeOfDimension(fwBias, 0)); |
| NN_RET_CHECK_EQ(fwNumUnits, getSizeOfDimension(fwRecurrentWeights, 0)); |
| NN_RET_CHECK_EQ(fwNumUnits, getSizeOfDimension(fwRecurrentWeights, 1)); |
| NN_RET_CHECK_EQ(batchSize, getSizeOfDimension(fwHiddenState, 0)); |
| NN_RET_CHECK_EQ(fwNumUnits, getSizeOfDimension(fwHiddenState, 1)); |
| |
| if (linkingMode != LinkingMode::PARALLEL_LINKING) { |
| NN_RET_CHECK_EQ(inputSize, getSizeOfDimension(bwWeights, 1)); |
| } |
| NN_RET_CHECK_EQ(bwNumUnits, getSizeOfDimension(bwBias, 0)); |
| NN_RET_CHECK_EQ(bwNumUnits, getSizeOfDimension(bwRecurrentWeights, 0)); |
| NN_RET_CHECK_EQ(bwNumUnits, getSizeOfDimension(bwRecurrentWeights, 1)); |
| NN_RET_CHECK_EQ(batchSize, getSizeOfDimension(bwHiddenState, 0)); |
| NN_RET_CHECK_EQ(bwNumUnits, getSizeOfDimension(bwHiddenState, 1)); |
| |
| if (linkingMode == LinkingMode::CROSS_LINKING) { |
| NN_RET_CHECK_EQ(getNumberOfDimensions(auxInput), 3u); |
| NN_RET_CHECK_EQ(getNumberOfDimensions(fwAuxWeights), 2u); |
| NN_RET_CHECK_EQ(getNumberOfDimensions(bwAuxWeights), 2u); |
| |
| NN_RET_CHECK_EQ(getSizeOfDimension(auxInput, 0), getSizeOfDimension(input, 0)); |
| NN_RET_CHECK_EQ(getSizeOfDimension(auxInput, 1), getSizeOfDimension(input, 1)); |
| NN_RET_CHECK_EQ(getSizeOfDimension(fwAuxWeights, 0), fwNumUnits); |
| NN_RET_CHECK_EQ(getSizeOfDimension(fwAuxWeights, 1), getSizeOfDimension(auxInput, 2)); |
| NN_RET_CHECK_EQ(getSizeOfDimension(bwAuxWeights, 0), bwNumUnits); |
| NN_RET_CHECK_EQ(getSizeOfDimension(bwAuxWeights, 1), getSizeOfDimension(auxInput, 2)); |
| } else if (linkingMode == LinkingMode::PARALLEL_LINKING) { |
| NN_RET_CHECK_EQ(getNumberOfDimensions(auxInput), 3u); |
| |
| NN_RET_CHECK_EQ(getSizeOfDimension(auxInput, 0), getSizeOfDimension(input, 0)); |
| NN_RET_CHECK_EQ(getSizeOfDimension(auxInput, 1), getSizeOfDimension(input, 1)); |
| NN_RET_CHECK_EQ(getSizeOfDimension(auxInput, 2), getSizeOfDimension(bwWeights, 1)); |
| } |
| |
| Shape fwOutput = context->getOutputShape(kFwOutputTensor); |
| fwOutput.dimensions.resize(3); |
| fwOutput.dimensions[0] = timeMajor ? maxTime : batchSize; |
| fwOutput.dimensions[1] = timeMajor ? batchSize : maxTime; |
| fwOutput.dimensions[2] = mergeOutputs ? fwNumUnits + bwNumUnits : fwNumUnits; |
| NN_RET_CHECK(context->setOutputShape(kFwOutputTensor, fwOutput)); |
| if (!mergeOutputs) { |
| Shape bwOutput = context->getOutputShape(kBwOutputTensor); |
| bwOutput.dimensions.resize(3); |
| bwOutput.dimensions[0] = timeMajor ? maxTime : batchSize; |
| bwOutput.dimensions[1] = timeMajor ? batchSize : maxTime; |
| bwOutput.dimensions[2] = bwNumUnits; |
| NN_RET_CHECK(context->setOutputShape(kBwOutputTensor, bwOutput)); |
| } |
| |
| const bool outputState = |
| (numOutputs == kNumOutputsWithState || numOutputs == kNumOutputsMergedWithState); |
| if (outputState) { |
| const int delta = mergeOutputs ? 1 : 0; |
| NN_RET_CHECK(context->setOutputShape(kFwOutputHiddenStateTensor - delta, |
| context->getInputShape(kFwHiddenStateTensor))); |
| NN_RET_CHECK(context->setOutputShape(kBwOutputHiddenStateTensor - delta, |
| context->getInputShape(kBwHiddenStateTensor))); |
| } |
| |
| return true; |
| } |
| |
| bool execute(IOperationExecutionContext* context) { |
| if (context->getInputType(kInputTensor) == OperandType::TENSOR_FLOAT16) { |
| executeTyped<_Float16>(context); |
| } else { |
| executeTyped<float>(context); |
| } |
| return true; |
| } |
| #endif // NN_INCLUDE_CPU_IMPLEMENTATION |
| |
| } // namespace bidirectional_sequence_rnn |
| |
| NN_REGISTER_OPERATION_DEFAULT_VALIDATION(BIDIRECTIONAL_SEQUENCE_RNN, |
| bidirectional_sequence_rnn::prepare, |
| bidirectional_sequence_rnn::execute, |
| .allowOmittedOperand = true); |
| |
| } // namespace nn |
| } // namespace android |