| #define TORCH_ASSERT_ONLY_METHOD_OPERATORS |
| #include <ATen/core/Tensor.h> |
| #include <ATen/AccumulateType.h> |
| #include <ATen/ceil_div.h> |
| #include <ATen/Dispatch.h> |
| #include <ATen/cuda/Atomic.cuh> |
| #include <ATen/cuda/CUDAContext.h> |
| #include <ATen/cuda/DeviceUtils.cuh> |
| #include <ATen/TensorUtils.h> |
| |
| #ifndef AT_PER_OPERATOR_HEADERS |
| #include <ATen/Functions.h> |
| #include <ATen/NativeFunctions.h> |
| #else |
| #include <ATen/ops/arange.h> |
| #include <ATen/ops/empty.h> |
| #include <ATen/ops/empty_like.h> |
| #include <ATen/ops/zeros.h> |
| #include <ATen/ops/_embedding_bag_native.h> |
| #include <ATen/ops/_embedding_bag_forward_only_native.h> |
| #include <ATen/ops/_embedding_bag_dense_backward_native.h> |
| #include <ATen/ops/_embedding_bag_per_sample_weights_backward_native.h> |
| #endif |
| |
| #include <ATen/cuda/cub.cuh> |
| #include <ATen/native/cuda/SortingCommon.cuh> |
| #include <ATen/native/cuda/EmbeddingBackwardKernel.cuh> |
| #include <ATen/native/cuda/KernelUtils.cuh> |
| #include <ATen/native/cuda/block_reduce.cuh> |
| |
| #include <c10/macros/Macros.h> |
| |
| #if CUB_SUPPORTS_SCAN_BY_KEY() |
| #include <thrust/iterator/reverse_iterator.h> |
| #endif |
| |
| namespace at::native { |
| |
| #if !CUB_SUPPORTS_SCAN_BY_KEY() |
| template<typename index_t> |
| void embedding_dense_backward_cuda_scan(Tensor &sorted_indices, Tensor &count); |
| #endif |
| |
| namespace { |
| |
| constexpr int MODE_SUM = 0; |
| constexpr int MODE_MEAN = 1; |
| constexpr int MODE_MAX = 2; |
| |
| std::pair<Tensor, Tensor> promoteIndicesAndOffsets( |
| const Tensor& indices, |
| const Tensor& offsets) { |
| const auto commonType = |
| promoteTypes(offsets.scalar_type(), indices.scalar_type()); |
| return { |
| indices.scalar_type() == commonType ? indices |
| : indices.toType(commonType), |
| offsets.scalar_type() == commonType ? offsets |
| : offsets.toType(commonType)}; |
| } |
| |
| // This kernel assumes that all input tensors except `weight` and |
| // per_sample_weights are contiguous. |
| template <typename scalar_t, typename index_t> |
| __global__ void EmbeddingBag_updateOutputKernel_max( |
| const index_t *input, const index_t *offsets, const scalar_t *weight, scalar_t *output, |
| index_t *offset2bag, int64_t numIndices, int64_t numBags, |
| int64_t featureSize, int64_t weight_stride0, int64_t weight_stride1, |
| index_t *bag_size, index_t *max_indices, |
| index_t padding_idx, int64_t numRows) { |
| |
| // the strategy here is that each bag x feature is handled by a single thread |
| |
| int64_t chunksPerBag = ceil_div(featureSize, (int64_t)blockDim.x); |
| int64_t numChunks = numBags * chunksPerBag; |
| int64_t chunkOffset = blockIdx.x * blockDim.y + threadIdx.y; |
| int64_t chunkStride = gridDim.x * blockDim.y; |
| |
| for (int64_t chunk = chunkOffset; chunk < numChunks; chunk += chunkStride) { |
| int64_t featureDim = (chunk % chunksPerBag) * blockDim.x + threadIdx.x; |
| if (featureDim < featureSize) { |
| int64_t bag = chunk / chunksPerBag; |
| const scalar_t *weightFeat = weight + featureDim * weight_stride1; |
| int64_t begin = bag == 0 ? 0 : offsets[bag]; // forces first offset to be 0 instead of asserting on it |
| int64_t end = (bag < numBags - 1) ? (offsets[bag + 1]) : numIndices; |
| CUDA_KERNEL_ASSERT(end >= begin); |
| scalar_t weightFeatMax = 0; |
| int64_t bag_size_ = 0; |
| int64_t maxWord = -1; |
| for (int64_t emb = begin; emb < end; emb++) { |
| bool pad = (input[emb] == padding_idx); |
| CUDA_KERNEL_ASSERT(input[emb] < numRows); |
| const int64_t weightRow = input[emb] * weight_stride0; |
| scalar_t weightValue = weightFeat[weightRow]; |
| if (bag_size_ == 0 || weightValue > weightFeatMax) { |
| weightFeatMax = pad ? weightFeatMax : weightValue; |
| maxWord = pad ? maxWord : input[emb]; |
| } |
| bag_size_ += pad ? 0 : 1; |
| |
| if (featureDim == 0) { |
| offset2bag[emb] = bag; |
| } |
| } |
| bag_size[bag] = bag_size_; |
| max_indices[bag * featureSize + featureDim] = maxWord; |
| output[bag * featureSize + featureDim] = weightFeatMax; |
| } |
| } |
| } |
| |
| // This kernel assumes that all input tensors except `weight` and |
| // per_sample_weights are contiguous. |
| template <typename scalar_t, typename index_t> |
| __global__ void EmbeddingBag_updateOutputKernel_sum_mean( |
| const index_t *input, const index_t *offsets, const scalar_t *weight, scalar_t *output, |
| index_t *offset2bag, int64_t numIndices, int64_t numBags, |
| int64_t featureSize, int64_t weight_stride0, int64_t weight_stride1, |
| int mode, index_t *bag_size, |
| const scalar_t* per_sample_weights, int64_t per_sample_weights_stride, |
| index_t padding_idx, int64_t numRows) { |
| |
| // the strategy here is that each bag x feature is handled by a single thread |
| |
| using accscalar_t = acc_type<scalar_t, true>; |
| int64_t chunksPerBag = ceil_div(featureSize, (int64_t)blockDim.x); |
| int64_t numChunks = numBags * chunksPerBag; |
| int64_t chunkOffset = blockIdx.x * blockDim.y + threadIdx.y; |
| int64_t chunkStride = gridDim.x * blockDim.y; |
| |
| for (int64_t chunk = chunkOffset; chunk < numChunks; chunk += chunkStride) { |
| int64_t featureDim = (chunk % chunksPerBag) * blockDim.x + threadIdx.x; |
| if (featureDim < featureSize) { |
| int64_t bag = chunk / chunksPerBag; |
| const scalar_t *weightFeat = weight + featureDim * weight_stride1; |
| int64_t begin = bag == 0 ? 0 : offsets[bag]; // forces first offset to be 0 instead of asserting on it |
| int64_t end = (bag < numBags - 1) ? (offsets[bag + 1]) : numIndices; |
| CUDA_KERNEL_ASSERT(end >= begin); |
| accscalar_t weightFeatSum = 0; |
| int64_t bag_size_ = 0; |
| for (int64_t emb = begin; emb < end; emb++) { |
| bool pad = (input[emb] == padding_idx); |
| CUDA_KERNEL_ASSERT(input[emb] < numRows); |
| const int64_t weightRow = input[emb] * weight_stride0; |
| scalar_t weightValue = weightFeat[weightRow]; |
| weightValue = pad ? static_cast<scalar_t>(0) : weightValue; |
| if (per_sample_weights) { |
| accscalar_t scaleWeightBy = static_cast<accscalar_t>( |
| per_sample_weights[emb * per_sample_weights_stride]); |
| weightFeatSum += scaleWeightBy * static_cast<accscalar_t>(weightValue); |
| } else { |
| weightFeatSum += static_cast<accscalar_t>(weightValue); |
| } |
| bag_size_ += pad ? 0 : 1; |
| |
| if (featureDim == 0) { |
| offset2bag[emb] = bag; |
| } |
| } |
| if (mode == MODE_MEAN) { |
| if (bag_size_ != 0) { |
| weightFeatSum = weightFeatSum / static_cast<accscalar_t>(bag_size_); |
| } |
| } |
| bag_size[bag] = bag_size_; |
| output[bag * featureSize + featureDim] = static_cast<scalar_t>(weightFeatSum); |
| } |
| } |
| } |
| |
| Tensor embedding_bag_backward_cuda_sum_avg( |
| const Tensor &grad, |
| const Tensor &indices_, |
| const Tensor &offset2bag, |
| const Tensor &bag_size, |
| int64_t num_weights, |
| bool scale_grad_by_freq, int64_t mode, |
| const Tensor& per_sample_weights, |
| int64_t padding_idx) { |
| auto indices = indices_.contiguous(); |
| |
| ptrdiff_t num_indices = indices.numel(); |
| |
| if (num_indices == 0) { |
| // all empty bags |
| return at::zeros({num_weights, grad.size(1)}, grad.options()); |
| } |
| |
| auto sorted_indices = at::empty_like(indices, LEGACY_CONTIGUOUS_MEMORY_FORMAT); |
| auto orig_indices = at::empty_like(indices, LEGACY_CONTIGUOUS_MEMORY_FORMAT); |
| Tensor count; |
| |
| AT_DISPATCH_INDEX_TYPES(indices.scalar_type(), "embedding_bag_backward_cuda_sum_avg", [&] () { |
| auto range = at::arange(num_indices, indices.options()); |
| // int64_t nbits = cuda::cub::get_num_bits(num_weights); |
| cuda::cub::radix_sort_pairs( |
| indices.const_data_ptr<index_t>(), sorted_indices.mutable_data_ptr<index_t>(), |
| range.const_data_ptr<index_t>(), orig_indices.mutable_data_ptr<index_t>(), |
| num_indices, false/*, 0, nbits*/); |
| }); |
| |
| if (scale_grad_by_freq) { |
| count = at::empty_like(indices, LEGACY_CONTIGUOUS_MEMORY_FORMAT); |
| #if CUB_SUPPORTS_SCAN_BY_KEY() |
| AT_DISPATCH_INDEX_TYPES(indices.scalar_type(), "embedding_bag_backward_cuda_sum_avg", [&] () { |
| cudaStream_t stream = at::cuda::getCurrentCUDAStream(); |
| |
| // Compute an increasing sequence per unique item in sortedIndices: |
| // sorted: 2 5 5 5 7 7 8 9 9 |
| // count: 1 1 2 3 1 2 1 1 2 |
| auto sorted_data = sorted_indices.const_data_ptr<index_t>(); |
| auto count_data = count.mutable_data_ptr<index_t>(); |
| cuda::cub::inclusive_sum_by_key( |
| sorted_data, |
| at_cuda_detail::cub::ConstantInputIterator<index_t>(1), |
| count_data, |
| num_indices |
| ); |
| |
| // Take the maximum of each count per unique key in reverse: |
| // sorted: 2 5 5 5 7 7 8 9 9 |
| // count: 1 3 3 3 2 2 1 2 2 |
| cuda::cub::inclusive_scan_by_key( |
| thrust::make_reverse_iterator(sorted_data + num_indices), |
| thrust::make_reverse_iterator(count_data + num_indices), |
| thrust::make_reverse_iterator(count_data + num_indices), |
| at_cuda_detail::cub::Max(), |
| num_indices |
| ); |
| }); |
| #else |
| AT_DISPATCH_INDEX_TYPES(indices.scalar_type(), "embedding_bag_backward_cuda_sum_avg", [&] () { |
| embedding_dense_backward_cuda_scan<index_t>(sorted_indices, count); |
| }); |
| #endif |
| } |
| return embedding_backward_cuda_kernel(grad, orig_indices, sorted_indices, |
| count, num_weights, padding_idx, mode == MODE_MEAN, offset2bag, |
| bag_size, per_sample_weights); |
| } |
| |
| template <typename scalar_t, typename index_t> |
| __global__ void EmbeddingBag_accGradParametersKernel_max( |
| const index_t *max_indices, const scalar_t *gradOutput, |
| scalar_t *gradWeight, int64_t stride, int64_t numBags, |
| index_t padding_idx, const index_t numel) { |
| |
| using accscalar_t = acc_type<scalar_t, true>; |
| |
| int64_t chunksPerBag = ceil_div(stride, (int64_t)blockDim.x); |
| int64_t numChunks = numBags * chunksPerBag; |
| int64_t chunkOffset = blockIdx.x * blockDim.y + threadIdx.y; |
| int64_t chunkStride = gridDim.x * blockDim.y; |
| |
| for (int64_t chunk = chunkOffset; chunk < numChunks; chunk += chunkStride) { |
| int64_t featureDim = (chunk % chunksPerBag) * blockDim.x + threadIdx.x; |
| if (featureDim < stride) { |
| int64_t bag = chunk / chunksPerBag; |
| |
| index_t word_idx = max_indices[bag * stride + featureDim]; |
| if (word_idx >= 0 && word_idx != padding_idx) { |
| // If bag is empty, we have max_indices[idx] set to -1 in forward. |
| fastAtomicAdd( |
| gradWeight, static_cast<index_t>(word_idx * stride + featureDim), |
| numel, gradOutput[bag * stride + featureDim], true); |
| } |
| } |
| } |
| } |
| |
| Tensor embedding_bag_backward_cuda_max(const Tensor &grad, |
| const Tensor &max_indices, |
| int64_t num_weights, |
| int64_t padding_idx) { |
| // See Note [Writing Nondeterministic Operations] |
| // Nondeterministic because of atomicAdd usage |
| globalContext().alertNotDeterministic("embedding_bag_backward_cuda_max"); |
| |
| auto grad_weight = at::zeros({num_weights, grad.size(1)}, grad.options()); |
| |
| int64_t stride = grad_weight.stride(0); |
| |
| int64_t numBags = grad.size(0); |
| |
| cudaStream_t stream = at::cuda::getCurrentCUDAStream(); |
| |
| #if defined(USE_ROCM) |
| dim3 block = dim3(64, 4); |
| #else |
| dim3 block = dim3(32, 8); |
| #endif |
| int grid = 1024; |
| |
| AT_DISPATCH_FLOATING_TYPES_AND_HALF( |
| grad.scalar_type(), "embedding_bag_backward_cuda_max", [&] { |
| AT_DISPATCH_INDEX_TYPES(max_indices.scalar_type(), "embedding_bag_backward_cuda_max", [&] () { |
| EmbeddingBag_accGradParametersKernel_max< |
| scalar_t, index_t><<<grid, block, 0, stream>>>( |
| max_indices.const_data_ptr<index_t>(), grad.const_data_ptr<scalar_t>(), |
| grad_weight.mutable_data_ptr<scalar_t>(), stride, numBags, |
| padding_idx, grad_weight.numel()); |
| C10_CUDA_KERNEL_LAUNCH_CHECK(); |
| }); |
| }); |
| |
| return grad_weight; |
| } |
| } |
| |
| // Assumes all input tensors are contiguous. |
| // See NOTE [ embedding_bag Native Functions ] in native_functions.yaml for details |
| std::tuple<Tensor, Tensor, Tensor, Tensor> |
| _embedding_bag_forward_only_cuda(const Tensor &weight, const Tensor &indices, |
| const Tensor &offsets, const bool scale_grad_by_freq, |
| const int64_t mode, bool sparse, const c10::optional<Tensor>& per_sample_weights_opt, |
| bool include_last_offset, int64_t padding_idx) { |
| // See [Note: hacky wrapper removal for optional tensor] |
| c10::MaybeOwned<Tensor> per_sample_weights_maybe_owned = at::borrow_from_optional_tensor(per_sample_weights_opt); |
| const Tensor& per_sample_weights = *per_sample_weights_maybe_owned; |
| |
| return _embedding_bag_cuda( |
| weight, |
| indices, |
| offsets, |
| scale_grad_by_freq, |
| mode, |
| sparse, |
| per_sample_weights, |
| include_last_offset, |
| padding_idx); |
| } |
| |
| // Assumes all input tensors are contiguous. |
| // See NOTE [ embedding_bag Native Functions ] in native_functions.yaml for details |
| std::tuple<Tensor, Tensor, Tensor, Tensor> |
| _embedding_bag_cuda(const Tensor &weight, const Tensor &indices_, |
| const Tensor &offsets_, const bool scale_grad_by_freq, |
| const int64_t mode, bool sparse, const c10::optional<Tensor>& per_sample_weights_opt, |
| bool include_last_offset, int64_t padding_idx) { |
| TORCH_CHECK(indices_.dim() == 1 || indices_.dim() == 2, |
| "input has to be a 1D or 2D Tensor, but got Tensor of dimension ", |
| indices_.dim()); |
| if (indices_.dim() == 1) { |
| TORCH_CHECK(offsets_.dim() == 1, |
| "offsets has to be a 1D Tensor, but got Tensor of dimension ", |
| offsets_.dim()); |
| } |
| TORCH_CHECK(weight.dim() == 2, |
| "weight has to be a 2D Tensor, but got Tensor of dimension ", |
| weight.dim()); |
| // See [Note: hacky wrapper removal for optional tensor] |
| c10::MaybeOwned<Tensor> per_sample_weights_maybe_owned = at::borrow_from_optional_tensor(per_sample_weights_opt); |
| const Tensor& per_sample_weights = *per_sample_weights_maybe_owned; |
| |
| Tensor indices, offsets; |
| std::tie(indices, offsets) = promoteIndicesAndOffsets(indices_, offsets_); |
| auto indices_arg = TensorArg(indices, "indices", 1); |
| checkScalarTypes("embedding_bag_cuda", indices_arg, {kLong, kInt}); |
| auto offsets_arg = TensorArg(offsets, "offsets", 1); |
| checkScalarTypes("embedding_bag_cuda", offsets_arg, {kLong, kInt}); |
| checkSameType("embedding_bag_cuda", indices_arg, offsets_arg); |
| auto weight_arg = TensorArg(weight, "weight", 1); |
| checkSameGPU("embedding_bag_cuda", weight_arg, indices_arg); |
| checkSameGPU("embedding_bag_cuda", weight_arg, offsets_arg); |
| |
| int64_t numIndices = indices.size(0); |
| int64_t numBags = offsets.size(0); |
| if (include_last_offset) { |
| // Check https://github.com/pytorch/pytorch/issues/29019 |
| // We plan to add one more element in offsets, which is equal to the size of |
| // indices. Currently for cuda devices, we still use the legacy |
| // implementation even this flag is enabled. |
| TORCH_CHECK( |
| numBags >= 1, "include_last_offset: numBags should be at least 1"); |
| numBags -= 1; |
| } |
| int64_t featureSize = weight.size(1); |
| |
| auto bag_size = at::empty(offsets.sizes(), indices.options()); |
| auto offset2bag = |
| at::empty({indices.size(0)}, indices.options()); // offset2bag = [0 0 0 0 0] |
| |
| cudaStream_t stream = at::cuda::getCurrentCUDAStream(); |
| |
| auto output = at::empty({numBags, featureSize}, weight.options()); |
| |
| Tensor max_indices; |
| |
| if (mode == MODE_MAX) { |
| max_indices = at::empty({numBags, featureSize}, indices.options()); |
| } else { |
| // No need to allocate if we aren't doing a backwards pass |
| max_indices = at::empty({0}, indices.options()); |
| } |
| |
| #if defined(USE_ROCM) |
| dim3 block = dim3(64, 4); |
| #else |
| dim3 block = dim3(32, 8); |
| #endif |
| int grid = 1024; |
| AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, weight.scalar_type(), "embedding_bag_cuda", [&] { |
| AT_DISPATCH_INDEX_TYPES(indices.scalar_type(), "embedding_bag_cuda", [&] () { |
| if (mode == MODE_MAX) { |
| EmbeddingBag_updateOutputKernel_max<scalar_t, index_t><<<grid, block, 0, stream>>>( |
| indices.const_data_ptr<index_t>(), offsets.const_data_ptr<index_t>(), |
| weight.const_data_ptr<scalar_t>(), output.mutable_data_ptr<scalar_t>(), |
| offset2bag.mutable_data_ptr<index_t>(), numIndices, numBags, featureSize, |
| weight.stride(0), weight.stride(1), bag_size.mutable_data_ptr<index_t>(), |
| max_indices.mutable_data_ptr<index_t>(), |
| padding_idx, weight.size(0)); |
| C10_CUDA_KERNEL_LAUNCH_CHECK(); |
| } else { |
| EmbeddingBag_updateOutputKernel_sum_mean<scalar_t, index_t><<<grid, block, 0, stream>>>( |
| indices.const_data_ptr<index_t>(), offsets.const_data_ptr<index_t>(), |
| weight.const_data_ptr<scalar_t>(), output.mutable_data_ptr<scalar_t>(), |
| offset2bag.mutable_data_ptr<index_t>(), numIndices, numBags, featureSize, |
| weight.stride(0), weight.stride(1), mode, bag_size.mutable_data_ptr<index_t>(), |
| per_sample_weights.defined() ? per_sample_weights.const_data_ptr<scalar_t>() : NULL, |
| per_sample_weights.defined() ? per_sample_weights.stride(0) : 0, |
| padding_idx, weight.size(0)); |
| C10_CUDA_KERNEL_LAUNCH_CHECK(); |
| } |
| }); |
| }); |
| |
| return std::tuple<Tensor, Tensor, Tensor, Tensor>(output, offset2bag, bag_size, max_indices); |
| } |
| |
| Tensor _embedding_bag_dense_backward_cuda(const Tensor &grad_, const Tensor &indices, |
| const Tensor &offset2bag, |
| const Tensor &bag_size_, |
| const Tensor &max_indices, |
| int64_t num_weights, |
| bool scale_grad_by_freq, int64_t mode, const c10::optional<Tensor>& per_sample_weights_opt, |
| int64_t padding_idx) { |
| // See [Note: hacky wrapper removal for optional tensor] |
| c10::MaybeOwned<Tensor> per_sample_weights_maybe_owned = at::borrow_from_optional_tensor(per_sample_weights_opt); |
| const Tensor& per_sample_weights = *per_sample_weights_maybe_owned; |
| |
| // indices, offsets and offset2bag are assumed having correct dtypes and |
| // contiguous here due to the checks in _embedding_bag_backward in |
| // EmbeddingBag.cpp. |
| // Also see NOTE [ embedding_bag Native Functions ] in native_functions.yaml |
| // for more details. |
| |
| Tensor grad = grad_.contiguous(); |
| auto indices_arg = TensorArg(indices, "indices", 1); |
| auto grad_arg = TensorArg(grad, "grad", 1); |
| checkSameGPU("embedding_bag_cuda", grad_arg, indices_arg); |
| |
| |
| switch (mode) { |
| case MODE_SUM: |
| case MODE_MEAN: |
| if (mode == MODE_MEAN) |
| AT_ASSERT(!per_sample_weights.defined()); |
| return embedding_bag_backward_cuda_sum_avg(grad, indices, offset2bag, |
| bag_size_, num_weights, scale_grad_by_freq, mode, |
| per_sample_weights, padding_idx); |
| |
| case MODE_MAX: |
| AT_ASSERT(!per_sample_weights.defined()); |
| return embedding_bag_backward_cuda_max(grad, max_indices, num_weights, |
| padding_idx); |
| |
| default: |
| AT_ERROR( |
| "Unknown mode for embedding_bag_backward_cuda ", mode); |
| } |
| } |
| |
| template <typename scalar_t, typename index_t> |
| __global__ static void _embedding_bag_per_sample_weights_backward_kernel( |
| const scalar_t* grad, int64_t grad_stride0, int64_t grad_stride1, |
| const scalar_t* weight, int64_t weight_stride0, int64_t weight_stride1, |
| const index_t* indices, // contiguous |
| const index_t* offset2bag, // contiguous |
| int64_t num_samples, |
| int64_t embedding_features, |
| scalar_t* output, |
| index_t padding_idx) { |
| using accscalar_t = acc_type<scalar_t, true>; |
| const int idx = threadIdx.x + blockIdx.x * blockDim.x; |
| const int warp = idx / C10_WARP_SIZE; |
| const int thread_in_warp = idx % C10_WARP_SIZE; |
| const int num_warps = blockDim.x * gridDim.x / C10_WARP_SIZE; |
| |
| // Each warp is responsible for the accumulation of one sample. |
| // This involves doing one dot product between grad[bag_idx] and weight[embedding_idx]. |
| for (int sample_idx = warp; sample_idx < num_samples; sample_idx += num_warps) { |
| accscalar_t result = 0.; |
| const int bag_idx = (int)offset2bag[sample_idx]; |
| const int embedding_idx = (int)indices[sample_idx]; |
| if (embedding_idx != padding_idx) { |
| for (int feature_idx = thread_in_warp; feature_idx < embedding_features; |
| feature_idx += C10_WARP_SIZE) { |
| result += |
| grad[grad_stride0 * bag_idx + grad_stride1 * feature_idx] * |
| weight[weight_stride0 * embedding_idx + weight_stride1 * feature_idx]; |
| } |
| } |
| result = cuda_utils::WarpReduceSum<accscalar_t>(result); |
| if (thread_in_warp == 0) { |
| output[sample_idx] = result; |
| } |
| } |
| } |
| |
| Tensor _embedding_bag_per_sample_weights_backward_cuda( |
| const Tensor& grad, |
| const Tensor& weight, // NB: embedding table, not per_sample_weights |
| const Tensor& indices_, |
| const Tensor& offsets_, |
| const Tensor& offset2bag, |
| int64_t mode, |
| int64_t padding_idx) { |
| TORCH_CHECK( |
| mode == MODE_SUM, |
| "embedding_bag_backward: per_sample_weights only supported for mode='sum'"); |
| |
| AT_ASSERT(grad.dim() == 2); |
| auto embedding_features = grad.size(1); |
| |
| Tensor indices, offsets; |
| std::tie(indices, offsets) = promoteIndicesAndOffsets(indices_, offsets_); |
| AT_ASSERT(indices.dim() == 1); |
| auto num_samples = indices.size(0); |
| |
| AT_ASSERT(weight.dim() == 2); |
| AT_ASSERT(weight.size(1) == embedding_features); |
| |
| const int threads_per_block = 512; |
| const int warps_per_block = threads_per_block / at::cuda::warp_size(); |
| |
| dim3 block(threads_per_block); |
| dim3 grid((num_samples + warps_per_block - 1) / warps_per_block); |
| |
| auto output = at::empty({num_samples}, grad.options()); |
| |
| // Early return when there is no samples in the batch. This saves unnecessary kernel |
| // launch, but also prevents cudaGetLastError() to complain about invalid launch args |
| if (num_samples == 0) { |
| return output; |
| } |
| |
| AT_DISPATCH_FLOATING_TYPES_AND_HALF( |
| grad.scalar_type(), "_embedding_bag_per_sample_weights_backward_cuda", [&]() { |
| AT_DISPATCH_INDEX_TYPES(indices.scalar_type(), "_embedding_bag_per_sample_weights_backward_cuda", [&]() { |
| _embedding_bag_per_sample_weights_backward_kernel<scalar_t, index_t> |
| <<<grid, block, 0, at::cuda::getCurrentCUDAStream()>>>( |
| grad.const_data_ptr<scalar_t>(), grad.stride(0), grad.stride(1), |
| weight.const_data_ptr<scalar_t>(), weight.stride(0), weight.stride(1), |
| indices.const_data_ptr<index_t>(), |
| offset2bag.const_data_ptr<index_t>(), |
| num_samples, |
| embedding_features, |
| output.mutable_data_ptr<scalar_t>(), |
| padding_idx); |
| C10_CUDA_KERNEL_LAUNCH_CHECK(); |
| }); |
| } |
| ); |
| return output; |
| } |
| |
| } // namespace at::native |