[MPS] Add error inputs check (#98167)
Fixes #ISSUE_NUMBER
Pull Request resolved: https://github.com/pytorch/pytorch/pull/98167
Approved by: https://github.com/kulinseth
diff --git a/test/test_mps.py b/test/test_mps.py
index da92ece..9bf767c 100644
--- a/test/test_mps.py
+++ b/test/test_mps.py
@@ -37,13 +37,15 @@
SpectralFuncInfo,
BinaryUfuncInfo,
)
-from torch.testing._internal.common_device_type import ops, dtypes, instantiate_device_type_tests
+from torch.testing._internal.common_device_type import ops, dtypes, instantiate_device_type_tests, OpDTypes
from torch.testing._internal.common_nn import NNTestCase
import numpy as np
import torch
import torch.utils._pytree as pytree
from itertools import product
+test_consistency_op_db = copy.deepcopy(op_db)
+test_error_inputs_op_db = copy.deepcopy(op_db)
# Copied from `test_ops.py` for the purposes of duplicating `test_numpy_ref`
_ref_test_ops = tuple(
@@ -744,6 +746,57 @@
dtypes=MACOS_12_3_XFAILLIST[key]))
yield op
+def mps_ops_error_inputs_modifier(ops):
+ # Error input samples do not take a dtype argument.
+ XFAILLIST = {
+ # Exceptions are not raised
+ '__rmod__',
+ '__rsub__',
+ 'bernoulli',
+ 'clamp_max',
+ 'clamp_min',
+ 'index_add',
+ 'trace',
+ 'nn.functional.max_pool2d',
+ 'nn.functional.gelu',
+ 'masked_scatter',
+
+ # unsupported float64 dtype
+ 'cat',
+ 'complex',
+ 'multinomial',
+ 'nn.functional.conv1d',
+ 'nn.functional.conv2d',
+ 'gather',
+ 'scatter',
+ 'scatter_add',
+
+ # unsupported complex dtypes
+ 'masked_fill',
+ 'gradient',
+ 'fft.hfft',
+ 'fft.irfft',
+
+ # MPS does not support tensor dimensions > 16
+ 'amax',
+ 'amin',
+
+ # unimplemented
+ 'logcumsumexp',
+ }
+
+ def addDecorator(op, d) -> None:
+ op.decorators = list(op.decorators) if op.decorators is not None else []
+ op.decorators.append(d)
+
+ for op in ops:
+ if op.error_inputs_func is None:
+ continue
+ key = op.name + op.variant_test_name
+ if key in XFAILLIST:
+ addDecorator(op, DecorateInfo(unittest.expectedFailure))
+ yield op
+
# Same logic as test_cuda.py
if not torch.backends.mps.is_available():
print('MPS not available, skipping tests', file=sys.stderr)
@@ -10223,7 +10276,7 @@
NEW_ALLOW_LIST = defaultdict(list)
NEW_ALLOW_LIST_GRAD = defaultdict(list)
- @ops(mps_ops_modifier(op_db), allowed_dtypes=MPS_DTYPES)
+ @ops(mps_ops_modifier(test_consistency_op_db), allowed_dtypes=MPS_DTYPES)
def test_output_match(self, device, dtype, op):
self.assertEqual(device, "cpu")
key = op.name + op.variant_test_name
@@ -10275,7 +10328,7 @@
self.assertEqual(cpu_out, mps_out, atol=atol, rtol=rtol)
- @ops(mps_ops_grad_modifier(copy.deepcopy(op_db)), allowed_dtypes=MPS_GRAD_DTYPES)
+ @ops(mps_ops_grad_modifier(copy.deepcopy(test_consistency_op_db)), allowed_dtypes=MPS_GRAD_DTYPES)
def test_output_grad_match(self, device, dtype, op):
self.assertEqual(device, "cpu")
key = op.name + op.variant_test_name
@@ -10375,6 +10428,30 @@
self.assertEqual(cpu_grad_inputs, mps_grad_inputs, atol=atol, rtol=rtol)
+
+class TestErrorInputs(TestCaseMPS):
+ @ops(mps_ops_error_inputs_modifier(test_error_inputs_op_db), dtypes=OpDTypes.none)
+ def test_error_inputs(self, device, op):
+ self.assertEqual(device, "mps")
+
+ mps_samples = op.error_inputs(device)
+
+ for mps_sample in mps_samples:
+ mps_sample_input = mps_sample.sample_input
+ error_type = mps_sample.error_type
+ error_regex = mps_sample.error_regex
+
+ mps_args = [mps_sample_input.input] + list(mps_sample_input.args)
+ mps_kwargs = mps_sample_input.kwargs
+
+ # for tensor_split(), the second tensor arg ("tensor_indices_or_sections") must be on CPU only
+ if (op.name == "tensor_split" and isinstance(mps_args[1], torch.Tensor)):
+ mps_args[1] = mps_args[1].cpu()
+
+ with self.assertRaisesRegex(error_type, error_regex):
+ op(*mps_args, **mps_kwargs)
+
+
# Copied from `TestCommon` in `test_ops.py`, just enough to duplicate the `test_numpy_ref` for MPS
@skipIfSlowGradcheckEnv
class TestCommon(TestCase):
@@ -10433,6 +10510,7 @@
# case right now. We can probably use `allow_mps` introduced in https://github.com/pytorch/pytorch/pull/87342
# to achieve this.
instantiate_device_type_tests(TestConsistency, globals(), only_for="cpu")
+instantiate_device_type_tests(TestErrorInputs, globals(), allow_mps=True, only_for="mps")
instantiate_device_type_tests(TestCommon, globals(), allow_mps=True, only_for="mps")
if __name__ == "__main__":