[MPS] Remove remaining casts from 13.3 (#95870)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/95870
Approved by: https://github.com/kulinseth
diff --git a/aten/src/ATen/native/mps/OperationUtils.h b/aten/src/ATen/native/mps/OperationUtils.h
index 622efc9..bf8eaa0 100644
--- a/aten/src/ATen/native/mps/OperationUtils.h
+++ b/aten/src/ATen/native/mps/OperationUtils.h
@@ -245,9 +245,10 @@
// Common math operations
MPSGraphTensor* log1p(MPSGraph* mpsGraph, MPSGraphTensor* inputTensor);
-#define MPS_CHECK_INT64_OP_SUPPORTED(input_tensor, mac_os_13_3_plus, op_name) \
- if (!mac_os_13_3_plus && input_tensor.scalar_type() == kLong) { \
- TORCH_WARN_ONCE("MPS: no support for int64 for ", op_name, ", casting to int32. Support has been added in macOS 13.3"); \
+#define MPS_CHECK_INT64_OP_SUPPORTED(input_tensor, mac_os_13_3_plus, op_name) \
+ if (!mac_os_13_3_plus && input_tensor.scalar_type() == kLong) { \
+ TORCH_WARN_ONCE("MPS: no support for int64 for ", op_name, \
+ ", downcasting to a smaller data type (int32/float32). Native support for int64 has been added in macOS 13.3."); \
}
} // namespace mps
diff --git a/aten/src/ATen/native/mps/OperationUtils.mm b/aten/src/ATen/native/mps/OperationUtils.mm
index 0ecdfd0..4547512 100644
--- a/aten/src/ATen/native/mps/OperationUtils.mm
+++ b/aten/src/ATen/native/mps/OperationUtils.mm
@@ -45,7 +45,7 @@
condition = condition && (dataType != MPSDataTypeInt64);
}
if (condition) {
- dataType = ((dataType & MPSDataTypeFloatBit) || (dataType == MPSDataTypeInt64)) ? MPSDataTypeFloat32 : MPSDataTypeInt32;
+ dataType = (dataType & MPSDataTypeFloatBit) ? MPSDataTypeFloat32 : MPSDataTypeInt32;
return [mpsGraph castTensor:inputTensor
toType:dataType
name:@"castInputTensor"];
diff --git a/aten/src/ATen/native/mps/operations/ReduceOps.mm b/aten/src/ATen/native/mps/operations/ReduceOps.mm
index ceacc61..732db9e 100644
--- a/aten/src/ATen/native/mps/operations/ReduceOps.mm
+++ b/aten/src/ATen/native/mps/operations/ReduceOps.mm
@@ -191,7 +191,7 @@
@autoreleasepool {
MPSGraph* mpsGraph = make_mps_graph();
newCachedGraph = new CachedGraph(mpsGraph);
- MPSDataType input_type = getMPSDataType(input_t.scalar_type());
+ auto inputScalarType = input_t.scalar_type();
MPSGraphTensor* inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, input_t);
MPSGraphTensor* castInputTensor = inputTensor;
@@ -200,14 +200,15 @@
(dtype.value() == kFloat || dtype.value() == kHalf || dtype.value() == kInt ||
(dtype.value() == kLong && macOS13_3_plus))) {
inputCastType = getMPSDataType(dtype.value());
- } else if (!is_macos_13_or_newer() && input_type == MPSDataTypeFloat16) {
- inputCastType = MPSDataTypeFloat32;
+ } else if (inputScalarType != kInt && inputScalarType != kHalf && inputScalarType != kFloat &&
+ (inputScalarType != kLong || !macOS13_3_plus)) {
+ inputCastType = getMPSDataType(kFloat);
+ } else if (!is_macos_13_or_newer() && inputScalarType == kHalf) {
+ inputCastType = getMPSDataType(kFloat);
}
if (inputCastType != MPSDataTypeInvalid) {
castInputTensor = castMPSTensor(mpsGraph, inputTensor, inputCastType);
- } else {
- castInputTensor = castToIHFTypes(mpsGraph, inputTensor, input_t, /*includesInt64=*/macOS13_3_plus);
}
MPSGraphTensor* castOutputTensor = nil;
@@ -1503,7 +1504,7 @@
NSString* ns_key = [[apparent_in_shape valueForKey:@"description"] componentsJoinedByString:@","];
string key = func_name + ":" +
to_string(dim_) + ":" +
- getTensorsStringKey(input_t) + ":" +
+ getTensorsStringKey(input_t) + ":" +
string([ns_key UTF8String]);
CachedGraph* cachedGraph = cache_->LookUpAs<CachedGraph>(key);
@@ -1514,10 +1515,15 @@
MPSGraph* mpsGraph = make_mps_graph();
newCachedGraph = new CachedGraph(mpsGraph);
- MPSGraphTensor* inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, getMPSDataType(input_t.scalar_type()), apparent_in_shape);
+ auto inputScalarType = input_t.scalar_type();
+ MPSGraphTensor* inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, getMPSDataType(inputScalarType), apparent_in_shape);
MPSGraphTensor* argreduceOutTensor = nil;
- MPSGraphTensor* castInputTensor = castToIHFTypes(mpsGraph, inputTensor, input_t, /*includesInt64=*/macOS13_3_plus);
+ MPSGraphTensor* castInputTensor = inputTensor;
+ if (inputScalarType != kInt && inputScalarType != kHalf && inputScalarType != kFloat &&
+ (inputScalarType != kLong || !macOS13_3_plus)) {
+ castInputTensor = castMPSTensor(mpsGraph, inputTensor, kFloat);
+ }
if (reduction_type == MPSReductionType::MAX) {
argreduceOutTensor = [mpsGraph reductionArgMaximumWithTensor: castInputTensor
axis: (NSInteger)dim_
diff --git a/aten/src/ATen/native/mps/operations/Sort.mm b/aten/src/ATen/native/mps/operations/Sort.mm
index 7b9706a..00ec364 100644
--- a/aten/src/ATen/native/mps/operations/Sort.mm
+++ b/aten/src/ATen/native/mps/operations/Sort.mm
@@ -66,12 +66,16 @@
axis:(NSInteger)dim
descending:(BOOL)descending
name:@"sort_out"];
- sortedTensor = castFromIHFTypes(mpsGraph, sortedTensor, values, /*includesInt64=*/macOS13_3_plus);
+ if ([sortedTensor dataType] != getMPSDataType(values.scalar_type())) {
+ sortedTensor = castMPSTensor(mpsGraph, sortedTensor, values.scalar_type());
+ }
MPSGraphTensor* argSortedTensor = [mpsGraph argSortWithTensor:castInputTensor
axis:(NSInteger)dim
descending:(BOOL)descending
name:@"argsort_out"];
- argSortedTensor = castFromIHFTypes(mpsGraph, argSortedTensor, indices, /*includesInt64=*/macOS13_3_plus);
+ if ([argSortedTensor dataType] != getMPSDataType(indices.scalar_type())) {
+ argSortedTensor = castMPSTensor(mpsGraph, argSortedTensor, indices.scalar_type());
+ }
newCachedGraph->valuesTensor = sortedTensor;
newCachedGraph->indicesTensor = argSortedTensor;
}
diff --git a/test/test_mps.py b/test/test_mps.py
index 0c77fa5..9877612 100644
--- a/test/test_mps.py
+++ b/test/test_mps.py
@@ -3852,6 +3852,15 @@
helper(2, 8, 4, 4, "min", torch.float16)
helper(2, 8, 4, 4, "min", torch.int64)
+ @unittest.skipIf(product_version < 13.3, "Long data type supported from macOS 13.3 and above")
+ def test_reduction_sum_max_long_val(self):
+ x_mps = torch.tensor([sys.maxsize, sys.maxsize - 10, sys.maxsize - 5, sys.maxsize - 18], device="mps")
+ x_cpu = x_mps.detach().clone().cpu()
+
+ res_mps = torch.sum(x_mps)
+ res_cpu = torch.sum(x_cpu)
+ self.assertEqual(res_mps, res_cpu)
+
# Test forward max
# Note - don't test grad now
def test_max_el(self):