[MPS] Add error inputs check (#98167)

Fixes #ISSUE_NUMBER

Pull Request resolved: https://github.com/pytorch/pytorch/pull/98167
Approved by: https://github.com/kulinseth
diff --git a/test/test_mps.py b/test/test_mps.py
index da92ece..9bf767c 100644
--- a/test/test_mps.py
+++ b/test/test_mps.py
@@ -37,13 +37,15 @@
     SpectralFuncInfo,
     BinaryUfuncInfo,
 )
-from torch.testing._internal.common_device_type import ops, dtypes, instantiate_device_type_tests
+from torch.testing._internal.common_device_type import ops, dtypes, instantiate_device_type_tests, OpDTypes
 from torch.testing._internal.common_nn import NNTestCase
 import numpy as np
 import torch
 import torch.utils._pytree as pytree
 from itertools import product
 
+test_consistency_op_db = copy.deepcopy(op_db)
+test_error_inputs_op_db = copy.deepcopy(op_db)
 
 # Copied from `test_ops.py` for the purposes of duplicating `test_numpy_ref`
 _ref_test_ops = tuple(
@@ -744,6 +746,57 @@
                          dtypes=MACOS_12_3_XFAILLIST[key]))
         yield op
 
+def mps_ops_error_inputs_modifier(ops):
+    # Error input samples do not take a dtype argument.
+    XFAILLIST = {
+        # Exceptions are not raised
+        '__rmod__',
+        '__rsub__',
+        'bernoulli',
+        'clamp_max',
+        'clamp_min',
+        'index_add',
+        'trace',
+        'nn.functional.max_pool2d',
+        'nn.functional.gelu',
+        'masked_scatter',
+
+        # unsupported float64 dtype
+        'cat',
+        'complex',
+        'multinomial',
+        'nn.functional.conv1d',
+        'nn.functional.conv2d',
+        'gather',
+        'scatter',
+        'scatter_add',
+
+        # unsupported complex dtypes
+        'masked_fill',
+        'gradient',
+        'fft.hfft',
+        'fft.irfft',
+
+        # MPS does not support tensor dimensions > 16
+        'amax',
+        'amin',
+
+        # unimplemented
+        'logcumsumexp',
+    }
+
+    def addDecorator(op, d) -> None:
+        op.decorators = list(op.decorators) if op.decorators is not None else []
+        op.decorators.append(d)
+
+    for op in ops:
+        if op.error_inputs_func is None:
+            continue
+        key = op.name + op.variant_test_name
+        if key in XFAILLIST:
+            addDecorator(op, DecorateInfo(unittest.expectedFailure))
+        yield op
+
 # Same logic as test_cuda.py
 if not torch.backends.mps.is_available():
     print('MPS not available, skipping tests', file=sys.stderr)
@@ -10223,7 +10276,7 @@
     NEW_ALLOW_LIST = defaultdict(list)
     NEW_ALLOW_LIST_GRAD = defaultdict(list)
 
-    @ops(mps_ops_modifier(op_db), allowed_dtypes=MPS_DTYPES)
+    @ops(mps_ops_modifier(test_consistency_op_db), allowed_dtypes=MPS_DTYPES)
     def test_output_match(self, device, dtype, op):
         self.assertEqual(device, "cpu")
         key = op.name + op.variant_test_name
@@ -10275,7 +10328,7 @@
             self.assertEqual(cpu_out, mps_out, atol=atol, rtol=rtol)
 
 
-    @ops(mps_ops_grad_modifier(copy.deepcopy(op_db)), allowed_dtypes=MPS_GRAD_DTYPES)
+    @ops(mps_ops_grad_modifier(copy.deepcopy(test_consistency_op_db)), allowed_dtypes=MPS_GRAD_DTYPES)
     def test_output_grad_match(self, device, dtype, op):
         self.assertEqual(device, "cpu")
         key = op.name + op.variant_test_name
@@ -10375,6 +10428,30 @@
 
             self.assertEqual(cpu_grad_inputs, mps_grad_inputs, atol=atol, rtol=rtol)
 
+
+class TestErrorInputs(TestCaseMPS):
+    @ops(mps_ops_error_inputs_modifier(test_error_inputs_op_db), dtypes=OpDTypes.none)
+    def test_error_inputs(self, device, op):
+        self.assertEqual(device, "mps")
+
+        mps_samples = op.error_inputs(device)
+
+        for mps_sample in mps_samples:
+            mps_sample_input = mps_sample.sample_input
+            error_type = mps_sample.error_type
+            error_regex = mps_sample.error_regex
+
+            mps_args = [mps_sample_input.input] + list(mps_sample_input.args)
+            mps_kwargs = mps_sample_input.kwargs
+
+            # for tensor_split(), the second tensor arg ("tensor_indices_or_sections") must be on CPU only
+            if (op.name == "tensor_split" and isinstance(mps_args[1], torch.Tensor)):
+                mps_args[1] = mps_args[1].cpu()
+
+            with self.assertRaisesRegex(error_type, error_regex):
+                op(*mps_args, **mps_kwargs)
+
+
 # Copied from `TestCommon` in `test_ops.py`, just enough to duplicate the `test_numpy_ref` for MPS
 @skipIfSlowGradcheckEnv
 class TestCommon(TestCase):
@@ -10433,6 +10510,7 @@
 # case right now. We can probably use `allow_mps` introduced in https://github.com/pytorch/pytorch/pull/87342
 # to achieve this.
 instantiate_device_type_tests(TestConsistency, globals(), only_for="cpu")
+instantiate_device_type_tests(TestErrorInputs, globals(), allow_mps=True, only_for="mps")
 instantiate_device_type_tests(TestCommon, globals(), allow_mps=True, only_for="mps")
 
 if __name__ == "__main__":