| #define TORCH_ASSERT_ONLY_METHOD_OPERATORS |
| #include <ATen/native/EmbeddingBag.h> |
| #include <ATen/Dispatch.h> |
| #include <ATen/Parallel.h> |
| #include <ATen/TensorOperators.h> |
| #include <ATen/TensorUtils.h> |
| #include <ATen/TensorSubclassLikeUtils.h> |
| |
| #include <ATen/native/CPUBlas.h> |
| #include <ATen/native/NonSymbolicBC.h> |
| |
| #include <c10/util/irange.h> |
| #include <c10/util/Half.h> |
| |
| #ifdef USE_FBGEMM |
| #include <fbgemm/Fbgemm.h> |
| #include <fbgemm/FbgemmConvert.h> |
| #else |
| #include <caffe2/perfkernels/embedding_lookup_idx.h> |
| #endif |
| |
| #include <algorithm> |
| #include <cstring> |
| #include <tuple> |
| #include <vector> |
| |
| #ifndef AT_PER_OPERATOR_HEADERS |
| #include <ATen/Functions.h> |
| #include <ATen/NativeFunctions.h> |
| #else |
| #include <ATen/ops/_embedding_bag.h> |
| #include <ATen/ops/_embedding_bag_backward_native.h> |
| #include <ATen/ops/_embedding_bag_dense_backward.h> |
| #include <ATen/ops/_embedding_bag_dense_backward_native.h> |
| #include <ATen/ops/_embedding_bag_forward_only.h> |
| #include <ATen/ops/_embedding_bag_forward_only_native.h> |
| #include <ATen/ops/_embedding_bag_native.h> |
| #include <ATen/ops/_embedding_bag_per_sample_weights_backward_native.h> |
| #include <ATen/ops/_embedding_bag_sparse_backward.h> |
| #include <ATen/ops/_embedding_bag_sparse_backward_native.h> |
| #include <ATen/ops/embedding_backward_native.h> |
| #include <ATen/ops/embedding_bag_native.h> |
| #include <ATen/ops/empty.h> |
| #include <ATen/ops/max.h> |
| #include <ATen/ops/ones_like.h> |
| #include <ATen/ops/resize_native.h> |
| #include <ATen/ops/zero_native.h> |
| #include <ATen/ops/zeros.h> |
| #endif |
| |
| namespace { |
| const int MODE_SUM = 0; |
| const int MODE_MEAN = 1; |
| const int MODE_MAX = 2; |
| } |
| |
| namespace at { |
| namespace native { |
| |
| template<typename scalar_t> |
| scalar_t dot_impl(int64_t n, scalar_t *x, int64_t incx, scalar_t *y, int64_t incy); |
| |
| static void make_offset2bag(const Tensor &offsets, Tensor& offset2bag) { |
| offset2bag.index_add_( |
| 0, offsets, at::ones_like(offsets, LEGACY_CONTIGUOUS_MEMORY_FORMAT)); // offset2bag = [1 0 1 0 1] |
| offset2bag[0] -= 1; // offset2bag = [0 0 1 0 1] |
| offset2bag = offset2bag.cumsum(0, offset2bag.scalar_type()); // offset2bag = [0 0 1 1 2] |
| } |
| |
| namespace { |
| |
| std::pair<Tensor, Tensor> promoteIndicesAndOffsets( |
| const Tensor& indices, |
| const Tensor& offsets) { |
| const auto commonType = |
| promoteTypes(offsets.scalar_type(), indices.scalar_type()); |
| return { |
| indices.scalar_type() == commonType ? indices |
| : indices.toType(commonType), |
| offsets.scalar_type() == commonType ? offsets |
| : offsets.toType(commonType)}; |
| } |
| |
| // Determines if we can use a fast implementation for index_select_add, which |
| // is only applicable if special conditions are met |
| template<typename index_t> |
| bool is_fast_path_index_select(const Tensor& src, Tensor& output, index_t padding_idx) { |
| return (src.scalar_type() == kFloat || src.scalar_type() == kHalf) && src.strides()[1] == 1 && output.strides()[1] == 1 && padding_idx < static_cast<index_t>(0); |
| } |
| |
| // Determines if we can use a fast implementation for index_select_scale_add, |
| // which is only applicable if special conditions are met |
| template<typename index_t> |
| bool is_fast_path_index_select_scale(const Tensor& src, const Tensor& scale, Tensor& output, index_t padding_idx) { |
| return (src.scalar_type() == kFloat || src.scalar_type() == kHalf) && src.strides()[1] == 1 && output.strides()[1] == 1 && scale.strides()[0] == 1 && padding_idx < static_cast<index_t>(0); |
| } |
| |
| template<typename index_t> |
| bool is_fast_path(const Tensor& src, const c10::optional<Tensor>& scale, Tensor& output, index_t padding_idx) { |
| return (scale.has_value() && scale.value().defined()) ? |
| is_fast_path_index_select_scale(src, scale.value(), output, padding_idx) : |
| is_fast_path_index_select(src, output, padding_idx); |
| } |
| |
| // This function combines index_select (using select_indices as the index) and |
| // index_add (using add_indices as the index), without creating an intermediary |
| // tensor to hold the selected embeddings |
| template<typename data_t, typename index_t> |
| typename std::enable_if<!std::is_same<data_t, float>::value && !std::is_same<data_t, at::Half>::value, void>::type |
| index_select_add(const Tensor &select_indices, |
| const Tensor &add_indices, |
| const Tensor &src, |
| Tensor &output, |
| const Tensor& /*offsets*/, |
| bool /*include_last_offset*/, |
| Tensor &bag_size, |
| index_t padding_idx, |
| _EmbeddingBagKernelCache* /* fbgemm_kernel_cache */) { |
| TORCH_CHECK(select_indices.numel() == add_indices.numel()); |
| auto* add_indices_data = add_indices.data_ptr<index_t>(); |
| auto* select_indices_data = select_indices.data_ptr<index_t>(); |
| auto* src_data = src.data_ptr<data_t>(); |
| auto* output_data = output.data_ptr<data_t>(); |
| // NOLINTNEXTLINE(cppcoreguidelines-init-variables) |
| index_t* bag_size_data = nullptr; |
| if (bag_size.defined()) { |
| bag_size_data = bag_size.data_ptr<index_t>(); |
| } |
| auto numel = add_indices.numel(); |
| int64_t ddim = src.size(1); |
| auto vocab_size = src.size(0); |
| auto src_stride0 = src.strides()[0]; |
| auto src_stride1 = src.strides()[1]; |
| auto output_stride0 = output.strides()[0]; |
| auto output_stride1 = output.strides()[1]; |
| |
| for (const auto i : c10::irange(numel)) { |
| // We can skip indices equal to padding_idx so they are not included in |
| // the reduction |
| auto idx = select_indices_data[i]; |
| TORCH_CHECK( |
| idx >= 0 && idx < vocab_size, |
| "embedding_bag: Expected idx >= 0 && idx < num_embeddings but found idx to be ", |
| idx); |
| if (idx != padding_idx) { |
| at::native::cpublas::axpy<data_t>(ddim, 1, |
| src_data + src_stride0 * idx, src_stride1, |
| output_data + output_stride0 * add_indices_data[i], output_stride1); |
| } else if (bag_size.defined()) { |
| // Decrement bag_size to reflect that the index is padded |
| // NOLINTNEXTLINE(clang-analyzer-core.NullDereference) |
| bag_size_data[add_indices_data[i]]--; |
| } |
| } |
| } |
| |
| namespace { |
| template <typename index_t> |
| void fbgemm_spmdm_report_error_( |
| int64_t output_size, |
| int index_size, |
| int64_t N, |
| const index_t* offsets, |
| const index_t* indices) { |
| for (const auto m : c10::irange(output_size)) { |
| for (index_t i = offsets[m]; i < offsets[m + 1]; ++i) { |
| TORCH_CHECK(i < index_size); |
| index_t idx = indices[i]; |
| TORCH_CHECK( |
| 0 <= idx && idx < N, |
| "Index ", |
| i, |
| " is out of bounds: ", |
| idx, |
| ", range 0 to ", |
| N); |
| } |
| } |
| TORCH_CHECK( |
| offsets[output_size] == index_size, |
| "Yout input seems to be incorrect: the last offset value should be " |
| "the size of the indices tensor, but it appears not."); |
| } |
| } // namespace |
| |
| template<typename data_t, typename index_t> |
| typename std::enable_if<std::is_same<data_t, at::Half>::value, void>::type |
| index_select_add(const Tensor &select_indices, |
| const Tensor &add_indices, |
| const Tensor &src, |
| Tensor &output, |
| const Tensor& offsets, |
| bool include_last_offset, |
| Tensor &bag_size, |
| index_t padding_idx, |
| _EmbeddingBagKernelCache* fbgemm_kernel_cache) { |
| int64_t ddim = src.size(1); |
| auto* select_indices_data = select_indices.data_ptr<index_t>(); |
| auto* output_data = output.data_ptr<at::Half>(); |
| |
| if (is_fast_path_index_select(src, output, padding_idx)) { |
| auto src_contig = src.contiguous(); |
| auto* src_data = src_contig.data_ptr<at::Half>(); |
| int64_t output_size = offsets.numel() - 1; |
| auto* offsets_data = offsets.data_ptr<index_t>(); |
| std::vector<index_t> offsets_include_last; |
| |
| if (include_last_offset) { |
| output_size = offsets.numel() - 1; |
| } else { |
| output_size = offsets.numel(); |
| offsets_include_last.resize(offsets.numel() + 1); |
| if (offsets.numel() > 0) { |
| std::memcpy( |
| offsets_include_last.data(), |
| offsets.data_ptr<index_t>(), |
| sizeof(index_t) * offsets.numel()); |
| } |
| offsets_include_last[offsets.numel()] = select_indices.numel(); |
| offsets_data = offsets_include_last.data(); |
| } |
| |
| #ifdef USE_FBGEMM |
| using float16 = uint16_t; |
| auto kernel_fp16_index_t = fbgemm_kernel_cache ? |
| fbgemm_kernel_cache->getCallback</* has_weight */ false, index_t, float16>(ddim) : |
| fbgemm::GenerateEmbeddingSpMDM<float16, index_t, index_t, float16>( |
| /* block_size */ddim, |
| /* has_weight */false, |
| /* normalize_by_lengths */false, |
| /* prefetch */16, |
| /* is_weight_positional */false, |
| /* use_offsets */true |
| ); |
| #else |
| // Initialize the intermediate output buffer to be 0. |
| Tensor output_fp32 = at::zeros({output_size, ddim}, output.options().dtype(at::kFloat)); |
| auto* output_data_fp32 = output_fp32.data_ptr<float>(); |
| #endif |
| at::parallel_for( |
| 0, output_size, 1, [&](index_t start_idx, index_t end_idx) { |
| #ifdef USE_FBGEMM |
| bool success = kernel_fp16_index_t( |
| /* output_size */end_idx - start_idx, |
| /* index_size */offsets_data[end_idx] - offsets_data[start_idx], |
| /* data_size */src.size(0), |
| /* input */reinterpret_cast<const float16*>(src_data), |
| /* indices */select_indices_data + offsets_data[start_idx], |
| /* offsets_or_lengths */offsets_data + start_idx, |
| /* weights */nullptr, |
| /* output */reinterpret_cast<float16*>(output_data + start_idx * ddim)); |
| if (!success) { |
| fbgemm_spmdm_report_error_( |
| end_idx - start_idx, |
| offsets_data[end_idx] - offsets_data[start_idx], |
| src.size(0), |
| offsets_data + start_idx, |
| select_indices_data + offsets_data[start_idx]); |
| } |
| #else |
| caffe2::EmbeddingLookupIdx( |
| /*block_size=*/ddim, |
| /*output_size=*/end_idx - start_idx, |
| /*index_size=*/offsets_data[end_idx] - offsets_data[start_idx], |
| /*data_size=*/src.size(0), |
| /*input=*/src_data, |
| /*indices=*/select_indices_data + offsets_data[start_idx], |
| /*offsets=*/offsets_data + start_idx, |
| /*weights=*/nullptr, |
| /*scale_bias=*/nullptr, |
| /*normalize_by_lengths=*/false, |
| /*out=*/output_data_fp32 + start_idx * ddim); |
| for (const auto i : c10::irange(output_size)) { |
| // Convert FP32 intermediate buffer result back to FP16 for output dtype |
| for (const auto d : c10::irange(ddim)) { |
| (output_data + i * ddim)[d] = static_cast<at::Half>((output_data_fp32 + ddim * i)[d]); |
| } |
| } |
| #endif |
| }); |
| |
| } else { |
| TORCH_CHECK(select_indices.numel() == add_indices.numel()); |
| auto* src_data = src.data_ptr<at::Half>(); |
| auto* add_indices_data = add_indices.data_ptr<index_t>(); |
| // NOLINTNEXTLINE(cppcoreguidelines-init-variables) |
| index_t* bag_size_data = nullptr; |
| if (bag_size.defined()) { |
| bag_size_data = bag_size.data_ptr<index_t>(); |
| } |
| auto vocab_size = src.size(0); |
| auto src_stride0 = src.strides()[0]; |
| auto src_stride1 = src.strides()[1]; |
| auto output_stride0 = output.strides()[0]; |
| auto output_stride1 = output.strides()[1]; |
| auto numel = add_indices.numel(); |
| |
| Tensor src_fp32 = at::empty({ddim}, src.options().dtype(at::kFloat)); |
| auto* src_data_fp32 = src_fp32.data_ptr<float>(); |
| |
| // Initialize the intermediate output buffer to be 0. |
| Tensor output_fp32 = at::zeros({output.size(0), ddim}, output.options().dtype(at::kFloat)); |
| auto* output_data_fp32 = output_fp32.data_ptr<float>(); |
| |
| for (const auto i : c10::irange(numel)) { |
| // We can skip indices equal to padding_idx so they are not included in |
| // the reduction |
| auto idx = select_indices_data[i]; |
| TORCH_CHECK( |
| idx >= 0 && idx < vocab_size, |
| "embedding_bag: Expected idx >= 0 && idx < num_embeddings but found idx to be ", |
| idx); |
| if (idx != padding_idx) { |
| // Copy src_data + src_stride0 * idx to src_data_fp32 |
| for (const auto d : c10::irange(ddim)) { |
| src_data_fp32[d] = static_cast<float>((src_data + src_stride0 * idx)[d * src_stride1]); |
| } |
| at::native::cpublas::axpy<float>(ddim, 1, |
| src_data_fp32, 1, |
| output_data_fp32 + ddim * add_indices_data[i], 1); |
| |
| } else if (bag_size.defined()) { |
| // Decrement bag_size to reflect that the index is padded |
| // NOLINTNEXTLINE(clang-analyzer-core.NullDereference) |
| bag_size_data[add_indices_data[i]]--; |
| } |
| } |
| for (const auto i : c10::irange(output.size(0))) { |
| // Convert FP32 intermediate buffer result back to FP16 for output dtype |
| for (const auto d : c10::irange(ddim)) { |
| (output_data + output_stride0 * i)[d * output_stride1] = static_cast<at::Half>((output_data_fp32 + ddim * i)[d]); |
| } |
| } |
| } |
| } |
| |
| template<typename data_t, typename index_t> |
| typename std::enable_if<std::is_same<data_t, float>::value, void>::type |
| index_select_add(const Tensor &select_indices, |
| const Tensor &add_indices, |
| const Tensor &src, |
| Tensor &output, |
| const Tensor& offsets, |
| bool include_last_offset, |
| Tensor &bag_size, |
| index_t padding_idx, |
| _EmbeddingBagKernelCache* fbgemm_kernel_cache) { |
| int64_t ddim = src.size(1); |
| auto* select_indices_data = select_indices.data_ptr<index_t>(); |
| auto* output_data = output.data_ptr<float>(); |
| |
| if (is_fast_path_index_select(src, output, padding_idx)) { |
| auto src_contig = src.contiguous(); |
| auto* src_data = src_contig.data_ptr<float>(); |
| int64_t output_size = offsets.numel() - 1; |
| auto* offsets_data = offsets.data_ptr<index_t>(); |
| std::vector<index_t> offsets_include_last; |
| |
| if (include_last_offset) { |
| output_size = offsets.numel() - 1; |
| } else { |
| output_size = offsets.numel(); |
| offsets_include_last.resize(offsets.numel() + 1); |
| if (offsets.numel() > 0) { |
| std::memcpy( |
| offsets_include_last.data(), |
| offsets.data_ptr<index_t>(), |
| sizeof(index_t) * offsets.numel()); |
| } |
| offsets_include_last[offsets.numel()] = select_indices.numel(); |
| offsets_data = offsets_include_last.data(); |
| } |
| |
| #ifdef USE_FBGEMM |
| auto kernel_fp32_index_t = |
| fbgemm_kernel_cache ? |
| fbgemm_kernel_cache->getCallback</* has_weight */ false, index_t, float>(ddim) : |
| fbgemm::GenerateEmbeddingSpMDM<float, index_t, index_t>( |
| /* block_size */ddim, |
| /* has_weight */false, |
| /* normalize_by_lengths */false, |
| /* prefetch */16, |
| /* is_weight_positional */false, |
| /* use_offsets */true |
| ); |
| #endif |
| at::parallel_for( |
| 0, output_size, 1, [&](index_t start_idx, index_t end_idx) { |
| #ifdef USE_FBGEMM |
| bool success = kernel_fp32_index_t( |
| /* output_size */end_idx - start_idx, |
| /* index_size */offsets_data[end_idx] - offsets_data[start_idx], |
| /* data_size */src.size(0), |
| /* input */src_data, |
| /* indices */select_indices_data + offsets_data[start_idx], |
| /* offsets_or_lengths */offsets_data + start_idx, |
| /* weights */nullptr, |
| /* output */output_data + start_idx * ddim); |
| if (!success) { |
| fbgemm_spmdm_report_error_( |
| end_idx - start_idx, |
| offsets_data[end_idx] - offsets_data[start_idx], |
| src.size(0), |
| offsets_data + start_idx, |
| select_indices_data + offsets_data[start_idx]); |
| } |
| #else |
| caffe2::EmbeddingLookupIdx( |
| /*block_size=*/ddim, |
| /*output_size=*/end_idx - start_idx, |
| /*index_size=*/offsets_data[end_idx] - offsets_data[start_idx], |
| /*data_size=*/src.size(0), |
| /*input=*/src_data, |
| /*indices=*/select_indices_data + offsets_data[start_idx], |
| /*offsets=*/offsets_data + start_idx, |
| /*weights=*/nullptr, |
| /*scale_bias=*/nullptr, |
| /*normalize_by_lengths=*/false, |
| /*out=*/output_data + start_idx * ddim); |
| #endif |
| }); |
| } else { |
| AT_ASSERT(select_indices.numel() == add_indices.numel()); |
| auto* src_data = src.data_ptr<float>(); |
| auto* add_indices_data = add_indices.data_ptr<index_t>(); |
| // NOLINTNEXTLINE(cppcoreguidelines-init-variables) |
| index_t* bag_size_data = nullptr; |
| if (bag_size.defined()) { |
| bag_size_data = bag_size.data_ptr<index_t>(); |
| } |
| auto vocab_size = src.size(0); |
| auto src_stride0 = src.strides()[0]; |
| auto src_stride1 = src.strides()[1]; |
| auto output_stride0 = output.strides()[0]; |
| auto output_stride1 = output.strides()[1]; |
| auto numel = add_indices.numel(); |
| for (const auto i : c10::irange(numel)) { |
| // We can skip indices equal to padding_idx so they are not included in |
| // the reduction |
| auto idx = select_indices_data[i]; |
| TORCH_CHECK( |
| idx >= 0 && idx < vocab_size, |
| "embedding_bag: Expected idx >= 0 && idx < num_embeddings but found idx to be ", |
| idx); |
| if (idx != padding_idx) { |
| at::native::cpublas::axpy<float>( |
| ddim, |
| 1, |
| src_data + src_stride0 * idx, |
| src_stride1, |
| output_data + output_stride0 * add_indices_data[i], |
| output_stride1); |
| } else if (bag_size.defined()) { |
| // Decrement bag_size to reflect that the index is padded |
| // NOLINTNEXTLINE(clang-analyzer-core.NullDereference) |
| bag_size_data[add_indices_data[i]]--; |
| } |
| } |
| } |
| } |
| |
| // This function fuses the following three fns: |
| // index_select (using select_indices as the index) |
| // mul (scaling by per_sample_weights) |
| // index_add (using add_indices as the index) |
| template<typename data_t, typename index_t> |
| static typename std::enable_if<!std::is_same<data_t, float>::value && !std::is_same<data_t, at::Half>::value, void>::type |
| index_select_scale_add(const Tensor &select_indices, |
| const Tensor &add_indices, |
| const Tensor &scale, |
| const Tensor &src, |
| Tensor &output, |
| const Tensor& /*offsets*/, |
| bool /*include_last_offset*/, |
| Tensor &bag_size, |
| index_t padding_idx, |
| _EmbeddingBagKernelCache* /* fbgemm_kernel_cache */) { |
| AT_ASSERT(select_indices.numel() == add_indices.numel()); |
| auto* add_indices_data = add_indices.data_ptr<index_t>(); |
| auto* select_indices_data = select_indices.data_ptr<index_t>(); |
| auto* src_data = src.data_ptr<data_t>(); |
| auto* output_data = output.data_ptr<data_t>(); |
| // NOLINTNEXTLINE(cppcoreguidelines-init-variables) |
| index_t* bag_size_data = nullptr; |
| if (bag_size.defined()) { |
| bag_size_data = bag_size.data_ptr<index_t>(); |
| } |
| auto numel = add_indices.numel(); |
| int64_t ddim = src.size(1); |
| auto vocab_size = src.size(0); |
| auto src_stride0 = src.strides()[0]; |
| auto src_stride1 = src.strides()[1]; |
| auto output_stride0 = output.strides()[0]; |
| auto output_stride1 = output.strides()[1]; |
| |
| auto* scale_data = scale.data_ptr<data_t>(); |
| auto scale_stride = scale.strides()[0]; |
| |
| for (const auto i : c10::irange(numel)) { |
| // We can skip indices equal to padding_idx so they are not included in |
| // the reduction |
| auto idx = select_indices_data[i]; |
| TORCH_CHECK( |
| idx >= 0 && idx < vocab_size, |
| "embedding_bag: Expected idx >= 0 && idx < num_embeddings but found idx to be ", |
| idx); |
| if (idx != padding_idx) { |
| auto* src_base = src_data + src_stride0 * idx; |
| auto* output_base = output_data + output_stride0 * add_indices_data[i]; |
| auto scale = scale_data[i * scale_stride]; |
| for (const auto j : c10::irange(ddim)) { |
| output_base[j * output_stride1] += src_base[j * src_stride1] * scale; |
| } |
| } else if (bag_size.defined()) { |
| // Decrement bag_size to reflect that the index is padded |
| // NOLINTNEXTLINE(clang-analyzer-core.NullDereference) |
| bag_size_data[add_indices_data[i]]--; |
| } |
| } |
| } |
| |
| template<typename data_t, typename index_t> |
| typename std::enable_if<std::is_same<data_t, at::Half>::value, void>::type |
| index_select_scale_add(const Tensor &select_indices, |
| const Tensor &add_indices, |
| const Tensor &scale, |
| const Tensor &src, |
| Tensor &output, |
| const Tensor& offsets, |
| bool include_last_offset, |
| Tensor &bag_size, |
| index_t padding_idx, |
| _EmbeddingBagKernelCache* fbgemm_kernel_cache) { |
| int64_t ddim = src.size(1); |
| auto* scale_data = scale.data_ptr<at::Half>(); |
| auto* select_indices_data = select_indices.data_ptr<index_t>(); |
| auto* output_data = output.data_ptr<at::Half>(); |
| |
| if (is_fast_path_index_select_scale(src, scale, output, padding_idx)) { |
| auto src_contig = src.contiguous(); |
| auto* src_data = src_contig.data_ptr<at::Half>(); |
| int64_t output_size = offsets.numel() - 1; |
| auto* offsets_data = offsets.data_ptr<index_t>(); |
| std::vector<index_t> offsets_include_last; |
| |
| if (include_last_offset) { |
| output_size = offsets.numel() - 1; |
| } else { |
| output_size = offsets.numel(); |
| offsets_include_last.resize(offsets.numel() + 1); |
| std::memcpy( |
| offsets_include_last.data(), |
| offsets.data_ptr<index_t>(), |
| sizeof(index_t) * offsets.numel()); |
| offsets_include_last[offsets.numel()] = select_indices.numel(); |
| offsets_data = offsets_include_last.data(); |
| } |
| |
| Tensor scale_fp32 = at::empty(scale.sizes(), scale.options().dtype(at::kFloat)); |
| auto* scale_data_fp32 = scale_fp32.data_ptr<float>(); |
| |
| #ifdef USE_FBGEMM |
| using float16 = uint16_t; |
| fbgemm::Float16ToFloat_simd(reinterpret_cast<const float16*>(scale_data), scale_data_fp32, scale_fp32.numel()); |
| auto kernel_fp16_index_t = |
| fbgemm_kernel_cache ? |
| fbgemm_kernel_cache->getCallback</* has_weight */ true, index_t, float16>(ddim) : |
| fbgemm::GenerateEmbeddingSpMDM<float16, index_t, index_t, float16>( |
| /* block_size */ddim, |
| /* has_weight */true, |
| /* normalize_by_lengths */false, |
| /* prefetch */16, |
| /* is_weight_positional */false, |
| /* use_offsets */true |
| ); |
| #else |
| // Initialize the intermediate output buffer to be 0. |
| Tensor output_fp32 = at::zeros({output_size, ddim}, output.options().dtype(at::kFloat)); |
| auto* output_data_fp32 = output_fp32.data_ptr<float>(); |
| for (const auto i : c10::irange(scale.numel())) { |
| scale_data_fp32[i] = static_cast<float>(scale_data[i]); |
| } |
| #endif |
| at::parallel_for( |
| 0, output_size, 1, [&](index_t start_idx, index_t end_idx) { |
| #ifdef USE_FBGEMM |
| bool success = kernel_fp16_index_t( |
| /* output_size */end_idx - start_idx, |
| /* index_size */offsets_data[end_idx] - offsets_data[start_idx], |
| /* data_size */src.size(0), |
| /* input */reinterpret_cast<const float16*>(src_data), |
| /* indices */select_indices_data + offsets_data[start_idx], |
| /* offsets_or_lengths */offsets_data + start_idx, |
| /* weights */scale_data_fp32 + offsets_data[start_idx], |
| /* output */reinterpret_cast<float16*>(output_data + start_idx * ddim)); |
| if (!success) { |
| fbgemm_spmdm_report_error_( |
| end_idx - start_idx, |
| offsets_data[end_idx] - offsets_data[start_idx], |
| src.size(0), |
| offsets_data + start_idx, |
| select_indices_data + offsets_data[start_idx]); |
| } |
| #else |
| caffe2::EmbeddingLookupIdx( |
| /*block_size=*/ddim, |
| /*output_size=*/end_idx - start_idx, |
| /*index_size=*/offsets_data[end_idx] - offsets_data[start_idx], |
| /*data_size=*/src.size(0), |
| /*input=*/src_data, |
| /*indices=*/select_indices_data + offsets_data[start_idx], |
| /*offsets=*/offsets_data + start_idx, |
| /*weights=*/scale_data_fp32 + offsets_data[start_idx], |
| /*scale_bias=*/nullptr, |
| /*normalize_by_lengths=*/false, |
| /*out=*/output_data_fp32 + start_idx * ddim); |
| for (const auto i : c10::irange(output_size)) { |
| // Convert FP32 intermediate buffer result back to FP16 for output dtype |
| for (const auto d : c10::irange(ddim)) { |
| (output_data + i * ddim)[d] = static_cast<at::Half>((output_data_fp32 + ddim * i)[d]); |
| } |
| } |
| #endif |
| }); |
| } else { |
| AT_ASSERT(select_indices.numel() == add_indices.numel()); |
| auto* src_data = src.data_ptr<at::Half>(); |
| auto* add_indices_data = add_indices.data_ptr<index_t>(); |
| // NOLINTNEXTLINE(cppcoreguidelines-init-variables) |
| index_t* bag_size_data = nullptr; |
| if (bag_size.defined()) { |
| bag_size_data = bag_size.data_ptr<index_t>(); |
| } |
| auto vocab_size = src.size(0); |
| auto src_stride0 = src.strides()[0]; |
| auto src_stride1 = src.strides()[1]; |
| auto output_stride0 = output.strides()[0]; |
| auto output_stride1 = output.strides()[1]; |
| auto scale_stride = scale.strides()[0]; |
| auto numel = add_indices.numel(); |
| |
| // Initialize the intermediate output buffer to be 0. |
| Tensor output_fp32 = at::zeros({output.size(0), ddim}, output.options().dtype(at::kFloat)); |
| auto* output_data_fp32 = output_fp32.data_ptr<float>(); |
| |
| for (const auto i : c10::irange(numel)) { |
| // We can skip indices equal to padding_idx so they are not included in |
| // the reduction |
| auto idx = select_indices_data[i]; |
| TORCH_CHECK( |
| idx >= 0 && idx < vocab_size, |
| "embedding_bag: Expected idx >= 0 && idx < num_embeddings but found idx to be ", |
| idx); |
| if (idx != padding_idx) { |
| |
| auto* src_base = src_data + src_stride0 * idx; |
| auto* output_base_fp32 = output_data_fp32 + ddim * add_indices_data[i]; |
| auto scale = scale_data[i * scale_stride]; |
| for (const auto j : c10::irange(ddim)) { |
| output_base_fp32[j] += static_cast<float>(src_base[j * src_stride1]) * static_cast<float>(scale); |
| } |
| } else if (bag_size.defined()) { |
| // Decrement bag_size to reflect that the index is padded |
| // NOLINTNEXTLINE(clang-analyzer-core.NullDereference) |
| bag_size_data[add_indices_data[i]]--; |
| } |
| } |
| for (const auto i : c10::irange(output.size(0))) { |
| // Convert FP32 intermediate buffer result back to FP16 for output dtype |
| for (const auto d : c10::irange(ddim)) { |
| (output_data + output_stride0 * i)[d * output_stride1] = static_cast<at::Half>((output_data_fp32 + ddim * i)[d]); |
| } |
| } |
| } |
| } |
| |
| template<typename data_t, typename index_t> |
| typename std::enable_if<std::is_same<data_t, float>::value, void>::type |
| index_select_scale_add(const Tensor &select_indices, |
| const Tensor &add_indices, |
| const Tensor &scale, |
| const Tensor &src, |
| Tensor &output, |
| const Tensor& offsets, |
| bool include_last_offset, |
| Tensor &bag_size, |
| index_t padding_idx, |
| _EmbeddingBagKernelCache* fbgemm_kernel_cache) { |
| int64_t ddim = src.size(1); |
| auto* scale_data = scale.data_ptr<float>(); |
| auto* select_indices_data = select_indices.data_ptr<index_t>(); |
| auto* output_data = output.data_ptr<float>(); |
| |
| if (is_fast_path_index_select_scale(src, scale, output, padding_idx)) { |
| auto src_contig = src.contiguous(); |
| auto* src_data = src_contig.data_ptr<float>(); |
| int64_t output_size = offsets.numel() - 1; |
| auto* offsets_data = offsets.data_ptr<index_t>(); |
| std::vector<index_t> offsets_include_last; |
| |
| if (include_last_offset) { |
| output_size = offsets.numel() - 1; |
| } else { |
| output_size = offsets.numel(); |
| offsets_include_last.resize(offsets.numel() + 1); |
| std::memcpy( |
| offsets_include_last.data(), |
| offsets.data_ptr<index_t>(), |
| sizeof(index_t) * offsets.numel()); |
| offsets_include_last[offsets.numel()] = select_indices.numel(); |
| offsets_data = offsets_include_last.data(); |
| } |
| |
| #ifdef USE_FBGEMM |
| auto kernel_fp32_index_t = |
| fbgemm_kernel_cache ? |
| fbgemm_kernel_cache->getCallback</* has_weight */ true, index_t, float>(ddim) : |
| fbgemm::GenerateEmbeddingSpMDM<float, index_t, index_t>( |
| /* block_size */ddim, |
| /* has_weight */true, |
| /* normalize_by_lengths */false, |
| /* prefetch */16, |
| /* is_weight_positional */false, |
| /* use_offsets */true |
| ); |
| #endif |
| at::parallel_for( |
| 0, output_size, 1, [&](index_t start_idx, index_t end_idx) { |
| #ifdef USE_FBGEMM |
| bool success = kernel_fp32_index_t( |
| /* output_size */end_idx - start_idx, |
| /* index_size */offsets_data[end_idx] - offsets_data[start_idx], |
| /* data_size */src.size(0), |
| /* input */src_data, |
| /* indices */select_indices_data + offsets_data[start_idx], |
| /* offsets_or_lengths */offsets_data + start_idx, |
| /* weights */scale_data + offsets_data[start_idx], |
| /* output */output_data + start_idx * ddim); |
| if (!success) { |
| fbgemm_spmdm_report_error_( |
| end_idx - start_idx, |
| offsets_data[end_idx] - offsets_data[start_idx], |
| src.size(0), |
| offsets_data + start_idx, |
| select_indices_data + offsets_data[start_idx]); |
| } |
| #else |
| caffe2::EmbeddingLookupIdx( |
| /*block_size=*/ddim, |
| /*output_size=*/end_idx - start_idx, |
| /*index_size=*/offsets_data[end_idx] - offsets_data[start_idx], |
| /*data_size=*/src.size(0), |
| /*input=*/src_data, |
| /*indices=*/select_indices_data + offsets_data[start_idx], |
| /*offsets=*/offsets_data + start_idx, |
| /*weights=*/scale_data + offsets_data[start_idx], |
| /*scale_bias=*/nullptr, |
| /*normalize_by_lengths=*/false, |
| /*out=*/output_data + start_idx * ddim); |
| #endif |
| }); |
| } else { |
| AT_ASSERT(select_indices.numel() == add_indices.numel()); |
| auto* src_data = src.data_ptr<float>(); |
| auto* add_indices_data = add_indices.data_ptr<index_t>(); |
| // NOLINTNEXTLINE(cppcoreguidelines-init-variables) |
| index_t* bag_size_data = nullptr; |
| if (bag_size.defined()) { |
| bag_size_data = bag_size.data_ptr<index_t>(); |
| } |
| auto vocab_size = src.size(0); |
| auto src_stride0 = src.strides()[0]; |
| auto src_stride1 = src.strides()[1]; |
| auto output_stride0 = output.strides()[0]; |
| auto output_stride1 = output.strides()[1]; |
| auto scale_stride = scale.strides()[0]; |
| auto numel = add_indices.numel(); |
| |
| |
| for (const auto i : c10::irange(numel)) { |
| // We can skip indices equal to padding_idx so they are not included in |
| // the reduction |
| auto idx = select_indices_data[i]; |
| TORCH_CHECK( |
| idx >= 0 && idx < vocab_size, |
| "embedding_bag: Expected idx >= 0 && idx < num_embeddings but found idx to be ", |
| idx); |
| if (idx != padding_idx) { |
| auto* src_base = src_data + src_stride0 * idx; |
| auto* output_base = output_data + output_stride0 * add_indices_data[i]; |
| auto scale = scale_data[i * scale_stride]; |
| for (const auto j : c10::irange(ddim)) { |
| output_base[j * output_stride1] += src_base[j * src_stride1] * scale; |
| } |
| } else if (bag_size.defined()) { |
| // Decrement bag_size to reflect that the index is padded |
| // NOLINTNEXTLINE(clang-analyzer-core.NullDereference) |
| bag_size_data[add_indices_data[i]]--; |
| } |
| } |
| } |
| } |
| |
| } // namespace |
| |
| void check_arguments( |
| const Tensor& weight, |
| const Tensor& indices, |
| const Tensor& offsets, |
| const int64_t mode, |
| const c10::optional<Tensor>& per_sample_weights, |
| bool include_last_offset) { |
| auto indices_arg = TensorArg(indices, "indices", 1); |
| checkScalarTypes("embedding_bag", indices_arg, {kLong, kInt}); |
| auto offsets_arg = TensorArg(offsets, "offsets", 1); |
| checkScalarTypes("embedding_bag", offsets_arg, {kLong, kInt}); |
| checkSameType("embedding_bag", indices_arg, offsets_arg); |
| auto weight_arg = TensorArg(weight, "weight", 1); |
| checkScalarTypes("embedding_bag", weight_arg, {kHalf, kFloat, kDouble}); |
| |
| AT_DISPATCH_INDEX_TYPES(offsets.scalar_type(), "_embedding_bag_cpu_impl", [&]() { |
| if (offsets.size(0) > 0) { |
| index_t offset_0 = offsets.data_ptr<index_t>()[0]; |
| index_t offset_n = offsets.data_ptr<index_t>()[offsets.size(0)-1]; |
| TORCH_CHECK(offset_0 == 0, "offsets[0] has to be 0, i.e., the first sequence " |
| "in the mini-batch has to start from position 0. " |
| "However, got ", offsets[0]); |
| TORCH_CHECK(offset_n <= indices.size(0), "offsets[-1] can not " |
| "be greater than input's length ", indices.size(0), " but got offsets[-1] of ", |
| offset_n); |
| } |
| }); |
| |
| if (per_sample_weights.has_value() && per_sample_weights.value().defined()) { |
| TORCH_CHECK(mode == MODE_SUM, |
| "embedding_bag: per_sample_weights only supported with mode='sum'"); |
| auto per_input_weights_arg = TensorArg( |
| per_sample_weights.value(),"per_sample_weights", 1); |
| checkSameType("embedding_bag", weight_arg, per_input_weights_arg); |
| TORCH_CHECK(per_sample_weights.value().dim() == 1); |
| TORCH_CHECK(per_sample_weights.value().numel() == indices.numel()); |
| } |
| |
| if (include_last_offset) { |
| TORCH_CHECK( |
| offsets.size(0) >= 1, |
| "include_last_offset: number of offset should be at least 1"); |
| } |
| } |
| |
| void make_bag_size_out( |
| Tensor& bag_size_out, |
| const Tensor& offsets, |
| const Tensor& indices, |
| const int64_t mode, |
| const bool include_last_offset, |
| const bool requires_grad) { |
| if (requires_grad || mode == MODE_MEAN || mode == MODE_MAX) { |
| auto num_bags = offsets.size(0) - (include_last_offset ? 1 : 0); |
| at::native::resize_(bag_size_out, {num_bags}, c10::nullopt); |
| // Compute this for MODE_MEAN and MODE_MAX (latter needed for backwards) |
| if (num_bags != 1) { |
| bag_size_out.slice(0, 0, bag_size_out.size(0) - 1, 1) = |
| offsets.slice(0, 1, num_bags, 1) - |
| offsets.slice(0, 0, num_bags - 1, 1); |
| } |
| if (num_bags > 0) { |
| bag_size_out[-1] = indices.size(0) - offsets[num_bags - 1]; |
| } |
| } else { |
| at::native::resize_(bag_size_out, offsets.sizes(), c10::nullopt); |
| } |
| } |
| |
| void make_max_indices_out( |
| Tensor& max_indices_out, |
| const Tensor& weight, |
| const Tensor& indices, |
| const Tensor& offsets, |
| const Tensor& bag_size, |
| const int64_t mode, |
| bool include_last_offset) { |
| int64_t numBags = offsets.size(0); |
| if (mode == MODE_MAX) { |
| if (include_last_offset) { |
| TORCH_CHECK( |
| numBags >= 1, "include_last_offset: numBags should be at least 1"); |
| numBags -= 1; |
| } |
| at::native::resize_(max_indices_out, {numBags, weight.sizes()[1]}, c10::nullopt); |
| at::native::zero_(max_indices_out); |
| } else { |
| at::native::resize_(max_indices_out, bag_size.sizes(), c10::nullopt); |
| } |
| } |
| |
| void make_offset2bag_out( |
| Tensor& offset2bag, |
| Tensor& output, |
| const Tensor& weight, |
| const Tensor& indices, |
| const Tensor& offsets, |
| const int64_t mode, |
| const c10::optional<Tensor>& per_sample_weights, |
| const int64_t padding_idx) { |
| // To save compute, if we are going to go down the fast path case for the 'sum' |
| // mode, we skip calculating offset2bag, since it is not going to be used. |
| bool fast_path_sum = is_fast_path(weight, per_sample_weights, output, padding_idx); |
| |
| if (mode == MODE_MEAN || mode == MODE_MAX || !fast_path_sum) { |
| at::native::resize_(offset2bag, {indices.size(0) + 1}, c10::nullopt); |
| at::native::zero_(offset2bag); |
| |
| make_offset2bag(offsets, offset2bag); |
| at::native::resize_(offset2bag, {indices.size(0)}, c10::nullopt); |
| // only initialize output in slow path |
| at::native::zero_(output); |
| } |
| } |
| |
| static Tensor make_bag_size( |
| const Tensor& offsets, |
| const Tensor& indices, |
| const int64_t mode, |
| const bool include_last_offset, |
| const bool requires_grad) { |
| Tensor bag_size = at::empty(offsets.sizes(), offsets.options()); |
| make_bag_size_out(bag_size, offsets, indices, mode, include_last_offset, requires_grad); |
| return bag_size; |
| } |
| |
| static Tensor make_max_indices( |
| const Tensor& weight, |
| const Tensor& indices, |
| const Tensor& offsets, |
| const Tensor& bag_size, |
| const int64_t mode, |
| bool include_last_offset) { |
| Tensor max_indices = at::empty(bag_size.sizes(), offsets.options()); |
| make_max_indices_out(max_indices, weight, indices, offsets, bag_size, mode, include_last_offset); |
| return max_indices; |
| } |
| |
| static Tensor make_offset2bag( |
| Tensor& output, |
| const Tensor& weight, |
| const Tensor& indices, |
| const Tensor& offsets, |
| const int64_t mode, |
| const c10::optional<Tensor>& per_sample_weights, |
| const int64_t padding_idx) { |
| Tensor offset2bag = at::empty({0}, offsets.options()); |
| make_offset2bag_out(offset2bag, output, weight, indices, offsets, mode, per_sample_weights, padding_idx); |
| return offset2bag; |
| } |
| |
| static Tensor apply_bag_size( |
| const int64_t mode, |
| Tensor &output, |
| const Tensor &bag_size) { |
| if (mode == MODE_MEAN) { |
| auto bag_size_ = at::max(bag_size, at::ones_like(bag_size, LEGACY_CONTIGUOUS_MEMORY_FORMAT)) |
| .to(output.options()) |
| .unsqueeze(1) |
| .expand_as(output); |
| output /= bag_size_; |
| } |
| return output; |
| } |
| |
| static Tensor apply_bag_size_backward( |
| const int64_t mode, |
| Tensor &output, |
| const Tensor &offset2bag, |
| const Tensor &bag_size) { |
| if (mode == MODE_MEAN) { |
| auto inv_bag_size_ = (1 / bag_size.to(output.options())) |
| .unsqueeze(1) |
| .index_select(0, offset2bag); |
| output *= inv_bag_size_; |
| } |
| return output; |
| } |
| |
| template <typename scalar_t> |
| void embedding_bag_cpu_max_out( |
| Tensor* max_indices, |
| const Tensor& weight, |
| const Tensor& indices, |
| const Tensor& offset2bag, |
| const Tensor& output, |
| bool include_last_offset, |
| Tensor& bag_size, |
| int64_t padding_idx) { |
| int64_t numIndices = indices.numel(); |
| int64_t featureSize = weight.size(1); |
| int64_t vocab_size = weight.size(0); |
| AT_DISPATCH_INDEX_TYPES(indices.scalar_type(), "embedding_bag_cpu_max_out", [&] { |
| auto* indices_data = indices.data_ptr<index_t>(); |
| auto* offset2bag_data = offset2bag.data_ptr<index_t>(); |
| |
| index_t* max_indices_data = nullptr; |
| int64_t max_indices_stride = 0; |
| if (max_indices) { |
| max_indices_data = max_indices->data_ptr<index_t>(); |
| max_indices_stride = max_indices->strides()[0]; |
| } |
| |
| auto* weight_data = weight.data_ptr<scalar_t>(); |
| auto* output_data = output.data_ptr<scalar_t>(); |
| auto* bag_size_data = bag_size.data_ptr<index_t>(); |
| auto weight_stride0 = weight.strides()[0]; |
| auto weight_stride1 = weight.strides()[1]; |
| auto output_stride = output.strides()[0]; |
| int64_t numBags = bag_size.size(0); |
| std::vector<bool> bag_empty(numBags, true); |
| |
| for (const auto i : c10::irange(numIndices)) { |
| auto bag = offset2bag_data[i]; |
| auto word_idx = indices_data[i]; |
| TORCH_CHECK( |
| word_idx >= 0 && word_idx < vocab_size, |
| "embedding_bag: Expected idx >= 0 && idx < num_embeddings but found idx to be ", |
| word_idx); |
| if (word_idx != static_cast<index_t>(padding_idx)) { |
| bool is_first_for_bag = bag_empty[bag]; |
| for (const auto dim : c10::irange(featureSize)) { |
| auto& current_item = output_data[output_stride * bag + dim]; |
| auto weight_item = |
| weight_data[weight_stride0 * word_idx + dim * weight_stride1]; |
| |
| if (is_first_for_bag || (weight_item > current_item)) { |
| current_item = weight_item; |
| if (max_indices_data) { |
| max_indices_data[max_indices_stride * bag + dim] = word_idx; |
| } |
| } |
| } |
| if (is_first_for_bag) { |
| bag_empty[bag] = false; |
| } |
| } else { |
| // Decrement bag_size to reflect that the index is padded |
| bag_size_data[bag]--; |
| } |
| } |
| }); |
| } |
| |
| void _embedding_bag_cpu_impl_out(Tensor& output, Tensor& offset2bag, |
| Tensor& bag_size, Tensor* max_indices, |
| const Tensor &weight, const Tensor &indices, |
| const Tensor &offsets, const int64_t mode, |
| const c10::optional<Tensor>& per_sample_weights, |
| bool include_last_offset, int64_t padding_idx, _EmbeddingBagKernelCache* fbgemm_kernel_cache) { |
| if (mode == MODE_MEAN || mode == MODE_SUM) { |
| AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, weight.scalar_type(), "embedding_bag_no_grad_cpu_out", |
| [&indices, &offset2bag, &per_sample_weights, &weight, &output, &offsets, &include_last_offset, &mode, &bag_size, &padding_idx, &fbgemm_kernel_cache]() { |
| AT_DISPATCH_INDEX_TYPES(indices.scalar_type(), "embedding_bag_no_grad_cpu_out", |
| [&indices, &offset2bag, &per_sample_weights, &weight, &output, &offsets, &include_last_offset, &mode, &bag_size, &padding_idx, &fbgemm_kernel_cache]() { |
| if (per_sample_weights.has_value() && per_sample_weights.value().defined()) { |
| TORCH_INTERNAL_ASSERT(mode == MODE_SUM); |
| index_select_scale_add<scalar_t, index_t>( |
| indices, offset2bag, per_sample_weights.value(), weight, output, offsets, include_last_offset, bag_size, padding_idx, fbgemm_kernel_cache); |
| } else { |
| index_select_add<scalar_t, index_t>(indices, offset2bag, weight, output, offsets, include_last_offset, bag_size, padding_idx, fbgemm_kernel_cache); |
| } |
| }); |
| }); |
| apply_bag_size(mode, output, bag_size); |
| if (mode == MODE_SUM) { |
| // make bag_size output deterministic |
| at::native::zero_(bag_size); |
| } |
| if (max_indices) { |
| max_indices->copy_(bag_size); |
| } |
| } else { // MODE_MAX |
| AT_DISPATCH_FLOATING_TYPES_AND_HALF( |
| weight.scalar_type(), "embedding_bag_cpu_max_out", [&]() { |
| embedding_bag_cpu_max_out<scalar_t>( |
| max_indices, weight, indices, offset2bag, output, include_last_offset, bag_size, padding_idx); |
| } |
| ); |
| } |
| } |
| |
| // Assumes all input tensors except for `weight` are contiguous. |
| // See NOTE [ embedding_bag Native Functions ] in native_functions.yaml for details |
| std::tuple<Tensor, Tensor, Tensor, Tensor> _embedding_bag_cpu_impl( |
| const Tensor& weight, |
| const Tensor& indices_, |
| const Tensor& offsets_, |
| const int64_t mode, |
| const Tensor& per_sample_weights, |
| bool include_last_offset, |
| int64_t padding_idx, |
| bool requires_grad) { |
| Tensor indices, offsets; |
| std::tie(indices, offsets) = promoteIndicesAndOffsets(indices_, offsets_); |
| check_arguments(weight, indices, offsets, mode, per_sample_weights, include_last_offset); |
| |
| Tensor output = at::empty( |
| {include_last_offset ? offsets.size(0) - 1 : offsets.size(0), |
| weight.sizes()[1]}, |
| weight.options()); |
| |
| Tensor offset2bag = make_offset2bag(output, weight, indices, offsets, mode, per_sample_weights, padding_idx); |
| |
| Tensor bag_size = make_bag_size(offsets, indices, mode, include_last_offset, requires_grad); |
| |
| Tensor max_indices = make_max_indices(weight, indices, offsets, bag_size, mode, include_last_offset); |
| |
| _embedding_bag_cpu_impl_out(output, offset2bag, |
| bag_size, &max_indices, |
| weight, indices, offsets, |
| mode, per_sample_weights, |
| include_last_offset, padding_idx); |
| |
| return std::make_tuple(std::move(output), std::move(offset2bag), std::move(bag_size), std::move(max_indices)); |
| } |
| |
| // embedding_bag wrapper to enforce contiguity in tensors other than `weight`. |
| // This is created to save extra `.contiguous()` call in backward. |
| // See NOTE [ embedding_bag Native Functions ] in native_functions.yaml for details |
| std::tuple<Tensor, Tensor, Tensor, Tensor> |
| embedding_bag(const Tensor &weight, const Tensor &indices, |
| const Tensor &offsets, const bool scale_grad_by_freq, |
| const int64_t mode, bool sparse, const c10::optional<Tensor>& per_sample_weights_opt, |
| bool include_last_offset, c10::optional<int64_t> padding_idx_opt) { |
| // See [Note: hacky wrapper removal for optional tensor] |
| c10::MaybeOwned<Tensor> per_sample_weights_maybe_owned = at::borrow_from_optional_tensor(per_sample_weights_opt); |
| const Tensor& per_sample_weights = *per_sample_weights_maybe_owned; |
| int64_t padding_idx = -1; |
| |
| if (padding_idx_opt.has_value()) { |
| auto num_embeddings = weight.size(0); |
| padding_idx = padding_idx_opt.value(); |
| TORCH_CHECK( |
| (padding_idx >= -num_embeddings) && (padding_idx < num_embeddings), |
| "padding_idx must be within the number of embeddings, -", num_embeddings, |
| " through ", num_embeddings - 1, ", but got ", padding_idx); |
| padding_idx = maybe_wrap_dim(padding_idx, weight.size(0)); |
| } |
| std::tuple<Tensor, Tensor, Tensor, Tensor> out; |
| if (!weight.requires_grad() && !weight._fw_grad(/*level=*/0).defined()) { |
| out = at::_embedding_bag_forward_only( |
| weight, indices.contiguous(), offsets.contiguous(), scale_grad_by_freq, |
| mode, sparse, per_sample_weights, include_last_offset, padding_idx); |
| } else { |
| out = at::_embedding_bag( |
| weight, indices.contiguous(), offsets.contiguous(), scale_grad_by_freq, |
| mode, sparse, per_sample_weights, include_last_offset, padding_idx); |
| } |
| return out; |
| }; |
| |
| std::tuple<Tensor, Tensor, Tensor, Tensor> |
| embedding_bag(const Tensor &weight, const Tensor &indices, |
| const Tensor &offsets, const bool scale_grad_by_freq, |
| const int64_t mode, bool sparse, const c10::optional<Tensor>& per_sample_weights_opt, |
| bool include_last_offset) { |
| return at::native::embedding_bag(weight, indices, offsets, scale_grad_by_freq, |
| mode, sparse, per_sample_weights_opt, include_last_offset, c10::nullopt); |
| } |
| |
| // Assumes all input tensors except for `weight` are contiguous. |
| // See NOTE [ embedding_bag Native Functions ] in native_functions.yaml for details |
| std::tuple<Tensor, Tensor, Tensor, Tensor> |
| _embedding_bag_forward_only_cpu(const Tensor &weight, const Tensor &indices, |
| const Tensor &offsets, const bool scale_grad_by_freq, |
| const int64_t mode, bool sparse, const c10::optional<Tensor>& per_sample_weights_opt, bool include_last_offset, |
| int64_t padding_idx) { |
| // See [Note: hacky wrapper removal for optional tensor] |
| c10::MaybeOwned<Tensor> per_sample_weights_maybe_owned = at::borrow_from_optional_tensor(per_sample_weights_opt); |
| const Tensor& per_sample_weights = *per_sample_weights_maybe_owned; |
| std::ignore = scale_grad_by_freq; |
| std::ignore = sparse; |
| return _embedding_bag_cpu_impl( |
| weight, |
| indices, |
| offsets, |
| mode, |
| per_sample_weights, |
| include_last_offset, |
| padding_idx, |
| /*requires_grad=*/false); |
| } |
| |
| // Assumes all input tensors except for `weight` are contiguous. |
| // See NOTE [ embedding_bag Native Functions ] in native_functions.yaml for details |
| std::tuple<Tensor, Tensor, Tensor, Tensor> |
| _embedding_bag_cpu(const Tensor &weight, const Tensor &indices, |
| const Tensor &offsets, const bool scale_grad_by_freq, |
| const int64_t mode, bool sparse, const c10::optional<Tensor>& per_sample_weights_opt, bool include_last_offset, |
| int64_t padding_idx) { |
| // See [Note: hacky wrapper removal for optional tensor] |
| c10::MaybeOwned<Tensor> per_sample_weights_maybe_owned = at::borrow_from_optional_tensor(per_sample_weights_opt); |
| const Tensor& per_sample_weights = *per_sample_weights_maybe_owned; |
| |
| std::ignore = scale_grad_by_freq; |
| std::ignore = sparse; |
| return _embedding_bag_cpu_impl( |
| weight, |
| indices, |
| offsets, |
| mode, |
| per_sample_weights, |
| include_last_offset, |
| padding_idx, |
| /*requires_grad=*/true); |
| } |
| |
| void _embedding_bag_cpu_out( |
| at::Tensor& output, |
| at::Tensor& offset2bag, |
| at::Tensor& bag_size, |
| at::Tensor* p_max_indices, |
| const at::Tensor& weight, |
| const at::Tensor& indices, |
| const at::Tensor& offsets, |
| const bool /* scale_grad_by_freq */, |
| const int64_t mode, |
| const bool /* sparse */, |
| const c10::optional<at::Tensor>& per_sample_weights, |
| const bool include_last_offset, |
| const c10::optional<int64_t>& padding_idx, |
| _EmbeddingBagKernelCache* fbgemm_kernel_cache) { |
| at::native::check_arguments( |
| weight, indices, offsets, mode, per_sample_weights, include_last_offset); |
| |
| at::native::make_offset2bag_out( |
| offset2bag, |
| output, |
| weight, |
| indices, |
| offsets, |
| mode, |
| per_sample_weights, |
| padding_idx.value_or(-1)); |
| |
| at::native::make_bag_size_out( |
| bag_size, offsets, indices, mode, include_last_offset, false); |
| |
| if (p_max_indices) { |
| at::native::make_max_indices_out( |
| *p_max_indices, |
| weight, |
| indices, |
| offsets, |
| bag_size, |
| mode, |
| include_last_offset); |
| } |
| |
| at::native::_embedding_bag_cpu_impl_out( |
| output, |
| offset2bag, |
| bag_size, |
| p_max_indices, |
| weight, |
| indices, |
| offsets, |
| mode, |
| per_sample_weights, |
| include_last_offset, |
| padding_idx.value_or(-1), |
| fbgemm_kernel_cache); |
| } |
| |
| Tensor _embedding_bag_backward(const Tensor &grad, const Tensor &indices_, |
| const Tensor &offsets_, |
| const Tensor &offset2bag, |
| const Tensor &bag_size_, |
| const Tensor &max_indices_, |
| int64_t num_weights, |
| bool scale_grad_by_freq, int64_t mode, |
| bool sparse, const c10::optional<Tensor>& per_sample_weights_opt, |
| int64_t padding_idx) { |
| return at::native::_embedding_bag_backward_symint( |
| grad, indices_, offsets_, offset2bag, bag_size_, max_indices_, num_weights, scale_grad_by_freq, mode, sparse, per_sample_weights_opt, padding_idx); |
| } |
| |
| // Assumes all input tensors are contiguous. |
| // See NOTE [ embedding_bag Native Functions ] in native_functions.yaml for details |
| Tensor _embedding_bag_backward_symint(const Tensor &grad, const Tensor &indices_, |
| const Tensor &offsets_, |
| const Tensor &offset2bag, |
| const Tensor &bag_size_, |
| const Tensor &max_indices_, |
| c10::SymInt num_weights, |
| bool scale_grad_by_freq, int64_t mode, |
| bool sparse, const c10::optional<Tensor>& per_sample_weights_opt, |
| int64_t padding_idx) { |
| // See [Note: hacky wrapper removal for optional tensor] |
| c10::MaybeOwned<Tensor> per_sample_weights_maybe_owned = at::borrow_from_optional_tensor(per_sample_weights_opt); |
| const Tensor& per_sample_weights = *per_sample_weights_maybe_owned; |
| |
| Tensor indices, offsets; |
| std::tie(indices, offsets) = promoteIndicesAndOffsets(indices_, offsets_); |
| auto indices_arg = TensorArg(indices, "indices", 1); |
| checkScalarTypes("embedding_bag", indices_arg, {kLong, kInt}); |
| checkContiguous("embedding_bag", indices_arg); |
| auto offsets_arg = TensorArg(offsets, "offsets", 1); |
| checkScalarTypes("embedding_bag", offsets_arg, {kLong, kInt}); |
| checkSameType("embedding_bag", indices_arg, offsets_arg); |
| checkContiguous("embedding_bag", offsets_arg); |
| |
| Tensor offset2bag_; |
| if (indices.numel() != 0 && offset2bag.numel() == 0) { |
| offset2bag_ = offsets.new_zeros( |
| {indices.size(0) + 1}, offsets.options()); // offset2bag = [0 0 0 0 0] |
| |
| make_offset2bag(offsets, offset2bag_); |
| // For Composite Compliance, if `offset2bag_` is CCT |
| // then we can't call `resize_`. Instead we call `narrow` |
| // to slice the tensor. |
| if (isTensorSubclassLike(offset2bag_)) { |
| offset2bag_ = offset2bag_.narrow(0, 0, indices.size(0)); |
| } else { |
| offset2bag_.resize_({indices.size(0)}); |
| } |
| } else { |
| auto offset2bag_arg = TensorArg(offset2bag, "offset2bag", 1); |
| checkScalarTypes("embedding_bag", offset2bag_arg, {kLong, kInt}); |
| checkContiguous("embedding_bag", offset2bag_arg); |
| offset2bag_ = offset2bag; |
| } |
| |
| if (sparse) { |
| return at::_embedding_bag_sparse_backward_symint( |
| grad, indices, offsets, offset2bag_, bag_size_, num_weights, |
| scale_grad_by_freq, mode, per_sample_weights, padding_idx); |
| } else { |
| return at::_embedding_bag_dense_backward_symint( |
| grad, indices, offset2bag_, bag_size_, max_indices_, num_weights, |
| scale_grad_by_freq, mode, per_sample_weights, padding_idx); |
| } |
| } |
| |
| static Tensor _embedding_bag_dense_backward_cpu_max( |
| const Tensor& grad, |
| const Tensor& bag_size, |
| const Tensor& max_indices, |
| int64_t num_weights) { |
| AT_ASSERT(max_indices.defined()); |
| auto index_grad_weight = |
| at::zeros({num_weights, grad.sizes()[1]}, grad.options()); |
| auto nonempty_max_indices = max_indices.index_select(0, bag_size.nonzero().view(-1)); |
| auto nonempty_grad = grad.index_select(0, bag_size.nonzero().view(-1)); |
| |
| for (const auto dim : c10::irange(grad.sizes()[1])) { |
| index_grad_weight.select(1, dim).index_add_( |
| 0, nonempty_max_indices.select(1, dim), nonempty_grad.select(1, dim)); |
| } |
| return index_grad_weight; |
| } |
| |
| template<typename index_t> |
| static std::vector<index_t> compute_counts( |
| int64_t num_weights, |
| index_t* indices_data, |
| int64_t indices_length) { |
| std::vector<index_t> counts(num_weights, 0); |
| for (const auto i : c10::irange(indices_length)) { |
| counts[indices_data[i]]++; |
| } |
| return counts; |
| } |
| |
| // counts_uniq stores the index of the NEXT unique element |
| // of the (sorted) indices vector. |
| // |
| // For example: |
| // indices: [0, 0, 0, 1, 3, 3, 4] |
| // counts: [3, 1, 0, 2, 1, 0] |
| // counts_uniq: [3, 4, 6, 7] |
| // |
| // The unique indices can be found at index 0, 3, 4, 6. |
| template<typename index_t> |
| static std::vector<index_t> compute_counts_uniq( |
| int64_t num_weights, |
| index_t* indices_data, |
| int64_t indices_length, |
| const std::vector<index_t>& counts) { |
| std::vector<index_t> counts_uniq; |
| counts_uniq.reserve(num_weights); |
| int64_t o = 0; |
| for (int64_t i = 0; i < indices_length; i += counts[indices_data[i]]) { |
| counts_uniq.push_back(counts[indices_data[i]]); |
| if (o > 0) { |
| counts_uniq[o] += counts_uniq[o - 1]; |
| } |
| o++; |
| } |
| return counts_uniq; |
| } |
| |
| template <typename scalar_t> |
| void _embedding_bag_dense_backward_cpu_sum_mean( |
| const Tensor& grad, |
| const Tensor& indices_, |
| const Tensor& offset2bag__, |
| const Tensor& bag_size_, |
| int64_t num_weights, |
| bool scale_grad_by_freq, |
| int64_t mode, |
| const Tensor& per_sample_weights_, |
| Tensor& index_grad_weight, |
| int64_t padding_idx) { |
| |
| // NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast) |
| Tensor &offset2bag_ = const_cast<Tensor &>(offset2bag__); |
| |
| auto ind_sort_ = indices_.sort(); |
| auto indices = std::get<0>(ind_sort_); |
| auto ind_sort = std::get<1>(ind_sort_); |
| auto offset2bag = offset2bag_.index_select(0, ind_sort); |
| |
| optional<Tensor> per_sample_weights; |
| // NOLINTNEXTLINE(cppcoreguidelines-init-variables) |
| scalar_t* per_sample_weights_data; |
| optional<int64_t> per_sample_weights_stride; |
| if (per_sample_weights_.defined()) { |
| per_sample_weights = per_sample_weights_.index_select(0, ind_sort); |
| per_sample_weights_data = per_sample_weights->data_ptr<scalar_t>(); |
| per_sample_weights_stride = per_sample_weights->strides()[0]; |
| } |
| |
| int64_t numel = indices.numel(); |
| |
| // explicitly capture all required variables to work around windows build |
| // TODO: fix this when windows can correctly capture variables in nested lambda |
| AT_DISPATCH_INDEX_TYPES(indices.scalar_type(), "_embedding_bag_dense_backward_cpu_sum_mean", |
| [&indices, &offset2bag, &bag_size_, &num_weights, &numel, &per_sample_weights, |
| &per_sample_weights_data, &per_sample_weights_stride, &mode, &scale_grad_by_freq, |
| &grad, &index_grad_weight, &padding_idx] { |
| auto* indices_data = indices.data_ptr<index_t>(); |
| auto* offset2bag_data = offset2bag.data_ptr<index_t>(); |
| auto* bag_size_data = bag_size_.data_ptr<index_t>(); |
| |
| auto counts = compute_counts(num_weights, indices_data, numel); |
| auto next_unique_index_idx = |
| compute_counts_uniq(num_weights, indices_data, numel, counts); |
| |
| auto loop = |
| [&next_unique_index_idx, &indices_data, &offset2bag_data, &bag_size_data, &per_sample_weights, |
| &mode, &per_sample_weights_data, &per_sample_weights_stride, &scale_grad_by_freq, |
| &counts, &grad, &index_grad_weight, &padding_idx |
| ](index_t start, index_t end) { |
| for (index_t i = start; i < end; i++) { |
| index_t start = i == 0 ? 0 : next_unique_index_idx[i - 1]; |
| index_t index = indices_data[start]; |
| |
| if (index != static_cast<index_t>(padding_idx)) { |
| for (index_t j = start; j < next_unique_index_idx[i]; j++) { |
| index_t source = offset2bag_data[j]; |
| double scale = 1.0; |
| if (per_sample_weights) { |
| AT_ASSERT(mode == MODE_SUM); |
| scale = per_sample_weights_data[*per_sample_weights_stride * j]; |
| } |
| if (scale_grad_by_freq) { |
| scale /= counts[indices_data[i]]; |
| } |
| if (mode == MODE_MEAN) { |
| auto bag_size = bag_size_data[source]; |
| if (bag_size != 0) { |
| scale /= bag_size; |
| } |
| } |
| int64_t ddim = grad.size(1); |
| auto igwd = index_grad_weight.data_ptr<scalar_t>(); |
| auto gd = grad.data_ptr<scalar_t>(); |
| at::native::cpublas::axpy<scalar_t>(ddim, (scalar_t)scale, gd + ddim * source, 1, |
| igwd + ddim * index, 1); |
| } |
| } |
| } |
| }; |
| |
| if (numel > 1000) { |
| at::parallel_for(0, (int64_t)next_unique_index_idx.size(), 0, loop); |
| } else { |
| loop(0, (int64_t)next_unique_index_idx.size()); |
| } |
| }); |
| } |
| |
| Tensor _embedding_bag_dense_backward_cpu(const Tensor &grad_, const Tensor &indices_, |
| const Tensor &offset2bag__, |
| const Tensor &bag_size_, |
| const Tensor& max_indices_, int64_t num_weights, |
| bool scale_grad_by_freq, int64_t mode, const c10::optional<Tensor>& per_sample_weights__opt, |
| int64_t padding_idx) { |
| // See [Note: hacky wrapper removal for optional tensor] |
| c10::MaybeOwned<Tensor> per_sample_weights__maybe_owned = at::borrow_from_optional_tensor(per_sample_weights__opt); |
| const Tensor& per_sample_weights_ = *per_sample_weights__maybe_owned; |
| |
| // indices_, offsets_ and offset2bag__ are assumed having correct dtypes and |
| // contiguous here due to the checks in _embedding_bag_backward above. |
| // Also see NOTE [ embedding_bag Native Functions ] in native_functions.yaml |
| // for more details. |
| auto grad = grad_.contiguous(); |
| auto grad_arg = TensorArg(grad, "grad_", 1); |
| checkScalarTypes("embedding_bag", grad_arg, {kHalf, kFloat, kDouble}); |
| |
| if (mode == MODE_MAX) { |
| return _embedding_bag_dense_backward_cpu_max( |
| grad_, bag_size_, max_indices_, num_weights); |
| } |
| AT_ASSERT(mode == MODE_MEAN || mode == MODE_SUM); |
| |
| auto index_grad_weight = |
| at::zeros({num_weights, grad.sizes()[1]}, grad.options()); |
| |
| AT_DISPATCH_FLOATING_TYPES_AND2( |
| at::ScalarType::Half, |
| at::ScalarType::BFloat16, |
| grad.scalar_type(), |
| "embedding_bag_backward", |
| [&] { |
| _embedding_bag_dense_backward_cpu_sum_mean<scalar_t>( |
| grad, |
| indices_, |
| offset2bag__, |
| bag_size_, |
| num_weights, |
| scale_grad_by_freq, |
| mode, |
| per_sample_weights_, |
| index_grad_weight, |
| padding_idx); |
| }); |
| return index_grad_weight; |
| } |
| |
| template<typename scalar_t> |
| Tensor _embedding_bag_per_sample_weights_backward_cpu_template( |
| const Tensor& grad, |
| const Tensor& weight, // NB: embedding table, not per_sample_weights |
| const Tensor& indices_, |
| const Tensor& offsets_, |
| const Tensor& offset2bag, |
| int64_t mode, |
| int64_t padding_idx) { |
| TORCH_CHECK( |
| mode == MODE_SUM, |
| "embedding_bag_backward: per_sample_weights only supported for mode='sum'"); |
| |
| AT_ASSERT(grad.dim() == 2); |
| auto embedding_features = grad.sizes()[1]; |
| |
| Tensor indices, offsets; |
| std::tie(indices, offsets) = promoteIndicesAndOffsets(indices_, offsets_); |
| AT_ASSERT(indices.dim() == 1); |
| auto num_samples = indices.size(0); |
| |
| AT_ASSERT(weight.dim() == 2); |
| AT_ASSERT(weight.sizes()[1] == embedding_features); |
| |
| auto output = at::zeros({num_samples}, grad.options()); |
| |
| auto indices_arg = TensorArg(indices, "indices", 1); |
| checkScalarTypes("embedding_bag", indices_arg, {kLong, kInt}); |
| checkContiguous("embedding_bag", indices_arg); |
| |
| Tensor offset2bag_; |
| if (indices.numel() != 0 && offset2bag.numel() == 0) { |
| offset2bag_ = at::zeros( |
| {indices.size(0) + 1}, offset2bag.options()); // offset2bag = [0 0 0 0 0] |
| |
| make_offset2bag(offsets, offset2bag_); |
| |
| at::native::resize_(offset2bag_, {indices.size(0)}, c10::nullopt); |
| } else { |
| auto offset2bag_arg = TensorArg(offset2bag, "offset2bag", 1); |
| checkScalarTypes("embedding_bag", offset2bag_arg, {kLong, kInt}); |
| checkContiguous("embedding_bag", offset2bag_arg); |
| offset2bag_ = offset2bag; |
| } |
| |
| auto* grad_data = grad.data_ptr<scalar_t>(); |
| auto grad_stride0 = grad.strides()[0]; |
| auto grad_stride1 = grad.strides()[1]; |
| |
| auto* weight_data = weight.data_ptr<scalar_t>(); |
| auto weight_stride0 = weight.strides()[0]; |
| auto weight_stride1 = weight.strides()[1]; |
| |
| // explicitly capture all required variables to work around windows build |
| // TODO: fix this when windows can correctly capture variables in nested lambda |
| AT_DISPATCH_INDEX_TYPES(indices.scalar_type(), "_embedding_bag_per_sample_weights_backward_cpu_template", |
| [&indices, &output, &offset2bag_, &num_samples, &embedding_features, |
| &grad_data, &grad_stride0, &grad_stride1, &weight_data, &weight_stride0, &weight_stride1, |
| &padding_idx] () { |
| auto* indices_data = indices.data_ptr<index_t>(); |
| |
| // The following are contiguous |
| auto* output_data = output.data_ptr<scalar_t>(); |
| auto* offset2bag_data = offset2bag_.data_ptr<index_t>(); |
| |
| // XXX: 64 was arbitrarily chosen. There is probably a sweet spot for this number. |
| parallel_for(0, num_samples, 64, |
| [&embedding_features, &grad_data, &grad_stride0, &grad_stride1, &weight_data, &weight_stride0, |
| &weight_stride1, &offset2bag_data, &indices_data, &output_data, &padding_idx](index_t begin, index_t end) { |
| for (index_t sample_idx = begin; sample_idx < end; sample_idx++) { |
| auto bag_idx = offset2bag_data[sample_idx]; |
| auto embedding_idx = indices_data[sample_idx]; |
| |
| if (embedding_idx != static_cast<index_t>(padding_idx)) { |
| output_data[sample_idx] = dot_impl<scalar_t>( |
| embedding_features, |
| grad_data + grad_stride0 * bag_idx, grad_stride1, |
| weight_data + weight_stride0 * embedding_idx, weight_stride1); |
| } |
| } |
| }); |
| }); |
| return output; |
| } |
| |
| Tensor _embedding_bag_per_sample_weights_backward_cpu( |
| const Tensor& grad, |
| const Tensor& weight, // NB: embedding table, not per_sample_weights |
| const Tensor& indices, |
| const Tensor& offsets, |
| const Tensor& offset2bag, |
| int64_t mode, |
| int64_t padding_idx) { |
| return AT_DISPATCH_FLOATING_TYPES_AND2( |
| at::ScalarType::Half, |
| at::ScalarType::BFloat16, |
| grad.scalar_type(), |
| "_embedding_bag_per_sample_weights_backward_cpu", |
| [&]() { |
| return _embedding_bag_per_sample_weights_backward_cpu_template< |
| scalar_t>( |
| grad, weight, indices, offsets, offset2bag, mode, padding_idx); |
| }); |
| } |
| |
| Tensor _embedding_bag_sparse_backward( |
| const Tensor &grad_, const Tensor &indices, const Tensor &offsets, |
| const Tensor &offset2bag, const Tensor &bag_size_, SymInt num_weights, |
| bool scale_grad_by_freq, int64_t mode, const c10::optional<Tensor>& per_sample_weights_opt, |
| int64_t padding_idx) { |
| return at::native::_embedding_bag_sparse_backward_symint(grad_, indices, offsets, offset2bag, bag_size_, num_weights, |
| scale_grad_by_freq, mode, per_sample_weights_opt, padding_idx); |
| } |
| |
| Tensor _embedding_bag_sparse_backward_symint( |
| const Tensor &grad_, const Tensor &indices, const Tensor &offsets, |
| const Tensor &offset2bag, const Tensor &bag_size_, SymInt num_weights, |
| bool scale_grad_by_freq, int64_t mode, const c10::optional<Tensor>& per_sample_weights_opt, |
| int64_t padding_idx) { |
| // See [Note: hacky wrapper removal for optional tensor] |
| c10::MaybeOwned<Tensor> per_sample_weights_maybe_owned = at::borrow_from_optional_tensor(per_sample_weights_opt); |
| const Tensor& per_sample_weights = *per_sample_weights_maybe_owned; |
| |
| // indices, offsets and offset2bag are assumed having correct dtypes and |
| // contiguous here due to the checks in _embedding_bag_backward above. |
| // Also see NOTE [ embedding_bag Native Functions ] in native_functions.yaml |
| // for more details. |
| |
| // NOLINTNEXTLINE(performance-unnecessary-copy-initialization) |
| Tensor grad = grad_; |
| Tensor index_grad = grad_.index_select(0, offset2bag); |
| |
| index_grad = apply_bag_size_backward(mode, index_grad, offset2bag, bag_size_); |
| |
| if (per_sample_weights.defined()) { |
| AT_ASSERT(mode == MODE_SUM); |
| index_grad.mul_(per_sample_weights.unsqueeze(1)); |
| } |
| return native::embedding_backward_symint(index_grad, indices, num_weights, padding_idx, |
| scale_grad_by_freq, true); |
| } |
| } |
| } // namespace at::native |