| #include "caffe2/operators/batch_gather_ops.h" |
| |
| namespace caffe2 { |
| |
| REGISTER_CPU_OPERATOR(BatchGather, BatchGatherOp<CPUContext>); |
| REGISTER_CPU_OPERATOR(BatchGatherGradient, BatchGatherGradientOp<CPUContext>); |
| |
| OPERATOR_SCHEMA(BatchGather) |
| .NumInputs(2) |
| .NumOutputs(1) |
| .TensorInferenceFunction([](const OperatorDef& def, |
| const vector<TensorShape>& in) { |
| vector<TensorShape> out(1); |
| ArgumentHelper helper(def); |
| const auto& data_dims = GetDimsVector(in[0]); |
| const auto& indices_dims = GetDimsVector(in[1]); |
| |
| vector<int> output_dims = |
| caffe2::gather_helper::calc_output_shape_vector<int>( |
| data_dims, indices_dims, 1, false); |
| out[0] = CreateTensorShape(output_dims, TensorProto::FLOAT); |
| return out; |
| }) |
| .SetDoc(R"DOC( |
| Batch gather operation, first dimension in DATA is the batch size. |
| Given DATA tensor of rank r >= 2, and INDICES tensor of rank q >= 1, gather |
| entries of the second outer dimension (axis == 1) of DATA indexed by INDICES, |
| and concatenate them in an output tensor of rank q + (r - 1). |
| |
| Example: |
| DATA = [ |
| [1.0, 1.2, 2.4, 4.5], |
| [2.3, 3.4, 3.6, 2.3], |
| [4.5, 5.7, 1.2, 4.5], |
| ] |
| INDICES = [0, 2] |
| |
| OUTPUT = [ |
| [1.0, 2.4], |
| [2.3, 3.6], |
| [4.5, 1.2], |
| ] |
| )DOC") |
| .Input(0, "DATA", "Tensor of rank r >= 2.") |
| .Input(1, "INDICES", "Tensor of int32/int64 indices, of any rank q.") |
| .Output(0, "OUTPUT", "Tensor of rank q + (r - 1).") |
| .InheritOnnxSchema(); |
| |
| OPERATOR_SCHEMA(BatchGatherGradient).NumInputs(3).NumOutputs(1); |
| |
| class GetBatchGatherGradient : public GradientMakerBase { |
| using GradientMakerBase::GradientMakerBase; |
| vector<OperatorDef> GetGradientDefs() override { |
| using Op = BatchGatherOp<CPUContext>; |
| return SingleGradientDef( |
| "BatchGatherGradient", |
| "", |
| vector<string>{I(Op::DATA), I(Op::INDICES), GO(0)}, |
| vector<string>{GI(0)}); |
| } |
| }; |
| |
| REGISTER_GRADIENT(BatchGather, GetBatchGatherGradient); |
| |
| } // namespace caffe2 |