[MPS] Remove remaining casts from 13.3 (#95870)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/95870
Approved by: https://github.com/kulinseth
diff --git a/aten/src/ATen/native/mps/OperationUtils.h b/aten/src/ATen/native/mps/OperationUtils.h
index 622efc9..bf8eaa0 100644
--- a/aten/src/ATen/native/mps/OperationUtils.h
+++ b/aten/src/ATen/native/mps/OperationUtils.h
@@ -245,9 +245,10 @@
 // Common math operations
 MPSGraphTensor* log1p(MPSGraph* mpsGraph, MPSGraphTensor* inputTensor);
 
-#define MPS_CHECK_INT64_OP_SUPPORTED(input_tensor, mac_os_13_3_plus, op_name)                                                   \
-  if (!mac_os_13_3_plus && input_tensor.scalar_type() == kLong) {                                                               \
-     TORCH_WARN_ONCE("MPS: no support for int64 for ", op_name, ", casting to int32. Support has been added in macOS 13.3");    \
+#define MPS_CHECK_INT64_OP_SUPPORTED(input_tensor, mac_os_13_3_plus, op_name)                                           \
+  if (!mac_os_13_3_plus && input_tensor.scalar_type() == kLong) {                                                       \
+     TORCH_WARN_ONCE("MPS: no support for int64 for ", op_name,                                                         \
+     ", downcasting to a smaller data type (int32/float32). Native support for int64 has been added in macOS 13.3.");   \
   }
 
 } // namespace mps
diff --git a/aten/src/ATen/native/mps/OperationUtils.mm b/aten/src/ATen/native/mps/OperationUtils.mm
index 0ecdfd0..4547512 100644
--- a/aten/src/ATen/native/mps/OperationUtils.mm
+++ b/aten/src/ATen/native/mps/OperationUtils.mm
@@ -45,7 +45,7 @@
     condition = condition && (dataType != MPSDataTypeInt64);
   }
   if (condition) {
-    dataType = ((dataType & MPSDataTypeFloatBit) || (dataType == MPSDataTypeInt64)) ? MPSDataTypeFloat32 : MPSDataTypeInt32;
+    dataType = (dataType & MPSDataTypeFloatBit) ? MPSDataTypeFloat32 : MPSDataTypeInt32;
     return [mpsGraph castTensor:inputTensor
                          toType:dataType
                            name:@"castInputTensor"];
diff --git a/aten/src/ATen/native/mps/operations/ReduceOps.mm b/aten/src/ATen/native/mps/operations/ReduceOps.mm
index ceacc61..732db9e 100644
--- a/aten/src/ATen/native/mps/operations/ReduceOps.mm
+++ b/aten/src/ATen/native/mps/operations/ReduceOps.mm
@@ -191,7 +191,7 @@
         @autoreleasepool {
           MPSGraph* mpsGraph = make_mps_graph();
           newCachedGraph = new CachedGraph(mpsGraph);
-          MPSDataType input_type = getMPSDataType(input_t.scalar_type());
+          auto inputScalarType = input_t.scalar_type();
 
           MPSGraphTensor* inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, input_t);
           MPSGraphTensor* castInputTensor = inputTensor;
@@ -200,14 +200,15 @@
              (dtype.value() == kFloat || dtype.value() == kHalf || dtype.value() == kInt ||
              (dtype.value() == kLong && macOS13_3_plus))) {
             inputCastType = getMPSDataType(dtype.value());
-          } else if (!is_macos_13_or_newer() && input_type == MPSDataTypeFloat16) {
-            inputCastType = MPSDataTypeFloat32;
+          } else if (inputScalarType != kInt && inputScalarType != kHalf && inputScalarType != kFloat &&
+                    (inputScalarType != kLong || !macOS13_3_plus)) {
+            inputCastType = getMPSDataType(kFloat);
+          } else if (!is_macos_13_or_newer() && inputScalarType == kHalf) {
+            inputCastType = getMPSDataType(kFloat);
           }
 
           if (inputCastType != MPSDataTypeInvalid) {
             castInputTensor = castMPSTensor(mpsGraph, inputTensor, inputCastType);
-          } else {
-            castInputTensor = castToIHFTypes(mpsGraph, inputTensor, input_t, /*includesInt64=*/macOS13_3_plus);
           }
 
           MPSGraphTensor* castOutputTensor = nil;
@@ -1503,7 +1504,7 @@
     NSString* ns_key = [[apparent_in_shape valueForKey:@"description"] componentsJoinedByString:@","];
     string key = func_name                                + ":" +
                  to_string(dim_)                          + ":" +
-                 getTensorsStringKey(input_t) + ":" +
+                 getTensorsStringKey(input_t)             + ":" +
                  string([ns_key UTF8String]);
     CachedGraph* cachedGraph = cache_->LookUpAs<CachedGraph>(key);
 
@@ -1514,10 +1515,15 @@
           MPSGraph* mpsGraph = make_mps_graph();
           newCachedGraph = new CachedGraph(mpsGraph);
 
-          MPSGraphTensor* inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, getMPSDataType(input_t.scalar_type()), apparent_in_shape);
+          auto inputScalarType = input_t.scalar_type();
+          MPSGraphTensor* inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, getMPSDataType(inputScalarType), apparent_in_shape);
           MPSGraphTensor* argreduceOutTensor = nil;
 
-          MPSGraphTensor* castInputTensor = castToIHFTypes(mpsGraph, inputTensor, input_t, /*includesInt64=*/macOS13_3_plus);
+          MPSGraphTensor* castInputTensor = inputTensor;
+          if (inputScalarType != kInt && inputScalarType != kHalf && inputScalarType != kFloat &&
+             (inputScalarType != kLong || !macOS13_3_plus)) {
+            castInputTensor = castMPSTensor(mpsGraph, inputTensor, kFloat);
+          }
           if (reduction_type == MPSReductionType::MAX) {
             argreduceOutTensor = [mpsGraph reductionArgMaximumWithTensor: castInputTensor
                                                                     axis: (NSInteger)dim_
diff --git a/aten/src/ATen/native/mps/operations/Sort.mm b/aten/src/ATen/native/mps/operations/Sort.mm
index 7b9706a..00ec364 100644
--- a/aten/src/ATen/native/mps/operations/Sort.mm
+++ b/aten/src/ATen/native/mps/operations/Sort.mm
@@ -66,12 +66,16 @@
                                                                 axis:(NSInteger)dim
                                                                 descending:(BOOL)descending
                                                                 name:@"sort_out"];
-            sortedTensor = castFromIHFTypes(mpsGraph, sortedTensor, values, /*includesInt64=*/macOS13_3_plus);
+            if ([sortedTensor dataType] != getMPSDataType(values.scalar_type())) {
+              sortedTensor = castMPSTensor(mpsGraph, sortedTensor, values.scalar_type());
+            }
             MPSGraphTensor* argSortedTensor = [mpsGraph argSortWithTensor:castInputTensor
                                                                      axis:(NSInteger)dim
                                                                      descending:(BOOL)descending
                                                                      name:@"argsort_out"];
-            argSortedTensor = castFromIHFTypes(mpsGraph, argSortedTensor, indices, /*includesInt64=*/macOS13_3_plus);
+            if ([argSortedTensor dataType] != getMPSDataType(indices.scalar_type())) {
+              argSortedTensor = castMPSTensor(mpsGraph, argSortedTensor, indices.scalar_type());
+            }
             newCachedGraph->valuesTensor = sortedTensor;
             newCachedGraph->indicesTensor = argSortedTensor;
         }
diff --git a/test/test_mps.py b/test/test_mps.py
index 0c77fa5..9877612 100644
--- a/test/test_mps.py
+++ b/test/test_mps.py
@@ -3852,6 +3852,15 @@
         helper(2, 8, 4, 4, "min", torch.float16)
         helper(2, 8, 4, 4, "min", torch.int64)
 
+    @unittest.skipIf(product_version < 13.3, "Long data type supported from macOS 13.3 and above")
+    def test_reduction_sum_max_long_val(self):
+        x_mps = torch.tensor([sys.maxsize, sys.maxsize - 10, sys.maxsize - 5, sys.maxsize - 18], device="mps")
+        x_cpu = x_mps.detach().clone().cpu()
+
+        res_mps = torch.sum(x_mps)
+        res_cpu = torch.sum(x_cpu)
+        self.assertEqual(res_mps, res_cpu)
+
     # Test forward max
     # Note - don't test grad now
     def test_max_el(self):