blob: 83199db8fe66ed8187a9a6ab72c39d04ea2b87ac [file] [log] [blame]
#pragma once
#include <c10/util/irange.h>
#include "caffe2/core/context.h"
#include "caffe2/core/operator.h"
#include "caffe2/perfkernels/embedding_lookup.h"
#ifdef USE_FBGEMM
#include "fbgemm/Fbgemm.h"
#endif
#include <algorithm>
#include <functional>
namespace caffe2 {
// A templated class that implements SparseLengths[Sum,WeightedSum,Mean].
template <
typename T, // output type
class InputTypes, // supported input types, such as TensorTypes<float>
bool USE_WEIGHT = false, // Whether it is SparseLengthsWeightedSum
bool USE_MEAN = false, // Whether this is SparseLengthsMean
bool USE_POSITIONAL_WEIGHT = false
// USE_WEIGHT = true and USE_POSITIONAL_WEIGHT = true
// -> SparseLengthsPositionalWeightedSum
>
class CPUSparseLengthsReductionOp : public Operator<CPUContext> {
public:
USE_OPERATOR_FUNCTIONS(CPUContext);
template <class... Args>
explicit CPUSparseLengthsReductionOp(Args&&... args)
: Operator<CPUContext>(std::forward<Args>(args)...) {
static_assert(
!(USE_WEIGHT & USE_MEAN), "Cannot both specify weight and mean.");
}
~CPUSparseLengthsReductionOp() {}
// Currently, we support float and at::Half inputs for input data type, and
// int32_t and int64_t for the index type.
bool RunOnDevice() override {
return DispatchHelper<InputTypes>::call(this, Input(DATA));
}
template <typename InputType>
bool DoRunWithType() {
return DispatchHelper<TensorTypes2<int32_t, int64_t>, InputType>::call(
this, Input(INDICES));
}
template <typename InputType, typename IndexType>
bool DoRunWithType2() {
auto& dataInput = Input(DATA);
auto& indicesInput = Input(INDICES);
auto& lengthsInput = Input(LENGTHS);
const int64_t M = lengthsInput.size(0);
const int64_t indices_size = indicesInput.numel();
auto shape = dataInput.sizes().vec();
shape[0] = M;
auto* output = Output(0, shape, at::dtype<T>());
T* out_data = output->template mutable_data<T>();
if (indices_size == 0) {
if (M > 0) {
memset(out_data, 0, output->numel() * sizeof(T));
}
return true;
}
CAFFE_ENFORCE_EQ(1, indicesInput.dim(), "INDICES must be a vector");
CAFFE_ENFORCE_EQ(1, lengthsInput.dim(), "LENGTHS must be a vector");
const int64_t N = dataInput.size(0);
const int D = dataInput.size_from_dim(1);
const InputType* in_data = dataInput.template data<InputType>();
const IndexType* indices = indicesInput.template data<IndexType>();
const int* lengths = lengthsInput.template data<int>();
const T* in_weight = nullptr;
if (USE_WEIGHT) {
// static if
auto& weightInput = Input(WEIGHT);
CAFFE_ENFORCE_EQ(1, weightInput.dim(), "WEIGHT must be a vector");
if (!USE_POSITIONAL_WEIGHT) {
CAFFE_ENFORCE_EQ(
weightInput.numel(),
indices_size,
"Weight should have the same length as indices.");
}
in_weight = weightInput.template data<T>();
}
#ifdef USE_FBGEMM
// If this is the first call or block size has changed (should never
// happen actually), generate a kernel.
if (D != last_block_size) {
last_block_size = D;
if (std::is_same<InputType, float>::value) {
if (std::is_same<IndexType, std::int32_t>::value) {
kernel_fp32_i32_ =
fbgemm::GenerateEmbeddingSpMDM<float, std::int32_t>(
D,
USE_WEIGHT,
USE_MEAN,
/*prefetch distance*/ 16,
USE_POSITIONAL_WEIGHT,
/*use_offsets*/ false);
} else {
CAFFE_ENFORCE((std::is_same<IndexType, std::int64_t>::value));
kernel_fp32_i64_ =
fbgemm::GenerateEmbeddingSpMDM<float, std::int64_t>(
D,
USE_WEIGHT,
USE_MEAN,
/*prefetch distance*/ 16,
USE_POSITIONAL_WEIGHT,
/*use_offsets*/ false);
}
} else {
CAFFE_ENFORCE((std::is_same<InputType, at::Half>::value));
if (std::is_same<IndexType, std::int32_t>::value) {
kernel_fp16_i32_ =
fbgemm::GenerateEmbeddingSpMDM<fbgemm::float16, std::int32_t>(
D,
USE_WEIGHT,
USE_MEAN,
/*prefetch distance*/ 16,
USE_POSITIONAL_WEIGHT,
/*use_offsets*/ false);
} else {
CAFFE_ENFORCE((std::is_same<IndexType, std::int64_t>::value));
kernel_fp16_i64_ =
fbgemm::GenerateEmbeddingSpMDM<fbgemm::float16, std::int64_t>(
D,
USE_WEIGHT,
USE_MEAN,
/*prefetch distance*/ 16,
USE_POSITIONAL_WEIGHT,
/*use_offsets*/ false);
}
}
}
bool success;
if (std::is_same<InputType, float>::value) {
if (std::is_same<IndexType, std::int32_t>::value) {
success = kernel_fp32_i32_(
M,
indices_size,
N,
reinterpret_cast<const float*>(in_data),
indicesInput.template data<std::int32_t>(),
lengths,
in_weight,
out_data);
} else {
success = kernel_fp32_i64_(
M,
indices_size,
N,
reinterpret_cast<const float*>(in_data),
indicesInput.template data<std::int64_t>(),
lengths,
in_weight,
out_data);
}
} else {
if (std::is_same<IndexType, std::int32_t>::value) {
success = kernel_fp16_i32_(
M,
indices_size,
N,
reinterpret_cast<const fbgemm::float16*>(in_data),
indicesInput.template data<std::int32_t>(),
lengths,
in_weight,
out_data);
} else {
success = kernel_fp16_i64_(
M,
indices_size,
N,
reinterpret_cast<const fbgemm::float16*>(in_data),
indicesInput.template data<std::int64_t>(),
lengths,
in_weight,
out_data);
}
}
if (success) {
return true;
}
int64_t current = 0;
for (const auto m : c10::irange(M)) {
for (int i = 0; i < lengths[m]; ++i) {
CAFFE_ENFORCE_LT(
current,
indices_size,
"Your input seems to be incorrect: the sum of lengths values "
"should be the size of the indices tensor, but it appears not.");
IndexType idx = indices[current];
CAFFE_ENFORCE(
0 <= idx && idx < N,
"Index ",
current,
" is out of bounds: ",
idx,
", range 0 to ",
N,
", actual batch length is ",
M);
++current;
}
}
CAFFE_ENFORCE_EQ(
current,
indices_size,
"Your input seems to be incorrect: the sum of lengths values should be "
"the size of the indices tensor, but it appears not.");
return false;
#endif
// delegate work to perfkernel that branches based on architecture
EmbeddingLookup<IndexType, InputType, T, USE_POSITIONAL_WEIGHT>(
D,
M,
indices_size,
N,
in_data,
indices,
lengths,
in_weight,
nullptr, // scale_bias field is only used in SparseLengths8BitsRowwiseOp
USE_MEAN,
out_data);
return true;
}
enum {
DATA = 0, // Data input.
WEIGHT = 1, // Weight input used in SparseLengthsWeightedSum
INDICES = 1 + USE_WEIGHT, // 1 in SparseLengths[Sum,Mean] and
// 2 in SparseLengthsWeightedSum
LENGTHS = 2 + USE_WEIGHT, // 2 in SparseLengths[Sum, Mean],
// 3 in SparseLengthsWeightedSum
};
#ifdef USE_FBGEMM
private:
std::int64_t last_block_size{-1};
fbgemm::EmbeddingSpMDMKernelSignature<float, std::int32_t>::Type
kernel_fp32_i32_;
fbgemm::EmbeddingSpMDMKernelSignature<float, std::int64_t>::Type
kernel_fp32_i64_;
fbgemm::EmbeddingSpMDMKernelSignature<fbgemm::float16, std::int32_t>::Type
kernel_fp16_i32_;
fbgemm::EmbeddingSpMDMKernelSignature<fbgemm::float16, std::int64_t>::Type
kernel_fp16_i64_;
#endif
};
template <typename T, class Context, class Engine = DefaultEngine>
class TTSparseLengthsSumOp final : public Operator<Context> {
public:
USE_OPERATOR_CONTEXT_FUNCTIONS;
template <class... Args>
explicit TTSparseLengthsSumOp(Args&&... args)
: Operator<Context>(std::forward<Args>(args)...),
factor_i(this->template GetRepeatedArgument<int>(
"factor_i",
vector<int>{1, 1, 1})),
factor_j(this->template GetRepeatedArgument<int>(
"factor_j",
vector<int>{1, 1, 1})),
ranks(this->template GetRepeatedArgument<int>(
"ranks",
vector<int>{1, 1, 1, 1})),
emb_size(this->template GetSingleArgument<int>("emb_size", 64)) {
// cumprod of i, used for index slice
l_cumprod.push_back(1);
for (const auto i : c10::irange(1, factor_i.size())) {
l_cumprod.push_back(l_cumprod[i - 1] * factor_i[i - 1]);
}
}
~TTSparseLengthsSumOp() {}
void Ind2Sub(int64_t* out_factor_index, const int64_t* indices, int len) {
// TODO: vectorization
auto N = factor_i.size();
for (const auto j : c10::irange(len)) {
auto idx = indices[j];
for (int i = N; i > 0; i--) {
out_factor_index[j * N + i - 1] = idx / l_cumprod[i - 1];
idx = idx % l_cumprod[i - 1];
}
}
}
bool GetSlice(
std::vector<std::vector<T>>& tgt_slice,
const T* core,
const vector<int64_t>& ind_slice,
int bs,
int idx) {
// implement the functinality index_select(core, 1, ind_slice)
auto num_of_elements = ranks[idx] * factor_j[idx] * ranks[idx + 1];
for (const auto i : c10::irange(bs)) {
memcpy(
tgt_slice[i].data(),
core + ind_slice[i] * num_of_elements,
num_of_elements * sizeof(T));
}
return true;
}
// ind: it stores the index to each tensor core
// bs: the number of indices
// GatherAllRows uses two steps to calculate the lengthsum functionality: 1) it uses tensor train
// to calculate the embedding for each index. 2) it sums the embedding for each bag.
// In Step 1), it batches all the indices together. Specifically, for every index, it uses the pre-computed
// ind of each tensor core to extract the corresponding slice of the core. Then it does gemm operation
// sequentially on the slices to produce the embedding result for each index.
// In Step 2), it takes the embedding computed in step 1) and apply the sum operation for each bag.
bool GatherAllRows(
int64_t* ind,
int bs,
int x_len,
vector<const T*> cores,
int segments,
const int* lengths,
T* out_data) {
// compute the largest memory consumption of intermediate result
// TODO: dynamic allocation size: cur_rows*factor_j[i]*ranks[i+1]
// and also explore the contiguous memory storage for res and int_res
int max_rank = *max_element(ranks.begin(), ranks.end());
std::vector<std::vector<T>> res(bs, std::vector<T>(emb_size * max_rank, 0));
std::vector<std::vector<T>> int_res(
bs, std::vector<T>(emb_size * max_rank, 0));
// Store the matrix A
vector<T*> Y_ptr(bs);
// Store the intermediate result in each layer
vector<T*> Z_ptr(bs);
for (const auto b : c10::irange(bs)) {
Y_ptr[b] = res[b].data();
Z_ptr[b] = int_res[b].data();
}
vector<int64_t> ind_slice(bs);
int rows = 0;
for (const auto i : c10::irange(x_len)) {
// slice cur
for (const auto j : c10::irange(bs)) {
ind_slice[j] = ind[x_len * j + i];
}
if (i == 0) {
GetSlice(res, cores[i], ind_slice, bs, i);
rows = factor_j[0];
} else {
std::vector<std::vector<T>> slice(
bs, std::vector<T>(ranks[i] * factor_j[i] * ranks[i + 1], 0));
vector<const T*> X_ptr(bs);
for (const auto b : c10::irange(bs)) {
X_ptr[b] = slice[b].data();
}
GetSlice(slice, cores[i], ind_slice, bs, i);
math::GemmBatched<T, CPUContext>(
CblasNoTrans,
CblasNoTrans,
bs,
rows,
factor_j[i] * ranks[i + 1],
ranks[i],
1.0f,
const_cast<const T**>(Y_ptr.data()),
X_ptr.data(),
0.0f,
Z_ptr.data(),
&context_);
for (const auto b : c10::irange(bs)) {
std::memcpy(Y_ptr[b], Z_ptr[b], (emb_size * max_rank) * sizeof(T));
}
rows *= factor_j[i];
}
// save the intermediate output for backward path
// shape for the core
auto shape = vector<int64_t>({bs, rows, ranks[i + 1]});
if (i < 2) {
auto* core_data = Output(i + 1, shape, at::dtype<T>());
T* out_core = core_data->template mutable_data<T>();
for (const auto b : c10::irange(bs)) {
std::memcpy(
out_core + b * rows * ranks[i + 1],
Y_ptr[b],
rows * ranks[i + 1] * sizeof(T));
}
}
}
// reduction and store back to output
vector<int64_t> cum_lengths(segments);
for (const auto seg : c10::irange(segments)) {
cum_lengths[seg] =
seg == 0 ? lengths[0] : lengths[seg] + cum_lengths[seg - 1];
}
int length_idx = 0;
vector<T> tmp_sum(emb_size, 0.0f);
for (int i = 0; i <= bs; i++) {
while ((length_idx < segments) && (i == cum_lengths[length_idx])) {
// store the tmp_sum into output
memcpy(
&out_data[length_idx * emb_size],
tmp_sum.data(),
emb_size * sizeof(T));
length_idx++;
fill(tmp_sum.begin(), tmp_sum.end(), 0.0f);
}
if (i == bs) {
break;
}
transform(
res[i].begin(),
res[i].begin() + emb_size,
tmp_sum.begin(),
tmp_sum.begin(),
std::plus<T>());
}
return true;
}
bool RunOnDevice() override {
const auto& dataInput0 = Input(0);
const auto& dataInput1 = Input(1);
const auto& dataInput2 = Input(2);
const auto& indicesInput = Input(3);
const auto& lengthsInput = Input(4);
CAFFE_ENFORCE_EQ(1, indicesInput.dim(), "INDICES must be a vector");
CAFFE_ENFORCE_EQ(1, lengthsInput.dim(), "LENGTHS must be a vector");
int N = factor_i.size();
const int64_t M = lengthsInput.size(0);
auto shape = vector<int64_t>({M, emb_size});
auto* output = Output(0, shape, at::dtype<T>());
T* out_data = output->template mutable_data<T>();
const T* core0 = dataInput0.template data<T>();
const T* core1 = dataInput1.template data<T>();
const T* core2 = dataInput2.template data<T>();
const int* lengths = lengthsInput.template data<int>();
vector<const T*> cores = {core0, core1, core2};
const int64_t* indices = indicesInput.template data<int64_t>();
// Store the factor index for backward path
auto index_shape = vector<int64_t>({indicesInput.size(), N});
auto* index_data = Output(3, index_shape, at::dtype<int64_t>());
int64_t* out_factor_index = index_data->template mutable_data<int64_t>();
// Store the factorized index for each core
Ind2Sub(out_factor_index, indices, indicesInput.size());
return GatherAllRows(
out_factor_index, indicesInput.size(), N, cores, M, lengths, out_data);
}
protected:
vector<int> factor_i;
vector<int> factor_j;
vector<int> ranks;
vector<int> l_cumprod;
int emb_size;
};
template <typename T, class Context>
class TTSparseLengthsSumGradientOp final : public Operator<Context> {
public:
USE_OPERATOR_CONTEXT_FUNCTIONS;
template <class... Args>
explicit TTSparseLengthsSumGradientOp(Args&&... args)
: Operator<Context>(std::forward<Args>(args)...) {}
bool RunOnDevice() override;
~TTSparseLengthsSumGradientOp() {}
};
// implement the graident op for TTLengthSumGradient op
template <typename T, class Context>
bool TTSparseLengthsSumGradientOp<T, Context>::RunOnDevice() {
const auto& core0 = Input(0);
const auto& core1 = Input(1);
const auto& core2 = Input(2);
const auto& lengths = Input(3);
const auto& core0_out = Input(4);
const auto& core1_out = Input(5);
const auto& index_out = Input(6);
const auto& dY = Input(7);
const int* lengths_data = lengths.template data<int>();
const T* dY_data = dY.template data<T>();
// restore the arguments from shape
const int64_t bs = index_out.size(0);
const int64_t emb_size = dY.size(1);
const int64_t num_segments = lengths.size(0);
auto core0_shape = core0.sizes().vec();
auto core1_shape = core1.sizes().vec();
auto core2_shape = core2.sizes().vec();
auto core0_out_shape = core0_out.sizes().vec();
auto core1_out_shape = core1_out.sizes().vec();
auto* dCore0 = Output(0, core0_shape, at::dtype<T>());
auto* dCore1 = Output(1, core1_shape, at::dtype<T>());
auto* dCore2 = Output(2, core2_shape, at::dtype<T>());
T* dCore0_data = dCore0->template mutable_data<T>();
T* dCore1_data = dCore1->template mutable_data<T>();
T* dCore2_data = dCore2->template mutable_data<T>();
memset(
dCore0_data,
0.0f,
sizeof(T) *
accumulate(
core0_shape.begin(), core0_shape.end(), 1, std::multiplies<T>()));
memset(
dCore1_data,
0.0f,
sizeof(T) *
accumulate(
core1_shape.begin(), core1_shape.end(), 1, std::multiplies<T>()));
memset(
dCore2_data,
0.0f,
sizeof(T) *
accumulate(
core2_shape.begin(), core2_shape.end(), 1, std::multiplies<T>()));
int64_t* index_out_data = index_out.template mutable_data<int64_t>();
vector<vector<int64_t>> index_slice(bs, vector<int64_t>(3, 0));
for (const auto b : c10::irange(bs)) {
memcpy(index_slice[b].data(), index_out_data + b * 3, 3 * sizeof(int64_t));
}
vector<const T*> A_ptr(bs);
vector<T*> B_ptr(bs);
vector<T*> C_ptr(bs);
// size of each batch
int64_t num_of_elements = 0;
// construct the ranks
// expand the gradient into all indices
vector<vector<T>> core2_out_grad(bs, vector<T>(emb_size, 0));
int64_t data_index = 0;
for (const auto range_index : c10::irange(num_segments)) {
for (int64_t start = data_index;
data_index < start + lengths_data[range_index];
++data_index) {
memcpy(
core2_out_grad[data_index].data(),
dY_data + range_index * emb_size,
emb_size * sizeof(T));
}
}
// =======================================================
// Calculate dCore2_data:
// 1) Transpose core1_out and multiply iwth core2_out_grad
// 2) add to dCore2_data
vector<vector<T>> dCore2_data_slice_grad(
bs, vector<T>(core2_shape[1] * core2_shape[2] * core2_shape[3], 0));
const T* core1_out_data = core1_out.template data<T>();
// const T* core1_out_p[bs];
for (const auto b : c10::irange(bs)) {
A_ptr[b] = core1_out_data + b * core1_out.size(1) * core1_out.size(2);
B_ptr[b] = core2_out_grad[b].data();
C_ptr[b] = dCore2_data_slice_grad[b].data();
}
math::GemmBatched<T, CPUContext>(
CblasTrans,
CblasNoTrans,
bs,
core2.size(1), // M
core2.size(2) * core2.size(3), // N
core1_out.size(1), // K
1.0f,
const_cast<const T**>(A_ptr.data()),
const_cast<const T**>(B_ptr.data()),
0.0f,
C_ptr.data(),
&context_);
// update the corresponding slice
num_of_elements = core2_shape[1] * core2_shape[2] * core2_shape[3];
T* core2_data = core2.template mutable_data<T>();
vector<vector<T>> core2_slice(
bs, vector<T>(core2_shape[1] * core2_shape[2] * core2_shape[3], 0));
for (const auto b : c10::irange(bs)) {
for (const auto i : c10::irange(num_of_elements)) {
dCore2_data[index_slice[b][2] * num_of_elements + i] += C_ptr[b][i];
}
memcpy(
core2_slice[b].data(),
core2_data + index_slice[b][2] * num_of_elements,
sizeof(T) * num_of_elements);
}
// Calculate core1_out_grad
vector<vector<T>> core1_out_grad(
bs, vector<T>(core1_out_shape[1] * core1_out_shape[2], 0));
for (const auto b : c10::irange(bs)) {
A_ptr[b] = core2_out_grad[b].data();
B_ptr[b] = core2_slice[b].data();
C_ptr[b] = core1_out_grad[b].data();
}
math::GemmBatched<T, CPUContext>(
CblasNoTrans,
CblasTrans,
bs,
core1_out.size(1), // M
core2_shape[1], // N
core2_shape[2] * core2_shape[3], // K
1.0f,
const_cast<const T**>(A_ptr.data()),
const_cast<const T**>(B_ptr.data()),
0.0f,
C_ptr.data(),
&context_);
// =======================================================
// Calcuate dCore1_data:
// 1) Transpose core1_out_grad and multiply with core0_out
// 2) Transpose the result and then add to dCore1_data
vector<vector<T>> dCore1_data_slice_grad(
bs, vector<T>(core1_shape[1] * core1_shape[2] * core1_shape[3], 0));
const T* core0_out_data = core0_out.template data<T>();
for (const auto b : c10::irange(bs)) {
A_ptr[b] = core0_out_data + b * core0_out.size(1) * core0_out.size(2);
B_ptr[b] = core1_out_grad[b].data();
C_ptr[b] = dCore1_data_slice_grad[b].data();
}
math::GemmBatched<T, CPUContext>(
CblasTrans,
CblasNoTrans,
bs,
core1.size(1), // M
core1.size(2) * core1.size(3), // N
core0_out.size(1), // K
1.0f,
const_cast<const T**>(A_ptr.data()),
const_cast<const T**>(B_ptr.data()),
0.0f,
C_ptr.data(),
&context_);
// update the corresponding slice
num_of_elements = core1_shape[1] * core1_shape[2] * core1_shape[3];
T* core1_data = core1.template mutable_data<T>();
vector<vector<T>> core1_slice(
bs, vector<T>(core1_shape[1] * core1_shape[2] * core1_shape[3], 0));
for (const auto b : c10::irange(bs)) {
for (const auto i : c10::irange(num_of_elements)) {
dCore1_data[index_slice[b][1] * num_of_elements + i] += C_ptr[b][i];
}
memcpy(
core1_slice[b].data(),
core1_data + index_slice[b][1] * num_of_elements,
sizeof(T) * num_of_elements);
}
// Calcuate core0_out_grad
vector<vector<T>> core0_out_grad(
bs, vector<T>(core0_out_shape[1] * core0_out_shape[2], 0));
for (const auto b : c10::irange(bs)) {
A_ptr[b] = core1_out_grad[b].data();
B_ptr[b] = core1_slice[b].data();
C_ptr[b] = core0_out_grad[b].data();
}
math::GemmBatched<T, CPUContext>(
CblasNoTrans,
CblasTrans,
bs,
core0_out.size(1), // M
core1_shape[1], // N
core1_shape[2] * core1_shape[3], // K
1.0f,
const_cast<const T**>(A_ptr.data()),
const_cast<const T**>(B_ptr.data()),
0.0f,
C_ptr.data(),
&context_);
num_of_elements = core0_shape[1] * core0_shape[2] * core0_shape[3];
for (const auto b : c10::irange(bs)) {
for (const auto i : c10::irange(num_of_elements)) {
dCore0_data[index_slice[b][0] * num_of_elements + i] += C_ptr[b][i];
}
}
return true;
}
} // namespace caffe2