| #define TORCH_ASSERT_ONLY_METHOD_OPERATORS |
| #include <ATen/core/Tensor.h> |
| #include <ATen/AccumulateType.h> |
| #include <ATen/Dispatch.h> |
| #include <c10/macros/Macros.h> |
| #include <ATen/cuda/CUDAContext.h> |
| #include <ATen/native/cuda/block_reduce.cuh> |
| |
| #ifndef AT_PER_OPERATOR_HEADERS |
| #include <ATen/Functions.h> |
| #include <ATen/CUDAFunctions.h> |
| #include <ATen/NativeFunctions.h> |
| #else |
| #include <ATen/ops/empty.h> |
| #include <ATen/ops/zeros_like.h> |
| #include <ATen/ops/sum_cuda_dispatch.h> |
| #include <ATen/ops/multilabel_margin_loss.h> |
| #endif |
| |
| |
| namespace at::native { |
| |
| namespace { |
| const int MULTILABELMARGIN_THREADS = 128; |
| |
| void multilabel_margin_loss_shape_check( |
| int64_t& nframe, |
| int64_t& dim, |
| const int64_t& ndims, |
| const Tensor& input, |
| const Tensor& target) { |
| TORCH_CHECK( |
| (ndims == 2 && input.size(1) != 0) || (ndims == 1 && input.size(0) != 0) || ndims == 0, |
| "Expected non-empty vector or matrix with optional 0-dim batch size, but got: ", |
| input.sizes()); |
| |
| if (ndims <= 1) { |
| nframe = 1; |
| dim = ndims == 0 ? 1 : input.size(0); |
| TORCH_CHECK( |
| target.dim() <= 1 && target.numel() == dim, |
| "inconsistent target size: ", target.sizes(), " for input of size: ", |
| input.sizes()); |
| } else { |
| nframe = input.size(0); |
| dim = input.size(1); |
| TORCH_CHECK( |
| target.dim() == 2 && target.size(0) == nframe && |
| target.size(1) == dim, |
| "inconsistent target size: ", target.sizes(), " for input of size: ", |
| input.sizes()); |
| } |
| } |
| |
| template <typename scalar_t, typename accscalar_t> |
| C10_LAUNCH_BOUNDS_1(MULTILABELMARGIN_THREADS) |
| __global__ void multilabel_margin_loss_forward_kernel( |
| scalar_t* output, |
| const scalar_t* input, |
| const int64_t* target, |
| scalar_t* is_target, |
| int nframe, |
| int dim, |
| bool size_average) { |
| |
| // vectors: |
| int k = blockIdx.x; |
| const scalar_t* input_k = input + k * dim; |
| const int64_t* target_k = target + k * dim; |
| scalar_t* output_k = output + k; |
| scalar_t* is_target_k = is_target + k * dim; |
| |
| // zero is_target |
| for (int d = threadIdx.x; d < dim; d += blockDim.x) { |
| is_target_k[d] = static_cast<scalar_t>(0); |
| } |
| __syncthreads(); |
| |
| // mark targets in is_target |
| if (threadIdx.x == 0) { |
| for (int dt = 0; dt < dim; dt++) { |
| int target_idx = target_k[dt]; |
| if (target_idx < 0) { |
| break; |
| } |
| is_target_k[target_idx] = static_cast<scalar_t>(1); |
| } |
| } |
| __syncthreads(); |
| |
| // iterate over targets |
| accscalar_t sum = 0; |
| for (int dt = 0; dt < dim; dt++) { |
| // next target: |
| int target_idx = target_k[dt]; |
| if (target_idx < 0) { |
| break; |
| } |
| |
| // current value for target |
| scalar_t input_target_k = input_k[target_idx]; |
| |
| // compare to all inputs (multithreaded): |
| for (int d = threadIdx.x; d < dim; d += blockDim.x) { |
| // contribute to loss only if not a target |
| if (!static_cast<int>(is_target_k[d])) { |
| scalar_t z = 1 - input_target_k + input_k[d]; |
| if (z > 0) { |
| sum += z; |
| } |
| } |
| } |
| } |
| |
| // Temporary sums (for mapreduce) |
| __shared__ accscalar_t smem[MULTILABELMARGIN_THREADS]; |
| accscalar_t total_sum = cuda_utils::BlockReduceSum(sum, smem); |
| if (threadIdx.x == 0) { |
| if (size_average) { |
| *output_k = static_cast<scalar_t>((total_sum / dim) / nframe); |
| } else { |
| *output_k = static_cast<scalar_t>(total_sum / dim); |
| } |
| } |
| } |
| |
| template <typename scalar_t, typename accscalar_t> |
| C10_LAUNCH_BOUNDS_1(MULTILABELMARGIN_THREADS) |
| __global__ void multilabel_margin_loss_backward_kernel( |
| scalar_t* grad_input, |
| const scalar_t* grad_output, |
| const scalar_t* input, |
| const int64_t* target, |
| const scalar_t* is_target, |
| int nframe, |
| int dim, |
| bool size_average, |
| bool reduce) { |
| |
| int k = blockIdx.x; |
| const scalar_t* input_k = input + k * dim; |
| scalar_t* grad_input_k = grad_input + k * dim; |
| const int64_t* target_k = target + k * dim; |
| const scalar_t* is_target_k = is_target + k * dim; |
| |
| const scalar_t* grad_output_k = grad_output; |
| if (!reduce) { |
| grad_output_k += k; |
| } |
| |
| // gain: |
| scalar_t g = static_cast<scalar_t>( |
| size_average && reduce ? 1. / static_cast<accscalar_t>(nframe * dim) |
| : 1. / static_cast<accscalar_t>(dim)); |
| |
| // zero gradients: |
| for (int d = threadIdx.x; d < dim; d += blockDim.x) { |
| grad_input_k[d] = static_cast<scalar_t>(0); |
| } |
| __syncthreads(); |
| |
| // iterate over targets |
| for (int dt = 0; dt < dim; dt++) { |
| // next target: |
| int target_idx = static_cast<int>(target_k[dt]); |
| if (target_idx < 0) { |
| break; |
| } |
| |
| // current value for target |
| scalar_t input_target_k = input_k[target_idx]; |
| |
| // compare to all inputs (multithreaded): |
| accscalar_t sum = 0; |
| for (int d = threadIdx.x; d < dim; d += blockDim.x) { |
| // contribute to loss only if not a target |
| if (!static_cast<int>(is_target_k[d])) { |
| scalar_t z = 1 - input_target_k + input_k[d]; |
| if (z > 0) { |
| sum -= g; |
| grad_input_k[d] += g; |
| } |
| } |
| } |
| __syncthreads(); |
| |
| // Temporary sums (for mapreduce) |
| __shared__ accscalar_t smem[MULTILABELMARGIN_THREADS]; |
| accscalar_t total_sum = cuda_utils::BlockReduceSum(sum, smem); |
| if (threadIdx.x == 0) { |
| grad_input_k[target_idx] += static_cast<scalar_t>(total_sum); |
| } |
| } |
| |
| for (int d = threadIdx.x; d < dim; d += blockDim.x) { |
| grad_input_k[d] *= *grad_output_k; |
| } |
| } |
| |
| void multilabel_margin_loss_forward_out_cuda_template( |
| const Tensor& input, |
| const Tensor& target, |
| int64_t reduction, |
| Tensor& output, |
| Tensor& is_target) { |
| int64_t nframe, dim; |
| const int64_t ndims = input.dim(); |
| multilabel_margin_loss_shape_check(nframe, dim, ndims, input, target); |
| |
| if (input.numel() == 0) { |
| return; |
| } |
| |
| auto input_ = input.contiguous(); |
| auto target_ = target.contiguous(); |
| auto is_target_ = is_target.contiguous(); |
| is_target_.resize_as_(target); |
| |
| if (input.dim() <= 1) { |
| output.resize_({}); |
| |
| dim3 blocks(1); |
| dim3 threads(MULTILABELMARGIN_THREADS); |
| |
| AT_DISPATCH_FLOATING_TYPES_AND2( |
| at::ScalarType::Half, |
| at::ScalarType::BFloat16, |
| input.scalar_type(), |
| "multilabel_margin_loss_forward_kernel", |
| [&] { |
| using accscalar_t = at::acc_type<scalar_t, true>; |
| multilabel_margin_loss_forward_kernel<scalar_t, accscalar_t> |
| <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>( |
| output.mutable_data_ptr<scalar_t>(), |
| input_.const_data_ptr<scalar_t>(), |
| target_.const_data_ptr<int64_t>(), |
| is_target_.mutable_data_ptr<scalar_t>(), |
| 1, |
| dim, |
| reduction == at::Reduction::Mean); |
| C10_CUDA_KERNEL_LAUNCH_CHECK(); |
| }); |
| } else if (input.dim() == 2) { |
| dim3 blocks(input.size(0)); |
| dim3 threads(MULTILABELMARGIN_THREADS); |
| |
| if (reduction != at::Reduction::None) { |
| auto output_tmp = at::empty({input_.size(0)}, input_.options()); |
| output.resize_({}); |
| AT_DISPATCH_FLOATING_TYPES_AND2( |
| at::ScalarType::Half, |
| at::ScalarType::BFloat16, |
| input.scalar_type(), |
| "multilabel_margin_loss_forward_kernel", |
| [&] { |
| using accscalar_t = at::acc_type<scalar_t, true>; |
| multilabel_margin_loss_forward_kernel<scalar_t, accscalar_t> |
| <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>( |
| output_tmp.mutable_data_ptr<scalar_t>(), |
| input_.const_data_ptr<scalar_t>(), |
| target_.const_data_ptr<int64_t>(), |
| is_target_.mutable_data_ptr<scalar_t>(), |
| nframe, |
| dim, |
| reduction == at::Reduction::Mean); |
| C10_CUDA_KERNEL_LAUNCH_CHECK(); |
| }); |
| at::cuda::sum_out( |
| output, |
| output_tmp, |
| at::IntArrayRef(std::vector<int64_t>{}), |
| false, |
| output.scalar_type()); |
| } else { |
| output.resize_({input.size(0)}); |
| AT_DISPATCH_FLOATING_TYPES_AND2( |
| at::ScalarType::Half, |
| at::ScalarType::BFloat16, |
| input.scalar_type(), |
| "multilabel_margin_loss_forward_kernel", |
| [&] { |
| using accscalar_t = at::acc_type<scalar_t, true>; |
| multilabel_margin_loss_forward_kernel<scalar_t, accscalar_t> |
| <<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>( |
| output.mutable_data_ptr<scalar_t>(), |
| input_.const_data_ptr<scalar_t>(), |
| target_.const_data_ptr<int64_t>(), |
| is_target_.mutable_data_ptr<scalar_t>(), |
| nframe, |
| dim, |
| false); |
| C10_CUDA_KERNEL_LAUNCH_CHECK(); |
| }); |
| } |
| |
| } else { |
| TORCH_CHECK( |
| false, |
| "Expected 2D input with optional zero batch dim, or 1D input with non-zero dims, but got sizes: ", |
| input.sizes()); |
| } |
| } |
| |
| void multilabel_margin_loss_backward_cuda_out_template( |
| const Tensor& grad_output, |
| const Tensor& input, |
| const Tensor& target, |
| int64_t reduction, |
| const Tensor& is_target, |
| Tensor& grad_input) { |
| int64_t nframe, dim; |
| const int64_t ndims = input.dim(); |
| multilabel_margin_loss_shape_check(nframe, dim, ndims, input, target); |
| |
| if (input.numel() == 0) { |
| return; |
| } |
| |
| auto input_ = input.contiguous(); |
| auto target_ = target.contiguous(); |
| auto is_target_ = is_target.contiguous(); |
| auto grad_output_ = grad_output.contiguous(); |
| grad_input.resize_as_(input_); |
| |
| if (grad_input.dim() <= 1) { |
| int target_size = target_.dim() == 0 ? 1 : target_.size(0); |
| TORCH_CHECK( |
| (target_.numel() != 0) && (target_.dim() <= 1) && (target_size == dim), |
| "inconsistent target size"); |
| TORCH_CHECK( |
| target_.sizes() == is_target_.sizes(), "inconsistent is_target size"); |
| dim3 blocks(1); |
| dim3 threads(MULTILABELMARGIN_THREADS); |
| |
| AT_DISPATCH_FLOATING_TYPES_AND2( |
| at::ScalarType::Half, |
| at::ScalarType::BFloat16, |
| input.scalar_type(), |
| "multilabel_margin_loss_backward_kernel", |
| [&] { |
| using accscalar_t = at::acc_type<scalar_t, true>; |
| multilabel_margin_loss_backward_kernel<scalar_t, accscalar_t> |
| <<<blocks, threads, 0, c10::cuda::getCurrentCUDAStream()>>>( |
| grad_input.mutable_data_ptr<scalar_t>(), |
| grad_output_.const_data_ptr<scalar_t>(), |
| input_.const_data_ptr<scalar_t>(), |
| target_.const_data_ptr<int64_t>(), |
| is_target_.const_data_ptr<scalar_t>(), |
| 1, |
| dim, |
| reduction == at::Reduction::Mean, |
| reduction != at::Reduction::None); |
| C10_CUDA_KERNEL_LAUNCH_CHECK(); |
| }); |
| } else if (grad_input.dim() == 2) { |
| TORCH_CHECK( |
| (input_.size(1) != 0) && (target_.dim() == 2) && |
| (target_.size(0) == nframe) && (target_.size(1) == dim), |
| "inconsistent target size"); |
| TORCH_CHECK(target_.sizes() == is_target_.sizes(), "inconsistent is_target size"); |
| dim3 blocks(grad_input.size(0)); |
| dim3 threads(MULTILABELMARGIN_THREADS); |
| |
| AT_DISPATCH_FLOATING_TYPES_AND2( |
| at::ScalarType::Half, |
| at::ScalarType::BFloat16, |
| input.scalar_type(), |
| "multilabel_margin_loss_backward_kernel", |
| [&] { |
| using accscalar_t = at::acc_type<scalar_t, true>; |
| multilabel_margin_loss_backward_kernel<scalar_t, accscalar_t> |
| <<<blocks, threads, 0, c10::cuda::getCurrentCUDAStream()>>>( |
| grad_input.mutable_data_ptr<scalar_t>(), |
| grad_output_.const_data_ptr<scalar_t>(), |
| input_.const_data_ptr<scalar_t>(), |
| target_.const_data_ptr<int64_t>(), |
| is_target_.const_data_ptr<scalar_t>(), |
| grad_input.size(0), |
| grad_input.size(1), |
| reduction == at::Reduction::Mean, |
| reduction != at::Reduction::None); |
| C10_CUDA_KERNEL_LAUNCH_CHECK(); |
| }); |
| } else { |
| TORCH_CHECK( |
| false, |
| "Expected 2D input with optional zero batch dim, or 1D input with non-zero dims, but got sizes: ", |
| grad_input.sizes()); |
| } |
| } |
| |
| } // namespace |
| |
| std::tuple<Tensor&, Tensor&> multilabel_margin_loss_forward_out_cuda( |
| const Tensor& self, |
| const Tensor& target, |
| int64_t reduction, |
| Tensor& output, |
| Tensor& is_target) { |
| multilabel_margin_loss_forward_out_cuda_template( |
| self, target, reduction, output, is_target); |
| return std::tuple<Tensor&, Tensor&>(output, is_target); |
| } |
| |
| std::tuple<Tensor, Tensor> multilabel_margin_loss_forward_cuda( |
| const Tensor& self, |
| const Tensor& target, |
| int64_t reduction) { |
| auto output = at::empty({0}, self.options()); |
| auto is_target = at::empty({0}, self.options()); |
| multilabel_margin_loss_forward_out_cuda_template( |
| self, target, reduction, output, is_target); |
| return std::make_tuple(output, is_target); |
| } |
| |
| Tensor& multilabel_margin_loss_backward_cuda_out( |
| const Tensor& grad_output, |
| const Tensor& self, |
| const Tensor& target, |
| int64_t reduction, |
| const Tensor& is_target, |
| Tensor& grad_input) { |
| multilabel_margin_loss_backward_cuda_out_template( |
| grad_output, self, target, reduction, is_target, grad_input); |
| return grad_input; |
| } |
| |
| Tensor multilabel_margin_loss_backward_cuda( |
| const Tensor& grad_output, |
| const Tensor& self, |
| const Tensor& target, |
| int64_t reduction, |
| const Tensor& is_target) { |
| auto grad_input = at::zeros_like(self, LEGACY_CONTIGUOUS_MEMORY_FORMAT); |
| multilabel_margin_loss_backward_cuda_out_template( |
| grad_output, self, target, reduction, is_target, grad_input); |
| return grad_input; |
| } |
| |
| } // namespace at::native |