fix nll loss decomposition to properly ignore ignore_index (#95833)
Fixes #95794
This is a hotfix for decomposition only (that is currently used by inductor), reference still accesses invalid indices. Perhaps `_nll_loss_nd` and this decomp should be unified, cc @soumith @voznesenskym @yanboliang @penguinwu @anijain2305 @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @Xia-Weiwen @wenzhe-nrv @jiayisunx @peterbell10 @desertfire @lezcano
Pull Request resolved: https://github.com/pytorch/pytorch/pull/95833
Approved by: https://github.com/lezcano, https://github.com/Chillee
diff --git a/test/inductor/test_torchinductor.py b/test/inductor/test_torchinductor.py
index f7dfcac..4482d47 100644
--- a/test/inductor/test_torchinductor.py
+++ b/test/inductor/test_torchinductor.py
@@ -4174,18 +4174,38 @@
self.common(fn, (torch.zeros([4, 256, 296, 304]), torch.zeros([2292, 5])))
- @requires_decomp(aten.nll_loss_forward)
def test_nll_loss_forward(self):
def fn(a, b):
return aten.nll_loss_forward(a, b, None, 1, -100)
- self.common(
- fn,
- (
- torch.randn([5, 5]),
- torch.zeros([5], dtype=torch.int64),
- ),
+ labels = (
+ torch.zeros([5], dtype=torch.int64),
+ torch.tensor([-100, -100, 3, -100, -100], dtype=torch.int64),
)
+ inps = (torch.randn(5, 5), torch.randn(5, 5))
+ for a, b in zip(inps, labels):
+ self.common(
+ fn,
+ (a, b),
+ )
+
+ def test_nll_loss_backward(self):
+ def fn(a, b, c):
+ return aten.nll_loss_backward(
+ a, b, c, None, 1, -100, torch.tensor(1.0, device=self.device)
+ )
+
+ labels = (
+ torch.zeros([5], dtype=torch.int64),
+ torch.tensor([-100, -100, 3, -100, -100], dtype=torch.int64),
+ )
+ inps = (torch.randn(5, 5), torch.randn(5, 5))
+ grad_outs = (torch.randn(()), torch.randn(()))
+ for a, b, c in zip(grad_outs, inps, labels):
+ self.common(
+ fn,
+ (a, b, c),
+ )
def test_isinf(self):
def fn(x):
diff --git a/torch/_decomp/decompositions.py b/torch/_decomp/decompositions.py
index 54266e1..f9a8d13 100644
--- a/torch/_decomp/decompositions.py
+++ b/torch/_decomp/decompositions.py
@@ -405,8 +405,9 @@
grad_output = grad_output / total_weight
target = target.unsqueeze(channel_dim)
+ safe_target = torch.where(target != ignore_index, target, 0)
grad_input = torch.zeros_like(self)
- grad_input = torch.scatter(grad_input, channel_dim, target, -1.0)
+ grad_input = torch.scatter(grad_input, channel_dim, safe_target, -1.0)
if grad_input.dim() > grad_output.dim() > 0:
grad_output = grad_output.unsqueeze(channel_dim)
@@ -417,9 +418,7 @@
weight = weight.reshape(new_shape)
grad_output = grad_output * weight
- has_ignore_index = ignore_index >= 0
- if has_ignore_index:
- grad_output = torch.where(target != ignore_index, grad_output, 0)
+ grad_output = torch.where(target != ignore_index, grad_output, 0)
return grad_input * grad_output
@@ -2845,14 +2844,13 @@
if weight is not None:
w = weight.unsqueeze(0) if n_dims > 1 else weight
self = self * w
-
- target_ = target.unsqueeze(channel_dim)
+ safe_target = torch.where(target != ignore_index, target, 0)
+ safe_target_ = safe_target.unsqueeze(channel_dim)
# target can be [N, 1] or [1]
- result = -torch.gather(self, channel_dim, target_).squeeze(channel_dim)
+ result = -torch.gather(self, channel_dim, safe_target_).squeeze(channel_dim)
- if ignore_index >= 0:
- result = torch.where(target != ignore_index, result, 0)
+ result = torch.where(target != ignore_index, result, 0)
if reduction == Reduction.NONE.value and n_dims > 1:
total_weight = self.new_full((), 0.0)
@@ -2860,22 +2858,16 @@
if weight is not None:
w = weight.unsqueeze(0).expand(self.shape) if n_dims > 1 else weight
- wsum = torch.gather(w, channel_dim, target_).squeeze(channel_dim)
- if ignore_index >= 0:
- wsum = torch.where(target != ignore_index, wsum, 0)
+ wsum = torch.gather(w, channel_dim, safe_target_).squeeze(channel_dim)
+ wsum = torch.where(target != ignore_index, wsum, 0)
total_weight = wsum.sum()
- elif ignore_index >= 0:
- total_weight = (target != ignore_index).sum().to(self)
else:
- total_weight = self.new_full((), 1.0 * result.numel())
+ total_weight = (target != ignore_index).sum().to(self)
if reduction == Reduction.SUM.value:
result = result.sum()
elif reduction == Reduction.MEAN.value:
- if weight is None:
- result = result.sum() / total_weight if ignore_index >= 0 else result.mean()
- else:
- result = result.sum() / total_weight
+ result = result.sum() / total_weight
return result, total_weight