blob: 91795d9db213c5db1678dfa27004b100094d88ef [file] [log] [blame]
Slava Shklyaev145436e2018-11-30 22:50:15 +00001/*
2 * Copyright (C) 2018 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17#define LOG_TAG "Operations"
18
Lev Proleev0648e032019-11-08 16:27:15 +000019#include <algorithm>
20#include <limits>
21#include <vector>
22
Slava Shklyaev145436e2018-11-30 22:50:15 +000023#include "OperationResolver.h"
24#include "OperationsUtils.h"
25#include "Tracing.h"
26
Slava Shklyaev84829a92021-02-26 12:05:38 +000027#ifdef NN_INCLUDE_CPU_IMPLEMENTATION
Slava Shklyaeva75fa2c2021-01-13 16:12:02 +000028#include <tensorflow/lite/kernels/internal/reference/reference_ops.h>
Slava Shklyaev84829a92021-02-26 12:05:38 +000029#endif // NN_INCLUDE_CPU_IMPLEMENTATION
Slava Shklyaeva75fa2c2021-01-13 16:12:02 +000030
Slava Shklyaev145436e2018-11-30 22:50:15 +000031namespace android {
32namespace nn {
33namespace reduce {
34
Slava Shklyaev145436e2018-11-30 22:50:15 +000035constexpr uint32_t kNumInputs = 3;
36constexpr uint32_t kInputTensor = 0;
37constexpr uint32_t kInputAxes = 1;
38constexpr uint32_t kInputKeepDims = 2;
39
40constexpr uint32_t kNumOutputs = 1;
41constexpr uint32_t kOutputTensor = 0;
42
43// Values from
44// https://en.wikipedia.org/wiki/Half-precision_floating-point_format#IEEE_754_half-precision_binary_floating-point_format:_binary16
45constexpr _Float16 kFloat16Max = 65504;
46constexpr _Float16 kFloat16Lowest = -kFloat16Max;
47
Slava Shklyaev84829a92021-02-26 12:05:38 +000048#ifdef NN_INCLUDE_CPU_IMPLEMENTATION
Slava Shklyaev145436e2018-11-30 22:50:15 +000049namespace {
50
Slava Shklyaev145436e2018-11-30 22:50:15 +000051template <typename T>
52inline bool compute(IOperationExecutionContext* context, T init, T func(T, T)) {
53 const Shape inputShape = context->getInputShape(kInputTensor);
54 const Shape axesShape = context->getInputShape(kInputAxes);
55 const Shape outputShape = context->getOutputShape(kOutputTensor);
56 const uint32_t inputRank = getNumberOfDimensions(inputShape);
57 const uint32_t numAxes = getNumberOfElements(axesShape);
58 std::vector<int> tempIndex(inputShape.dimensions.size());
59 std::vector<int> tempAxes(numAxes);
60 return tflite::reference_ops::ReduceGeneric<T>(
61 context->getInputBuffer<T>(kInputTensor),
62 reinterpret_cast<const int32_t*>(inputShape.dimensions.data()), inputRank,
63 context->getOutputBuffer<T>(kOutputTensor),
64 reinterpret_cast<const int32_t*>(outputShape.dimensions.data()),
65 outputShape.dimensions.size(), context->getInputBuffer<int32_t>(kInputAxes), numAxes,
66 context->getInputValue<bool8>(kInputKeepDims), tempIndex.data(), tempAxes.data(), init,
67 func);
68}
69
70} // namespace
Slava Shklyaev84829a92021-02-26 12:05:38 +000071#endif // NN_INCLUDE_CPU_IMPLEMENTATION
Slava Shklyaev145436e2018-11-30 22:50:15 +000072
Michael Butler274ff7b2020-11-02 23:17:11 -080073Result<Version> validateProdSum(const IOperationValidationContext* context) {
Slava Shklyaev145436e2018-11-30 22:50:15 +000074 NN_RET_CHECK_EQ(context->getNumInputs(), kNumInputs);
75 NN_RET_CHECK_EQ(context->getNumOutputs(), kNumOutputs);
76 OperandType inputType = context->getInputType(kInputTensor);
77 NN_RET_CHECK(inputType == OperandType::TENSOR_FLOAT16 ||
78 inputType == OperandType::TENSOR_FLOAT32)
79 << "Unsupported tensor type for REDUCE_PROD or REDUCE_SUM";
80 NN_RET_CHECK(
81 validateInputTypes(context, {inputType, OperandType::TENSOR_INT32, OperandType::BOOL}));
82 NN_RET_CHECK(validateOutputTypes(context, {inputType}));
Lev Proleev88a3bba2020-03-18 15:20:46 +000083 const Shape& input = context->getInputShape(kInputTensor);
84 if (hasKnownRank(input)) {
85 NN_RET_CHECK_LE(getNumberOfDimensions(input), 4);
86 }
Michael Butler274ff7b2020-11-02 23:17:11 -080087 return Version::ANDROID_Q;
Slava Shklyaev145436e2018-11-30 22:50:15 +000088}
89
Michael Butler274ff7b2020-11-02 23:17:11 -080090Result<Version> validateMaxMin(const IOperationValidationContext* context) {
Slava Shklyaev145436e2018-11-30 22:50:15 +000091 NN_RET_CHECK_EQ(context->getNumInputs(), kNumInputs);
92 NN_RET_CHECK_EQ(context->getNumOutputs(), kNumOutputs);
93 OperandType inputType = context->getInputType(kInputTensor);
94 NN_RET_CHECK(inputType == OperandType::TENSOR_FLOAT16 ||
95 inputType == OperandType::TENSOR_FLOAT32 ||
Lev Proleev0648e032019-11-08 16:27:15 +000096 inputType == OperandType::TENSOR_QUANT8_ASYMM ||
97 inputType == OperandType::TENSOR_QUANT8_ASYMM_SIGNED)
Slava Shklyaev145436e2018-11-30 22:50:15 +000098 << "Unsupported tensor type for REDUCE_MAX or REDUCE_MIN";
99 NN_RET_CHECK(
100 validateInputTypes(context, {inputType, OperandType::TENSOR_INT32, OperandType::BOOL}));
101 NN_RET_CHECK(validateOutputTypes(context, {inputType}));
Michael Butler25d50732020-11-01 23:47:40 -0800102 auto minVersion = Version::ANDROID_Q;
Lev Proleev0648e032019-11-08 16:27:15 +0000103 if (inputType == OperandType::TENSOR_QUANT8_ASYMM_SIGNED) {
Michael Butler25d50732020-11-01 23:47:40 -0800104 minVersion = Version::ANDROID_R;
Lev Proleev0648e032019-11-08 16:27:15 +0000105 }
Lev Proleev88a3bba2020-03-18 15:20:46 +0000106 const Shape& input = context->getInputShape(kInputTensor);
107 if (hasKnownRank(input)) {
108 NN_RET_CHECK_LE(getNumberOfDimensions(input), 4);
109 }
Michael Butler274ff7b2020-11-02 23:17:11 -0800110 return minVersion;
Slava Shklyaev145436e2018-11-30 22:50:15 +0000111}
112
Michael Butler274ff7b2020-11-02 23:17:11 -0800113Result<Version> validateLogical(const IOperationValidationContext* context) {
Slava Shklyaev145436e2018-11-30 22:50:15 +0000114 NN_RET_CHECK_EQ(context->getNumInputs(), kNumInputs);
115 NN_RET_CHECK_EQ(context->getNumOutputs(), kNumOutputs);
116 OperandType inputType = context->getInputType(kInputTensor);
117 NN_RET_CHECK(inputType == OperandType::TENSOR_BOOL8)
118 << "Unsupported tensor type for REDUCE_ANY or REDUCE_ALL";
119 NN_RET_CHECK(
120 validateInputTypes(context, {inputType, OperandType::TENSOR_INT32, OperandType::BOOL}));
121 NN_RET_CHECK(validateOutputTypes(context, {inputType}));
Lev Proleev88a3bba2020-03-18 15:20:46 +0000122 const Shape& input = context->getInputShape(kInputTensor);
123 if (hasKnownRank(input)) {
124 NN_RET_CHECK_LE(getNumberOfDimensions(input), 4);
125 }
Michael Butler274ff7b2020-11-02 23:17:11 -0800126 return Version::ANDROID_Q;
Slava Shklyaev145436e2018-11-30 22:50:15 +0000127}
128
Slava Shklyaev84829a92021-02-26 12:05:38 +0000129#ifdef NN_INCLUDE_CPU_IMPLEMENTATION
Slava Shklyaev145436e2018-11-30 22:50:15 +0000130bool prepare(IOperationExecutionContext* context) {
131 Shape inputShape = context->getInputShape(kInputTensor);
132 const uint32_t inputRank = getNumberOfDimensions(inputShape);
Lev Proleev88a3bba2020-03-18 15:20:46 +0000133 NN_RET_CHECK_LE(inputRank, 4);
Slava Shklyaev145436e2018-11-30 22:50:15 +0000134
135 std::vector<bool> shouldReduce(inputRank);
136 const int32_t* axes = context->getInputBuffer<int32_t>(kInputAxes);
137 Shape axesShape = context->getInputShape(kInputAxes);
138 NN_RET_CHECK_EQ(getNumberOfDimensions(axesShape), 1u);
139 const uint32_t numAxes = getNumberOfElements(axesShape);
140 for (uint32_t i = 0; i < numAxes; ++i) {
141 int32_t axis = axes[i];
142 NN_RET_CHECK(handleNegativeAxis(inputRank, &axis));
143 shouldReduce[axis] = true;
144 }
145
146 // Input and output must have the same quantization parameters, etc.
147 Shape outputShape = inputShape;
148 outputShape.dimensions.clear();
149 bool keepDims = context->getInputValue<bool8>(kInputKeepDims);
150 for (uint32_t axis = 0; axis < inputRank; ++axis) {
151 if (shouldReduce[axis]) {
152 if (keepDims) {
153 outputShape.dimensions.push_back(1);
154 }
155 } else {
156 outputShape.dimensions.push_back(getSizeOfDimension(inputShape, axis));
157 }
158 }
159
Lev Proleev8a784722020-05-05 17:26:26 +0100160 // Handle the case when all dimensions are removed
161 if (outputShape.dimensions.empty()) {
162 outputShape.dimensions.push_back(1);
163 }
164
Slava Shklyaev145436e2018-11-30 22:50:15 +0000165 return context->setOutputShape(kOutputTensor, outputShape);
166}
167
168bool executeProd(IOperationExecutionContext* context) {
169 switch (context->getInputType(kInputTensor)) {
170 case OperandType::TENSOR_FLOAT16:
Xusong Wang3b0c7fb2020-03-11 17:53:08 -0700171 return compute<_Float16>(context, 1, [](_Float16 a, _Float16 b) -> _Float16 {
172 // Handle the zero case because 0 * inf evaluates to nan.
173 if (a == 0 || b == 0) return 0;
174 return a * b;
175 });
Slava Shklyaev145436e2018-11-30 22:50:15 +0000176 case OperandType::TENSOR_FLOAT32:
Xusong Wang3b0c7fb2020-03-11 17:53:08 -0700177 return compute<float>(context, 1, [](float a, float b) -> float {
178 // Handle the zero case because 0 * inf evaluates to nan.
179 if (a == 0 || b == 0) return 0;
180 return a * b;
181 });
Slava Shklyaev145436e2018-11-30 22:50:15 +0000182 default:
183 NN_RET_CHECK_FAIL() << "Unsupported tensor type for operation REDUCE_PROD";
184 }
185}
186
187bool executeSum(IOperationExecutionContext* context) {
188 switch (context->getInputType(kInputTensor)) {
189 case OperandType::TENSOR_FLOAT16:
190 return compute<_Float16>(context, 0, [](_Float16 a, _Float16 b) { return a + b; });
191 case OperandType::TENSOR_FLOAT32:
192 return compute<float>(context, 0, [](float a, float b) { return a + b; });
193 default:
194 NN_RET_CHECK_FAIL() << "Unsupported tensor type for operation REDUCE_SUM";
195 }
196}
197
198bool executeMax(IOperationExecutionContext* context) {
199 switch (context->getInputType(kInputTensor)) {
200 case OperandType::TENSOR_FLOAT16:
201 return compute<_Float16>(context, kFloat16Lowest,
202 [](_Float16 a, _Float16 b) { return std::max(a, b); });
203 case OperandType::TENSOR_FLOAT32:
204 return compute<float>(context, std::numeric_limits<float>::lowest(),
205 [](float a, float b) { return std::max(a, b); });
206 case OperandType::TENSOR_QUANT8_ASYMM:
207 return compute<uint8_t>(context, std::numeric_limits<uint8_t>::lowest(),
208 [](uint8_t a, uint8_t b) { return std::max(a, b); });
Lev Proleev0648e032019-11-08 16:27:15 +0000209 case OperandType::TENSOR_QUANT8_ASYMM_SIGNED:
210 return compute<int8_t>(context, std::numeric_limits<int8_t>::lowest(),
211 [](int8_t a, int8_t b) { return std::max(a, b); });
Slava Shklyaev145436e2018-11-30 22:50:15 +0000212 default:
213 NN_RET_CHECK_FAIL() << "Unsupported tensor type for operation REDUCE_MAX";
214 }
215}
216
217bool executeMin(IOperationExecutionContext* context) {
218 switch (context->getInputType(kInputTensor)) {
219 case OperandType::TENSOR_FLOAT16:
220 return compute<_Float16>(context, kFloat16Max,
221 [](_Float16 a, _Float16 b) { return std::min(a, b); });
222 case OperandType::TENSOR_FLOAT32:
223 return compute<float>(context, std::numeric_limits<float>::max(),
224 [](float a, float b) { return std::min(a, b); });
225 case OperandType::TENSOR_QUANT8_ASYMM:
226 return compute<uint8_t>(context, std::numeric_limits<uint8_t>::max(),
227 [](uint8_t a, uint8_t b) { return std::min(a, b); });
Lev Proleev0648e032019-11-08 16:27:15 +0000228 case OperandType::TENSOR_QUANT8_ASYMM_SIGNED:
229 return compute<int8_t>(context, std::numeric_limits<int8_t>::max(),
230 [](int8_t a, int8_t b) { return std::min(a, b); });
Slava Shklyaev145436e2018-11-30 22:50:15 +0000231 default:
232 NN_RET_CHECK_FAIL() << "Unsupported tensor type for operation REDUCE_MIN";
233 }
234}
235
236bool executeAny(IOperationExecutionContext* context) {
237 switch (context->getInputType(kInputTensor)) {
238 case OperandType::TENSOR_BOOL8:
239 return compute<bool8>(context, false,
240 [](bool8 a, bool8 b) { return static_cast<bool8>(a || b); });
241 default:
242 NN_RET_CHECK_FAIL() << "Unsupported tensor type for operation REDUCE_ANY";
243 }
244}
245
246bool executeAll(IOperationExecutionContext* context) {
247 switch (context->getInputType(kInputTensor)) {
248 case OperandType::TENSOR_BOOL8:
249 return compute<bool8>(context, true,
250 [](bool8 a, bool8 b) { return static_cast<bool8>(a && b); });
251 default:
252 NN_RET_CHECK_FAIL() << "Unsupported tensor type for operation REDUCE_ALL";
253 }
254}
Slava Shklyaev84829a92021-02-26 12:05:38 +0000255#endif // NN_INCLUDE_CPU_IMPLEMENTATION
Slava Shklyaev145436e2018-11-30 22:50:15 +0000256
257} // namespace reduce
258
259NN_REGISTER_OPERATION(REDUCE_PROD, "REDUCE_PROD", reduce::validateProdSum, reduce::prepare,
260 reduce::executeProd);
261NN_REGISTER_OPERATION(REDUCE_SUM, "REDUCE_SUM", reduce::validateProdSum, reduce::prepare,
262 reduce::executeSum);
263NN_REGISTER_OPERATION(REDUCE_MAX, "REDUCE_MAX", reduce::validateMaxMin, reduce::prepare,
264 reduce::executeMax);
265NN_REGISTER_OPERATION(REDUCE_MIN, "REDUCE_MIN", reduce::validateMaxMin, reduce::prepare,
266 reduce::executeMin);
267NN_REGISTER_OPERATION(REDUCE_ANY, "REDUCE_ANY", reduce::validateLogical, reduce::prepare,
268 reduce::executeAny);
269NN_REGISTER_OPERATION(REDUCE_ALL, "REDUCE_ALL", reduce::validateLogical, reduce::prepare,
270 reduce::executeAll);
271
272} // namespace nn
273} // namespace android