blob: 28ed4dcc0830285d3d2c71642ec8c4c88476a580 [file] [log] [blame]
#define TORCH_ASSERT_ONLY_METHOD_OPERATORS
#include <ATen/core/Tensor.h>
#include <ATen/AccumulateType.h>
#include <ATen/Dispatch.h>
#include <c10/macros/Macros.h>
#include <ATen/cuda/CUDAContext.h>
#include <ATen/native/cuda/block_reduce.cuh>
#ifndef AT_PER_OPERATOR_HEADERS
#include <ATen/Functions.h>
#include <ATen/CUDAFunctions.h>
#include <ATen/NativeFunctions.h>
#else
#include <ATen/ops/empty.h>
#include <ATen/ops/zeros_like.h>
#include <ATen/ops/sum_cuda_dispatch.h>
#include <ATen/ops/multilabel_margin_loss.h>
#endif
namespace at::native {
namespace {
const int MULTILABELMARGIN_THREADS = 128;
void multilabel_margin_loss_shape_check(
int64_t& nframe,
int64_t& dim,
const int64_t& ndims,
const Tensor& input,
const Tensor& target) {
TORCH_CHECK(
(ndims == 2 && input.size(1) != 0) || (ndims == 1 && input.size(0) != 0) || ndims == 0,
"Expected non-empty vector or matrix with optional 0-dim batch size, but got: ",
input.sizes());
if (ndims <= 1) {
nframe = 1;
dim = ndims == 0 ? 1 : input.size(0);
TORCH_CHECK(
target.dim() <= 1 && target.numel() == dim,
"inconsistent target size: ", target.sizes(), " for input of size: ",
input.sizes());
} else {
nframe = input.size(0);
dim = input.size(1);
TORCH_CHECK(
target.dim() == 2 && target.size(0) == nframe &&
target.size(1) == dim,
"inconsistent target size: ", target.sizes(), " for input of size: ",
input.sizes());
}
}
template <typename scalar_t, typename accscalar_t>
C10_LAUNCH_BOUNDS_1(MULTILABELMARGIN_THREADS)
__global__ void multilabel_margin_loss_forward_kernel(
scalar_t* output,
const scalar_t* input,
const int64_t* target,
scalar_t* is_target,
int nframe,
int dim,
bool size_average) {
// vectors:
int k = blockIdx.x;
const scalar_t* input_k = input + k * dim;
const int64_t* target_k = target + k * dim;
scalar_t* output_k = output + k;
scalar_t* is_target_k = is_target + k * dim;
// zero is_target
for (int d = threadIdx.x; d < dim; d += blockDim.x) {
is_target_k[d] = static_cast<scalar_t>(0);
}
__syncthreads();
// mark targets in is_target
if (threadIdx.x == 0) {
for (int dt = 0; dt < dim; dt++) {
int target_idx = target_k[dt];
if (target_idx < 0) {
break;
}
is_target_k[target_idx] = static_cast<scalar_t>(1);
}
}
__syncthreads();
// iterate over targets
accscalar_t sum = 0;
for (int dt = 0; dt < dim; dt++) {
// next target:
int target_idx = target_k[dt];
if (target_idx < 0) {
break;
}
// current value for target
scalar_t input_target_k = input_k[target_idx];
// compare to all inputs (multithreaded):
for (int d = threadIdx.x; d < dim; d += blockDim.x) {
// contribute to loss only if not a target
if (!static_cast<int>(is_target_k[d])) {
scalar_t z = 1 - input_target_k + input_k[d];
if (z > 0) {
sum += z;
}
}
}
}
// Temporary sums (for mapreduce)
__shared__ accscalar_t smem[MULTILABELMARGIN_THREADS];
accscalar_t total_sum = cuda_utils::BlockReduceSum(sum, smem);
if (threadIdx.x == 0) {
if (size_average) {
*output_k = static_cast<scalar_t>((total_sum / dim) / nframe);
} else {
*output_k = static_cast<scalar_t>(total_sum / dim);
}
}
}
template <typename scalar_t, typename accscalar_t>
C10_LAUNCH_BOUNDS_1(MULTILABELMARGIN_THREADS)
__global__ void multilabel_margin_loss_backward_kernel(
scalar_t* grad_input,
const scalar_t* grad_output,
const scalar_t* input,
const int64_t* target,
const scalar_t* is_target,
int nframe,
int dim,
bool size_average,
bool reduce) {
int k = blockIdx.x;
const scalar_t* input_k = input + k * dim;
scalar_t* grad_input_k = grad_input + k * dim;
const int64_t* target_k = target + k * dim;
const scalar_t* is_target_k = is_target + k * dim;
const scalar_t* grad_output_k = grad_output;
if (!reduce) {
grad_output_k += k;
}
// gain:
scalar_t g = static_cast<scalar_t>(
size_average && reduce ? 1. / static_cast<accscalar_t>(nframe * dim)
: 1. / static_cast<accscalar_t>(dim));
// zero gradients:
for (int d = threadIdx.x; d < dim; d += blockDim.x) {
grad_input_k[d] = static_cast<scalar_t>(0);
}
__syncthreads();
// iterate over targets
for (int dt = 0; dt < dim; dt++) {
// next target:
int target_idx = static_cast<int>(target_k[dt]);
if (target_idx < 0) {
break;
}
// current value for target
scalar_t input_target_k = input_k[target_idx];
// compare to all inputs (multithreaded):
accscalar_t sum = 0;
for (int d = threadIdx.x; d < dim; d += blockDim.x) {
// contribute to loss only if not a target
if (!static_cast<int>(is_target_k[d])) {
scalar_t z = 1 - input_target_k + input_k[d];
if (z > 0) {
sum -= g;
grad_input_k[d] += g;
}
}
}
__syncthreads();
// Temporary sums (for mapreduce)
__shared__ accscalar_t smem[MULTILABELMARGIN_THREADS];
accscalar_t total_sum = cuda_utils::BlockReduceSum(sum, smem);
if (threadIdx.x == 0) {
grad_input_k[target_idx] += static_cast<scalar_t>(total_sum);
}
}
for (int d = threadIdx.x; d < dim; d += blockDim.x) {
grad_input_k[d] *= *grad_output_k;
}
}
void multilabel_margin_loss_forward_out_cuda_template(
const Tensor& input,
const Tensor& target,
int64_t reduction,
Tensor& output,
Tensor& is_target) {
int64_t nframe, dim;
const int64_t ndims = input.dim();
multilabel_margin_loss_shape_check(nframe, dim, ndims, input, target);
if (input.numel() == 0) {
return;
}
auto input_ = input.contiguous();
auto target_ = target.contiguous();
auto is_target_ = is_target.contiguous();
is_target_.resize_as_(target);
if (input.dim() <= 1) {
output.resize_({});
dim3 blocks(1);
dim3 threads(MULTILABELMARGIN_THREADS);
AT_DISPATCH_FLOATING_TYPES_AND2(
at::ScalarType::Half,
at::ScalarType::BFloat16,
input.scalar_type(),
"multilabel_margin_loss_forward_kernel",
[&] {
using accscalar_t = at::acc_type<scalar_t, true>;
multilabel_margin_loss_forward_kernel<scalar_t, accscalar_t>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
output.mutable_data_ptr<scalar_t>(),
input_.const_data_ptr<scalar_t>(),
target_.const_data_ptr<int64_t>(),
is_target_.mutable_data_ptr<scalar_t>(),
1,
dim,
reduction == at::Reduction::Mean);
C10_CUDA_KERNEL_LAUNCH_CHECK();
});
} else if (input.dim() == 2) {
dim3 blocks(input.size(0));
dim3 threads(MULTILABELMARGIN_THREADS);
if (reduction != at::Reduction::None) {
auto output_tmp = at::empty({input_.size(0)}, input_.options());
output.resize_({});
AT_DISPATCH_FLOATING_TYPES_AND2(
at::ScalarType::Half,
at::ScalarType::BFloat16,
input.scalar_type(),
"multilabel_margin_loss_forward_kernel",
[&] {
using accscalar_t = at::acc_type<scalar_t, true>;
multilabel_margin_loss_forward_kernel<scalar_t, accscalar_t>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
output_tmp.mutable_data_ptr<scalar_t>(),
input_.const_data_ptr<scalar_t>(),
target_.const_data_ptr<int64_t>(),
is_target_.mutable_data_ptr<scalar_t>(),
nframe,
dim,
reduction == at::Reduction::Mean);
C10_CUDA_KERNEL_LAUNCH_CHECK();
});
at::cuda::sum_out(
output,
output_tmp,
at::IntArrayRef(std::vector<int64_t>{}),
false,
output.scalar_type());
} else {
output.resize_({input.size(0)});
AT_DISPATCH_FLOATING_TYPES_AND2(
at::ScalarType::Half,
at::ScalarType::BFloat16,
input.scalar_type(),
"multilabel_margin_loss_forward_kernel",
[&] {
using accscalar_t = at::acc_type<scalar_t, true>;
multilabel_margin_loss_forward_kernel<scalar_t, accscalar_t>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
output.mutable_data_ptr<scalar_t>(),
input_.const_data_ptr<scalar_t>(),
target_.const_data_ptr<int64_t>(),
is_target_.mutable_data_ptr<scalar_t>(),
nframe,
dim,
false);
C10_CUDA_KERNEL_LAUNCH_CHECK();
});
}
} else {
TORCH_CHECK(
false,
"Expected 2D input with optional zero batch dim, or 1D input with non-zero dims, but got sizes: ",
input.sizes());
}
}
void multilabel_margin_loss_backward_cuda_out_template(
const Tensor& grad_output,
const Tensor& input,
const Tensor& target,
int64_t reduction,
const Tensor& is_target,
Tensor& grad_input) {
int64_t nframe, dim;
const int64_t ndims = input.dim();
multilabel_margin_loss_shape_check(nframe, dim, ndims, input, target);
if (input.numel() == 0) {
return;
}
auto input_ = input.contiguous();
auto target_ = target.contiguous();
auto is_target_ = is_target.contiguous();
auto grad_output_ = grad_output.contiguous();
grad_input.resize_as_(input_);
if (grad_input.dim() <= 1) {
int target_size = target_.dim() == 0 ? 1 : target_.size(0);
TORCH_CHECK(
(target_.numel() != 0) && (target_.dim() <= 1) && (target_size == dim),
"inconsistent target size");
TORCH_CHECK(
target_.sizes() == is_target_.sizes(), "inconsistent is_target size");
dim3 blocks(1);
dim3 threads(MULTILABELMARGIN_THREADS);
AT_DISPATCH_FLOATING_TYPES_AND2(
at::ScalarType::Half,
at::ScalarType::BFloat16,
input.scalar_type(),
"multilabel_margin_loss_backward_kernel",
[&] {
using accscalar_t = at::acc_type<scalar_t, true>;
multilabel_margin_loss_backward_kernel<scalar_t, accscalar_t>
<<<blocks, threads, 0, c10::cuda::getCurrentCUDAStream()>>>(
grad_input.mutable_data_ptr<scalar_t>(),
grad_output_.const_data_ptr<scalar_t>(),
input_.const_data_ptr<scalar_t>(),
target_.const_data_ptr<int64_t>(),
is_target_.const_data_ptr<scalar_t>(),
1,
dim,
reduction == at::Reduction::Mean,
reduction != at::Reduction::None);
C10_CUDA_KERNEL_LAUNCH_CHECK();
});
} else if (grad_input.dim() == 2) {
TORCH_CHECK(
(input_.size(1) != 0) && (target_.dim() == 2) &&
(target_.size(0) == nframe) && (target_.size(1) == dim),
"inconsistent target size");
TORCH_CHECK(target_.sizes() == is_target_.sizes(), "inconsistent is_target size");
dim3 blocks(grad_input.size(0));
dim3 threads(MULTILABELMARGIN_THREADS);
AT_DISPATCH_FLOATING_TYPES_AND2(
at::ScalarType::Half,
at::ScalarType::BFloat16,
input.scalar_type(),
"multilabel_margin_loss_backward_kernel",
[&] {
using accscalar_t = at::acc_type<scalar_t, true>;
multilabel_margin_loss_backward_kernel<scalar_t, accscalar_t>
<<<blocks, threads, 0, c10::cuda::getCurrentCUDAStream()>>>(
grad_input.mutable_data_ptr<scalar_t>(),
grad_output_.const_data_ptr<scalar_t>(),
input_.const_data_ptr<scalar_t>(),
target_.const_data_ptr<int64_t>(),
is_target_.const_data_ptr<scalar_t>(),
grad_input.size(0),
grad_input.size(1),
reduction == at::Reduction::Mean,
reduction != at::Reduction::None);
C10_CUDA_KERNEL_LAUNCH_CHECK();
});
} else {
TORCH_CHECK(
false,
"Expected 2D input with optional zero batch dim, or 1D input with non-zero dims, but got sizes: ",
grad_input.sizes());
}
}
} // namespace
std::tuple<Tensor&, Tensor&> multilabel_margin_loss_forward_out_cuda(
const Tensor& self,
const Tensor& target,
int64_t reduction,
Tensor& output,
Tensor& is_target) {
multilabel_margin_loss_forward_out_cuda_template(
self, target, reduction, output, is_target);
return std::tuple<Tensor&, Tensor&>(output, is_target);
}
std::tuple<Tensor, Tensor> multilabel_margin_loss_forward_cuda(
const Tensor& self,
const Tensor& target,
int64_t reduction) {
auto output = at::empty({0}, self.options());
auto is_target = at::empty({0}, self.options());
multilabel_margin_loss_forward_out_cuda_template(
self, target, reduction, output, is_target);
return std::make_tuple(output, is_target);
}
Tensor& multilabel_margin_loss_backward_cuda_out(
const Tensor& grad_output,
const Tensor& self,
const Tensor& target,
int64_t reduction,
const Tensor& is_target,
Tensor& grad_input) {
multilabel_margin_loss_backward_cuda_out_template(
grad_output, self, target, reduction, is_target, grad_input);
return grad_input;
}
Tensor multilabel_margin_loss_backward_cuda(
const Tensor& grad_output,
const Tensor& self,
const Tensor& target,
int64_t reduction,
const Tensor& is_target) {
auto grad_input = at::zeros_like(self, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
multilabel_margin_loss_backward_cuda_out_template(
grad_output, self, target, reduction, is_target, grad_input);
return grad_input;
}
} // namespace at::native