Add new SPLIT op Also add tests for it. Bug: 113563597 Test: NeuralNetworksTest_static with new tests Change-Id: I32cb7ccd3fae023e97d207bdfbad29c497dc4044 Merged-In: I32cb7ccd3fae023e97d207bdfbad29c497dc4044 (cherry picked from commit a85cac86b4b5d53f0173197ff0882eb70ce0b2e1)
diff --git a/common/Android.bp b/common/Android.bp index 45351b6..654a356 100644 --- a/common/Android.bp +++ b/common/Android.bp
@@ -74,6 +74,7 @@ "operations/Reshape.cpp", "operations/RNN.cpp", "operations/SimpleMath.cpp", + "operations/Split.cpp", "operations/StridedSlice.cpp", "operations/SVDF.cpp", ],
diff --git a/common/CpuExecutor.cpp b/common/CpuExecutor.cpp index b275297..29bb14a 100644 --- a/common/CpuExecutor.cpp +++ b/common/CpuExecutor.cpp
@@ -1451,6 +1451,60 @@ setInfoAndAllocateIfNeeded(&output, outShape) && expand_dims::eval(input.buffer, input.shape(), axis, output.buffer, outShape); } break; + case OperationType::SPLIT: { + if (ins.size() != 3) { + return ANEURALNETWORKS_BAD_DATA; + } + + const RunTimeOperandInfo& input = mOperands[ins[0]]; + const int32_t axis = getScalarData<int32_t>(mOperands[ins[1]]); + const int32_t numOutputs = getScalarData<int32_t>(mOperands[ins[2]]); + + if (numOutputs != outs.size()) { + return ANEURALNETWORKS_BAD_DATA; + } + + std::vector<Shape> outputShapes(numOutputs); + for (int i = 0; i < numOutputs; ++i) { + outputShapes[i] = mOperands[outs[i]].shape(); + } + + success = splitPrepare(input.shape(), axis, numOutputs, &outputShapes); + for (int i = 0; i < numOutputs; ++i) { + success = success && + setInfoAndAllocateIfNeeded(&(mOperands[outs[i]]), outputShapes[i]); + } + switch (input.type) { + case OperandType::TENSOR_FLOAT32: { + std::vector<float*> outputDataPtrs(numOutputs); + for (int i = 0; i < numOutputs; ++i) { + outputDataPtrs[i] = reinterpret_cast<float*>(mOperands[outs[i]].buffer); + } + success = success && + splitFloat32(reinterpret_cast<const float*>(input.buffer), + input.shape(), axis, &outputDataPtrs, outputShapes); + } break; + case OperandType::TENSOR_INT32: { + std::vector<int32_t*> outputDataPtrs(numOutputs); + for (int i = 0; i < numOutputs; ++i) { + outputDataPtrs[i] = reinterpret_cast<int32_t*>(mOperands[outs[i]].buffer); + } + success = success && + splitInt32(reinterpret_cast<const int32_t*>(input.buffer), + input.shape(), axis, &outputDataPtrs, outputShapes); + } break; + case OperandType::TENSOR_QUANT8_ASYMM: { + std::vector<uint8_t*> outputDataPtrs(numOutputs); + for (int i = 0; i < numOutputs; ++i) { + outputDataPtrs[i] = reinterpret_cast<uint8_t*>(mOperands[outs[i]].buffer); + } + success = success && + splitQuant8(reinterpret_cast<const uint8_t*>(input.buffer), + input.shape(), axis, &outputDataPtrs, outputShapes); + } break; + default: { return ANEURALNETWORKS_BAD_DATA; } + } + } break; default: nnAssert(false); break;
diff --git a/common/OperationsUtils.cpp b/common/OperationsUtils.cpp index 9601853..cfd1c92 100644 --- a/common/OperationsUtils.cpp +++ b/common/OperationsUtils.cpp
@@ -921,5 +921,24 @@ return true; } + +bool splitPrepare(const Shape& input, int32_t axis, int32_t numOutputs, + std::vector<Shape>* output) { + axis = getDimensionIndex(input, axis); + + const int32_t sizeOfAxisToSplit = input.dimensions[axis]; + NN_OPS_CHECK(sizeOfAxisToSplit % numOutputs == 0); + const int32_t sliceSize = sizeOfAxisToSplit / numOutputs; + + for (int i = 0; i < numOutputs; ++i) { + output->at(i).type = input.type; + output->at(i).dimensions = input.dimensions; + output->at(i).dimensions[axis] = sliceSize; + output->at(i).offset = input.offset; + output->at(i).scale = input.scale; + } + return true; +} + } // namespace nn } // namespace android
diff --git a/common/Utils.cpp b/common/Utils.cpp index 46b811a..c195015 100644 --- a/common/Utils.cpp +++ b/common/Utils.cpp
@@ -1568,6 +1568,20 @@ outputCount, outputIndexes, outExpectedTypes); } + case ANEURALNETWORKS_SPLIT: { + if (inputCount != 3) { + LOG(ERROR) << "Invalid number of input operands (" << inputCount << ", expected 3)" + << kOperationNames[opType]; + return ANEURALNETWORKS_BAD_DATA; + } + auto inputType = operands[inputIndexes[0]].type; + std::vector<OperandType> inExpectedTypes = {inputType, OperandType::INT32, + OperandType::INT32}; + std::vector<OperandType> outExpectedTypes(outputCount, inputType); + return validateOperationOperandTypes(operands, inputCount, inputIndexes, + inExpectedTypes, outputCount, outputIndexes, + outExpectedTypes); + } default: return ANEURALNETWORKS_BAD_DATA; }
diff --git a/common/include/Operations.h b/common/include/Operations.h index 588253f..7e2283f 100644 --- a/common/include/Operations.h +++ b/common/include/Operations.h
@@ -238,6 +238,18 @@ bool argMinMaxGeneric(const uint8_t* inputData, const Shape& inputShape, int32_t axis, bool isArgMin, uint8_t* outputData, const Shape& outputShape); + +bool splitFloat32(const float* inputData, const Shape& inputShape, const int32_t axis, + const std::vector<float*>* outputDataPtrs, + const std::vector<Shape>& outputShapes); + +bool splitInt32(const int32_t* inputData, const Shape& inputShape, const int32_t axis, + const std::vector<int32_t*>* outputDataPtrs, + const std::vector<Shape>& outputShapes); + +bool splitQuant8(const uint8_t* inputData, const Shape& inputShape, const int32_t axis, + const std::vector<uint8_t*>* outputDataPtrs, + const std::vector<Shape>& outputShapes); } // namespace nn } // namespace android #endif // ANDROID_ML_NN_COMMON_OPERATIONS_H
diff --git a/common/include/OperationsUtils.h b/common/include/OperationsUtils.h index 358be4f..84cd4d3 100644 --- a/common/include/OperationsUtils.h +++ b/common/include/OperationsUtils.h
@@ -289,6 +289,7 @@ bool argMinMaxPrepare(const Shape& input, int32_t axis, Shape* output); +bool splitPrepare(const Shape& input, int32_t axis, int32_t numOutputs, std::vector<Shape>* output); } // namespace nn } // namespace android
diff --git a/common/operations/Split.cpp b/common/operations/Split.cpp new file mode 100644 index 0000000..290e2c8 --- /dev/null +++ b/common/operations/Split.cpp
@@ -0,0 +1,76 @@ +/* + * Copyright (C) 2018 The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#define LOG_TAG "Operations" + +#include "Operations.h" +#include "OperationsUtils.h" + +#include "Tracing.h" + +namespace android { +namespace nn { + +template <typename Scalar> +bool splitGeneric(const Scalar* inputData, const Shape& inputShape, int32_t axis, + const std::vector<Scalar*>* outputDataPtrs, + const std::vector<Shape>& outputShapes) { + axis = getDimensionIndex(inputShape, axis); + int outerSize = 1; + for (int i = 0; i < axis; ++i) { + outerSize *= inputShape.dimensions[i]; + } + int baseInnerSize = 1; + int concatDimensions = getNumberOfDimensions(inputShape); + for (int i = axis + 1; i < concatDimensions; ++i) { + baseInnerSize *= inputShape.dimensions[i]; + } + + const Scalar* inputPtr = inputData; + for (int k = 0; k < outerSize; k++) { + for (int i = 0; i < outputDataPtrs->size(); ++i) { + const int copySize = outputShapes[i].dimensions[axis] * baseInnerSize; + memcpy(outputDataPtrs->at(i) + k * copySize, inputPtr, copySize * sizeof(Scalar)); + inputPtr += copySize; + } + } + + return true; +} + +bool splitFloat32(const float* inputData, const Shape& inputShape, int32_t axis, + const std::vector<float*>* outputDataPtrs, + const std::vector<Shape>& outputShapes) { + NNTRACE_COMP("splitFloat32"); + return splitGeneric<float>(inputData, inputShape, axis, outputDataPtrs, outputShapes); +} + +bool splitQuant8(const uint8_t* inputData, const Shape& inputShape, int32_t axis, + const std::vector<uint8_t*>* outputDataPtrs, + const std::vector<Shape>& outputShapes) { + NNTRACE_COMP("splitQuant8"); + return splitGeneric<uint8_t>(inputData, inputShape, axis, outputDataPtrs, outputShapes); +} + +bool splitInt32(const int32_t* inputData, const Shape& inputShape, int32_t axis, + const std::vector<int32_t*>* outputDataPtrs, + const std::vector<Shape>& outputShapes) { + NNTRACE_COMP("splitInt32"); + return splitGeneric<int32_t>(inputData, inputShape, axis, outputDataPtrs, outputShapes); +} + +} // namespace nn +} // namespace android