[MPS] Handle MPS failures of test_modules.py in common_modules.py (#95334)
- Also cleaned up `test_modules.py` from skipMPS code.
- Added `skipMPS` for unsupported or failing tests on MPS backend in common_modules.py.
(We'll remove `skipMPS` from those tests once a fix is available for them.)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/95334
Approved by: https://github.com/kulinseth, https://github.com/albanD
diff --git a/aten/src/ATen/native/mps/operations/Indexing.mm b/aten/src/ATen/native/mps/operations/Indexing.mm
index 2b1dec7..7dc1d6f 100644
--- a/aten/src/ATen/native/mps/operations/Indexing.mm
+++ b/aten/src/ATen/native/mps/operations/Indexing.mm
@@ -634,7 +634,7 @@
}
// Empty index
- if (num_indices == 0) {
+ if (num_indices == 0 || self.numel() == 0) {
return output;
}
diff --git a/test/run_test.py b/test/run_test.py
index 53cbe4f..d92ca21 100755
--- a/test/run_test.py
+++ b/test/run_test.py
@@ -1265,7 +1265,7 @@
options.exclude.extend(CPP_TESTS)
if options.mps:
- selected_tests = ["test_mps", "test_metal"]
+ selected_tests = ["test_mps", "test_metal", "test_modules"]
else:
# Exclude all mps tests otherwise
options.exclude.extend(["test_mps", "test_metal"])
diff --git a/test/test_modules.py b/test/test_modules.py
index 7a797f8..4463843 100644
--- a/test/test_modules.py
+++ b/test/test_modules.py
@@ -13,7 +13,7 @@
from torch.testing._internal.common_modules import module_db, modules, TrainEvalMode
from torch.testing._internal.common_utils import (
TestCase, run_tests, freeze_rng_state, mock_wrapper, get_tensors_from, gradcheck,
- gradgradcheck, skipIfMps, skipIfTorchInductor)
+ gradgradcheck, skipIfTorchInductor)
from unittest.mock import patch, call
@@ -42,7 +42,6 @@
_check_module(module.named_parameters(), "Parameter")
_check_module(module.named_buffers(), "Buffer")
- @skipIfMps # the test doesn't work on MPS as double types are not supported
@modules(module_db)
def test_forward(self, device, dtype, module_info, training):
module_cls = module_info.module_cls
@@ -211,7 +210,6 @@
m.__repr__()
str(m)
- @skipIfMps
@modules(module_db)
def test_pickle(self, device, dtype, module_info, training):
# Test that module can be pickled and unpickled.
@@ -326,7 +324,6 @@
obj.grad = None
self._traverse_obj(obj, inner_zero_grad)
- @skipIfMps
@modules(module_db)
@skipIfTorchInductor("to be fixed")
def test_non_contiguous_tensors(self, device, dtype, module_info, training):
@@ -585,7 +582,6 @@
if cpu_output.requires_grad:
check_backward(cpu_output, gpu_output)
- @skipIfMps
@modules(module_db)
@skipIfTorchInductor("to be fixed")
def test_memory_format(self, device, dtype, module_info, training):
@@ -685,7 +681,6 @@
# Test whether train and eval modes differ for each module. Use to verify
# that the ModuleInfo entry flag is correct.
- @skipIfMps # the test doesn't work on MPS as double types are not supported
@modules(module_db, train_eval_mode=TrainEvalMode.train_only)
def test_if_train_and_eval_modes_differ(self, device, dtype, module_info, training):
module_cls = module_info.module_cls
@@ -720,7 +715,7 @@
else:
raise e
-instantiate_device_type_tests(TestModule, globals())
+instantiate_device_type_tests(TestModule, globals(), allow_mps=True)
if __name__ == '__main__':
run_tests()
diff --git a/test/test_mps.py b/test/test_mps.py
index f245ef5..28eba43 100644
--- a/test/test_mps.py
+++ b/test/test_mps.py
@@ -10525,7 +10525,7 @@
@ops(mps_ops_error_inputs_modifier(test_error_inputs_op_db), dtypes=OpDTypes.none)
def test_error_inputs(self, device, op):
- self.assertEqual(device, "mps")
+ self.assertEqual(device, "mps:0")
mps_samples = op.error_inputs(device)
diff --git a/torch/testing/_internal/common_device_type.py b/torch/testing/_internal/common_device_type.py
index 6d05f41..661df20 100644
--- a/torch/testing/_internal/common_device_type.py
+++ b/torch/testing/_internal/common_device_type.py
@@ -10,10 +10,9 @@
import unittest
import os
import torch
-import torch.backends.mps
from torch.testing._internal.common_utils import TestCase, TEST_WITH_ROCM, TEST_MKL, \
skipCUDANonDefaultStreamIf, TEST_WITH_ASAN, TEST_WITH_UBSAN, TEST_WITH_TSAN, \
- IS_SANDCASTLE, IS_FBCODE, IS_REMOTE_GPU, IS_WINDOWS, \
+ IS_SANDCASTLE, IS_FBCODE, IS_REMOTE_GPU, IS_WINDOWS, TEST_MPS, \
_TestParametrizer, compose_parametrize_fns, dtype_name, \
NATIVE_DEVICES, skipIfTorchDynamo
from torch.testing._internal.common_cuda import _get_torch_cuda_version, \
@@ -532,12 +531,25 @@
class MPSTestBase(DeviceTypeTestBase):
device_type = 'mps'
+ primary_device: ClassVar[str]
+
+ @classmethod
+ def get_primary_device(cls):
+ return cls.primary_device
+
+ @classmethod
+ def get_all_devices(cls):
+ # currently only one device is supported on MPS backend
+ prim_device = cls.get_primary_device()
+ return [prim_device]
+
+ @classmethod
+ def setUpClass(cls):
+ cls.primary_device = 'mps:0'
def _should_stop_test_suite(self):
return False
- # TODO: Maybe override `_get_dtypes`, `_get_precision_override`
-
# Adds available device-type-specific test base classes
def get_device_type_test_bases():
# set type to List[Any] due to mypy list-of-union issue:
@@ -633,10 +645,9 @@
generic_members = set(generic_test_class.__dict__.keys()) - set(empty_class.__dict__.keys())
generic_tests = [x for x in generic_members if x.startswith('test')]
- # MPS backend support is disabled in `get_device_type_test_bases` while support is being ramped
- # up, so allow callers to specifically opt tests into being tested on MPS, similar to `include_lazy`
+ # allow callers to specifically opt tests into being tested on MPS, similar to `include_lazy`
test_bases = device_type_test_bases.copy()
- if allow_mps and torch.backends.mps.is_available() and MPSTestBase not in test_bases:
+ if allow_mps and TEST_MPS and MPSTestBase not in test_bases:
test_bases.append(MPSTestBase)
# Filter out the device types based on user inputs
desired_device_type_test_bases = filter_desired_device_types(test_bases, except_for, only_for)
@@ -903,6 +914,12 @@
def __init__(self, dep, reason):
super().__init__(dep, reason, device_type='meta')
+# Skips a test on MPS if the condition is true.
+class skipMPSIf(skipIf):
+
+ def __init__(self, dep, reason):
+ super().__init__(dep, reason, device_type='mps')
+
# Skips a test on XLA if the condition is true.
class skipXLAIf(skipIf):
@@ -1350,6 +1367,9 @@
def skipXLA(fn):
return skipXLAIf(True, "Marked as skipped for XLA")(fn)
+def skipMPS(fn):
+ return skipMPSIf(True, "test doesn't work on MPS backend")(fn)
+
# TODO: the "all" in the name isn't true anymore for quite some time as we have also have for example XLA and MPS now.
# This should probably enumerate all available device type test base classes.
def get_all_device_types() -> List[str]:
diff --git a/torch/testing/_internal/common_modules.py b/torch/testing/_internal/common_modules.py
index 683e501..5c182b2 100644
--- a/torch/testing/_internal/common_modules.py
+++ b/torch/testing/_internal/common_modules.py
@@ -10,10 +10,11 @@
from torch.nn.utils.rnn import pack_padded_sequence
from torch.testing import make_tensor
from torch.testing._internal.common_cuda import TEST_CUDNN
-from torch.testing._internal.common_dtype import floating_types, floating_and_complex_types_and, get_all_fp_dtypes
+from torch.testing._internal.common_dtype import (
+ floating_types, floating_and_complex_types_and, get_all_fp_dtypes, complex_types_and)
from torch.testing._internal.common_device_type import (
_TestParametrizer, _update_param_kwargs, toleranceOverride, tol,
- skipCUDAIfCudnnVersionLessThan, skipCUDAIfRocm, precisionOverride, skipMeta, skipCUDAVersionIn)
+ skipCUDAIfCudnnVersionLessThan, skipCUDAIfRocm, precisionOverride, skipMeta, skipMPS, skipCUDAVersionIn)
from torch.testing._internal.common_methods_invocations import DecorateInfo
from torch.testing._internal.common_nn import nllloss_reference, get_reduction
from torch.testing._internal.common_utils import (
@@ -2239,20 +2240,23 @@
ModuleInfo(torch.nn.AdaptiveAvgPool1d,
module_inputs_func=module_inputs_torch_nn_AdaptiveAvgPool1d,
skips=(
- DecorateInfo(skipIfMps, 'TestModule', dtypes=[torch.float64]),)
+ # Fails on MPS backend if input/output sizes are not divisible
+ DecorateInfo(skipMPS),)
),
ModuleInfo(torch.nn.AdaptiveAvgPool2d,
gradcheck_nondet_tol=GRADCHECK_NONDET_TOL,
module_inputs_func=module_inputs_torch_nn_AdaptiveAvgPool2d,
skips=(
- DecorateInfo(skipIfMps, 'TestModule', dtypes=[torch.float64]),)
+ # Fails on MPS backend if input/output sizes are not divisible
+ DecorateInfo(skipMPS),)
),
ModuleInfo(torch.nn.AdaptiveAvgPool3d,
gradcheck_nondet_tol=GRADCHECK_NONDET_TOL,
module_inputs_func=module_inputs_torch_nn_AdaptiveAvgPool3d,
skips=(
DecorateInfo(unittest.skip("Skipped!"), 'TestModule', 'test_memory_format'),
- DecorateInfo(skipIfMps, 'TestModule', dtypes=[torch.float64]),)
+ # not supported on MPS backend
+ DecorateInfo(skipMPS),)
),
ModuleInfo(torch.nn.AdaptiveMaxPool1d,
module_inputs_func=module_inputs_torch_nn_AdaptiveMaxPool1d,
@@ -2270,7 +2274,8 @@
module_inputs_func=module_inputs_torch_nn_AdaptiveMaxPool3d,
skips=(
DecorateInfo(unittest.skip("Skipped!"), 'TestModule', 'test_memory_format'),
- DecorateInfo(skipIfMps, 'TestModule', dtypes=[torch.float64]),)
+ # not supported on MPS backend
+ DecorateInfo(skipMPS),)
),
ModuleInfo(torch.nn.AvgPool1d,
module_inputs_func=module_inputs_torch_nn_AvgPool1d,
@@ -2288,13 +2293,16 @@
skips=(
# No channels_last support for AvgPool1d as it does not take 4D inputs
DecorateInfo(unittest.skip("Skipped!"), 'TestModule', 'test_memory_format'),
- DecorateInfo(skipIfMps, 'TestModule', dtypes=[torch.float64]),)
+ # not supported on MPS backend
+ DecorateInfo(skipMPS),)
),
ModuleInfo(torch.nn.BatchNorm1d,
train_and_eval_differ=True,
module_inputs_func=module_inputs_torch_nn_BatchNorm1d,
skips=(
- DecorateInfo(skipIfMps, 'TestModule', dtypes=[torch.float64]),
+ # test fails on MPS backend and is being investigated.
+ # See https://github.com/pytorch/pytorch/issues/100914
+ DecorateInfo(skipMPS),
# tracking here rather than in the list in test_aotdispatch.py as eval mode passes
# RuntimeError: tried to get Double out of SymInt
DecorateInfo(
@@ -2313,7 +2321,9 @@
train_and_eval_differ=True,
module_inputs_func=module_inputs_torch_nn_BatchNorm2d,
skips=(
- DecorateInfo(skipIfMps, 'TestModule', dtypes=[torch.float64]),
+ # test fails on MPS backend and is being investigated.
+ # See https://github.com/pytorch/pytorch/issues/100914
+ DecorateInfo(skipMPS),
# tracking here rather than in the list in test_aotdispatch.py as eval mode passes
# RuntimeError: tried to get Double out of SymInt
DecorateInfo(
@@ -2332,7 +2342,8 @@
train_and_eval_differ=True,
module_inputs_func=module_inputs_torch_nn_BatchNorm3d,
skips=(
- DecorateInfo(skipIfMps, 'TestModule', dtypes=[torch.float64]),
+ # not supported on MPS backend
+ DecorateInfo(skipMPS),
# tracking here rather than in the list in test_aotdispatch.py as eval mode passes
# RuntimeError: tried to get Double out of SymInt
DecorateInfo(
@@ -2380,6 +2391,9 @@
# See https://github.com/pytorch/pytorch/issues/80247
DecorateInfo(unittest.expectedFailure, "TestModule", "test_memory_format",
device_type='cuda', dtypes=[torch.float64]),
+ # Fails with channels last test on MPS backend
+ DecorateInfo(unittest.expectedFailure, "TestModule", "test_memory_format",
+ device_type='mps', dtypes=[torch.float32]),
),
decorators=(
DecorateInfo(precisionOverride({torch.float32: 1e-04}), 'TestModule', 'test_memory_format'),
@@ -2393,7 +2407,8 @@
DecorateInfo(skipCUDAIfCudnnVersionLessThan(version=8005), 'TestModule', 'test_memory_format'),
# Failure on ROCM for float32 issue #70125
DecorateInfo(skipCUDAIfRocm, 'TestModule', 'test_memory_format', dtypes=[torch.float32]),
- DecorateInfo(skipIfMps, 'TestModule', dtypes=[torch.float64]),
+ # Conv3d is not supported on MPS backend
+ DecorateInfo(skipMPS),
# This was wrongly being skipped before and needs investigation.
# See https://github.com/pytorch/pytorch/issues/80247
DecorateInfo(unittest.expectedFailure, "TestModule", "test_memory_format"),
@@ -2411,7 +2426,8 @@
DecorateInfo(skipCUDAIfCudnnVersionLessThan(version=7603), 'TestModule', 'test_memory_format'),
# Failure on ROCM for float32 issue #70125
DecorateInfo(skipCUDAIfRocm, 'TestModule', 'test_memory_format', dtypes=[torch.float32]),
- DecorateInfo(skipIfMps, 'TestModule', dtypes=[torch.float64]),
+ DecorateInfo(skipIfMps, 'TestModule',
+ dtypes=complex_types_and(torch.chalf, torch.float64, torch.complex128)),
# Not implmented for chalf on CPU
DecorateInfo(unittest.expectedFailure, 'TestModule', 'test_forward',
dtypes=(torch.chalf,), device_type='cpu'),
@@ -2441,7 +2457,8 @@
DecorateInfo(skipCUDAIfCudnnVersionLessThan(version=7603), 'TestModule', 'test_memory_format'),
# Failure on ROCM for float32 issue #70125
DecorateInfo(skipCUDAIfRocm, 'TestModule', 'test_memory_format', dtypes=[torch.float32]),
- DecorateInfo(skipIfMps, 'TestModule', dtypes=[torch.float64]),
+ DecorateInfo(skipIfMps, 'TestModule',
+ dtypes=complex_types_and(torch.chalf, torch.float64, torch.complex128)),
# This was wrongly being skipped before and needs investigation.
# See https://github.com/pytorch/pytorch/issues/80247
DecorateInfo(unittest.expectedFailure, "TestModule", "test_memory_format", device_type='cuda',
@@ -2449,7 +2466,10 @@
# These fail only on ROCm
DecorateInfo(unittest.expectedFailure, "TestModule", "test_memory_format", device_type='cuda',
dtypes=[torch.complex32], active_if=TEST_WITH_ROCM),
- # Not implmented for chalf on CPU
+ # Fails with channels last test on MPS backend
+ DecorateInfo(unittest.expectedFailure, "TestModule", "test_memory_format",
+ device_type='mps', dtypes=[torch.float32]),
+ # Not implemented for chalf on CPU
DecorateInfo(unittest.expectedFailure, 'TestModule', 'test_forward',
dtypes=(torch.chalf,), device_type='cpu'),
DecorateInfo(unittest.expectedFailure, 'TestModule', 'test_memory_format',
@@ -2478,7 +2498,8 @@
DecorateInfo(skipCUDAIfCudnnVersionLessThan(version=8005), 'TestModule', 'test_memory_format'),
# Failure on ROCM for float32 issue #70125
DecorateInfo(skipCUDAIfRocm, 'TestModule', 'test_memory_format', dtypes=[torch.float32]),
- DecorateInfo(skipIfMps, 'TestModule', dtypes=[torch.float64]),
+ # ConvTranspose3d is not supported on MPS backend
+ DecorateInfo(skipMPS),
# This was wrongly being skipped before and needs investigation.
# See https://github.com/pytorch/pytorch/issues/80247
DecorateInfo(unittest.expectedFailure, "TestModule", "test_memory_format"),
@@ -2514,14 +2535,16 @@
module_inputs_func=module_inputs_torch_nn_FractionalMaxPool2d,
gradcheck_nondet_tol=GRADCHECK_NONDET_TOL,
skips=(
- DecorateInfo(skipIfMps, 'TestModule', dtypes=[torch.float64]),
+ # not supported on MPS backend
+ DecorateInfo(skipMPS),
DecorateInfo(unittest.skip("Skipped!"), 'TestModule', 'test_memory_format'),)
),
ModuleInfo(torch.nn.FractionalMaxPool3d,
module_inputs_func=module_inputs_torch_nn_FractionalMaxPool3d,
gradcheck_nondet_tol=GRADCHECK_NONDET_TOL,
skips=(
- DecorateInfo(skipIfMps, 'TestModule', dtypes=[torch.float64]),
+ # not supported on MPS backend
+ DecorateInfo(skipMPS),
DecorateInfo(unittest.skip("Skipped!"), 'TestModule', 'test_memory_format'),)
),
ModuleInfo(torch.nn.L1Loss,
@@ -2565,6 +2588,9 @@
# See https://github.com/pytorch/pytorch/issues/80247
DecorateInfo(unittest.expectedFailure, "TestModule", "test_memory_format",
device_type='cuda', dtypes=[torch.float64]),
+ # Fails with channels last test on MPS backend
+ DecorateInfo(unittest.expectedFailure, "TestModule", "test_memory_format",
+ device_type='mps', dtypes=[torch.float32]),
),
decorators=(
DecorateInfo(precisionOverride({torch.float32: 1e-04}), 'TestModule', 'test_memory_format'),
@@ -2581,7 +2607,8 @@
# Lazy modules don't currently play well with ModuleInfo tests on the meta device.
# See https://github.com/pytorch/pytorch/issues/70505 for more info.
DecorateInfo(skipMeta),
- DecorateInfo(skipIfMps, 'TestModule', dtypes=[torch.float64]),
+ # LazyConv3d is not supported on MPS backend
+ DecorateInfo(skipMPS),
# This was wrongly being skipped before and needs investigation.
# See https://github.com/pytorch/pytorch/issues/80247
DecorateInfo(unittest.expectedFailure, "TestModule", "test_memory_format"),
@@ -2623,6 +2650,9 @@
# See https://github.com/pytorch/pytorch/issues/80247
DecorateInfo(unittest.expectedFailure, "TestModule", "test_memory_format", device_type='cuda',
dtypes=[torch.float64]),
+ # Fails with channels last test on MPS backend
+ DecorateInfo(unittest.expectedFailure, "TestModule", "test_memory_format",
+ device_type='mps', dtypes=[torch.float32]),
),
decorators=(
DecorateInfo(precisionOverride({torch.float32: 1e-04}), 'TestModule', 'test_memory_format'),
@@ -2639,7 +2669,8 @@
# Lazy modules don't currently play well with ModuleInfo tests on the meta device.
# See https://github.com/pytorch/pytorch/issues/70505 for more info.
DecorateInfo(skipMeta),
- DecorateInfo(skipIfMps, 'TestModule', dtypes=[torch.float64]),
+ # LazyConvTranspose3d is not supported on MPS backend
+ DecorateInfo(skipMPS),
# This was wrongly being skipped before and needs investigation.
# See https://github.com/pytorch/pytorch/issues/80247
DecorateInfo(unittest.expectedFailure, "TestModule", "test_memory_format"),
@@ -2697,7 +2728,8 @@
gradcheck_nondet_tol=GRADCHECK_NONDET_TOL,
skips=(
DecorateInfo(unittest.skip("Skipped!"), 'TestModule', 'test_memory_format'),
- DecorateInfo(skipIfMps, 'TestModule', dtypes=[torch.float64]),)
+ # not supported on MPS backend
+ DecorateInfo(skipMPS),)
),
ModuleInfo(torch.nn.NLLLoss,
module_inputs_func=module_inputs_torch_nn_NLLLoss,
@@ -2731,7 +2763,7 @@
module_inputs_func=module_inputs_torch_nn_GroupNorm,
dtypes=get_all_fp_dtypes(include_bfloat16=True, include_half=False),
skips=(
- DecorateInfo(skipIfMps, 'TestModule', dtypes=[torch.float64]),
+ DecorateInfo(skipIfMps, 'TestModule', dtypes=[torch.float64, torch.bfloat16]),
# Tracking at https://github.com/pytorch/pytorch/issues/98089
DecorateInfo(unittest.skip("Skipped!"), 'TestModule', 'test_cpu_gpu_parity'),
# No channels_last support for GroupNorm currently.
@@ -2742,7 +2774,8 @@
ModuleInfo(torch.nn.Hardshrink,
module_inputs_func=module_inputs_torch_nn_Hardshrink,
skips=(
- DecorateInfo(skipIfMps, 'TestModule', dtypes=[torch.float64]),),
+ # not supported on MPS backend
+ DecorateInfo(skipMPS),),
),
ModuleInfo(torch.nn.Hardswish,
module_inputs_func=module_inputs_torch_nn_Hardswish,
@@ -2774,14 +2807,16 @@
module_inputs_func=partial(module_inputs_torch_nn_InstanceNormNd, N=3),
train_and_eval_differ=True,
skips=(
- DecorateInfo(skipIfMps, 'TestModule', dtypes=[torch.float64]),
+ # not supported on MPS backend
+ DecorateInfo(skipMPS),
# No channels_last support for InstanceNorm3d currently.
DecorateInfo(unittest.skip("Skipped!"), 'TestModule', 'test_memory_format'),)
),
ModuleInfo(torch.nn.LocalResponseNorm,
module_inputs_func=module_inputs_torch_nn_LocalResponseNorm,
skips=(
- DecorateInfo(skipIfMps, 'TestModule', dtypes=[torch.float64]),)
+ # uses avg_pool3d which is not supported on MPS backend
+ DecorateInfo(skipMPS),)
),
ModuleInfo(torch.nn.LayerNorm,
module_inputs_func=module_inputs_torch_nn_LayerNorm,
@@ -2856,12 +2891,16 @@
ModuleInfo(torch.nn.ReLU6,
module_inputs_func=module_inputs_torch_nn_ReLU6,
skips=(
- DecorateInfo(skipIfMps, 'TestModule', dtypes=[torch.float64]),)
+ # test fails on MPS backend and is being investigated.
+ # See https://github.com/pytorch/pytorch/issues/100914
+ DecorateInfo(skipMPS),)
),
ModuleInfo(torch.nn.PReLU,
module_inputs_func=module_inputs_torch_nn_PReLU,
skips=(
- DecorateInfo(skipIfMps, 'TestModule', dtypes=[torch.float64]),)
+ # test fails on MPS backend and is being investigated.
+ # See https://github.com/pytorch/pytorch/issues/100914
+ DecorateInfo(skipMPS),)
),
ModuleInfo(torch.nn.RNNCell,
module_inputs_func=partial(module_inputs_torch_nn_RNN_GRU_Cell, is_rnn=True),
@@ -2922,12 +2961,15 @@
ModuleInfo(torch.nn.Softplus,
module_inputs_func=module_inputs_torch_nn_Softplus,
skips=(
- DecorateInfo(skipIfMps, 'TestModule', dtypes=[torch.float64]),)
+ # test fails on MPS backend and is being investigated.
+ # See https://github.com/pytorch/pytorch/issues/100914
+ DecorateInfo(skipMPS),)
),
ModuleInfo(torch.nn.Softshrink,
module_inputs_func=module_inputs_torch_nn_Softshrink,
skips=(
- DecorateInfo(skipIfMps, 'TestModule', dtypes=[torch.float64]),)
+ # not supported on MPS backend
+ DecorateInfo(skipMPS),)
),
ModuleInfo(torch.nn.Softsign,
module_inputs_func=module_inputs_torch_nn_Softsign,
@@ -2947,12 +2989,15 @@
ModuleInfo(torch.nn.Threshold,
module_inputs_func=module_inputs_torch_nn_Threshold,
skips=(
- DecorateInfo(skipIfMps, 'TestModule', dtypes=[torch.float64]),)
+ # test fails on MPS backend and is being investigated.
+ # See https://github.com/pytorch/pytorch/issues/100914
+ DecorateInfo(skipMPS),)
),
ModuleInfo(torch.nn.Mish,
module_inputs_func=module_inputs_torch_nn_Mish,
skips=(
- DecorateInfo(skipIfMps, 'TestModule', dtypes=[torch.float64]),)
+ # not supported on MPS backend
+ DecorateInfo(skipMPS),)
),
ModuleInfo(torch.nn.RNN,
train_and_eval_differ=True,
@@ -2971,7 +3016,8 @@
train_and_eval_differ=True,
module_inputs_func=module_inputs_torch_nn_LSTM,
skips=(
- DecorateInfo(skipIfMps, 'TestModule', dtypes=[torch.float64]),),
+ # LSTM with projections is not currently supported with MPS
+ DecorateInfo(skipMPS),),
decorators=rnn_gru_lstm_module_info_decorators),
ModuleInfo(torch.nn.ReflectionPad1d,
module_inputs_func=module_inputs_torch_nn_ReflectionPad1d,
@@ -3014,7 +3060,9 @@
ModuleInfo(torch.nn.SELU,
module_inputs_func=module_inputs_torch_nn_SELU,
skips=(
- DecorateInfo(skipIfMps, 'TestModule', dtypes=[torch.float64]),)
+ # test fails on MPS backend and is being investigated.
+ # See https://github.com/pytorch/pytorch/issues/100914
+ DecorateInfo(skipMPS),)
),
ModuleInfo(torch.nn.ZeroPad1d,
module_inputs_func=module_inputs_torch_nn_ZeroPad1d,
@@ -3024,12 +3072,16 @@
ModuleInfo(torch.nn.ZeroPad2d,
module_inputs_func=module_inputs_torch_nn_ZeroPad2d,
skips=(
- DecorateInfo(skipIfMps, 'TestModule', dtypes=[torch.float64]),)
+ DecorateInfo(skipIfMps, 'TestModule', dtypes=[torch.float64]),
+ # Fails with channels last test on MPS backend
+ DecorateInfo(unittest.expectedFailure, "TestModule", "test_memory_format", device_type='mps'),)
),
ModuleInfo(torch.nn.ZeroPad3d,
module_inputs_func=module_inputs_torch_nn_ZeroPad3d,
skips=(
- DecorateInfo(skipIfMps, 'TestModule', dtypes=[torch.float64]),)
+ DecorateInfo(skipIfMps, 'TestModule', dtypes=[torch.float64]),
+ # Fails with channels last test on MPS backend
+ DecorateInfo(unittest.expectedFailure, "TestModule", "test_memory_format", device_type='mps'),)
),
ModuleInfo(torch.nn.ConstantPad1d,
module_inputs_func=module_inputs_torch_nn_ConstantPad1d,
@@ -3039,11 +3091,15 @@
ModuleInfo(torch.nn.ConstantPad2d,
module_inputs_func=module_inputs_torch_nn_ConstantPad2d,
skips=(
- DecorateInfo(skipIfMps, 'TestModule', dtypes=[torch.float64]),)
+ DecorateInfo(skipIfMps, 'TestModule', dtypes=[torch.float64]),
+ # Fails with channels last test on MPS backend
+ DecorateInfo(unittest.expectedFailure, "TestModule", "test_memory_format", device_type='mps'),)
),
ModuleInfo(torch.nn.ConstantPad3d,
module_inputs_func=module_inputs_torch_nn_ConstantPad3d,
skips=(
- DecorateInfo(skipIfMps, 'TestModule', dtypes=[torch.float64]),)
+ DecorateInfo(skipIfMps, 'TestModule', dtypes=[torch.float64]),
+ # Fails with channels last test on MPS backend
+ DecorateInfo(unittest.expectedFailure, "TestModule", "test_memory_format", device_type='mps'),)
)
]
diff --git a/torch/testing/_internal/common_utils.py b/torch/testing/_internal/common_utils.py
index 88eb9c3..927115f 100644
--- a/torch/testing/_internal/common_utils.py
+++ b/torch/testing/_internal/common_utils.py
@@ -63,6 +63,7 @@
import torch
import torch.backends.cudnn
import torch.backends.mkl
+import torch.backends.mps
import torch.backends.xnnpack
import torch.cuda
from torch import Tensor
@@ -955,6 +956,7 @@
TEST_FAIRSEQ = _check_module_exists('fairseq')
TEST_SCIPY = _check_module_exists('scipy')
TEST_MKL = torch.backends.mkl.is_available()
+TEST_MPS = torch.backends.mps.is_available()
TEST_CUDA = torch.cuda.is_available()
TEST_NUMBA = _check_module_exists('numba')
@@ -1143,7 +1145,7 @@
def skipIfMps(fn):
@wraps(fn)
def wrapper(*args, **kwargs):
- if torch.backends.mps.is_available():
+ if TEST_MPS:
raise unittest.SkipTest("test doesn't currently work with MPS")
else:
fn(*args, **kwargs)