[MPS] Support for median with dim (#88807)

## Summary ⚡

**Aim**: Add support for aten::median for MPS backend (Fixes #87220)

This is fresh clean PR from the previous [PR](https://github.com/pytorch/pytorch/pull/88554)

- Implementing the new median function in aten/src/ATen/native/mps/operations/ReduceOps.mm
- Adding it to aten/src/ATen/native/native_functions.yaml
- Adding it to existing test_median

### **this will works like this** 🪶
median of entire input tensor on MPS
`torch.median(mps_inputTensor)`
median of along a dim
`torch.median(mps_inputTensor, dim=[int], keepdim=[Bool])`
Pull Request resolved: https://github.com/pytorch/pytorch/pull/88807
Approved by: https://github.com/kulinseth
diff --git a/aten/src/ATen/native/mps/MPSGraphVenturaOps.h b/aten/src/ATen/native/mps/MPSGraphVenturaOps.h
index 86153b5..b77db66 100644
--- a/aten/src/ATen/native/mps/MPSGraphVenturaOps.h
+++ b/aten/src/ATen/native/mps/MPSGraphVenturaOps.h
@@ -6,4 +6,12 @@
 - (MPSGraphTensor *)cumulativeSumWithTensor:(MPSGraphTensor *)tensor
                                        axis:(NSInteger)axis
                                        name:(NSString *)name;
+
+- (MPSGraphTensor *)sortWithTensor:(MPSGraphTensor *)tensor
+                                       axis:(NSInteger)axis
+                                       name:(NSString *)name;
+
+- (MPSGraphTensor *)argSortWithTensor:(MPSGraphTensor *)tensor
+                                       axis:(NSInteger)axis
+                                       name:(NSString *)name;
 @end
diff --git a/aten/src/ATen/native/mps/operations/ReduceOps.mm b/aten/src/ATen/native/mps/operations/ReduceOps.mm
index 91aa245..c99f22d 100644
--- a/aten/src/ATen/native/mps/operations/ReduceOps.mm
+++ b/aten/src/ATen/native/mps/operations/ReduceOps.mm
@@ -9,6 +9,7 @@
 #include <ATen/native/ReduceOpsUtils.h>
 #include <ATen/native/Pool.h>
 #include <torch/library.h>
+#include <ATen/native/mps/MPSGraphVenturaOps.h>
 
 namespace at {
 namespace native {
@@ -1638,5 +1639,319 @@
     return min_max_mps(input_t, dim, keepdim, MPSReductionType::MIN, "min_mps");
 }
 
+// Median of entire tensor into scalar result
+Tensor median_mps(const Tensor& input_t) {
+
+  if(!is_macos_13_or_newer()){
+        TORCH_WARN_ONCE("MPS: median op is supported natively starting from macOS 13.0. ",
+                    "Falling back on CPU. This may have performace implications.");
+        return at::median(input_t.to("cpu"));
+  }
+
+    TORCH_INTERNAL_ASSERT(input_t.scalar_type() != ScalarType::Long, "median not supported for Long dtype on MPS");
+
+    namespace native_mps = at::native::mps;
+    using CachedGraph = native_mps::MPSUnaryCachedGraph;
+
+    native_mps::MPSGraphCache* cache_ = native_mps::MPSGraphCache::getInstance();
+
+    IntArrayRef input_shape = input_t.sizes();
+    int64_t num_input_dims = input_shape.size();
+
+    // calculate total no. of elements in the input tensor to reduce it to one dimension
+    NSMutableArray<NSNumber*> *apparent_input_shape = [NSMutableArray<NSNumber*> arrayWithCapacity:1];
+    int64_t num_in_elements = 1;
+    for(int i = 0; i < num_input_dims; i++) {
+        num_in_elements *= input_shape[i];
+    }
+
+    apparent_input_shape[0] = [NSNumber numberWithInt:num_in_elements];
+
+    Tensor output_t = at::native::empty_mps({}, input_t.scalar_type(), c10::nullopt, kMPS, c10::nullopt, c10::nullopt);
+
+    if (output_t.numel() == 0 || num_in_elements == 0) {
+        return output_t;
+    }
+
+  @autoreleasepool {
+    string key = "median_mps:"+ mps::getMPSTypeString(input_t.scalar_type())  + mps::getTensorsStringKey(input_t);
+    CachedGraph* cachedGraph = cache_->LookUpAs<CachedGraph>(key);
+    // Initialize once if configuration not found in cache
+    if(!cachedGraph) {
+      native_mps::MPSCachedGraph *tmpCachedGraph = cache_->CreateCachedGraph(key, ^ native_mps::MPSCachedGraph * () {
+
+        CachedGraph *newCachedGraph = nil;
+
+        @autoreleasepool {
+            MPSGraph* mpsGraph = native_mps::make_mps_graph();
+            newCachedGraph = new CachedGraph(mpsGraph);
+
+            MPSGraphTensor* inputTensor = native_mps::mpsGraphRankedPlaceHolder(mpsGraph, input_t);
+
+            MPSGraphTensor* outputTensor = nil;
+
+            MPSGraphTensor * reshapedTensor = [mpsGraph reshapeTensor:inputTensor
+                                                            withShape:@[@-1]
+                                                                  name:nil];
+            MPSGraphTensor * sortedTensor = [mpsGraph
+                                                  sortWithTensor:reshapedTensor
+                                                  axis:((NSUInteger) (int)0)
+                                                  name:nil];
+
+            outputTensor = [mpsGraph sliceTensor:sortedTensor
+                                                        dimension:0
+                                                        start:((NSUInteger) (int)((num_in_elements+1)/2 ) - 1)
+                                                        length:1
+                                                        name:nil];
+
+            newCachedGraph->inputTensor_ = inputTensor;
+            newCachedGraph->outputTensor_ = outputTensor;
+        }
+        return newCachedGraph;
+      });
+      cachedGraph = static_cast<CachedGraph *>(tmpCachedGraph);
+    }
+
+    auto inputPlaceholder = native_mps::Placeholder(cachedGraph->inputTensor_, input_t);
+    auto outputPlaceholder = native_mps::Placeholder(cachedGraph->outputTensor_, output_t, @[@1]);
+
+    NSDictionary<MPSGraphTensor *, MPSGraphTensorData *> *feeds = @{
+      inputPlaceholder.getMPSGraphTensor() : inputPlaceholder.getMPSGraphTensorData(),
+    };
+
+    NSDictionary<MPSGraphTensor *, MPSGraphTensorData *> *results = @{
+      outputPlaceholder.getMPSGraphTensor() : outputPlaceholder.getMPSGraphTensorData()
+    };
+
+    native_mps::runMPSGraph(getCurrentMPSStream(), cachedGraph->graph(), feeds, results);
+  }
+
+  return output_t;
+}
+
+
+void median_out_mps
+  (const Tensor& input_t,
+  int64_t dim,
+  bool keepdim,
+  const Tensor& output_t,
+  const Tensor& indices_t,
+  const std::string& func_name) {
+
+    namespace native_mps = at::native::mps;
+
+    if (output_t.numel() == 0) {
+      return;
+    }
+    if (input_t.numel() == 1 && input_t.dim() == 0) {
+      output_t.fill_(input_t);
+      indices_t.fill_(0);
+      return;
+    }
+
+    // Derive from MPSCachedGraph
+    struct CachedGraph : public native_mps::MPSCachedGraph
+    {
+      CachedGraph(MPSGraph *graph) : MPSCachedGraph(graph) {}
+      MPSGraphTensor *inputTensor_ = nil;
+      MPSGraphTensor *outputTensor_ = nil;
+      MPSGraphTensor *indicesTensor_ = nil;
+    };
+
+    native_mps::MPSGraphCache* cache_ = native_mps::MPSGraphCache::getInstance();
+
+    int64_t dim_ = maybe_wrap_dim(dim, input_t.dim());
+
+    // Calculate the output shape according to keepdim=True
+    // If there is no dim argument, the input shape is flattened
+    IntArrayRef input_shape = input_t.sizes();
+    int64_t num_input_dims = input_shape.size();
+    NSMutableArray<NSNumber*> *apparent_out_shape = nil;
+
+    apparent_out_shape = [NSMutableArray<NSNumber*> arrayWithCapacity:num_input_dims];
+    for(int i = 0; i < num_input_dims; i++) {
+        if(dim_ == i)
+            apparent_out_shape[i] = @1;
+        else
+            apparent_out_shape[i] = [NSNumber numberWithInt:input_shape[i]];
+    }
+    int dim_total_elements = input_shape[dim_];
+
+    auto stream = at::mps::getCurrentMPSStream();
+
+    @autoreleasepool {
+        string key = func_name + ":" + to_string(dim_) + ":" + native_mps::getMPSTypeString(input_t.scalar_type());
+        CachedGraph* cachedGraph = static_cast<CachedGraph *>(cache_->LookUp(key));
+
+        if(!cachedGraph) {
+          native_mps::MPSCachedGraph *tmpCachedGraph = cache_->CreateCachedGraph(key, ^ native_mps::MPSCachedGraph * () {
+
+            CachedGraph *newCachedGraph = nil;
+
+            @autoreleasepool {
+              MPSGraph* mpsGraph = native_mps::make_mps_graph();
+              newCachedGraph = new CachedGraph(mpsGraph);
+
+              MPSGraphTensor* inputTensor = native_mps::mpsGraphUnrankedPlaceHolder(mpsGraph, native_mps::getMPSDataType(input_t.scalar_type()));
+              MPSGraphTensor* outputTensor = nil;
+              MPSGraphTensor * sortedTensor = [mpsGraph
+                                                  sortWithTensor:inputTensor
+                                                  axis:((NSUInteger) (int)dim_)
+                                                  name:nil];
+
+              outputTensor = [mpsGraph sliceTensor:sortedTensor
+                                                        dimension:dim_
+                                                        start:((NSUInteger) (int)((dim_total_elements+1)/2 ) - 1)
+                                                        length:1
+                                                        name:nil];
+              MPSGraphTensor* argreduceOutTensor = nil;
+                argreduceOutTensor = [mpsGraph argSortWithTensor:inputTensor
+                                                                        axis:(NSInteger)dim_
+                                                                        name:@"argmax_out"];
+              MPSGraphTensor* argOutputTensor = [mpsGraph sliceTensor:argreduceOutTensor
+                                                        dimension:dim_
+                                                        start:((NSUInteger) (int)((dim_total_elements+1)/2 ) - 1)
+                                                        length:1
+                                                        name:nil];
+
+              newCachedGraph->inputTensor_ = inputTensor;
+              newCachedGraph->outputTensor_ = outputTensor;
+              newCachedGraph->indicesTensor_ = argOutputTensor;
+            }
+            return newCachedGraph;
+          });
+          cachedGraph = static_cast<CachedGraph *>(tmpCachedGraph);
+        }
+
+        auto inputPlaceholder = native_mps::Placeholder(cachedGraph->inputTensor_, input_t);
+        auto outputPlaceholder = native_mps::Placeholder(cachedGraph->outputTensor_, output_t, apparent_out_shape);
+        auto indicesPlaceholder = native_mps::Placeholder(cachedGraph->indicesTensor_, indices_t, apparent_out_shape);
+
+        NSDictionary<MPSGraphTensor *, MPSGraphTensorData *> *feeds = @{
+          inputPlaceholder.getMPSGraphTensor() : inputPlaceholder.getMPSGraphTensorData(),
+        };
+
+        NSDictionary<MPSGraphTensor *, MPSGraphTensorData *> *results = @{
+          outputPlaceholder.getMPSGraphTensor() : outputPlaceholder.getMPSGraphTensorData(),
+          indicesPlaceholder.getMPSGraphTensor() : indicesPlaceholder.getMPSGraphTensorData()
+        };
+
+        native_mps::runMPSGraph(stream, cachedGraph->graph(), feeds, results);
+
+    }
+
+}
+
+// in case mps sortWithTensor do not supported on macOS
+std::tuple<Tensor&, Tensor&> median_from_cpu(
+    const Tensor& self,
+    int64_t dim,
+    bool keepdim, Tensor & valuesI, Tensor & indicesI, IntArrayRef vec_out_shape, IntArrayRef vec_apparent_out_shape) {
+      // Tensor a = at::median(self.to("cpu"));
+      Tensor values;
+      Tensor indices;
+    if (!keepdim){
+        values = at::empty({vec_out_shape}, self.options());
+        indices = at::empty({vec_out_shape}, self.options().dtype(kLong));
+
+      }
+      else{
+          values = at::empty({vec_apparent_out_shape}, self.options());
+          indices = at::empty({vec_apparent_out_shape}, self.options().dtype(kLong));
+      }
+      at::median_out(values, indices, self, dim, keepdim);
+
+  valuesI.copy_(values);
+  indicesI.copy_(indices);
+  return std::forward_as_tuple(valuesI, indicesI);
+}
+
+TORCH_API ::std::tuple<at::Tensor &,at::Tensor &> median_out_mps
+    (const at::Tensor & input_t,
+    int64_t dim,
+    bool keepdim,
+    at::Tensor & values,
+    at::Tensor & indices){
+
+  TORCH_INTERNAL_ASSERT(input_t.scalar_type() != ScalarType::Long, "median not supported for Long dtype on MPS");
+
+  namespace native_mps = at::native::mps;
+    int64_t dim_ = maybe_wrap_dim(dim, input_t.dim());
+    native::zero_numel_check_dims(input_t, dim_, "max()");
+
+    // Calculate the output shape according to keepdim=True
+    // If there is no dim argument, the input shape is flattened
+    IntArrayRef input_shape = input_t.sizes();
+    int64_t num_input_dims = input_shape.size();
+    NSMutableArray<NSNumber*> *apparent_out_shape = nil;
+    // Use this if keepdim is false
+    int64_t num_output_dims = num_input_dims - 1;
+
+    std::vector<int64_t> vec_apparent_out_shape(num_input_dims);
+    std::vector<int64_t> vec_out_shape(num_output_dims);
+
+    apparent_out_shape = [NSMutableArray<NSNumber*> arrayWithCapacity:num_input_dims];
+    // Counter for shape when keepdim is false
+    int out_i = 0;
+    for(int i = 0; i < num_input_dims; i++) {
+        if(dim_ == i) {
+            apparent_out_shape[i] = @1;
+            vec_apparent_out_shape[i] = 1;
+        }
+        else {
+            apparent_out_shape[i] = [NSNumber numberWithInt:input_shape[i]];
+            vec_apparent_out_shape[i] = input_shape[i];
+            vec_out_shape[out_i] = input_shape[i];
+            out_i++;
+        }
+    }
+
+    if(!keepdim) {
+     values = at::native::empty_mps(
+                      IntArrayRef(vec_out_shape),
+                      input_t.scalar_type(),
+                      c10::nullopt,
+                      kMPS,
+                      c10::nullopt,
+                      c10::nullopt);
+     indices = at::native::empty_mps(
+                      IntArrayRef(vec_out_shape),
+                      ScalarType::Long,
+                      c10::nullopt,
+                      kMPS,
+                      c10::nullopt,
+                      c10::nullopt);
+    } else {
+      values = at::native::empty_mps(
+                      IntArrayRef(vec_apparent_out_shape),
+                      input_t.scalar_type(),
+                      c10::nullopt,
+                      kMPS,
+                      c10::nullopt,
+                      c10::nullopt);
+     indices = at::native::empty_mps(
+                      IntArrayRef(vec_apparent_out_shape),
+                      ScalarType::Long,
+                      c10::nullopt,
+                      kMPS,
+                      c10::nullopt,
+                      c10::nullopt);
+    }
+
+    if (values.numel() == 0 || input_t.numel() == 0) {
+        return std::tuple<Tensor&, Tensor&>{values, indices};
+    }
+
+    if(!is_macos_13_or_newer()){
+      TORCH_WARN_ONCE("MPS: median op is supported natively starting from macOS 13.0.",
+                    "Falling back on CPU. This may have performace implications.");
+    return median_from_cpu(input_t.to("cpu"), dim, keepdim, values, indices, IntArrayRef(vec_out_shape),IntArrayRef(vec_apparent_out_shape) );
+  }
+
+    median_out_mps(input_t, dim, keepdim, values, indices, "median_out_mps");
+
+    return std::tuple<Tensor&, Tensor&>{values, indices};
+}
+
 } // native
 } // at
diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml
index 8046b4f..b1d1094 100644
--- a/aten/src/ATen/native/native_functions.yaml
+++ b/aten/src/ATen/native/native_functions.yaml
@@ -3492,6 +3492,7 @@
   dispatch:
     CPU: median_cpu
     CUDA: median_cuda
+    MPS: median_mps
   autogen: median.out
 
 - func: median.dim(Tensor self, int dim, bool keepdim=False) -> (Tensor values, Tensor indices)
@@ -3503,6 +3504,7 @@
   dispatch:
     CPU: median_out_cpu
     CUDA: median_out_cuda
+    MPS: median_out_mps
 
 - func: median.names_dim(Tensor self, Dimname dim, bool keepdim=False) -> (Tensor values, Tensor indices)
   variants: function, method
diff --git a/test/test_mps.py b/test/test_mps.py
index 31e2e36..52d6695 100644
--- a/test/test_mps.py
+++ b/test/test_mps.py
@@ -2624,6 +2624,47 @@
         helper(2, 8, 4, 5, torch.int32)
         # helper(2, 8, 4, 5, torch.int64)
 
+    def test_median(self):
+        def helper_dtype_int32(n1, n2, n3):
+            cpu_x = torch.randint(50, (n1, n2, n3), device='cpu', dtype=torch.int32)
+            mps_x = cpu_x.detach().clone().to('mps')
+
+            result_cpu = torch.median(cpu_x)
+            result_mps = torch.median(mps_x)
+
+            self.assertEqual(result_cpu, result_mps)
+
+            for dim in [0, 1, 2]:
+                for keepdim in [True, False]:
+                    y, idx = torch.median(cpu_x, dim=dim, keepdim=keepdim)
+                    refy, refidx = torch.median(mps_x, dim=dim, keepdim=keepdim)
+                    self.assertEqual(y, refy)
+                    self.assertEqual(idx, refidx)
+
+        def helper_dtype_float32(n1, n2, n3):
+            cpu_x = torch.randn(n1, n2, n3, device='cpu', dtype=torch.float32)
+            mps_x = cpu_x.detach().clone().to('mps')
+
+            result_cpu = torch.median(cpu_x)
+            result_mps = torch.median(mps_x)
+
+            self.assertEqual(result_cpu, result_mps)
+
+            for dim in [0, 1, 2]:
+                for keepdim in [True, False]:
+                    y, idx = torch.median(cpu_x, dim=dim, keepdim=keepdim)
+                    refy, refidx = torch.median(mps_x, dim=dim, keepdim=keepdim)
+                    self.assertEqual(y, refy)
+                    self.assertEqual(idx, refidx)
+
+        helper_dtype_int32(10, 10, 10)  # median at even place
+        helper_dtype_int32(3, 3, 3)  # median at odd place
+        helper_dtype_int32(1, 1, 1)
+        helper_dtype_int32(1, 2, 3)
+        helper_dtype_float32(10, 10, 10)
+        helper_dtype_float32(3, 3, 3)
+        helper_dtype_float32(1, 1, 1)
+
     def test_any(self):
         def helper(shape):
             input_xs = []