Change skipIfs to xfails in test_mps.py for test_isin (#125412)
Follow-up to #124896 to move the added test to use expectedFailure instead of skip.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/125412
Approved by: https://github.com/kulinseth
diff --git a/test/test_mps.py b/test/test_mps.py
index 24c4e2d..cbf8874 100644
--- a/test/test_mps.py
+++ b/test/test_mps.py
@@ -61,10 +61,17 @@
)
)
+def xfailIf(condition):
+ def wrapper(func):
+ if condition:
+ return unittest.expectedFailure(func)
+ else:
+ return func
+ return wrapper
+
def xfailIfMacOS14_4Plus(func):
return unittest.expectedFailure(func) if product_version > 14.3 else func # noqa: F821
-
def mps_ops_grad_modifier(ops):
XFAILLIST_GRAD = {
@@ -901,9 +908,9 @@
'fft.rfft2': None,
'fft.rfftn': None,
'stft': None,
- # Error in TestConsistencyCPU.test_output_match_isin_cpu_int32,
+ # Error in TestConsistencyCPU.test_output_match_isin_cpu fails for integers,
# not reproducible in later OS. Added assert to op if used in < 14.0
- 'isin': None,
+ 'isin': [torch.int64, torch.int32, torch.int16, torch.uint8, torch.int8],
})
UNDEFINED_XFAILLIST = {
@@ -8218,7 +8225,6 @@
[helper(dtype) for dtype in [torch.float32, torch.float16, torch.int32, torch.int16, torch.uint8, torch.int8, torch.bool]]
- @unittest.skipIf(product_version < 14.0, "Skipped on MacOS < 14.0")
def test_isin(self):
def helper(dtype):
shapes = [([2, 5], [3, 5, 2]), ([10, 3, 5], [20, 1, 3]),
@@ -8237,15 +8243,19 @@
B_mps = B.clone().detach().to('mps')
cpu_ref = torch.isin(A, B, invert=inverted)
- if dtype is torch.float16:
+ if dtype in [torch.float16, torch.bfloat16]:
cpu_ref.type(dtype)
mps_out = torch.isin(A_mps, B_mps, invert=inverted)
self.assertEqual(mps_out, cpu_ref)
- [helper(dtype) for dtype in [torch.float32, torch.float16, torch.int32, torch.int16, torch.uint8, torch.int8]]
+ dtypes = [torch.float32, torch.float16, torch.bfloat16, torch.int32, torch.int16, torch.uint8, torch.int8]
+ if product_version < 14.0:
+ # Int types expected to fail on MacOS < 14.0
+ dtypes = [torch.float32, torch.float16, torch.bfloat16]
- @unittest.skipIf(product_version < 14.0, "Skipped on MacOS < 14.0")
+ [helper(dtype) for dtype in dtypes]
+
def test_isin_asserts(self):
A = torch.randn(size=[1, 4], device='mps', dtype=torch.float32)
B = torch.randn(size=[1, 4], device='mps', dtype=torch.float16)