| #define TORCH_ASSERT_ONLY_METHOD_OPERATORS |
| #include <ATen/core/Tensor.h> |
| #include <ATen/native/cuda/SortingCommon.cuh> |
| #include <ATen/cuda/cub_definitions.cuh> |
| |
| #ifndef AT_PER_OPERATOR_HEADERS |
| #include <ATen/Functions.h> |
| #else |
| #include <ATen/ops/empty_like.h> |
| #endif |
| |
| #include <ATen/cuda/ThrustAllocator.h> |
| #include <thrust/device_ptr.h> |
| #include <thrust/execution_policy.h> |
| #include <thrust/sort.h> |
| #include <thrust/unique.h> |
| #include <thrust/device_ptr.h> |
| #include <thrust/iterator/constant_iterator.h> |
| |
| namespace at::native { |
| |
| void index_put_with_sort_kernel_thrust_helper(Tensor &linearIndex, Tensor &orig_indices, Tensor &sorted_indices, int64_t num_indices) { |
| sorted_indices.copy_(linearIndex); |
| const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); |
| at::cuda::ThrustAllocator allocator; |
| auto policy = thrust::cuda::par(allocator).on(stream); |
| |
| using device_ptr = thrust::device_ptr<int64_t>; |
| |
| // Fill sortedOrigIndices with sequential indices |
| const auto count_iter = thrust::counting_iterator<int64_t>(0); |
| auto orig_data = device_ptr(orig_indices.mutable_data_ptr<int64_t>()); |
| thrust::copy(policy, count_iter, count_iter + num_indices, orig_data); |
| |
| // Sort the inputs into sorted with the corresponding indices; we |
| // don't need a stable or multidimensional sort, so just use Thrust |
| // directly |
| // Sort; a stable sort is not required |
| // NB - not passing comparator causes thrust to use radix sort, and it hurts perf A LOT, at least for medium (few K) sized indices |
| auto sorted_data = device_ptr(sorted_indices.mutable_data_ptr<int64_t>()); |
| thrust::sort_by_key(policy, sorted_data, sorted_data + num_indices, orig_data, LTOp<int64_t>()); |
| } |
| |
| #if !CUB_SUPPORTS_SCAN_BY_KEY() |
| |
| template<typename index_t> |
| void embedding_dense_backward_cuda_scan(Tensor &sorted_indices, Tensor &count) { |
| cudaStream_t stream = at::cuda::getCurrentCUDAStream(); |
| at::cuda::ThrustAllocator allocator; |
| auto policy = thrust::cuda::par(allocator).on(stream); |
| |
| auto num_indices = count.numel(); |
| |
| // Compute an increasing sequence per unique item in sortedIndices: |
| // sorted: 2 5 5 5 7 7 8 9 9 |
| // count: 1 1 2 3 1 2 1 1 2 |
| auto sorted_data = thrust::device_ptr<const index_t>(sorted_indices.const_data_ptr<index_t>()); |
| auto count_data = thrust::device_ptr<index_t>(count.mutable_data_ptr<index_t>()); |
| thrust::inclusive_scan_by_key( |
| policy, |
| sorted_data, |
| sorted_data + num_indices, |
| thrust::make_constant_iterator(1), |
| count_data |
| ); |
| |
| // Take the maximum of each count per unique key in reverse: |
| // sorted: 2 5 5 5 7 7 8 9 9 |
| // count: 1 3 3 3 2 2 1 2 2 |
| thrust::inclusive_scan_by_key( |
| policy, |
| thrust::make_reverse_iterator(sorted_data + num_indices), |
| thrust::make_reverse_iterator(sorted_data), |
| thrust::make_reverse_iterator(count_data + num_indices), |
| thrust::make_reverse_iterator(count_data + num_indices), |
| thrust::equal_to<index_t>(), |
| thrust::maximum<index_t>() |
| ); |
| } |
| |
| template |
| void embedding_dense_backward_cuda_scan<int>(Tensor &sorted_indices, Tensor &count); |
| template |
| void embedding_dense_backward_cuda_scan<int64_t>(Tensor &sorted_indices, Tensor &count); |
| |
| #endif |
| |
| template<typename index_t> |
| int64_t embedding_backward_cuda_kernel_unique_by_key(const Tensor &sorted_indices, Tensor &segment_offsets) { |
| auto stream = at::cuda::getCurrentCUDAStream(); |
| at::cuda::ThrustAllocator allocator; |
| auto policy = thrust::cuda::par(allocator).on(stream); |
| const ptrdiff_t numel = sorted_indices.numel(); |
| auto sorted_indices_dev = thrust::device_ptr<const index_t>(sorted_indices.const_data_ptr<index_t>()); |
| auto dummy = at::empty_like(sorted_indices, LEGACY_CONTIGUOUS_MEMORY_FORMAT); |
| auto dummy_dev = thrust::device_ptr<index_t>(dummy.mutable_data_ptr<index_t>()); |
| auto ends = thrust::unique_by_key_copy( |
| policy, |
| sorted_indices_dev, |
| sorted_indices_dev + numel, |
| thrust::make_counting_iterator(0), |
| dummy_dev, |
| thrust::device_ptr<index_t>(segment_offsets.mutable_data_ptr<index_t>())); |
| return thrust::get<0>(ends) - dummy_dev; |
| } |
| |
| template |
| int64_t embedding_backward_cuda_kernel_unique_by_key<int>(const Tensor &sorted_indices, Tensor &segment_offsets); |
| template |
| int64_t embedding_backward_cuda_kernel_unique_by_key<int64_t>(const Tensor &sorted_indices, Tensor &segment_offsets); |
| |
| } // namespace at::native |