[MPS] Fix unique flatten logic (#104938)
Tensor must be flatted if dim is none before checking whether or not dim dimension is already None
Fixes https://github.com/pytorch/pytorch/issues/104879
Pull Request resolved: https://github.com/pytorch/pytorch/pull/104938
Approved by: https://github.com/albanD
diff --git a/aten/src/ATen/native/mps/operations/Unique.mm b/aten/src/ATen/native/mps/operations/Unique.mm
index 23558f0..deb44ef 100644
--- a/aten/src/ATen/native/mps/operations/Unique.mm
+++ b/aten/src/ATen/native/mps/operations/Unique.mm
@@ -57,6 +57,20 @@
MPSGraphTensor* inverseIndicesTensor = nil;
MPSGraphTensor* countTensor = nil;
MPSGraphTensor* lengthTensor = nil;
+
+ const bool needsFlatten = !(dimOpt.has_value() || [shape count] == 1);
+ if (needsFlatten) {
+ inputTensor = [graph reshapeTensor:inputTensor withShape:@[ @-1 ] name:nil];
+ length = 1;
+ for (const auto i : c10::irange([shape count])) {
+ if (c10::mul_overflows(length, [shape[i] unsignedIntValue], &length)) {
+ TORCH_CHECK(false, "RuntimeError: Tensor size overflow");
+ }
+ }
+
+ destShape = @[ [NSNumber numberWithUnsignedInteger:length] ];
+ }
+
if (length <= 1) {
// Trivial case, only 1 element everything is unique
resultTensor = inputTensor;
@@ -76,19 +90,6 @@
inputTensor = [graph castTensor:inputTensor toType:dataType name:@"castInputTensor"];
}
- bool needsFlatten = !(dimOpt.has_value() || [shape count] == 1);
- if (needsFlatten) {
- inputTensor = [graph reshapeTensor:inputTensor withShape:@[ @-1 ] name:nil];
- length = 1;
- for (const auto i : c10::irange([shape count])) {
- if (c10::mul_overflows(length, [shape[i] unsignedIntValue], &length)) {
- TORCH_CHECK(false, "RuntimeError: Tensor size overflow");
- }
- }
-
- destShape = @[ [NSNumber numberWithUnsignedInteger:length] ];
- }
-
MPSGraphTensor* sortedInput = nil;
if (consecutive) {
sortedInput = inputTensor;
diff --git a/test/test_mps.py b/test/test_mps.py
index a9dcbd9..59d44c5 100644
--- a/test/test_mps.py
+++ b/test/test_mps.py
@@ -3452,6 +3452,9 @@
helper(torch.randint(3, (10, )), True, True)
helper(torch.randint(3, (1, )), True, True)
helper(torch.randint(3, (0, )), True, True)
+ # Regression test for https://github.com/pytorch/pytorch/issues/104879
+ x = torch.arange(2, device="mps")
+ self.assertEqual(x.reshape(1, 1, 2).unique(), x)
def test_unique_consecutive(self):
def helper(x, dim, return_inverse, return_counts):