| #define TORCH_ASSERT_ONLY_METHOD_OPERATORS |
| #include <ATen/core/Tensor.h> |
| #include <ATen/AccumulateType.h> |
| #include <ATen/Dispatch.h> |
| #include <ATen/Parallel.h> |
| #include <ATen/TensorMeta.h> |
| #include <ATen/TensorUtils.h> |
| #include <ATen/TensorIterator.h> |
| #include <ATen/WrapDimUtils.h> |
| #include <ATen/native/cpu/SoftmaxKernel.h> |
| #include <ATen/NamedTensorUtils.h> |
| |
| #ifndef AT_PER_OPERATOR_HEADERS |
| #include <ATen/Functions.h> |
| #include <ATen/NativeFunctions.h> |
| #else |
| #include <ATen/ops/_log_softmax.h> |
| #include <ATen/ops/_log_softmax_backward_data_native.h> |
| #include <ATen/ops/_log_softmax_native.h> |
| #include <ATen/ops/_masked_softmax_backward_native.h> |
| #include <ATen/ops/_masked_softmax_native.h> |
| #include <ATen/ops/_softmax.h> |
| #include <ATen/ops/_softmax_backward_data_native.h> |
| #include <ATen/ops/_softmax_native.h> |
| #include <ATen/ops/empty.h> |
| #include <ATen/ops/empty_like.h> |
| #include <ATen/ops/log_softmax.h> |
| #include <ATen/ops/log_softmax_native.h> |
| #include <ATen/ops/softmax.h> |
| #include <ATen/ops/softmax_native.h> |
| #include <ATen/ops/special_log_softmax_native.h> |
| #include <ATen/ops/special_softmax_native.h> |
| #endif |
| |
| #include <c10/core/TensorOptions.h> |
| #include <c10/macros/Macros.h> |
| #include <c10/util/irange.h> |
| |
| namespace at::meta { |
| TORCH_META_FUNC(_softmax) |
| (const Tensor& input, const int64_t dim, const bool half_to_float) { |
| int64_t dim_ = maybe_wrap_dim(dim, input.dim()); |
| |
| auto output_options = |
| input.options().memory_format(LEGACY_CONTIGUOUS_MEMORY_FORMAT); |
| |
| if (half_to_float) { |
| output_options = output_options.dtype(ScalarType::Float); |
| } |
| |
| int64_t input_dim = input.dim() > 0 ? input.dim() : 1; |
| TORCH_CHECK( |
| dim_ >= 0 && dim_ < input_dim, |
| "dim must be non-negative and less than input dimensions"); |
| |
| set_output_raw_strided(0, input.sizes(), {}, output_options); |
| } |
| |
| TORCH_META_FUNC(_log_softmax) ( |
| const Tensor& input, |
| const int64_t dim, |
| const bool half_to_float) { |
| int64_t dim_ = maybe_wrap_dim(dim, input.dim()); |
| |
| auto output_options = |
| input.options().memory_format(LEGACY_CONTIGUOUS_MEMORY_FORMAT); |
| |
| if (half_to_float) { |
| output_options = output_options.dtype(ScalarType::Float); |
| } |
| |
| int64_t input_dim = input.dim() > 0 ? input.dim() : 1; |
| TORCH_CHECK( |
| dim_ >= 0 && dim_ < input_dim, |
| "dim must be non-negative and less than input dimensions"); |
| |
| set_output_raw_strided(0, input.sizes(), {}, output_options); |
| } |
| |
| TORCH_META_FUNC(_softmax_backward_data) |
| (const Tensor& grad, |
| const Tensor& output, |
| int64_t dim, |
| ScalarType input_dtype) { |
| TensorArg grad_arg{grad, "grad", 1}, output_arg{output, "output", 2}; |
| checkSameSize("softmax_backward", grad_arg, output_arg); |
| |
| int64_t dim_ = maybe_wrap_dim(dim, grad.dim()); |
| |
| auto grad_input_options = |
| grad.options().memory_format(LEGACY_CONTIGUOUS_MEMORY_FORMAT); |
| |
| bool half_to_float = grad.scalar_type() != input_dtype; |
| if (half_to_float) { |
| // The code below is only valid for the CUDA implementation. It's "okay" |
| // to put it here because half-to-float conversion is not supported by |
| // the CPU implementation of _softmax. There is a TORCH_CHECK in the CUDA |
| // implementation that should ideally go here as well, but there is at least |
| // one test in which the grad and input dtypes do not match for the CPU |
| // implementation of this kernel and it is not true that the grad type is |
| // float and the input dtype is half (see #63057). |
| if (grad.scalar_type() == ScalarType::Float && |
| input_dtype == ScalarType::Half) { |
| grad_input_options = grad_input_options.dtype(ScalarType::Half); |
| } |
| } |
| |
| int64_t grad_dim = grad.dim() > 0 ? grad.dim() : 1; |
| TORCH_CHECK( |
| dim_ >= 0 && dim_ < grad_dim, |
| "dim must be non-negative and less than input dimensions"); |
| |
| set_output_raw_strided(0, grad.sizes(), {}, grad_input_options); |
| } |
| |
| TORCH_META_FUNC(_log_softmax_backward_data) |
| (const Tensor& grad, |
| const Tensor& output, |
| int64_t dim, |
| ScalarType input_dtype){ |
| int64_t dim_ = maybe_wrap_dim(dim, grad.dim()); |
| TensorOptions grad_input_options( |
| grad.options().memory_format(LEGACY_CONTIGUOUS_MEMORY_FORMAT)); |
| |
| bool half_to_float = grad.scalar_type() != input_dtype; |
| if (half_to_float) { |
| // The code below is only valid for the CUDA implementation. It's "okay" |
| // to put it here because half-to-float conversion is not supported by |
| // the CPU implementation of _softmax. There is a TORCH_CHECK in the CUDA |
| // implementation that should ideally go here as well, but there is at least |
| // one test in which the grad and input dtypes do not match for the CPU |
| // implementation of this kernel and it is not true that the grad type is |
| // float and the input dtype is half (see #63057). |
| if (grad.scalar_type() == ScalarType::Float && |
| input_dtype == ScalarType::Half) { |
| grad_input_options = grad_input_options.dtype(ScalarType::Half); |
| } |
| } |
| |
| int64_t grad_dim = grad.dim() > 0 ? grad.dim() : 1; |
| TORCH_CHECK( |
| dim_ >= 0 && dim_ < grad_dim, |
| "dim must be non-negative and less than input dimensions"); |
| |
| set_output_raw_strided(0, grad.sizes(), {}, grad_input_options); |
| } |
| } // namespace at::meta |
| |
| namespace at::native { |
| namespace { |
| |
| template <typename scalar_t, bool LogSoftMax, bool MaskedSoftMax = false> |
| void host_softmax( |
| Tensor output, |
| const Tensor& input, |
| const int64_t dim, |
| bool* mask = nullptr, |
| const c10::optional<int64_t> mask_type_ = {}) { |
| |
| if (MaskedSoftMax) { |
| TORCH_CHECK(mask_type_.has_value(), "Mask Type should be defined"); |
| int64_t mask_type = mask_type_.value(); |
| // If mask_type == 2, then mask_.sizes() must equal input_.sizes() |
| TORCH_CHECK((mask_type == 0) || (mask_type == 1) || (mask_type == 2), "Mask Type should be 0 (src_mask) or 1 (src_key_padding_mask), or 2 (default_mask)"); |
| } |
| |
| int64_t outer_size = 1; |
| int64_t dim_size = input.size(dim); |
| int64_t inner_size = 1; |
| for (const auto i : c10::irange(dim)) { |
| outer_size *= input.size(i); |
| } |
| for (int64_t i = dim + 1; i < input.dim(); ++i) { |
| inner_size *= input.size(i); |
| } |
| int64_t dim_stride = inner_size; |
| int64_t outer_stride = dim_size * dim_stride; |
| scalar_t* input_data_base = input.data_ptr<scalar_t>(); |
| scalar_t* output_data_base = output.data_ptr<scalar_t>(); |
| bool* mask_data_base = mask; |
| int64_t grain_size = std::min(internal::GRAIN_SIZE / dim_size, (int64_t)1); |
| parallel_for( |
| 0, outer_size * inner_size, grain_size, |
| [&](int64_t begin, int64_t end) __ubsan_ignore_float_divide_by_zero__ { |
| for (const auto i : c10::irange(begin, end)) { |
| int64_t outer_idx = i / inner_size; |
| int64_t inner_idx = i % inner_size; |
| scalar_t* input_data = |
| input_data_base + outer_idx * outer_stride + inner_idx; |
| scalar_t* output_data = |
| output_data_base + outer_idx * outer_stride + inner_idx; |
| bool* mask_data = nullptr; |
| if (MaskedSoftMax) { |
| // Process mask differently depending on the type: |
| // For a generic mask of mask_type == 2, mask shape is the same as the input shape, |
| // so indexing is the same. |
| auto mask_outer_idx = outer_idx; |
| if (mask_type_ == 0) { |
| // Optimized case: attention mask of shape LxL |
| // outer_idx goes over BxHxL, mask_outer_idx goes over L. |
| mask_outer_idx = outer_idx % input.size(2); |
| } else if (mask_type_ == 1) { |
| // Optimized case: padding mask of shape BxL |
| // outer_idx goes over BxHxL, mask_outer_idx goes over B. |
| mask_outer_idx = outer_idx / (input.size(1) * input.size(2)); |
| } |
| |
| mask_data = mask_data_base + mask_outer_idx * outer_stride + inner_idx; |
| }; |
| |
| // Calc max in softmax dim |
| bool is_meaningful_max = false; |
| scalar_t max_input = input_data[0]; |
| if (!MaskedSoftMax) { |
| for (const auto d : c10::irange(1, dim_size)) { |
| max_input = std::max(max_input, input_data[d * dim_stride]); |
| } |
| } else { |
| for (const auto d : c10::irange(0, dim_size)) { |
| if (!mask_data[d * dim_stride]) { |
| max_input = is_meaningful_max |
| ? std::max(max_input, input_data[d * dim_stride]) |
| : input_data[d * dim_stride]; |
| is_meaningful_max = true; |
| } |
| } |
| } |
| |
| // Calc sum in softmax dim |
| acc_type<scalar_t, false> tmpsum = 0; |
| for (const auto d : c10::irange(dim_size)) { |
| scalar_t z{}; |
| if (!MaskedSoftMax || !mask_data[d * dim_stride]) { |
| z = std::exp(input_data[d * dim_stride] - max_input); |
| } else { |
| z = 0; |
| } |
| if (!LogSoftMax) { |
| output_data[d * dim_stride] = z; |
| } |
| tmpsum += z; |
| } |
| |
| if (LogSoftMax) { |
| tmpsum = std::log(tmpsum); |
| } else if (tmpsum == 0) { |
| tmpsum = std::numeric_limits<scalar_t>::quiet_NaN(); |
| } else { |
| tmpsum = 1 / tmpsum; |
| } |
| |
| // update output |
| for (const auto d : c10::irange(dim_size)) { |
| // LogSoftMax and MaskedSoftMax should not both be true |
| if (LogSoftMax) { |
| output_data[d * dim_stride] = |
| input_data[d * dim_stride] - max_input - tmpsum; |
| } else { |
| output_data[d * dim_stride] *= tmpsum; |
| } |
| } |
| } |
| }); |
| } |
| |
| template <typename scalar_t, bool LogSoftMax, bool MaskedSoftMax = false> |
| void host_softmax_backward( |
| const Tensor& gI, |
| const Tensor& grad, |
| const Tensor& output, |
| int64_t dim, |
| bool* mask = nullptr) { |
| |
| int64_t outer_size = 1; |
| int64_t dim_size = grad.size(dim); |
| int64_t inner_size = 1; |
| for (const auto i : c10::irange(dim)) { |
| outer_size *= grad.size(i); |
| } |
| for (int64_t i = dim + 1; i < grad.dim(); ++i) { |
| inner_size *= grad.size(i); |
| } |
| int64_t dim_stride = inner_size; |
| int64_t outer_stride = dim_size * dim_stride; |
| scalar_t* gradInput_data_base = gI.data_ptr<scalar_t>(); |
| scalar_t* output_data_base = output.data_ptr<scalar_t>(); |
| scalar_t* gradOutput_data_base = grad.data_ptr<scalar_t>(); |
| bool* mask_data_base = mask; |
| int64_t grain_size = std::min(internal::GRAIN_SIZE / dim_size, (int64_t)1); |
| parallel_for( |
| 0, outer_size * inner_size, grain_size, [&](int64_t begin, int64_t end) { |
| for (const auto i : c10::irange(begin, end)) { |
| int64_t outer_idx = i / inner_size; |
| int64_t inner_idx = i % inner_size; |
| scalar_t* gradInput_data = |
| gradInput_data_base + outer_idx * outer_stride + inner_idx; |
| scalar_t* output_data = |
| output_data_base + outer_idx * outer_stride + inner_idx; |
| const scalar_t* gradOutput_data = |
| gradOutput_data_base + outer_idx * outer_stride + inner_idx; |
| bool* mask_data = nullptr; |
| if (MaskedSoftMax) { |
| mask_data = mask_data_base + outer_idx * outer_stride + inner_idx; |
| } |
| |
| acc_type<scalar_t, false> sum = 0; |
| for (const auto d : c10::irange(dim_size)) { |
| if (!MaskedSoftMax || !mask_data[d * dim_stride]) { |
| if (LogSoftMax) { |
| sum += gradOutput_data[d * dim_stride]; |
| } else { |
| sum += |
| gradOutput_data[d * dim_stride] * output_data[d * dim_stride]; |
| } |
| } |
| } |
| |
| for (const auto d : c10::irange(dim_size)) { |
| if (MaskedSoftMax && mask_data[d * dim_stride]) { |
| gradInput_data[d * dim_stride] = 0; |
| } |
| else if (LogSoftMax) { |
| gradInput_data[d * dim_stride] = gradOutput_data[d * dim_stride] - |
| std::exp(output_data[d * dim_stride]) * sum; |
| } else { |
| gradInput_data[d * dim_stride] = output_data[d * dim_stride] * |
| (gradOutput_data[d * dim_stride] - sum); |
| } |
| } |
| } |
| }); |
| } |
| } // namespace |
| |
| TORCH_IMPL_FUNC(softmax_cpu_out) |
| (const Tensor& input, |
| const int64_t dim, |
| const bool half_to_float, |
| const Tensor& output) { |
| TORCH_CHECK(!half_to_float, "softmax with half to float conversion is not supported on CPU"); |
| |
| if (input.numel() == 0) { |
| return; |
| } |
| |
| auto input_ = input.contiguous(); |
| int64_t dim_ = maybe_wrap_dim(dim, input_.dim()); |
| |
| if (input_.dim() == 0) { |
| input_ = input_.view(1); |
| } |
| |
| TORCH_CHECK( |
| dim_ >= 0 && dim_ < input_.dim(), |
| "dim must be non-negative and less than input dimensions"); |
| if (input_.ndimension() > 0 && dim_ == input_.ndimension() - 1) { |
| softmax_lastdim_kernel(kCPU, output, input_); |
| } else { |
| softmax_kernel(kCPU, output, input_, dim_); |
| } |
| } |
| |
| TORCH_IMPL_FUNC(log_softmax_cpu_out) |
| (const Tensor& input, |
| const int64_t dim, |
| const bool half_to_float, |
| const Tensor& output) { |
| TORCH_CHECK( |
| !half_to_float, |
| "softmax with half to float conversion is not supported on CPU"); |
| |
| if (input.numel() == 0) { |
| return; |
| } |
| |
| auto input_ = input.contiguous(); |
| int64_t dim_ = maybe_wrap_dim(dim, input_.dim()); |
| |
| if (input_.dim() == 0) { |
| input_ = input_.view(1); |
| } |
| |
| if (input_.ndimension() > 0 && dim_ == input_.ndimension() - 1) { |
| log_softmax_lastdim_kernel(kCPU, output, input_); |
| } else { |
| log_softmax_kernel(kCPU, output, input_, dim_); |
| } |
| } |
| |
| TORCH_IMPL_FUNC(softmax_backward_cpu_out) |
| (const Tensor& grad, |
| const Tensor& output, |
| int64_t dim, |
| ScalarType input_dtype, |
| const Tensor& grad_input) { |
| int64_t dim_ = maybe_wrap_dim(dim, grad.dim()); |
| auto grad_ = grad.contiguous(); |
| auto output_ = output.contiguous(); |
| |
| if (output.numel() == 0) { |
| return; |
| } |
| |
| if (grad_.dim() == 0) { |
| grad_ = grad_.view(1); |
| } |
| |
| if (output_.dim() == 0) { |
| output_ = output_.view(1); |
| } |
| |
| if (grad_.ndimension() > 0 && dim_ == grad_.ndimension() - 1) { |
| softmax_backward_lastdim_kernel(kCPU, grad_input, grad_, output); |
| } else { |
| softmax_backward_kernel(kCPU, grad_input, grad_, output, dim_); |
| } |
| } |
| |
| TORCH_IMPL_FUNC(log_softmax_backward_cpu_out) ( |
| const Tensor& grad, |
| const Tensor& output, |
| int64_t dim, |
| ScalarType input_dtype, |
| const Tensor& grad_input) { |
| int64_t dim_ = maybe_wrap_dim(dim, grad.dim()); |
| auto grad_ = grad.contiguous(); |
| auto output_ = output.contiguous(); |
| |
| if (output.numel() != 0) { |
| if (grad_.dim() == 0) |
| grad_ = grad_.view(1); |
| if (output_.dim() == 0) { |
| output_ = output_.view(1); |
| } |
| if (grad_.ndimension() > 0 && dim_ == grad_.ndimension() - 1) { |
| log_softmax_backward_lastdim_kernel(kCPU, grad_input, grad_, output_); |
| } else { |
| log_softmax_backward_kernel(kCPU, grad_input, grad_, output_, dim_); |
| } |
| } |
| } |
| |
| static Tensor softmax(const Tensor& input_, const int64_t dim_) { |
| auto result = [&]() { |
| NoNamesGuard guard; |
| return at::_softmax(input_, dim_, false); |
| }(); |
| namedinference::propagate_names(result, input_); |
| return result; |
| } |
| |
| Tensor softmax(const Tensor& input_, const int64_t dim_, c10::optional<ScalarType> dtype) { |
| auto result = [&]() { |
| NoNamesGuard guard; |
| if (input_.is_cuda() && input_.scalar_type() == ScalarType::Half && dtype == ScalarType::Float){ |
| return at::_softmax(input_, dim_, true); |
| } else { |
| Tensor converted = dtype.has_value() ? input_.toType(dtype.value()) : input_; |
| return at::_softmax(converted, dim_, false); |
| } |
| }(); |
| namedinference::propagate_names(result, input_); |
| return result; |
| } |
| |
| Tensor& softmax_out( |
| const Tensor& input_, |
| const int64_t dim_, |
| c10::optional<ScalarType> dtype, |
| Tensor& output_) { |
| Tensor output_temp; |
| if (input_.is_cuda() && input_.scalar_type() == ScalarType::Half && |
| dtype == ScalarType::Float) { |
| if (!output_.is_contiguous()) { |
| auto options = |
| TensorOptions().dtype(output_.dtype()).device(output_.device()); |
| output_temp = at::empty(output_.sizes(), options); |
| at::_softmax_out(output_temp, input_, dim_, true); |
| } else { |
| at::_softmax_out(output_, input_, dim_, true); |
| } |
| } else { |
| Tensor converted = |
| dtype.has_value() ? input_.toType(dtype.value()) : input_; |
| if (!output_.is_contiguous()) { |
| auto options = |
| TensorOptions().dtype(output_.dtype()).device(output_.device()); |
| output_temp = at::empty(output_.sizes(), options); |
| at::_softmax_out(output_temp, converted, dim_, false); |
| } else { |
| at::_softmax_out(output_, converted, dim_, false); |
| } |
| } |
| |
| if (!output_.is_contiguous()) { |
| output_.resize_(output_temp.sizes()); |
| output_.copy_(output_temp); |
| } |
| |
| return output_; |
| } |
| |
| // special_softmax, alias for softmax |
| Tensor special_softmax(const Tensor& input_, const int64_t dim_, c10::optional<ScalarType> dtype) { |
| return at::softmax(input_, dim_, dtype); |
| } |
| |
| static Tensor log_softmax(const Tensor& input_, const int64_t dim_) { |
| auto result = [&]() { |
| NoNamesGuard guard; |
| return at::_log_softmax(input_, dim_, false); |
| }(); |
| namedinference::propagate_names(result, input_); |
| return result; |
| } |
| |
| Tensor log_softmax(const Tensor& input_, const int64_t dim_, c10::optional<ScalarType> dtype) { |
| auto result = [&]() { |
| NoNamesGuard guard; |
| if (input_.is_cuda() && input_.scalar_type() == ScalarType::Half && dtype == ScalarType::Float){ |
| return at::_log_softmax(input_, dim_, true); |
| } else { |
| Tensor converted = dtype.has_value()? input_.toType(dtype.value()) : input_; |
| return at::_log_softmax(converted, dim_, false); |
| } |
| }(); |
| namedinference::propagate_names(result, input_); |
| return result; |
| } |
| |
| Tensor& log_softmax_out( |
| const Tensor& input_, |
| const int64_t dim_, |
| c10::optional<ScalarType> dtype, |
| Tensor& output_) { |
| Tensor output_temp; |
| if (input_.is_cuda() && input_.scalar_type() == ScalarType::Half && |
| dtype == ScalarType::Float) { |
| if (!output_.is_contiguous()) { |
| auto options = |
| TensorOptions().dtype(output_.dtype()).device(output_.device()); |
| output_temp = at::empty(output_.sizes(), options); |
| at::_log_softmax_out(output_temp, input_, dim_, true); |
| } else { |
| at::_log_softmax_out(output_, input_, dim_, true); |
| } |
| } else { |
| Tensor converted = |
| dtype.has_value() ? input_.toType(dtype.value()) : input_; |
| if (!output_.is_contiguous()) { |
| auto options = |
| TensorOptions().dtype(output_.dtype()).device(output_.device()); |
| output_temp = at::empty(output_.sizes(), options); |
| at::_log_softmax_out(output_temp, converted, dim_, false); |
| } else { |
| at::_log_softmax_out(output_, converted, dim_, false); |
| } |
| } |
| |
| if (!output_.is_contiguous()) { |
| output_.resize_(output_temp.sizes()); |
| output_.copy_(output_temp); |
| } |
| |
| return output_; |
| } |
| |
| Tensor special_log_softmax(const Tensor& input, const int64_t dim, c10::optional<ScalarType> dtype) { |
| return at::log_softmax(input, dim, dtype); |
| } |
| |
| DEFINE_DISPATCH(softmax_lastdim_kernel); |
| DEFINE_DISPATCH(log_softmax_lastdim_kernel); |
| DEFINE_DISPATCH(softmax_backward_lastdim_kernel); |
| DEFINE_DISPATCH(log_softmax_backward_lastdim_kernel); |
| |
| DEFINE_DISPATCH(softmax_kernel); |
| DEFINE_DISPATCH(log_softmax_kernel); |
| DEFINE_DISPATCH(softmax_backward_kernel); |
| DEFINE_DISPATCH(log_softmax_backward_kernel); |
| |
| Tensor softmax(const Tensor& self, Dimname dim, optional<ScalarType> dtype) { |
| return at::softmax(self, dimname_to_position(self, dim), dtype); |
| } |
| |
| Tensor log_softmax(const Tensor& self, Dimname dim, optional<ScalarType> dtype) { |
| return at::log_softmax(self, dimname_to_position(self, dim), dtype); |
| } |
| |
| Tensor masked_softmax_cpu(const Tensor& input_, const Tensor& mask_, const c10::optional<int64_t> dim_, const c10::optional<int64_t> mask_type_) { |
| |
| auto mask = mask_.contiguous(); |
| auto mask_type = mask_type_; // Mask type might get transformed below |
| |
| TORCH_CHECK( |
| mask_.scalar_type() == ScalarType::Bool, |
| "Mask should be a boolean tensor"); |
| |
| if ((mask.dim() != 2) || (input_.dim() != 4)) { |
| // Mask types 0 and 1 are only allowed for 2D masks and 4D inputs |
| mask_type = 2; |
| } |
| |
| if (mask_type == 2) { |
| TORCH_CHECK(input_.sizes() == mask.sizes(), |
| "For mask_type == 2 mask shape should match input shape") |
| } else if (mask_type == 1) { |
| // Padding mask of shape (B, L) |
| TORCH_CHECK((input_.sizes()[0] == mask.sizes()[0]) && (input_.sizes()[2] == mask.sizes()[1]), |
| "For mask_type == 1 mask shape should be (B, L)"); |
| if (dim_ != input_.dim() - 1) { |
| // We only process padding mask in the optimized way if softmax is applied along the last dimesion, |
| // otherwise we need to expand the mask into a generic 4D one |
| mask = mask_.view({input_.sizes()[0], 1, 1, input_.sizes()[2]}); |
| mask = mask.expand(input_.sizes()).contiguous(); |
| mask_type = 2; |
| } |
| } else if (mask_type == 0) { |
| // Attention mask of shape (L, L) |
| TORCH_CHECK((mask.dim() == 2) && (input_.sizes()[2] == mask.sizes()[0]) && (input_.sizes()[2] == mask.sizes()[1]), |
| "For mask_type == 0 mask shape should be (L, L)"); |
| if (dim_ != input_.dim() - 1) { |
| // We only process attention mask in a optimized way if softmax is applied along the last dimesion, |
| // otherwise we need to expand the mask into a generic 4D one |
| mask = mask.view({1, 1, input_.sizes()[2], input_.sizes()[2]}); |
| mask = mask.expand(input_.sizes()).contiguous(); |
| mask_type = 2; |
| } |
| } |
| |
| Tensor output = at::empty_like(input_, input_.options()); |
| auto input = input_.contiguous(); |
| int64_t dim = dim_.has_value() ? dim_.value() : input.dim() - 1; |
| dim = maybe_wrap_dim(dim, input_.dim()); |
| |
| if (input.dim() == 0) { |
| input = input.view(1); |
| } |
| |
| AT_DISPATCH_FLOATING_TYPES_AND2( |
| at::ScalarType::BFloat16, at::ScalarType::Half, input.scalar_type(), "masked_softmax", [&] { |
| host_softmax< |
| scalar_t, |
| false /* LogSoftMax */, |
| true /* MaskedSoftMax */>( |
| output, input, dim, mask.data_ptr<bool>(), mask_type); |
| }); |
| return output; |
| } |
| |
| Tensor masked_softmax_backward_cpu( |
| const Tensor& grad_, |
| const Tensor& output_, |
| const Tensor& mask_, |
| const c10::optional<int64_t> dim_) { |
| TORCH_CHECK( |
| grad_.sizes() == mask_.sizes(), "Mask shape should match grad shape"); |
| TORCH_CHECK( |
| mask_.scalar_type() == ScalarType::Bool, |
| "Mask should be a boolean tensor"); |
| auto grad = grad_.contiguous(); |
| auto output = output_.contiguous(); |
| auto mask = mask_.contiguous(); |
| |
| int64_t dim = dim_.has_value() ? dim_.value() : output.dim() - 1; |
| dim = maybe_wrap_dim(dim, grad.dim()); |
| |
| grad = grad.dim() == 0 ? grad.view(1) : grad; |
| output = output.dim() == 0 ? output.view(1) : output; |
| mask = mask.dim() == 0 ? mask.view(1) : mask; |
| |
| Tensor grad_input = at::empty_like(grad, grad.options()); |
| AT_DISPATCH_FLOATING_TYPES_AND2( |
| at::ScalarType::BFloat16, at::ScalarType::Half, grad.scalar_type(), "masked_softmax_backward", [&] { |
| host_softmax_backward< |
| scalar_t, |
| false /* LogSoftMax */, |
| true /* MaskedSoftmax */>(grad_input, grad, output, dim, mask.data_ptr<bool>()); |
| }); |
| return grad_input; |
| } |
| } // namespace at::native |