Adds float16 support for LSH_PROJECTION.
Bug: 118607785
Test: NeuralNetworksTest_static
Change-Id: Ibfb752efae48cc63a3a3b11e8ef664e2b4dcd988
Merged-In: Ibfb752efae48cc63a3a3b11e8ef664e2b4dcd988
(cherry picked from commit 62fc7896f1ed790260bef1849cd950482fb8c315)
diff --git a/common/Android.bp b/common/Android.bp
index f812beb..ef06e44 100644
--- a/common/Android.bp
+++ b/common/Android.bp
@@ -85,8 +85,9 @@
vendor_available: true,
// b/109953668, disable OpenMP
// openmp: true,
- export_include_dirs: ["include"],
-
+ export_include_dirs: [
+ "include",
+ ],
srcs: [
"CpuExecutor.cpp",
"GraphDump.cpp",
diff --git a/common/CpuExecutor.cpp b/common/CpuExecutor.cpp
index 814402e..6aa6638 100644
--- a/common/CpuExecutor.cpp
+++ b/common/CpuExecutor.cpp
@@ -1681,16 +1681,29 @@
lookup.Eval();
} break;
case OperationType::LSH_PROJECTION: {
- RunTimeOperandInfo &output =
- mOperands[outs[LSHProjection::kOutputTensor]];
-
+ RunTimeOperandInfo& output = mOperands[outs[LSHProjection::kOutputTensor]];
Shape outputShape;
- LSHProjection lsh(operation, mOperands);
+ if (!LSHProjection::Prepare(operation, mOperands, &outputShape) ||
+ !setInfoAndAllocateIfNeeded(&output, outputShape)) {
+ break;
+ }
- success = LSHProjection::Prepare(operation, mOperands,
- &outputShape) &&
- setInfoAndAllocateIfNeeded(&output, outputShape) &&
- lsh.Eval();
+ LSHProjection lsh(operation, mOperands);
+ const RunTimeOperandInfo& hash = mOperands[ins[LSHProjection::kHashTensor]];
+ switch (hash.type) {
+ case OperandType::TENSOR_FLOAT32: {
+ success = lsh.Eval<float>();
+ break;
+ }
+ case OperandType::TENSOR_FLOAT16: {
+ success = lsh.Eval<_Float16>();
+ break;
+ }
+ default: {
+ success = false;
+ LOG(ERROR) << "Unsupported data type";
+ }
+ }
} break;
case OperationType::LSTM: {
RunTimeOperandInfo& scratch = mOperands[outs[LSTMCell::kScratchBufferTensor]];
diff --git a/common/Utils.cpp b/common/Utils.cpp
index 9946180..c3e4fc1 100644
--- a/common/Utils.cpp
+++ b/common/Utils.cpp
@@ -1531,24 +1531,40 @@
return ANEURALNETWORKS_BAD_DATA;
}
auto inputType = operands[inputIndexes[1]].type;
- if (inputType != OperandType::TENSOR_FLOAT32 &&
+ if (inputType != OperandType::TENSOR_FLOAT16 &&
+ inputType != OperandType::TENSOR_FLOAT32 &&
inputType != OperandType::TENSOR_INT32 &&
inputType != OperandType::TENSOR_QUANT8_ASYMM) {
LOG(ERROR) << "Unsupported input tensor type for operation "
<< getOperationName(opType);
return ANEURALNETWORKS_BAD_DATA;
}
- std::vector<OperandType> inExpectedTypes = {OperandType::TENSOR_FLOAT32,
- inputType,
- OperandType::TENSOR_FLOAT32,
- OperandType::INT32};
+ auto hashType = operands[inputIndexes[0]].type;
+ std::vector<OperandType> inExpectedTypes;
+ if (hashType == OperandType::TENSOR_FLOAT16) {
+ NN_RETURN_IF_ERROR(validateHalVersion(opType, halVersion, HalVersion::V1_2));
+ inExpectedTypes = {
+ OperandType::TENSOR_FLOAT16,
+ inputType,
+ OperandType::TENSOR_FLOAT16,
+ OperandType::INT32,
+ };
+ } else if (hashType == OperandType::TENSOR_FLOAT32) {
+ NN_RETURN_IF_ERROR(validateHalVersion(opType, halVersion, HalVersion::V1_0));
+ inExpectedTypes = {
+ OperandType::TENSOR_FLOAT32,
+ inputType,
+ OperandType::TENSOR_FLOAT32,
+ OperandType::INT32,
+ };
+ } else {
+ LOG(ERROR) << "Unsupported hash tensor type for operation "
+ << getOperationName(opType);
+ return ANEURALNETWORKS_BAD_DATA;
+ }
std::vector<OperandType> outExpectedTypes = {OperandType::TENSOR_INT32};
- // TODO(mks): Return V1_2 if inputType is sparse.
- NN_RETURN_IF_ERROR(validateHalVersion(opType, halVersion, HalVersion::V1_0));
- return validateOperationOperandTypes(operands,
- inputCount, inputIndexes,
- inExpectedTypes,
- outputCount, outputIndexes,
+ return validateOperationOperandTypes(operands, inputCount, inputIndexes,
+ inExpectedTypes, outputCount, outputIndexes,
outExpectedTypes);
}
case ANEURALNETWORKS_LSTM: {
diff --git a/common/operations/LSHProjection.cpp b/common/operations/LSHProjection.cpp
index fc5bc26..e66aade 100644
--- a/common/operations/LSHProjection.cpp
+++ b/common/operations/LSHProjection.cpp
@@ -82,14 +82,14 @@
// to match the trained model. This is going to be changed once the new
// model is trained in an optimized method.
//
-int running_sign_bit(const RunTimeOperandInfo* input, const RunTimeOperandInfo* weight,
- float seed) {
+template <typename T>
+int runningSignBit(const RunTimeOperandInfo* input, const RunTimeOperandInfo* weight, float seed) {
double score = 0.0;
int input_item_bytes = sizeOfData(input->type, input->dimensions) / SizeOfDimension(input, 0);
char* input_ptr = (char*)(input->buffer);
- const size_t seed_size = sizeof(float);
- const size_t key_bytes = sizeof(float) + input_item_bytes;
+ const size_t seed_size = sizeof(seed);
+ const size_t key_bytes = seed_size + input_item_bytes;
std::unique_ptr<char[]> key(new char[key_bytes]);
for (uint32_t i = 0; i < SizeOfDimension(input, 0); ++i) {
@@ -103,13 +103,14 @@
if (weight->lifetime == OperandLifeTime::NO_VALUE) {
score += running_value;
} else {
- score += reinterpret_cast<float*>(weight->buffer)[i] * running_value;
+ score += static_cast<double>(reinterpret_cast<T*>(weight->buffer)[i]) * running_value;
}
}
return (score > 0) ? 1 : 0;
}
+template <typename T>
void SparseLshProjection(LSHProjectionType type, const RunTimeOperandInfo* hash,
const RunTimeOperandInfo* input, const RunTimeOperandInfo* weight,
int32_t* out_buf) {
@@ -118,8 +119,8 @@
for (int i = 0; i < num_hash; i++) {
int32_t hash_signature = 0;
for (int j = 0; j < num_bits; j++) {
- float seed = reinterpret_cast<float*>(hash->buffer)[i * num_bits + j];
- int bit = running_sign_bit(input, weight, seed);
+ T seed = reinterpret_cast<T*>(hash->buffer)[i * num_bits + j];
+ int bit = runningSignBit<T>(input, weight, static_cast<float>(seed));
hash_signature = (hash_signature << 1) | bit;
}
if (type == LSHProjectionType_SPARSE_DEPRECATED) {
@@ -130,19 +131,21 @@
}
}
+template <typename T>
void DenseLshProjection(const RunTimeOperandInfo* hash, const RunTimeOperandInfo* input,
const RunTimeOperandInfo* weight, int32_t* out_buf) {
int num_hash = SizeOfDimension(hash, 0);
int num_bits = SizeOfDimension(hash, 1);
for (int i = 0; i < num_hash; i++) {
for (int j = 0; j < num_bits; j++) {
- float seed = reinterpret_cast<float*>(hash->buffer)[i * num_bits + j];
- int bit = running_sign_bit(input, weight, seed);
+ T seed = reinterpret_cast<T*>(hash->buffer)[i * num_bits + j];
+ int bit = runningSignBit<T>(input, weight, static_cast<float>(seed));
*out_buf++ = bit;
}
}
}
+template <typename T>
bool LSHProjection::Eval() {
NNTRACE_COMP("LSHProjection::Eval");
@@ -150,11 +153,11 @@
switch (type_) {
case LSHProjectionType_DENSE:
- DenseLshProjection(hash_, input_, weight_, out_buf);
+ DenseLshProjection<T>(hash_, input_, weight_, out_buf);
break;
case LSHProjectionType_SPARSE:
case LSHProjectionType_SPARSE_DEPRECATED:
- SparseLshProjection(type_, hash_, input_, weight_, out_buf);
+ SparseLshProjection<T>(type_, hash_, input_, weight_, out_buf);
break;
default:
return false;
@@ -162,5 +165,27 @@
return true;
}
+template bool LSHProjection::Eval<float>();
+template bool LSHProjection::Eval<_Float16>();
+
+template int runningSignBit<float>(const RunTimeOperandInfo* input,
+ const RunTimeOperandInfo* weight, float seed);
+template int runningSignBit<_Float16>(const RunTimeOperandInfo* input,
+ const RunTimeOperandInfo* weight, float seed);
+
+template void SparseLshProjection<float>(LSHProjectionType type, const RunTimeOperandInfo* hash,
+ const RunTimeOperandInfo* input,
+ const RunTimeOperandInfo* weight, int32_t* outBuffer);
+template void SparseLshProjection<_Float16>(LSHProjectionType type, const RunTimeOperandInfo* hash,
+ const RunTimeOperandInfo* input,
+ const RunTimeOperandInfo* weight, int32_t* outBuffer);
+
+template void DenseLshProjection<float>(const RunTimeOperandInfo* hash,
+ const RunTimeOperandInfo* input,
+ const RunTimeOperandInfo* weight, int32_t* outBuffer);
+template void DenseLshProjection<_Float16>(const RunTimeOperandInfo* hash,
+ const RunTimeOperandInfo* input,
+ const RunTimeOperandInfo* weight, int32_t* outBuffer);
+
} // namespace nn
} // namespace android
diff --git a/common/operations/LSHProjection.h b/common/operations/LSHProjection.h
index 8cf2fdc..7a25518 100644
--- a/common/operations/LSHProjection.h
+++ b/common/operations/LSHProjection.h
@@ -40,6 +40,7 @@
static bool Prepare(const Operation& operation, std::vector<RunTimeOperandInfo>& operands,
Shape* outputShape);
+ template <typename T>
bool Eval();
static constexpr int kHashTensor = 0;
@@ -60,6 +61,18 @@
RunTimeOperandInfo* output_;
};
+template <typename T>
+int runningSignBit(const RunTimeOperandInfo* input, const RunTimeOperandInfo* weight, T seed);
+
+template <typename T>
+void SparseLshProjection(LSHProjectionType type, const RunTimeOperandInfo* hash,
+ const RunTimeOperandInfo* input, const RunTimeOperandInfo* weight,
+ int32_t* outBuffer);
+
+template <typename T>
+void DenseLshProjection(const RunTimeOperandInfo* hash, const RunTimeOperandInfo* input,
+ const RunTimeOperandInfo* weight, int32_t* outBuffer);
+
} // namespace nn
} // namespace android