[MPS] Fix unique flatten logic (#104938)

Tensor must be flatted if dim is none before checking whether or not dim dimension is already None

Fixes https://github.com/pytorch/pytorch/issues/104879

Pull Request resolved: https://github.com/pytorch/pytorch/pull/104938
Approved by: https://github.com/albanD
diff --git a/aten/src/ATen/native/mps/operations/Unique.mm b/aten/src/ATen/native/mps/operations/Unique.mm
index 23558f0..deb44ef 100644
--- a/aten/src/ATen/native/mps/operations/Unique.mm
+++ b/aten/src/ATen/native/mps/operations/Unique.mm
@@ -57,6 +57,20 @@
   MPSGraphTensor* inverseIndicesTensor = nil;
   MPSGraphTensor* countTensor = nil;
   MPSGraphTensor* lengthTensor = nil;
+
+  const bool needsFlatten = !(dimOpt.has_value() || [shape count] == 1);
+  if (needsFlatten) {
+    inputTensor = [graph reshapeTensor:inputTensor withShape:@[ @-1 ] name:nil];
+    length = 1;
+    for (const auto i : c10::irange([shape count])) {
+      if (c10::mul_overflows(length, [shape[i] unsignedIntValue], &length)) {
+        TORCH_CHECK(false, "RuntimeError: Tensor size overflow");
+      }
+    }
+
+    destShape = @[ [NSNumber numberWithUnsignedInteger:length] ];
+  }
+
   if (length <= 1) {
     // Trivial case, only 1 element everything is unique
     resultTensor = inputTensor;
@@ -76,19 +90,6 @@
     inputTensor = [graph castTensor:inputTensor toType:dataType name:@"castInputTensor"];
   }
 
-  bool needsFlatten = !(dimOpt.has_value() || [shape count] == 1);
-  if (needsFlatten) {
-    inputTensor = [graph reshapeTensor:inputTensor withShape:@[ @-1 ] name:nil];
-    length = 1;
-    for (const auto i : c10::irange([shape count])) {
-      if (c10::mul_overflows(length, [shape[i] unsignedIntValue], &length)) {
-        TORCH_CHECK(false, "RuntimeError: Tensor size overflow");
-      }
-    }
-
-    destShape = @[ [NSNumber numberWithUnsignedInteger:length] ];
-  }
-
   MPSGraphTensor* sortedInput = nil;
   if (consecutive) {
     sortedInput = inputTensor;
diff --git a/test/test_mps.py b/test/test_mps.py
index a9dcbd9..59d44c5 100644
--- a/test/test_mps.py
+++ b/test/test_mps.py
@@ -3452,6 +3452,9 @@
         helper(torch.randint(3, (10, )), True, True)
         helper(torch.randint(3, (1, )), True, True)
         helper(torch.randint(3, (0, )), True, True)
+        # Regression test for https://github.com/pytorch/pytorch/issues/104879
+        x = torch.arange(2, device="mps")
+        self.assertEqual(x.reshape(1, 1, 2).unique(), x)
 
     def test_unique_consecutive(self):
         def helper(x, dim, return_inverse, return_counts):