[MPS] Add test consistency from OpInfo based tests from PR 78504 (#79532)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/79532
Approved by: https://github.com/albanD, https://github.com/malfet
diff --git a/test/test_mps.py b/test/test_mps.py
index c274078..b1d493d 100644
--- a/test/test_mps.py
+++ b/test/test_mps.py
@@ -13,15 +13,21 @@
import torch.nn as nn
import torch.nn.functional as F
import itertools
+from collections import defaultdict
from torch._six import inf
from torch.nn import Parameter
-from torch.testing._internal.common_utils import run_tests, TestCase, download_file, TEST_WITH_UBSAN, gradcheck, gradgradcheck
+from torch.testing._internal.common_utils import \
+ (gradcheck, gradgradcheck, run_tests, TestCase, download_file,
+ TEST_WITH_UBSAN)
from torch.testing import make_tensor
from torch.testing._comparison import TensorLikePair
+from torch.testing._internal.common_dtype import get_all_dtypes
import torch.backends.mps
from torch.distributions import Uniform, Exponential
from functools import partial
+from torch.testing._internal.common_methods_invocations import op_db
+from torch.testing._internal.common_device_type import ops, instantiate_device_type_tests
from torch.testing._internal.common_nn import NNTestCase
import numpy as np
import torch
@@ -790,7 +796,6 @@
helper((2, 3, 4, 5), (4, 5), elementwise_affine=elementwise_affine)
helper((2, 3, 4, 5, 6), (4, 5, 6), elementwise_affine=elementwise_affine)
-
def test_instance_norm(self):
def helper(shape, eps=1, momentum=0.1, wts=False, channels_last=False, track_running_stats=True, test_module=False):
@@ -3269,6 +3274,14 @@
# Empty test - Currently failing! Empty tensor not handled!
# helper([0, 2, 4, 5], [2, 0, 4, 5], [2, 5, 0, 5])
+ def test_constant_pad(self):
+ m = torch.nn.ConstantPad2d((-2, -2, -2, -2), 3.5)
+ input_cpu = torch.randn(1, 16, 16, 16)
+ input_mps = input_cpu.detach().clone().to("mps")
+ r_cpu = m(input_cpu)
+ r_mps = m(input_mps)
+ self.assertEqual(r_cpu, r_mps.to("cpu"))
+
def test_pad(self):
def helper(shape, padding, op):
inputCPU = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=True)
@@ -4833,6 +4846,10 @@
a_cpu[:, 0] = a_cpu[:, 0] + b_cpu[:, 0]
self.assertEqual(a_cpu, a_mps)
+# These tests were taken from test/test_view_ops.py
+# They are subset of those tests as currently only this subset is working.
+# This whole `class` will be removed when we add generic device testing. There
+# are no additional tests added apart from what is part of test_view_ops.py
class TestViewOpsMPS(TestCase):
exact_dtype = True
@@ -4844,7 +4861,7 @@
return False
# Note: only validates storage on native device types
# because some accelerators, like XLA, do not expose storage
- if base.device.type == 'cpu' or base.device.type == 'cuda':
+ if base.device.type == 'mps':
if base.storage().data_ptr() != other.storage().data_ptr():
return False
@@ -4999,7 +5016,7 @@
v = torch.squeeze(t)
self.assertTrue(self.is_view_of(t, v))
v[0, 1] = 0
- self.assertEqual(t, v._base)
+ self.assertTrue(t is v._base)
def test_squeeze_inplace_view(self, device="mps"):
t = torch.ones(5, 5, device=device)
@@ -5007,7 +5024,7 @@
v = v.squeeze_()
self.assertTrue(self.is_view_of(t, v))
v[0, 1] = 0
- self.assertEqual(t, v._base)
+ self.assertTrue(t is v._base)
def test_unsqueeze_view(self, device="mps"):
t = torch.ones(5, 5, device=device)
@@ -5591,14 +5608,14 @@
self.assertRaises(RuntimeError, lambda: tensor.view(15, -1, -1))
# RuntimeError: Invalid device for storage: mps
- # def test_contiguous(self, device="mps"):
- # x = torch.randn(1, 16, 5, 5, device=device)
- # self.assertTrue(x.is_contiguous())
- # stride = list(x.stride())
- # stride[0] = 20
- # # change the stride in dimension 0. the tensor is still contiguous because size[0] is 1
- # x.set_(x.storage(), 0, x.size(), stride)
- # self.assertTrue(x.is_contiguous())
+ def test_contiguous(self, device="mps"):
+ x = torch.randn(1, 16, 5, 5, device=device)
+ self.assertTrue(x.is_contiguous())
+ stride = list(x.stride())
+ stride[0] = 20
+ # change the stride in dimension 0. the tensor is still contiguous because size[0] is 1
+ x.set_(x.storage(), 0, x.size(), stride)
+ self.assertTrue(x.is_contiguous())
def test_resize_all_dtypes_and_devices(self, device="mps"):
shape = (2, 2)
@@ -5744,8 +5761,11 @@
with self.assertRaisesRegex(AssertionError, "Tensor-likes are not close!"):
torch.testing.assert_close(a, inf)
- with self.assertRaisesRegex(AssertionError, "Tensor-likes are not close!"):
- torch.testing.assert_close(a, nan)
+ # TODO: The NaN test is failing when all the tests in test_mps are run
+ # together but passes when run separately. There seems to be memory
+ # corruption which needs to be fixed for this test to be enabled.
+ # with self.assertRaisesRegex(AssertionError, "Tensor-likes are not close!"):
+ # torch.testing.assert_close(a, nan)
@unittest.expectedFailure
def test_mps_compat(self):
@@ -5809,7 +5829,604 @@
self.assertEqual(x2.device.type, "cpu")
+MPS_DTYPES = get_all_dtypes()
+for t in [torch.double, torch.cdouble, torch.cfloat, torch.int8, torch.bfloat16]:
+ del MPS_DTYPES[MPS_DTYPES.index(t)]
+class TestConsistency(TestCase):
+ # TODO: This is only used while some ops are being added.
+ # This list should contain all ops and dtypes eventually
+ # This can be generated automatically in the `new_mps_allowlist.txt` file
+ # by doing `EXPECTTEST_ACCEPT=1 python test_mps.py TestConsistencyCPU`
+ # You most likely do NOT want to modify this manually
+ ALLOWLIST_OP = {
+ '__radd__': ['torch.float32',
+ 'torch.int16',
+ 'torch.int32',
+ 'torch.int64'],
+ '__rand__': ['torch.bool',
+ 'torch.int16',
+ 'torch.int32',
+ 'torch.int64'],
+ '__rmul__': ['torch.bool',
+ 'torch.float32',
+ 'torch.int16',
+ 'torch.int32',
+ 'torch.int64'],
+ '__ror__': ['torch.bool',
+ 'torch.int16',
+ 'torch.int32',
+ 'torch.int64'],
+ '__rxor__': ['torch.bool',
+ 'torch.int16',
+ 'torch.int32',
+ 'torch.int64'],
+ '_masked.normalize': ['torch.float32'],
+ 'abs': ['torch.float16',
+ 'torch.float32',
+ 'torch.int16',
+ 'torch.int32',
+ 'torch.uint8'],
+ 'add': ['torch.float32',
+ 'torch.int16',
+ 'torch.int32',
+ 'torch.int64'],
+ 'addcdiv': ['torch.float32'],
+ 'addcmul': ['torch.float32',
+ 'torch.int16',
+ 'torch.int32',
+ 'torch.int64'],
+ 'addmv': ['torch.float32'],
+ 'addr': ['torch.float32'],
+ 'all': ['torch.bool',
+ 'torch.float16',
+ 'torch.float32',
+ 'torch.int16',
+ 'torch.int32',
+ 'torch.int64'],
+ 'any': ['torch.bool',
+ 'torch.float16',
+ 'torch.float32',
+ 'torch.int16',
+ 'torch.int32',
+ 'torch.int64'],
+ 'argmax': ['torch.float16',
+ 'torch.float32',
+ 'torch.int16',
+ 'torch.int32',
+ 'torch.int64'],
+ 'asin': ['torch.float32'],
+ 'asinh': ['torch.float32'],
+ 'atan': ['torch.float32'],
+ 'atan2': ['torch.float32'],
+ 'atanh': ['torch.float32'],
+ 'atleast_1d': ['torch.bool',
+ 'torch.float16',
+ 'torch.float32',
+ 'torch.int16',
+ 'torch.int32',
+ 'torch.int64',
+ 'torch.uint8'],
+ 'atleast_2d': ['torch.bool',
+ 'torch.float16',
+ 'torch.float32',
+ 'torch.int16',
+ 'torch.int32',
+ 'torch.int64',
+ 'torch.uint8'],
+ 'atleast_3d': ['torch.bool',
+ 'torch.float16',
+ 'torch.float32',
+ 'torch.int16',
+ 'torch.int32',
+ 'torch.int64',
+ 'torch.uint8'],
+ 'baddbmm': ['torch.float32'],
+ 'bitwise_and': ['torch.bool',
+ 'torch.int16',
+ 'torch.int32',
+ 'torch.int64',
+ 'torch.uint8'],
+ 'bitwise_left_shift': ['torch.int16',
+ 'torch.int32',
+ 'torch.int64',
+ 'torch.uint8'],
+ 'bitwise_not': ['torch.bool',
+ 'torch.int16',
+ 'torch.int32',
+ 'torch.int64',
+ 'torch.uint8'],
+ 'bitwise_or': ['torch.bool',
+ 'torch.int16',
+ 'torch.int32',
+ 'torch.int64',
+ 'torch.uint8'],
+ 'bitwise_right_shift': ['torch.int16',
+ 'torch.int32',
+ 'torch.int64',
+ 'torch.uint8'],
+ 'bitwise_xor': ['torch.bool',
+ 'torch.int16',
+ 'torch.int32',
+ 'torch.int64',
+ 'torch.uint8'],
+ 'bmm': ['torch.float32'],
+ 'ceil': ['torch.float32'],
+ 'chunk': ['torch.float16', 'torch.float32', 'torch.int64'],
+ 'clone': ['torch.bool',
+ 'torch.float16',
+ 'torch.float32',
+ 'torch.int16',
+ 'torch.int32',
+ 'torch.int64',
+ 'torch.uint8'],
+ 'column_stack': ['torch.float16',
+ 'torch.float32',
+ 'torch.int16',
+ 'torch.int32',
+ 'torch.int64',
+ 'torch.uint8'],
+ 'conj': ['torch.bool',
+ 'torch.float16',
+ 'torch.float32',
+ 'torch.int16',
+ 'torch.int32',
+ 'torch.int64',
+ 'torch.uint8'],
+ 'conj_physical': ['torch.bool',
+ 'torch.float16',
+ 'torch.float32',
+ 'torch.int16',
+ 'torch.int32',
+ 'torch.int64',
+ 'torch.uint8'],
+ 'contiguous': ['torch.bool',
+ 'torch.float16',
+ 'torch.float32',
+ 'torch.int16',
+ 'torch.int32',
+ 'torch.int64',
+ 'torch.uint8'],
+ 'corrcoef': ['torch.float32'],
+ 'deg2rad': ['torch.float32'],
+ 'diag': ['torch.float32', 'torch.int32'],
+ 'diagflat': ['torch.int32'],
+ 'diff': ['torch.float32'],
+ 'dist': ['torch.float32'],
+ 'dot': ['torch.float32', 'torch.int32'],
+ 'einsum': ['torch.float32'],
+ 'erf': ['torch.float32'],
+ 'fill': ['torch.float16',
+ 'torch.float32',
+ 'torch.int16',
+ 'torch.int32',
+ 'torch.int64'],
+ 'flatten': ['torch.bool',
+ 'torch.float16',
+ 'torch.float32',
+ 'torch.int16',
+ 'torch.int32',
+ 'torch.int64'],
+ 'floor': ['torch.float32'],
+ 'hstack': ['torch.float16',
+ 'torch.float32',
+ 'torch.int16',
+ 'torch.int32',
+ 'torch.int64'],
+ 'index_select': ['torch.float16',
+ 'torch.float32',
+ 'torch.int16',
+ 'torch.int32',
+ 'torch.int64'],
+ 'isinf': ['torch.float16', 'torch.float32'],
+ 'isnan': ['torch.bool',
+ 'torch.float16',
+ 'torch.float32',
+ 'torch.int16',
+ 'torch.int32',
+ 'torch.int64',
+ 'torch.uint8'],
+ 'kron': ['torch.bool',
+ 'torch.float16',
+ 'torch.float32',
+ 'torch.int16',
+ 'torch.int32',
+ 'torch.int64',
+ 'torch.uint8'],
+ 'linalg.norm': ['torch.float16',
+ 'torch.float32',
+ 'torch.float16',
+ 'torch.float32'],
+ 'linalg.svd': ['torch.float32'],
+ 'linalg.vector_norm': ['torch.float16'],
+ 'log1p': ['torch.float32'],
+ 'log_softmax': ['torch.float32'],
+ 'logaddexp': ['torch.float32'],
+ 'logaddexp2': ['torch.float32'],
+ 'masked_select': ['torch.bool',
+ 'torch.float16',
+ 'torch.float32',
+ 'torch.int16',
+ 'torch.int32',
+ 'torch.int64',
+ 'torch.uint8'],
+ 'mm': ['torch.float32'],
+ 'mv': ['torch.float32'],
+ 'neg': ['torch.float16',
+ 'torch.float32',
+ 'torch.int16',
+ 'torch.int32'],
+ 'nn.functional.adaptive_max_pool1d': ['torch.float32'],
+ 'nn.functional.adaptive_max_pool2d': ['torch.float32'],
+ 'nn.functional.binary_cross_entropy': ['torch.float32'],
+ 'nn.functional.celu': ['torch.float32'],
+ 'nn.functional.elu': ['torch.float32'],
+ 'nn.functional.embedding': ['torch.float16', 'torch.float32'],
+ 'nn.functional.feature_alpha_dropout': ['torch.bool',
+ 'torch.float16',
+ 'torch.float32',
+ 'torch.int16',
+ 'torch.int32',
+ 'torch.int64',
+ 'torch.uint8'],
+ 'nn.functional.hardtanh': ['torch.float32',
+ 'torch.int16',
+ 'torch.int32',
+ 'torch.int64'],
+ 'nn.functional.hinge_embedding_loss': ['torch.float32'],
+ 'nn.functional.kl_div': ['torch.float32'],
+ 'nn.functional.l1_loss': ['torch.float32'],
+ 'nn.functional.leaky_relu': ['torch.float32'],
+ 'nn.functional.mse_loss': ['torch.float16', 'torch.float32'],
+ 'nn.functional.relu': ['torch.float32',
+ 'torch.int16',
+ 'torch.int32',
+ 'torch.int64',
+ 'torch.uint8'],
+ 'nn.functional.relu6': ['torch.float32',
+ 'torch.int16',
+ 'torch.int32',
+ 'torch.int64',
+ 'torch.uint8'],
+ 'nn.functional.selu': ['torch.float32'],
+ 'nn.functional.silu': ['torch.float32'],
+ 'nn.functional.smooth_l1_loss': ['torch.float32'],
+ 'nn.functional.softmin': ['torch.float32'],
+ 'nn.functional.threshold': ['torch.float32',
+ 'torch.int16',
+ 'torch.int32',
+ 'torch.int64',
+ 'torch.uint8'],
+ 'nn.functional.upsample_bilinear': ['torch.float32'],
+ 'norm': ['torch.float32', 'torch.float16', 'torch.float32'],
+ 'positive': ['torch.float16',
+ 'torch.float32',
+ 'torch.int16',
+ 'torch.int32',
+ 'torch.int64',
+ 'torch.uint8'],
+ 'rad2deg': ['torch.float32'],
+ 'ravel': ['torch.bool',
+ 'torch.float16',
+ 'torch.float32',
+ 'torch.int16',
+ 'torch.int32',
+ 'torch.int64',
+ 'torch.uint8'],
+ 'real': ['torch.bool',
+ 'torch.float16',
+ 'torch.float32',
+ 'torch.int16',
+ 'torch.int32',
+ 'torch.int64',
+ 'torch.uint8'],
+ 'repeat_interleave': ['torch.bool',
+ 'torch.float16',
+ 'torch.float32',
+ 'torch.int16',
+ 'torch.int32',
+ 'torch.int64',
+ 'torch.uint8'],
+ 'resize_': ['torch.bool',
+ 'torch.float16',
+ 'torch.float32',
+ 'torch.int16',
+ 'torch.int32',
+ 'torch.int64',
+ 'torch.uint8'],
+ 'resize_as_': ['torch.bool',
+ 'torch.float16',
+ 'torch.float32',
+ 'torch.int16',
+ 'torch.int32',
+ 'torch.int64',
+ 'torch.uint8'],
+ 'resolve_conj': ['torch.bool',
+ 'torch.float16',
+ 'torch.float32',
+ 'torch.int16',
+ 'torch.int32',
+ 'torch.int64',
+ 'torch.uint8'],
+ 'resolve_neg': ['torch.bool',
+ 'torch.float16',
+ 'torch.float32',
+ 'torch.int16',
+ 'torch.int32',
+ 'torch.int64',
+ 'torch.uint8'],
+ 'round': ['torch.float32'],
+ 'sgn': ['torch.bool',
+ 'torch.float16',
+ 'torch.float32',
+ 'torch.int16',
+ 'torch.int32',
+ 'torch.int64',
+ 'torch.uint8'],
+ 'sign': ['torch.bool',
+ 'torch.float16',
+ 'torch.float32',
+ 'torch.int16',
+ 'torch.int32',
+ 'torch.uint8'],
+ 'sin': ['torch.float32'],
+ 'sinh': ['torch.float32'],
+ 'softmax': ['torch.float32'],
+ 'split': ['torch.bool',
+ 'torch.float16',
+ 'torch.float32',
+ 'torch.int16',
+ 'torch.int32',
+ 'torch.int64',
+ 'torch.uint8'],
+ 'sqrt': ['torch.float32'],
+ 'square': ['torch.float32'],
+ 'squeeze': ['torch.bool',
+ 'torch.float16',
+ 'torch.float32',
+ 'torch.int16',
+ 'torch.int32',
+ 'torch.int64',
+ 'torch.uint8'],
+ 'stack': ['torch.float16',
+ 'torch.float32',
+ 'torch.int16',
+ 'torch.int32',
+ 'torch.int64',
+ 'torch.uint8'],
+ 'sub': ['torch.float32',
+ 'torch.int16',
+ 'torch.int32',
+ 'torch.int64'],
+ 'sum_to_size': ['torch.bool',
+ 'torch.float16',
+ 'torch.float32',
+ 'torch.int16',
+ 'torch.int32',
+ 'torch.int64',
+ 'torch.uint8'],
+ 'svd': ['torch.float32'],
+ 't': ['torch.bool',
+ 'torch.float16',
+ 'torch.float32',
+ 'torch.int16',
+ 'torch.int32',
+ 'torch.int64',
+ 'torch.uint8'],
+ 'tanh': ['torch.float32'],
+ 'tensordot': ['torch.float32'],
+ 'topk': ['torch.float32'],
+ 'tril': ['torch.bool',
+ 'torch.float16',
+ 'torch.float32',
+ 'torch.int16',
+ 'torch.int32',
+ 'torch.int64',
+ 'torch.uint8'],
+ 'triu': ['torch.bool',
+ 'torch.float16',
+ 'torch.float32',
+ 'torch.int16',
+ 'torch.int32',
+ 'torch.int64',
+ 'torch.uint8'],
+ 'true_divide': ['torch.float32'],
+ 'trunc': ['torch.float32'],
+ 'unsqueeze': ['torch.bool',
+ 'torch.float16',
+ 'torch.float32',
+ 'torch.int16',
+ 'torch.int32',
+ 'torch.int64',
+ 'torch.uint8'],
+ 'view': ['torch.bool',
+ 'torch.float16',
+ 'torch.float32',
+ 'torch.int16',
+ 'torch.int32',
+ 'torch.int64',
+ 'torch.uint8'],
+ 'view_as': ['torch.bool',
+ 'torch.float16',
+ 'torch.float32',
+ 'torch.int16',
+ 'torch.int32',
+ 'torch.int64',
+ 'torch.uint8'],
+ 'vsplit': ['torch.bool',
+ 'torch.float16',
+ 'torch.float32',
+ 'torch.int16',
+ 'torch.int32',
+ 'torch.int64',
+ 'torch.uint8'],
+ 'vstack': ['torch.float16',
+ 'torch.float32',
+ 'torch.int16',
+ 'torch.int32',
+ 'torch.int64'],
+ 'zero_': ['torch.float16',
+ 'torch.float32',
+ 'torch.int16',
+ 'torch.int32',
+ 'torch.int64',
+ 'torch.uint8']}
+
+ # These ops that are problematic. So never run them even when
+ # generating the new allowlist.
+ # If the dtype list is None, all dtypes are excluded.
+ # All the entries in this list should be removed
+ BLOCKLIST = {
+ # Functions that hang
+ 'masked_fill': [torch.bool, torch.uint8, torch.float32], 'where': [torch.bool],
+ # Functions that hard crash
+ 'nn.functional.kl_div': [torch.int16, torch.int32, torch.int64],
+ 'nn.functional.nll_loss': [torch.float32],
+ 'nn.functional.padreflect': [torch.float32], 'nn.functional.padreplicate': [torch.float32],
+ 'nn.functional.smooth_l1_loss': [torch.float16], 'std': [torch.float16],
+ 'stft': [torch.float32], 'var': [torch.float16],
+
+ # These were moved from ALLOWLIST to BLOCK as they are not working
+ # locally
+ 'tile': ['torch.float16', 'torch.float32', 'torch.int16', 'torch.int32', 'torch.int64', 'torch.uint8'],
+ 'repeat': ['torch.float16', 'torch.float32', 'torch.int16', 'torch.int32', 'torch.int64', 'torch.uint8'],
+ '__radd__': ['torch.bool', 'torch.uint8'],
+ '__rmul__': ['torch.uint8'],
+ 'add': ['torch.bool', 'torch.uint8'],
+ 'square': ['torch.int32', 'torch.int64', 'torch.uint8'],
+ 'addr': ['torch.int16', 'torch.int32', 'torch.int64', 'torch.uint8'],
+ 'diag': ['torch.int64'],
+ 'diagflat': ['torch.int64'],
+
+ # Functions that are flaky
+ # These are detected as "ok" by the expect case but actually fail to run sometimes
+ 'H': None,
+ 'T': None,
+ 'as_strided': None,
+ 'broadcast_tensors': None,
+ 'broadcast': None,
+ 'broadcast_to': None,
+ 'diagonal': None,
+ 'divfloor_rounding': None,
+ 'divno_rounding_mode': None,
+ 'divtrunc_rounding': None,
+ 'dsplit': None,
+ 'hsplit': None,
+ 'empty': None,
+ 'expand_as': None,
+ 'expand': None,
+ 'ge': None,
+ 'ne': None,
+ 'le': None,
+ 'lt': None,
+ 'gt': None,
+ 'transpose': None,
+ 'splitlist_args': None,
+ 'select': None,
+ 'reshape': None,
+ 'reshape_as': None,
+ 'permute': None,
+ 'norm': None,
+ 'nn.functional.pixel_unshuffle': None,
+ 'nn.functional.pixel_shuffle': None,
+ 'nn.functional.cross_entropy': None,
+ 'nn.functional.one_hot': None,
+ 'narrow': None,
+ 'movedim': None,
+ 'minreduction_with_dim': None,
+ 'minreduction_no_dim': None,
+ 'minbinary': None,
+ 'meshgridvariadic_tensors': None,
+ 'meshgridlist_of_tensors': None,
+ 'maxreduction_with_dim': None,
+ 'maxreduction_no_dim': None,
+ 'maxbinary': None,
+ 'maximum': None,
+ 'minimum': None,
+ 'mT': None,
+ 'mH': None,
+ 'outer': None,
+ 'softmaxwith_dtype': None,
+ 'rounddecimals_neg_3': None,
+ 'rounddecimals_3': None,
+ 'rounddecimals_0': None,
+ 'normnuc': None,
+ 'nn.functional.softminwith_dtype': None,
+ 'nn.functional.feature_alpha_dropoutwith_train': None,
+ 'log_softmaxdtype': None,
+ 'split_with_sizes': None,
+ 'trapezoid': None,
+ 'eq': None,
+ 'mul': None,
+ 'cartesian_prod': None,
+ 'nonzero': None,
+ 'bool': None,
+ 'inner': None,
+ 'dstack': None,
+ 'take_along_dim': None,
+ }
+
+ # Used for accept mode only
+ NEW_ALLOW_LIST = defaultdict(list)
+
+ @ops(op_db, allowed_dtypes=MPS_DTYPES)
+ def test_output_match(self, device, dtype, op):
+ self.assertEqual(device, "cpu")
+ if not torch.backends.mps.is_available():
+ self.skipTest("MPS is not available")
+
+ key = op.name + op.variant_test_name
+ if key in self.BLOCKLIST:
+ if self.BLOCKLIST[key] is None or dtype in self.BLOCKLIST[key]:
+ self.skipTest(f"Running test with {op.name} hangs so skipping")
+
+ # Make this an expecttest manually
+ # When this env variable is set, generate a new ALLOWLIST_OP
+ # that reflects the current state of what passes or not
+ if os.environ.get("EXPECTTEST_ACCEPT", None) == "1":
+ generate_new_truth = True
+ else:
+ generate_new_truth = False
+
+ if not generate_new_truth:
+ if op.name not in self.ALLOWLIST_OP:
+ self.skipTest(f"{op.name} is not in the allow list for test on MPS")
+ else:
+ if str(dtype) not in self.ALLOWLIST_OP[op.name]:
+ self.skipTest(f"{op.name} is in the allow list for MPS but {dtype} is excluded")
+
+ try:
+ cpu_samples = op.sample_inputs(device, dtype)
+
+ for cpu_sample in cpu_samples:
+ mps_sample = cpu_sample.transform(lambda x: x.to("mps") if isinstance(x, torch.Tensor) else x)
+
+ # TODO: This checks only the function variant. We should also check the method and inplace version
+ # when they exist
+ cpu_args = [cpu_sample.input] + list(cpu_sample.args)
+ cpu_kwargs = cpu_sample.kwargs
+ mps_args = [mps_sample.input] + list(mps_sample.args)
+ mps_kwargs = mps_sample.kwargs
+
+ cpu_out = op(*cpu_args, **cpu_kwargs)
+ mps_out = op(*mps_args, **mps_kwargs)
+ self.assertEqual(cpu_out, mps_out)
+ except Exception as e:
+ if not generate_new_truth:
+ raise e
+ else:
+ if generate_new_truth:
+ self.NEW_ALLOW_LIST[op.name].append(str(dtype))
+
+ # We could write it only once. But I don't know how to detect that the current test is the last one
+ # So each test append to the dict and write it.
+ with open("new_mps_allowlist.txt", "w") as f:
+ pprint.pprint(self.NEW_ALLOW_LIST, stream=f)
+
+# TODO: Actually instantiate that test for the "mps" device to better reflect what it is doing.
+# This requires mps to be properly registered in the device generic test framework which is not the
+# case right now.
+instantiate_device_type_tests(TestConsistency, globals(), only_for="cpu")
if __name__ == "__main__":
run_tests()