[MPS] Support for median with dim (#88807)
## Summary ⚡
**Aim**: Add support for aten::median for MPS backend (Fixes #87220)
This is fresh clean PR from the previous [PR](https://github.com/pytorch/pytorch/pull/88554)
- Implementing the new median function in aten/src/ATen/native/mps/operations/ReduceOps.mm
- Adding it to aten/src/ATen/native/native_functions.yaml
- Adding it to existing test_median
### **this will works like this** 🪶
median of entire input tensor on MPS
`torch.median(mps_inputTensor)`
median of along a dim
`torch.median(mps_inputTensor, dim=[int], keepdim=[Bool])`
Pull Request resolved: https://github.com/pytorch/pytorch/pull/88807
Approved by: https://github.com/kulinseth
diff --git a/test/test_mps.py b/test/test_mps.py
index 31e2e36..52d6695 100644
--- a/test/test_mps.py
+++ b/test/test_mps.py
@@ -2624,6 +2624,47 @@
helper(2, 8, 4, 5, torch.int32)
# helper(2, 8, 4, 5, torch.int64)
+ def test_median(self):
+ def helper_dtype_int32(n1, n2, n3):
+ cpu_x = torch.randint(50, (n1, n2, n3), device='cpu', dtype=torch.int32)
+ mps_x = cpu_x.detach().clone().to('mps')
+
+ result_cpu = torch.median(cpu_x)
+ result_mps = torch.median(mps_x)
+
+ self.assertEqual(result_cpu, result_mps)
+
+ for dim in [0, 1, 2]:
+ for keepdim in [True, False]:
+ y, idx = torch.median(cpu_x, dim=dim, keepdim=keepdim)
+ refy, refidx = torch.median(mps_x, dim=dim, keepdim=keepdim)
+ self.assertEqual(y, refy)
+ self.assertEqual(idx, refidx)
+
+ def helper_dtype_float32(n1, n2, n3):
+ cpu_x = torch.randn(n1, n2, n3, device='cpu', dtype=torch.float32)
+ mps_x = cpu_x.detach().clone().to('mps')
+
+ result_cpu = torch.median(cpu_x)
+ result_mps = torch.median(mps_x)
+
+ self.assertEqual(result_cpu, result_mps)
+
+ for dim in [0, 1, 2]:
+ for keepdim in [True, False]:
+ y, idx = torch.median(cpu_x, dim=dim, keepdim=keepdim)
+ refy, refidx = torch.median(mps_x, dim=dim, keepdim=keepdim)
+ self.assertEqual(y, refy)
+ self.assertEqual(idx, refidx)
+
+ helper_dtype_int32(10, 10, 10) # median at even place
+ helper_dtype_int32(3, 3, 3) # median at odd place
+ helper_dtype_int32(1, 1, 1)
+ helper_dtype_int32(1, 2, 3)
+ helper_dtype_float32(10, 10, 10)
+ helper_dtype_float32(3, 3, 3)
+ helper_dtype_float32(1, 1, 1)
+
def test_any(self):
def helper(shape):
input_xs = []