| #ifndef CAFFE2_OPERATORS_MATMUL_OP_H_ |
| #define CAFFE2_OPERATORS_MATMUL_OP_H_ |
| |
| #include "caffe2/core/context.h" |
| #include "caffe2/core/operator.h" |
| #include "caffe2/utils/math.h" |
| |
| namespace caffe2 { |
| |
| template <typename T, class Context, class Engine = DefaultEngine> |
| class MatMulOp final : public Operator<Context> { |
| public: |
| USE_OPERATOR_CONTEXT_FUNCTIONS; |
| template <class... Args> |
| explicit MatMulOp(Args&&... args) |
| : Operator<Context>(std::forward<Args>(args)...), |
| axis_a_(this->template GetSingleArgument<int>("axis_a", 1)), |
| axis_b_(this->template GetSingleArgument<int>("axis_b", 1)), |
| trans_a_(this->template GetSingleArgument<int>("trans_a", 0)), |
| trans_b_(this->template GetSingleArgument<int>("trans_b", 0)) {} |
| ~MatMulOp() {} |
| |
| bool RunOnDevice() override { |
| const auto& A = Input(0); |
| const auto& B = Input(1); |
| |
| const auto canonical_axis_a = A.canonical_axis_index(axis_a_); |
| const auto canonical_axis_b = B.canonical_axis_index(axis_b_); |
| int A_dim0 = A.size_to_dim(canonical_axis_a); |
| int A_dim1 = A.size_from_dim(canonical_axis_a); |
| int B_dim0 = B.size_to_dim(canonical_axis_b); |
| int B_dim1 = B.size_from_dim(canonical_axis_b); |
| |
| int a_dim0, a_dim1, b_dim0, b_dim1; |
| |
| if (trans_a_) { |
| a_dim0 = A_dim1; |
| a_dim1 = A_dim0; |
| } else { |
| a_dim0 = A_dim0; |
| a_dim1 = A_dim1; |
| } |
| |
| if (trans_b_) { |
| b_dim0 = B_dim1; |
| b_dim1 = B_dim0; |
| } else { |
| b_dim0 = B_dim0; |
| b_dim1 = B_dim1; |
| } |
| |
| auto dimErrorString = [&]() { |
| return c10::str( |
| "Dimension mismatch: ", |
| trans_a_ ? "trans(A): " : "A: ", |
| a_dim0, |
| " ", |
| a_dim1, |
| trans_b_ ? ", trans(B): " : ", B: ", |
| b_dim0, |
| " ", |
| b_dim1); |
| }; |
| // Error checking |
| CAFFE_ENFORCE(a_dim1 == b_dim0, dimErrorString()); |
| |
| Y_shape_cache_[0] = a_dim0; |
| Y_shape_cache_[1] = b_dim1; |
| auto* Y = Output(0, Y_shape_cache_, at::dtype<T>()); |
| CAFFE_ENFORCE(a_dim0 * b_dim1 == Y->numel(), dimErrorString()); |
| // Y = A * B |
| math::Gemm<T, Context, Engine>( |
| trans_a_ ? CblasTrans : CblasNoTrans, |
| trans_b_ ? CblasTrans : CblasNoTrans, |
| a_dim0, |
| b_dim1, |
| a_dim1, |
| 1, |
| A.template data<T>(), |
| B.template data<T>(), |
| 0, |
| Y->template mutable_data<T>(), |
| &context_); |
| |
| if (InputSize() == 3) { |
| // In gradient op, resize to input |
| Y->ResizeLike(Input(2)); |
| } |
| return true; |
| } |
| |
| protected: |
| // A local vector to cache the output shape so we don't need to recreate |
| // a vector object every time we run Run(). |
| vector<int64_t> Y_shape_cache_{0, 0}; |
| int axis_a_{1}; |
| int axis_b_{1}; |
| bool trans_a_; |
| bool trans_b_; |
| }; |
| |
| } // namespace caffe2 |
| |
| #endif // CAFFE2_OPERATORS_MATMUL_OP_H_ |