[MPS] Fix unique flatten logic (#104938) Tensor must be flatted if dim is none before checking whether or not dim dimension is already None Fixes https://github.com/pytorch/pytorch/issues/104879 Pull Request resolved: https://github.com/pytorch/pytorch/pull/104938 Approved by: https://github.com/albanD
diff --git a/test/test_mps.py b/test/test_mps.py index a9dcbd9..59d44c5 100644 --- a/test/test_mps.py +++ b/test/test_mps.py
@@ -3452,6 +3452,9 @@ helper(torch.randint(3, (10, )), True, True) helper(torch.randint(3, (1, )), True, True) helper(torch.randint(3, (0, )), True, True) + # Regression test for https://github.com/pytorch/pytorch/issues/104879 + x = torch.arange(2, device="mps") + self.assertEqual(x.reshape(1, 1, 2).unique(), x) def test_unique_consecutive(self): def helper(x, dim, return_inverse, return_counts):