[MPS] Remove casts from reduction/cumsum/sort ops starting with macOS 13.3 (#95817)
MPS in macOS13.3 has added support for int64 in reduction ops / cumsum / sort / argsort. This change removes the hard-coded casts and error messages prior macOS 13.3, allowing the op to run natively with int64.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/95817
Approved by: https://github.com/kulinseth
diff --git a/test/test_mps.py b/test/test_mps.py
index 4a930ce..0c77fa5 100644
--- a/test/test_mps.py
+++ b/test/test_mps.py
@@ -2840,7 +2840,7 @@
helper(torch.int64)
except Exception as e:
e_string = str(e)
- self.assertEqual(e_string, "MPS does not support cumsum op with int64 input")
+ self.assertEqual(e_string, "MPS does not support cumsum op with int64 input. Support has been added in macOS 13.3")
def test_cumsum_minus_one_axis(self):
def helper(dtype):
@@ -9550,7 +9550,7 @@
'cos': ['b8', 'f32', 'i16', 'i32', 'u8', 'i64'],
'cosh': ['b8', 'f32', 'i16', 'i32', 'u8', 'i64'],
'cov': ['f32'],
- 'cumsum': ['f16', 'f32', 'int16', 'int32'],
+ 'cumsum': ['i8', 'b8', 'f16', 'f32', 'i16', 'i32', 'i64'],
'deg2rad': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'],
'diag': ['f32', 'i32'],
'diag_embed': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64'],
@@ -10181,7 +10181,7 @@
self.assertEqual(cpu_out, mps_out, atol=atol, rtol=rtol)
except Exception as e:
- if any(s in str(e).lower() for s in ["float16", "div truc rounding"]):
+ if any(s in str(e).lower() for s in ["int64", "float16", "div truc rounding"]):
self.skipTest(f"Expected Runtime Error: {str(e)}")
if not generate_new_truth: