[MPS] Remove casts from reduction/cumsum/sort ops starting with macOS 13.3 (#95817)

MPS in macOS13.3 has added support for int64 in reduction ops / cumsum / sort / argsort. This change removes the hard-coded casts and error messages prior macOS 13.3, allowing the op to run natively with int64.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/95817
Approved by: https://github.com/kulinseth
diff --git a/aten/src/ATen/mps/MPSDevice.mm b/aten/src/ATen/mps/MPSDevice.mm
index 0576f9b..792e7a8 100644
--- a/aten/src/ATen/mps/MPSDevice.mm
+++ b/aten/src/ATen/mps/MPSDevice.mm
@@ -99,8 +99,9 @@
     sampleGridWithSourceTensor:coordinateTensor:layout:normalizeCoordinates:relativeCoordinates:alignCorners:paddingMode:samplingMode:constantValue:name:)] == YES;
   static bool _macos_13_2_plus = [mpsCD instancesRespondToSelector:@selector(convolution3DWithSourceTensor:weightsTensor:descriptor:name:)] == YES;
   static bool _macos_13_3_plus = NO;
-  if (@available(macOS 13.3, *))
+  if (@available(macOS 13.3, *)) {
     _macos_13_3_plus = YES;
+  }
 
   switch (version) {
     case MacOSVersion::MACOS_VER_13_0_PLUS:  return _macos_13_0_plus;
diff --git a/aten/src/ATen/native/mps/OperationUtils.h b/aten/src/ATen/native/mps/OperationUtils.h
index 689d58f..622efc9 100644
--- a/aten/src/ATen/native/mps/OperationUtils.h
+++ b/aten/src/ATen/native/mps/OperationUtils.h
@@ -56,8 +56,8 @@
 Tensor& scatterViewTensor(const at::Tensor& src, at::Tensor& output);
 bool canSliceViewTensor(const Tensor& src, MPSShape *mpsShape);
 MPSGraphTensorData* getMPSGraphTensorDataForView(const Tensor& src, MPSShape *mpsShape, const MPSDataType mpsDataType);
-MPSGraphTensor* castToIHFTypes(MPSGraph* mpsGraph, MPSGraphTensor* inputTensor, const Tensor& input);
-MPSGraphTensor* castFromIHFTypes(MPSGraph* mpsGraph, MPSGraphTensor* inputTensor, const Tensor& input);
+MPSGraphTensor* castToIHFTypes(MPSGraph* mpsGraph, MPSGraphTensor* inputTensor, const Tensor& input, bool includesInt64 = false);
+MPSGraphTensor* castFromIHFTypes(MPSGraph* mpsGraph, MPSGraphTensor* inputTensor, const Tensor& input, bool includesInt64 = false);
 
 // The MPSShape could vary based on memory format
 MPSShape* getMPSShape(const Tensor& t, c10::MemoryFormat memory_format = MemoryFormat::Contiguous);
@@ -93,6 +93,7 @@
 MPSGraphTensor* trunc_tensor(MPSGraph* mpsGraph, MPSGraphTensor* inputTensor);
 MPSGraphTensor* convertNHWCtoNCHW(MPSGraph *mpsGraph, MPSGraphTensor* tensor);
 MPSGraphTensor* castMPSTensor(MPSGraph *mpsGraph, MPSGraphTensor* tensor, ScalarType toType);
+MPSGraphTensor* castMPSTensor(MPSGraph *mpsGraph, MPSGraphTensor* tensor, MPSDataType toType);
 MPSGraphTensorData *getMPSGraphTensorData(MPSGraph* mpsGraph, MPSStream* mpsStream, const Tensor& tensor);
 MPSGraphTensorData* getMPSGraphTensorFromScalar(MPSStream* mpsStream, MPSScalar& scalar);
 
@@ -244,6 +245,10 @@
 // Common math operations
 MPSGraphTensor* log1p(MPSGraph* mpsGraph, MPSGraphTensor* inputTensor);
 
+#define MPS_CHECK_INT64_OP_SUPPORTED(input_tensor, mac_os_13_3_plus, op_name)                                                   \
+  if (!mac_os_13_3_plus && input_tensor.scalar_type() == kLong) {                                                               \
+     TORCH_WARN_ONCE("MPS: no support for int64 for ", op_name, ", casting to int32. Support has been added in macOS 13.3");    \
+  }
 
 } // namespace mps
 } // namespace native
diff --git a/aten/src/ATen/native/mps/OperationUtils.mm b/aten/src/ATen/native/mps/OperationUtils.mm
index c5e8b5d1f..0ecdfd0 100644
--- a/aten/src/ATen/native/mps/OperationUtils.mm
+++ b/aten/src/ATen/native/mps/OperationUtils.mm
@@ -38,15 +38,17 @@
 // #issue 104398441 sortWithTensor and argsortWithTensor has support of
 // Int32, Half and Float32 types. These utilities are to help cast to these
 // types.
-MPSGraphTensor* castToIHFTypes(MPSGraph* mpsGraph, MPSGraphTensor* inputTensor, const Tensor& input) {
+MPSGraphTensor* castToIHFTypes(MPSGraph* mpsGraph, MPSGraphTensor* inputTensor, const Tensor& input, bool includesInt64) {
   MPSDataType dataType = getMPSDataType(input.scalar_type());
-  if (dataType != MPSDataTypeInt32 &&
-      dataType != MPSDataTypeFloat32 &&
-      dataType != MPSDataTypeFloat16) {
-      dataType = (dataType & MPSDataTypeFloatBit) ? MPSDataTypeFloat32 : MPSDataTypeInt32;
-      return [mpsGraph castTensor:inputTensor
-                          toType:dataType
-                          name:@"castInputTensor"];
+  bool condition = (dataType != MPSDataTypeInt32) && (dataType != MPSDataTypeFloat32) && (dataType != MPSDataTypeFloat16);
+  if (includesInt64) {
+    condition = condition && (dataType != MPSDataTypeInt64);
+  }
+  if (condition) {
+    dataType = ((dataType & MPSDataTypeFloatBit) || (dataType == MPSDataTypeInt64)) ? MPSDataTypeFloat32 : MPSDataTypeInt32;
+    return [mpsGraph castTensor:inputTensor
+                         toType:dataType
+                           name:@"castInputTensor"];
   }
   return inputTensor;
 }
@@ -54,14 +56,16 @@
 // #issue 104398441 sortWithTensor and argsortWithTensor has support of
 // Int32, Half and Float32 types. These utilities are to help cast from these
 // types.
-MPSGraphTensor* castFromIHFTypes(MPSGraph* mpsGraph, MPSGraphTensor* inputTensor, const Tensor& input) {
+MPSGraphTensor* castFromIHFTypes(MPSGraph* mpsGraph, MPSGraphTensor* inputTensor, const Tensor& input, bool includesInt64) {
   MPSDataType dataType = getMPSDataType(input.scalar_type());
-  if (dataType != MPSDataTypeInt32 &&
-      dataType != MPSDataTypeFloat32 &&
-      dataType != MPSDataTypeFloat16) {
-      inputTensor = [mpsGraph castTensor:inputTensor
-                              toType:dataType
-                                name:@"castInputTensor"];
+  bool condition = (dataType != MPSDataTypeInt32) && (dataType != MPSDataTypeFloat32) && (dataType != MPSDataTypeFloat16);
+  if (includesInt64) {
+    condition = condition && (dataType != MPSDataTypeInt64);
+  }
+  if (condition) {
+    inputTensor = [mpsGraph castTensor:inputTensor
+                                toType:dataType
+                                  name:@"castInputTensor"];
   }
   return inputTensor;
 }
@@ -399,6 +403,10 @@
 
 // this is meant to suppress the availability warning on castTensor
 // we pass ScalarType instead of MPSDataType to handle MPSDataTypeBoolean's availability too
+MPSGraphTensor* castMPSTensor(MPSGraph *mpsGraph, MPSGraphTensor* tensor, MPSDataType toType) {
+  return [mpsGraph castTensor:tensor toType:toType name:@"castTensor"];
+}
+
 MPSGraphTensor* castMPSTensor(MPSGraph *mpsGraph, MPSGraphTensor* tensor, ScalarType toType) {
   return [mpsGraph castTensor:tensor toType:getMPSScalarType(toType) name:@"castTensor"];
 }
diff --git a/aten/src/ATen/native/mps/operations/ReduceOps.mm b/aten/src/ATen/native/mps/operations/ReduceOps.mm
index 583d12d..ceacc61 100644
--- a/aten/src/ATen/native/mps/operations/ReduceOps.mm
+++ b/aten/src/ATen/native/mps/operations/ReduceOps.mm
@@ -138,13 +138,10 @@
   const Tensor& output_t,
   MPSReductionType reduction_type,
   const std::string& func_name) {
+  bool macOS13_3_plus = is_macos_13_or_newer(MacOSVersion::MACOS_VER_13_3_PLUS);
+  MPS_CHECK_INT64_OP_SUPPORTED(input_t, macOS13_3_plus, func_name);
 
-  // issue 103641234, reduction ops does not have int64 support
-  if (input_t.scalar_type() == ScalarType::Long) {
-    TORCH_WARN_ONCE("MPS: no support for int64 reduction ops, casting it to int32");
-  }
-  IntArrayRef input_shape = input_t.sizes();
-
+  auto input_shape = input_t.sizes();
   if (opt_dim.has_value()) {
     IntArrayRef dim = opt_dim.value();
     for (const auto dim_val : dim) {
@@ -172,7 +169,6 @@
     }
     return;
   }
-
   auto stream = at::mps::getCurrentMPSStream();
   @autoreleasepool {
     std::string dtype_str = dtype.has_value() ? mps::getMPSTypeString(dtype.value()) : "";
@@ -199,22 +195,19 @@
 
           MPSGraphTensor* inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, input_t);
           MPSGraphTensor* castInputTensor = inputTensor;
-          MPSDataType inputCastDtype = MPSDataTypeInvalid;
+          MPSDataType inputCastType = MPSDataTypeInvalid;
           if (dtype.has_value() &&
-             (dtype.value() == kFloat || dtype.value() == kHalf || dtype.value() == kInt)) {
-            inputCastDtype = getMPSDataType(dtype.value());
-          } else if (input_type != MPSDataTypeInt32   &&
-                     input_type != MPSDataTypeFloat32 &&
-                     input_type != MPSDataTypeFloat16) {
-            inputCastDtype = MPSDataTypeFloat32;
+             (dtype.value() == kFloat || dtype.value() == kHalf || dtype.value() == kInt ||
+             (dtype.value() == kLong && macOS13_3_plus))) {
+            inputCastType = getMPSDataType(dtype.value());
           } else if (!is_macos_13_or_newer() && input_type == MPSDataTypeFloat16) {
-            inputCastDtype = MPSDataTypeFloat32;
+            inputCastType = MPSDataTypeFloat32;
           }
 
-          if (inputCastDtype != MPSDataTypeInvalid) {
-            castInputTensor = [mpsGraph castTensor:inputTensor
-                                            toType:inputCastDtype
-                                              name:@"castInputTensor"];
+          if (inputCastType != MPSDataTypeInvalid) {
+            castInputTensor = castMPSTensor(mpsGraph, inputTensor, inputCastType);
+          } else {
+            castInputTensor = castToIHFTypes(mpsGraph, inputTensor, input_t, /*includesInt64=*/macOS13_3_plus);
           }
 
           MPSGraphTensor* castOutputTensor = nil;
@@ -276,14 +269,9 @@
                                                            name:nil];
           }
 
-          MPSGraphTensor* outputTensor = nil;
-
-          if (output_t.scalar_type() != ScalarType::Float) {
-            outputTensor = [mpsGraph castTensor:castOutputTensor
-                                         toType:getMPSDataType(output_t.scalar_type())
-                                           name:@"outputTensor"];
-          } else {
-            outputTensor = castOutputTensor;
+          MPSGraphTensor* outputTensor = castOutputTensor;
+          if (getMPSDataType(output_t.scalar_type()) != [castOutputTensor dataType]) {
+            outputTensor = castMPSTensor(mpsGraph, castOutputTensor, output_t.scalar_type());
           }
 
           newCachedGraph->inputTensor_ = inputTensor;
@@ -959,6 +947,9 @@
     return;
   }
 
+  bool macOS13_3_plus = is_macos_13_or_newer(MacOSVersion::MACOS_VER_13_3_PLUS);
+  MPS_CHECK_INT64_OP_SUPPORTED(input_t, macOS13_3_plus, "any_out");
+
   MPSGraphCache* cache_ = MPSGraphCache::getInstance();
   int64_t dim_ = maybe_wrap_dim(dim, input_t.dim());
   native::zero_numel_check_dims(input_t, dim_, "any()");
@@ -987,29 +978,18 @@
           MPSGraph* mpsGraph = make_mps_graph();
           newCachedGraph = new CachedGraph(mpsGraph);
 
-          MPSGraphTensor* outputTensor;
           MPSDataType input_type = getMPSDataType(input_t.scalar_type());
           MPSGraphTensor* inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, input_type, input_t_shape);
 
-          if (input_type != MPSDataTypeInt32 &&
-              input_type != MPSDataTypeFloat32 &&
-              input_type != MPSDataTypeFloat16) {
-            MPSGraphTensor* inputCastedTensor = [mpsGraph castTensor:inputTensor
-                                                              toType:MPSDataTypeInt32
-                                                                name:@"any_all"];
-            MPSGraphTensor* outputCastedTensor = [mpsGraph reductionOrWithTensor:inputCastedTensor
-                                                                              axis:dim_
-                                                                              name:nil];
-            outputTensor = [mpsGraph castTensor:outputCastedTensor
-                                          toType:MPSDataTypeBool
-                                            name:@"any"];
-          } else {
-            MPSGraphTensor* outputUncastedTensor = [mpsGraph reductionOrWithTensor:inputTensor
-                                                                                axis:dim_
-                                                                                name:nil];
-            outputTensor = [mpsGraph castTensor:outputUncastedTensor
-                                          toType:MPSDataTypeBool
-                                            name:@"any"];
+          MPSGraphTensor* castInputTensor = castToIHFTypes(mpsGraph, inputTensor, input_t, /*includesInt64=*/macOS13_3_plus);
+          MPSGraphTensor* castOutputTensor = [mpsGraph reductionOrWithTensor:castInputTensor
+                                                                        axis:dim_
+                                                                        name:nil];
+          MPSGraphTensor* outputTensor = castOutputTensor;
+          if (MPSDataTypeBool != [castOutputTensor dataType]) {
+            outputTensor = [mpsGraph castTensor:castOutputTensor
+                                         toType:MPSDataTypeBool
+                                           name:@"outputTensor"];
           }
           newCachedGraph->inputTensor_ = inputTensor;
           newCachedGraph->outputTensor_ = outputTensor;
@@ -1044,6 +1024,9 @@
     return;
   }
 
+  bool macOS13_3_plus = is_macos_13_or_newer(MacOSVersion::MACOS_VER_13_3_PLUS);
+  MPS_CHECK_INT64_OP_SUPPORTED(input_t, macOS13_3_plus, "any_all_out");
+
   auto cache_ = MPSGraphCache::getInstance();
   auto stream = at::mps::getCurrentMPSStream();
 
@@ -1061,29 +1044,16 @@
           MPSGraph* mpsGraph = make_mps_graph();
           newCachedGraph = new CachedGraph(mpsGraph);
 
-          MPSGraphTensor* outputTensor;
           MPSDataType input_type = getMPSDataType(input_t.scalar_type());
           MPSGraphTensor* inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, input_type, input_t_shape);
+          MPSGraphTensor* castInputTensor = castToIHFTypes(mpsGraph, inputTensor, input_t, /*includesInt64=*/macOS13_3_plus);
+          MPSGraphTensor* castOutputTensor = [mpsGraph reductionOrWithTensor:castInputTensor
+                                                                            axes:nil
+                                                                            name:nil];
 
-          if (input_type != MPSDataTypeInt32 &&
-              input_type != MPSDataTypeFloat32 &&
-              input_type != MPSDataTypeFloat16) {
-              MPSGraphTensor* inputCastedTensor = [mpsGraph castTensor:inputTensor
-                                                                toType:MPSDataTypeInt32
-                                                                  name:@"any_all"];
-              MPSGraphTensor* outputCastedTensor = [mpsGraph reductionOrWithTensor:inputCastedTensor
-                                                                                axes:nil
-                                                                                name:nil];
-              outputTensor = [mpsGraph castTensor:outputCastedTensor
-                                            toType:MPSDataTypeBool
-                                              name:@"any_all"];
-          } else {
-              MPSGraphTensor* outputUncastedTensor = [mpsGraph reductionOrWithTensor:inputTensor
-                                                                                  axes:nil
-                                                                                  name:nil];
-              outputTensor = [mpsGraph castTensor:outputUncastedTensor
-                                            toType:MPSDataTypeBool
-                                              name:@"any_all"];
+          MPSGraphTensor* outputTensor = castOutputTensor;
+          if (getMPSDataType(output_t.scalar_type()) != [castOutputTensor dataType]) {
+            outputTensor = castMPSTensor(mpsGraph, castOutputTensor, output_t.scalar_type());
           }
           newCachedGraph->inputTensor_ = inputTensor;
           newCachedGraph->outputTensor_ = outputTensor;
@@ -1119,6 +1089,9 @@
     return;
   }
 
+  bool macOS13_3_plus = is_macos_13_or_newer(MacOSVersion::MACOS_VER_13_3_PLUS);
+  MPS_CHECK_INT64_OP_SUPPORTED(input_t, macOS13_3_plus, "all_out");
+
   MPSGraphCache* cache_ = MPSGraphCache::getInstance();
   int64_t dim_ = maybe_wrap_dim(dim, input_t.dim());
   native::zero_numel_check_dims(input_t, dim_, "all()");
@@ -1147,30 +1120,15 @@
           MPSGraph* mpsGraph = make_mps_graph();
           newCachedGraph = new CachedGraph(mpsGraph);
 
-          MPSGraphTensor* outputTensor;
           MPSDataType input_type = getMPSDataType(input_t.scalar_type());
           MPSGraphTensor* inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, input_type, input_t_shape);
-
-          if (input_type != MPSDataTypeInt32 &&
-              input_type != MPSDataTypeFloat32 &&
-              input_type != MPSDataTypeFloat16 )
-          {
-              MPSGraphTensor* inputCastedTensor = [mpsGraph castTensor:inputTensor
-                                                                toType:MPSDataTypeInt32
-                                                                  name:@"all_all"];
-              MPSGraphTensor* outputCastedTensor = [mpsGraph reductionAndWithTensor:inputCastedTensor
-                                                                               axis:dim_
-                                                                               name:nil];
-              outputTensor = [mpsGraph castTensor:outputCastedTensor
-                                           toType:MPSDataTypeBool
-                                             name:@"all"];
-          } else {
-              MPSGraphTensor* outputUncastedTensor = [mpsGraph reductionAndWithTensor:inputTensor
-                                                                                 axis:dim_
-                                                                                 name:nil];
-              outputTensor = [mpsGraph castTensor:outputUncastedTensor
-                                           toType:MPSDataTypeBool
-                                             name:@"all"];
+          MPSGraphTensor* castInputTensor = castToIHFTypes(mpsGraph, inputTensor, input_t, /*includesInt64=*/macOS13_3_plus);
+          MPSGraphTensor* castOutputTensor = [mpsGraph reductionAndWithTensor:castInputTensor
+                                                                         axis:dim_
+                                                                         name:nil];
+          MPSGraphTensor* outputTensor = castOutputTensor;
+          if (MPSDataTypeBool != [castOutputTensor dataType]) {
+            outputTensor = castMPSTensor(mpsGraph, castOutputTensor, MPSDataTypeBool);
           }
           newCachedGraph->inputTensor_ = inputTensor;
           newCachedGraph->outputTensor_ = outputTensor;
@@ -1199,6 +1157,9 @@
     return;
   }
 
+  bool macOS13_3_plus = is_macos_13_or_newer(MacOSVersion::MACOS_VER_13_3_PLUS);
+  MPS_CHECK_INT64_OP_SUPPORTED(input_t, macOS13_3_plus, "all_all_out");
+
   MPSGraphCache* cache_ = MPSGraphCache::getInstance();
 
   auto stream = at::mps::getCurrentMPSStream();
@@ -1215,30 +1176,17 @@
           MPSGraph* mpsGraph = make_mps_graph();
           newCachedGraph = new CachedGraph(mpsGraph);
 
-          MPSGraphTensor* outputTensor;
           MPSDataType input_type = getMPSDataType(input_t.scalar_type());
           MPSGraphTensor* inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, input_type, input_t_shape);
-
-          if (input_type != MPSDataTypeInt32 &&
-              input_type != MPSDataTypeFloat32 &&
-              input_type != MPSDataTypeFloat16) {
-              MPSGraphTensor* inputCastedTensor = [mpsGraph castTensor:inputTensor
-                                                                toType:MPSDataTypeInt32
-                                                                  name:@"all_all"];
-              MPSGraphTensor* outputCastedTensor = [mpsGraph reductionAndWithTensor:inputCastedTensor
-                                                                               axes:nil
-                                                                               name:nil];
-              outputTensor = [mpsGraph castTensor:outputCastedTensor
-                                           toType:MPSDataTypeBool
-                                             name:@"all_all"];
-          } else {
-              MPSGraphTensor* outputUncastedTensor = [mpsGraph reductionAndWithTensor:inputTensor
-                                                                                 axes:nil
-                                                                                 name:nil];
-              outputTensor = [mpsGraph castTensor:outputUncastedTensor
-                                           toType:MPSDataTypeBool
-                                             name:@"all_all"];
+          MPSGraphTensor* castInputTensor = castToIHFTypes(mpsGraph, inputTensor, input_t, /*includesInt64=*/macOS13_3_plus);
+          MPSGraphTensor* castOutputTensor = [mpsGraph reductionAndWithTensor:castInputTensor
+                                                                            axes:nil
+                                                                            name:nil];
+          MPSGraphTensor* outputTensor = castOutputTensor;
+          if (MPSDataTypeBool != [castOutputTensor dataType]) {
+            outputTensor = castMPSTensor(mpsGraph, castOutputTensor, MPSDataTypeBool);
           }
+
           newCachedGraph->inputTensor_ = inputTensor;
           newCachedGraph->outputTensor_ = outputTensor;
 
@@ -1268,9 +1216,8 @@
   (const Tensor& input_t,
    MPSReductionType reduction_type,
    const std::string& func_name) {
-  if (input_t.scalar_type() == ScalarType::Long) {
-    TORCH_WARN_ONCE("MPS: no support for int64 min/max ops, casting it to int32");
-  }
+  bool macOS13_3_plus = is_macos_13_or_newer(MacOSVersion::MACOS_VER_13_3_PLUS);
+  MPS_CHECK_INT64_OP_SUPPORTED(input_t, macOS13_3_plus, "min_max");
 
   using CachedGraph = MPSUnaryCachedGraph;
 
@@ -1297,40 +1244,27 @@
 
           MPSGraphTensor* inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, input_t);
 
-          MPSGraphTensor* outputTensor = nil;
-          MPSGraphTensor* castInputTensor = nil;
           MPSGraphTensor* castOutputTensor = nil;
-
-          if (input_t.scalar_type() != ScalarType::Float &&
-              input_t.scalar_type() != ScalarType::Int   &&
-              input_t.scalar_type() != ScalarType::Half) {
-            castInputTensor =  [mpsGraph castTensor:inputTensor
-                                             toType:MPSDataTypeInt32
-                                               name:@"castInputTensor"];
-          } else {
-            castInputTensor = inputTensor;
-          }
+          MPSGraphTensor* castInputTensor = castToIHFTypes(mpsGraph, inputTensor, input_t, /*includesInt64=*/macOS13_3_plus);
 
           NSArray<NSNumber*>* axes = getTensorAxes(input_t);
           if (reduction_type == MPSReductionType::MAX) {
-            outputTensor = [mpsGraph reductionMaximumWithTensor:castInputTensor
+            castOutputTensor = [mpsGraph reductionMaximumWithTensor:castInputTensor
                                                            axes:axes
                                                            name:nil];
           } else if(reduction_type == MPSReductionType::MIN) {
-            outputTensor = [mpsGraph reductionMinimumWithTensor:castInputTensor
+            castOutputTensor = [mpsGraph reductionMinimumWithTensor:castInputTensor
                                                            axes:axes
                                                            name:nil];
           }
 
-          if(input_t.scalar_type() == ScalarType::Long) {
-            castOutputTensor =  [mpsGraph castTensor:outputTensor
-                                             toType:MPSDataTypeInt64
-                                               name:@"castInputTensor"];
-          } else {
-            castOutputTensor = outputTensor;
+          MPSGraphTensor* outputTensor = castOutputTensor;
+          if (getMPSDataType(output_t.scalar_type()) != [castOutputTensor dataType]) {
+            outputTensor = castMPSTensor(mpsGraph, castOutputTensor, output_t.scalar_type());
           }
+
           newCachedGraph->inputTensor_ = inputTensor;
-          newCachedGraph->outputTensor_ = castOutputTensor;
+          newCachedGraph->outputTensor_ = outputTensor;
         }
         return newCachedGraph;
       });
@@ -1373,7 +1307,8 @@
   const Tensor& indices_t,
   MPSReductionType reduction_type,
   const std::string& func_name) {
-  TORCH_CHECK(input_t.scalar_type() != ScalarType::Long, "MPS does not support min/max ops with int64 input");
+  bool macOS13_3_plus = is_macos_13_or_newer(MacOSVersion::MACOS_VER_13_3_PLUS);
+  MPS_CHECK_INT64_OP_SUPPORTED(input_t, macOS13_3_plus, "min_max_out");
 
   if (output_t.numel() == 0) {
     return;
@@ -1423,26 +1358,17 @@
 
           MPSGraphTensor* inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, input_t);
           MPSGraphTensor* outputTensor = nil;
+          MPSGraphTensor* castInputTensor = castToIHFTypes(mpsGraph, inputTensor, input_t, /*includesInt64=*/macOS13_3_plus);;
 
-          MPSGraphTensor* castInputTensor = inputTensor;
-          bool castOutput = false;
-          if(input_t.scalar_type() != ScalarType::Float &&
-             input_t.scalar_type() != ScalarType::Int   &&
-             input_t.scalar_type() != ScalarType::Half) {
-            castInputTensor =  [mpsGraph castTensor:inputTensor
-                                             toType:MPSDataTypeInt32
-                                               name:@"castInputTensor"];
-            castOutput = true;
-          }
-
-          if(reduction_type == MPSReductionType::MAX)
+          if(reduction_type == MPSReductionType::MAX) {
             outputTensor = [mpsGraph reductionMaximumWithTensor:castInputTensor
                                                            axis:(NSInteger)dim_
                                                            name:nil];
-          else if(reduction_type == MPSReductionType::MIN)
+          } else if(reduction_type == MPSReductionType::MIN) {
             outputTensor = [mpsGraph reductionMinimumWithTensor:castInputTensor
                                                            axis:(NSInteger)dim_
                                                            name:nil];
+          }
 
           MPSGraphTensor* argreduceOutTensor = nil;
           if(reduction_type == MPSReductionType::MAX)
@@ -1454,14 +1380,15 @@
                                                                     axis:(NSInteger)dim_
                                                                     name:@"argmax_out"];
 
-          MPSGraphTensor *indicesTensor = [mpsGraph castTensor:argreduceOutTensor
-                                                        toType:MPSDataTypeInt64
-                                                          name:@"cast_out"];
+          MPSGraphTensor *indicesTensor = nil;
+          if ([argreduceOutTensor dataType] != MPSDataTypeInt64) {
+            indicesTensor = [mpsGraph castTensor:argreduceOutTensor
+                                          toType:MPSDataTypeInt64
+                                            name:@"cast_out"];
+          }
 
-          if (castOutput) {
-            outputTensor = [mpsGraph castTensor:outputTensor
-                                         toType:getMPSDataType(output_t.scalar_type())
-                                           name:@"cast_out"];
+          if ([outputTensor dataType] != getMPSDataType(output_t.scalar_type())) {
+            outputTensor = castMPSTensor(mpsGraph, outputTensor, output_t.scalar_type());
           }
           newCachedGraph->inputTensor_ = inputTensor;
           newCachedGraph->outputTensor_ = outputTensor;
@@ -1526,6 +1453,9 @@
   using CachedGraph = MPSUnaryCachedGraph;
   auto cache_ = MPSGraphCache::getInstance();
 
+  bool macOS13_3_plus = is_macos_13_or_newer(MacOSVersion::MACOS_VER_13_3_PLUS);
+  MPS_CHECK_INT64_OP_SUPPORTED(input_t, macOS13_3_plus, "argmax_argmin_out");
+
   int64_t dim_ = -1;
 
   if (dim.has_value()) {
@@ -1585,18 +1515,9 @@
           newCachedGraph = new CachedGraph(mpsGraph);
 
           MPSGraphTensor* inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, getMPSDataType(input_t.scalar_type()), apparent_in_shape);
-
-          MPSGraphTensor* castInputTensor = inputTensor;
           MPSGraphTensor* argreduceOutTensor = nil;
 
-          if (input_t.scalar_type() != ScalarType::Float &&
-              input_t.scalar_type() != ScalarType::Int   &&
-              input_t.scalar_type() != ScalarType::Half) {
-            castInputTensor =  [mpsGraph castTensor: inputTensor
-                                             toType: MPSDataTypeFloat32
-                                               name: @"castInputTensor"];
-          }
-
+          MPSGraphTensor* castInputTensor = castToIHFTypes(mpsGraph, inputTensor, input_t, /*includesInt64=*/macOS13_3_plus);
           if (reduction_type == MPSReductionType::MAX) {
             argreduceOutTensor = [mpsGraph reductionArgMaximumWithTensor: castInputTensor
                                                                     axis: (NSInteger)dim_
@@ -1606,9 +1527,11 @@
                                                                     axis: (NSInteger)dim_
                                                                     name: nil];
           }
-          MPSGraphTensor* outputTensor = [mpsGraph castTensor: argreduceOutTensor
-                                                       toType: MPSDataTypeInt64
-                                                         name: @"castOutputTensor"];
+
+          MPSGraphTensor* outputTensor = argreduceOutTensor;
+          if (getMPSDataType(output_t.scalar_type()) != [argreduceOutTensor dataType]) {
+            outputTensor = castMPSTensor(mpsGraph, argreduceOutTensor, output_t.scalar_type());
+          }
 
           MPSGraphTensor* outputClampedTensor = [mpsGraph clampWithTensor: outputTensor
                                                            minValueTensor: [mpsGraph constantWithScalar:0 dataType:MPSDataTypeInt64]
@@ -1752,7 +1675,8 @@
     return at::median(input_t.to("cpu"));
   }
 
-  TORCH_CHECK(input_t.scalar_type() != ScalarType::Long, "MPS does not support median op with int64 input");
+  bool macOS13_3_plus = is_macos_13_or_newer(MacOSVersion::MACOS_VER_13_3_PLUS);
+  MPS_CHECK_INT64_OP_SUPPORTED(input_t, macOS13_3_plus, "median");
 
   using CachedGraph = MPSUnaryCachedGraph;
 
@@ -1783,20 +1707,11 @@
           MPSGraph* mpsGraph = make_mps_graph();
           newCachedGraph = new CachedGraph(mpsGraph);
           auto inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, input_t);
-          auto reshapedTensor = [mpsGraph reshapeTensor: inputTensor
+          MPSGraphTensor* castInputTensor = castToIHFTypes(mpsGraph, inputTensor, input_t, /*includesInt64=*/macOS13_3_plus);
+
+          auto reshapedTensor = [mpsGraph reshapeTensor: castInputTensor
                                               withShape: @[@-1]
                                                    name: nil];
-          MPSDataType dataType = [inputTensor dataType];
-          // #issue 104398441 sortWithTensor only supports following types, cast if necessary
-          if (dataType != MPSDataTypeInt32 &&
-              dataType != MPSDataTypeFloat32 &&
-              dataType != MPSDataTypeFloat16) {
-              dataType = (dataType & MPSDataTypeFloatBit) ? MPSDataTypeFloat32 : MPSDataTypeInt32;
-              reshapedTensor = [mpsGraph castTensor:reshapedTensor
-                                      toType:dataType
-                                        name:@"castReshapedTensor"];
-          }
-
           auto sortedTensor = [mpsGraph sortWithTensor: reshapedTensor
                                                   axis: ((NSUInteger) (int)0)
                                                   name: nil];
@@ -1858,6 +1773,8 @@
   };
 
   auto cache_ = MPSGraphCache::getInstance();
+  bool macOS13_3_plus = is_macos_13_or_newer(MacOSVersion::MACOS_VER_13_3_PLUS);
+  MPS_CHECK_INT64_OP_SUPPORTED(input_t, macOS13_3_plus, "median_out");
 
   int64_t dim_ = maybe_wrap_dim(dim, input_t.dim());
 
@@ -1888,34 +1805,22 @@
           auto mpsGraph = make_mps_graph();
           newCachedGraph = new CachedGraph(mpsGraph);
 
-          MPSGraphTensor* inputTensor = mpsGraphUnrankedPlaceHolder(mpsGraph, getMPSDataType(input_t.scalar_type()));
-          MPSGraphTensor* outputTensor = nil;
-          MPSGraphTensor* castInputTensor = inputTensor;
-          MPSDataType dataType = getMPSDataType(input_t.scalar_type());
-          // #issue 104398441 sortWithTensor only supports following types, cast if necessary
-          if (dataType != MPSDataTypeInt32 &&
-              dataType != MPSDataTypeFloat32 &&
-              dataType != MPSDataTypeFloat16) {
-              dataType = (dataType & MPSDataTypeFloatBit) ? MPSDataTypeFloat32 : MPSDataTypeInt32;
-              castInputTensor = [mpsGraph castTensor:inputTensor
-                                      toType:dataType
-                                        name:@"castInputTensor"];
-          }
+          MPSGraphTensor* inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, input_t);
+          MPSGraphTensor* castInputTensor = castToIHFTypes(mpsGraph, inputTensor, input_t, /*includesInt64=*/macOS13_3_plus);
 
-          MPSGraphTensor * sortedTensor = [mpsGraph
-                                              sortWithTensor:castInputTensor
-                                              axis:((NSUInteger) (int)dim_)
-                                              name:nil];
+          MPSGraphTensor * sortedTensor = [mpsGraph sortWithTensor:castInputTensor
+                                                              axis:((NSUInteger) (int)dim_)
+                                                              name:nil];
 
-          outputTensor = [mpsGraph sliceTensor:sortedTensor
-                                                    dimension:dim_
-                                                    start:((NSUInteger) (int)((dim_total_elements+1)/2 ) - 1)
-                                                    length:1
-                                                    name:nil];
+          MPSGraphTensor* outputTensor = [mpsGraph sliceTensor:sortedTensor
+                                                     dimension:dim_
+                                                         start:((NSUInteger) (int)((dim_total_elements+1)/2 ) - 1)
+                                                        length:1
+                                                          name:nil];
           MPSGraphTensor* argreduceOutTensor = nil;
             argreduceOutTensor = [mpsGraph argSortWithTensor:castInputTensor
-                                                                    axis:(NSInteger)dim_
-                                                                    name:@"argmax_out"];
+                                                        axis:(NSInteger)dim_
+                                                        name:@"argmax_out"];
           MPSGraphTensor* argOutputTensor = [mpsGraph sliceTensor:argreduceOutTensor
                                                     dimension:dim_
                                                     start:((NSUInteger) (int)((dim_total_elements+1)/2 ) - 1)
@@ -1978,8 +1883,8 @@
   bool keepdim,
   at::Tensor & values,
   at::Tensor & indices){
-
-  TORCH_CHECK(input_t.scalar_type() != ScalarType::Long, "MPS does not support median ops with int64 input");
+  bool macOS13_3_plus = is_macos_13_or_newer(MacOSVersion::MACOS_VER_13_3_PLUS);
+  MPS_CHECK_INT64_OP_SUPPORTED(input_t, macOS13_3_plus, "median_out");
 
   int64_t dim_ = maybe_wrap_dim(dim, input_t.dim());
   native::zero_numel_check_dims(input_t, dim_, "max()");
diff --git a/aten/src/ATen/native/mps/operations/Repeat.mm b/aten/src/ATen/native/mps/operations/Repeat.mm
index d2155d2..2052bc8 100644
--- a/aten/src/ATen/native/mps/operations/Repeat.mm
+++ b/aten/src/ATen/native/mps/operations/Repeat.mm
@@ -230,10 +230,10 @@
 Tensor repeat_interleave_mps(const Tensor& repeat_, c10::optional<int64_t> output_size) {
   Tensor output;
   Tensor repeat = repeat_;
-  if (repeat.scalar_type() == kLong) {
+  if (repeat.scalar_type() == kLong && !is_macos_13_or_newer(MacOSVersion::MACOS_VER_13_3_PLUS)) {
     // #103810551: `repeat_interleave_common` uses cumsum to calculate the final shape of output,
     // which currently doesn't support int64_t as input. Casting internally the indices to int32_t.
-    TORCH_WARN_ONCE("MPS: no support for int64 repeats mask, casting it to int32");
+    TORCH_WARN_ONCE("MPS: no support for int64 repeats mask, casting it to int32. Support has been added in macOS 13.3");
     repeat = repeat.to(kInt);
   }
   AT_DISPATCH_INDEX_TYPES(repeat.scalar_type(), "repeat_interleave_mps", [&]() {
diff --git a/aten/src/ATen/native/mps/operations/Sort.mm b/aten/src/ATen/native/mps/operations/Sort.mm
index 4b3bb69..7b9706a 100644
--- a/aten/src/ATen/native/mps/operations/Sort.mm
+++ b/aten/src/ATen/native/mps/operations/Sort.mm
@@ -18,6 +18,10 @@
  const Tensor& values,
  const Tensor& indices) {
   using namespace mps;
+
+  bool macOS13_3_plus = is_macos_13_or_newer(MacOSVersion::MACOS_VER_13_3_PLUS);
+  MPS_CHECK_INT64_OP_SUPPORTED(self, macOS13_3_plus, "sort_stable_out");
+
   values.copy_(self);
   // check if self is scalar
   dim = maybe_wrap_dim(dim, self.dim(), true);
@@ -35,9 +39,6 @@
     indices.copy_(cpu_indices);
     return;
   }
-  if (self.scalar_type() == ScalarType::Long) {
-    TORCH_WARN_ONCE("MPS: no support for int64 min/max ops, casting it to int32");
-  }
 
   MPSStream* stream = getCurrentMPSStream();
   struct CachedGraph : public MPSCachedGraph {
@@ -60,17 +61,17 @@
             newCachedGraph = new CachedGraph(mpsGraph);
             newCachedGraph->selfTensor = mpsGraphRankedPlaceHolder(mpsGraph, getMPSDataType(self.scalar_type()), input_shape);
 
-            MPSGraphTensor* castInputTensor = castToIHFTypes(mpsGraph, newCachedGraph->selfTensor, self);
+            MPSGraphTensor* castInputTensor = castToIHFTypes(mpsGraph, newCachedGraph->selfTensor, self, /*includesInt64=*/macOS13_3_plus);
             MPSGraphTensor * sortedTensor = [mpsGraph sortWithTensor:castInputTensor
                                                                 axis:(NSInteger)dim
                                                                 descending:(BOOL)descending
                                                                 name:@"sort_out"];
-            sortedTensor = castFromIHFTypes(mpsGraph, sortedTensor, values);
+            sortedTensor = castFromIHFTypes(mpsGraph, sortedTensor, values, /*includesInt64=*/macOS13_3_plus);
             MPSGraphTensor* argSortedTensor = [mpsGraph argSortWithTensor:castInputTensor
                                                                      axis:(NSInteger)dim
                                                                      descending:(BOOL)descending
                                                                      name:@"argsort_out"];
-            argSortedTensor = castFromIHFTypes(mpsGraph, argSortedTensor, indices);
+            argSortedTensor = castFromIHFTypes(mpsGraph, argSortedTensor, indices, /*includesInt64=*/macOS13_3_plus);
             newCachedGraph->valuesTensor = sortedTensor;
             newCachedGraph->indicesTensor = argSortedTensor;
         }
diff --git a/aten/src/ATen/native/mps/operations/UnaryOps.mm b/aten/src/ATen/native/mps/operations/UnaryOps.mm
index 76b7e25..6396ff3 100644
--- a/aten/src/ATen/native/mps/operations/UnaryOps.mm
+++ b/aten/src/ATen/native/mps/operations/UnaryOps.mm
@@ -418,6 +418,7 @@
  c10::optional<ScalarType> dtype,
  const Tensor& result) {
 
+  bool macOS13_3_plus = is_macos_13_or_newer(MacOSVersion::MACOS_VER_13_3_PLUS);
   auto nDims = self.dim();
   auto wrapped_dim = maybe_wrap_dim(dim, nDims);
   TORCH_CHECK(wrapped_dim >=0 && wrapped_dim < std::max(1LL, self.ndimension()), "Expected wrapped dim to be between 0 and ", self.ndimension(), " but got ", wrapped_dim , "(original dim is ", dim, ")");
@@ -428,18 +429,26 @@
     return;
   }
   auto input = dtype.has_value() ? self.to(dtype.value()) : self;
-  TORCH_CHECK(input.scalar_type() != ScalarType::Long, "MPS does not support cumsum op with int64 input");
+
+  // issue #103810551: cumsum is horribly broken for int8, int16 and as chances for overflow is pretty high, cast to int32
+  // fixed in macOS 13.3
+  bool castInputData = (isIntegralType(input.scalar_type()) &&
+                        input.scalar_type() != ScalarType::Int &&
+                        input.scalar_type() != ScalarType::Long);
+
+  TORCH_CHECK(macOS13_3_plus || input.scalar_type() != ScalarType::Long,
+              "MPS does not support cumsum op with int64 input. Support has been added in macOS 13.3");
+
   mps::unary_op(input, result, "cumsum_out_mp" + std::to_string(dim),
                 ^ MPSGraphTensor* (MPSGraph* mpsGraph, MPSGraphTensor* inputTensor) {
-       // cumsum is horribly broken for int8, int16 and as chances for overflow is pretty high, cast to int32
-       if (isIntegralType(input.scalar_type()) && input.scalar_type() !=ScalarType::Int) {
+
+       if (castInputData) {
            inputTensor = mps::castMPSTensor(mpsGraph, inputTensor, ScalarType::Int);
        }
        auto rc = [mpsGraph cumulativeSumWithTensor: inputTensor
                                               axis: dim
                                               name: nil];
-       if (result.scalar_type()!= input.scalar_type() ||
-           (isIntegralType(input.scalar_type()) && input.scalar_type() !=ScalarType::Int)) {
+       if ((mps::getMPSDataType(result.scalar_type()) != [rc dataType]) || castInputData) {
          return mps::castMPSTensor(mpsGraph, rc, result.scalar_type());
        }
        return rc;
diff --git a/test/test_mps.py b/test/test_mps.py
index 4a930ce..0c77fa5 100644
--- a/test/test_mps.py
+++ b/test/test_mps.py
@@ -2840,7 +2840,7 @@
             helper(torch.int64)
         except Exception as e:
             e_string = str(e)
-            self.assertEqual(e_string, "MPS does not support cumsum op with int64 input")
+            self.assertEqual(e_string, "MPS does not support cumsum op with int64 input. Support has been added in macOS 13.3")
 
     def test_cumsum_minus_one_axis(self):
         def helper(dtype):
@@ -9550,7 +9550,7 @@
         'cos': ['b8', 'f32', 'i16', 'i32', 'u8', 'i64'],
         'cosh': ['b8', 'f32', 'i16', 'i32', 'u8', 'i64'],
         'cov': ['f32'],
-        'cumsum': ['f16', 'f32', 'int16', 'int32'],
+        'cumsum': ['i8', 'b8', 'f16', 'f32', 'i16', 'i32', 'i64'],
         'deg2rad': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'],
         'diag': ['f32', 'i32'],
         'diag_embed': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64'],
@@ -10181,7 +10181,7 @@
                 self.assertEqual(cpu_out, mps_out, atol=atol, rtol=rtol)
 
             except Exception as e:
-                if any(s in str(e).lower() for s in ["float16", "div truc rounding"]):
+                if any(s in str(e).lower() for s in ["int64", "float16", "div truc rounding"]):
                     self.skipTest(f"Expected Runtime Error: {str(e)}")
 
                 if not generate_new_truth: