[MPS] Handle MPS failures of test_modules.py in common_modules.py (#95334)
- Also cleaned up `test_modules.py` from skipMPS code.
- Added `skipMPS` for unsupported or failing tests on MPS backend in common_modules.py.
(We'll remove `skipMPS` from those tests once a fix is available for them.)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/95334
Approved by: https://github.com/kulinseth, https://github.com/albanD
diff --git a/test/run_test.py b/test/run_test.py
index 53cbe4f..d92ca21 100755
--- a/test/run_test.py
+++ b/test/run_test.py
@@ -1265,7 +1265,7 @@
options.exclude.extend(CPP_TESTS)
if options.mps:
- selected_tests = ["test_mps", "test_metal"]
+ selected_tests = ["test_mps", "test_metal", "test_modules"]
else:
# Exclude all mps tests otherwise
options.exclude.extend(["test_mps", "test_metal"])
diff --git a/test/test_modules.py b/test/test_modules.py
index 7a797f8..4463843 100644
--- a/test/test_modules.py
+++ b/test/test_modules.py
@@ -13,7 +13,7 @@
from torch.testing._internal.common_modules import module_db, modules, TrainEvalMode
from torch.testing._internal.common_utils import (
TestCase, run_tests, freeze_rng_state, mock_wrapper, get_tensors_from, gradcheck,
- gradgradcheck, skipIfMps, skipIfTorchInductor)
+ gradgradcheck, skipIfTorchInductor)
from unittest.mock import patch, call
@@ -42,7 +42,6 @@
_check_module(module.named_parameters(), "Parameter")
_check_module(module.named_buffers(), "Buffer")
- @skipIfMps # the test doesn't work on MPS as double types are not supported
@modules(module_db)
def test_forward(self, device, dtype, module_info, training):
module_cls = module_info.module_cls
@@ -211,7 +210,6 @@
m.__repr__()
str(m)
- @skipIfMps
@modules(module_db)
def test_pickle(self, device, dtype, module_info, training):
# Test that module can be pickled and unpickled.
@@ -326,7 +324,6 @@
obj.grad = None
self._traverse_obj(obj, inner_zero_grad)
- @skipIfMps
@modules(module_db)
@skipIfTorchInductor("to be fixed")
def test_non_contiguous_tensors(self, device, dtype, module_info, training):
@@ -585,7 +582,6 @@
if cpu_output.requires_grad:
check_backward(cpu_output, gpu_output)
- @skipIfMps
@modules(module_db)
@skipIfTorchInductor("to be fixed")
def test_memory_format(self, device, dtype, module_info, training):
@@ -685,7 +681,6 @@
# Test whether train and eval modes differ for each module. Use to verify
# that the ModuleInfo entry flag is correct.
- @skipIfMps # the test doesn't work on MPS as double types are not supported
@modules(module_db, train_eval_mode=TrainEvalMode.train_only)
def test_if_train_and_eval_modes_differ(self, device, dtype, module_info, training):
module_cls = module_info.module_cls
@@ -720,7 +715,7 @@
else:
raise e
-instantiate_device_type_tests(TestModule, globals())
+instantiate_device_type_tests(TestModule, globals(), allow_mps=True)
if __name__ == '__main__':
run_tests()
diff --git a/test/test_mps.py b/test/test_mps.py
index f245ef5..28eba43 100644
--- a/test/test_mps.py
+++ b/test/test_mps.py
@@ -10525,7 +10525,7 @@
@ops(mps_ops_error_inputs_modifier(test_error_inputs_op_db), dtypes=OpDTypes.none)
def test_error_inputs(self, device, op):
- self.assertEqual(device, "mps")
+ self.assertEqual(device, "mps:0")
mps_samples = op.error_inputs(device)