[MPS] Remove remaining casts from 13.3 (#95870)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/95870
Approved by: https://github.com/kulinseth
diff --git a/test/test_mps.py b/test/test_mps.py
index 0c77fa5..9877612 100644
--- a/test/test_mps.py
+++ b/test/test_mps.py
@@ -3852,6 +3852,15 @@
helper(2, 8, 4, 4, "min", torch.float16)
helper(2, 8, 4, 4, "min", torch.int64)
+ @unittest.skipIf(product_version < 13.3, "Long data type supported from macOS 13.3 and above")
+ def test_reduction_sum_max_long_val(self):
+ x_mps = torch.tensor([sys.maxsize, sys.maxsize - 10, sys.maxsize - 5, sys.maxsize - 18], device="mps")
+ x_cpu = x_mps.detach().clone().cpu()
+
+ res_mps = torch.sum(x_mps)
+ res_cpu = torch.sum(x_cpu)
+ self.assertEqual(res_mps, res_cpu)
+
# Test forward max
# Note - don't test grad now
def test_max_el(self):