| #ifndef CAFFE2_OPERATORS_BATCH_GATHER_OPS_H_ |
| #define CAFFE2_OPERATORS_BATCH_GATHER_OPS_H_ |
| |
| #include "caffe2/core/context.h" |
| #include "caffe2/core/operator.h" |
| #include "caffe2/utils/math.h" |
| // Reuse helper logic from GatherOp since BatchGather is the same with axis=1. |
| #include "caffe2/operators/gather_op.h" |
| |
| namespace caffe2 { |
| |
| template <class Context> |
| class BatchGatherOp final : public Operator<Context> { |
| public: |
| USE_OPERATOR_CONTEXT_FUNCTIONS; |
| |
| template <class... Args> |
| explicit BatchGatherOp(Args&&... args) |
| : Operator<Context>(std::forward<Args>(args)...), |
| OP_SINGLE_ARG(bool, "match_outer", match_outer_, false) {} |
| |
| // virtual ~BatchGatherOp() noexcept {} |
| |
| bool RunOnDevice() override { |
| return DispatchHelper<TensorTypes<int32_t, int64_t>>::call( |
| this, this->template Input<Tensor>(INDICES, CPU)); |
| } |
| |
| template <typename TInd> |
| bool DoRunWithType() { |
| // BatchGather is a special-case of Gather with Axis = 1. |
| return gather_helper::gather_impl<TInd, Context>( |
| this, DATA, INDICES, 0, 1, false, match_outer_); |
| } |
| INPUT_TAGS(DATA, INDICES); |
| |
| protected: |
| bool match_outer_; |
| }; |
| |
| template <class Context> |
| class BatchGatherGradientOp final : public Operator<Context> { |
| public: |
| USE_OPERATOR_CONTEXT_FUNCTIONS; |
| |
| // Constructor to receive axis in case it was passed for GatherOp gradient, |
| // use default of 1 for batch gather otherwise. |
| template <class... Args> |
| explicit BatchGatherGradientOp(Args&&... args) |
| : Operator<Context>(std::forward<Args>(args)...), |
| OP_SINGLE_ARG(int, "axis", axis_, 1), |
| OP_SINGLE_ARG(bool, "match_outer", match_outer_, false) {} |
| virtual ~BatchGatherGradientOp() noexcept {} |
| |
| bool RunOnDevice() override { |
| return DispatchHelper<TensorTypes<int32_t, int64_t>>::call( |
| this, this->template Input<Tensor>(INDICES, CPU)); |
| } |
| |
| template <typename TInd> |
| bool DoRunWithType() { |
| return DispatchHelper< |
| TensorTypes2<float, GenericTensorImplementation>, |
| TInd>::call(this, Input(DATA)); |
| } |
| |
| template <typename TInd, typename TData> |
| bool DoRunWithType2() { |
| auto& data = Input(DATA); |
| auto& indices = Input(INDICES); |
| auto& grad = Input(GRAD); |
| |
| // ONNX allows negative axis to index from the back, valid range: [-r, r]. |
| int axis = axis_; |
| bool match_outer = match_outer_; |
| if (axis < 0) { |
| axis = data.dim() + axis; |
| } |
| |
| CAFFE_ENFORCE_GE(data.dim(), 2, "DATA should be at least 2-D"); |
| // Outer dimensions of input data and gradient should be the same |
| // because they are preserved for gathers with axis > 0. |
| for (const auto acheck : c10::irange(axis)) { |
| CAFFE_ENFORCE_EQ( |
| data.size(acheck), |
| grad.size(acheck), |
| "batch gather outer dimensions should match"); |
| } |
| |
| auto* output = Output(0, data.sizes(), at::dtype<TData>()); |
| TData* out_data = output->template mutable_data<TData>(); |
| if (data.numel() <= 0) { |
| return true; |
| } |
| memset(out_data, 0, output->nbytes()); |
| |
| const TData* grad_data = grad.template data<TData>(); |
| const TInd* idxs = indices.template data<TInd>(); |
| |
| auto outer_dims_product = data.size_to_dim(axis); |
| auto batch_size = data.size_from_dim(axis); |
| auto block_size = data.size_from_dim(axis + 1); |
| auto N = indices.numel(); |
| |
| auto idx_inner_dims_product = indices.size_from_dim(axis); |
| if (match_outer) { |
| CAFFE_ENFORCE_GE(axis, 1, "Axis should be at least 1"); |
| for (const auto i : c10::irange(axis)) { |
| CAFFE_ENFORCE_EQ( |
| data.size(i), |
| indices.size(i), |
| "INDICES must have the same outer dims as DATA (before dim AXIS)"); |
| } |
| N = idx_inner_dims_product; |
| } |
| |
| auto gathered_grad_batch_size = N * block_size; |
| // Check indexing bounds. |
| auto src_indexing_axis_dim = data.dim(axis); |
| gather_helper::check_indexarray_range<TInd>( |
| idxs, N, src_indexing_axis_dim, false); |
| |
| for (const auto batch : c10::irange(outer_dims_product)) { |
| auto grad_batch_base = grad_data + batch * gathered_grad_batch_size; |
| auto out_batch_base = out_data + batch * batch_size; |
| |
| for (const auto i : c10::irange(N)) { |
| auto idx = idxs[i]; |
| if (match_outer) { |
| idx = idxs[batch * idx_inner_dims_product + i]; |
| } |
| if (idx < 0) { |
| idx = idx + src_indexing_axis_dim; |
| } |
| if (block_size == 1) { |
| out_batch_base[idx] += grad_batch_base[i]; |
| } else { |
| math::Add( |
| block_size, |
| out_batch_base + idx * block_size, |
| grad_batch_base + i * block_size, |
| out_batch_base + idx * block_size, |
| &context_); |
| } |
| } |
| } |
| return true; |
| } |
| |
| template <typename TInd> |
| bool DoRunWithOtherType2() { |
| CAFFE_THROW( |
| "BatchGatherGradient is not implemented on tensor of type ", |
| Input(DATA).meta().name(), |
| "consider adding it as a type in the DispatchHelper list or " |
| "implementing a generic version (which won't work for " |
| "duplicated indices though)"); |
| } |
| |
| INPUT_TAGS(DATA, INDICES, GRAD); |
| |
| protected: |
| int axis_; |
| bool match_outer_; |
| }; |
| |
| } // namespace caffe2 |
| |
| #endif // CAFFE2_OPERATORS_BATCH_GATHER_OPS_H_ |