blob: 48c869ee703818a7b53d035fffa742dfe4847cb4 [file] [log] [blame]
#include "caffe2/perfkernels/embedding_lookup_idx.h"
#include <c10/util/BFloat16.h>
#include <c10/util/Half.h>
#include <c10/util/irange.h>
#include "caffe2/core/common.h"
#include "caffe2/core/logging.h"
#include "caffe2/perfkernels/common.h"
namespace caffe2 {
/**
* Base implementation does runtime dispatch for each segment of reduction
* @return false if there is an out-of-bound error
*/
template <
typename IndexType,
typename InType,
typename OutType,
bool IS_WEIGHT_POSITIONAL = false>
static bool EmbeddingLookupGenericSlowIdx(
const int64_t block_size,
const int64_t output_size,
const int64_t index_size,
const int64_t data_size,
const InType* input,
const IndexType* indices,
const IndexType* offsets,
const float* weights, // optional, can be null for sum reducer
const float* scale_bias, // optional scale & bias params for uint8 input
bool normalize_by_lengths,
OutType* out) {
int64_t current = 0;
for (const auto m : c10::irange(output_size)) {
memset(out, 0, sizeof(OutType) * block_size);
if (current != offsets[m] - offsets[0]) {
return false;
}
int64_t start_offset = offsets[m];
int64_t end_offset = offsets[m + 1];
int64_t length = end_offset - start_offset;
for (const auto i : c10::irange(start_offset, end_offset)) {
int64_t idx = indices[current];
if (idx < 0 || idx >= data_size) {
return false;
}
#ifdef __GNUC__
if (current + 1 < index_size) {
__builtin_prefetch(input + block_size * indices[current + 1], 0, 1);
}
#endif // __GNUC__
float w = 1.f, b = 0.f;
if (weights) {
w = weights[IS_WEIGHT_POSITIONAL ? i - start_offset : current];
}
if (scale_bias) {
b = w * scale_bias[2 * indices[current] + 1];
w = w * scale_bias[2 * indices[current]];
}
for (const auto j : c10::irange(block_size)) {
out[j] += w * input[block_size * indices[current] + j] + b;
}
++current;
}
if (normalize_by_lengths && length) {
float scale = 1.f / length;
for (const auto j : c10::irange(block_size)) {
out[j] *= scale;
}
}
out += block_size;
}
return current == index_size;
}
// clang-format off
// Proxy back to generic implementation
#define EMBEDDING_IDX_SPECIALIZATION( \
IndexType, InTypeName, InType, OutType, IS_WEIGHT_POSITIONAL) \
bool \
EmbeddingLookupIdx_##IndexType##_##InTypeName##_##OutType##_##IS_WEIGHT_POSITIONAL##__base( \
const int64_t block_size, \
const int64_t output_size, \
const int64_t index_size, \
const int64_t data_size, \
const InType* input, \
const IndexType* indices, \
const IndexType* offsets, \
const float* weights, \
const float* scale_bias, \
bool normalize_by_lengths, \
OutType* out) { \
return EmbeddingLookupGenericSlowIdx< \
IndexType, \
InType, \
OutType, \
IS_WEIGHT_POSITIONAL>( \
block_size, \
output_size, \
index_size, \
data_size, \
input, \
indices, \
offsets, \
weights, \
scale_bias, \
normalize_by_lengths, \
out); \
} \
decltype( \
EmbeddingLookupIdx_##IndexType##_##InTypeName##_##OutType##_##IS_WEIGHT_POSITIONAL##__base) \
EmbeddingLookupIdx_##IndexType##_##InTypeName##_##OutType##_##IS_WEIGHT_POSITIONAL##__avx2_fma; \
bool \
EmbeddingLookupIdx_##IndexType##_##InTypeName##_##OutType##_##IS_WEIGHT_POSITIONAL( \
const int64_t block_size, \
const int64_t output_size, \
const int64_t index_size, \
const int64_t data_size, \
const InType* input, \
const IndexType* indices, \
const IndexType* offsets, \
const float* weights, \
const float* scale_bias, \
bool normalize_by_lengths, \
OutType* out) { \
if (std::is_same<InType, uint8_t>::value) { \
CAFFE_ENFORCE(scale_bias != nullptr, "scale_bias must not be nullptr"); \
} else { \
CAFFE_ENFORCE(scale_bias == nullptr, "scale_bias must be nullptr"); \
} \
AVX2_FMA_DO( \
EmbeddingLookupIdx_##IndexType##_##InTypeName##_##OutType##_##IS_WEIGHT_POSITIONAL, \
block_size, \
output_size, \
index_size, \
data_size, \
input, \
indices, \
offsets, \
weights, \
scale_bias, \
normalize_by_lengths, \
out); \
BASE_DO( \
EmbeddingLookupIdx_##IndexType##_##InTypeName##_##OutType##_##IS_WEIGHT_POSITIONAL, \
block_size, \
output_size, \
index_size, \
data_size, \
input, \
indices, \
offsets, \
weights, \
scale_bias, \
normalize_by_lengths, \
out); \
} \
template <> \
void EmbeddingLookupIdx<IndexType, InType, OutType, IS_WEIGHT_POSITIONAL>( \
const int64_t block_size, \
const int64_t output_size, \
const int64_t index_size, \
const int64_t data_size, \
const InType* input, \
const IndexType* indices, \
const IndexType* offsets, \
const float* weights, \
const float* scale_bias, \
bool normalize_by_lengths, \
OutType* out) { \
bool success = \
EmbeddingLookupIdx_##IndexType##_##InTypeName##_##OutType##_##IS_WEIGHT_POSITIONAL( \
block_size, \
output_size, \
index_size, \
data_size, \
input, \
indices, \
offsets, \
weights, \
scale_bias, \
normalize_by_lengths, \
out); \
if (success) { \
return; \
} \
int64_t current = 0; \
for (int m = 0; m < output_size; ++m) { \
for (int64_t i = offsets[m]; i < offsets[m + 1]; ++i) { \
CAFFE_ENFORCE_LT(current, index_size); \
IndexType idx = indices[current]; \
CAFFE_ENFORCE( \
0 <= idx && idx < data_size, \
"Index ", \
current, \
" is out of bounds: ", \
idx, \
", range 0 to ", \
data_size); \
++current; \
} \
} \
CAFFE_ENFORCE_EQ( \
current, \
index_size, \
"Your input seems to be incorrect: the sum of lengths values should be " \
"the size of the indices tensor, but it appears not."); \
}
// clang-format on
EMBEDDING_IDX_SPECIALIZATION(int32_t, float, float, float, false);
EMBEDDING_IDX_SPECIALIZATION(int64_t, float, float, float, false);
EMBEDDING_IDX_SPECIALIZATION(int32_t, half, at::Half, float, false);
EMBEDDING_IDX_SPECIALIZATION(int64_t, half, at::Half, float, false);
EMBEDDING_IDX_SPECIALIZATION(int32_t, bfloat16, at::BFloat16, float, false);
EMBEDDING_IDX_SPECIALIZATION(int64_t, bfloat16, at::BFloat16, float, false);
EMBEDDING_IDX_SPECIALIZATION(int32_t, uint8_t, uint8_t, float, false);
EMBEDDING_IDX_SPECIALIZATION(int64_t, uint8_t, uint8_t, float, false);
EMBEDDING_IDX_SPECIALIZATION(int32_t, float, float, float, true);
EMBEDDING_IDX_SPECIALIZATION(int64_t, float, float, float, true);
EMBEDDING_IDX_SPECIALIZATION(int32_t, half, at::Half, float, true);
EMBEDDING_IDX_SPECIALIZATION(int64_t, half, at::Half, float, true);
EMBEDDING_IDX_SPECIALIZATION(int32_t, bfloat16, at::BFloat16, float, true);
EMBEDDING_IDX_SPECIALIZATION(int64_t, bfloat16, at::BFloat16, float, true);
EMBEDDING_IDX_SPECIALIZATION(int32_t, uint8_t, uint8_t, float, true);
EMBEDDING_IDX_SPECIALIZATION(int64_t, uint8_t, uint8_t, float, true);
#undef EMBEDDING_IDX_SPECIALIZATION
} // namespace caffe2