fix nll loss decomposition to properly ignore ignore_index (#95833)

Fixes #95794
This is a hotfix for decomposition only (that is currently used by inductor), reference still accesses invalid indices. Perhaps `_nll_loss_nd` and this decomp should be unified, cc @soumith @voznesenskym @yanboliang @penguinwu @anijain2305 @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @Xia-Weiwen @wenzhe-nrv @jiayisunx @peterbell10 @desertfire @lezcano

Pull Request resolved: https://github.com/pytorch/pytorch/pull/95833
Approved by: https://github.com/lezcano, https://github.com/Chillee
diff --git a/test/inductor/test_torchinductor.py b/test/inductor/test_torchinductor.py
index f7dfcac..4482d47 100644
--- a/test/inductor/test_torchinductor.py
+++ b/test/inductor/test_torchinductor.py
@@ -4174,18 +4174,38 @@
 
         self.common(fn, (torch.zeros([4, 256, 296, 304]), torch.zeros([2292, 5])))
 
-    @requires_decomp(aten.nll_loss_forward)
     def test_nll_loss_forward(self):
         def fn(a, b):
             return aten.nll_loss_forward(a, b, None, 1, -100)
 
-        self.common(
-            fn,
-            (
-                torch.randn([5, 5]),
-                torch.zeros([5], dtype=torch.int64),
-            ),
+        labels = (
+            torch.zeros([5], dtype=torch.int64),
+            torch.tensor([-100, -100, 3, -100, -100], dtype=torch.int64),
         )
+        inps = (torch.randn(5, 5), torch.randn(5, 5))
+        for a, b in zip(inps, labels):
+            self.common(
+                fn,
+                (a, b),
+            )
+
+    def test_nll_loss_backward(self):
+        def fn(a, b, c):
+            return aten.nll_loss_backward(
+                a, b, c, None, 1, -100, torch.tensor(1.0, device=self.device)
+            )
+
+        labels = (
+            torch.zeros([5], dtype=torch.int64),
+            torch.tensor([-100, -100, 3, -100, -100], dtype=torch.int64),
+        )
+        inps = (torch.randn(5, 5), torch.randn(5, 5))
+        grad_outs = (torch.randn(()), torch.randn(()))
+        for a, b, c in zip(grad_outs, inps, labels):
+            self.common(
+                fn,
+                (a, b, c),
+            )
 
     def test_isinf(self):
         def fn(x):
diff --git a/torch/_decomp/decompositions.py b/torch/_decomp/decompositions.py
index 54266e1..f9a8d13 100644
--- a/torch/_decomp/decompositions.py
+++ b/torch/_decomp/decompositions.py
@@ -405,8 +405,9 @@
         grad_output = grad_output / total_weight
 
     target = target.unsqueeze(channel_dim)
+    safe_target = torch.where(target != ignore_index, target, 0)
     grad_input = torch.zeros_like(self)
-    grad_input = torch.scatter(grad_input, channel_dim, target, -1.0)
+    grad_input = torch.scatter(grad_input, channel_dim, safe_target, -1.0)
 
     if grad_input.dim() > grad_output.dim() > 0:
         grad_output = grad_output.unsqueeze(channel_dim)
@@ -417,9 +418,7 @@
         weight = weight.reshape(new_shape)
         grad_output = grad_output * weight
 
-    has_ignore_index = ignore_index >= 0
-    if has_ignore_index:
-        grad_output = torch.where(target != ignore_index, grad_output, 0)
+    grad_output = torch.where(target != ignore_index, grad_output, 0)
 
     return grad_input * grad_output
 
@@ -2845,14 +2844,13 @@
     if weight is not None:
         w = weight.unsqueeze(0) if n_dims > 1 else weight
         self = self * w
-
-    target_ = target.unsqueeze(channel_dim)
+    safe_target = torch.where(target != ignore_index, target, 0)
+    safe_target_ = safe_target.unsqueeze(channel_dim)
     # target can be [N, 1] or [1]
 
-    result = -torch.gather(self, channel_dim, target_).squeeze(channel_dim)
+    result = -torch.gather(self, channel_dim, safe_target_).squeeze(channel_dim)
 
-    if ignore_index >= 0:
-        result = torch.where(target != ignore_index, result, 0)
+    result = torch.where(target != ignore_index, result, 0)
 
     if reduction == Reduction.NONE.value and n_dims > 1:
         total_weight = self.new_full((), 0.0)
@@ -2860,22 +2858,16 @@
 
     if weight is not None:
         w = weight.unsqueeze(0).expand(self.shape) if n_dims > 1 else weight
-        wsum = torch.gather(w, channel_dim, target_).squeeze(channel_dim)
-        if ignore_index >= 0:
-            wsum = torch.where(target != ignore_index, wsum, 0)
+        wsum = torch.gather(w, channel_dim, safe_target_).squeeze(channel_dim)
+        wsum = torch.where(target != ignore_index, wsum, 0)
         total_weight = wsum.sum()
-    elif ignore_index >= 0:
-        total_weight = (target != ignore_index).sum().to(self)
     else:
-        total_weight = self.new_full((), 1.0 * result.numel())
+        total_weight = (target != ignore_index).sum().to(self)
 
     if reduction == Reduction.SUM.value:
         result = result.sum()
     elif reduction == Reduction.MEAN.value:
-        if weight is None:
-            result = result.sum() / total_weight if ignore_index >= 0 else result.mean()
-        else:
-            result = result.sum() / total_weight
+        result = result.sum() / total_weight
 
     return result, total_weight