blob: 9719441149d38d1a412c40c53149db16b0a1dca7 [file] [log] [blame]
#define TORCH_ASSERT_ONLY_METHOD_OPERATORS
#include <ATen/core/Tensor.h>
#include <ATen/Dispatch.h>
#include <ATen/Parallel.h>
#include <ATen/TensorIndexing.h>
#include <ATen/TensorMeta.h>
#include <ATen/TensorOperators.h>
#include <ATen/TensorUtils.h>
#include <ATen/native/cpu/utils.h>
#include <ATen/native/Resize.h>
#include <c10/util/SmallBuffer.h>
#include <ATen/TensorSubclassLikeUtils.h>
#ifndef AT_PER_OPERATOR_HEADERS
#include <ATen/Functions.h>
#include <ATen/NativeFunctions.h>
#else
#include <ATen/ops/cross_entropy_loss_native.h>
#include <ATen/ops/empty.h>
#include <ATen/ops/log_softmax.h>
#include <ATen/ops/nll_loss.h>
#include <ATen/ops/nll_loss2d.h>
#include <ATen/ops/nll_loss_backward_native.h>
#include <ATen/ops/nll_loss_forward.h>
#include <ATen/ops/nll_loss_forward_native.h>
#include <ATen/ops/nll_loss_native.h>
#include <ATen/ops/nll_loss_nd.h>
#include <ATen/ops/nll_loss_nd_native.h>
#endif
#include <c10/core/TensorOptions.h>
#include <c10/util/irange.h>
#include <utility>
namespace at::meta {
TORCH_META_FUNC(nll_loss_forward)
(const Tensor& self,
const Tensor& target,
const OptionalTensorRef weight_opt,
int64_t reduction,
int64_t ignore_index) {
const Tensor& weight = weight_opt.getTensorRef();
TORCH_CHECK(
self.dim() > 0 && self.dim() <= 2, "input tensor should be 1D or 2D");
TORCH_CHECK(
target.dim() <= 1,
"0D or 1D target tensor expected, multi-target not supported");
auto no_batch_dim = self.dim() == 1 && target.dim() == 0;
TORCH_CHECK(
no_batch_dim || (self.size(0) == target.size(0)),
"size mismatch (got input: ",
self.sizes(),
", target: ",
target.sizes(),
")")
const auto n_classes = self.size(-1);
TORCH_CHECK(
!weight.defined() || (weight.dim() == 1 && weight.numel() == n_classes),
"weight tensor should be defined either for all ",
n_classes,
" classes or no classes"
" but got weight tensor of shape: ",
weight.sizes());
const auto n_dims = self.dim();
const auto batch_size = self.size(0);
if (reduction == Reduction::None && n_dims == 2) {
set_output_raw_strided(0, {batch_size}, {}, self.options());
} else {
// produce scalar output when reducing or input is 1d
set_output_raw_strided(0, {}, {}, self.options());
}
set_output_raw_strided(1, {}, {}, self.options());
}
TORCH_META_FUNC(nll_loss_backward)
(const Tensor& grad_output,
const Tensor& self,
const Tensor& target,
OptionalTensorRef weight_opt,
int64_t reduction,
int64_t ignore_index,
const Tensor& total_weight) {
TORCH_CHECK(
self.dim() > 0 && self.dim() <= 2, "input tensor should be 1D or 2D");
TORCH_CHECK(
target.dim() <= 1,
"0D or 1D target tensor expected, multi-target not supported");
auto no_batch_dim = self.dim() == 1 && target.dim() == 0;
TORCH_CHECK(
no_batch_dim || (self.size(0) == target.size(0)),
"size mismatch (got input: ",
self.sizes(),
", target: ",
target.sizes(),
")")
TORCH_CHECK(
total_weight.numel() == 1,
"expected total_weight to be a single element tensor, got: ",
total_weight.sizes(),
" (",
total_weight.numel(),
" elements)");
const auto& weight = weight_opt.getTensorRef();
TORCH_CHECK(
!weight.defined() || weight.numel() == self.size(-1),
"weight tensor should be defined either for all or no classes");
const auto n_dims = self.dim();
if (reduction == Reduction::None && n_dims == 2) {
const auto batch_size = self.size(0);
check_dim_size(grad_output, 1, 0, batch_size);
} else {
TORCH_CHECK(
grad_output.dim() <= 1 && grad_output.numel() == 1,
"Expected a single element grad_output tensor, but got: ",
grad_output.sizes());
}
set_output_raw_strided(0, self.sizes(), {}, self.options().memory_format(LEGACY_CONTIGUOUS_MEMORY_FORMAT));
}
} // namespace at::meta
namespace at::native {
namespace {
// Returns a contiguous tensor if the source tensor
// is defined. Otherwise returns the undefined
// source tensor unmodified.
inline Tensor optional_contiguous(const Tensor& source) {
return source.defined() ? source.contiguous() : source;
}
// Returns the address of the first element of a tensor
// or nullptr if the tensor is undefined.
template <typename scalar_t>
inline scalar_t* optional_data(const Tensor& source) {
if constexpr (std::is_const<scalar_t>::value) {
return source.defined() ? source.const_data_ptr<scalar_t>() : nullptr;
} else {
return source.defined() ? source.data_ptr<scalar_t>() : nullptr;
}
}
template <typename scalar_t, typename target_t>
static void nll_loss_out_frame(
const Tensor& output,
const Tensor& total_weight,
const Tensor& input,
const Tensor& target,
const Tensor& weight,
int64_t reduction,
int64_t ignore_index) {
const auto n_dims = input.dim();
const auto n_classes = input.size(-1);
scalar_t* total_weight_data = total_weight.data_ptr<scalar_t>();
*total_weight_data = 0;
auto weight_contiguous = optional_contiguous(weight);
const scalar_t* weight_data = optional_data<const scalar_t>(weight_contiguous);
if (reduction == Reduction::None && n_dims == 2) {
const auto batch_size = input.size(0);
at::native::resize_output(output, {batch_size});
auto input_acc = input.accessor<const scalar_t, 2>();
auto target_acc = target.accessor<const target_t, 1>();
auto output_acc = output.accessor<scalar_t, 1>();
at::parallel_for(0, batch_size, 0, [&](int64_t start, int64_t end) {
for (const auto i : c10::irange(start, end)) {
const auto cur_target = target_acc[i];
if (cur_target == ignore_index) {
output_acc[i] = 0;
continue;
}
TORCH_CHECK_INDEX(
cur_target >= 0 && cur_target < n_classes,
"Target ",
cur_target,
" is out of bounds.");
scalar_t cur_weight = weight_data != nullptr ? weight_data[cur_target]
: static_cast<scalar_t>(1);
output_acc[i] = -input_acc[i][cur_target] * cur_weight;
}
});
return;
}
// produce scalar outputs for the reduction case
at::native::resize_output(output, {});
if (target.numel() == 0) {
// Here target (and input) have zero elements
// Mean reduction on empty tensors produces NaN. See the discussion in
// https://github.com/pytorch/pytorch/pull/64572#issuecomment-926504162
if (reduction == Reduction::Mean) {
output.fill_(std::numeric_limits<double>::quiet_NaN());
} else {
output.zero_();
}
total_weight.zero_();
return;
}
auto input_contiguous = input.contiguous();
auto target_contiguous = target.contiguous();
const scalar_t* input_data = input_contiguous.const_data_ptr<scalar_t>();
const target_t* target_data = target_contiguous.const_data_ptr<target_t>();
const int64_t ndim = input.dim();
const int64_t batch_size = ndim == 1 ? 1 : input.size(0);
constexpr int64_t cascade_sum_num_levels = 8;
const int64_t level_power =
std::max(int64_t(4), utils::CeilLog2(batch_size) / cascade_sum_num_levels);
const int64_t level_step = (1 << level_power);
const int64_t level_mask = level_step - 1;
int64_t num_ignored = 0;
// NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays)
scalar_t weight_partial_sums[cascade_sum_num_levels] = {0};
// NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays)
scalar_t loss_partial_sums[cascade_sum_num_levels] = {0};
for (const auto b : c10::irange(batch_size)) {
const int64_t cur_target = target_data[b];
if (cur_target == ignore_index) {
++num_ignored;
continue;
}
TORCH_CHECK_INDEX(
cur_target >= 0 && cur_target < n_classes,
"Target ",
cur_target,
" is out of bounds.");
const auto data = input_data[b * n_classes + cur_target];
if (weight_data) {
const scalar_t weight_val = weight_data[cur_target];
loss_partial_sums[0] -= data * weight_val;
weight_partial_sums[0] += weight_val;
} else {
loss_partial_sums[0] -= data;
}
for (int64_t j = 0; j + 1 < cascade_sum_num_levels; ++j) {
const auto mask = (level_mask << (j * level_power));
if (C10_LIKELY((b & mask) != 0)) {
break;
}
weight_partial_sums[j + 1] += weight_partial_sums[j];
loss_partial_sums[j + 1] += loss_partial_sums[j];
weight_partial_sums[j] = 0;
loss_partial_sums[j] = 0;
}
}
const scalar_t total_weight_val = !weight_data ?
static_cast<scalar_t>(batch_size - num_ignored) :
std::accumulate(std::begin(weight_partial_sums),
std::end(weight_partial_sums),
scalar_t{0});
scalar_t output_val = std::accumulate(std::begin(loss_partial_sums),
std::end(loss_partial_sums),
scalar_t{0});
if (reduction == Reduction::Mean) {
output_val /= total_weight_val;
}
// write result to output tensors
*output.data_ptr<scalar_t>() = output_val;
*total_weight_data = total_weight_val;
}
void nll_loss_forward_out_cpu_template(
const Tensor& output,
const Tensor& total_weight,
const Tensor& input,
const Tensor& target,
const Tensor& weight,
int64_t reduction,
int64_t ignore_index) {
AT_DISPATCH_FLOATING_TYPES_AND(
ScalarType::BFloat16, input.scalar_type(), "nll_loss_out_frame", [&] {
if (target.scalar_type() == kByte) {
nll_loss_out_frame<scalar_t, uint8_t>(
output,
total_weight,
input,
target,
weight,
reduction,
ignore_index);
} else {
// assumed to be int64
nll_loss_out_frame<scalar_t, int64_t>(
output,
total_weight,
input,
target,
weight,
reduction,
ignore_index);
}
});
}
template <typename scalar_t, typename target_t>
static void nll_loss_backward_out_frame(
const Tensor& grad_input,
const Tensor& grad_output,
const Tensor& input,
const Tensor& target,
const Tensor& weight,
int64_t reduction,
int64_t ignore_index,
const Tensor& total_weight) {
const auto n_dims = input.dim();
const auto n_classes = input.size(-1);
auto target_ = target;
if (target.dim() == 0) {
target_ = target.unsqueeze(0);
}
auto target_acc = target_.accessor<target_t, 1>();
auto weight_contiguous = optional_contiguous(weight);
const scalar_t* weight_data = optional_data<scalar_t>(weight_contiguous);
if (reduction == Reduction::None && n_dims == 2) {
const auto batch_size = input.size(0);
auto grad_input_acc = grad_input.accessor<scalar_t, 2>();
auto grad_output_acc = grad_output.accessor<scalar_t, 1>();
at::parallel_for(0, batch_size, 0, [&](int64_t start, int64_t end) {
for (const auto i : c10::irange(start, end)) {
auto cur_target = target_acc[i];
if (cur_target == ignore_index) {
continue;
}
const scalar_t w =
weight_data ? weight_data[cur_target] : static_cast<scalar_t>(1);
grad_input_acc[i][cur_target] = -w * grad_output_acc[i];
}
});
return;
}
const scalar_t total_weight_value = *total_weight.data_ptr<scalar_t>();
const scalar_t grad_output_value = *grad_output.data_ptr<scalar_t>();
if (input.dim() == 1) {
auto grad_input_acc = grad_input.accessor<scalar_t, 1>();
const auto t = target_acc[0];
if (t != ignore_index) {
TORCH_CHECK_INDEX(t >= 0 && t < n_classes, "Target ", t, " is out of bounds.");
const auto grad = -(reduction == Reduction::Mean ? grad_output_value / total_weight_value
: grad_output_value);
grad_input_acc[t] = weight_data != nullptr ? weight_data[t] * grad
: grad;
}
} else if (input.dim() == 2) {
auto grad_input_acc = grad_input.accessor<scalar_t, 2>();
const auto grad = -(reduction == Reduction::Mean ? grad_output_value / total_weight_value
: grad_output_value);
const auto batch_size = input.size(0);
at::parallel_for(0, batch_size, 0, [&](int64_t start, int64_t end) {
for (const auto i : c10::irange(start, end)) {
const auto t = target_acc[i];
if (t != ignore_index) {
TORCH_CHECK_INDEX(t >= 0 && t < n_classes, "Target ", t, " is out of bounds.");
grad_input_acc[i][t] = weight_data != nullptr ? weight_data[t] * grad
: grad;
}
}
});
}
}
void nll_loss_backward_out_cpu_template(
const Tensor& grad_input,
const Tensor& grad_output,
const Tensor& input,
const Tensor& target,
const Tensor& weight,
int64_t reduction,
int64_t ignore_index,
const Tensor& total_weight) {
grad_input.zero_();
AT_DISPATCH_FLOATING_TYPES_AND(
ScalarType::BFloat16,
input.scalar_type(),
"nll_loss_backward_out_frame",
[&] {
if (target.scalar_type() == kByte) {
nll_loss_backward_out_frame<scalar_t, uint8_t>(
grad_input,
grad_output,
input,
target,
weight,
reduction,
ignore_index,
total_weight);
} else {
// assumed to be uint64
nll_loss_backward_out_frame<scalar_t, int64_t>(
grad_input,
grad_output,
input,
target,
weight,
reduction,
ignore_index,
total_weight);
}
});
}
} // namespace
TORCH_IMPL_FUNC(nll_loss_forward_out_cpu)
(const Tensor& self,
const Tensor& target,
const OptionalTensorRef weight_opt,
int64_t reduction,
int64_t ignore_index,
const Tensor& output,
const Tensor& total_weight) {
const Tensor& weight = weight_opt.getTensorRef();
nll_loss_forward_out_cpu_template(
output, total_weight, self, target, weight, reduction, ignore_index);
}
TORCH_IMPL_FUNC(nll_loss_backward_out_cpu)
(const Tensor& grad_output,
const Tensor& self,
const Tensor& target,
OptionalTensorRef weight_opt,
int64_t reduction,
int64_t ignore_index,
const Tensor& total_weight,
const Tensor& grad_input
) {
const Tensor& weight = weight_opt.getTensorRef();
nll_loss_backward_out_cpu_template(
grad_input,
grad_output,
self,
target,
weight,
reduction,
ignore_index,
total_weight);
}
static Tensor cross_entropy_loss_prob_target(
const Tensor& self,
const Tensor& target_,
const Tensor& weight,
int64_t reduction,
double label_smoothing) {
const auto class_dim = self.dim() == 1 ? 0 : 1;
const auto n_classes = self.size(class_dim);
TORCH_CHECK(
!weight.defined() || (weight.dim() == 1 && weight.numel() == n_classes),
"cross_entropy: weight tensor should be defined either for all ",
n_classes,
" classes or no classes"
" but got weight tensor of shape: ",
weight.sizes());
auto input = at::log_softmax(self, class_dim, self.scalar_type());
Tensor target;
if (label_smoothing > 0.0) {
TORCH_CHECK(label_smoothing <= 1.0, "label_smoothing must be between 0.0 and 1.0. Got: ", label_smoothing);
target = target_ * (1 - label_smoothing) + label_smoothing / n_classes;
} else {
target = target_;
}
if (weight.defined()) {
// Expand weight to the correct number of dims for broadcasting with input / target
Tensor weight_ = weight;
if (input.dim() > 1) {
auto weight_broadcast_shape = SmallBuffer<int64_t, 5>(input.dim());
std::fill(weight_broadcast_shape.begin(), weight_broadcast_shape.end(), 1);
weight_broadcast_shape[1] = weight.size(0);
weight_ = weight.view(weight_broadcast_shape);
}
switch (reduction) {
case Reduction::Mean:
if (input.numel()==0){
return -(input * target * weight_).sum().fill_(std::numeric_limits<double>::quiet_NaN());
} else {
return -(input * target * weight_).sum() / (input.numel() / n_classes);
}
case Reduction::Sum:
return -(input * target * weight_).sum();
case Reduction::None:
return -(input * target * weight_).sum(class_dim);
default:
TORCH_CHECK(false, "Invalid reduction type encountered in cross_entropy: ", reduction);
}
} else {
switch (reduction) {
case Reduction::Mean:
if (input.numel()==0){
return -(input * target).sum().fill_(std::numeric_limits<double>::quiet_NaN());
} else {
return -(input * target).sum() / (input.numel() / n_classes);
}
case Reduction::Sum:
return -(input * target).sum();
case Reduction::None:
return -(input * target).sum(class_dim);
default:
TORCH_CHECK(false, "Invalid reduction type encountered in cross_entropy: ", reduction);
}
}
}
static Tensor cross_entropy_loss_label_smoothing(
const Tensor& self,
const Tensor& target,
const Tensor& weight,
int64_t reduction,
c10::SymInt ignore_index,
double label_smoothing) {
auto class_dim = self.dim() == 1 ? 0 : 1;
auto input = at::log_softmax(self, class_dim, self.scalar_type());
auto nllloss = at::nll_loss_nd_symint(input, target, weight, reduction, ignore_index);
auto n_classes = input.sym_size(class_dim);
Tensor smooth_loss;
if (weight.defined()) {
// Expand weight to the correct number of dims for broadcasting with input / target
auto weight_broadcast_shape = SmallBuffer<int64_t, 5>(input.dim());
std::fill(weight_broadcast_shape.begin(), weight_broadcast_shape.end(), 1);
weight_broadcast_shape[class_dim] = weight.size(0);
Tensor weight_ = weight.view(weight_broadcast_shape);
smooth_loss = -(input * weight_).sum(class_dim);
} else {
smooth_loss = -input.sum(class_dim);
}
auto ignore_mask = target == std::move(ignore_index);
smooth_loss.masked_fill_(ignore_mask, 0.0);
Tensor ret;
switch (reduction) {
case Reduction::Mean:
if (weight.defined()) {
if (isTensorSubclassLike(weight)){
// we will collect weights from 0 index which is always valid
// and mask them out if they are ignored
auto filtered_target = target.masked_fill(ignore_mask, 0);
auto tgt_weights = weight.gather(0, filtered_target.flatten());
auto weight_sum =
tgt_weights.masked_fill_(ignore_mask.flatten(), 0).sum();
ret = smooth_loss.sum() / weight_sum;
} else {
// TODO: This code can path can be removed if #61309 is resolved
// loss is normalized by the weights to be consistent with
// nll_loss_nd
ret = smooth_loss.sum() /
weight.gather(0, target.masked_select(~ignore_mask).flatten())
.sum();
}
} else {
auto true_mask = ~ignore_mask;
ret = smooth_loss.sum()/ true_mask.sum();
}
break;
case Reduction::Sum:
ret = smooth_loss.sum();
break;
case Reduction::None:
ret = smooth_loss;
break;
default:
TORCH_CHECK(false, "Invalid reduction type encountered in cross_entropy: ", reduction);
}
return (1 - label_smoothing) * nllloss + ret * (label_smoothing / n_classes);
}
Tensor cross_entropy_loss_symint(
const Tensor& self,
const Tensor& target,
const c10::optional<Tensor>& weight,
int64_t reduction,
c10::SymInt ignore_index,
double label_smoothing) {
Tensor ret;
if (self.sym_sizes() == target.sym_sizes()) {
// Assume soft targets when input and target shapes are the same
TORCH_CHECK(at::isFloatingType(target.scalar_type()),
"Expected floating point type for target with class probabilities, got ", target.scalar_type());
TORCH_CHECK(ignore_index < 0, "ignore_index is not supported for floating point target");
// See [Note: hacky wrapper removal for optional tensor]
c10::MaybeOwned<Tensor> weight_maybe_owned = at::borrow_from_optional_tensor(weight);
const Tensor& weight_ = *weight_maybe_owned;
ret = cross_entropy_loss_prob_target(self, target, weight_, reduction, label_smoothing);
} else if (label_smoothing > 0.0) {
TORCH_CHECK(label_smoothing <= 1.0, "label_smoothing must be between 0.0 and 1.0. Got: ", label_smoothing);
// See [Note: hacky wrapper removal for optional tensor]
c10::MaybeOwned<Tensor> weight_maybe_owned = at::borrow_from_optional_tensor(weight);
const Tensor& weight_ = *weight_maybe_owned;
ret = cross_entropy_loss_label_smoothing(self, target, weight_, reduction, std::move(ignore_index), label_smoothing);
} else {
auto class_dim = self.dim() == 1 ? 0 : 1;
ret = at::nll_loss_nd_symint(
at::log_softmax(self, class_dim, self.scalar_type()),
target,
weight,
reduction,
std::move(ignore_index));
}
return ret;
}
Tensor & nll_loss_out(const Tensor & self, const Tensor & target, const c10::optional<Tensor>& weight_opt, int64_t reduction, int64_t ignore_index, Tensor & output) {
// See [Note: hacky wrapper removal for optional tensor]
c10::MaybeOwned<Tensor> weight_maybe_owned = at::borrow_from_optional_tensor(weight_opt);
const Tensor& weight = *weight_maybe_owned;
Tensor total_weight = at::empty({0}, self.options());
return std::get<0>(at::nll_loss_forward_out(output, total_weight, self, target, weight, reduction, ignore_index));
}
Tensor nll_loss_symint(const Tensor & self, const Tensor & target, const c10::optional<Tensor>& weight_opt, int64_t reduction, c10::SymInt ignore_index) {
// See [Note: hacky wrapper removal for optional tensor]
c10::MaybeOwned<Tensor> weight_maybe_owned = at::borrow_from_optional_tensor(weight_opt);
const Tensor& weight = *weight_maybe_owned;
return std::get<0>(at::nll_loss_forward_symint(self, target, weight, reduction, std::move(ignore_index)));
}
// Duplicate of above code for non-symbolic ints. Kept for BC purposes and to minimize breakages.
static Tensor nll_loss(const Tensor & self, const Tensor & target, const c10::optional<Tensor>& weight_opt, int64_t reduction, int64_t ignore_index) {
// See [Note: hacky wrapper removal for optional tensor]
c10::MaybeOwned<Tensor> weight_maybe_owned = at::borrow_from_optional_tensor(weight_opt);
const Tensor& weight = *weight_maybe_owned;
return std::get<0>(at::nll_loss_forward_symint(self, target, weight, reduction, ignore_index));
}
Tensor nll_loss_nd_symint(
const Tensor& self,
const Tensor& target,
const c10::optional<Tensor>& weight,
int64_t reduction,
c10::SymInt ignore_index) {
if (self.dim() < 1) {
TORCH_CHECK_VALUE(
false, "Expected 1 or more dimensions (got ", self.dim(), ")");
}
if (self.dim() != 1 && self.sym_sizes()[0] != target.sym_sizes()[0]) {
TORCH_CHECK_VALUE(
false,
"Expected input batch_size (",
self.sym_sizes()[0],
") to match target batch_size (",
target.sym_sizes()[0],
").");
}
Tensor ret;
Tensor input_ = self;
Tensor target_ = target;
if (input_.dim() == 1 || input_.dim() == 2) {
ret = at::nll_loss_symint(input_, target_, weight, reduction, std::move(ignore_index));
} else if (input_.dim() == 4) {
ret = at::nll_loss2d_symint(input_, target_, weight, reduction, std::move(ignore_index));
} else {
// dim == 3 or dim > 4
auto n = input_.sym_sizes()[0];
auto c = input_.sym_sizes()[1];
auto out_size = input_.sym_sizes().slice(2).vec();
out_size.insert(out_size.begin(), n);
if (target_.sym_sizes().slice(1) != input_.sym_sizes().slice(2)) {
TORCH_CHECK(
false,
"Expected target size ",
SymIntArrayRef(out_size),
", got ",
target_.sym_sizes());
}
input_ = input_.contiguous();
target_ = target_.contiguous();
// support empty batches, see #15870
if (input_.sym_numel() > 0) {
input_ = input_.view_symint({n, std::move(c), 1, -1});
} else {
input_ = input_.view_symint({n, std::move(c), 0, 0});
}
if (target_.sym_numel() > 0) {
target_ = target_.view_symint({std::move(n), 1, -1});
} else {
target_ = target_.view_symint({std::move(n), 0, 0});
}
if (reduction != Reduction::None) {
ret = at::nll_loss2d_symint(input_, target_, weight, reduction, std::move(ignore_index));
} else {
auto out =
at::nll_loss2d_symint(input_, target_, weight, reduction, std::move(ignore_index));
ret = out.view_symint(out_size);
}
}
return ret;
}
} // namespace at::native