[MPS] Fix and unblock TestConsistency for median (#94489)

- fix num_output_dims calculation
- fix median_out_mps key
- cast tensor sent to sortWithTensor and argSortWithTensor
- note down same issue for unique
- unblock median from blocklist
- adding test_median_int16 test

Fixes #ISSUE_NUMBER

Pull Request resolved: https://github.com/pytorch/pytorch/pull/94489
Approved by: https://github.com/razarmehr
diff --git a/aten/src/ATen/native/mps/operations/ReduceOps.mm b/aten/src/ATen/native/mps/operations/ReduceOps.mm
index 88df3af..6f3b8d7 100644
--- a/aten/src/ATen/native/mps/operations/ReduceOps.mm
+++ b/aten/src/ATen/native/mps/operations/ReduceOps.mm
@@ -1751,11 +1751,21 @@
         @autoreleasepool {
           MPSGraph* mpsGraph = make_mps_graph();
           newCachedGraph = new CachedGraph(mpsGraph);
-
           auto inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, input_t);
           auto reshapedTensor = [mpsGraph reshapeTensor: inputTensor
                                               withShape: @[@-1]
                                                    name: nil];
+          MPSDataType dataType = [inputTensor dataType];
+          // #issue 104398441 sortWithTensor only supports following types, cast if necessary
+          if (dataType != MPSDataTypeInt32 &&
+              dataType != MPSDataTypeFloat32 &&
+              dataType != MPSDataTypeFloat16) {
+              dataType = (dataType & MPSDataTypeFloatBit) ? MPSDataTypeFloat32 : MPSDataTypeInt32;
+              reshapedTensor = [mpsGraph castTensor:reshapedTensor
+                                      toType:dataType
+                                        name:@"castReshapedTensor"];
+          }
+
           auto sortedTensor = [mpsGraph sortWithTensor: reshapedTensor
                                                   axis: ((NSUInteger) (int)0)
                                                   name: nil];
@@ -1835,7 +1845,7 @@
   auto stream = at::mps::getCurrentMPSStream();
 
   @autoreleasepool {
-    string key = func_name + ":" + to_string(dim_) + ":" + getTensorsStringKey(input_t);
+    string key = func_name + ":" + to_string(dim_) + ":" + getTensorsStringKey(input_t) + ":" + getTensorsStringKey(indices_t);
     CachedGraph* cachedGraph = cache_->LookUpAs<CachedGraph>(key);
 
     if (!cachedGraph) {
@@ -1847,24 +1857,39 @@
           auto mpsGraph = make_mps_graph();
           newCachedGraph = new CachedGraph(mpsGraph);
 
-          MPSGraphTensor* inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, input_t);
-          auto sortedTensor = [mpsGraph sortWithTensor: inputTensor
-                                                  axis: (NSUInteger)dim_
-                                                  name: nil];
-          const NSUInteger midpoint = (dim_total_elements + 1) / 2 - 1;
-          auto outputTensor = [mpsGraph sliceTensor:sortedTensor
-                                          dimension:dim_
-                                              start:midpoint
-                                             length:1
-                                               name:nil];
-          auto argreduceOutTensor = [mpsGraph argSortWithTensor:inputTensor
-                                                           axis:(NSInteger)dim_
-                                                           name:@"argmax_out"];
-          auto argOutputTensor = [mpsGraph sliceTensor:argreduceOutTensor
-                                             dimension:dim_
-                                                 start:midpoint
-                                                length:1
-                                                  name:nil];
+          MPSGraphTensor* inputTensor = mpsGraphUnrankedPlaceHolder(mpsGraph, getMPSDataType(input_t.scalar_type()));
+          MPSGraphTensor* outputTensor = nil;
+          MPSGraphTensor* castInputTensor = inputTensor;
+          MPSDataType dataType = getMPSDataType(input_t.scalar_type());
+          // #issue 104398441 sortWithTensor only supports following types, cast if necessary
+          if (dataType != MPSDataTypeInt32 &&
+              dataType != MPSDataTypeFloat32 &&
+              dataType != MPSDataTypeFloat16) {
+              dataType = (dataType & MPSDataTypeFloatBit) ? MPSDataTypeFloat32 : MPSDataTypeInt32;
+              castInputTensor = [mpsGraph castTensor:inputTensor
+                                      toType:dataType
+                                        name:@"castInputTensor"];
+          }
+
+          MPSGraphTensor * sortedTensor = [mpsGraph
+                                              sortWithTensor:castInputTensor
+                                              axis:((NSUInteger) (int)dim_)
+                                              name:nil];
+
+          outputTensor = [mpsGraph sliceTensor:sortedTensor
+                                                    dimension:dim_
+                                                    start:((NSUInteger) (int)((dim_total_elements+1)/2 ) - 1)
+                                                    length:1
+                                                    name:nil];
+          MPSGraphTensor* argreduceOutTensor = nil;
+            argreduceOutTensor = [mpsGraph argSortWithTensor:castInputTensor
+                                                                    axis:(NSInteger)dim_
+                                                                    name:@"argmax_out"];
+          MPSGraphTensor* argOutputTensor = [mpsGraph sliceTensor:argreduceOutTensor
+                                                    dimension:dim_
+                                                    start:((NSUInteger) (int)((dim_total_elements+1)/2 ) - 1)
+                                                    length:1
+                                                    name:nil];
 
           newCachedGraph->inputTensor_ = inputTensor;
           newCachedGraph->outputTensor_ = outputTensor;
@@ -1934,7 +1959,7 @@
   int64_t num_input_dims = input_shape.size();
   NSMutableArray<NSNumber*> *apparent_out_shape = nil;
   // Use this if keepdim is false
-  int64_t num_output_dims = num_input_dims - 1;
+  int64_t num_output_dims = num_input_dims - 1 < 0 ? 0 : num_input_dims - 1;
 
   std::vector<int64_t> vec_apparent_out_shape(num_input_dims);
   std::vector<int64_t> vec_out_shape(num_output_dims);
diff --git a/aten/src/ATen/native/mps/operations/Unique.mm b/aten/src/ATen/native/mps/operations/Unique.mm
index 109244b..eac16a7 100644
--- a/aten/src/ATen/native/mps/operations/Unique.mm
+++ b/aten/src/ATen/native/mps/operations/Unique.mm
@@ -57,7 +57,7 @@
     return {resultTensor, inverseIndicesTensor, countTensor, lengthTensor};
   }
 
-  // Sort only supports following types, cast if necessary
+  // #issue 104398441 sortWithTensor only supports following types, cast if necessary
   if (dataType != MPSDataTypeInt32 &&
       dataType != MPSDataTypeFloat32 &&
       dataType != MPSDataTypeFloat16) {
diff --git a/test/test_mps.py b/test/test_mps.py
index 9002a0a..3cd98df 100644
--- a/test/test_mps.py
+++ b/test/test_mps.py
@@ -2325,6 +2325,17 @@
                 helper(dtype, noncontiguous, dim)
 
 
+    def test_median_int16(self):
+        def helper(shape, dtype):
+            cpu_x = torch.randint(-9999, 9999, shape, device='cpu', dtype=dtype)
+            x = cpu_x.detach().clone().to('mps')
+
+            median_result = torch.median(x)
+            median_result_cpu = torch.median(cpu_x)
+            self.assertEqual(median_result, median_result_cpu)
+
+        helper((2, 8, 4, 5), torch.int16)
+
 class TestLogical(TestCase):
     def _wrap_tensor(self, x, device="cpu", dtype=None, requires_grad=False):
         return torch.tensor(x, device=device, dtype=dtype, requires_grad=requires_grad)