[MPS] Block uint8 data type for unary and binary ops on macOS 12 (#94876)
Blocks uint8 data type for unary and binary ops on macOS 12
Pull Request resolved: https://github.com/pytorch/pytorch/pull/94876
Approved by: https://github.com/kulinseth
diff --git a/test/test_mps.py b/test/test_mps.py
index b3740b5..e3374b0 100644
--- a/test/test_mps.py
+++ b/test/test_mps.py
@@ -1,6 +1,7 @@
# -*- coding: utf-8 -*-
# Owner(s): ["module: mps"]
+import platform
import sys
import math
import random
@@ -62,6 +63,8 @@
TestCase = object # noqa: F811
NNTestCase = object # noqa: F811
+product_version = float('.'.join(platform.mac_ver()[0].split('.')[:2]))
+
# Determine whether to enable MPS memory leak check (uses same code as CUDA).
TEST_MPS_MEM_LEAK_CHECK = os.getenv('PYTORCH_TEST_MPS_MEM_LEAK_CHECK', '0') == '1'
@@ -2238,6 +2241,7 @@
y_cpu = torch.full((2, 2), 247, device='cpu', dtype=torch.uint8)
self.assertEqual(y_mps, y_cpu)
+ @unittest.skipIf(product_version < 13.0, "Skipped on macOS 12")
# See https://github.com/pytorch/pytorch/issues/84995
def test_div_bugs(self):
for (dtype, mode) in itertools.product(integral_types(), ['trunc', 'floor']):
@@ -3366,6 +3370,7 @@
self.assertEqual(result_cpu, result_mps.to('cpu'))
+ @unittest.skipIf(product_version < 13.0, "Skipped on macOS 12")
def test_signed_vs_unsigned_comparison(self):
cpu_x = torch.tensor((-1, 2, 3), device='cpu', dtype=torch.uint8)
mps_x = torch.tensor((-1, 2, 3), device='mps', dtype=torch.uint8)
@@ -8351,6 +8356,7 @@
self.assertEqual(v[boolIndices], torch.tensor([True], dtype=torch.bool, device=device))
self.assertEqual(len(w), 2)
+ @unittest.skipIf(product_version < 13.0, "Skipped on macOS 12")
def test_bool_indices_accumulate(self, device="mps"):
mask = torch.zeros(size=(10, ), dtype=torch.uint8, device=device)
mask = mask > 0
@@ -8541,6 +8547,7 @@
self.assertEqual(res.shape, src.shape)
[helper(device="mps", dtype=dtype) for dtype in [torch.float, torch.int32]]
+ @unittest.skipIf(product_version < 13.0, "Skipped on macOS 12")
def test_index_src_datatype(self):
def helper(device, dtype):
orig_dtype = dtype