[MPS] Fix unique flatten logic (#104938)
Tensor must be flatted if dim is none before checking whether or not dim dimension is already None
Fixes https://github.com/pytorch/pytorch/issues/104879
Pull Request resolved: https://github.com/pytorch/pytorch/pull/104938
Approved by: https://github.com/albanD
diff --git a/test/test_mps.py b/test/test_mps.py
index a9dcbd9..59d44c5 100644
--- a/test/test_mps.py
+++ b/test/test_mps.py
@@ -3452,6 +3452,9 @@
helper(torch.randint(3, (10, )), True, True)
helper(torch.randint(3, (1, )), True, True)
helper(torch.randint(3, (0, )), True, True)
+ # Regression test for https://github.com/pytorch/pytorch/issues/104879
+ x = torch.arange(2, device="mps")
+ self.assertEqual(x.reshape(1, 1, 2).unique(), x)
def test_unique_consecutive(self):
def helper(x, dim, return_inverse, return_counts):