Adds float16 support for LSH_PROJECTION.

Bug: 118607785
Test: NeuralNetworksTest_static
Change-Id: Ibfb752efae48cc63a3a3b11e8ef664e2b4dcd988
Merged-In: Ibfb752efae48cc63a3a3b11e8ef664e2b4dcd988
(cherry picked from commit 62fc7896f1ed790260bef1849cd950482fb8c315)
diff --git a/common/Android.bp b/common/Android.bp
index f812beb..ef06e44 100644
--- a/common/Android.bp
+++ b/common/Android.bp
@@ -85,8 +85,9 @@
     vendor_available: true,
     // b/109953668, disable OpenMP
     // openmp: true,
-    export_include_dirs: ["include"],
-
+    export_include_dirs: [
+        "include",
+    ],
     srcs: [
         "CpuExecutor.cpp",
         "GraphDump.cpp",
diff --git a/common/CpuExecutor.cpp b/common/CpuExecutor.cpp
index 814402e..6aa6638 100644
--- a/common/CpuExecutor.cpp
+++ b/common/CpuExecutor.cpp
@@ -1681,16 +1681,29 @@
                 lookup.Eval();
         } break;
         case OperationType::LSH_PROJECTION: {
-            RunTimeOperandInfo &output =
-                mOperands[outs[LSHProjection::kOutputTensor]];
-
+            RunTimeOperandInfo& output = mOperands[outs[LSHProjection::kOutputTensor]];
             Shape outputShape;
-            LSHProjection lsh(operation, mOperands);
+            if (!LSHProjection::Prepare(operation, mOperands, &outputShape) ||
+                !setInfoAndAllocateIfNeeded(&output, outputShape)) {
+                break;
+            }
 
-            success = LSHProjection::Prepare(operation, mOperands,
-                                             &outputShape) &&
-                setInfoAndAllocateIfNeeded(&output, outputShape) &&
-                lsh.Eval();
+            LSHProjection lsh(operation, mOperands);
+            const RunTimeOperandInfo& hash = mOperands[ins[LSHProjection::kHashTensor]];
+            switch (hash.type) {
+                case OperandType::TENSOR_FLOAT32: {
+                    success = lsh.Eval<float>();
+                    break;
+                }
+                case OperandType::TENSOR_FLOAT16: {
+                    success = lsh.Eval<_Float16>();
+                    break;
+                }
+                default: {
+                    success = false;
+                    LOG(ERROR) << "Unsupported data type";
+                }
+            }
         } break;
         case OperationType::LSTM: {
             RunTimeOperandInfo& scratch = mOperands[outs[LSTMCell::kScratchBufferTensor]];
diff --git a/common/Utils.cpp b/common/Utils.cpp
index 9946180..c3e4fc1 100644
--- a/common/Utils.cpp
+++ b/common/Utils.cpp
@@ -1531,24 +1531,40 @@
                 return ANEURALNETWORKS_BAD_DATA;
             }
             auto inputType = operands[inputIndexes[1]].type;
-            if (inputType != OperandType::TENSOR_FLOAT32 &&
+            if (inputType != OperandType::TENSOR_FLOAT16 &&
+                inputType != OperandType::TENSOR_FLOAT32 &&
                 inputType != OperandType::TENSOR_INT32 &&
                 inputType != OperandType::TENSOR_QUANT8_ASYMM) {
                 LOG(ERROR) << "Unsupported input tensor type for operation "
                            << getOperationName(opType);
                 return ANEURALNETWORKS_BAD_DATA;
             }
-            std::vector<OperandType> inExpectedTypes = {OperandType::TENSOR_FLOAT32,
-                                                        inputType,
-                                                        OperandType::TENSOR_FLOAT32,
-                                                        OperandType::INT32};
+            auto hashType = operands[inputIndexes[0]].type;
+            std::vector<OperandType> inExpectedTypes;
+            if (hashType == OperandType::TENSOR_FLOAT16) {
+                NN_RETURN_IF_ERROR(validateHalVersion(opType, halVersion, HalVersion::V1_2));
+                inExpectedTypes = {
+                        OperandType::TENSOR_FLOAT16,
+                        inputType,
+                        OperandType::TENSOR_FLOAT16,
+                        OperandType::INT32,
+                };
+            } else if (hashType == OperandType::TENSOR_FLOAT32) {
+                NN_RETURN_IF_ERROR(validateHalVersion(opType, halVersion, HalVersion::V1_0));
+                inExpectedTypes = {
+                        OperandType::TENSOR_FLOAT32,
+                        inputType,
+                        OperandType::TENSOR_FLOAT32,
+                        OperandType::INT32,
+                };
+            } else {
+                LOG(ERROR) << "Unsupported hash tensor type for operation "
+                           << getOperationName(opType);
+                return ANEURALNETWORKS_BAD_DATA;
+            }
             std::vector<OperandType> outExpectedTypes = {OperandType::TENSOR_INT32};
-            // TODO(mks): Return V1_2 if inputType is sparse.
-            NN_RETURN_IF_ERROR(validateHalVersion(opType, halVersion, HalVersion::V1_0));
-            return validateOperationOperandTypes(operands,
-                                                 inputCount, inputIndexes,
-                                                 inExpectedTypes,
-                                                 outputCount, outputIndexes,
+            return validateOperationOperandTypes(operands, inputCount, inputIndexes,
+                                                 inExpectedTypes, outputCount, outputIndexes,
                                                  outExpectedTypes);
         }
         case ANEURALNETWORKS_LSTM: {
diff --git a/common/operations/LSHProjection.cpp b/common/operations/LSHProjection.cpp
index fc5bc26..e66aade 100644
--- a/common/operations/LSHProjection.cpp
+++ b/common/operations/LSHProjection.cpp
@@ -82,14 +82,14 @@
 //       to match the trained model. This is going to be changed once the new
 //       model is trained in an optimized method.
 //
-int running_sign_bit(const RunTimeOperandInfo* input, const RunTimeOperandInfo* weight,
-                     float seed) {
+template <typename T>
+int runningSignBit(const RunTimeOperandInfo* input, const RunTimeOperandInfo* weight, float seed) {
     double score = 0.0;
     int input_item_bytes = sizeOfData(input->type, input->dimensions) / SizeOfDimension(input, 0);
     char* input_ptr = (char*)(input->buffer);
 
-    const size_t seed_size = sizeof(float);
-    const size_t key_bytes = sizeof(float) + input_item_bytes;
+    const size_t seed_size = sizeof(seed);
+    const size_t key_bytes = seed_size + input_item_bytes;
     std::unique_ptr<char[]> key(new char[key_bytes]);
 
     for (uint32_t i = 0; i < SizeOfDimension(input, 0); ++i) {
@@ -103,13 +103,14 @@
         if (weight->lifetime == OperandLifeTime::NO_VALUE) {
             score += running_value;
         } else {
-            score += reinterpret_cast<float*>(weight->buffer)[i] * running_value;
+            score += static_cast<double>(reinterpret_cast<T*>(weight->buffer)[i]) * running_value;
         }
     }
 
     return (score > 0) ? 1 : 0;
 }
 
+template <typename T>
 void SparseLshProjection(LSHProjectionType type, const RunTimeOperandInfo* hash,
                          const RunTimeOperandInfo* input, const RunTimeOperandInfo* weight,
                          int32_t* out_buf) {
@@ -118,8 +119,8 @@
     for (int i = 0; i < num_hash; i++) {
         int32_t hash_signature = 0;
         for (int j = 0; j < num_bits; j++) {
-            float seed = reinterpret_cast<float*>(hash->buffer)[i * num_bits + j];
-            int bit = running_sign_bit(input, weight, seed);
+            T seed = reinterpret_cast<T*>(hash->buffer)[i * num_bits + j];
+            int bit = runningSignBit<T>(input, weight, static_cast<float>(seed));
             hash_signature = (hash_signature << 1) | bit;
         }
         if (type == LSHProjectionType_SPARSE_DEPRECATED) {
@@ -130,19 +131,21 @@
     }
 }
 
+template <typename T>
 void DenseLshProjection(const RunTimeOperandInfo* hash, const RunTimeOperandInfo* input,
                         const RunTimeOperandInfo* weight, int32_t* out_buf) {
     int num_hash = SizeOfDimension(hash, 0);
     int num_bits = SizeOfDimension(hash, 1);
     for (int i = 0; i < num_hash; i++) {
         for (int j = 0; j < num_bits; j++) {
-            float seed = reinterpret_cast<float*>(hash->buffer)[i * num_bits + j];
-            int bit = running_sign_bit(input, weight, seed);
+            T seed = reinterpret_cast<T*>(hash->buffer)[i * num_bits + j];
+            int bit = runningSignBit<T>(input, weight, static_cast<float>(seed));
             *out_buf++ = bit;
         }
     }
 }
 
+template <typename T>
 bool LSHProjection::Eval() {
     NNTRACE_COMP("LSHProjection::Eval");
 
@@ -150,11 +153,11 @@
 
     switch (type_) {
         case LSHProjectionType_DENSE:
-            DenseLshProjection(hash_, input_, weight_, out_buf);
+            DenseLshProjection<T>(hash_, input_, weight_, out_buf);
             break;
         case LSHProjectionType_SPARSE:
         case LSHProjectionType_SPARSE_DEPRECATED:
-            SparseLshProjection(type_, hash_, input_, weight_, out_buf);
+            SparseLshProjection<T>(type_, hash_, input_, weight_, out_buf);
             break;
         default:
             return false;
@@ -162,5 +165,27 @@
     return true;
 }
 
+template bool LSHProjection::Eval<float>();
+template bool LSHProjection::Eval<_Float16>();
+
+template int runningSignBit<float>(const RunTimeOperandInfo* input,
+                                   const RunTimeOperandInfo* weight, float seed);
+template int runningSignBit<_Float16>(const RunTimeOperandInfo* input,
+                                      const RunTimeOperandInfo* weight, float seed);
+
+template void SparseLshProjection<float>(LSHProjectionType type, const RunTimeOperandInfo* hash,
+                                         const RunTimeOperandInfo* input,
+                                         const RunTimeOperandInfo* weight, int32_t* outBuffer);
+template void SparseLshProjection<_Float16>(LSHProjectionType type, const RunTimeOperandInfo* hash,
+                                            const RunTimeOperandInfo* input,
+                                            const RunTimeOperandInfo* weight, int32_t* outBuffer);
+
+template void DenseLshProjection<float>(const RunTimeOperandInfo* hash,
+                                        const RunTimeOperandInfo* input,
+                                        const RunTimeOperandInfo* weight, int32_t* outBuffer);
+template void DenseLshProjection<_Float16>(const RunTimeOperandInfo* hash,
+                                           const RunTimeOperandInfo* input,
+                                           const RunTimeOperandInfo* weight, int32_t* outBuffer);
+
 }  // namespace nn
 }  // namespace android
diff --git a/common/operations/LSHProjection.h b/common/operations/LSHProjection.h
index 8cf2fdc..7a25518 100644
--- a/common/operations/LSHProjection.h
+++ b/common/operations/LSHProjection.h
@@ -40,6 +40,7 @@
 
     static bool Prepare(const Operation& operation, std::vector<RunTimeOperandInfo>& operands,
                         Shape* outputShape);
+    template <typename T>
     bool Eval();
 
     static constexpr int kHashTensor = 0;
@@ -60,6 +61,18 @@
     RunTimeOperandInfo* output_;
 };
 
+template <typename T>
+int runningSignBit(const RunTimeOperandInfo* input, const RunTimeOperandInfo* weight, T seed);
+
+template <typename T>
+void SparseLshProjection(LSHProjectionType type, const RunTimeOperandInfo* hash,
+                         const RunTimeOperandInfo* input, const RunTimeOperandInfo* weight,
+                         int32_t* outBuffer);
+
+template <typename T>
+void DenseLshProjection(const RunTimeOperandInfo* hash, const RunTimeOperandInfo* input,
+                        const RunTimeOperandInfo* weight, int32_t* outBuffer);
+
 }  // namespace nn
 }  // namespace android