[MPS] Fix and unblock TestConsistency for median (#94489)
- fix num_output_dims calculation
- fix median_out_mps key
- cast tensor sent to sortWithTensor and argSortWithTensor
- note down same issue for unique
- unblock median from blocklist
- adding test_median_int16 test
Fixes #ISSUE_NUMBER
Pull Request resolved: https://github.com/pytorch/pytorch/pull/94489
Approved by: https://github.com/razarmehr
diff --git a/test/test_mps.py b/test/test_mps.py
index 9002a0a..3cd98df 100644
--- a/test/test_mps.py
+++ b/test/test_mps.py
@@ -2325,6 +2325,17 @@
helper(dtype, noncontiguous, dim)
+ def test_median_int16(self):
+ def helper(shape, dtype):
+ cpu_x = torch.randint(-9999, 9999, shape, device='cpu', dtype=dtype)
+ x = cpu_x.detach().clone().to('mps')
+
+ median_result = torch.median(x)
+ median_result_cpu = torch.median(cpu_x)
+ self.assertEqual(median_result, median_result_cpu)
+
+ helper((2, 8, 4, 5), torch.int16)
+
class TestLogical(TestCase):
def _wrap_tensor(self, x, device="cpu", dtype=None, requires_grad=False):
return torch.tensor(x, device=device, dtype=dtype, requires_grad=requires_grad)