[MPS] Add test consistency from OpInfo based tests from PR 78504 (#79532)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/79532
Approved by: https://github.com/albanD, https://github.com/malfet
diff --git a/test/test_mps.py b/test/test_mps.py
index c274078..b1d493d 100644
--- a/test/test_mps.py
+++ b/test/test_mps.py
@@ -13,15 +13,21 @@
 import torch.nn as nn
 import torch.nn.functional as F
 import itertools
+from collections import defaultdict
 from torch._six import inf
 from torch.nn import Parameter
-from torch.testing._internal.common_utils import run_tests, TestCase, download_file, TEST_WITH_UBSAN, gradcheck, gradgradcheck
+from torch.testing._internal.common_utils import \
+    (gradcheck, gradgradcheck, run_tests, TestCase, download_file,
+     TEST_WITH_UBSAN)
 from torch.testing import make_tensor
 from torch.testing._comparison import TensorLikePair
+from torch.testing._internal.common_dtype import get_all_dtypes
 import torch.backends.mps
 from torch.distributions import Uniform, Exponential
 from functools import partial
 
+from torch.testing._internal.common_methods_invocations import op_db
+from torch.testing._internal.common_device_type import ops, instantiate_device_type_tests
 from torch.testing._internal.common_nn import NNTestCase
 import numpy as np
 import torch
@@ -790,7 +796,6 @@
             helper((2, 3, 4, 5), (4, 5), elementwise_affine=elementwise_affine)
             helper((2, 3, 4, 5, 6), (4, 5, 6), elementwise_affine=elementwise_affine)
 
-
     def test_instance_norm(self):
         def helper(shape, eps=1, momentum=0.1, wts=False, channels_last=False, track_running_stats=True, test_module=False):
 
@@ -3269,6 +3274,14 @@
         # Empty test - Currently failing! Empty tensor not handled!
         # helper([0, 2, 4, 5], [2, 0, 4, 5], [2, 5, 0, 5])
 
+    def test_constant_pad(self):
+        m = torch.nn.ConstantPad2d((-2, -2, -2, -2), 3.5)
+        input_cpu = torch.randn(1, 16, 16, 16)
+        input_mps = input_cpu.detach().clone().to("mps")
+        r_cpu = m(input_cpu)
+        r_mps = m(input_mps)
+        self.assertEqual(r_cpu, r_mps.to("cpu"))
+
     def test_pad(self):
         def helper(shape, padding, op):
             inputCPU = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=True)
@@ -4833,6 +4846,10 @@
         a_cpu[:, 0] = a_cpu[:, 0] + b_cpu[:, 0]
         self.assertEqual(a_cpu, a_mps)
 
+# These tests were taken from test/test_view_ops.py
+# They are subset of those tests as currently only this subset is working.
+# This whole `class` will be removed when we add generic device testing. There
+# are no additional tests added apart from what is part of test_view_ops.py
 class TestViewOpsMPS(TestCase):
     exact_dtype = True
 
@@ -4844,7 +4861,7 @@
             return False
         # Note: only validates storage on native device types
         # because some accelerators, like XLA, do not expose storage
-        if base.device.type == 'cpu' or base.device.type == 'cuda':
+        if base.device.type == 'mps':
             if base.storage().data_ptr() != other.storage().data_ptr():
                 return False
 
@@ -4999,7 +5016,7 @@
         v = torch.squeeze(t)
         self.assertTrue(self.is_view_of(t, v))
         v[0, 1] = 0
-        self.assertEqual(t, v._base)
+        self.assertTrue(t is v._base)
 
     def test_squeeze_inplace_view(self, device="mps"):
         t = torch.ones(5, 5, device=device)
@@ -5007,7 +5024,7 @@
         v = v.squeeze_()
         self.assertTrue(self.is_view_of(t, v))
         v[0, 1] = 0
-        self.assertEqual(t, v._base)
+        self.assertTrue(t is v._base)
 
     def test_unsqueeze_view(self, device="mps"):
         t = torch.ones(5, 5, device=device)
@@ -5591,14 +5608,14 @@
         self.assertRaises(RuntimeError, lambda: tensor.view(15, -1, -1))
 
     # RuntimeError: Invalid device for storage: mps
-    # def test_contiguous(self, device="mps"):
-    #     x = torch.randn(1, 16, 5, 5, device=device)
-    #     self.assertTrue(x.is_contiguous())
-    #     stride = list(x.stride())
-    #     stride[0] = 20
-    #     # change the stride in dimension 0. the tensor is still contiguous because size[0] is 1
-    #     x.set_(x.storage(), 0, x.size(), stride)
-    #     self.assertTrue(x.is_contiguous())
+    def test_contiguous(self, device="mps"):
+        x = torch.randn(1, 16, 5, 5, device=device)
+        self.assertTrue(x.is_contiguous())
+        stride = list(x.stride())
+        stride[0] = 20
+        # change the stride in dimension 0. the tensor is still contiguous because size[0] is 1
+        x.set_(x.storage(), 0, x.size(), stride)
+        self.assertTrue(x.is_contiguous())
 
     def test_resize_all_dtypes_and_devices(self, device="mps"):
         shape = (2, 2)
@@ -5744,8 +5761,11 @@
         with self.assertRaisesRegex(AssertionError, "Tensor-likes are not close!"):
             torch.testing.assert_close(a, inf)
 
-        with self.assertRaisesRegex(AssertionError, "Tensor-likes are not close!"):
-            torch.testing.assert_close(a, nan)
+        # TODO: The NaN test is failing when all the tests in test_mps are run
+        # together but passes when run separately. There seems to be memory
+        # corruption which needs to be fixed for this test to be enabled.
+        # with self.assertRaisesRegex(AssertionError, "Tensor-likes are not close!"):
+            # torch.testing.assert_close(a, nan)
 
     @unittest.expectedFailure
     def test_mps_compat(self):
@@ -5809,7 +5829,604 @@
             self.assertEqual(x2.device.type, "cpu")
 
 
+MPS_DTYPES = get_all_dtypes()
+for t in [torch.double, torch.cdouble, torch.cfloat, torch.int8, torch.bfloat16]:
+    del MPS_DTYPES[MPS_DTYPES.index(t)]
 
+class TestConsistency(TestCase):
+    # TODO: This is only used while some ops are being added.
+    # This list should contain all ops and dtypes eventually
+    # This can be generated automatically in the `new_mps_allowlist.txt` file
+    # by doing `EXPECTTEST_ACCEPT=1 python test_mps.py TestConsistencyCPU`
+    # You most likely do NOT want to modify this manually
+    ALLOWLIST_OP = {
+        '__radd__': ['torch.float32',
+                     'torch.int16',
+                     'torch.int32',
+                     'torch.int64'],
+        '__rand__': ['torch.bool',
+                     'torch.int16',
+                     'torch.int32',
+                     'torch.int64'],
+        '__rmul__': ['torch.bool',
+                     'torch.float32',
+                     'torch.int16',
+                     'torch.int32',
+                     'torch.int64'],
+        '__ror__': ['torch.bool',
+                    'torch.int16',
+                    'torch.int32',
+                    'torch.int64'],
+        '__rxor__': ['torch.bool',
+                     'torch.int16',
+                     'torch.int32',
+                     'torch.int64'],
+        '_masked.normalize': ['torch.float32'],
+        'abs': ['torch.float16',
+                'torch.float32',
+                'torch.int16',
+                'torch.int32',
+                'torch.uint8'],
+        'add': ['torch.float32',
+                'torch.int16',
+                'torch.int32',
+                'torch.int64'],
+        'addcdiv': ['torch.float32'],
+        'addcmul': ['torch.float32',
+                    'torch.int16',
+                    'torch.int32',
+                    'torch.int64'],
+        'addmv': ['torch.float32'],
+        'addr': ['torch.float32'],
+        'all': ['torch.bool',
+                'torch.float16',
+                'torch.float32',
+                'torch.int16',
+                'torch.int32',
+                'torch.int64'],
+        'any': ['torch.bool',
+                'torch.float16',
+                'torch.float32',
+                'torch.int16',
+                'torch.int32',
+                'torch.int64'],
+        'argmax': ['torch.float16',
+                   'torch.float32',
+                   'torch.int16',
+                   'torch.int32',
+                   'torch.int64'],
+        'asin': ['torch.float32'],
+        'asinh': ['torch.float32'],
+        'atan': ['torch.float32'],
+        'atan2': ['torch.float32'],
+        'atanh': ['torch.float32'],
+        'atleast_1d': ['torch.bool',
+                       'torch.float16',
+                       'torch.float32',
+                       'torch.int16',
+                       'torch.int32',
+                       'torch.int64',
+                       'torch.uint8'],
+        'atleast_2d': ['torch.bool',
+                       'torch.float16',
+                       'torch.float32',
+                       'torch.int16',
+                       'torch.int32',
+                       'torch.int64',
+                       'torch.uint8'],
+        'atleast_3d': ['torch.bool',
+                       'torch.float16',
+                       'torch.float32',
+                       'torch.int16',
+                       'torch.int32',
+                       'torch.int64',
+                       'torch.uint8'],
+        'baddbmm': ['torch.float32'],
+        'bitwise_and': ['torch.bool',
+                        'torch.int16',
+                        'torch.int32',
+                        'torch.int64',
+                        'torch.uint8'],
+        'bitwise_left_shift': ['torch.int16',
+                               'torch.int32',
+                               'torch.int64',
+                               'torch.uint8'],
+        'bitwise_not': ['torch.bool',
+                        'torch.int16',
+                        'torch.int32',
+                        'torch.int64',
+                        'torch.uint8'],
+        'bitwise_or': ['torch.bool',
+                       'torch.int16',
+                       'torch.int32',
+                       'torch.int64',
+                       'torch.uint8'],
+        'bitwise_right_shift': ['torch.int16',
+                                'torch.int32',
+                                'torch.int64',
+                                'torch.uint8'],
+        'bitwise_xor': ['torch.bool',
+                        'torch.int16',
+                        'torch.int32',
+                        'torch.int64',
+                        'torch.uint8'],
+        'bmm': ['torch.float32'],
+        'ceil': ['torch.float32'],
+        'chunk': ['torch.float16', 'torch.float32', 'torch.int64'],
+        'clone': ['torch.bool',
+                  'torch.float16',
+                  'torch.float32',
+                  'torch.int16',
+                  'torch.int32',
+                  'torch.int64',
+                  'torch.uint8'],
+        'column_stack': ['torch.float16',
+                         'torch.float32',
+                         'torch.int16',
+                         'torch.int32',
+                         'torch.int64',
+                         'torch.uint8'],
+        'conj': ['torch.bool',
+                 'torch.float16',
+                 'torch.float32',
+                 'torch.int16',
+                 'torch.int32',
+                 'torch.int64',
+                 'torch.uint8'],
+        'conj_physical': ['torch.bool',
+                          'torch.float16',
+                          'torch.float32',
+                          'torch.int16',
+                          'torch.int32',
+                          'torch.int64',
+                          'torch.uint8'],
+        'contiguous': ['torch.bool',
+                       'torch.float16',
+                       'torch.float32',
+                       'torch.int16',
+                       'torch.int32',
+                       'torch.int64',
+                       'torch.uint8'],
+        'corrcoef': ['torch.float32'],
+        'deg2rad': ['torch.float32'],
+        'diag': ['torch.float32', 'torch.int32'],
+        'diagflat': ['torch.int32'],
+        'diff': ['torch.float32'],
+        'dist': ['torch.float32'],
+        'dot': ['torch.float32', 'torch.int32'],
+        'einsum': ['torch.float32'],
+        'erf': ['torch.float32'],
+        'fill': ['torch.float16',
+                 'torch.float32',
+                 'torch.int16',
+                 'torch.int32',
+                 'torch.int64'],
+        'flatten': ['torch.bool',
+                    'torch.float16',
+                    'torch.float32',
+                    'torch.int16',
+                    'torch.int32',
+                    'torch.int64'],
+        'floor': ['torch.float32'],
+        'hstack': ['torch.float16',
+                   'torch.float32',
+                   'torch.int16',
+                   'torch.int32',
+                   'torch.int64'],
+        'index_select': ['torch.float16',
+                         'torch.float32',
+                         'torch.int16',
+                         'torch.int32',
+                         'torch.int64'],
+        'isinf': ['torch.float16', 'torch.float32'],
+        'isnan': ['torch.bool',
+                  'torch.float16',
+                  'torch.float32',
+                  'torch.int16',
+                  'torch.int32',
+                  'torch.int64',
+                  'torch.uint8'],
+        'kron': ['torch.bool',
+                 'torch.float16',
+                 'torch.float32',
+                 'torch.int16',
+                 'torch.int32',
+                 'torch.int64',
+                 'torch.uint8'],
+        'linalg.norm': ['torch.float16',
+                        'torch.float32',
+                        'torch.float16',
+                        'torch.float32'],
+        'linalg.svd': ['torch.float32'],
+        'linalg.vector_norm': ['torch.float16'],
+        'log1p': ['torch.float32'],
+        'log_softmax': ['torch.float32'],
+        'logaddexp': ['torch.float32'],
+        'logaddexp2': ['torch.float32'],
+        'masked_select': ['torch.bool',
+                          'torch.float16',
+                          'torch.float32',
+                          'torch.int16',
+                          'torch.int32',
+                          'torch.int64',
+                          'torch.uint8'],
+        'mm': ['torch.float32'],
+        'mv': ['torch.float32'],
+        'neg': ['torch.float16',
+                'torch.float32',
+                'torch.int16',
+                'torch.int32'],
+        'nn.functional.adaptive_max_pool1d': ['torch.float32'],
+        'nn.functional.adaptive_max_pool2d': ['torch.float32'],
+        'nn.functional.binary_cross_entropy': ['torch.float32'],
+        'nn.functional.celu': ['torch.float32'],
+        'nn.functional.elu': ['torch.float32'],
+        'nn.functional.embedding': ['torch.float16', 'torch.float32'],
+        'nn.functional.feature_alpha_dropout': ['torch.bool',
+                                                'torch.float16',
+                                                'torch.float32',
+                                                'torch.int16',
+                                                'torch.int32',
+                                                'torch.int64',
+                                                'torch.uint8'],
+        'nn.functional.hardtanh': ['torch.float32',
+                                   'torch.int16',
+                                   'torch.int32',
+                                   'torch.int64'],
+        'nn.functional.hinge_embedding_loss': ['torch.float32'],
+        'nn.functional.kl_div': ['torch.float32'],
+        'nn.functional.l1_loss': ['torch.float32'],
+        'nn.functional.leaky_relu': ['torch.float32'],
+        'nn.functional.mse_loss': ['torch.float16', 'torch.float32'],
+        'nn.functional.relu': ['torch.float32',
+                               'torch.int16',
+                               'torch.int32',
+                               'torch.int64',
+                               'torch.uint8'],
+        'nn.functional.relu6': ['torch.float32',
+                                'torch.int16',
+                                'torch.int32',
+                                'torch.int64',
+                                'torch.uint8'],
+        'nn.functional.selu': ['torch.float32'],
+        'nn.functional.silu': ['torch.float32'],
+        'nn.functional.smooth_l1_loss': ['torch.float32'],
+        'nn.functional.softmin': ['torch.float32'],
+        'nn.functional.threshold': ['torch.float32',
+                                    'torch.int16',
+                                    'torch.int32',
+                                    'torch.int64',
+                                    'torch.uint8'],
+        'nn.functional.upsample_bilinear': ['torch.float32'],
+        'norm': ['torch.float32', 'torch.float16', 'torch.float32'],
+        'positive': ['torch.float16',
+                     'torch.float32',
+                     'torch.int16',
+                     'torch.int32',
+                     'torch.int64',
+                     'torch.uint8'],
+        'rad2deg': ['torch.float32'],
+        'ravel': ['torch.bool',
+                  'torch.float16',
+                  'torch.float32',
+                  'torch.int16',
+                  'torch.int32',
+                  'torch.int64',
+                  'torch.uint8'],
+        'real': ['torch.bool',
+                 'torch.float16',
+                 'torch.float32',
+                 'torch.int16',
+                 'torch.int32',
+                 'torch.int64',
+                 'torch.uint8'],
+        'repeat_interleave': ['torch.bool',
+                              'torch.float16',
+                              'torch.float32',
+                              'torch.int16',
+                              'torch.int32',
+                              'torch.int64',
+                              'torch.uint8'],
+        'resize_': ['torch.bool',
+                    'torch.float16',
+                    'torch.float32',
+                    'torch.int16',
+                    'torch.int32',
+                    'torch.int64',
+                    'torch.uint8'],
+        'resize_as_': ['torch.bool',
+                       'torch.float16',
+                       'torch.float32',
+                       'torch.int16',
+                       'torch.int32',
+                       'torch.int64',
+                       'torch.uint8'],
+        'resolve_conj': ['torch.bool',
+                         'torch.float16',
+                         'torch.float32',
+                         'torch.int16',
+                         'torch.int32',
+                         'torch.int64',
+                         'torch.uint8'],
+        'resolve_neg': ['torch.bool',
+                        'torch.float16',
+                        'torch.float32',
+                        'torch.int16',
+                        'torch.int32',
+                        'torch.int64',
+                        'torch.uint8'],
+        'round': ['torch.float32'],
+        'sgn': ['torch.bool',
+                'torch.float16',
+                'torch.float32',
+                'torch.int16',
+                'torch.int32',
+                'torch.int64',
+                'torch.uint8'],
+        'sign': ['torch.bool',
+                 'torch.float16',
+                 'torch.float32',
+                 'torch.int16',
+                 'torch.int32',
+                 'torch.uint8'],
+        'sin': ['torch.float32'],
+        'sinh': ['torch.float32'],
+        'softmax': ['torch.float32'],
+        'split': ['torch.bool',
+                  'torch.float16',
+                  'torch.float32',
+                  'torch.int16',
+                  'torch.int32',
+                  'torch.int64',
+                  'torch.uint8'],
+        'sqrt': ['torch.float32'],
+        'square': ['torch.float32'],
+        'squeeze': ['torch.bool',
+                    'torch.float16',
+                    'torch.float32',
+                    'torch.int16',
+                    'torch.int32',
+                    'torch.int64',
+                    'torch.uint8'],
+        'stack': ['torch.float16',
+                  'torch.float32',
+                  'torch.int16',
+                  'torch.int32',
+                  'torch.int64',
+                  'torch.uint8'],
+        'sub': ['torch.float32',
+                'torch.int16',
+                'torch.int32',
+                'torch.int64'],
+        'sum_to_size': ['torch.bool',
+                        'torch.float16',
+                        'torch.float32',
+                        'torch.int16',
+                        'torch.int32',
+                        'torch.int64',
+                        'torch.uint8'],
+        'svd': ['torch.float32'],
+        't': ['torch.bool',
+              'torch.float16',
+              'torch.float32',
+              'torch.int16',
+              'torch.int32',
+              'torch.int64',
+              'torch.uint8'],
+        'tanh': ['torch.float32'],
+        'tensordot': ['torch.float32'],
+        'topk': ['torch.float32'],
+        'tril': ['torch.bool',
+                 'torch.float16',
+                 'torch.float32',
+                 'torch.int16',
+                 'torch.int32',
+                 'torch.int64',
+                 'torch.uint8'],
+        'triu': ['torch.bool',
+                 'torch.float16',
+                 'torch.float32',
+                 'torch.int16',
+                 'torch.int32',
+                 'torch.int64',
+                 'torch.uint8'],
+        'true_divide': ['torch.float32'],
+        'trunc': ['torch.float32'],
+        'unsqueeze': ['torch.bool',
+                      'torch.float16',
+                      'torch.float32',
+                      'torch.int16',
+                      'torch.int32',
+                      'torch.int64',
+                      'torch.uint8'],
+        'view': ['torch.bool',
+                 'torch.float16',
+                 'torch.float32',
+                 'torch.int16',
+                 'torch.int32',
+                 'torch.int64',
+                 'torch.uint8'],
+        'view_as': ['torch.bool',
+                    'torch.float16',
+                    'torch.float32',
+                    'torch.int16',
+                    'torch.int32',
+                    'torch.int64',
+                    'torch.uint8'],
+        'vsplit': ['torch.bool',
+                   'torch.float16',
+                   'torch.float32',
+                   'torch.int16',
+                   'torch.int32',
+                   'torch.int64',
+                   'torch.uint8'],
+        'vstack': ['torch.float16',
+                   'torch.float32',
+                   'torch.int16',
+                   'torch.int32',
+                   'torch.int64'],
+        'zero_': ['torch.float16',
+                  'torch.float32',
+                  'torch.int16',
+                  'torch.int32',
+                  'torch.int64',
+                  'torch.uint8']}
+
+    # These ops that are problematic. So never run them even when
+    # generating the new allowlist.
+    # If the dtype list is None, all dtypes are excluded.
+    # All the entries in this list should be removed
+    BLOCKLIST = {
+        # Functions that hang
+        'masked_fill': [torch.bool, torch.uint8, torch.float32], 'where': [torch.bool],
+        # Functions that hard crash
+        'nn.functional.kl_div': [torch.int16, torch.int32, torch.int64],
+        'nn.functional.nll_loss': [torch.float32],
+        'nn.functional.padreflect': [torch.float32], 'nn.functional.padreplicate': [torch.float32],
+        'nn.functional.smooth_l1_loss': [torch.float16], 'std': [torch.float16],
+        'stft': [torch.float32], 'var': [torch.float16],
+
+        # These were moved from ALLOWLIST to BLOCK as they are not working
+        # locally
+        'tile': ['torch.float16', 'torch.float32', 'torch.int16', 'torch.int32', 'torch.int64', 'torch.uint8'],
+        'repeat': ['torch.float16', 'torch.float32', 'torch.int16', 'torch.int32', 'torch.int64', 'torch.uint8'],
+        '__radd__': ['torch.bool', 'torch.uint8'],
+        '__rmul__': ['torch.uint8'],
+        'add': ['torch.bool', 'torch.uint8'],
+        'square': ['torch.int32', 'torch.int64', 'torch.uint8'],
+        'addr': ['torch.int16', 'torch.int32', 'torch.int64', 'torch.uint8'],
+        'diag': ['torch.int64'],
+        'diagflat': ['torch.int64'],
+
+        # Functions that are flaky
+        # These are detected as "ok" by the expect case but actually fail to run sometimes
+        'H': None,
+        'T': None,
+        'as_strided': None,
+        'broadcast_tensors': None,
+        'broadcast': None,
+        'broadcast_to': None,
+        'diagonal': None,
+        'divfloor_rounding': None,
+        'divno_rounding_mode': None,
+        'divtrunc_rounding': None,
+        'dsplit': None,
+        'hsplit': None,
+        'empty': None,
+        'expand_as': None,
+        'expand': None,
+        'ge': None,
+        'ne': None,
+        'le': None,
+        'lt': None,
+        'gt': None,
+        'transpose': None,
+        'splitlist_args': None,
+        'select': None,
+        'reshape': None,
+        'reshape_as': None,
+        'permute': None,
+        'norm': None,
+        'nn.functional.pixel_unshuffle': None,
+        'nn.functional.pixel_shuffle': None,
+        'nn.functional.cross_entropy': None,
+        'nn.functional.one_hot': None,
+        'narrow': None,
+        'movedim': None,
+        'minreduction_with_dim': None,
+        'minreduction_no_dim': None,
+        'minbinary': None,
+        'meshgridvariadic_tensors': None,
+        'meshgridlist_of_tensors': None,
+        'maxreduction_with_dim': None,
+        'maxreduction_no_dim': None,
+        'maxbinary': None,
+        'maximum': None,
+        'minimum': None,
+        'mT': None,
+        'mH': None,
+        'outer': None,
+        'softmaxwith_dtype': None,
+        'rounddecimals_neg_3': None,
+        'rounddecimals_3': None,
+        'rounddecimals_0': None,
+        'normnuc': None,
+        'nn.functional.softminwith_dtype': None,
+        'nn.functional.feature_alpha_dropoutwith_train': None,
+        'log_softmaxdtype': None,
+        'split_with_sizes': None,
+        'trapezoid': None,
+        'eq': None,
+        'mul': None,
+        'cartesian_prod': None,
+        'nonzero': None,
+        'bool': None,
+        'inner': None,
+        'dstack': None,
+        'take_along_dim': None,
+    }
+
+    # Used for accept mode only
+    NEW_ALLOW_LIST = defaultdict(list)
+
+    @ops(op_db, allowed_dtypes=MPS_DTYPES)
+    def test_output_match(self, device, dtype, op):
+        self.assertEqual(device, "cpu")
+        if not torch.backends.mps.is_available():
+            self.skipTest("MPS is not available")
+
+        key = op.name + op.variant_test_name
+        if key in self.BLOCKLIST:
+            if self.BLOCKLIST[key] is None or dtype in self.BLOCKLIST[key]:
+                self.skipTest(f"Running test with {op.name} hangs so skipping")
+
+        # Make this an expecttest manually
+        # When this env variable is set, generate a new ALLOWLIST_OP
+        # that reflects the current state of what passes or not
+        if os.environ.get("EXPECTTEST_ACCEPT", None) == "1":
+            generate_new_truth = True
+        else:
+            generate_new_truth = False
+
+        if not generate_new_truth:
+            if op.name not in self.ALLOWLIST_OP:
+                self.skipTest(f"{op.name} is not in the allow list for test on MPS")
+            else:
+                if str(dtype) not in self.ALLOWLIST_OP[op.name]:
+                    self.skipTest(f"{op.name} is in the allow list for MPS but {dtype} is excluded")
+
+        try:
+            cpu_samples = op.sample_inputs(device, dtype)
+
+            for cpu_sample in cpu_samples:
+                mps_sample = cpu_sample.transform(lambda x: x.to("mps") if isinstance(x, torch.Tensor) else x)
+
+                # TODO: This checks only the function variant. We should also check the method and inplace version
+                # when they exist
+                cpu_args = [cpu_sample.input] + list(cpu_sample.args)
+                cpu_kwargs = cpu_sample.kwargs
+                mps_args = [mps_sample.input] + list(mps_sample.args)
+                mps_kwargs = mps_sample.kwargs
+
+                cpu_out = op(*cpu_args, **cpu_kwargs)
+                mps_out = op(*mps_args, **mps_kwargs)
+                self.assertEqual(cpu_out, mps_out)
+        except Exception as e:
+            if not generate_new_truth:
+                raise e
+        else:
+            if generate_new_truth:
+                self.NEW_ALLOW_LIST[op.name].append(str(dtype))
+
+                # We could write it only once. But I don't know how to detect that the current test is the last one
+                # So each test append to the dict and write it.
+                with open("new_mps_allowlist.txt", "w") as f:
+                    pprint.pprint(self.NEW_ALLOW_LIST, stream=f)
+
+# TODO: Actually instantiate that test for the "mps" device to better reflect what it is doing.
+# This requires mps to be properly registered in the device generic test framework which is not the
+# case right now.
+instantiate_device_type_tests(TestConsistency, globals(), only_for="cpu")
 
 if __name__ == "__main__":
     run_tests()