Add new op CHANNEL_SHUFFLE.
Add reference CPU implementation for CHANNEL_SHUFFLE with NHWC data
layout and FP32/Quant8 input data type.
Create tests for the new op. Testcases are verified by execution on
Caffe2 lib with the same op. Generate cts/vts tests.
Bug: 113562591
Test: NeuralNetworksTest_static
Change-Id: I4792bab2c3336125e15bd7dae7cf6de022dbacd6
Merged-In: I4792bab2c3336125e15bd7dae7cf6de022dbacd6
(cherry picked from commit 451bee05b2fdcbc3ebc944ae7099f1a36ffa5834)
diff --git a/common/Android.bp b/common/Android.bp
index a217db3..6d0436a 100644
--- a/common/Android.bp
+++ b/common/Android.bp
@@ -60,6 +60,7 @@
"ValidateHal.cpp",
"operations/Activation.cpp",
"operations/ArgMinMax.cpp",
+ "operations/ChannelShuffle.cpp",
"operations/Concatenation.cpp",
"operations/Conv2D.cpp",
"operations/DepthwiseConv2D.cpp",
diff --git a/common/CpuExecutor.cpp b/common/CpuExecutor.cpp
index b62e347..e746cc0 100644
--- a/common/CpuExecutor.cpp
+++ b/common/CpuExecutor.cpp
@@ -1625,6 +1625,21 @@
reinterpret_cast<uint8_t*>(output.buffer), outShape);
}
} break;
+ case OperationType::CHANNEL_SHUFFLE: {
+ if (!allParametersPresent(2, 1)) {
+ return ANEURALNETWORKS_BAD_DATA;
+ }
+ const RunTimeOperandInfo& input = mOperands[ins[0]];
+ const int32_t numGroups = getScalarData<int32_t>(mOperands[ins[1]]);
+
+ RunTimeOperandInfo& out = mOperands[outs[0]];
+ Shape outShape = out.shape();
+
+ success = channelShufflePrepare(input.shape(), numGroups, &outShape) &&
+ setInfoAndAllocateIfNeeded(&out, outShape) &&
+ channelShuffleGeneric(input.buffer, input.shape(), numGroups, out.buffer,
+ outShape);
+ } break;
default:
nnAssert(false);
break;
diff --git a/common/OperationsUtils.cpp b/common/OperationsUtils.cpp
index c6d4665..75bd97d 100644
--- a/common/OperationsUtils.cpp
+++ b/common/OperationsUtils.cpp
@@ -1056,5 +1056,18 @@
output->dimensions = {batches, outHeight, outWidth, channels_out};
return true;
}
+
+bool channelShufflePrepare(const Shape& input, int32_t numGroups, Shape* output) {
+ uint32_t numDimensions = getNumberOfDimensions(input);
+
+ NN_OPS_CHECK(numGroups > 0);
+ NN_OPS_CHECK(getSizeOfDimension(input, numDimensions - 1) % numGroups == 0);
+
+ output->type = input.type;
+ output->dimensions = input.dimensions;
+ output->offset = input.offset;
+ output->scale = input.scale;
+ return true;
+}
} // namespace nn
} // namespace android
diff --git a/common/Utils.cpp b/common/Utils.cpp
index c67ce56..ed22186 100644
--- a/common/Utils.cpp
+++ b/common/Utils.cpp
@@ -1650,6 +1650,29 @@
inExpectedTypes, outputCount, outputIndexes,
outExpectedTypes);
}
+ case ANEURALNETWORKS_CHANNEL_SHUFFLE: {
+ if (inputCount != 2 || outputCount != 1) {
+ logInvalidInOutNumber(2, 1);
+ return ANEURALNETWORKS_BAD_DATA;
+ }
+ auto inputType = operands[inputIndexes[0]].type;
+ std::vector<OperandType> inExpectedTypes;
+ std::vector<OperandType> outExpectedTypes;
+ if (inputType == OperandType::TENSOR_FLOAT32) {
+ inExpectedTypes = {OperandType::TENSOR_FLOAT32, OperandType::INT32};
+ outExpectedTypes = {OperandType::TENSOR_FLOAT32};
+ } else if (inputType == OperandType::TENSOR_QUANT8_ASYMM) {
+ inExpectedTypes = {OperandType::TENSOR_QUANT8_ASYMM, OperandType::INT32};
+ outExpectedTypes = {OperandType::TENSOR_QUANT8_ASYMM};
+ } else {
+ LOG(ERROR) << "Unsupported input tensor type for operation "
+ << kOperationNames[opType];
+ return ANEURALNETWORKS_BAD_DATA;
+ }
+ return validateOperationOperandTypes(operands, inputCount, inputIndexes,
+ inExpectedTypes, outputCount, outputIndexes,
+ outExpectedTypes);
+ }
default:
return ANEURALNETWORKS_BAD_DATA;
}
diff --git a/common/include/Operations.h b/common/include/Operations.h
index 20a589d..10dc484 100644
--- a/common/include/Operations.h
+++ b/common/include/Operations.h
@@ -271,6 +271,9 @@
int32_t padding_top, int32_t padding_bottom, int32_t stride_width,
int32_t stride_height, int32_t activation, uint8_t* outputData,
const Shape& outputShape);
+
+bool channelShuffleGeneric(const uint8_t* inputData, const Shape& inputShape, int32_t numGroups,
+ uint8_t* outputData, const Shape& outputShape);
} // namespace nn
} // namespace android
#endif // ANDROID_ML_NN_COMMON_OPERATIONS_H
diff --git a/common/include/OperationsUtils.h b/common/include/OperationsUtils.h
index 7263eeb..dcefbc9 100644
--- a/common/include/OperationsUtils.h
+++ b/common/include/OperationsUtils.h
@@ -302,6 +302,8 @@
int32_t padding_left, int32_t padding_right, int32_t padding_top,
int32_t padding_bottom, int32_t stride_width, int32_t stride_height,
int32_t numGroups, Shape* output);
+
+bool channelShufflePrepare(const Shape& input, int32_t numGroups, Shape* output);
} // namespace nn
} // namespace android
diff --git a/common/operations/ChannelShuffle.cpp b/common/operations/ChannelShuffle.cpp
new file mode 100644
index 0000000..fbc0485
--- /dev/null
+++ b/common/operations/ChannelShuffle.cpp
@@ -0,0 +1,63 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "CpuOperationUtils.h"
+#include "Operations.h"
+
+#include "Tracing.h"
+
+namespace android {
+namespace nn {
+
+template <typename T>
+inline bool channelShuffleGeneric(const T* inputData, const Shape& inputShape, int32_t numGroups,
+ T* outputData, const Shape& outputShape) {
+ uint32_t numDimensions = getNumberOfDimensions(inputShape);
+ uint32_t inDepth = getSizeOfDimension(inputShape, numDimensions - 1);
+ uint32_t groupSize = inDepth / numGroups;
+
+ const T* inputDataEnd = inputData + getNumberOfElements(inputShape);
+ T* outPtr = outputData;
+ for (const T* inputBase = inputData; inputBase < inputDataEnd; inputBase += inDepth) {
+ for (uint32_t j = 0; j < groupSize; j++) {
+ for (uint32_t k = 0; k < numGroups; k++) {
+ *outPtr = inputBase[j + k * groupSize];
+ outPtr++;
+ }
+ }
+ }
+
+ return true;
+}
+
+bool channelShuffleGeneric(const uint8_t* inputData, const Shape& inputShape, int32_t numGroups,
+ uint8_t* outputData, const Shape& outputShape) {
+ NNTRACE_TRANS("channelShuffleGeneric");
+ if (inputShape.type == OperandType::TENSOR_FLOAT32) {
+ return channelShuffleGeneric<float>(reinterpret_cast<const float*>(inputData), inputShape,
+ numGroups, reinterpret_cast<float*>(outputData),
+ outputShape);
+ } else if (inputShape.type == OperandType::TENSOR_QUANT8_ASYMM) {
+ return channelShuffleGeneric<uint8_t>(reinterpret_cast<const uint8_t*>(inputData),
+ inputShape, numGroups,
+ reinterpret_cast<uint8_t*>(outputData), outputShape);
+ } else {
+ LOG(ERROR) << "Unsupported data type";
+ return false;
+ }
+}
+} // namespace nn
+} // namespace android