[MPS] Handle 1D inputs for NLL (#81290)
* Add test for NLL 1d
* Fix forward NLL for 1D case
* Handle NLL backward for 1d
Pull Request resolved: https://github.com/pytorch/pytorch/pull/81290
Approved by: https://github.com/razarmehr
diff --git a/test/test_mps.py b/test/test_mps.py
index 9d94ad1..fd735b1 100644
--- a/test/test_mps.py
+++ b/test/test_mps.py
@@ -1701,6 +1701,26 @@
output_mps.sum().backward()
self.assertEqual(input.grad, input_mps.grad.to('cpu'))
+ def _nll_loss_1d_helper(self, input_size, reduction):
+
+ # CPU
+ input = torch.rand(input_size, requires_grad=True, device='cpu')
+ num_channels = input_size[0]
+ target = torch.randint(num_channels, [], device='cpu')
+
+ # MPS
+ input_mps = input.detach().clone().to('mps').requires_grad_()
+ target_mps = target.detach().clone().to('mps')
+
+ output_cpu = F.nll_loss(input, target, reduction=reduction)
+ output_mps = F.nll_loss(input_mps, target_mps, reduction=reduction)
+ # TODO(#38095): Replace assertEqualIgnoreType. See issue #38095
+ self.assertEqualIgnoreType(output_cpu, output_mps.to('cpu'))
+
+ output_cpu.sum().backward()
+ output_mps.sum().backward()
+ self.assertEqual(input.grad, input_mps.grad.to('cpu'))
+
def test_as_strided(self):
values = [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0]]
values_1 = [[1.0, 1.0], [1.0, 1.0]]
@@ -1743,6 +1763,11 @@
helper(3, 3)
+ def test_nll_loss_1d(self, device='cpu'):
+ self._nll_loss_1d_helper([10], "none")
+ self._nll_loss_1d_helper([10], "mean")
+ self._nll_loss_1d_helper([10], "sum")
+
def test_nll_loss_empty_tensor_reduction_none(self, device='cpu'):
self._nll_loss_helper([1, 3], "none", torch.empty([0], device=device))
self._nll_loss_helper([3, 5, 7], "none", torch.empty([5, 7], device=device))