[MPS] Remove casts from reduction/cumsum/sort ops starting with macOS 13.3 (#95817)
MPS in macOS13.3 has added support for int64 in reduction ops / cumsum / sort / argsort. This change removes the hard-coded casts and error messages prior macOS 13.3, allowing the op to run natively with int64.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/95817
Approved by: https://github.com/kulinseth
diff --git a/aten/src/ATen/mps/MPSDevice.mm b/aten/src/ATen/mps/MPSDevice.mm
index 0576f9b..792e7a8 100644
--- a/aten/src/ATen/mps/MPSDevice.mm
+++ b/aten/src/ATen/mps/MPSDevice.mm
@@ -99,8 +99,9 @@
sampleGridWithSourceTensor:coordinateTensor:layout:normalizeCoordinates:relativeCoordinates:alignCorners:paddingMode:samplingMode:constantValue:name:)] == YES;
static bool _macos_13_2_plus = [mpsCD instancesRespondToSelector:@selector(convolution3DWithSourceTensor:weightsTensor:descriptor:name:)] == YES;
static bool _macos_13_3_plus = NO;
- if (@available(macOS 13.3, *))
+ if (@available(macOS 13.3, *)) {
_macos_13_3_plus = YES;
+ }
switch (version) {
case MacOSVersion::MACOS_VER_13_0_PLUS: return _macos_13_0_plus;
diff --git a/aten/src/ATen/native/mps/OperationUtils.h b/aten/src/ATen/native/mps/OperationUtils.h
index 689d58f..622efc9 100644
--- a/aten/src/ATen/native/mps/OperationUtils.h
+++ b/aten/src/ATen/native/mps/OperationUtils.h
@@ -56,8 +56,8 @@
Tensor& scatterViewTensor(const at::Tensor& src, at::Tensor& output);
bool canSliceViewTensor(const Tensor& src, MPSShape *mpsShape);
MPSGraphTensorData* getMPSGraphTensorDataForView(const Tensor& src, MPSShape *mpsShape, const MPSDataType mpsDataType);
-MPSGraphTensor* castToIHFTypes(MPSGraph* mpsGraph, MPSGraphTensor* inputTensor, const Tensor& input);
-MPSGraphTensor* castFromIHFTypes(MPSGraph* mpsGraph, MPSGraphTensor* inputTensor, const Tensor& input);
+MPSGraphTensor* castToIHFTypes(MPSGraph* mpsGraph, MPSGraphTensor* inputTensor, const Tensor& input, bool includesInt64 = false);
+MPSGraphTensor* castFromIHFTypes(MPSGraph* mpsGraph, MPSGraphTensor* inputTensor, const Tensor& input, bool includesInt64 = false);
// The MPSShape could vary based on memory format
MPSShape* getMPSShape(const Tensor& t, c10::MemoryFormat memory_format = MemoryFormat::Contiguous);
@@ -93,6 +93,7 @@
MPSGraphTensor* trunc_tensor(MPSGraph* mpsGraph, MPSGraphTensor* inputTensor);
MPSGraphTensor* convertNHWCtoNCHW(MPSGraph *mpsGraph, MPSGraphTensor* tensor);
MPSGraphTensor* castMPSTensor(MPSGraph *mpsGraph, MPSGraphTensor* tensor, ScalarType toType);
+MPSGraphTensor* castMPSTensor(MPSGraph *mpsGraph, MPSGraphTensor* tensor, MPSDataType toType);
MPSGraphTensorData *getMPSGraphTensorData(MPSGraph* mpsGraph, MPSStream* mpsStream, const Tensor& tensor);
MPSGraphTensorData* getMPSGraphTensorFromScalar(MPSStream* mpsStream, MPSScalar& scalar);
@@ -244,6 +245,10 @@
// Common math operations
MPSGraphTensor* log1p(MPSGraph* mpsGraph, MPSGraphTensor* inputTensor);
+#define MPS_CHECK_INT64_OP_SUPPORTED(input_tensor, mac_os_13_3_plus, op_name) \
+ if (!mac_os_13_3_plus && input_tensor.scalar_type() == kLong) { \
+ TORCH_WARN_ONCE("MPS: no support for int64 for ", op_name, ", casting to int32. Support has been added in macOS 13.3"); \
+ }
} // namespace mps
} // namespace native
diff --git a/aten/src/ATen/native/mps/OperationUtils.mm b/aten/src/ATen/native/mps/OperationUtils.mm
index c5e8b5d1f..0ecdfd0 100644
--- a/aten/src/ATen/native/mps/OperationUtils.mm
+++ b/aten/src/ATen/native/mps/OperationUtils.mm
@@ -38,15 +38,17 @@
// #issue 104398441 sortWithTensor and argsortWithTensor has support of
// Int32, Half and Float32 types. These utilities are to help cast to these
// types.
-MPSGraphTensor* castToIHFTypes(MPSGraph* mpsGraph, MPSGraphTensor* inputTensor, const Tensor& input) {
+MPSGraphTensor* castToIHFTypes(MPSGraph* mpsGraph, MPSGraphTensor* inputTensor, const Tensor& input, bool includesInt64) {
MPSDataType dataType = getMPSDataType(input.scalar_type());
- if (dataType != MPSDataTypeInt32 &&
- dataType != MPSDataTypeFloat32 &&
- dataType != MPSDataTypeFloat16) {
- dataType = (dataType & MPSDataTypeFloatBit) ? MPSDataTypeFloat32 : MPSDataTypeInt32;
- return [mpsGraph castTensor:inputTensor
- toType:dataType
- name:@"castInputTensor"];
+ bool condition = (dataType != MPSDataTypeInt32) && (dataType != MPSDataTypeFloat32) && (dataType != MPSDataTypeFloat16);
+ if (includesInt64) {
+ condition = condition && (dataType != MPSDataTypeInt64);
+ }
+ if (condition) {
+ dataType = ((dataType & MPSDataTypeFloatBit) || (dataType == MPSDataTypeInt64)) ? MPSDataTypeFloat32 : MPSDataTypeInt32;
+ return [mpsGraph castTensor:inputTensor
+ toType:dataType
+ name:@"castInputTensor"];
}
return inputTensor;
}
@@ -54,14 +56,16 @@
// #issue 104398441 sortWithTensor and argsortWithTensor has support of
// Int32, Half and Float32 types. These utilities are to help cast from these
// types.
-MPSGraphTensor* castFromIHFTypes(MPSGraph* mpsGraph, MPSGraphTensor* inputTensor, const Tensor& input) {
+MPSGraphTensor* castFromIHFTypes(MPSGraph* mpsGraph, MPSGraphTensor* inputTensor, const Tensor& input, bool includesInt64) {
MPSDataType dataType = getMPSDataType(input.scalar_type());
- if (dataType != MPSDataTypeInt32 &&
- dataType != MPSDataTypeFloat32 &&
- dataType != MPSDataTypeFloat16) {
- inputTensor = [mpsGraph castTensor:inputTensor
- toType:dataType
- name:@"castInputTensor"];
+ bool condition = (dataType != MPSDataTypeInt32) && (dataType != MPSDataTypeFloat32) && (dataType != MPSDataTypeFloat16);
+ if (includesInt64) {
+ condition = condition && (dataType != MPSDataTypeInt64);
+ }
+ if (condition) {
+ inputTensor = [mpsGraph castTensor:inputTensor
+ toType:dataType
+ name:@"castInputTensor"];
}
return inputTensor;
}
@@ -399,6 +403,10 @@
// this is meant to suppress the availability warning on castTensor
// we pass ScalarType instead of MPSDataType to handle MPSDataTypeBoolean's availability too
+MPSGraphTensor* castMPSTensor(MPSGraph *mpsGraph, MPSGraphTensor* tensor, MPSDataType toType) {
+ return [mpsGraph castTensor:tensor toType:toType name:@"castTensor"];
+}
+
MPSGraphTensor* castMPSTensor(MPSGraph *mpsGraph, MPSGraphTensor* tensor, ScalarType toType) {
return [mpsGraph castTensor:tensor toType:getMPSScalarType(toType) name:@"castTensor"];
}
diff --git a/aten/src/ATen/native/mps/operations/ReduceOps.mm b/aten/src/ATen/native/mps/operations/ReduceOps.mm
index 583d12d..ceacc61 100644
--- a/aten/src/ATen/native/mps/operations/ReduceOps.mm
+++ b/aten/src/ATen/native/mps/operations/ReduceOps.mm
@@ -138,13 +138,10 @@
const Tensor& output_t,
MPSReductionType reduction_type,
const std::string& func_name) {
+ bool macOS13_3_plus = is_macos_13_or_newer(MacOSVersion::MACOS_VER_13_3_PLUS);
+ MPS_CHECK_INT64_OP_SUPPORTED(input_t, macOS13_3_plus, func_name);
- // issue 103641234, reduction ops does not have int64 support
- if (input_t.scalar_type() == ScalarType::Long) {
- TORCH_WARN_ONCE("MPS: no support for int64 reduction ops, casting it to int32");
- }
- IntArrayRef input_shape = input_t.sizes();
-
+ auto input_shape = input_t.sizes();
if (opt_dim.has_value()) {
IntArrayRef dim = opt_dim.value();
for (const auto dim_val : dim) {
@@ -172,7 +169,6 @@
}
return;
}
-
auto stream = at::mps::getCurrentMPSStream();
@autoreleasepool {
std::string dtype_str = dtype.has_value() ? mps::getMPSTypeString(dtype.value()) : "";
@@ -199,22 +195,19 @@
MPSGraphTensor* inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, input_t);
MPSGraphTensor* castInputTensor = inputTensor;
- MPSDataType inputCastDtype = MPSDataTypeInvalid;
+ MPSDataType inputCastType = MPSDataTypeInvalid;
if (dtype.has_value() &&
- (dtype.value() == kFloat || dtype.value() == kHalf || dtype.value() == kInt)) {
- inputCastDtype = getMPSDataType(dtype.value());
- } else if (input_type != MPSDataTypeInt32 &&
- input_type != MPSDataTypeFloat32 &&
- input_type != MPSDataTypeFloat16) {
- inputCastDtype = MPSDataTypeFloat32;
+ (dtype.value() == kFloat || dtype.value() == kHalf || dtype.value() == kInt ||
+ (dtype.value() == kLong && macOS13_3_plus))) {
+ inputCastType = getMPSDataType(dtype.value());
} else if (!is_macos_13_or_newer() && input_type == MPSDataTypeFloat16) {
- inputCastDtype = MPSDataTypeFloat32;
+ inputCastType = MPSDataTypeFloat32;
}
- if (inputCastDtype != MPSDataTypeInvalid) {
- castInputTensor = [mpsGraph castTensor:inputTensor
- toType:inputCastDtype
- name:@"castInputTensor"];
+ if (inputCastType != MPSDataTypeInvalid) {
+ castInputTensor = castMPSTensor(mpsGraph, inputTensor, inputCastType);
+ } else {
+ castInputTensor = castToIHFTypes(mpsGraph, inputTensor, input_t, /*includesInt64=*/macOS13_3_plus);
}
MPSGraphTensor* castOutputTensor = nil;
@@ -276,14 +269,9 @@
name:nil];
}
- MPSGraphTensor* outputTensor = nil;
-
- if (output_t.scalar_type() != ScalarType::Float) {
- outputTensor = [mpsGraph castTensor:castOutputTensor
- toType:getMPSDataType(output_t.scalar_type())
- name:@"outputTensor"];
- } else {
- outputTensor = castOutputTensor;
+ MPSGraphTensor* outputTensor = castOutputTensor;
+ if (getMPSDataType(output_t.scalar_type()) != [castOutputTensor dataType]) {
+ outputTensor = castMPSTensor(mpsGraph, castOutputTensor, output_t.scalar_type());
}
newCachedGraph->inputTensor_ = inputTensor;
@@ -959,6 +947,9 @@
return;
}
+ bool macOS13_3_plus = is_macos_13_or_newer(MacOSVersion::MACOS_VER_13_3_PLUS);
+ MPS_CHECK_INT64_OP_SUPPORTED(input_t, macOS13_3_plus, "any_out");
+
MPSGraphCache* cache_ = MPSGraphCache::getInstance();
int64_t dim_ = maybe_wrap_dim(dim, input_t.dim());
native::zero_numel_check_dims(input_t, dim_, "any()");
@@ -987,29 +978,18 @@
MPSGraph* mpsGraph = make_mps_graph();
newCachedGraph = new CachedGraph(mpsGraph);
- MPSGraphTensor* outputTensor;
MPSDataType input_type = getMPSDataType(input_t.scalar_type());
MPSGraphTensor* inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, input_type, input_t_shape);
- if (input_type != MPSDataTypeInt32 &&
- input_type != MPSDataTypeFloat32 &&
- input_type != MPSDataTypeFloat16) {
- MPSGraphTensor* inputCastedTensor = [mpsGraph castTensor:inputTensor
- toType:MPSDataTypeInt32
- name:@"any_all"];
- MPSGraphTensor* outputCastedTensor = [mpsGraph reductionOrWithTensor:inputCastedTensor
- axis:dim_
- name:nil];
- outputTensor = [mpsGraph castTensor:outputCastedTensor
- toType:MPSDataTypeBool
- name:@"any"];
- } else {
- MPSGraphTensor* outputUncastedTensor = [mpsGraph reductionOrWithTensor:inputTensor
- axis:dim_
- name:nil];
- outputTensor = [mpsGraph castTensor:outputUncastedTensor
- toType:MPSDataTypeBool
- name:@"any"];
+ MPSGraphTensor* castInputTensor = castToIHFTypes(mpsGraph, inputTensor, input_t, /*includesInt64=*/macOS13_3_plus);
+ MPSGraphTensor* castOutputTensor = [mpsGraph reductionOrWithTensor:castInputTensor
+ axis:dim_
+ name:nil];
+ MPSGraphTensor* outputTensor = castOutputTensor;
+ if (MPSDataTypeBool != [castOutputTensor dataType]) {
+ outputTensor = [mpsGraph castTensor:castOutputTensor
+ toType:MPSDataTypeBool
+ name:@"outputTensor"];
}
newCachedGraph->inputTensor_ = inputTensor;
newCachedGraph->outputTensor_ = outputTensor;
@@ -1044,6 +1024,9 @@
return;
}
+ bool macOS13_3_plus = is_macos_13_or_newer(MacOSVersion::MACOS_VER_13_3_PLUS);
+ MPS_CHECK_INT64_OP_SUPPORTED(input_t, macOS13_3_plus, "any_all_out");
+
auto cache_ = MPSGraphCache::getInstance();
auto stream = at::mps::getCurrentMPSStream();
@@ -1061,29 +1044,16 @@
MPSGraph* mpsGraph = make_mps_graph();
newCachedGraph = new CachedGraph(mpsGraph);
- MPSGraphTensor* outputTensor;
MPSDataType input_type = getMPSDataType(input_t.scalar_type());
MPSGraphTensor* inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, input_type, input_t_shape);
+ MPSGraphTensor* castInputTensor = castToIHFTypes(mpsGraph, inputTensor, input_t, /*includesInt64=*/macOS13_3_plus);
+ MPSGraphTensor* castOutputTensor = [mpsGraph reductionOrWithTensor:castInputTensor
+ axes:nil
+ name:nil];
- if (input_type != MPSDataTypeInt32 &&
- input_type != MPSDataTypeFloat32 &&
- input_type != MPSDataTypeFloat16) {
- MPSGraphTensor* inputCastedTensor = [mpsGraph castTensor:inputTensor
- toType:MPSDataTypeInt32
- name:@"any_all"];
- MPSGraphTensor* outputCastedTensor = [mpsGraph reductionOrWithTensor:inputCastedTensor
- axes:nil
- name:nil];
- outputTensor = [mpsGraph castTensor:outputCastedTensor
- toType:MPSDataTypeBool
- name:@"any_all"];
- } else {
- MPSGraphTensor* outputUncastedTensor = [mpsGraph reductionOrWithTensor:inputTensor
- axes:nil
- name:nil];
- outputTensor = [mpsGraph castTensor:outputUncastedTensor
- toType:MPSDataTypeBool
- name:@"any_all"];
+ MPSGraphTensor* outputTensor = castOutputTensor;
+ if (getMPSDataType(output_t.scalar_type()) != [castOutputTensor dataType]) {
+ outputTensor = castMPSTensor(mpsGraph, castOutputTensor, output_t.scalar_type());
}
newCachedGraph->inputTensor_ = inputTensor;
newCachedGraph->outputTensor_ = outputTensor;
@@ -1119,6 +1089,9 @@
return;
}
+ bool macOS13_3_plus = is_macos_13_or_newer(MacOSVersion::MACOS_VER_13_3_PLUS);
+ MPS_CHECK_INT64_OP_SUPPORTED(input_t, macOS13_3_plus, "all_out");
+
MPSGraphCache* cache_ = MPSGraphCache::getInstance();
int64_t dim_ = maybe_wrap_dim(dim, input_t.dim());
native::zero_numel_check_dims(input_t, dim_, "all()");
@@ -1147,30 +1120,15 @@
MPSGraph* mpsGraph = make_mps_graph();
newCachedGraph = new CachedGraph(mpsGraph);
- MPSGraphTensor* outputTensor;
MPSDataType input_type = getMPSDataType(input_t.scalar_type());
MPSGraphTensor* inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, input_type, input_t_shape);
-
- if (input_type != MPSDataTypeInt32 &&
- input_type != MPSDataTypeFloat32 &&
- input_type != MPSDataTypeFloat16 )
- {
- MPSGraphTensor* inputCastedTensor = [mpsGraph castTensor:inputTensor
- toType:MPSDataTypeInt32
- name:@"all_all"];
- MPSGraphTensor* outputCastedTensor = [mpsGraph reductionAndWithTensor:inputCastedTensor
- axis:dim_
- name:nil];
- outputTensor = [mpsGraph castTensor:outputCastedTensor
- toType:MPSDataTypeBool
- name:@"all"];
- } else {
- MPSGraphTensor* outputUncastedTensor = [mpsGraph reductionAndWithTensor:inputTensor
- axis:dim_
- name:nil];
- outputTensor = [mpsGraph castTensor:outputUncastedTensor
- toType:MPSDataTypeBool
- name:@"all"];
+ MPSGraphTensor* castInputTensor = castToIHFTypes(mpsGraph, inputTensor, input_t, /*includesInt64=*/macOS13_3_plus);
+ MPSGraphTensor* castOutputTensor = [mpsGraph reductionAndWithTensor:castInputTensor
+ axis:dim_
+ name:nil];
+ MPSGraphTensor* outputTensor = castOutputTensor;
+ if (MPSDataTypeBool != [castOutputTensor dataType]) {
+ outputTensor = castMPSTensor(mpsGraph, castOutputTensor, MPSDataTypeBool);
}
newCachedGraph->inputTensor_ = inputTensor;
newCachedGraph->outputTensor_ = outputTensor;
@@ -1199,6 +1157,9 @@
return;
}
+ bool macOS13_3_plus = is_macos_13_or_newer(MacOSVersion::MACOS_VER_13_3_PLUS);
+ MPS_CHECK_INT64_OP_SUPPORTED(input_t, macOS13_3_plus, "all_all_out");
+
MPSGraphCache* cache_ = MPSGraphCache::getInstance();
auto stream = at::mps::getCurrentMPSStream();
@@ -1215,30 +1176,17 @@
MPSGraph* mpsGraph = make_mps_graph();
newCachedGraph = new CachedGraph(mpsGraph);
- MPSGraphTensor* outputTensor;
MPSDataType input_type = getMPSDataType(input_t.scalar_type());
MPSGraphTensor* inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, input_type, input_t_shape);
-
- if (input_type != MPSDataTypeInt32 &&
- input_type != MPSDataTypeFloat32 &&
- input_type != MPSDataTypeFloat16) {
- MPSGraphTensor* inputCastedTensor = [mpsGraph castTensor:inputTensor
- toType:MPSDataTypeInt32
- name:@"all_all"];
- MPSGraphTensor* outputCastedTensor = [mpsGraph reductionAndWithTensor:inputCastedTensor
- axes:nil
- name:nil];
- outputTensor = [mpsGraph castTensor:outputCastedTensor
- toType:MPSDataTypeBool
- name:@"all_all"];
- } else {
- MPSGraphTensor* outputUncastedTensor = [mpsGraph reductionAndWithTensor:inputTensor
- axes:nil
- name:nil];
- outputTensor = [mpsGraph castTensor:outputUncastedTensor
- toType:MPSDataTypeBool
- name:@"all_all"];
+ MPSGraphTensor* castInputTensor = castToIHFTypes(mpsGraph, inputTensor, input_t, /*includesInt64=*/macOS13_3_plus);
+ MPSGraphTensor* castOutputTensor = [mpsGraph reductionAndWithTensor:castInputTensor
+ axes:nil
+ name:nil];
+ MPSGraphTensor* outputTensor = castOutputTensor;
+ if (MPSDataTypeBool != [castOutputTensor dataType]) {
+ outputTensor = castMPSTensor(mpsGraph, castOutputTensor, MPSDataTypeBool);
}
+
newCachedGraph->inputTensor_ = inputTensor;
newCachedGraph->outputTensor_ = outputTensor;
@@ -1268,9 +1216,8 @@
(const Tensor& input_t,
MPSReductionType reduction_type,
const std::string& func_name) {
- if (input_t.scalar_type() == ScalarType::Long) {
- TORCH_WARN_ONCE("MPS: no support for int64 min/max ops, casting it to int32");
- }
+ bool macOS13_3_plus = is_macos_13_or_newer(MacOSVersion::MACOS_VER_13_3_PLUS);
+ MPS_CHECK_INT64_OP_SUPPORTED(input_t, macOS13_3_plus, "min_max");
using CachedGraph = MPSUnaryCachedGraph;
@@ -1297,40 +1244,27 @@
MPSGraphTensor* inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, input_t);
- MPSGraphTensor* outputTensor = nil;
- MPSGraphTensor* castInputTensor = nil;
MPSGraphTensor* castOutputTensor = nil;
-
- if (input_t.scalar_type() != ScalarType::Float &&
- input_t.scalar_type() != ScalarType::Int &&
- input_t.scalar_type() != ScalarType::Half) {
- castInputTensor = [mpsGraph castTensor:inputTensor
- toType:MPSDataTypeInt32
- name:@"castInputTensor"];
- } else {
- castInputTensor = inputTensor;
- }
+ MPSGraphTensor* castInputTensor = castToIHFTypes(mpsGraph, inputTensor, input_t, /*includesInt64=*/macOS13_3_plus);
NSArray<NSNumber*>* axes = getTensorAxes(input_t);
if (reduction_type == MPSReductionType::MAX) {
- outputTensor = [mpsGraph reductionMaximumWithTensor:castInputTensor
+ castOutputTensor = [mpsGraph reductionMaximumWithTensor:castInputTensor
axes:axes
name:nil];
} else if(reduction_type == MPSReductionType::MIN) {
- outputTensor = [mpsGraph reductionMinimumWithTensor:castInputTensor
+ castOutputTensor = [mpsGraph reductionMinimumWithTensor:castInputTensor
axes:axes
name:nil];
}
- if(input_t.scalar_type() == ScalarType::Long) {
- castOutputTensor = [mpsGraph castTensor:outputTensor
- toType:MPSDataTypeInt64
- name:@"castInputTensor"];
- } else {
- castOutputTensor = outputTensor;
+ MPSGraphTensor* outputTensor = castOutputTensor;
+ if (getMPSDataType(output_t.scalar_type()) != [castOutputTensor dataType]) {
+ outputTensor = castMPSTensor(mpsGraph, castOutputTensor, output_t.scalar_type());
}
+
newCachedGraph->inputTensor_ = inputTensor;
- newCachedGraph->outputTensor_ = castOutputTensor;
+ newCachedGraph->outputTensor_ = outputTensor;
}
return newCachedGraph;
});
@@ -1373,7 +1307,8 @@
const Tensor& indices_t,
MPSReductionType reduction_type,
const std::string& func_name) {
- TORCH_CHECK(input_t.scalar_type() != ScalarType::Long, "MPS does not support min/max ops with int64 input");
+ bool macOS13_3_plus = is_macos_13_or_newer(MacOSVersion::MACOS_VER_13_3_PLUS);
+ MPS_CHECK_INT64_OP_SUPPORTED(input_t, macOS13_3_plus, "min_max_out");
if (output_t.numel() == 0) {
return;
@@ -1423,26 +1358,17 @@
MPSGraphTensor* inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, input_t);
MPSGraphTensor* outputTensor = nil;
+ MPSGraphTensor* castInputTensor = castToIHFTypes(mpsGraph, inputTensor, input_t, /*includesInt64=*/macOS13_3_plus);;
- MPSGraphTensor* castInputTensor = inputTensor;
- bool castOutput = false;
- if(input_t.scalar_type() != ScalarType::Float &&
- input_t.scalar_type() != ScalarType::Int &&
- input_t.scalar_type() != ScalarType::Half) {
- castInputTensor = [mpsGraph castTensor:inputTensor
- toType:MPSDataTypeInt32
- name:@"castInputTensor"];
- castOutput = true;
- }
-
- if(reduction_type == MPSReductionType::MAX)
+ if(reduction_type == MPSReductionType::MAX) {
outputTensor = [mpsGraph reductionMaximumWithTensor:castInputTensor
axis:(NSInteger)dim_
name:nil];
- else if(reduction_type == MPSReductionType::MIN)
+ } else if(reduction_type == MPSReductionType::MIN) {
outputTensor = [mpsGraph reductionMinimumWithTensor:castInputTensor
axis:(NSInteger)dim_
name:nil];
+ }
MPSGraphTensor* argreduceOutTensor = nil;
if(reduction_type == MPSReductionType::MAX)
@@ -1454,14 +1380,15 @@
axis:(NSInteger)dim_
name:@"argmax_out"];
- MPSGraphTensor *indicesTensor = [mpsGraph castTensor:argreduceOutTensor
- toType:MPSDataTypeInt64
- name:@"cast_out"];
+ MPSGraphTensor *indicesTensor = nil;
+ if ([argreduceOutTensor dataType] != MPSDataTypeInt64) {
+ indicesTensor = [mpsGraph castTensor:argreduceOutTensor
+ toType:MPSDataTypeInt64
+ name:@"cast_out"];
+ }
- if (castOutput) {
- outputTensor = [mpsGraph castTensor:outputTensor
- toType:getMPSDataType(output_t.scalar_type())
- name:@"cast_out"];
+ if ([outputTensor dataType] != getMPSDataType(output_t.scalar_type())) {
+ outputTensor = castMPSTensor(mpsGraph, outputTensor, output_t.scalar_type());
}
newCachedGraph->inputTensor_ = inputTensor;
newCachedGraph->outputTensor_ = outputTensor;
@@ -1526,6 +1453,9 @@
using CachedGraph = MPSUnaryCachedGraph;
auto cache_ = MPSGraphCache::getInstance();
+ bool macOS13_3_plus = is_macos_13_or_newer(MacOSVersion::MACOS_VER_13_3_PLUS);
+ MPS_CHECK_INT64_OP_SUPPORTED(input_t, macOS13_3_plus, "argmax_argmin_out");
+
int64_t dim_ = -1;
if (dim.has_value()) {
@@ -1585,18 +1515,9 @@
newCachedGraph = new CachedGraph(mpsGraph);
MPSGraphTensor* inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, getMPSDataType(input_t.scalar_type()), apparent_in_shape);
-
- MPSGraphTensor* castInputTensor = inputTensor;
MPSGraphTensor* argreduceOutTensor = nil;
- if (input_t.scalar_type() != ScalarType::Float &&
- input_t.scalar_type() != ScalarType::Int &&
- input_t.scalar_type() != ScalarType::Half) {
- castInputTensor = [mpsGraph castTensor: inputTensor
- toType: MPSDataTypeFloat32
- name: @"castInputTensor"];
- }
-
+ MPSGraphTensor* castInputTensor = castToIHFTypes(mpsGraph, inputTensor, input_t, /*includesInt64=*/macOS13_3_plus);
if (reduction_type == MPSReductionType::MAX) {
argreduceOutTensor = [mpsGraph reductionArgMaximumWithTensor: castInputTensor
axis: (NSInteger)dim_
@@ -1606,9 +1527,11 @@
axis: (NSInteger)dim_
name: nil];
}
- MPSGraphTensor* outputTensor = [mpsGraph castTensor: argreduceOutTensor
- toType: MPSDataTypeInt64
- name: @"castOutputTensor"];
+
+ MPSGraphTensor* outputTensor = argreduceOutTensor;
+ if (getMPSDataType(output_t.scalar_type()) != [argreduceOutTensor dataType]) {
+ outputTensor = castMPSTensor(mpsGraph, argreduceOutTensor, output_t.scalar_type());
+ }
MPSGraphTensor* outputClampedTensor = [mpsGraph clampWithTensor: outputTensor
minValueTensor: [mpsGraph constantWithScalar:0 dataType:MPSDataTypeInt64]
@@ -1752,7 +1675,8 @@
return at::median(input_t.to("cpu"));
}
- TORCH_CHECK(input_t.scalar_type() != ScalarType::Long, "MPS does not support median op with int64 input");
+ bool macOS13_3_plus = is_macos_13_or_newer(MacOSVersion::MACOS_VER_13_3_PLUS);
+ MPS_CHECK_INT64_OP_SUPPORTED(input_t, macOS13_3_plus, "median");
using CachedGraph = MPSUnaryCachedGraph;
@@ -1783,20 +1707,11 @@
MPSGraph* mpsGraph = make_mps_graph();
newCachedGraph = new CachedGraph(mpsGraph);
auto inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, input_t);
- auto reshapedTensor = [mpsGraph reshapeTensor: inputTensor
+ MPSGraphTensor* castInputTensor = castToIHFTypes(mpsGraph, inputTensor, input_t, /*includesInt64=*/macOS13_3_plus);
+
+ auto reshapedTensor = [mpsGraph reshapeTensor: castInputTensor
withShape: @[@-1]
name: nil];
- MPSDataType dataType = [inputTensor dataType];
- // #issue 104398441 sortWithTensor only supports following types, cast if necessary
- if (dataType != MPSDataTypeInt32 &&
- dataType != MPSDataTypeFloat32 &&
- dataType != MPSDataTypeFloat16) {
- dataType = (dataType & MPSDataTypeFloatBit) ? MPSDataTypeFloat32 : MPSDataTypeInt32;
- reshapedTensor = [mpsGraph castTensor:reshapedTensor
- toType:dataType
- name:@"castReshapedTensor"];
- }
-
auto sortedTensor = [mpsGraph sortWithTensor: reshapedTensor
axis: ((NSUInteger) (int)0)
name: nil];
@@ -1858,6 +1773,8 @@
};
auto cache_ = MPSGraphCache::getInstance();
+ bool macOS13_3_plus = is_macos_13_or_newer(MacOSVersion::MACOS_VER_13_3_PLUS);
+ MPS_CHECK_INT64_OP_SUPPORTED(input_t, macOS13_3_plus, "median_out");
int64_t dim_ = maybe_wrap_dim(dim, input_t.dim());
@@ -1888,34 +1805,22 @@
auto mpsGraph = make_mps_graph();
newCachedGraph = new CachedGraph(mpsGraph);
- MPSGraphTensor* inputTensor = mpsGraphUnrankedPlaceHolder(mpsGraph, getMPSDataType(input_t.scalar_type()));
- MPSGraphTensor* outputTensor = nil;
- MPSGraphTensor* castInputTensor = inputTensor;
- MPSDataType dataType = getMPSDataType(input_t.scalar_type());
- // #issue 104398441 sortWithTensor only supports following types, cast if necessary
- if (dataType != MPSDataTypeInt32 &&
- dataType != MPSDataTypeFloat32 &&
- dataType != MPSDataTypeFloat16) {
- dataType = (dataType & MPSDataTypeFloatBit) ? MPSDataTypeFloat32 : MPSDataTypeInt32;
- castInputTensor = [mpsGraph castTensor:inputTensor
- toType:dataType
- name:@"castInputTensor"];
- }
+ MPSGraphTensor* inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, input_t);
+ MPSGraphTensor* castInputTensor = castToIHFTypes(mpsGraph, inputTensor, input_t, /*includesInt64=*/macOS13_3_plus);
- MPSGraphTensor * sortedTensor = [mpsGraph
- sortWithTensor:castInputTensor
- axis:((NSUInteger) (int)dim_)
- name:nil];
+ MPSGraphTensor * sortedTensor = [mpsGraph sortWithTensor:castInputTensor
+ axis:((NSUInteger) (int)dim_)
+ name:nil];
- outputTensor = [mpsGraph sliceTensor:sortedTensor
- dimension:dim_
- start:((NSUInteger) (int)((dim_total_elements+1)/2 ) - 1)
- length:1
- name:nil];
+ MPSGraphTensor* outputTensor = [mpsGraph sliceTensor:sortedTensor
+ dimension:dim_
+ start:((NSUInteger) (int)((dim_total_elements+1)/2 ) - 1)
+ length:1
+ name:nil];
MPSGraphTensor* argreduceOutTensor = nil;
argreduceOutTensor = [mpsGraph argSortWithTensor:castInputTensor
- axis:(NSInteger)dim_
- name:@"argmax_out"];
+ axis:(NSInteger)dim_
+ name:@"argmax_out"];
MPSGraphTensor* argOutputTensor = [mpsGraph sliceTensor:argreduceOutTensor
dimension:dim_
start:((NSUInteger) (int)((dim_total_elements+1)/2 ) - 1)
@@ -1978,8 +1883,8 @@
bool keepdim,
at::Tensor & values,
at::Tensor & indices){
-
- TORCH_CHECK(input_t.scalar_type() != ScalarType::Long, "MPS does not support median ops with int64 input");
+ bool macOS13_3_plus = is_macos_13_or_newer(MacOSVersion::MACOS_VER_13_3_PLUS);
+ MPS_CHECK_INT64_OP_SUPPORTED(input_t, macOS13_3_plus, "median_out");
int64_t dim_ = maybe_wrap_dim(dim, input_t.dim());
native::zero_numel_check_dims(input_t, dim_, "max()");
diff --git a/aten/src/ATen/native/mps/operations/Repeat.mm b/aten/src/ATen/native/mps/operations/Repeat.mm
index d2155d2..2052bc8 100644
--- a/aten/src/ATen/native/mps/operations/Repeat.mm
+++ b/aten/src/ATen/native/mps/operations/Repeat.mm
@@ -230,10 +230,10 @@
Tensor repeat_interleave_mps(const Tensor& repeat_, c10::optional<int64_t> output_size) {
Tensor output;
Tensor repeat = repeat_;
- if (repeat.scalar_type() == kLong) {
+ if (repeat.scalar_type() == kLong && !is_macos_13_or_newer(MacOSVersion::MACOS_VER_13_3_PLUS)) {
// #103810551: `repeat_interleave_common` uses cumsum to calculate the final shape of output,
// which currently doesn't support int64_t as input. Casting internally the indices to int32_t.
- TORCH_WARN_ONCE("MPS: no support for int64 repeats mask, casting it to int32");
+ TORCH_WARN_ONCE("MPS: no support for int64 repeats mask, casting it to int32. Support has been added in macOS 13.3");
repeat = repeat.to(kInt);
}
AT_DISPATCH_INDEX_TYPES(repeat.scalar_type(), "repeat_interleave_mps", [&]() {
diff --git a/aten/src/ATen/native/mps/operations/Sort.mm b/aten/src/ATen/native/mps/operations/Sort.mm
index 4b3bb69..7b9706a 100644
--- a/aten/src/ATen/native/mps/operations/Sort.mm
+++ b/aten/src/ATen/native/mps/operations/Sort.mm
@@ -18,6 +18,10 @@
const Tensor& values,
const Tensor& indices) {
using namespace mps;
+
+ bool macOS13_3_plus = is_macos_13_or_newer(MacOSVersion::MACOS_VER_13_3_PLUS);
+ MPS_CHECK_INT64_OP_SUPPORTED(self, macOS13_3_plus, "sort_stable_out");
+
values.copy_(self);
// check if self is scalar
dim = maybe_wrap_dim(dim, self.dim(), true);
@@ -35,9 +39,6 @@
indices.copy_(cpu_indices);
return;
}
- if (self.scalar_type() == ScalarType::Long) {
- TORCH_WARN_ONCE("MPS: no support for int64 min/max ops, casting it to int32");
- }
MPSStream* stream = getCurrentMPSStream();
struct CachedGraph : public MPSCachedGraph {
@@ -60,17 +61,17 @@
newCachedGraph = new CachedGraph(mpsGraph);
newCachedGraph->selfTensor = mpsGraphRankedPlaceHolder(mpsGraph, getMPSDataType(self.scalar_type()), input_shape);
- MPSGraphTensor* castInputTensor = castToIHFTypes(mpsGraph, newCachedGraph->selfTensor, self);
+ MPSGraphTensor* castInputTensor = castToIHFTypes(mpsGraph, newCachedGraph->selfTensor, self, /*includesInt64=*/macOS13_3_plus);
MPSGraphTensor * sortedTensor = [mpsGraph sortWithTensor:castInputTensor
axis:(NSInteger)dim
descending:(BOOL)descending
name:@"sort_out"];
- sortedTensor = castFromIHFTypes(mpsGraph, sortedTensor, values);
+ sortedTensor = castFromIHFTypes(mpsGraph, sortedTensor, values, /*includesInt64=*/macOS13_3_plus);
MPSGraphTensor* argSortedTensor = [mpsGraph argSortWithTensor:castInputTensor
axis:(NSInteger)dim
descending:(BOOL)descending
name:@"argsort_out"];
- argSortedTensor = castFromIHFTypes(mpsGraph, argSortedTensor, indices);
+ argSortedTensor = castFromIHFTypes(mpsGraph, argSortedTensor, indices, /*includesInt64=*/macOS13_3_plus);
newCachedGraph->valuesTensor = sortedTensor;
newCachedGraph->indicesTensor = argSortedTensor;
}
diff --git a/aten/src/ATen/native/mps/operations/UnaryOps.mm b/aten/src/ATen/native/mps/operations/UnaryOps.mm
index 76b7e25..6396ff3 100644
--- a/aten/src/ATen/native/mps/operations/UnaryOps.mm
+++ b/aten/src/ATen/native/mps/operations/UnaryOps.mm
@@ -418,6 +418,7 @@
c10::optional<ScalarType> dtype,
const Tensor& result) {
+ bool macOS13_3_plus = is_macos_13_or_newer(MacOSVersion::MACOS_VER_13_3_PLUS);
auto nDims = self.dim();
auto wrapped_dim = maybe_wrap_dim(dim, nDims);
TORCH_CHECK(wrapped_dim >=0 && wrapped_dim < std::max(1LL, self.ndimension()), "Expected wrapped dim to be between 0 and ", self.ndimension(), " but got ", wrapped_dim , "(original dim is ", dim, ")");
@@ -428,18 +429,26 @@
return;
}
auto input = dtype.has_value() ? self.to(dtype.value()) : self;
- TORCH_CHECK(input.scalar_type() != ScalarType::Long, "MPS does not support cumsum op with int64 input");
+
+ // issue #103810551: cumsum is horribly broken for int8, int16 and as chances for overflow is pretty high, cast to int32
+ // fixed in macOS 13.3
+ bool castInputData = (isIntegralType(input.scalar_type()) &&
+ input.scalar_type() != ScalarType::Int &&
+ input.scalar_type() != ScalarType::Long);
+
+ TORCH_CHECK(macOS13_3_plus || input.scalar_type() != ScalarType::Long,
+ "MPS does not support cumsum op with int64 input. Support has been added in macOS 13.3");
+
mps::unary_op(input, result, "cumsum_out_mp" + std::to_string(dim),
^ MPSGraphTensor* (MPSGraph* mpsGraph, MPSGraphTensor* inputTensor) {
- // cumsum is horribly broken for int8, int16 and as chances for overflow is pretty high, cast to int32
- if (isIntegralType(input.scalar_type()) && input.scalar_type() !=ScalarType::Int) {
+
+ if (castInputData) {
inputTensor = mps::castMPSTensor(mpsGraph, inputTensor, ScalarType::Int);
}
auto rc = [mpsGraph cumulativeSumWithTensor: inputTensor
axis: dim
name: nil];
- if (result.scalar_type()!= input.scalar_type() ||
- (isIntegralType(input.scalar_type()) && input.scalar_type() !=ScalarType::Int)) {
+ if ((mps::getMPSDataType(result.scalar_type()) != [rc dataType]) || castInputData) {
return mps::castMPSTensor(mpsGraph, rc, result.scalar_type());
}
return rc;
diff --git a/test/test_mps.py b/test/test_mps.py
index 4a930ce..0c77fa5 100644
--- a/test/test_mps.py
+++ b/test/test_mps.py
@@ -2840,7 +2840,7 @@
helper(torch.int64)
except Exception as e:
e_string = str(e)
- self.assertEqual(e_string, "MPS does not support cumsum op with int64 input")
+ self.assertEqual(e_string, "MPS does not support cumsum op with int64 input. Support has been added in macOS 13.3")
def test_cumsum_minus_one_axis(self):
def helper(dtype):
@@ -9550,7 +9550,7 @@
'cos': ['b8', 'f32', 'i16', 'i32', 'u8', 'i64'],
'cosh': ['b8', 'f32', 'i16', 'i32', 'u8', 'i64'],
'cov': ['f32'],
- 'cumsum': ['f16', 'f32', 'int16', 'int32'],
+ 'cumsum': ['i8', 'b8', 'f16', 'f32', 'i16', 'i32', 'i64'],
'deg2rad': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'],
'diag': ['f32', 'i32'],
'diag_embed': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64'],
@@ -10181,7 +10181,7 @@
self.assertEqual(cpu_out, mps_out, atol=atol, rtol=rtol)
except Exception as e:
- if any(s in str(e).lower() for s in ["float16", "div truc rounding"]):
+ if any(s in str(e).lower() for s in ["int64", "float16", "div truc rounding"]):
self.skipTest(f"Expected Runtime Error: {str(e)}")
if not generate_new_truth: