| #ifndef GATHER_OP_H_ |
| #define GATHER_OP_H_ |
| |
| #include "caffe2/core/context.h" |
| #include "caffe2/core/operator.h" |
| #include <c10/util/irange.h> |
| |
| namespace caffe2 { |
| |
| // This maintains index-mapping functions shared by Gather and BatchGather ops. |
| namespace gather_helper { |
| |
| // New shape is concatenation: |
| // [data dims before axis] + [indices dims] + [data dims after axis] |
| template <typename IndexType, typename DataDimsVec, typename IndexDimsVec> |
| static vector<IndexType> calc_output_shape_vector( |
| const DataDimsVec& data_dims, |
| const IndexDimsVec& indices_dims, |
| int axis, |
| bool match_outer) { |
| vector<IndexType> shape; |
| // If the dimension we are indexing is empty, just use data_dims as shape. |
| // This replicates behavior in (https://github.com/pytorch/pytorch/pull/13781) |
| // needed to allow workflows with empty batch to succeed. |
| if (data_dims[axis] == 0) { |
| shape.insert(shape.end(), data_dims.begin(), data_dims.end()); |
| } else { |
| shape.insert(shape.end(), data_dims.begin(), data_dims.begin() + axis); |
| if (match_outer) { |
| shape.insert( |
| shape.end(), indices_dims.begin() + axis, indices_dims.end()); |
| } else { |
| shape.insert(shape.end(), indices_dims.begin(), indices_dims.end()); |
| } |
| shape.insert(shape.end(), data_dims.begin() + axis + 1, data_dims.end()); |
| } |
| return shape; |
| } |
| |
| // Check that indices fall within dimension array size with CAFFE_ENFORCE. |
| template <typename IndexType> |
| static void check_indexarray_range( |
| const IndexType* indices, |
| int64_t n, |
| IndexType indexing_axis_dim, |
| bool wrap_indices) { |
| // |
| for (const auto i : c10::irange(n)) { |
| auto idx = indices[i]; |
| if (wrap_indices && idx < 0) { |
| idx = idx + indexing_axis_dim; |
| } |
| CAFFE_ENFORCE( |
| 0 <= idx && idx < indexing_axis_dim, |
| "INDICES element is out of DATA bounds, id=", |
| idx, |
| " axis_dim=", |
| indexing_axis_dim); |
| } |
| } |
| |
| // Actual gather implementation - resizes output and copies indexed data. |
| template <typename Index, typename Context> |
| static bool gather_impl( |
| Operator<Context>* op, |
| int dataIdx, |
| int indicesIdx, |
| int outputIdx, |
| int axis, |
| bool wrap_indices, |
| bool match_outer) { |
| // If we endup using it on GPU doing O(N) memcpy is probably not best :) |
| // TODO: implement prefetching if it starts mattering (TF does it) |
| |
| const Tensor& data = op->Input(dataIdx); |
| const Tensor& indices = op->Input(indicesIdx); |
| const TypeMeta dataType = data.dtype(); |
| size_t item_bytesize = dataType.itemsize(); |
| |
| // ONNX allows negative axis to index from the back, valid range: [-r, r]. |
| if (axis < 0) { |
| axis = data.dim() + axis; |
| } |
| CAFFE_ENFORCE_GE(data.dim(), axis + 1, "DATA should be at least [axis+1]-D"); |
| CAFFE_ENFORCE_GE(axis, 0, "Axis should be non-negative"); |
| CAFFE_ENFORCE_LT(axis, data.dim(), "Axis out of range"); |
| |
| // New shape: |
| // [data dims before axis] + [indices dims] + [data dims after axis] |
| vector<int64_t> shape = calc_output_shape_vector<int64_t>( |
| data.sizes(), indices.sizes(), axis, match_outer); |
| Tensor* output = op->Output(outputIdx, shape, at::dtype(dataType)); |
| auto out = static_cast<char*>(output->raw_mutable_data(dataType)); |
| |
| // Succeed if size of output is zero, which can happen for empty batch which |
| // would have data dimension size of 0. |
| // This *must* be done AFTER output->raw_mutable_data() above as that has |
| // important allocation side effect that we must see. |
| if (output->numel() == 0) { |
| return true; |
| } |
| |
| const Index* idxs = indices.template data<Index>(); |
| auto src_base = static_cast<const char*>(data.raw_data()); |
| |
| auto outer_dims_product = data.size_to_dim(axis); |
| auto block_size = data.size_from_dim(axis + 1); |
| auto block_bytesize = block_size * item_bytesize; |
| |
| auto src_indexing_axis_dim = data.size(axis); |
| auto src_batch_bytesize = data.size_from_dim(axis) * item_bytesize; |
| // Treat indices as a single block even if they have multiple dimensions. |
| // The "gathered batch" is a cumulative result combining indexed blocks. |
| auto idx_inner_dims_product = indices.size_from_dim(axis); |
| auto N = indices.numel(); |
| if (match_outer) { |
| CAFFE_ENFORCE_GE(axis, 1, "Axis should be at least 1"); |
| for (const auto i : c10::irange(axis)) { |
| CAFFE_ENFORCE_EQ( |
| data.size(i), |
| indices.size(i), |
| "INDICES must have the same outer dims as DATA (before dim AXIS)"); |
| } |
| N = idx_inner_dims_product; |
| } |
| |
| auto gathered_batch_bytesize = N * block_size * item_bytesize; |
| |
| check_indexarray_range<Index>(idxs, N, src_indexing_axis_dim, wrap_indices); |
| |
| // Special-case single-float copy for efficiency |
| if (data.template IsType<float>() && block_size == 1) { |
| for (const auto batch : c10::irange(outer_dims_product)) { |
| const float* src_floats = |
| (const float*)(src_base + batch * src_batch_bytesize); |
| float* dst_floats = (float*)(out + batch * gathered_batch_bytesize); |
| |
| for (const auto i : c10::irange(N)) { |
| auto idx = idxs[i]; |
| if (match_outer) { |
| idx = idxs[batch * idx_inner_dims_product + i]; |
| } |
| if (wrap_indices && idx < 0) { |
| idx = idx + src_indexing_axis_dim; |
| } |
| dst_floats[i] = src_floats[idx]; |
| } |
| } |
| } else { |
| // outer_dims_product specifies how many times we repeat inner dimensions, |
| // so we just iterate over it to cover all outer dimensions. |
| for (const auto batch : c10::irange(outer_dims_product)) { |
| for (const auto i : c10::irange(N)) { |
| auto idx = idxs[i]; |
| if (match_outer) { |
| idx = idxs[batch * idx_inner_dims_product + i]; |
| } |
| if (wrap_indices && idx < 0) { |
| idx = idx + src_indexing_axis_dim; |
| } |
| |
| auto src = src_base + batch * src_batch_bytesize + idx * block_bytesize; |
| auto dst = out + batch * gathered_batch_bytesize + i * block_bytesize; |
| op->getContext()->CopyItemsSameDevice(dataType, block_size, src, dst); |
| } |
| } |
| } |
| return true; |
| } |
| |
| } // namespace gather_helper |
| |
| template <class Context> |
| class GatherOp : public Operator<Context> { |
| public: |
| USE_OPERATOR_CONTEXT_FUNCTIONS; |
| |
| template <class... Args> |
| explicit GatherOp(Args&&... args) |
| : Operator<Context>(std::forward<Args>(args)...), |
| OP_SINGLE_ARG(int, "axis", axis_, 0), |
| OP_SINGLE_ARG(bool, "match_outer", match_outer_, false) { |
| // TBD: We may want to fix the old index wrap behaviour once we have |
| // operator versioning, to only apply it when needed as otherwise its likely |
| // an error. |
| // Right now, we apply index wrapping by default only to axis == 0, |
| // since we have ONNX conversion code that uses it. For other ops it |
| // needs to be specified explicitly with argument or you don't get it. |
| if (OperatorBase::HasArgument("wrap_indices")) { |
| wrap_indices_ = Operator<Context>::template GetSingleArgument<bool>( |
| "wrap_indices", (false)); |
| } else { |
| wrap_indices_ = (axis_ == 0) ? true : false; |
| } |
| } |
| |
| ~GatherOp() noexcept override {} |
| |
| bool RunOnDevice() override { |
| return DispatchHelper<TensorTypes<int32_t, int64_t>>::call( |
| this, this->template Input<Tensor>(INDICES, CPU)); |
| } |
| |
| template <typename Index> |
| bool DoRunWithType() { |
| return gather_helper::gather_impl<Index, Context>( |
| this, DATA, INDICES, 0, axis_, wrap_indices_, match_outer_); |
| } |
| |
| INPUT_TAGS(DATA, INDICES); |
| |
| protected: |
| int axis_; |
| bool wrap_indices_; |
| bool match_outer_; |
| }; |
| |
| } // namespace caffe2 |
| #endif // GATHER_OP_H_ |