[MPS] Support for median with dim (#88807)

## Summary ⚡

**Aim**: Add support for aten::median for MPS backend (Fixes #87220)

This is fresh clean PR from the previous [PR](https://github.com/pytorch/pytorch/pull/88554)

- Implementing the new median function in aten/src/ATen/native/mps/operations/ReduceOps.mm
- Adding it to aten/src/ATen/native/native_functions.yaml
- Adding it to existing test_median

### **this will works like this** 🪶
median of entire input tensor on MPS
`torch.median(mps_inputTensor)`
median of along a dim
`torch.median(mps_inputTensor, dim=[int], keepdim=[Bool])`
Pull Request resolved: https://github.com/pytorch/pytorch/pull/88807
Approved by: https://github.com/kulinseth
diff --git a/test/test_mps.py b/test/test_mps.py
index 31e2e36..52d6695 100644
--- a/test/test_mps.py
+++ b/test/test_mps.py
@@ -2624,6 +2624,47 @@
         helper(2, 8, 4, 5, torch.int32)
         # helper(2, 8, 4, 5, torch.int64)
 
+    def test_median(self):
+        def helper_dtype_int32(n1, n2, n3):
+            cpu_x = torch.randint(50, (n1, n2, n3), device='cpu', dtype=torch.int32)
+            mps_x = cpu_x.detach().clone().to('mps')
+
+            result_cpu = torch.median(cpu_x)
+            result_mps = torch.median(mps_x)
+
+            self.assertEqual(result_cpu, result_mps)
+
+            for dim in [0, 1, 2]:
+                for keepdim in [True, False]:
+                    y, idx = torch.median(cpu_x, dim=dim, keepdim=keepdim)
+                    refy, refidx = torch.median(mps_x, dim=dim, keepdim=keepdim)
+                    self.assertEqual(y, refy)
+                    self.assertEqual(idx, refidx)
+
+        def helper_dtype_float32(n1, n2, n3):
+            cpu_x = torch.randn(n1, n2, n3, device='cpu', dtype=torch.float32)
+            mps_x = cpu_x.detach().clone().to('mps')
+
+            result_cpu = torch.median(cpu_x)
+            result_mps = torch.median(mps_x)
+
+            self.assertEqual(result_cpu, result_mps)
+
+            for dim in [0, 1, 2]:
+                for keepdim in [True, False]:
+                    y, idx = torch.median(cpu_x, dim=dim, keepdim=keepdim)
+                    refy, refidx = torch.median(mps_x, dim=dim, keepdim=keepdim)
+                    self.assertEqual(y, refy)
+                    self.assertEqual(idx, refidx)
+
+        helper_dtype_int32(10, 10, 10)  # median at even place
+        helper_dtype_int32(3, 3, 3)  # median at odd place
+        helper_dtype_int32(1, 1, 1)
+        helper_dtype_int32(1, 2, 3)
+        helper_dtype_float32(10, 10, 10)
+        helper_dtype_float32(3, 3, 3)
+        helper_dtype_float32(1, 1, 1)
+
     def test_any(self):
         def helper(shape):
             input_xs = []