Add new SPLIT op

Also add tests for it.

Bug: 113563597
Test: NeuralNetworksTest_static with new tests
Change-Id: I32cb7ccd3fae023e97d207bdfbad29c497dc4044
Merged-In: I32cb7ccd3fae023e97d207bdfbad29c497dc4044
(cherry picked from commit a85cac86b4b5d53f0173197ff0882eb70ce0b2e1)
diff --git a/common/Android.bp b/common/Android.bp
index 45351b6..654a356 100644
--- a/common/Android.bp
+++ b/common/Android.bp
@@ -74,6 +74,7 @@
         "operations/Reshape.cpp",
         "operations/RNN.cpp",
         "operations/SimpleMath.cpp",
+        "operations/Split.cpp",
         "operations/StridedSlice.cpp",
         "operations/SVDF.cpp",
     ],
diff --git a/common/CpuExecutor.cpp b/common/CpuExecutor.cpp
index b275297..29bb14a 100644
--- a/common/CpuExecutor.cpp
+++ b/common/CpuExecutor.cpp
@@ -1451,6 +1451,60 @@
                       setInfoAndAllocateIfNeeded(&output, outShape) &&
                       expand_dims::eval(input.buffer, input.shape(), axis, output.buffer, outShape);
         } break;
+        case OperationType::SPLIT: {
+            if (ins.size() != 3) {
+                return ANEURALNETWORKS_BAD_DATA;
+            }
+
+            const RunTimeOperandInfo& input = mOperands[ins[0]];
+            const int32_t axis = getScalarData<int32_t>(mOperands[ins[1]]);
+            const int32_t numOutputs = getScalarData<int32_t>(mOperands[ins[2]]);
+
+            if (numOutputs != outs.size()) {
+                return ANEURALNETWORKS_BAD_DATA;
+            }
+
+            std::vector<Shape> outputShapes(numOutputs);
+            for (int i = 0; i < numOutputs; ++i) {
+                outputShapes[i] = mOperands[outs[i]].shape();
+            }
+
+            success = splitPrepare(input.shape(), axis, numOutputs, &outputShapes);
+            for (int i = 0; i < numOutputs; ++i) {
+                success = success &&
+                          setInfoAndAllocateIfNeeded(&(mOperands[outs[i]]), outputShapes[i]);
+            }
+            switch (input.type) {
+                case OperandType::TENSOR_FLOAT32: {
+                    std::vector<float*> outputDataPtrs(numOutputs);
+                    for (int i = 0; i < numOutputs; ++i) {
+                        outputDataPtrs[i] = reinterpret_cast<float*>(mOperands[outs[i]].buffer);
+                    }
+                    success = success &&
+                              splitFloat32(reinterpret_cast<const float*>(input.buffer),
+                                           input.shape(), axis, &outputDataPtrs, outputShapes);
+                } break;
+                case OperandType::TENSOR_INT32: {
+                    std::vector<int32_t*> outputDataPtrs(numOutputs);
+                    for (int i = 0; i < numOutputs; ++i) {
+                        outputDataPtrs[i] = reinterpret_cast<int32_t*>(mOperands[outs[i]].buffer);
+                    }
+                    success = success &&
+                              splitInt32(reinterpret_cast<const int32_t*>(input.buffer),
+                                         input.shape(), axis, &outputDataPtrs, outputShapes);
+                } break;
+                case OperandType::TENSOR_QUANT8_ASYMM: {
+                    std::vector<uint8_t*> outputDataPtrs(numOutputs);
+                    for (int i = 0; i < numOutputs; ++i) {
+                        outputDataPtrs[i] = reinterpret_cast<uint8_t*>(mOperands[outs[i]].buffer);
+                    }
+                    success = success &&
+                              splitQuant8(reinterpret_cast<const uint8_t*>(input.buffer),
+                                          input.shape(), axis, &outputDataPtrs, outputShapes);
+                } break;
+                default: { return ANEURALNETWORKS_BAD_DATA; }
+            }
+        } break;
         default:
             nnAssert(false);
             break;
diff --git a/common/OperationsUtils.cpp b/common/OperationsUtils.cpp
index 9601853..cfd1c92 100644
--- a/common/OperationsUtils.cpp
+++ b/common/OperationsUtils.cpp
@@ -921,5 +921,24 @@
 
     return true;
 }
+
+bool splitPrepare(const Shape& input, int32_t axis, int32_t numOutputs,
+                  std::vector<Shape>* output) {
+    axis = getDimensionIndex(input, axis);
+
+    const int32_t sizeOfAxisToSplit = input.dimensions[axis];
+    NN_OPS_CHECK(sizeOfAxisToSplit % numOutputs == 0);
+    const int32_t sliceSize = sizeOfAxisToSplit / numOutputs;
+
+    for (int i = 0; i < numOutputs; ++i) {
+        output->at(i).type = input.type;
+        output->at(i).dimensions = input.dimensions;
+        output->at(i).dimensions[axis] = sliceSize;
+        output->at(i).offset = input.offset;
+        output->at(i).scale = input.scale;
+    }
+    return true;
+}
+
 } // namespace nn
 } // namespace android
diff --git a/common/Utils.cpp b/common/Utils.cpp
index 46b811a..c195015 100644
--- a/common/Utils.cpp
+++ b/common/Utils.cpp
@@ -1568,6 +1568,20 @@
                                                  outputCount, outputIndexes,
                                                  outExpectedTypes);
         }
+        case ANEURALNETWORKS_SPLIT: {
+            if (inputCount != 3) {
+                LOG(ERROR) << "Invalid number of input operands (" << inputCount << ", expected 3)"
+                           << kOperationNames[opType];
+                return ANEURALNETWORKS_BAD_DATA;
+            }
+            auto inputType = operands[inputIndexes[0]].type;
+            std::vector<OperandType> inExpectedTypes = {inputType, OperandType::INT32,
+                                                        OperandType::INT32};
+            std::vector<OperandType> outExpectedTypes(outputCount, inputType);
+            return validateOperationOperandTypes(operands, inputCount, inputIndexes,
+                                                 inExpectedTypes, outputCount, outputIndexes,
+                                                 outExpectedTypes);
+        }
         default:
             return ANEURALNETWORKS_BAD_DATA;
     }
diff --git a/common/include/Operations.h b/common/include/Operations.h
index 588253f..7e2283f 100644
--- a/common/include/Operations.h
+++ b/common/include/Operations.h
@@ -238,6 +238,18 @@
 bool argMinMaxGeneric(const uint8_t* inputData, const Shape& inputShape,
                       int32_t axis, bool isArgMin,
                       uint8_t* outputData, const Shape& outputShape);
+
+bool splitFloat32(const float* inputData, const Shape& inputShape, const int32_t axis,
+                  const std::vector<float*>* outputDataPtrs,
+                  const std::vector<Shape>& outputShapes);
+
+bool splitInt32(const int32_t* inputData, const Shape& inputShape, const int32_t axis,
+                const std::vector<int32_t*>* outputDataPtrs,
+                const std::vector<Shape>& outputShapes);
+
+bool splitQuant8(const uint8_t* inputData, const Shape& inputShape, const int32_t axis,
+                 const std::vector<uint8_t*>* outputDataPtrs,
+                 const std::vector<Shape>& outputShapes);
 } // namespace nn
 } // namespace android
 #endif // ANDROID_ML_NN_COMMON_OPERATIONS_H
diff --git a/common/include/OperationsUtils.h b/common/include/OperationsUtils.h
index 358be4f..84cd4d3 100644
--- a/common/include/OperationsUtils.h
+++ b/common/include/OperationsUtils.h
@@ -289,6 +289,7 @@
 
 bool argMinMaxPrepare(const Shape& input, int32_t axis, Shape* output);
 
+bool splitPrepare(const Shape& input, int32_t axis, int32_t numOutputs, std::vector<Shape>* output);
 } // namespace nn
 } // namespace android
 
diff --git a/common/operations/Split.cpp b/common/operations/Split.cpp
new file mode 100644
index 0000000..290e2c8
--- /dev/null
+++ b/common/operations/Split.cpp
@@ -0,0 +1,76 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#define LOG_TAG "Operations"
+
+#include "Operations.h"
+#include "OperationsUtils.h"
+
+#include "Tracing.h"
+
+namespace android {
+namespace nn {
+
+template <typename Scalar>
+bool splitGeneric(const Scalar* inputData, const Shape& inputShape, int32_t axis,
+                  const std::vector<Scalar*>* outputDataPtrs,
+                  const std::vector<Shape>& outputShapes) {
+    axis = getDimensionIndex(inputShape, axis);
+    int outerSize = 1;
+    for (int i = 0; i < axis; ++i) {
+        outerSize *= inputShape.dimensions[i];
+    }
+    int baseInnerSize = 1;
+    int concatDimensions = getNumberOfDimensions(inputShape);
+    for (int i = axis + 1; i < concatDimensions; ++i) {
+        baseInnerSize *= inputShape.dimensions[i];
+    }
+
+    const Scalar* inputPtr = inputData;
+    for (int k = 0; k < outerSize; k++) {
+        for (int i = 0; i < outputDataPtrs->size(); ++i) {
+            const int copySize = outputShapes[i].dimensions[axis] * baseInnerSize;
+            memcpy(outputDataPtrs->at(i) + k * copySize, inputPtr, copySize * sizeof(Scalar));
+            inputPtr += copySize;
+        }
+    }
+
+    return true;
+}
+
+bool splitFloat32(const float* inputData, const Shape& inputShape, int32_t axis,
+                  const std::vector<float*>* outputDataPtrs,
+                  const std::vector<Shape>& outputShapes) {
+    NNTRACE_COMP("splitFloat32");
+    return splitGeneric<float>(inputData, inputShape, axis, outputDataPtrs, outputShapes);
+}
+
+bool splitQuant8(const uint8_t* inputData, const Shape& inputShape, int32_t axis,
+                 const std::vector<uint8_t*>* outputDataPtrs,
+                 const std::vector<Shape>& outputShapes) {
+    NNTRACE_COMP("splitQuant8");
+    return splitGeneric<uint8_t>(inputData, inputShape, axis, outputDataPtrs, outputShapes);
+}
+
+bool splitInt32(const int32_t* inputData, const Shape& inputShape, int32_t axis,
+                const std::vector<int32_t*>* outputDataPtrs,
+                const std::vector<Shape>& outputShapes) {
+    NNTRACE_COMP("splitInt32");
+    return splitGeneric<int32_t>(inputData, inputShape, axis, outputDataPtrs, outputShapes);
+}
+
+}  // namespace nn
+}  // namespace android