| #include "caffe2/operators/segment_reduction_op.h" |
| |
| namespace caffe2 { |
| |
| OpSchema::Cost CostInferenceForSparseLengths( |
| const OperatorDef& def, |
| const vector<TensorShape>& inputs, |
| bool use_weight) { |
| int min_num_of_inputs = 3 + use_weight; |
| CAFFE_ENFORCE_GE( |
| inputs.size(), |
| min_num_of_inputs, |
| def.type() + " requires at least " + c10::to_string(min_num_of_inputs)); |
| |
| const TensorShape data = inputs[0]; |
| const TensorShape indices = inputs[1 + use_weight]; |
| const TensorShape lengths = inputs[2 + use_weight]; |
| |
| OpSchema::Cost c; |
| CAFFE_ENFORCE_GT(data.dims_size(), 0, "data requires at least 1 dimension"); |
| uint64_t N = data.dims(0); |
| if (N == 0) { |
| return c; |
| } |
| uint64_t D = nElemFromDim(data, 1); |
| CAFFE_ENFORCE_GT( |
| lengths.dims_size(), 0, "lengths requires at least 1 dimension"); |
| uint64_t M = lengths.dims(0); |
| uint64_t indices_size = nElemFromDim(indices); |
| |
| c.flops = indices_size * D; |
| c.bytes_read = indices_size * |
| (D * sizeof(data.data_type()) + sizeof(indices.data_type())) + |
| M * sizeof(lengths.data_type()); |
| c.params_bytes = N * D * sizeof(data.data_type()); |
| if (use_weight) { |
| const TensorShape weights = inputs[1]; |
| c.flops += indices_size * D; |
| c.bytes_read += indices_size * sizeof(weights.data_type()); |
| } |
| |
| return c; |
| } |
| |
| // registering 5 input gradient with main output |
| // gradient of SparseLengthsWeightedSum |
| OPERATOR_SCHEMA(SparseLengthsIndicesInGradientWeightedSumWithMainInputGradient) |
| .NumInputs(5) |
| .NumOutputs(2); |
| REGISTER_CPU_OPERATOR( |
| SparseLengthsIndicesInGradientWeightedSumWithMainInputGradient, |
| AbstractLengthsWithMainInputGradientOp< |
| float, |
| float, |
| int, |
| CPUContext, |
| WeightedSumReducerDef::template ReducerGradient<float, CPUContext>, |
| true /*SparseFused*/, |
| true /*GradientNeedIndices*/>); |
| |
| // registering 4 input version |
| OPERATOR_SCHEMA(SparseLengthsIndicesInGradientWeightedSumGradient) |
| .NumInputs(4) |
| .NumOutputs(1); |
| REGISTER_CPU_OPERATOR( |
| SparseLengthsIndicesInGradientWeightedSumGradient, |
| AbstractLengthsGradientOp< |
| float, |
| int, |
| CPUContext, |
| WeightedSumReducerDef::template ReducerGradient<float, CPUContext>, |
| true /*GradientNeedIndices*/>); |
| |
| // registering 3 input version |
| // gradient of SparseLengthsSum |
| OPERATOR_SCHEMA(SparseLengthsIndicesInGradientSumGradient) |
| .NumInputs(3) |
| .NumOutputs(1); |
| REGISTER_CPU_OPERATOR( |
| SparseLengthsIndicesInGradientSumGradient, |
| AbstractLengthsGradientOp< |
| float, |
| int, |
| CPUContext, |
| SumReducerDef::template ReducerGradient<float, CPUContext>, |
| true /*GradientNeedIndices*/>); |
| // gradient of LengthsSum |
| OPERATOR_SCHEMA(LengthsIndicesInGradientSumGradient).NumInputs(3).NumOutputs(1); |
| REGISTER_CPU_OPERATOR( |
| LengthsIndicesInGradientSumGradient, |
| AbstractLengthsGradientOp< |
| float, |
| int, |
| CPUContext, |
| SumReducerDef::template ReducerGradient<float, CPUContext>, |
| true /*GradientNeedIndices*/>); |
| |
| // registering 3 input version |
| // gradient of SparseLengthsMean |
| OPERATOR_SCHEMA(SparseLengthsIndicesInGradientMeanGradient) |
| .NumInputs(3) |
| .NumOutputs(1); |
| REGISTER_CPU_OPERATOR( |
| SparseLengthsIndicesInGradientMeanGradient, |
| AbstractLengthsGradientOp< |
| float, |
| int, |
| CPUContext, |
| MeanReducerDef::template ReducerGradient<float, CPUContext>, |
| true /*GradientNeedIndices*/>); |
| // gradient of LengthsMean |
| OPERATOR_SCHEMA(LengthsIndicesInGradientMeanGradient) |
| .NumInputs(3) |
| .NumOutputs(1); |
| REGISTER_CPU_OPERATOR( |
| LengthsIndicesInGradientMeanGradient, |
| AbstractLengthsGradientOp< |
| float, |
| int, |
| CPUContext, |
| MeanReducerDef::template ReducerGradient<float, CPUContext>, |
| true /*GradientNeedIndices*/>); |
| |
| namespace { |
| |
| static const char* kLengthsMaxExtra = R"DOC( |
| The *LengthsMax* op takes two inputs *DATA* and *LENGTHS*, and produces a single output *OUTPUT*. The op finds the maximum value in each of the segments of *DATA*, where segments are defined by their lengths. |
| For example, if $DATA = [2,4,3,1,2,10]$ and $LENGTHS = [2,3,1]$ then $OUTPUT = [max([2,4]), max([3,1,2]), max([10])] = [4,3,10]$. |
| |
| Github Link: |
| - https://github.com/caffe2/caffe2/blob/master/caffe2/operators/segment_reduction_op.cc |
| |
| <details> |
| |
| <summary> <b>Example</b> </summary> |
| |
| **Code** |
| |
| ``` |
| |
| workspace.ResetWorkspace() |
| |
| op = core.CreateOperator( |
| "LengthsMax", |
| ["DATA", "LENGTHS"], |
| ["OUTPUT"], |
| ) |
| |
| workspace.FeedBlob("DATA", np.array([2,4,3,1,2,10]).astype(np.float32)) |
| print("DATA:\n", workspace.FetchBlob("DATA")) |
| |
| workspace.FeedBlob("LENGTHS", np.array([2,3,1]).astype(np.int32)) |
| print("LENGTHS:\n", workspace.FetchBlob("LENGTHS")) |
| |
| workspace.RunOperatorOnce(op) |
| print("OUTPUT: \n", workspace.FetchBlob("OUTPUT")) |
| |
| ``` |
| |
| **Result** |
| |
| ``` |
| |
| DATA: |
| [ 2. 4. 3. 1. 2. 10.] |
| LENGTHS: |
| [2 3 1] |
| OUTPUT: |
| [ 4. 3. 10.] |
| |
| ``` |
| |
| </details> |
| |
| )DOC"; |
| |
| static const char* kLengthsMeanExtra = R"DOC( |
| The *LengthsMean* op takes two inputs *DATA* and *LENGTHS*, and produces a single output *OUTPUT*. The op finds the mean value in each of the segments of *DATA*, where segments are defined by their lengths. |
| For example, if $DATA = [2,4,3,1,2,10]$ and $LENGTHS = [2,3,1]$ then $OUTPUT = [mean([2,4]), mean([3,1,2]), mean([10])] = [3,2,10]$. |
| |
| Github Link: |
| - https://github.com/caffe2/caffe2/blob/master/caffe2/operators/segment_reduction_op.cc |
| |
| <details> |
| |
| <summary> <b>Example</b> </summary> |
| |
| **Code** |
| |
| ``` |
| |
| workspace.ResetWorkspace() |
| |
| op = core.CreateOperator( |
| "LengthsMean", |
| ["DATA", "LENGTHS"], |
| ["OUTPUT"], |
| ) |
| |
| workspace.FeedBlob("DATA", np.array([2,4,3,1,2,10]).astype(np.float32)) |
| print("DATA:\n", workspace.FetchBlob("DATA")) |
| |
| workspace.FeedBlob("LENGTHS", np.array([2,3,1]).astype(np.int32)) |
| print("LENGTHS:\n", workspace.FetchBlob("LENGTHS")) |
| |
| workspace.RunOperatorOnce(op) |
| print("OUTPUT: \n", workspace.FetchBlob("OUTPUT")) |
| |
| ``` |
| |
| **Result** |
| |
| ``` |
| |
| DATA: |
| [ 2. 4. 3. 1. 2. 10.] |
| LENGTHS: |
| [2 3 1] |
| OUTPUT: |
| [ 3. 2. 10.] |
| |
| ``` |
| |
| </details> |
| |
| )DOC"; |
| |
| static const char* kLengthsSumExtra = R"DOC( |
| The *LengthsSum* op takes two inputs *DATA* and *LENGTHS*, and produces a single output *OUTPUT*. The op finds the sum in each of the segments of *DATA*, where segments are defined by their lengths. |
| For example, if $DATA = [2,4,3,1,2,10]$ and $LENGTHS = [2,3,1]$ then $OUTPUT = [sum([2,4]), sum([3,1,2]), sum([10])] = [6,6,10]$. |
| |
| Github Link: |
| - https://github.com/caffe2/caffe2/blob/master/caffe2/operators/segment_reduction_op.cc |
| |
| <details> |
| |
| <summary> <b>Example</b> </summary> |
| |
| **Code** |
| |
| ``` |
| |
| workspace.ResetWorkspace() |
| |
| op = core.CreateOperator( |
| "LengthsSum", |
| ["DATA", "LENGTHS"], |
| ["OUTPUT"], |
| ) |
| |
| workspace.FeedBlob("DATA", np.array([2,4,3,1,2,10]).astype(np.float32)) |
| print("DATA:\n", workspace.FetchBlob("DATA")) |
| |
| workspace.FeedBlob("LENGTHS", np.array([2,3,1]).astype(np.int32)) |
| print("LENGTHS:\n", workspace.FetchBlob("LENGTHS")) |
| |
| workspace.RunOperatorOnce(op) |
| print("OUTPUT: \n", workspace.FetchBlob("OUTPUT")) |
| |
| ``` |
| |
| **Result** |
| |
| ``` |
| |
| DATA: |
| [ 2. 4. 3. 1. 2. 10.] |
| LENGTHS: |
| [2 3 1] |
| OUTPUT: |
| [ 6. 6. 10.] |
| |
| ``` |
| |
| </details> |
| |
| )DOC"; |
| |
| static const char* kLengthsWeightedSumExtra = R"DOC( |
| The *LengthsWeightedSum* op takes three inputs *DATA*, *LENGTHS*, and *SCALARS*, and produces a single output *OUTPUT*. The op finds the weighted sum in each of the segments of *DATA*, where segments are defined by their lengths. Before calculating the sums, the input *DATA* is weighted by the contents of *SCALARS*. |
| For example, if $DATA = [2,4,3,1,2,10]$, $SCALARS = [8, 2, 1, 4, 1, 0.6]$, and $LENGTHS = [2,3,1]$, then $OUTPUT = [sum([8*2,2*4]), sum([1*3,4*1,1*2]), sum([0.6*10])] = [24,9,6]$. |
| |
| Github Link: |
| - https://github.com/caffe2/caffe2/blob/master/caffe2/operators/segment_reduction_op.cc |
| |
| <details> |
| |
| <summary> <b>Example</b> </summary> |
| |
| **Code** |
| |
| ``` |
| |
| workspace.ResetWorkspace() |
| |
| op = core.CreateOperator( |
| "LengthsWeightedSum", |
| ["DATA", "SCALARS","LENGTHS"], |
| ["OUTPUT"], |
| ) |
| |
| workspace.FeedBlob("DATA", np.array([2,4,3,1,2,10]).astype(np.float32)) |
| print("DATA:\n", workspace.FetchBlob("DATA")) |
| |
| workspace.FeedBlob("SCALARS", np.array([8, 2, 1, 4, 1, 0.6]).astype(np.float32)) |
| print("SCALARS:\n", workspace.FetchBlob("SCALARS")) |
| |
| workspace.FeedBlob("LENGTHS", np.array([2,3,1]).astype(np.int32)) |
| print("LENGTHS:\n", workspace.FetchBlob("LENGTHS")) |
| |
| workspace.RunOperatorOnce(op) |
| print("OUTPUT: \n", workspace.FetchBlob("OUTPUT")) |
| |
| ``` |
| |
| **Result** |
| |
| ``` |
| |
| DATA: |
| [ 2. 4. 3. 1. 2. 10.] |
| SCALARS: |
| [8. 2. 1. 4. 1. 0.6] |
| LENGTHS: |
| [2 3 1] |
| OUTPUT: |
| [24. 9. 6.] |
| |
| ``` |
| |
| </details> |
| |
| )DOC"; |
| |
| template <typename Def> |
| string FormatDoc() { |
| string doc = Def::doc; |
| c10::ReplaceAll(doc, "{op}", Def::OpDef::name); |
| c10::ReplaceAll(doc, "{op_doc}", Def::OpDef::doc); |
| if (strcmp(Def::OpDef::name, "Max") == 0) { |
| c10::ReplaceAll(doc, "{extra}", kLengthsMaxExtra); |
| } else if (strcmp(Def::OpDef::name, "Mean") == 0) { |
| c10::ReplaceAll(doc, "{extra}", kLengthsMeanExtra); |
| } else if (strcmp(Def::OpDef::name, "Sum") == 0) { |
| c10::ReplaceAll(doc, "{extra}", kLengthsSumExtra); |
| } else if (strcmp(Def::OpDef::name, "WeightedSum") == 0) { |
| c10::ReplaceAll(doc, "{extra}", kLengthsWeightedSumExtra); |
| } else { |
| c10::ReplaceAll(doc, "{extra}", " "); |
| } |
| return doc; |
| } |
| |
| // Helper function to enforce naming conventions at compile time. |
| constexpr bool equal( |
| char const* lhs, |
| char const* rhs1, |
| char const* rhs2, |
| char const* rhs3 = "") { |
| return (*lhs == 0 && *rhs1 == 0 && *rhs2 == 0 && *rhs3 == 0) || |
| (*rhs1 != 0 && *lhs == *rhs1 && equal(lhs + 1, rhs1 + 1, rhs2, rhs3)) || |
| (*rhs1 == 0 && *rhs2 != 0 && *lhs == *rhs2 && |
| equal(lhs + 1, rhs1, rhs2 + 1, rhs3)) || |
| (*rhs1 == 0 && *rhs2 == 0 && *rhs3 != 0 && *lhs == *rhs3 && |
| equal(lhs + 1, rhs1, rhs2, rhs3 + 1)); |
| } |
| |
| // Helper macro when the main op is defined elsewhere, and we only need to |
| // define the schema, and the gradient op. |
| // TODO: enable input fillers |
| #define REGISTER_SEGMENT_DEF_SCHEMA_GRADIENT_ONLY( \ |
| segment_name, gradient_name, ...) \ |
| static_assert( \ |
| equal(#segment_name, __VA_ARGS__::basename, __VA_ARGS__::OpDef::name), \ |
| #segment_name); \ |
| static_assert( \ |
| equal( \ |
| #gradient_name, \ |
| __VA_ARGS__::basename, \ |
| __VA_ARGS__::OpDef::name, \ |
| "Gradient"), \ |
| #gradient_name); \ |
| OPERATOR_SCHEMA(segment_name) \ |
| .NumInputs(__VA_ARGS__::ForwardOp::kNumInputs) \ |
| .NumOutputs(1) \ |
| .DisallowInputFillers() \ |
| .SetDoc(FormatDoc<__VA_ARGS__>()) \ |
| .Output(0, "OUTPUT", "Aggregated tensor") \ |
| .FillUsing(__VA_ARGS__::PopulateSchema); \ |
| REGISTER_CPU_OPERATOR_STR(string(#gradient_name), __VA_ARGS__::BackwardOp); \ |
| OPERATOR_SCHEMA(gradient_name) \ |
| .NumInputs(__VA_ARGS__::BackwardOp::kNumInputs) \ |
| .NumOutputs(1) \ |
| .DisallowInputFillers(); \ |
| REGISTER_GRADIENT_STR(string(#segment_name), __VA_ARGS__::GetGradient) |
| |
| #define REGISTER_SEGMENT_DEF(segment_name, gradient_name, ...) \ |
| static_assert( \ |
| equal(#segment_name, __VA_ARGS__::basename, __VA_ARGS__::OpDef::name), \ |
| #segment_name); \ |
| REGISTER_CPU_OPERATOR_STR(string(#segment_name), __VA_ARGS__::ForwardOp); \ |
| REGISTER_SEGMENT_DEF_SCHEMA_GRADIENT_ONLY( \ |
| segment_name, gradient_name, __VA_ARGS__) |
| |
| // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables,clang-diagnostic-unused-function) |
| REGISTER_SEGMENT_DEF( |
| SortedSegmentRangeSum, |
| SortedSegmentRangeSumGradient, |
| AbstractSortedSegmentRangeDef<float, int, CPUContext, SumRangeReducerDef>); |
| // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables,clang-diagnostic-unused-function) |
| REGISTER_SEGMENT_DEF( |
| SortedSegmentRangeLogSumExp, |
| SortedSegmentRangeLogSumExpGradient, |
| AbstractSortedSegmentRangeDef< |
| float, |
| int, |
| CPUContext, |
| LogSumExpRangeReducerDef>); |
| // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables,clang-diagnostic-unused-function) |
| REGISTER_SEGMENT_DEF( |
| SortedSegmentRangeLogMeanExp, |
| SortedSegmentRangeLogMeanExpGradient, |
| AbstractSortedSegmentRangeDef< |
| float, |
| int, |
| CPUContext, |
| LogMeanExpRangeReducerDef>); |
| // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables,clang-diagnostic-unused-function) |
| REGISTER_SEGMENT_DEF( |
| SortedSegmentRangeMean, |
| SortedSegmentRangeMeanGradient, |
| AbstractSortedSegmentRangeDef<float, int, CPUContext, MeanRangeReducerDef>); |
| // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables,clang-diagnostic-unused-function) |
| REGISTER_SEGMENT_DEF( |
| SortedSegmentRangeMax, |
| SortedSegmentRangeMaxGradient, |
| AbstractSortedSegmentRangeDef<float, int, CPUContext, MaxRangeReducerDef>); |
| |
| // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables,clang-diagnostic-unused-function) |
| REGISTER_SEGMENT_DEF( |
| SortedSegmentSum, |
| SortedSegmentSumGradient, |
| AbstractSortedSegmentDef<float, int, CPUContext, SumReducerDef>); |
| // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables,clang-diagnostic-unused-function) |
| REGISTER_SEGMENT_DEF( |
| SparseSortedSegmentSum, |
| SparseSortedSegmentSumGradient, |
| AbstractSparseSortedSegmentDef<float, int, CPUContext, SumReducerDef>); |
| // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables,clang-diagnostic-unused-function) |
| REGISTER_SEGMENT_DEF( |
| UnsortedSegmentSum, |
| UnsortedSegmentSumGradient, |
| AbstractUnsortedSegmentDef<float, int, CPUContext, SumReducerDef>); |
| // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables,clang-diagnostic-unused-function) |
| REGISTER_SEGMENT_DEF( |
| SparseUnsortedSegmentSum, |
| SparseUnsortedSegmentSumGradient, |
| AbstractSparseUnsortedSegmentDef<float, int, CPUContext, SumReducerDef>); |
| |
| // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables,clang-diagnostic-unused-function) |
| REGISTER_SEGMENT_DEF( |
| LengthsSum, |
| LengthsSumGradient, |
| AbstractLengthsDef<float, int, CPUContext, SumReducerDef, true>); |
| |
| // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables,clang-diagnostic-unused-function) |
| REGISTER_SEGMENT_DEF( |
| SortedSegmentMean, |
| SortedSegmentMeanGradient, |
| AbstractSortedSegmentDef<float, int, CPUContext, MeanReducerDef>); |
| // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables,clang-diagnostic-unused-function) |
| REGISTER_SEGMENT_DEF( |
| SparseSortedSegmentMean, |
| SparseSortedSegmentMeanGradient, |
| AbstractSparseSortedSegmentDef<float, int, CPUContext, MeanReducerDef>); |
| // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables,clang-diagnostic-unused-function) |
| REGISTER_SEGMENT_DEF( |
| UnsortedSegmentMean, |
| UnsortedSegmentMeanGradient, |
| AbstractUnsortedSegmentDef<float, int, CPUContext, MeanReducerDef>); |
| // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables,clang-diagnostic-unused-function) |
| REGISTER_SEGMENT_DEF( |
| SparseUnsortedSegmentMean, |
| SparseUnsortedSegmentMeanGradient, |
| AbstractSparseUnsortedSegmentDef<float, int, CPUContext, MeanReducerDef>); |
| |
| // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables,clang-diagnostic-unused-function) |
| REGISTER_SEGMENT_DEF( |
| LengthsMean, |
| LengthsMeanGradient, |
| AbstractLengthsDef<float, int, CPUContext, MeanReducerDef, true>); |
| |
| // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables,clang-diagnostic-unused-function) |
| REGISTER_SEGMENT_DEF( |
| ReduceFrontWeightedSum, |
| ReduceFrontWeightedSumGradient, |
| AbstractReduceFrontDef<float, CPUContext, WeightedSumReducerDef>); |
| // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables,clang-diagnostic-unused-function) |
| REGISTER_SEGMENT_DEF( |
| SortedSegmentWeightedSum, |
| SortedSegmentWeightedSumGradient, |
| AbstractSortedSegmentDef<float, int, CPUContext, WeightedSumReducerDef>); |
| // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables,clang-diagnostic-unused-function) |
| REGISTER_SEGMENT_DEF( |
| SparseSortedSegmentWeightedSum, |
| SparseSortedSegmentWeightedSumGradient, |
| AbstractSparseSortedSegmentDef< |
| float, |
| int, |
| CPUContext, |
| WeightedSumReducerDef>); |
| // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables,clang-diagnostic-unused-function) |
| REGISTER_SEGMENT_DEF( |
| UnsortedSegmentWeightedSum, |
| UnsortedSegmentWeightedSumGradient, |
| AbstractUnsortedSegmentDef<float, int, CPUContext, WeightedSumReducerDef>); |
| // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables,clang-diagnostic-unused-function) |
| REGISTER_SEGMENT_DEF( |
| SparseUnsortedSegmentWeightedSum, |
| SparseUnsortedSegmentWeightedSumGradient, |
| AbstractSparseUnsortedSegmentDef< |
| float, |
| int, |
| CPUContext, |
| WeightedSumReducerDef>); |
| // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables,clang-diagnostic-unused-function) |
| REGISTER_SEGMENT_DEF( |
| LengthsWeightedSum, |
| LengthsWeightedSumGradient, |
| AbstractLengthsDef<float, int, CPUContext, WeightedSumReducerDef, false>); |
| |
| // Auxiliary output gradients are currently implemented only for Lengths version |
| #define REGISTER_GRADIENT_WITH_MAIN_INPUT(gradient_name, ...) \ |
| static_assert( \ |
| equal( \ |
| #gradient_name, \ |
| __VA_ARGS__::basename, \ |
| __VA_ARGS__::OpDef::name, \ |
| "WithMainInputGradient"), \ |
| #gradient_name); \ |
| REGISTER_CPU_OPERATOR_STR( \ |
| string(#gradient_name), __VA_ARGS__::WithMainInputBackwardOp); \ |
| OPERATOR_SCHEMA(gradient_name) \ |
| .NumInputs(__VA_ARGS__::WithMainInputBackwardOp::kNumInputs) \ |
| .NumOutputs(1, INT_MAX) |
| |
| // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables,clang-diagnostic-unused-function) |
| REGISTER_GRADIENT_WITH_MAIN_INPUT( |
| LengthsWeightedSumWithMainInputGradient, |
| AbstractLengthsDef<float, int, CPUContext, WeightedSumReducerDef>); |
| // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables,clang-diagnostic-unused-function) |
| REGISTER_GRADIENT_WITH_MAIN_INPUT( |
| SparseLengthsWeightedSumWithMainInputGradient, |
| AbstractSparseLengthsDef<float, int, CPUContext, WeightedSumReducerDef>); |
| } // namespace |
| |
| #define REGISTER_GRADIENT_WITH_MAIN_INPUT_AND_FORWARD_OUTPUT( \ |
| gradient_name, ...) \ |
| static_assert( \ |
| equal( \ |
| #gradient_name, \ |
| __VA_ARGS__::basename, \ |
| __VA_ARGS__::OpDef::name, \ |
| "WithMainInputAndForwardOutputGradient"), \ |
| #gradient_name); \ |
| REGISTER_CPU_OPERATOR_STR( \ |
| string(#gradient_name), \ |
| __VA_ARGS__::WithMainInputAndForwardOutputBackwardOp); \ |
| OPERATOR_SCHEMA(gradient_name) \ |
| .NumInputs( \ |
| __VA_ARGS__::WithMainInputAndForwardOutputBackwardOp::kNumInputs) \ |
| .NumOutputs(1, INT_MAX) |
| |
| #define REGISTER_SEGMENT_DEF_MAIN_INPUT_AND_FORWARD_OUTPUT_GRADIENT( \ |
| segment_name, gradient_name, ...) \ |
| static_assert( \ |
| equal(#segment_name, __VA_ARGS__::basename, __VA_ARGS__::OpDef::name), \ |
| #segment_name); \ |
| OPERATOR_SCHEMA(segment_name) \ |
| .NumInputs(__VA_ARGS__::ForwardOp::kNumInputs) \ |
| .NumOutputs(1) \ |
| .SetDoc(FormatDoc<__VA_ARGS__>()) \ |
| .Output(0, "OUTPUT", "Aggregated tensor") \ |
| .FillUsing(__VA_ARGS__::PopulateSchema); \ |
| REGISTER_GRADIENT_WITH_MAIN_INPUT_AND_FORWARD_OUTPUT( \ |
| gradient_name, __VA_ARGS__); \ |
| REGISTER_GRADIENT_STR(string(#segment_name), __VA_ARGS__::GetGradient) |
| |
| // This implements and registers a length op with a gradient which requires |
| // the main input as well as the output of the forward output. |
| #define REGISTER_LENGTHS_OPS_MAIN_INPUT_AND_FORWARD_OUTPUT_GRADIENT( \ |
| segment_name, gradient_name, ...) \ |
| static_assert( \ |
| equal(#segment_name, __VA_ARGS__::basename, __VA_ARGS__::OpDef::name), \ |
| #segment_name); \ |
| REGISTER_CPU_OPERATOR_STR(string(#segment_name), __VA_ARGS__::ForwardOp); \ |
| REGISTER_SEGMENT_DEF_MAIN_INPUT_AND_FORWARD_OUTPUT_GRADIENT( \ |
| segment_name, gradient_name, __VA_ARGS__) |
| |
| REGISTER_LENGTHS_OPS_MAIN_INPUT_AND_FORWARD_OUTPUT_GRADIENT( |
| LengthsMax, |
| LengthsMaxWithMainInputAndForwardOutputGradient, |
| AbstractLengthsDef<float, int, CPUContext, MaxReducerDef>); |
| } // namespace caffe2 |
| |
| // Macro doesn't like comma |
| using LengthsSumCPUOp = caffe2::AbstractLengthsDef< |
| float, |
| int, |
| caffe2::CPUContext, |
| caffe2::SumReducerDef, |
| true>::ForwardOp; |
| using LengthsMeanCPUOp = caffe2::AbstractLengthsDef< |
| float, |
| int, |
| caffe2::CPUContext, |
| caffe2::MeanReducerDef, |
| true>::ForwardOp; |
| using LengthsMaxCPUOp = caffe2::AbstractLengthsDef< |
| float, |
| int, |
| caffe2::CPUContext, |
| caffe2::MaxReducerDef, |
| true>::ForwardOp; |
| |
| C10_EXPORT_CAFFE2_OP_TO_C10_CPU( |
| LengthsSum, |
| "_caffe2::LengthsSum(Tensor data, Tensor lengths) -> Tensor", |
| LengthsSumCPUOp); |
| C10_EXPORT_CAFFE2_OP_TO_C10_CPU( |
| LengthsMean, |
| "_caffe2::LengthsMean(Tensor data, Tensor lengths) -> Tensor", |
| LengthsMeanCPUOp); |
| C10_EXPORT_CAFFE2_OP_TO_C10_CPU( |
| LengthsMax, |
| "_caffe2::LengthsMax(Tensor data, Tensor lengths) -> Tensor", |
| LengthsMaxCPUOp); |