Support zero batch in broadcast ops.

Also switch to OperationResolver.

Bug: 126737477
Test: NeuralNetworksTest_static
Change-Id: Ia2aaf7db4539ce5ffb97eb2341b0b5a56b2b8483
Merged-In: Ia2aaf7db4539ce5ffb97eb2341b0b5a56b2b8483
(cherry picked from commit 041d28acbe75b80b5d55db5daea7303751f1c4fa)
diff --git a/common/Android.bp b/common/Android.bp
index 21d2e3e..34608ee 100644
--- a/common/Android.bp
+++ b/common/Android.bp
@@ -25,6 +25,7 @@
     srcs: [
         "OperationResolver.cpp",
         "operations/BidirectionalSequenceRNN.cpp",
+        "operations/Broadcast.cpp",
         "operations/ChannelShuffle.cpp",
         "operations/Comparisons.cpp",
         "operations/Conv2D.cpp",
diff --git a/common/CpuExecutor.cpp b/common/CpuExecutor.cpp
index fa4e00a..0114389 100644
--- a/common/CpuExecutor.cpp
+++ b/common/CpuExecutor.cpp
@@ -688,64 +688,6 @@
             LOG(ERROR) << "OEM operation not supported for CPU execution";
             success = false;
         } break;
-        case OperationType::ADD: {
-            if (!allParametersPresent(3, 1)) {
-                return ANEURALNETWORKS_BAD_DATA;
-            }
-            const RunTimeOperandInfo& in1 = mOperands[ins[0]];
-            const RunTimeOperandInfo& in2 = mOperands[ins[1]];
-            int32_t activation = getScalarData<int32_t>(mOperands[ins[2]]);
-
-            RunTimeOperandInfo& out = mOperands[outs[0]];
-            Shape outShape = out.shape();
-
-            if (!addMulPrepare(in1.shape(), in2.shape(), &outShape) ||
-                !setInfoAndAllocateIfNeeded(&out, outShape, &result)) {
-                break;
-            }
-            if (in1.type == OperandType::TENSOR_FLOAT32) {
-                success = addFloat32(reinterpret_cast<const float*>(in1.buffer), in1.shape(),
-                                     reinterpret_cast<const float*>(in2.buffer), in2.shape(),
-                                     activation, reinterpret_cast<float*>(out.buffer), outShape);
-            } else if (in1.type == OperandType::TENSOR_FLOAT16) {
-                success = addFloat16(reinterpret_cast<const _Float16*>(in1.buffer), in1.shape(),
-                                     reinterpret_cast<const _Float16*>(in2.buffer), in2.shape(),
-                                     activation, reinterpret_cast<_Float16*>(out.buffer), outShape);
-            } else if (in1.type == OperandType::TENSOR_QUANT8_ASYMM) {
-                success = addQuant8(reinterpret_cast<const uint8_t*>(in1.buffer), in1.shape(),
-                                    reinterpret_cast<const uint8_t*>(in2.buffer), in2.shape(),
-                                    activation, reinterpret_cast<uint8_t*>(out.buffer), outShape);
-            }
-        } break;
-        case OperationType::MUL: {
-            if (!allParametersPresent(3, 1)) {
-                return ANEURALNETWORKS_BAD_DATA;
-            }
-            const RunTimeOperandInfo& in1 = mOperands[ins[0]];
-            const RunTimeOperandInfo& in2 = mOperands[ins[1]];
-            int32_t activation = getScalarData<int32_t>(mOperands[ins[2]]);
-
-            RunTimeOperandInfo& out = mOperands[outs[0]];
-            Shape outShape = out.shape();
-
-            if (!addMulPrepare(in1.shape(), in2.shape(), &outShape) ||
-                !setInfoAndAllocateIfNeeded(&out, outShape, &result)) {
-                break;
-            }
-            if (in1.type == OperandType::TENSOR_FLOAT32) {
-                success = mulFloat32(reinterpret_cast<const float*>(in1.buffer), in1.shape(),
-                                     reinterpret_cast<const float*>(in2.buffer), in2.shape(),
-                                     activation, reinterpret_cast<float*>(out.buffer), outShape);
-            } else if (in1.type == OperandType::TENSOR_FLOAT16) {
-                success = mulFloat16(reinterpret_cast<const _Float16*>(in1.buffer), in1.shape(),
-                                     reinterpret_cast<const _Float16*>(in2.buffer), in2.shape(),
-                                     activation, reinterpret_cast<_Float16*>(out.buffer), outShape);
-            } else if (in1.type == OperandType::TENSOR_QUANT8_ASYMM) {
-                success = mulQuant8(reinterpret_cast<const uint8_t*>(in1.buffer), in1.shape(),
-                                    reinterpret_cast<const uint8_t*>(in2.buffer), in2.shape(),
-                                    activation, reinterpret_cast<uint8_t*>(out.buffer), outShape);
-            }
-        } break;
         case OperationType::FLOOR: {
             if (!allParametersPresent(1, 1)) {
                 return ANEURALNETWORKS_BAD_DATA;
@@ -1716,60 +1658,6 @@
                                         reinterpret_cast<const int32_t*>(strides.buffer), beginMask,
                                         endMask, shrinkAxisMask, output.buffer, outShape);
         } break;
-        case OperationType::DIV: {
-            if (!allParametersPresent(3, 1)) {
-                return ANEURALNETWORKS_BAD_DATA;
-            }
-            const RunTimeOperandInfo& in1 = mOperands[ins[0]];
-            const RunTimeOperandInfo& in2 = mOperands[ins[1]];
-            int32_t activation = getScalarData<int32_t>(mOperands[ins[2]]);
-
-            RunTimeOperandInfo& out = mOperands[outs[0]];
-            Shape outShape = out.shape();
-
-            if (!addMulPrepare(in1.shape(), in2.shape(), &outShape) ||
-                !setInfoAndAllocateIfNeeded(&out, outShape, &result)) {
-                break;
-            }
-            if (in1.type == OperandType::TENSOR_FLOAT32) {
-                success = divFloat32(reinterpret_cast<const float*>(in1.buffer), in1.shape(),
-                                     reinterpret_cast<const float*>(in2.buffer), in2.shape(),
-                                     activation, reinterpret_cast<float*>(out.buffer), outShape);
-            } else if (in1.type == OperandType::TENSOR_FLOAT16) {
-                success = divFloat16(reinterpret_cast<const _Float16*>(in1.buffer), in1.shape(),
-                                     reinterpret_cast<const _Float16*>(in2.buffer), in2.shape(),
-                                     activation, reinterpret_cast<_Float16*>(out.buffer), outShape);
-            }
-        } break;
-        case OperationType::SUB: {
-            if (!allParametersPresent(3, 1)) {
-                return ANEURALNETWORKS_BAD_DATA;
-            }
-            const RunTimeOperandInfo& in1 = mOperands[ins[0]];
-            const RunTimeOperandInfo& in2 = mOperands[ins[1]];
-            int32_t activation = getScalarData<int32_t>(mOperands[ins[2]]);
-
-            RunTimeOperandInfo& out = mOperands[outs[0]];
-            Shape outShape = out.shape();
-
-            if (!addMulPrepare(in1.shape(), in2.shape(), &outShape) ||
-                !setInfoAndAllocateIfNeeded(&out, outShape, &result)) {
-                break;
-            }
-            if (in1.type == OperandType::TENSOR_FLOAT16) {
-                success = subFloat16(reinterpret_cast<const _Float16*>(in1.buffer), in1.shape(),
-                                     reinterpret_cast<const _Float16*>(in2.buffer), in2.shape(),
-                                     activation, reinterpret_cast<_Float16*>(out.buffer), outShape);
-            } else if (in1.type == OperandType::TENSOR_FLOAT32) {
-                success = subFloat32(reinterpret_cast<const float*>(in1.buffer), in1.shape(),
-                                     reinterpret_cast<const float*>(in2.buffer), in2.shape(),
-                                     activation, reinterpret_cast<float*>(out.buffer), outShape);
-            } else if (in1.type == OperandType::TENSOR_QUANT8_ASYMM) {
-                success = subQuant8(reinterpret_cast<const uint8_t*>(in1.buffer), in1.shape(),
-                                    reinterpret_cast<const uint8_t*>(in2.buffer), in2.shape(),
-                                    activation, reinterpret_cast<uint8_t*>(out.buffer), outShape);
-            }
-        } break;
         case OperationType::MEAN: {
             if (!allParametersPresent(3, 1)) {
                 return ANEURALNETWORKS_BAD_DATA;
diff --git a/common/OperationResolver.cpp b/common/OperationResolver.cpp
index 8f1edfb..e77cf40 100644
--- a/common/OperationResolver.cpp
+++ b/common/OperationResolver.cpp
@@ -25,6 +25,7 @@
 
 // TODO(b/119608412): Find a way to not reference every operation here.
 const OperationRegistration* register_ABS();
+const OperationRegistration* register_ADD();
 const OperationRegistration* register_AVERAGE_POOL_2D();
 const OperationRegistration* register_AXIS_ALIGNED_BBOX_TRANSFORM();
 const OperationRegistration* register_BIDIRECTIONAL_SEQUENCE_RNN();
@@ -33,6 +34,7 @@
 const OperationRegistration* register_CONV_2D();
 const OperationRegistration* register_DEQUANTIZE();
 const OperationRegistration* register_DETECTION_POSTPROCESSING();
+const OperationRegistration* register_DIV();
 const OperationRegistration* register_EQUAL();
 const OperationRegistration* register_EXP();
 const OperationRegistration* register_FULLY_CONNECTED();
@@ -51,6 +53,7 @@
 const OperationRegistration* register_LOGICAL_OR();
 const OperationRegistration* register_LOG_SOFTMAX();
 const OperationRegistration* register_MAX_POOL_2D();
+const OperationRegistration* register_MUL();
 const OperationRegistration* register_NEG();
 const OperationRegistration* register_NOT_EQUAL();
 const OperationRegistration* register_PRELU();
@@ -68,11 +71,13 @@
 const OperationRegistration* register_SELECT();
 const OperationRegistration* register_SIN();
 const OperationRegistration* register_SQRT();
+const OperationRegistration* register_SUB();
 const OperationRegistration* register_UNIDIRECTIONAL_SEQUENCE_LSTM();
 const OperationRegistration* register_UNIDIRECTIONAL_SEQUENCE_RNN();
 
 BuiltinOperationResolver::BuiltinOperationResolver() {
     registerOperation(register_ABS());
+    registerOperation(register_ADD());
     registerOperation(register_AVERAGE_POOL_2D());
     registerOperation(register_AXIS_ALIGNED_BBOX_TRANSFORM());
     registerOperation(register_BIDIRECTIONAL_SEQUENCE_RNN());
@@ -81,6 +86,7 @@
     registerOperation(register_CONV_2D());
     registerOperation(register_DEQUANTIZE());
     registerOperation(register_DETECTION_POSTPROCESSING());
+    registerOperation(register_DIV());
     registerOperation(register_EQUAL());
     registerOperation(register_EXP());
     registerOperation(register_FULLY_CONNECTED());
@@ -99,6 +105,7 @@
     registerOperation(register_LOGICAL_OR());
     registerOperation(register_LOG_SOFTMAX());
     registerOperation(register_MAX_POOL_2D());
+    registerOperation(register_MUL());
     registerOperation(register_NEG());
     registerOperation(register_NOT_EQUAL());
     registerOperation(register_PRELU());
@@ -116,6 +123,7 @@
     registerOperation(register_SELECT());
     registerOperation(register_SIN());
     registerOperation(register_SQRT());
+    registerOperation(register_SUB());
     registerOperation(register_UNIDIRECTIONAL_SEQUENCE_LSTM());
     registerOperation(register_UNIDIRECTIONAL_SEQUENCE_RNN());
 }
diff --git a/common/OperationsUtils.cpp b/common/OperationsUtils.cpp
index af4e2e7..25a5258 100644
--- a/common/OperationsUtils.cpp
+++ b/common/OperationsUtils.cpp
@@ -289,6 +289,7 @@
 }
 
 bool calculateBroadcastedShape(const Shape& in1, const Shape& in2, Shape* out) {
+    NN_RET_CHECK(in1.type == in2.type);
     uint32_t numberOfDims1 = getNumberOfDimensions(in1);
     uint32_t numberOfDims2 = getNumberOfDimensions(in2);
     uint32_t maxDims = std::max(numberOfDims1, numberOfDims2);
@@ -308,7 +309,7 @@
                        << "\nSecond tensor: dimension " << numberOfDims2 - i << "of size " << dim2;
             return false;
         }
-        out->dimensions[maxDims - i] = std::max(dim1, dim2);
+        out->dimensions[maxDims - i] = (dim1 == 1) ? dim2 : dim1;
     }
     return true;
 }
@@ -318,15 +319,6 @@
     return static_cast<uint8_t>(doubleValue / newShape.scale + newShape.offset);
 }
 
-bool addMulPrepare(const Shape& in1, const Shape& in2, Shape* out) {
-    NN_OPS_CHECK(getNumberOfDimensions(in1) <= 4 && getNumberOfDimensions(in2) <= 4);
-    NN_OPS_CHECK(in1.type == in2.type);
-    if (SameShape(in1, in2)) {
-        return SetShape(in1, out);
-    }
-    return calculateBroadcastedShape(in1, in2, out);
-}
-
 bool floorPrepare(const Shape& input, Shape* output) {
     return SetShape(input, output);
 }
diff --git a/common/Utils.cpp b/common/Utils.cpp
index 20429fd..ccb4f62 100644
--- a/common/Utils.cpp
+++ b/common/Utils.cpp
@@ -571,80 +571,6 @@
         case ANEURALNETWORKS_OEM_OPERATION: {
             return ANEURALNETWORKS_NO_ERROR;
         }
-        case ANEURALNETWORKS_ADD: {
-            if (inputCount != 3 || outputCount != 1) {
-                logInvalidInOutNumber(3, 1);
-                return ANEURALNETWORKS_BAD_DATA;
-            }
-            auto inputType = operands[inputIndexes[0]].type;
-            std::vector<OperandType> inExpectedTypes;
-            std::vector<OperandType> outExpectedTypes;
-            if (inputType == OperandType::TENSOR_FLOAT32 ||
-                inputType == OperandType::TENSOR_QUANT8_ASYMM) {
-                inExpectedTypes = {
-                        inputType,
-                        inputType,
-                        OperandType::INT32,
-                };
-                outExpectedTypes = {inputType};
-                NN_RETURN_IF_ERROR(validateHalVersion(opType, halVersion, HalVersion::V1_0));
-            } else if (inputType == OperandType::TENSOR_FLOAT16) {
-                inExpectedTypes = {
-                        OperandType::TENSOR_FLOAT16,
-                        OperandType::TENSOR_FLOAT16,
-                        OperandType::INT32,
-                };
-                outExpectedTypes = {OperandType::TENSOR_FLOAT16};
-                NN_RETURN_IF_ERROR(validateHalVersion(opType, halVersion, HalVersion::V1_2));
-            } else {
-                LOG(ERROR) << "Unsupported input tensor type for operation "
-                           << getOperationName(opType);
-                return ANEURALNETWORKS_BAD_DATA;
-            }
-            return validateOperationOperandTypes(operands,
-                                                 inputCount, inputIndexes,
-                                                 inExpectedTypes,
-                                                 outputCount, outputIndexes,
-                                                 outExpectedTypes);
-        }
-        case ANEURALNETWORKS_MUL: {
-            if (inputCount != 3 || outputCount != 1) {
-                logInvalidInOutNumber(3, 1);
-                return ANEURALNETWORKS_BAD_DATA;
-            }
-            auto inputType = operands[inputIndexes[0]].type;
-            std::vector<OperandType> inExpectedTypes;
-            std::vector<OperandType> outExpectedTypes;
-            if (inputType == OperandType::TENSOR_FLOAT32) {
-                inExpectedTypes = {OperandType::TENSOR_FLOAT32,
-                                   OperandType::TENSOR_FLOAT32,
-                                   OperandType::INT32};
-                outExpectedTypes = {OperandType::TENSOR_FLOAT32};
-            } else if (inputType == OperandType::TENSOR_FLOAT16) {
-                NN_RETURN_IF_ERROR(validateHalVersion(opType, halVersion, HalVersion::V1_2));
-                inExpectedTypes = {
-                        OperandType::TENSOR_FLOAT16,
-                        OperandType::TENSOR_FLOAT16,
-                        OperandType::INT32,
-                };
-                outExpectedTypes = {OperandType::TENSOR_FLOAT16};
-            } else if (inputType == OperandType::TENSOR_QUANT8_ASYMM) {
-                inExpectedTypes = {OperandType::TENSOR_QUANT8_ASYMM,
-                                   OperandType::TENSOR_QUANT8_ASYMM,
-                                   OperandType::INT32};
-                outExpectedTypes = {OperandType::TENSOR_QUANT8_ASYMM};
-            } else {
-                LOG(ERROR) << "Unsupported input tensor type for operation "
-                           << getOperationName(opType);
-                return ANEURALNETWORKS_BAD_DATA;
-            }
-            NN_RETURN_IF_ERROR(validateHalVersion(opType, halVersion, HalVersion::V1_0));
-            return validateOperationOperandTypes(operands,
-                                                 inputCount, inputIndexes,
-                                                 inExpectedTypes,
-                                                 outputCount, outputIndexes,
-                                                 outExpectedTypes);
-        }
         case ANEURALNETWORKS_FLOOR: {
             if (inputCount != 1 || outputCount != 1) {
                 logInvalidInOutNumber(1, 1);
@@ -1749,68 +1675,6 @@
                                                  inExpectedTypes, outputCount, outputIndexes,
                                                  outExpectedTypes);
         }
-        case ANEURALNETWORKS_DIV: {
-            if (inputCount != 3 || outputCount != 1) {
-                logInvalidInOutNumber(3, 1);
-                return ANEURALNETWORKS_BAD_DATA;
-            }
-            auto inputType = operands[inputIndexes[0]].type;
-            std::vector<OperandType> inExpectedTypes;
-            std::vector<OperandType> outExpectedTypes;
-            if (inputType == OperandType::TENSOR_FLOAT32) {
-                inExpectedTypes = {OperandType::TENSOR_FLOAT32,
-                                   OperandType::TENSOR_FLOAT32,
-                                   OperandType::INT32};
-                outExpectedTypes = {OperandType::TENSOR_FLOAT32};
-            } else if (inputType == OperandType::TENSOR_FLOAT16) {
-                inExpectedTypes = {
-                        OperandType::TENSOR_FLOAT16,
-                        OperandType::TENSOR_FLOAT16,
-                        OperandType::INT32,
-                };
-                outExpectedTypes = {OperandType::TENSOR_FLOAT16};
-                NN_RETURN_IF_ERROR(validateHalVersion(opType, halVersion, HalVersion::V1_2));
-            } else {
-                LOG(ERROR) << "Unsupported input tensor type for operation "
-                           << getOperationName(opType);
-                return ANEURALNETWORKS_BAD_DATA;
-            }
-            NN_RETURN_IF_ERROR(validateHalVersion(opType, halVersion, HalVersion::V1_1));
-            return validateOperationOperandTypes(operands,
-                                                 inputCount, inputIndexes,
-                                                 inExpectedTypes,
-                                                 outputCount, outputIndexes,
-                                                 outExpectedTypes);
-        }
-        case ANEURALNETWORKS_SUB: {
-            if (inputCount != 3 || outputCount != 1) {
-                logInvalidInOutNumber(3, 1);
-                return ANEURALNETWORKS_BAD_DATA;
-            }
-            auto inputType = operands[inputIndexes[0]].type;
-            std::vector<OperandType> inExpectedTypes;
-            std::vector<OperandType> outExpectedTypes;
-            if (inputType == OperandType::TENSOR_FLOAT32) {
-                inExpectedTypes = {OperandType::TENSOR_FLOAT32, OperandType::TENSOR_FLOAT32,
-                                   OperandType::INT32};
-                outExpectedTypes = {OperandType::TENSOR_FLOAT32};
-                NN_RETURN_IF_ERROR(validateHalVersion(opType, halVersion, HalVersion::V1_1));
-            } else if (inputType == OperandType::TENSOR_FLOAT16 ||
-                       inputType == OperandType::TENSOR_QUANT8_ASYMM) {
-                inExpectedTypes = {inputType, inputType, OperandType::INT32};
-                outExpectedTypes = {inputType};
-                NN_RETURN_IF_ERROR(validateHalVersion(opType, halVersion, HalVersion::V1_2));
-            } else {
-                LOG(ERROR) << "Unsupported input tensor type for operation "
-                           << getOperationName(opType);
-                return ANEURALNETWORKS_BAD_DATA;
-            }
-            return validateOperationOperandTypes(operands,
-                                                 inputCount, inputIndexes,
-                                                 inExpectedTypes,
-                                                 outputCount, outputIndexes,
-                                                 outExpectedTypes);
-        }
         case ANEURALNETWORKS_MEAN: {
             if (inputCount != 3 || outputCount != 1) {
                 logInvalidInOutNumber(3, 1);
diff --git a/common/include/Operations.h b/common/include/Operations.h
index 3eafd7e..d816761 100644
--- a/common/include/Operations.h
+++ b/common/include/Operations.h
@@ -44,20 +44,6 @@
 
 struct Shape;
 
-bool addFloat16(const _Float16* in1, const Shape& shape1, const _Float16* in2, const Shape& shape2,
-                int32_t activation, _Float16* out, const Shape& shapeOut);
-bool addFloat32(const float* in1, const Shape& shape1, const float* in2, const Shape& shape2,
-                int32_t activation, float* out, const Shape& shapeOut);
-bool addQuant8(const uint8_t* in1, const Shape& shape1, const uint8_t* in2, const Shape& shape2,
-               int32_t activation, uint8_t* out, const Shape& shapeOut);
-
-bool mulFloat16(const _Float16* in1, const Shape& shape1, const _Float16* in2, const Shape& shape2,
-                int32_t activation, _Float16* out, const Shape& shapeOut);
-bool mulFloat32(const float* in1, const Shape& shape1, const float* in2, const Shape& shape2,
-                int32_t activation, float* out, const Shape& shapeOut);
-bool mulQuant8(const uint8_t* in1, const Shape& shape1, const uint8_t* in2, const Shape& shape2,
-               int32_t activation, uint8_t* out, const Shape& shapeOut);
-
 bool floorFloat16(const _Float16* inputData, _Float16* outputData, const Shape& shape);
 bool floorFloat32(const float* inputData, float* outputData, const Shape& shape);
 
@@ -172,20 +158,6 @@
                          const int32_t* padding, const Shape& paddingShape, T* outputData,
                          const Shape& outputShape);
 
-bool subFloat16(const _Float16* in1, const Shape& shape1, const _Float16* in2, const Shape& shape2,
-                int32_t activation, _Float16* out, const Shape& shapeOut);
-
-bool subFloat32(const float* in1, const Shape& shape1, const float* in2, const Shape& shape2,
-                int32_t activation, float* out, const Shape& shapeOut);
-
-bool subQuant8(const uint8_t* in1, const Shape& shape1, const uint8_t* in2, const Shape& shape2,
-               int32_t activation, uint8_t* out, const Shape& shapeOut);
-
-bool divFloat16(const _Float16* in1, const Shape& shape1, const _Float16* in2, const Shape& shape2,
-                int32_t activation, _Float16* out, const Shape& shapeOut);
-bool divFloat32(const float* in1, const Shape& shape1, const float* in2, const Shape& shape2,
-                int32_t activation, float* out, const Shape& shapeOut);
-
 template <typename T>
 bool transposeGeneric(const T* inputData, const Shape& inputShape, const int32_t* perm,
                       const Shape& permShape, T* outputData, const Shape& outputShape);
diff --git a/common/include/OperationsUtils.h b/common/include/OperationsUtils.h
index 0f6666e..604c355 100644
--- a/common/include/OperationsUtils.h
+++ b/common/include/OperationsUtils.h
@@ -293,8 +293,6 @@
 uint8_t requantize(uint8_t value, const Shape& oldShape, const Shape& newShape);
 
 // Preparation functions for the corresponding ops
-bool addMulPrepare(const Shape& in1, const Shape& in2, Shape* out1);
-
 bool floorPrepare(const Shape& input, Shape* output);
 
 bool quantizePrepare(const Shape& input, Shape* output);
diff --git a/common/operations/Broadcast.cpp b/common/operations/Broadcast.cpp
new file mode 100644
index 0000000..76b1c44
--- /dev/null
+++ b/common/operations/Broadcast.cpp
@@ -0,0 +1,518 @@
+/*
+ * Copyright (C) 2017 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+// Contains the implementation of the operations.
+
+#define LOG_TAG "Operations"
+
+#include "CpuOperationUtils.h"
+#include "OperationResolver.h"
+
+#include "tensorflow/lite/kernels/internal/optimized/legacy_optimized_ops.h"
+#include "tensorflow/lite/kernels/internal/reference/legacy_reference_ops.h"
+
+#include "Tracing.h"
+
+namespace android {
+namespace nn {
+namespace broadcast {
+
+constexpr uint32_t kNumInputs = 3;
+constexpr uint32_t kInputTensor1 = 0;
+constexpr uint32_t kInputTensor2 = 1;
+constexpr uint32_t kActivationScalar = 2;
+
+constexpr uint32_t kNumOutputs = 1;
+constexpr uint32_t kOutputTensor = 0;
+
+namespace {
+
+#define ANDROID_NN_MACRO_DISPATCH(macro)                                \
+    switch (activation) {                                               \
+        case (int32_t)FusedActivationFunc::NONE:                        \
+            macro(kNone);                                               \
+            break;                                                      \
+        case (int32_t)FusedActivationFunc::RELU:                        \
+            macro(kRelu);                                               \
+            break;                                                      \
+        case (int32_t)FusedActivationFunc::RELU1:                       \
+            macro(kRelu1);                                              \
+            break;                                                      \
+        case (int32_t)FusedActivationFunc::RELU6:                       \
+            macro(kRelu6);                                              \
+            break;                                                      \
+        default:                                                        \
+            LOG(ERROR) << "Unsupported fused activation function type"; \
+            return false;                                               \
+    }
+
+using binaryFunctionFloat32 = std::function<bool(
+        const float* in1, const Shape& shape1, const float* in2, const Shape& shape2,
+        int32_t activation, float* out, const Shape& shapeOut)>;
+
+bool binaryOperationFloat16(const _Float16* in1, const Shape& shape1, const _Float16* in2,
+                            const Shape& shape2, int32_t activation, _Float16* out,
+                            const Shape& shapeOut, binaryFunctionFloat32 operationFloat32) {
+    std::vector<float> in1_float32(getNumberOfElements(shape1));
+    convertFloat16ToFloat32(in1, &in1_float32);
+    std::vector<float> in2_float32(getNumberOfElements(shape2));
+    convertFloat16ToFloat32(in2, &in2_float32);
+    std::vector<float> out_float32(getNumberOfElements(shapeOut));
+
+    operationFloat32(in1_float32.data(), shape1, in2_float32.data(), shape2, activation,
+                     out_float32.data(), shapeOut);
+    convertFloat32ToFloat16(out_float32, out);
+
+    return true;
+}
+
+bool addFloat32(const float* in1, const Shape& shape1, const float* in2, const Shape& shape2,
+                int32_t activation, float* out, const Shape& shapeOut) {
+    NNTRACE_TRANS("addFloat32");
+    bool needBroadcast = !SameShape(shape1, shape2);
+    if (needBroadcast) {
+        NNTRACE_COMP_SWITCH("optimized_ops::BroadcastAdd");
+#define ANDROID_NN_BROADCAST_ADD(activation)                                              \
+    tflite::optimized_ops::BroadcastAdd<tflite::FusedActivationFunctionType::activation>( \
+            in1, convertShapeToDims(shape1), in2, convertShapeToDims(shape2), out,        \
+            convertShapeToDims(shapeOut))
+
+        ANDROID_NN_MACRO_DISPATCH(ANDROID_NN_BROADCAST_ADD)
+#undef ANDROID_NN_BROADCAST_ADD
+    } else {
+        NNTRACE_COMP_SWITCH("optimized_ops::Add");
+#define ANDROID_NN_ADD(activation)                                                 \
+    tflite::optimized_ops::Add<tflite::FusedActivationFunctionType::activation>(   \
+            in1, convertShapeToDims(shape1), in2, convertShapeToDims(shape2), out, \
+            convertShapeToDims(shapeOut))
+
+        ANDROID_NN_MACRO_DISPATCH(ANDROID_NN_ADD)
+#undef ANDROID_NN_ADD
+    }
+
+    return true;
+}
+
+bool addFloat16(const _Float16* in1, const Shape& shape1, const _Float16* in2, const Shape& shape2,
+                int32_t activation, _Float16* out, const Shape& shapeOut) {
+    NNTRACE_TRANS("addFloat16");
+    return binaryOperationFloat16(in1, shape1, in2, shape2, activation, out, shapeOut, &addFloat32);
+}
+
+bool addQuant8(const uint8_t* in1, const Shape& shape1, const uint8_t* in2, const Shape& shape2,
+               int32_t activation, uint8_t* out, const Shape& shapeOut) {
+    NNTRACE_TRANS("addQuant8");
+    bool needBroadcast = !SameShape(shape1, shape2);
+
+    const int32_t input1_offset = -shape1.offset;
+    const int32_t input2_offset = -shape2.offset;
+    const int32_t output_offset = shapeOut.offset;
+    const int left_shift = 20;
+    const double twice_max_input_scale = 2 * std::max(shape1.scale, shape2.scale);
+    const double real_input1_multiplier = shape1.scale / twice_max_input_scale;
+    const double real_input2_multiplier = shape2.scale / twice_max_input_scale;
+    const double real_output_multiplier =
+            twice_max_input_scale / ((1 << left_shift) * shapeOut.scale);
+
+    int32_t input1_multiplier;
+    int32_t input1_shift;
+    if (!QuantizeMultiplierSmallerThanOne(real_input1_multiplier, &input1_multiplier,
+                                          &input1_shift)) {
+        return false;
+    }
+    int32_t input2_multiplier;
+    int32_t input2_shift;
+    if (!QuantizeMultiplierSmallerThanOne(real_input2_multiplier, &input2_multiplier,
+                                          &input2_shift)) {
+        return false;
+    }
+    int32_t output_multiplier;
+    int32_t output_shift;
+    if (!QuantizeMultiplierSmallerThanOne(real_output_multiplier, &output_multiplier,
+                                          &output_shift)) {
+        return false;
+    }
+    int32_t output_activation_min;
+    int32_t output_activation_max;
+    CalculateActivationRangeUint8(activation, shapeOut, &output_activation_min,
+                                  &output_activation_max);
+
+    if (needBroadcast) {
+        NNTRACE_COMP_SWITCH("optimized_ops::BroadcastAdd");
+#define ANDROID_NN_BROADCAST_ADD(activation)                                                     \
+    tflite::optimized_ops::BroadcastAdd<tflite::FusedActivationFunctionType::activation>(        \
+            left_shift, in1, convertShapeToDims(shape1), input1_offset, input1_multiplier,       \
+            input1_shift, in2, convertShapeToDims(shape2), input2_offset, input2_multiplier,     \
+            input2_shift, output_offset, output_multiplier, output_shift, output_activation_min, \
+            output_activation_max, out, convertShapeToDims(shapeOut))
+
+        ANDROID_NN_MACRO_DISPATCH(ANDROID_NN_BROADCAST_ADD)
+#undef ANDROID_NN_BROADCAST_ADD
+    } else {
+        NNTRACE_COMP_SWITCH("optimized_ops::Add");
+#define ANDROID_NN_NORMAL_ADD(activation)                                                        \
+    tflite::optimized_ops::Add<tflite::FusedActivationFunctionType::activation>(                 \
+            left_shift, in1, convertShapeToDims(shape1), input1_offset, input1_multiplier,       \
+            input1_shift, in2, convertShapeToDims(shape2), input2_offset, input2_multiplier,     \
+            input2_shift, output_offset, output_multiplier, output_shift, output_activation_min, \
+            output_activation_max, out, convertShapeToDims(shapeOut))
+
+        ANDROID_NN_MACRO_DISPATCH(ANDROID_NN_NORMAL_ADD)
+#undef ANDROID_NN_NORMAL_ADD
+    }
+
+    return true;
+}
+
+bool mulFloat32(const float* in1, const Shape& shape1, const float* in2, const Shape& shape2,
+                int32_t activation, float* out, const Shape& shapeOut) {
+    NNTRACE_TRANS("mulFloat32");
+    bool needBroadcast = !SameShape(shape1, shape2);
+
+    if (needBroadcast) {
+        NNTRACE_COMP_SWITCH("optimized_ops::BroadcastMul");
+#define ANDROID_NN_BROADCAST_MUL(activation)                                              \
+    tflite::optimized_ops::BroadcastMul<tflite::FusedActivationFunctionType::activation>( \
+            in1, convertShapeToDims(shape1), in2, convertShapeToDims(shape2), out,        \
+            convertShapeToDims(shapeOut))
+
+        ANDROID_NN_MACRO_DISPATCH(ANDROID_NN_BROADCAST_MUL)
+#undef ANDROID_NN_BROADCAST_MUL
+    } else {
+        float output_activation_min, output_activation_max;
+        CalculateActivationRangeFloat(activation, &output_activation_min, &output_activation_max);
+
+        NNTRACE_COMP_SWITCH("optimized_ops::Mul");
+        tflite::optimized_ops::Mul(in1, convertShapeToDims(shape1), in2, convertShapeToDims(shape2),
+                                   output_activation_min, output_activation_max, out,
+                                   convertShapeToDims(shapeOut));
+    }
+
+    return true;
+}
+
+bool mulFloat16(const _Float16* in1, const Shape& shape1, const _Float16* in2, const Shape& shape2,
+                int32_t activation, _Float16* out, const Shape& shapeOut) {
+    NNTRACE_TRANS("mulFloat16");
+    return binaryOperationFloat16(in1, shape1, in2, shape2, activation, out, shapeOut, &mulFloat32);
+}
+
+bool mulQuant8(const uint8_t* in1, const Shape& shape1, const uint8_t* in2, const Shape& shape2,
+               int32_t activation, uint8_t* out, const Shape& shapeOut) {
+    NNTRACE_TRANS("mulQuant8");
+    const int32_t input1_offset = -shape1.offset;
+    const int32_t input2_offset = -shape2.offset;
+    const int32_t output_offset = shapeOut.offset;
+    const double input_product_scale = shape1.scale * shape2.scale;
+    const double real_multiplier = input_product_scale / shapeOut.scale;
+    int32 output_multiplier;
+    int output_shift;
+    if (!QuantizeMultiplierSmallerThanOne(real_multiplier, &output_multiplier, &output_shift)) {
+        return false;
+    }
+    int32_t output_activation_min;
+    int32_t output_activation_max;
+    CalculateActivationRangeUint8(activation, shapeOut, &output_activation_min,
+                                  &output_activation_max);
+
+    // Use BROADCAST version to handle the normal case.
+    NNTRACE_COMP_SWITCH("optimized_ops::BroadcastMul");
+    tflite::optimized_ops::BroadcastMul(in1, convertShapeToDims(shape1), input1_offset, in2,
+                                        convertShapeToDims(shape2), input2_offset, output_offset,
+                                        output_multiplier, output_shift, output_activation_min,
+                                        output_activation_max, out, convertShapeToDims(shapeOut));
+
+    return true;
+}
+
+bool subFloat32(const float* in1, const Shape& shape1, const float* in2, const Shape& shape2,
+                int32_t activation, float* out, const Shape& shapeOut) {
+    NNTRACE_TRANS("subFloat32");
+    NNTRACE_COMP_SWITCH("optimized_ops::Sub");
+    tflite::optimized_ops::Sub(in1, convertShapeToDims(shape1), in2, convertShapeToDims(shape2),
+                               out, convertShapeToDims(shapeOut));
+
+    // TFLite does not apply activation to broadcast sub.
+    float output_activation_min, output_activation_max;
+    CalculateActivationRangeFloat(activation, &output_activation_min, &output_activation_max);
+    uint32_t numOutputElements = getNumberOfElements(shapeOut);
+    for (uint32_t i = 0; i < numOutputElements; i++) {
+        out[i] = std::min(std::max(out[i], output_activation_min), output_activation_max);
+    }
+    return true;
+}
+
+bool subFloat16(const _Float16* in1, const Shape& shape1, const _Float16* in2, const Shape& shape2,
+                int32_t activation, _Float16* out, const Shape& shapeOut) {
+    NNTRACE_TRANS("subFloat16");
+    return binaryOperationFloat16(in1, shape1, in2, shape2, activation, out, shapeOut, &subFloat32);
+}
+
+bool subQuant8(const uint8_t* in1, const Shape& shape1, const uint8_t* in2, const Shape& shape2,
+               int32_t activation, uint8_t* out, const Shape& shapeOut) {
+    NNTRACE_TRANS("subQuant8");
+
+    const int32_t input1_offset = -shape1.offset;
+    const int32_t input2_offset = -shape2.offset;
+    const int32_t output_offset = shapeOut.offset;
+    const int left_shift = 20;
+    const double twice_max_input_scale = 2 * std::max(shape1.scale, shape2.scale);
+    const double real_input1_multiplier = shape1.scale / twice_max_input_scale;
+    const double real_input2_multiplier = shape2.scale / twice_max_input_scale;
+    const double real_output_multiplier =
+            twice_max_input_scale / ((1 << left_shift) * shapeOut.scale);
+
+    int32_t input1_multiplier;
+    int32_t input1_shift;
+    if (!QuantizeMultiplierSmallerThanOne(real_input1_multiplier, &input1_multiplier,
+                                          &input1_shift)) {
+        return false;
+    }
+    int32_t input2_multiplier;
+    int32_t input2_shift;
+    if (!QuantizeMultiplierSmallerThanOne(real_input2_multiplier, &input2_multiplier,
+                                          &input2_shift)) {
+        return false;
+    }
+    input2_multiplier *= -1;
+    int32_t output_multiplier;
+    int32_t output_shift;
+    if (!QuantizeMultiplierSmallerThanOne(real_output_multiplier, &output_multiplier,
+                                          &output_shift)) {
+        return false;
+    }
+    int32_t output_activation_min;
+    int32_t output_activation_max;
+    CalculateActivationRangeUint8(activation, shapeOut, &output_activation_min,
+                                  &output_activation_max);
+
+    // We are using tflite::optimized_ops::BroadcastAdd unconditionally here
+    // because tflite::optimized_ops::Add fails to pass some of the
+    // sub_quantized_different_scales tests.
+    NNTRACE_COMP_SWITCH("optimized_ops::BroadcastAdd");
+#define ANDROID_NN_BROADCAST_ADD(activation)                                                     \
+    tflite::optimized_ops::BroadcastAdd<tflite::FusedActivationFunctionType::activation>(        \
+            left_shift, in1, convertShapeToDims(shape1), input1_offset, input1_multiplier,       \
+            input1_shift, in2, convertShapeToDims(shape2), input2_offset, input2_multiplier,     \
+            input2_shift, output_offset, output_multiplier, output_shift, output_activation_min, \
+            output_activation_max, out, convertShapeToDims(shapeOut))
+
+    ANDROID_NN_MACRO_DISPATCH(ANDROID_NN_BROADCAST_ADD)
+#undef ANDROID_NN_BROADCAST_ADD
+
+    return true;
+}
+
+bool divFloat32(const float* in1, const Shape& shape1, const float* in2, const Shape& shape2,
+                int32_t activation, float* out, const Shape& shapeOut) {
+    NNTRACE_TRANS("divFloat32");
+    float output_activation_min, output_activation_max;
+    CalculateActivationRangeFloat(activation, &output_activation_min, &output_activation_max);
+
+    bool needBroadcast = !SameShape(shape1, shape2);
+    if (needBroadcast) {
+        NNTRACE_COMP_SWITCH("optimized_ops::BroadcastDiv");
+        tflite::optimized_ops::BroadcastDiv(
+                in1, convertShapeToDims(shape1), in2, convertShapeToDims(shape2),
+                output_activation_min, output_activation_max, out, convertShapeToDims(shapeOut));
+    } else {
+        NNTRACE_COMP_SWITCH("optimized_ops::Div");
+        tflite::optimized_ops::Div(in1, convertShapeToDims(shape1), in2, convertShapeToDims(shape2),
+                                   output_activation_min, output_activation_max, out,
+                                   convertShapeToDims(shapeOut));
+    }
+    return true;
+}
+
+bool divFloat16(const _Float16* in1, const Shape& shape1, const _Float16* in2, const Shape& shape2,
+                int32_t activation, _Float16* out, const Shape& shapeOut) {
+    NNTRACE_TRANS("divFloat16");
+    return binaryOperationFloat16(in1, shape1, in2, shape2, activation, out, shapeOut, &divFloat32);
+}
+
+}  // namespace
+
+bool validate(OperationType opType, const IOperationValidationContext* context) {
+    NN_RET_CHECK_EQ(context->getNumInputs(), kNumInputs);
+    NN_RET_CHECK_EQ(context->getNumOutputs(), kNumOutputs);
+    auto inputType = context->getInputType(kInputTensor1);
+    if (inputType == OperandType::TENSOR_FLOAT32) {
+        NN_RET_CHECK(validateHalVersion(context, HalVersion::V1_0));
+    } else if (inputType == OperandType::TENSOR_FLOAT16) {
+        NN_RET_CHECK(validateHalVersion(context, HalVersion::V1_2));
+    } else if (inputType == OperandType::TENSOR_QUANT8_ASYMM) {
+        if (opType == OperationType::SUB) {
+            NN_RET_CHECK(validateHalVersion(context, HalVersion::V1_2));
+        } else if (opType == OperationType::DIV) {
+            NN_RET_CHECK_FAIL() << "Unsupported tensor type for operation DIV";
+        } else {
+            NN_RET_CHECK(validateHalVersion(context, HalVersion::V1_0));
+        }
+    } else {
+        NN_RET_CHECK_FAIL() << "Unsupported tensor type for operation " << getOperationName(opType);
+    }
+    return validateInputTypes(context, {inputType, inputType, OperandType::INT32}) &&
+           validateOutputTypes(context, {inputType});
+}
+
+bool prepare(IOperationExecutionContext* context) {
+    Shape input1 = context->getInputShape(kInputTensor1);
+    Shape input2 = context->getInputShape(kInputTensor2);
+    Shape output = context->getOutputShape(kOutputTensor);
+    NN_RET_CHECK_LE(getNumberOfDimensions(input1), 4);
+    NN_RET_CHECK_LE(getNumberOfDimensions(input2), 4);
+    NN_RET_CHECK(calculateBroadcastedShape(input1, input2, &output));
+    return context->setOutputShape(kOutputTensor, output);
+}
+
+bool executeAdd(IOperationExecutionContext* context) {
+    // Bypass execution in the case of zero-sized input.
+    if (getNumberOfElements(context->getOutputShape(kOutputTensor)) == 0) return true;
+    switch (context->getInputType(kInputTensor1)) {
+        case OperandType::TENSOR_FLOAT16:
+            return addFloat16(context->getInputBuffer<_Float16>(kInputTensor1),
+                              context->getInputShape(kInputTensor1),
+                              context->getInputBuffer<_Float16>(kInputTensor2),
+                              context->getInputShape(kInputTensor2),
+                              context->getInputValue<int32_t>(kActivationScalar),
+                              context->getOutputBuffer<_Float16>(kOutputTensor),
+                              context->getOutputShape(kOutputTensor));
+        case OperandType::TENSOR_FLOAT32:
+            return addFloat32(context->getInputBuffer<float>(kInputTensor1),
+                              context->getInputShape(kInputTensor1),
+                              context->getInputBuffer<float>(kInputTensor2),
+                              context->getInputShape(kInputTensor2),
+                              context->getInputValue<int32_t>(kActivationScalar),
+                              context->getOutputBuffer<float>(kOutputTensor),
+                              context->getOutputShape(kOutputTensor));
+        case OperandType::TENSOR_QUANT8_ASYMM:
+            return addQuant8(context->getInputBuffer<uint8_t>(kInputTensor1),
+                             context->getInputShape(kInputTensor1),
+                             context->getInputBuffer<uint8_t>(kInputTensor2),
+                             context->getInputShape(kInputTensor2),
+                             context->getInputValue<int32_t>(kActivationScalar),
+                             context->getOutputBuffer<uint8_t>(kOutputTensor),
+                             context->getOutputShape(kOutputTensor));
+        default:
+            NN_RET_CHECK_FAIL() << "Unsupported tensor type for operation ADD";
+    }
+}
+
+bool executeMul(IOperationExecutionContext* context) {
+    // Bypass execution in the case of zero-sized input.
+    if (getNumberOfElements(context->getOutputShape(kOutputTensor)) == 0) return true;
+    switch (context->getInputType(kInputTensor1)) {
+        case OperandType::TENSOR_FLOAT16:
+            return mulFloat16(context->getInputBuffer<_Float16>(kInputTensor1),
+                              context->getInputShape(kInputTensor1),
+                              context->getInputBuffer<_Float16>(kInputTensor2),
+                              context->getInputShape(kInputTensor2),
+                              context->getInputValue<int32_t>(kActivationScalar),
+                              context->getOutputBuffer<_Float16>(kOutputTensor),
+                              context->getOutputShape(kOutputTensor));
+        case OperandType::TENSOR_FLOAT32:
+            return mulFloat32(context->getInputBuffer<float>(kInputTensor1),
+                              context->getInputShape(kInputTensor1),
+                              context->getInputBuffer<float>(kInputTensor2),
+                              context->getInputShape(kInputTensor2),
+                              context->getInputValue<int32_t>(kActivationScalar),
+                              context->getOutputBuffer<float>(kOutputTensor),
+                              context->getOutputShape(kOutputTensor));
+        case OperandType::TENSOR_QUANT8_ASYMM:
+            return mulQuant8(context->getInputBuffer<uint8_t>(kInputTensor1),
+                             context->getInputShape(kInputTensor1),
+                             context->getInputBuffer<uint8_t>(kInputTensor2),
+                             context->getInputShape(kInputTensor2),
+                             context->getInputValue<int32_t>(kActivationScalar),
+                             context->getOutputBuffer<uint8_t>(kOutputTensor),
+                             context->getOutputShape(kOutputTensor));
+        default:
+            NN_RET_CHECK_FAIL() << "Unsupported tensor type for operation MUL";
+    }
+}
+
+bool executeSub(IOperationExecutionContext* context) {
+    // Bypass execution in the case of zero-sized input.
+    if (getNumberOfElements(context->getOutputShape(kOutputTensor)) == 0) return true;
+    switch (context->getInputType(kInputTensor1)) {
+        case OperandType::TENSOR_FLOAT16:
+            return subFloat16(context->getInputBuffer<_Float16>(kInputTensor1),
+                              context->getInputShape(kInputTensor1),
+                              context->getInputBuffer<_Float16>(kInputTensor2),
+                              context->getInputShape(kInputTensor2),
+                              context->getInputValue<int32_t>(kActivationScalar),
+                              context->getOutputBuffer<_Float16>(kOutputTensor),
+                              context->getOutputShape(kOutputTensor));
+        case OperandType::TENSOR_FLOAT32:
+            return subFloat32(context->getInputBuffer<float>(kInputTensor1),
+                              context->getInputShape(kInputTensor1),
+                              context->getInputBuffer<float>(kInputTensor2),
+                              context->getInputShape(kInputTensor2),
+                              context->getInputValue<int32_t>(kActivationScalar),
+                              context->getOutputBuffer<float>(kOutputTensor),
+                              context->getOutputShape(kOutputTensor));
+        case OperandType::TENSOR_QUANT8_ASYMM:
+            return subQuant8(context->getInputBuffer<uint8_t>(kInputTensor1),
+                             context->getInputShape(kInputTensor1),
+                             context->getInputBuffer<uint8_t>(kInputTensor2),
+                             context->getInputShape(kInputTensor2),
+                             context->getInputValue<int32_t>(kActivationScalar),
+                             context->getOutputBuffer<uint8_t>(kOutputTensor),
+                             context->getOutputShape(kOutputTensor));
+        default:
+            NN_RET_CHECK_FAIL() << "Unsupported tensor type for operation SUB";
+    }
+}
+
+bool executeDiv(IOperationExecutionContext* context) {
+    // Bypass execution in the case of zero-sized input.
+    if (getNumberOfElements(context->getOutputShape(kOutputTensor)) == 0) return true;
+    switch (context->getInputType(kInputTensor1)) {
+        case OperandType::TENSOR_FLOAT16:
+            return divFloat16(context->getInputBuffer<_Float16>(kInputTensor1),
+                              context->getInputShape(kInputTensor1),
+                              context->getInputBuffer<_Float16>(kInputTensor2),
+                              context->getInputShape(kInputTensor2),
+                              context->getInputValue<int32_t>(kActivationScalar),
+                              context->getOutputBuffer<_Float16>(kOutputTensor),
+                              context->getOutputShape(kOutputTensor));
+        case OperandType::TENSOR_FLOAT32:
+            return divFloat32(context->getInputBuffer<float>(kInputTensor1),
+                              context->getInputShape(kInputTensor1),
+                              context->getInputBuffer<float>(kInputTensor2),
+                              context->getInputShape(kInputTensor2),
+                              context->getInputValue<int32_t>(kActivationScalar),
+                              context->getOutputBuffer<float>(kOutputTensor),
+                              context->getOutputShape(kOutputTensor));
+        default:
+            NN_RET_CHECK_FAIL() << "Unsupported tensor type for operation DIV";
+    }
+}
+
+}  // namespace broadcast
+
+using std::placeholders::_1;
+NN_REGISTER_OPERATION(ADD, "ADD", std::bind(broadcast::validate, OperationType::ADD, _1),
+                      broadcast::prepare, broadcast::executeAdd, .allowZeroSizedInput = true);
+NN_REGISTER_OPERATION(MUL, "MUL", std::bind(broadcast::validate, OperationType::MUL, _1),
+                      broadcast::prepare, broadcast::executeMul, .allowZeroSizedInput = true);
+NN_REGISTER_OPERATION(SUB, "SUB", std::bind(broadcast::validate, OperationType::SUB, _1),
+                      broadcast::prepare, broadcast::executeSub, .allowZeroSizedInput = true);
+NN_REGISTER_OPERATION(DIV, "DIV", std::bind(broadcast::validate, OperationType::DIV, _1),
+                      broadcast::prepare, broadcast::executeDiv, .allowZeroSizedInput = true);
+
+}  // namespace nn
+}  // namespace android
diff --git a/common/operations/SimpleMath.cpp b/common/operations/SimpleMath.cpp
index 50c0b5e..ad9711f 100644
--- a/common/operations/SimpleMath.cpp
+++ b/common/operations/SimpleMath.cpp
@@ -29,204 +29,6 @@
 namespace android {
 namespace nn {
 
-#define ANDROID_NN_MACRO_DISPATCH(macro)                                \
-    switch (activation) {                                               \
-        case (int32_t)FusedActivationFunc::NONE:                        \
-            macro(kNone);                                               \
-            break;                                                      \
-        case (int32_t)FusedActivationFunc::RELU:                        \
-            macro(kRelu);                                               \
-            break;                                                      \
-        case (int32_t)FusedActivationFunc::RELU1:                       \
-            macro(kRelu1);                                              \
-            break;                                                      \
-        case (int32_t)FusedActivationFunc::RELU6:                       \
-            macro(kRelu6);                                              \
-            break;                                                      \
-        default:                                                        \
-            LOG(ERROR) << "Unsupported fused activation function type"; \
-            return false;                                               \
-    }
-
-using binaryFunctionFloat32 = std::function<bool(
-        const float* in1, const Shape& shape1, const float* in2, const Shape& shape2,
-        int32_t activation, float* out, const Shape& shapeOut)>;
-
-bool binaryOperationFloat16(const _Float16* in1, const Shape& shape1, const _Float16* in2,
-                            const Shape& shape2, int32_t activation, _Float16* out,
-                            const Shape& shapeOut, binaryFunctionFloat32 operationFloat32) {
-    std::vector<float> in1_float32(getNumberOfElements(shape1));
-    convertFloat16ToFloat32(in1, &in1_float32);
-    std::vector<float> in2_float32(getNumberOfElements(shape2));
-    convertFloat16ToFloat32(in2, &in2_float32);
-    std::vector<float> out_float32(getNumberOfElements(shapeOut));
-
-    operationFloat32(in1_float32.data(), shape1, in2_float32.data(), shape2, activation,
-                     out_float32.data(), shapeOut);
-    convertFloat32ToFloat16(out_float32, out);
-
-    return true;
-}
-
-bool addFloat16(const _Float16* in1, const Shape& shape1, const _Float16* in2, const Shape& shape2,
-                int32_t activation, _Float16* out, const Shape& shapeOut) {
-    NNTRACE_TRANS("addFloat16");
-    return binaryOperationFloat16(in1, shape1, in2, shape2, activation, out, shapeOut, &addFloat32);
-}
-
-bool addFloat32(const float* in1, const Shape& shape1, const float* in2, const Shape& shape2,
-                int32_t activation, float* out, const Shape& shapeOut) {
-    NNTRACE_TRANS("addFloat32");
-    bool needBroadcast = !SameShape(shape1, shape2);
-    if (needBroadcast) {
-        NNTRACE_COMP_SWITCH("optimized_ops::BroadcastAdd");
-#define ANDROID_NN_BROADCAST_ADD(activation)                                              \
-    tflite::optimized_ops::BroadcastAdd<tflite::FusedActivationFunctionType::activation>( \
-            in1, convertShapeToDims(shape1), in2, convertShapeToDims(shape2), out,        \
-            convertShapeToDims(shapeOut))
-
-        ANDROID_NN_MACRO_DISPATCH(ANDROID_NN_BROADCAST_ADD)
-#undef ANDROID_NN_BROADCAST_ADD
-    } else {
-        NNTRACE_COMP_SWITCH("optimized_ops::Add");
-#define ANDROID_NN_ADD(activation)                                                 \
-    tflite::optimized_ops::Add<tflite::FusedActivationFunctionType::activation>(   \
-            in1, convertShapeToDims(shape1), in2, convertShapeToDims(shape2), out, \
-            convertShapeToDims(shapeOut))
-
-        ANDROID_NN_MACRO_DISPATCH(ANDROID_NN_ADD)
-#undef ANDROID_NN_ADD
-    }
-
-    return true;
-}
-
-bool addQuant8(const uint8_t* in1, const Shape& shape1, const uint8_t* in2, const Shape& shape2,
-               int32_t activation, uint8_t* out, const Shape& shapeOut) {
-    NNTRACE_TRANS("addQuant8");
-    bool needBroadcast = !SameShape(shape1, shape2);
-
-    const int32_t input1_offset = -shape1.offset;
-    const int32_t input2_offset = -shape2.offset;
-    const int32_t output_offset = shapeOut.offset;
-    const int left_shift = 20;
-    const double twice_max_input_scale = 2 * std::max(shape1.scale, shape2.scale);
-    const double real_input1_multiplier = shape1.scale / twice_max_input_scale;
-    const double real_input2_multiplier = shape2.scale / twice_max_input_scale;
-    const double real_output_multiplier =
-            twice_max_input_scale / ((1 << left_shift) * shapeOut.scale);
-
-    int32_t input1_multiplier;
-    int32_t input1_shift;
-    if (!QuantizeMultiplierSmallerThanOne(real_input1_multiplier, &input1_multiplier,
-                                          &input1_shift)) {
-        return false;
-    }
-    int32_t input2_multiplier;
-    int32_t input2_shift;
-    if (!QuantizeMultiplierSmallerThanOne(real_input2_multiplier, &input2_multiplier,
-                                          &input2_shift)) {
-        return false;
-    }
-    int32_t output_multiplier;
-    int32_t output_shift;
-    if (!QuantizeMultiplierSmallerThanOne(real_output_multiplier, &output_multiplier,
-                                          &output_shift)) {
-        return false;
-    }
-    int32_t output_activation_min;
-    int32_t output_activation_max;
-    CalculateActivationRangeUint8(activation, shapeOut, &output_activation_min,
-                                  &output_activation_max);
-
-    if (needBroadcast) {
-        NNTRACE_COMP_SWITCH("optimized_ops::BroadcastAdd");
-#define ANDROID_NN_BROADCAST_ADD(activation)                                                     \
-    tflite::optimized_ops::BroadcastAdd<tflite::FusedActivationFunctionType::activation>(        \
-            left_shift, in1, convertShapeToDims(shape1), input1_offset, input1_multiplier,       \
-            input1_shift, in2, convertShapeToDims(shape2), input2_offset, input2_multiplier,     \
-            input2_shift, output_offset, output_multiplier, output_shift, output_activation_min, \
-            output_activation_max, out, convertShapeToDims(shapeOut))
-
-        ANDROID_NN_MACRO_DISPATCH(ANDROID_NN_BROADCAST_ADD)
-#undef ANDROID_NN_BROADCAST_ADD
-    } else {
-        NNTRACE_COMP_SWITCH("optimized_ops::Add");
-#define ANDROID_NN_NORMAL_ADD(activation)                                                        \
-    tflite::optimized_ops::Add<tflite::FusedActivationFunctionType::activation>(                 \
-            left_shift, in1, convertShapeToDims(shape1), input1_offset, input1_multiplier,       \
-            input1_shift, in2, convertShapeToDims(shape2), input2_offset, input2_multiplier,     \
-            input2_shift, output_offset, output_multiplier, output_shift, output_activation_min, \
-            output_activation_max, out, convertShapeToDims(shapeOut))
-
-        ANDROID_NN_MACRO_DISPATCH(ANDROID_NN_NORMAL_ADD)
-#undef ANDROID_NN_NORMAL_ADD
-    }
-
-    return true;
-}
-
-bool mulFloat16(const _Float16* in1, const Shape& shape1, const _Float16* in2, const Shape& shape2,
-                int32_t activation, _Float16* out, const Shape& shapeOut) {
-    NNTRACE_TRANS("mulFloat16");
-    return binaryOperationFloat16(in1, shape1, in2, shape2, activation, out, shapeOut, &mulFloat32);
-}
-
-bool mulFloat32(const float* in1, const Shape& shape1, const float* in2, const Shape& shape2,
-                int32_t activation, float* out, const Shape& shapeOut) {
-    NNTRACE_TRANS("mulFloat32");
-    bool needBroadcast = !SameShape(shape1, shape2);
-
-    if (needBroadcast) {
-        NNTRACE_COMP_SWITCH("optimized_ops::BroadcastMul");
-#define ANDROID_NN_BROADCAST_MUL(activation)                                              \
-    tflite::optimized_ops::BroadcastMul<tflite::FusedActivationFunctionType::activation>( \
-            in1, convertShapeToDims(shape1), in2, convertShapeToDims(shape2), out,        \
-            convertShapeToDims(shapeOut))
-
-        ANDROID_NN_MACRO_DISPATCH(ANDROID_NN_BROADCAST_MUL)
-#undef ANDROID_NN_BROADCAST_MUL
-    } else {
-        float output_activation_min, output_activation_max;
-        CalculateActivationRangeFloat(activation, &output_activation_min, &output_activation_max);
-
-        NNTRACE_COMP_SWITCH("optimized_ops::Mul");
-        tflite::optimized_ops::Mul(in1, convertShapeToDims(shape1), in2, convertShapeToDims(shape2),
-                                   output_activation_min, output_activation_max, out,
-                                   convertShapeToDims(shapeOut));
-    }
-
-    return true;
-}
-
-bool mulQuant8(const uint8_t* in1, const Shape& shape1, const uint8_t* in2, const Shape& shape2,
-               int32_t activation, uint8_t* out, const Shape& shapeOut) {
-    NNTRACE_TRANS("mulQuant8");
-    const int32_t input1_offset = -shape1.offset;
-    const int32_t input2_offset = -shape2.offset;
-    const int32_t output_offset = shapeOut.offset;
-    const double input_product_scale = shape1.scale * shape2.scale;
-    const double real_multiplier = input_product_scale / shapeOut.scale;
-    int32 output_multiplier;
-    int output_shift;
-    if (!QuantizeMultiplierSmallerThanOne(real_multiplier, &output_multiplier, &output_shift)) {
-        return false;
-    }
-    int32_t output_activation_min;
-    int32_t output_activation_max;
-    CalculateActivationRangeUint8(activation, shapeOut, &output_activation_min,
-                                  &output_activation_max);
-
-    // Use BROADCAST version to handle the normal case.
-    NNTRACE_COMP_SWITCH("optimized_ops::BroadcastMul");
-    tflite::optimized_ops::BroadcastMul(in1, convertShapeToDims(shape1), input1_offset, in2,
-                                        convertShapeToDims(shape2), input2_offset, output_offset,
-                                        output_multiplier, output_shift, output_activation_min,
-                                        output_activation_max, out, convertShapeToDims(shapeOut));
-
-    return true;
-}
-
 bool floorFloat16(const _Float16* inputData, _Float16* outputData, const Shape& shape) {
     NNTRACE_TRANS("floorFloat16");
     std::vector<float> inputDataFloat32(getNumberOfElements(shape));
@@ -270,111 +72,6 @@
     return true;
 }
 
-bool subFloat16(const _Float16* in1, const Shape& shape1, const _Float16* in2, const Shape& shape2,
-                int32_t activation, _Float16* out, const Shape& shapeOut) {
-    NNTRACE_TRANS("subFloat16");
-    return binaryOperationFloat16(in1, shape1, in2, shape2, activation, out, shapeOut, &subFloat32);
-}
-
-bool subFloat32(const float* in1, const Shape& shape1, const float* in2, const Shape& shape2,
-                int32_t activation, float* out, const Shape& shapeOut) {
-    NNTRACE_TRANS("subFloat32");
-    NNTRACE_COMP_SWITCH("optimized_ops::Sub");
-    tflite::optimized_ops::Sub(in1, convertShapeToDims(shape1), in2, convertShapeToDims(shape2),
-                               out, convertShapeToDims(shapeOut));
-
-    // TFLite does not apply activation to broadcast sub.
-    float output_activation_min, output_activation_max;
-    CalculateActivationRangeFloat(activation, &output_activation_min, &output_activation_max);
-    uint32_t numOutputElements = getNumberOfElements(shapeOut);
-    for (uint32_t i = 0; i < numOutputElements; i++) {
-        out[i] = std::min(std::max(out[i], output_activation_min), output_activation_max);
-    }
-    return true;
-}
-
-bool subQuant8(const uint8_t* in1, const Shape& shape1, const uint8_t* in2, const Shape& shape2,
-               int32_t activation, uint8_t* out, const Shape& shapeOut) {
-    NNTRACE_TRANS("subQuant8");
-
-    const int32_t input1_offset = -shape1.offset;
-    const int32_t input2_offset = -shape2.offset;
-    const int32_t output_offset = shapeOut.offset;
-    const int left_shift = 20;
-    const double twice_max_input_scale = 2 * std::max(shape1.scale, shape2.scale);
-    const double real_input1_multiplier = shape1.scale / twice_max_input_scale;
-    const double real_input2_multiplier = shape2.scale / twice_max_input_scale;
-    const double real_output_multiplier =
-            twice_max_input_scale / ((1 << left_shift) * shapeOut.scale);
-
-    int32_t input1_multiplier;
-    int32_t input1_shift;
-    if (!QuantizeMultiplierSmallerThanOne(real_input1_multiplier, &input1_multiplier,
-                                          &input1_shift)) {
-        return false;
-    }
-    int32_t input2_multiplier;
-    int32_t input2_shift;
-    if (!QuantizeMultiplierSmallerThanOne(real_input2_multiplier, &input2_multiplier,
-                                          &input2_shift)) {
-        return false;
-    }
-    input2_multiplier *= -1;
-    int32_t output_multiplier;
-    int32_t output_shift;
-    if (!QuantizeMultiplierSmallerThanOne(real_output_multiplier, &output_multiplier,
-                                          &output_shift)) {
-        return false;
-    }
-    int32_t output_activation_min;
-    int32_t output_activation_max;
-    CalculateActivationRangeUint8(activation, shapeOut, &output_activation_min,
-                                  &output_activation_max);
-
-    // We are using tflite::optimized_ops::BroadcastAdd unconditionally here
-    // because tflite::optimized_ops::Add fails to pass some of the
-    // sub_quantized_different_scales tests.
-    NNTRACE_COMP_SWITCH("optimized_ops::BroadcastAdd");
-#define ANDROID_NN_BROADCAST_ADD(activation)                                                     \
-    tflite::optimized_ops::BroadcastAdd<tflite::FusedActivationFunctionType::activation>(        \
-            left_shift, in1, convertShapeToDims(shape1), input1_offset, input1_multiplier,       \
-            input1_shift, in2, convertShapeToDims(shape2), input2_offset, input2_multiplier,     \
-            input2_shift, output_offset, output_multiplier, output_shift, output_activation_min, \
-            output_activation_max, out, convertShapeToDims(shapeOut))
-
-    ANDROID_NN_MACRO_DISPATCH(ANDROID_NN_BROADCAST_ADD)
-#undef ANDROID_NN_BROADCAST_ADD
-
-    return true;
-}
-
-bool divFloat16(const _Float16* in1, const Shape& shape1, const _Float16* in2, const Shape& shape2,
-                int32_t activation, _Float16* out, const Shape& shapeOut) {
-    NNTRACE_TRANS("divFloat16");
-    return binaryOperationFloat16(in1, shape1, in2, shape2, activation, out, shapeOut, &divFloat32);
-}
-
-bool divFloat32(const float* in1, const Shape& shape1, const float* in2, const Shape& shape2,
-                int32_t activation, float* out, const Shape& shapeOut) {
-    NNTRACE_TRANS("divFloat32");
-    float output_activation_min, output_activation_max;
-    CalculateActivationRangeFloat(activation, &output_activation_min, &output_activation_max);
-
-    bool needBroadcast = !SameShape(shape1, shape2);
-    if (needBroadcast) {
-        NNTRACE_COMP_SWITCH("optimized_ops::BroadcastDiv");
-        tflite::optimized_ops::BroadcastDiv(
-                in1, convertShapeToDims(shape1), in2, convertShapeToDims(shape2),
-                output_activation_min, output_activation_max, out, convertShapeToDims(shapeOut));
-    } else {
-        NNTRACE_COMP_SWITCH("optimized_ops::Div");
-        tflite::optimized_ops::Div(in1, convertShapeToDims(shape1), in2, convertShapeToDims(shape2),
-                                   output_activation_min, output_activation_max, out,
-                                   convertShapeToDims(shapeOut));
-    }
-    return true;
-}
-
 bool meanFloat16(_Float16* inputData, const Shape& inputShape, const int32_t* axis,
                  const Shape& axisShape, bool keepDims, _Float16* outputData,
                  const Shape& outputShape) {