blob: 17a75da727aecb7e181e3e593d6d392b41b2d8e9 [file] [log] [blame]
#include "caffe2/operators/matmul_op.h"
namespace caffe2 {
REGISTER_CPU_OPERATOR(MatMul, MatMulOp<float, CPUContext>);
OPERATOR_SCHEMA(MatMul)
.NumInputs(2, 3)
.NumOutputs(1)
.TensorInferenceFunction([](const OperatorDef& def,
const vector<TensorShape>& in) {
vector<TensorShape> out(1);
out[0].set_data_type(in[0].data_type());
ArgumentHelper arg_helper(def);
int axis_a = arg_helper.GetSingleArgument<int>("axis_a", 1);
int axis_b = arg_helper.GetSingleArgument<int>("axis_b", 1);
int trans_a = arg_helper.GetSingleArgument<bool>("trans_a", false);
int trans_b = arg_helper.GetSingleArgument<bool>("trans_b", false);
int canonical_axis_a = canonical_axis_index_(axis_a, in[0].dims().size());
int canonical_axis_b = canonical_axis_index_(axis_b, in[0].dims().size());
int M = size_to_dim_(canonical_axis_a, GetDimsVector(in[0]));
int N = size_from_dim_(canonical_axis_b, GetDimsVector(in[1]));
if (trans_a) {
M = size_from_dim_(canonical_axis_a, GetDimsVector(in[0]));
}
if (trans_b) {
N = size_to_dim_(canonical_axis_b, GetDimsVector(in[1]));
}
out[0].add_dims(M);
out[0].add_dims(N);
return out;
})
.SetDoc(R"DOC(
Matrix multiplication $Y = A * B$, where `A` has size (M x K), `B` has size
(K x N), and `Y` will have a size (M x N). To transpose `A` or `B` before
multiplication, pass 1 to the `trans_a` and/or `trans_b` arguments, which
separate the first and second dimensions of the respective matrices using
`axis_a` and `axis_b`.
Github Links:
- https://github.com/pytorch/pytorch/blob/main/caffe2/operators/matmul_op.cc
<details>
<summary> <b>Example</b> </summary>
**Code**
```
workspace.ResetWorkspace()
op = core.CreateOperator(
"MatMul",
["A", "B"],
["Y"],
)
workspace.FeedBlob("A", np.random.randint(10, size=(3,3)).astype(np.float32))
workspace.FeedBlob("B", np.random.randint(10, size=(3,3)).astype(np.float32))
print("A:", workspace.FetchBlob("A"))
print("B:", workspace.FetchBlob("B"))
workspace.RunOperatorOnce(op)
print("Y:", workspace.FetchBlob("Y"))
```
**Result**
```
A: [[1. 8. 3.]
[6. 4. 4.]
[5. 4. 7.]]
B: [[4. 0. 3.]
[3. 1. 1.]
[8. 5. 8.]]
Y: [[52. 23. 35.]
[68. 24. 54.]
[88. 39. 75.]]
```
</details>
)DOC")
.Input(
0,
"A",
"*(type: Tensor`<float>`)* 2D matrix of size (M x K).")
.Input(
1,
"B",
"*(type: Tensor`<float>`)* 2D matrix of size (K x N).")
.Output(
0,
"Y",
"*(type: Tensor`<float>`)* 2D matrix of size (M x N).")
.Arg(
"axis_a",
"*(type: int; default: 1)* Exclusive axis that divides the first and "
"second dimension of matrix `A`.")
.Arg(
"axis_b",
"*(type: int; default: 1)* Exclusive axis that divides the first and "
"second dimension of matrix `B`.")
.Arg(
"trans_a",
"*(type: int; default: 0)* Pass 1 to transpose `A` before multiplication and "
"after the dimension adjustment using `axis_a`.")
.Arg(
"trans_b",
"*(type: int; default: 0)* Pass 1 to transpose `B` before multiplication and "
"after the dimension adjustment using `axis_b`.");
class GetMatMulGradient : public GradientMakerBase {
using GradientMakerBase::GradientMakerBase;
vector<OperatorDef> GetGradientDefs() override {
CAFFE_ENFORCE(def_.input_size() == 2 || def_.input_size() == 3);
// NOLINTNEXTLINE(modernize-use-bool-literals)
bool axis_a = 1;
// NOLINTNEXTLINE(modernize-use-bool-literals)
bool axis_b = 1;
// NOLINTNEXTLINE(modernize-use-bool-literals)
bool trans_a = 0;
// NOLINTNEXTLINE(modernize-use-bool-literals)
bool trans_b = 0;
if (ArgumentHelper::HasArgument(Def(), "trans_a")) {
trans_a = GetArgument(Def(), "trans_a").i();
}
if (ArgumentHelper::HasArgument(Def(), "trans_b")) {
trans_b = GetArgument(Def(), "trans_b").i();
}
if (ArgumentHelper::HasArgument(Def(), "axis_a")) {
axis_a = GetArgument(Def(), "axis_a").i();
}
if (ArgumentHelper::HasArgument(Def(), "axis_b")) {
axis_b = GetArgument(Def(), "axis_b").i();
}
if (trans_a) {
if (trans_b) {
// A'B':
// dA = B'G', dB = G'A'
return vector<OperatorDef>{
CreateOperatorDef(
"MatMul",
"",
vector<string>{I(1), GO(0), I(0)},
vector<string>{GI(0)},
vector<Argument>{MakeArgument<int>("trans_a", 1),
MakeArgument<int>("trans_b", 1),
MakeArgument<int>("axis_a", axis_b)}),
CreateOperatorDef(
"MatMul",
"",
vector<string>{GO(0), I(0), I(1)},
vector<string>{GI(1)},
vector<Argument>{MakeArgument<int>("trans_a", 1),
MakeArgument<int>("trans_b", 1),
MakeArgument<int>("axis_b", axis_a)})};
} else {
// A'B:
// dA = BG', dB = AG
return vector<OperatorDef>{
CreateOperatorDef(
"MatMul",
"",
vector<string>{I(1), GO(0), I(0)},
vector<string>{GI(0)},
vector<Argument>{MakeArgument<int>("trans_b", 1),
MakeArgument<int>("axis_a", axis_b)}),
CreateOperatorDef(
"MatMul",
"",
vector<string>{I(0), GO(0), I(1)},
vector<string>{GI(1)},
vector<Argument>{MakeArgument<int>("axis_a", axis_a)})};
}
} else {
if (trans_b) {
// AB':
// dA = GB, dB = G'A
return vector<OperatorDef>{
CreateOperatorDef(
"MatMul",
"",
vector<string>{GO(0), I(1), I(0)},
vector<string>{GI(0)},
vector<Argument>{MakeArgument<int>("axis_b", axis_b)}),
CreateOperatorDef(
"MatMul",
"",
vector<string>{GO(0), I(0), I(1)},
vector<string>{GI(1)},
vector<Argument>{MakeArgument<int>("trans_a", 1),
MakeArgument<int>("axis_b", axis_a)})};
} else {
// AB:
// dA = GB', dB = A'G
return vector<OperatorDef>{
CreateOperatorDef(
"MatMul",
"",
vector<string>{GO(0), I(1), I(0)},
vector<string>{GI(0)},
vector<Argument>{MakeArgument<int>("trans_b", 1),
MakeArgument<int>("axis_b", axis_b)}),
CreateOperatorDef(
"MatMul",
"",
vector<string>{I(0), GO(0), I(1)},
vector<string>{GI(1)},
vector<Argument>{MakeArgument<int>("trans_a", 1),
MakeArgument<int>("axis_a", axis_a)})};
}
}
}
bool CopyArguments() const override {
return false;
}
};
REGISTER_GRADIENT(MatMul, GetMatMulGradient);
} // namespace caffe2