| #include "caffe2/operators/batch_matmul_op.h" |
| |
| #include "caffe2/core/operator_schema.h" |
| #include "caffe2/core/types.h" |
| |
| namespace caffe2 { |
| |
| REGISTER_CPU_OPERATOR(BatchMatMul, BatchMatMulOp<CPUContext>); |
| |
| vector<TensorShape> TensorInferenceForBatchMatMul( |
| const OperatorDef& def, |
| const vector<TensorShape>& in) { |
| ArgumentHelper helper(def); |
| bool broadcast = helper.GetSingleArgument<int>("broadcast", 0); |
| if (!broadcast) { |
| const auto ndim = in[0].dims_size(); |
| CAFFE_ENFORCE_GE(ndim, 2); |
| CAFFE_ENFORCE_GE(in[1].dims_size(), 2); |
| // NOLINTNEXTLINE(cppcoreguidelines-init-variables) |
| int a_dim0; |
| // NOLINTNEXTLINE(cppcoreguidelines-init-variables) |
| int b_dim1; |
| if (helper.GetSingleArgument<int>("trans_a", 0)) { |
| a_dim0 = in[0].dims(ndim - 1); |
| } else { |
| a_dim0 = in[0].dims(ndim - 2); |
| } |
| |
| if (helper.GetSingleArgument<int>("trans_b", 0)) { |
| b_dim1 = in[1].dims(ndim - 2); |
| } else { |
| b_dim1 = in[1].dims(ndim - 1); |
| } |
| |
| auto output_dims = |
| vector<int64_t>{in[0].dims().begin(), in[0].dims().end()}; |
| output_dims[ndim - 2] = a_dim0; |
| output_dims[ndim - 1] = b_dim1; |
| |
| return vector<TensorShape>{ |
| CreateTensorShape(vector<int64_t>{output_dims}, in[0].data_type())}; |
| } else { |
| auto ndims_A = in[0].dims_size(); |
| auto ndims_B = in[1].dims_size(); |
| std::vector<int64_t> dims_A(ndims_A), dims_B(ndims_B); |
| for (int i = 0; i < ndims_A; ++i) { |
| dims_A[i] = in[0].dims(i); |
| } |
| for (int i = 0; i < ndims_B; ++i) { |
| dims_B[i] = in[1].dims(i); |
| } |
| bool A_broadcasted = false, B_broadcasted = false; |
| if (ndims_A == 1) { |
| dims_A.insert(dims_A.begin(), 1); |
| ndims_A = 2; |
| A_broadcasted = true; |
| } |
| if (ndims_B == 1) { |
| dims_B.push_back(1); |
| ndims_B = 2; |
| B_broadcasted = true; |
| } |
| // NOLINTNEXTLINE(cppcoreguidelines-init-variables) |
| size_t M, N; |
| if (helper.GetSingleArgument<int>("trans_a", 0)) { |
| M = dims_A[ndims_A - 1]; |
| } else { |
| M = dims_A[ndims_A - 2]; |
| } |
| if (helper.GetSingleArgument<int>("trans_b", 0)) { |
| N = dims_B[ndims_B - 2]; |
| } else { |
| N = dims_B[ndims_B - 1]; |
| } |
| |
| const int ndims = std::max(ndims_A, ndims_B); |
| std::vector<int64_t> new_dims(ndims - 2); |
| std::vector<int64_t> dims_A_broadcast(ndims - 2, 1); |
| std::vector<int64_t> dims_B_broadcast(ndims - 2, 1); |
| |
| std::copy_n(dims_A.begin(), ndims_A - 2, dims_A_broadcast.begin() + ndims - ndims_A); |
| std::copy_n(dims_B.begin(), ndims_B - 2, dims_B_broadcast.begin() + ndims - ndims_B); |
| for (int i = 0; i < ndims - 2; ++i) { |
| if (!dims_A_broadcast[i] || !dims_B_broadcast[i]) { |
| new_dims[i] = 0; |
| } else { |
| new_dims[i] = std::max(dims_A_broadcast[i], dims_B_broadcast[i]); |
| } |
| } |
| if (!A_broadcasted) { |
| new_dims.push_back(M); |
| } |
| if (!B_broadcasted) { |
| new_dims.push_back(N); |
| } |
| if (A_broadcasted && B_broadcasted) { |
| new_dims.push_back(1); |
| } |
| return vector<TensorShape>{ |
| CreateTensorShape(vector<int64_t>{new_dims}, in[0].data_type())}; |
| } |
| } |
| |
| OpSchema::Cost CostInferenceForBatchMatMul( |
| const OperatorDef& def, |
| const vector<TensorShape>& in) { |
| CAFFE_ENFORCE_EQ(in.size(), 2U, "BatchMatMul requires two inputs"); |
| |
| ArgumentHelper helper(def); |
| struct OpSchema::Cost c; |
| const auto& A = in[0]; |
| const auto& B = in[1]; |
| const TensorShape Y = TensorInferenceForBatchMatMul(def, in)[0]; |
| |
| uint64_t nElemA = nElemFromDim(A); |
| uint64_t nElemB = nElemFromDim(B); |
| uint64_t nElemY = nElemFromDim(Y); |
| |
| auto ndims_A = A.dims_size(); |
| // NOLINTNEXTLINE(cppcoreguidelines-init-variables) |
| size_t K; |
| if (helper.GetSingleArgument<int>("trans_a", 0)) { |
| K = in[0].dims(ndims_A - 2); |
| } else { |
| K = in[0].dims(ndims_A - 1); |
| } |
| |
| auto const& A_element_size_byte = |
| DataTypeToTypeMeta(A.data_type()).itemsize(); |
| auto const& Y_element_size_byte = |
| DataTypeToTypeMeta(Y.data_type()).itemsize(); |
| c.flops = 2 * nElemY * K; |
| c.bytes_read = (nElemA + nElemB) * A_element_size_byte; |
| c.bytes_written = nElemY * Y_element_size_byte; |
| c.params_bytes = 0; |
| return c; |
| } |
| |
| OPERATOR_SCHEMA(BatchMatMul) |
| .NumInputs(2) |
| .NumOutputs(1) |
| .SetDoc(R"DOC( |
| Batch Matrix multiplication Yi = Ai * Bi, where A has shape (dim0, dim1, ... M, K), |
| B has shape (dim0, dim1, ... K, N), Y has shape (dim0, dim1, ... M, N) and i ranges |
| from 0 to (dim0 * dim1 ...) - 1. rank(A) == rank(B) >= 2. In case of A and B being |
| two dimensional, it behaves like normal matrix multiplication. |
| )DOC") |
| .Input(0, "A", "tensor of shape (dim0, dim1 ... M, K)") |
| .Input(1, "B", "tensor of shape (dim0, dim1 ... K, N)") |
| .Output(0, "Y", "tensor of shape (dim0, dim1 ... M, N)") |
| .Arg( |
| "trans_a", |
| "Pass 1 to transpose the last two dimensions of A before " |
| "doing multiplication") |
| .Arg( |
| "trans_b", |
| "Pass 1 to transpose the last two dimensions of B before " |
| "doing multiplication") |
| .Arg( |
| "broadcast", |
| "Pass 1 to allow broadcasting of dimensions. Behavior is the same as numpy.matmul. Gradient is currently not supported when running in broadcast mode.") |
| .TensorInferenceFunction(TensorInferenceForBatchMatMul) |
| .CostInferenceFunction( |
| OpSchema::CostInferenceFunctionType(CostInferenceForBatchMatMul)) |
| .InheritOnnxSchema(); |
| |
| class GetBatchMatMulGradient : public GradientMakerBase { |
| using GradientMakerBase::GradientMakerBase; |
| vector<OperatorDef> GetGradientDefs() override { |
| CAFFE_ENFORCE_EQ(def_.input_size(), 2); |
| |
| bool broadcast = false; |
| if (ArgumentHelper::HasArgument(Def(), "broadcast")) { |
| broadcast = GetArgument(Def(), "broadcast").i(); |
| } |
| CAFFE_ENFORCE( |
| !broadcast, |
| "Gradient is currently not supported with " |
| "broadcast=1 for BatchMatMul."); |
| |
| // NOLINTNEXTLINE(modernize-use-bool-literals) |
| bool trans_a = 0; |
| // NOLINTNEXTLINE(modernize-use-bool-literals) |
| bool trans_b = 0; |
| |
| if (ArgumentHelper::HasArgument(Def(), "trans_a")) { |
| trans_a = GetArgument(Def(), "trans_a").i(); |
| } |
| if (ArgumentHelper::HasArgument(Def(), "trans_b")) { |
| trans_b = GetArgument(Def(), "trans_b").i(); |
| } |
| |
| auto no_trans_arg = vector<Argument>(); |
| auto trans_a_arg = vector<Argument>{MakeArgument<int>("trans_a", 1)}; |
| auto trans_b_arg = vector<Argument>{MakeArgument<int>("trans_b", 1)}; |
| auto trans_both_arg = vector<Argument>{ |
| MakeArgument<int>("trans_a", 1), MakeArgument<int>("trans_b", 1)}; |
| |
| if (trans_a) { |
| if (trans_b) { |
| // A'B': |
| // dA = B'G', dB = G'A' |
| return vector<OperatorDef>{ |
| CreateOperatorDef( |
| "BatchMatMul", |
| "", |
| vector<string>{I(1), GO(0)}, |
| vector<string>{GI(0)}, |
| trans_both_arg), |
| CreateOperatorDef( |
| "BatchMatMul", |
| "", |
| vector<string>{GO(0), I(0)}, |
| vector<string>{GI(1)}, |
| trans_both_arg)}; |
| } else { |
| // A'B: |
| // dA = BG', dB = AG |
| return vector<OperatorDef>{ |
| CreateOperatorDef( |
| "BatchMatMul", |
| "", |
| vector<string>{I(1), GO(0)}, |
| vector<string>{GI(0)}, |
| trans_b_arg), |
| CreateOperatorDef( |
| "BatchMatMul", |
| "", |
| vector<string>{I(0), GO(0)}, |
| vector<string>{GI(1)}, |
| no_trans_arg)}; |
| } |
| } else { |
| if (trans_b) { |
| // AB': |
| // dA = GB, dB = G'A |
| return vector<OperatorDef>{ |
| CreateOperatorDef( |
| "BatchMatMul", |
| "", |
| vector<string>{GO(0), I(1)}, |
| vector<string>{GI(0)}, |
| no_trans_arg), |
| CreateOperatorDef( |
| "BatchMatMul", |
| "", |
| vector<string>{GO(0), I(0)}, |
| vector<string>{GI(1)}, |
| trans_a_arg)}; |
| } else { |
| // AB: |
| // dA = GB', dB = A'G |
| return vector<OperatorDef>{ |
| CreateOperatorDef( |
| "BatchMatMul", |
| "", |
| vector<string>{GO(0), I(1)}, |
| vector<string>{GI(0)}, |
| trans_b_arg), |
| CreateOperatorDef( |
| "BatchMatMul", |
| "", |
| vector<string>{I(0), GO(0)}, |
| vector<string>{GI(1)}, |
| trans_a_arg)}; |
| } |
| } |
| } |
| |
| bool CopyArguments() const override { |
| return false; |
| } |
| }; |
| |
| REGISTER_GRADIENT(BatchMatMul, GetBatchMatMulGradient); |
| |
| } // namespace caffe2 |