[MPS] Handle 1D inputs for NLL (#81290) * Add test for NLL 1d * Fix forward NLL for 1D case * Handle NLL backward for 1d Pull Request resolved: https://github.com/pytorch/pytorch/pull/81290 Approved by: https://github.com/razarmehr
diff --git a/test/test_mps.py b/test/test_mps.py index 9d94ad1..fd735b1 100644 --- a/test/test_mps.py +++ b/test/test_mps.py
@@ -1701,6 +1701,26 @@ output_mps.sum().backward() self.assertEqual(input.grad, input_mps.grad.to('cpu')) + def _nll_loss_1d_helper(self, input_size, reduction): + + # CPU + input = torch.rand(input_size, requires_grad=True, device='cpu') + num_channels = input_size[0] + target = torch.randint(num_channels, [], device='cpu') + + # MPS + input_mps = input.detach().clone().to('mps').requires_grad_() + target_mps = target.detach().clone().to('mps') + + output_cpu = F.nll_loss(input, target, reduction=reduction) + output_mps = F.nll_loss(input_mps, target_mps, reduction=reduction) + # TODO(#38095): Replace assertEqualIgnoreType. See issue #38095 + self.assertEqualIgnoreType(output_cpu, output_mps.to('cpu')) + + output_cpu.sum().backward() + output_mps.sum().backward() + self.assertEqual(input.grad, input_mps.grad.to('cpu')) + def test_as_strided(self): values = [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0]] values_1 = [[1.0, 1.0], [1.0, 1.0]] @@ -1743,6 +1763,11 @@ helper(3, 3) + def test_nll_loss_1d(self, device='cpu'): + self._nll_loss_1d_helper([10], "none") + self._nll_loss_1d_helper([10], "mean") + self._nll_loss_1d_helper([10], "sum") + def test_nll_loss_empty_tensor_reduction_none(self, device='cpu'): self._nll_loss_helper([1, 3], "none", torch.empty([0], device=device)) self._nll_loss_helper([3, 5, 7], "none", torch.empty([5, 7], device=device))