blob: 8f377df4d19b42296314814014747afadac89a3c [file] [log] [blame]
#ifndef CAFFE2_OPERATORS_BATCH_MATMUL_OP_H_
#define CAFFE2_OPERATORS_BATCH_MATMUL_OP_H_
#include <algorithm>
#include <functional>
#include <numeric>
#include <string>
#include <vector>
#include "caffe2/core/context.h"
#include "caffe2/core/operator.h"
#include "caffe2/utils/math.h"
namespace caffe2 {
template <class Context, class Engine = DefaultEngine>
class BatchMatMulOp final : public Operator<Context> {
public:
USE_OPERATOR_CONTEXT_FUNCTIONS;
template <class... Args>
explicit BatchMatMulOp(Args&&... args)
: Operator<Context>(std::forward<Args>(args)...),
OP_SINGLE_ARG(bool, "trans_a", trans_a_, false),
OP_SINGLE_ARG(bool, "trans_b", trans_b_, false),
OP_SINGLE_ARG(bool, "broadcast", broadcast_, false) {}
bool RunOnDevice() override {
return DispatchHelper<TensorTypes<float>>::call(this, Input(0));
}
template <typename T>
bool DoRunWithType() {
const auto& A = Input(0);
const auto& B = Input(1);
const int A_ndim = A.dim();
const int B_ndim = B.dim();
const std::vector<std::int64_t> A_dims = A.sizes().vec();
const std::vector<std::int64_t> B_dims = B.sizes().vec();
const T* A_data = A.template data<T>();
const T* B_data = B.template data<T>();
if (A_ndim == 1 && B_ndim == 1) {
CAFFE_ENFORCE_EQ(A.numel(), B.numel());
auto* Y = Output(0, {1}, at::dtype<T>());
T* Y_data = Y->template mutable_data<T>();
math::Dot<T, Context>(A.numel(), A_data, B_data, Y_data, &context_);
return true;
}
if (A_ndim == 1) {
const int N = A.numel();
if (trans_b_) {
CAFFE_ENFORCE_EQ(B_dims[B_ndim - 1], N);
} else {
CAFFE_ENFORCE_EQ(B_dims[B_ndim - 2], N);
}
std::vector<std::int64_t> Y_dims(B_ndim - 1);
if (trans_b_) {
std::copy_n(B_dims.cbegin(), B_ndim - 1, Y_dims.begin());
} else {
std::copy_n(B_dims.cbegin(), B_ndim - 2, Y_dims.begin());
Y_dims.back() = B_dims.back();
}
auto* Y = Output(0, Y_dims, at::dtype<T>());
T* Y_data = Y->template mutable_data<T>();
if (trans_b_) {
const int M = B.numel() / N;
math::Gemv<T, Context, Engine>(
CblasNoTrans, M, N, 1.0f, B_data, A_data, 0.0f, Y_data, &context_);
} else {
const int M = B_dims[B_ndim - 1];
const int batch_size = B.numel() / (M * N);
if (batch_size == 1) {
math::Gemv<T, Context, Engine>(
CblasTrans, N, M, 1.0f, B_data, A_data, 0.0f, Y_data, &context_);
} else {
math::GemmStridedBatched<T, Context, Engine>(
CblasTrans,
CblasNoTrans,
batch_size,
M,
1,
N,
1.0f,
B_data,
M * N,
A_data,
0,
0.0f,
Y_data,
M,
&context_);
}
}
return true;
}
if (B_ndim == 1) {
const int N = B.numel();
if (trans_a_) {
CAFFE_ENFORCE_EQ(A_dims[A_ndim - 2], N);
} else {
CAFFE_ENFORCE_EQ(A_dims[A_ndim - 1], N);
}
const std::vector<std::int64_t> Y_dims(
A_dims.cbegin(), A_dims.cbegin() + A_ndim - 1);
auto* Y = Output(0, Y_dims, at::dtype<T>());
T* Y_data = Y->template mutable_data<T>();
if (trans_a_) {
const int M = A_dims[A_ndim - 1];
const int batch_size = A.numel() / (M * N);
if (batch_size == 1) {
math::Gemv<T, Context, Engine>(
CblasTrans, N, M, 1.0f, A_data, B_data, 0.0f, Y_data, &context_);
} else {
math::GemmStridedBatched<T, Context, Engine>(
CblasTrans,
CblasNoTrans,
batch_size,
M,
1,
N,
1.0f,
A_data,
M * N,
B_data,
0,
0.0f,
Y_data,
M,
&context_);
}
} else {
const int M = A.numel() / N;
math::Gemv<T, Context, Engine>(
CblasNoTrans, M, N, 1.0f, A_data, B_data, 0.0f, Y_data, &context_);
}
return true;
}
const int M = trans_a_ ? A_dims[A_ndim - 1] : A_dims[A_ndim - 2];
const int K = trans_a_ ? A_dims[A_ndim - 2] : A_dims[A_ndim - 1];
if (trans_b_) {
CAFFE_ENFORCE_EQ(B_dims[B_ndim - 1], K);
} else {
CAFFE_ENFORCE_EQ(B_dims[B_ndim - 2], K);
}
const int N = trans_b_ ? B_dims[B_ndim - 2] : B_dims[B_ndim - 1];
const int ndim = std::max(A_ndim, B_ndim);
std::vector<std::int64_t> A_broadcast_dims(ndim);
std::vector<std::int64_t> B_broadcast_dims(ndim);
std::vector<std::int64_t> Y_broadcast_dims(ndim);
math::utils::ComputeBroadcastBinaryOpDims(
A_ndim - 2,
A_dims.data(),
B_ndim - 2,
B_dims.data(),
A_broadcast_dims.data(),
B_broadcast_dims.data(),
Y_broadcast_dims.data());
Y_broadcast_dims[ndim - 2] = M;
Y_broadcast_dims[ndim - 1] = N;
auto* Y = Output(0, Y_broadcast_dims, at::dtype<T>());
T* Y_data = Y->template mutable_data<T>();
const int batch_dim = ndim - 2;
const bool is_broadcast_dims = !std::equal(
A_broadcast_dims.cbegin(),
A_broadcast_dims.cbegin() + batch_dim,
B_broadcast_dims.cbegin());
if (is_broadcast_dims) {
CAFFE_ENFORCE(broadcast_);
}
const std::int64_t A_batch_size = std::accumulate(
A_broadcast_dims.cbegin(),
A_broadcast_dims.cbegin() + batch_dim,
1LL,
std::multiplies<std::int64_t>());
const std::int64_t B_batch_size = std::accumulate(
B_broadcast_dims.cbegin(),
B_broadcast_dims.cbegin() + batch_dim,
1LL,
std::multiplies<std::int64_t>());
const std::int64_t Y_batch_size = std::accumulate(
Y_broadcast_dims.cbegin(),
Y_broadcast_dims.cbegin() + batch_dim,
1LL,
std::multiplies<std::int64_t>());
if (Y_batch_size == 0) {
return true;
}
if (A_batch_size == 1 && B_batch_size == 1) {
math::Gemm<T, Context, Engine>(
trans_a_ ? CblasTrans : CblasNoTrans,
trans_b_ ? CblasTrans : CblasNoTrans,
M,
N,
K,
1.0f,
A_data,
B_data,
0.0f,
Y_data,
&context_);
} else if (A_batch_size == 1) {
if (M == 1 && trans_b_) {
math::Gemv<T, Context, Engine>(
CblasNoTrans,
B_batch_size * N,
K,
1.0f,
B_data,
A_data,
0.0f,
Y_data,
&context_);
} else {
math::GemmStridedBatched<T, Context, Engine>(
trans_a_ ? CblasTrans : CblasNoTrans,
trans_b_ ? CblasTrans : CblasNoTrans,
Y_batch_size,
M,
N,
K,
1.0f,
A_data,
0,
B_data,
K * N,
0.0f,
Y_data,
M * N,
&context_);
}
} else if (B_batch_size == 1) {
if (!trans_a_) {
math::Gemm<T, Context, Engine>(
CblasNoTrans,
trans_b_ ? CblasTrans : CblasNoTrans,
A_batch_size * M,
N,
K,
1.0f,
A_data,
B_data,
0.0f,
Y_data,
&context_);
} else {
math::GemmStridedBatched<T, Context, Engine>(
CblasTrans,
trans_b_ ? CblasTrans : CblasNoTrans,
Y_batch_size,
M,
N,
K,
1.0f,
A_data,
M * K,
B_data,
0,
0.0f,
Y_data,
M * N,
&context_);
}
} else if (!is_broadcast_dims) {
math::GemmStridedBatched<T, Context, Engine>(
trans_a_ ? CblasTrans : CblasNoTrans,
trans_b_ ? CblasTrans : CblasNoTrans,
Y_batch_size,
M,
N,
K,
1.0f,
A_data,
M * K,
B_data,
K * N,
0.0f,
Y_data,
M * N,
&context_);
} else {
std::vector<const T*> A_ptr(Y_batch_size);
std::vector<const T*> B_ptr(Y_batch_size);
std::vector<T*> Y_ptr(Y_batch_size);
std::vector<std::int64_t> index(batch_dim);
for (std::int64_t i = 0; i < Y_batch_size; ++i) {
const std::int64_t A_index = math::utils::GetIndexFromDims(
batch_dim, A_broadcast_dims.data(), index.data());
const std::int64_t B_index = math::utils::GetIndexFromDims(
batch_dim, B_broadcast_dims.data(), index.data());
A_ptr[i] = A_data + A_index * M * K;
B_ptr[i] = B_data + B_index * K * N;
Y_ptr[i] = Y_data + i * M * N;
math::utils::IncreaseIndexInDims(
batch_dim, Y_broadcast_dims.data(), index.data());
}
math::GemmBatched<T, Context, Engine>(
trans_a_ ? CblasTrans : CblasNoTrans,
trans_b_ ? CblasTrans : CblasNoTrans,
Y_batch_size,
M,
N,
K,
1.0f,
A_ptr.data(),
B_ptr.data(),
0.0f,
Y_ptr.data(),
&context_);
}
return true;
}
private:
const bool trans_a_;
const bool trans_b_;
const bool broadcast_;
};
} // namespace caffe2
#endif // CAFFE2_OPERATORS_BATCH_MATMUL_OP_H_