| #include "caffe2/operators/matmul_op.h" |
| |
| namespace caffe2 { |
| |
| REGISTER_CPU_OPERATOR(MatMul, MatMulOp<float, CPUContext>); |
| |
| OPERATOR_SCHEMA(MatMul) |
| .NumInputs(2, 3) |
| .NumOutputs(1) |
| .TensorInferenceFunction([](const OperatorDef& def, |
| const vector<TensorShape>& in) { |
| vector<TensorShape> out(1); |
| out[0].set_data_type(in[0].data_type()); |
| ArgumentHelper arg_helper(def); |
| int axis_a = arg_helper.GetSingleArgument<int>("axis_a", 1); |
| int axis_b = arg_helper.GetSingleArgument<int>("axis_b", 1); |
| int trans_a = arg_helper.GetSingleArgument<bool>("trans_a", false); |
| int trans_b = arg_helper.GetSingleArgument<bool>("trans_b", false); |
| int canonical_axis_a = canonical_axis_index_(axis_a, in[0].dims().size()); |
| int canonical_axis_b = canonical_axis_index_(axis_b, in[0].dims().size()); |
| |
| int M = size_to_dim_(canonical_axis_a, GetDimsVector(in[0])); |
| int N = size_from_dim_(canonical_axis_b, GetDimsVector(in[1])); |
| if (trans_a) { |
| M = size_from_dim_(canonical_axis_a, GetDimsVector(in[0])); |
| } |
| if (trans_b) { |
| N = size_to_dim_(canonical_axis_b, GetDimsVector(in[1])); |
| } |
| |
| out[0].add_dims(M); |
| out[0].add_dims(N); |
| |
| return out; |
| }) |
| .SetDoc(R"DOC( |
| Matrix multiplication $Y = A * B$, where `A` has size (M x K), `B` has size |
| (K x N), and `Y` will have a size (M x N). To transpose `A` or `B` before |
| multiplication, pass 1 to the `trans_a` and/or `trans_b` arguments, which |
| separate the first and second dimensions of the respective matrices using |
| `axis_a` and `axis_b`. |
| |
| Github Links: |
| |
| - https://github.com/pytorch/pytorch/blob/main/caffe2/operators/matmul_op.cc |
| |
| <details> |
| |
| <summary> <b>Example</b> </summary> |
| |
| **Code** |
| |
| ``` |
| workspace.ResetWorkspace() |
| |
| op = core.CreateOperator( |
| "MatMul", |
| ["A", "B"], |
| ["Y"], |
| ) |
| |
| workspace.FeedBlob("A", np.random.randint(10, size=(3,3)).astype(np.float32)) |
| workspace.FeedBlob("B", np.random.randint(10, size=(3,3)).astype(np.float32)) |
| print("A:", workspace.FetchBlob("A")) |
| print("B:", workspace.FetchBlob("B")) |
| workspace.RunOperatorOnce(op) |
| print("Y:", workspace.FetchBlob("Y")) |
| ``` |
| |
| **Result** |
| |
| ``` |
| A: [[1. 8. 3.] |
| [6. 4. 4.] |
| [5. 4. 7.]] |
| B: [[4. 0. 3.] |
| [3. 1. 1.] |
| [8. 5. 8.]] |
| Y: [[52. 23. 35.] |
| [68. 24. 54.] |
| [88. 39. 75.]] |
| ``` |
| |
| </details> |
| |
| )DOC") |
| .Input( |
| 0, |
| "A", |
| "*(type: Tensor`<float>`)* 2D matrix of size (M x K).") |
| .Input( |
| 1, |
| "B", |
| "*(type: Tensor`<float>`)* 2D matrix of size (K x N).") |
| .Output( |
| 0, |
| "Y", |
| "*(type: Tensor`<float>`)* 2D matrix of size (M x N).") |
| .Arg( |
| "axis_a", |
| "*(type: int; default: 1)* Exclusive axis that divides the first and " |
| "second dimension of matrix `A`.") |
| .Arg( |
| "axis_b", |
| "*(type: int; default: 1)* Exclusive axis that divides the first and " |
| "second dimension of matrix `B`.") |
| .Arg( |
| "trans_a", |
| "*(type: int; default: 0)* Pass 1 to transpose `A` before multiplication and " |
| "after the dimension adjustment using `axis_a`.") |
| .Arg( |
| "trans_b", |
| "*(type: int; default: 0)* Pass 1 to transpose `B` before multiplication and " |
| "after the dimension adjustment using `axis_b`."); |
| |
| class GetMatMulGradient : public GradientMakerBase { |
| using GradientMakerBase::GradientMakerBase; |
| vector<OperatorDef> GetGradientDefs() override { |
| CAFFE_ENFORCE(def_.input_size() == 2 || def_.input_size() == 3); |
| |
| // NOLINTNEXTLINE(modernize-use-bool-literals) |
| bool axis_a = 1; |
| // NOLINTNEXTLINE(modernize-use-bool-literals) |
| bool axis_b = 1; |
| // NOLINTNEXTLINE(modernize-use-bool-literals) |
| bool trans_a = 0; |
| // NOLINTNEXTLINE(modernize-use-bool-literals) |
| bool trans_b = 0; |
| |
| if (ArgumentHelper::HasArgument(Def(), "trans_a")) { |
| trans_a = GetArgument(Def(), "trans_a").i(); |
| } |
| if (ArgumentHelper::HasArgument(Def(), "trans_b")) { |
| trans_b = GetArgument(Def(), "trans_b").i(); |
| } |
| if (ArgumentHelper::HasArgument(Def(), "axis_a")) { |
| axis_a = GetArgument(Def(), "axis_a").i(); |
| } |
| if (ArgumentHelper::HasArgument(Def(), "axis_b")) { |
| axis_b = GetArgument(Def(), "axis_b").i(); |
| } |
| |
| if (trans_a) { |
| if (trans_b) { |
| // A'B': |
| // dA = B'G', dB = G'A' |
| return vector<OperatorDef>{ |
| CreateOperatorDef( |
| "MatMul", |
| "", |
| vector<string>{I(1), GO(0), I(0)}, |
| vector<string>{GI(0)}, |
| vector<Argument>{MakeArgument<int>("trans_a", 1), |
| MakeArgument<int>("trans_b", 1), |
| MakeArgument<int>("axis_a", axis_b)}), |
| CreateOperatorDef( |
| "MatMul", |
| "", |
| vector<string>{GO(0), I(0), I(1)}, |
| vector<string>{GI(1)}, |
| vector<Argument>{MakeArgument<int>("trans_a", 1), |
| MakeArgument<int>("trans_b", 1), |
| MakeArgument<int>("axis_b", axis_a)})}; |
| } else { |
| // A'B: |
| // dA = BG', dB = AG |
| return vector<OperatorDef>{ |
| CreateOperatorDef( |
| "MatMul", |
| "", |
| vector<string>{I(1), GO(0), I(0)}, |
| vector<string>{GI(0)}, |
| vector<Argument>{MakeArgument<int>("trans_b", 1), |
| MakeArgument<int>("axis_a", axis_b)}), |
| CreateOperatorDef( |
| "MatMul", |
| "", |
| vector<string>{I(0), GO(0), I(1)}, |
| vector<string>{GI(1)}, |
| vector<Argument>{MakeArgument<int>("axis_a", axis_a)})}; |
| } |
| } else { |
| if (trans_b) { |
| // AB': |
| // dA = GB, dB = G'A |
| return vector<OperatorDef>{ |
| CreateOperatorDef( |
| "MatMul", |
| "", |
| vector<string>{GO(0), I(1), I(0)}, |
| vector<string>{GI(0)}, |
| vector<Argument>{MakeArgument<int>("axis_b", axis_b)}), |
| CreateOperatorDef( |
| "MatMul", |
| "", |
| vector<string>{GO(0), I(0), I(1)}, |
| vector<string>{GI(1)}, |
| vector<Argument>{MakeArgument<int>("trans_a", 1), |
| MakeArgument<int>("axis_b", axis_a)})}; |
| } else { |
| // AB: |
| // dA = GB', dB = A'G |
| return vector<OperatorDef>{ |
| CreateOperatorDef( |
| "MatMul", |
| "", |
| vector<string>{GO(0), I(1), I(0)}, |
| vector<string>{GI(0)}, |
| vector<Argument>{MakeArgument<int>("trans_b", 1), |
| MakeArgument<int>("axis_b", axis_b)}), |
| CreateOperatorDef( |
| "MatMul", |
| "", |
| vector<string>{I(0), GO(0), I(1)}, |
| vector<string>{GI(1)}, |
| vector<Argument>{MakeArgument<int>("trans_a", 1), |
| MakeArgument<int>("axis_a", axis_a)})}; |
| } |
| } |
| } |
| |
| bool CopyArguments() const override { |
| return false; |
| } |
| }; |
| |
| REGISTER_GRADIENT(MatMul, GetMatMulGradient); |
| |
| } // namespace caffe2 |