[MPS] Handle MPS failures of test_modules.py in common_modules.py (#95334)

- Also cleaned up `test_modules.py` from skipMPS code.
- Added `skipMPS` for unsupported or failing tests on MPS backend in common_modules.py.
   (We'll remove `skipMPS` from those tests once a fix is available for them.)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/95334
Approved by: https://github.com/kulinseth, https://github.com/albanD
diff --git a/aten/src/ATen/native/mps/operations/Indexing.mm b/aten/src/ATen/native/mps/operations/Indexing.mm
index 2b1dec7..7dc1d6f 100644
--- a/aten/src/ATen/native/mps/operations/Indexing.mm
+++ b/aten/src/ATen/native/mps/operations/Indexing.mm
@@ -634,7 +634,7 @@
   }
 
   // Empty index
-  if (num_indices == 0) {
+  if (num_indices == 0 || self.numel() == 0) {
     return output;
   }
 
diff --git a/test/run_test.py b/test/run_test.py
index 53cbe4f..d92ca21 100755
--- a/test/run_test.py
+++ b/test/run_test.py
@@ -1265,7 +1265,7 @@
         options.exclude.extend(CPP_TESTS)
 
     if options.mps:
-        selected_tests = ["test_mps", "test_metal"]
+        selected_tests = ["test_mps", "test_metal", "test_modules"]
     else:
         # Exclude all mps tests otherwise
         options.exclude.extend(["test_mps", "test_metal"])
diff --git a/test/test_modules.py b/test/test_modules.py
index 7a797f8..4463843 100644
--- a/test/test_modules.py
+++ b/test/test_modules.py
@@ -13,7 +13,7 @@
 from torch.testing._internal.common_modules import module_db, modules, TrainEvalMode
 from torch.testing._internal.common_utils import (
     TestCase, run_tests, freeze_rng_state, mock_wrapper, get_tensors_from, gradcheck,
-    gradgradcheck, skipIfMps, skipIfTorchInductor)
+    gradgradcheck, skipIfTorchInductor)
 from unittest.mock import patch, call
 
 
@@ -42,7 +42,6 @@
         _check_module(module.named_parameters(), "Parameter")
         _check_module(module.named_buffers(), "Buffer")
 
-    @skipIfMps  # the test doesn't work on MPS as double types are not supported
     @modules(module_db)
     def test_forward(self, device, dtype, module_info, training):
         module_cls = module_info.module_cls
@@ -211,7 +210,6 @@
             m.__repr__()
             str(m)
 
-    @skipIfMps
     @modules(module_db)
     def test_pickle(self, device, dtype, module_info, training):
         # Test that module can be pickled and unpickled.
@@ -326,7 +324,6 @@
                 obj.grad = None
         self._traverse_obj(obj, inner_zero_grad)
 
-    @skipIfMps
     @modules(module_db)
     @skipIfTorchInductor("to be fixed")
     def test_non_contiguous_tensors(self, device, dtype, module_info, training):
@@ -585,7 +582,6 @@
                         if cpu_output.requires_grad:
                             check_backward(cpu_output, gpu_output)
 
-    @skipIfMps
     @modules(module_db)
     @skipIfTorchInductor("to be fixed")
     def test_memory_format(self, device, dtype, module_info, training):
@@ -685,7 +681,6 @@
 
     # Test whether train and eval modes differ for each module. Use to verify
     # that the ModuleInfo entry flag is correct.
-    @skipIfMps  # the test doesn't work on MPS as double types are not supported
     @modules(module_db, train_eval_mode=TrainEvalMode.train_only)
     def test_if_train_and_eval_modes_differ(self, device, dtype, module_info, training):
         module_cls = module_info.module_cls
@@ -720,7 +715,7 @@
                 else:
                     raise e
 
-instantiate_device_type_tests(TestModule, globals())
+instantiate_device_type_tests(TestModule, globals(), allow_mps=True)
 
 if __name__ == '__main__':
     run_tests()
diff --git a/test/test_mps.py b/test/test_mps.py
index f245ef5..28eba43 100644
--- a/test/test_mps.py
+++ b/test/test_mps.py
@@ -10525,7 +10525,7 @@
 
     @ops(mps_ops_error_inputs_modifier(test_error_inputs_op_db), dtypes=OpDTypes.none)
     def test_error_inputs(self, device, op):
-        self.assertEqual(device, "mps")
+        self.assertEqual(device, "mps:0")
 
         mps_samples = op.error_inputs(device)
 
diff --git a/torch/testing/_internal/common_device_type.py b/torch/testing/_internal/common_device_type.py
index 6d05f41..661df20 100644
--- a/torch/testing/_internal/common_device_type.py
+++ b/torch/testing/_internal/common_device_type.py
@@ -10,10 +10,9 @@
 import unittest
 import os
 import torch
-import torch.backends.mps
 from torch.testing._internal.common_utils import TestCase, TEST_WITH_ROCM, TEST_MKL, \
     skipCUDANonDefaultStreamIf, TEST_WITH_ASAN, TEST_WITH_UBSAN, TEST_WITH_TSAN, \
-    IS_SANDCASTLE, IS_FBCODE, IS_REMOTE_GPU, IS_WINDOWS, \
+    IS_SANDCASTLE, IS_FBCODE, IS_REMOTE_GPU, IS_WINDOWS, TEST_MPS, \
     _TestParametrizer, compose_parametrize_fns, dtype_name, \
     NATIVE_DEVICES, skipIfTorchDynamo
 from torch.testing._internal.common_cuda import _get_torch_cuda_version, \
@@ -532,12 +531,25 @@
 
 class MPSTestBase(DeviceTypeTestBase):
     device_type = 'mps'
+    primary_device: ClassVar[str]
+
+    @classmethod
+    def get_primary_device(cls):
+        return cls.primary_device
+
+    @classmethod
+    def get_all_devices(cls):
+        # currently only one device is supported on MPS backend
+        prim_device = cls.get_primary_device()
+        return [prim_device]
+
+    @classmethod
+    def setUpClass(cls):
+        cls.primary_device = 'mps:0'
 
     def _should_stop_test_suite(self):
         return False
 
-    # TODO: Maybe override `_get_dtypes`, `_get_precision_override`
-
 # Adds available device-type-specific test base classes
 def get_device_type_test_bases():
     # set type to List[Any] due to mypy list-of-union issue:
@@ -633,10 +645,9 @@
     generic_members = set(generic_test_class.__dict__.keys()) - set(empty_class.__dict__.keys())
     generic_tests = [x for x in generic_members if x.startswith('test')]
 
-    # MPS backend support is disabled in `get_device_type_test_bases` while support is being ramped
-    # up, so allow callers to specifically opt tests into being tested on MPS, similar to `include_lazy`
+    # allow callers to specifically opt tests into being tested on MPS, similar to `include_lazy`
     test_bases = device_type_test_bases.copy()
-    if allow_mps and torch.backends.mps.is_available() and MPSTestBase not in test_bases:
+    if allow_mps and TEST_MPS and MPSTestBase not in test_bases:
         test_bases.append(MPSTestBase)
     # Filter out the device types based on user inputs
     desired_device_type_test_bases = filter_desired_device_types(test_bases, except_for, only_for)
@@ -903,6 +914,12 @@
     def __init__(self, dep, reason):
         super().__init__(dep, reason, device_type='meta')
 
+# Skips a test on MPS if the condition is true.
+class skipMPSIf(skipIf):
+
+    def __init__(self, dep, reason):
+        super().__init__(dep, reason, device_type='mps')
+
 # Skips a test on XLA if the condition is true.
 class skipXLAIf(skipIf):
 
@@ -1350,6 +1367,9 @@
 def skipXLA(fn):
     return skipXLAIf(True, "Marked as skipped for XLA")(fn)
 
+def skipMPS(fn):
+    return skipMPSIf(True, "test doesn't work on MPS backend")(fn)
+
 # TODO: the "all" in the name isn't true anymore for quite some time as we have also have for example XLA and MPS now.
 #  This should probably enumerate all available device type test base classes.
 def get_all_device_types() -> List[str]:
diff --git a/torch/testing/_internal/common_modules.py b/torch/testing/_internal/common_modules.py
index 683e501..5c182b2 100644
--- a/torch/testing/_internal/common_modules.py
+++ b/torch/testing/_internal/common_modules.py
@@ -10,10 +10,11 @@
 from torch.nn.utils.rnn import pack_padded_sequence
 from torch.testing import make_tensor
 from torch.testing._internal.common_cuda import TEST_CUDNN
-from torch.testing._internal.common_dtype import floating_types, floating_and_complex_types_and, get_all_fp_dtypes
+from torch.testing._internal.common_dtype import (
+    floating_types, floating_and_complex_types_and, get_all_fp_dtypes, complex_types_and)
 from torch.testing._internal.common_device_type import (
     _TestParametrizer, _update_param_kwargs, toleranceOverride, tol,
-    skipCUDAIfCudnnVersionLessThan, skipCUDAIfRocm, precisionOverride, skipMeta, skipCUDAVersionIn)
+    skipCUDAIfCudnnVersionLessThan, skipCUDAIfRocm, precisionOverride, skipMeta, skipMPS, skipCUDAVersionIn)
 from torch.testing._internal.common_methods_invocations import DecorateInfo
 from torch.testing._internal.common_nn import nllloss_reference, get_reduction
 from torch.testing._internal.common_utils import (
@@ -2239,20 +2240,23 @@
     ModuleInfo(torch.nn.AdaptiveAvgPool1d,
                module_inputs_func=module_inputs_torch_nn_AdaptiveAvgPool1d,
                skips=(
-                   DecorateInfo(skipIfMps, 'TestModule', dtypes=[torch.float64]),)
+                   # Fails on MPS backend if input/output sizes are not divisible
+                   DecorateInfo(skipMPS),)
                ),
     ModuleInfo(torch.nn.AdaptiveAvgPool2d,
                gradcheck_nondet_tol=GRADCHECK_NONDET_TOL,
                module_inputs_func=module_inputs_torch_nn_AdaptiveAvgPool2d,
                skips=(
-                   DecorateInfo(skipIfMps, 'TestModule', dtypes=[torch.float64]),)
+                   # Fails on MPS backend if input/output sizes are not divisible
+                   DecorateInfo(skipMPS),)
                ),
     ModuleInfo(torch.nn.AdaptiveAvgPool3d,
                gradcheck_nondet_tol=GRADCHECK_NONDET_TOL,
                module_inputs_func=module_inputs_torch_nn_AdaptiveAvgPool3d,
                skips=(
                    DecorateInfo(unittest.skip("Skipped!"), 'TestModule', 'test_memory_format'),
-                   DecorateInfo(skipIfMps, 'TestModule', dtypes=[torch.float64]),)
+                   # not supported on MPS backend
+                   DecorateInfo(skipMPS),)
                ),
     ModuleInfo(torch.nn.AdaptiveMaxPool1d,
                module_inputs_func=module_inputs_torch_nn_AdaptiveMaxPool1d,
@@ -2270,7 +2274,8 @@
                module_inputs_func=module_inputs_torch_nn_AdaptiveMaxPool3d,
                skips=(
                    DecorateInfo(unittest.skip("Skipped!"), 'TestModule', 'test_memory_format'),
-                   DecorateInfo(skipIfMps, 'TestModule', dtypes=[torch.float64]),)
+                   # not supported on MPS backend
+                   DecorateInfo(skipMPS),)
                ),
     ModuleInfo(torch.nn.AvgPool1d,
                module_inputs_func=module_inputs_torch_nn_AvgPool1d,
@@ -2288,13 +2293,16 @@
                skips=(
                    # No channels_last support for AvgPool1d as it does not take 4D inputs
                    DecorateInfo(unittest.skip("Skipped!"), 'TestModule', 'test_memory_format'),
-                   DecorateInfo(skipIfMps, 'TestModule', dtypes=[torch.float64]),)
+                   # not supported on MPS backend
+                   DecorateInfo(skipMPS),)
                ),
     ModuleInfo(torch.nn.BatchNorm1d,
                train_and_eval_differ=True,
                module_inputs_func=module_inputs_torch_nn_BatchNorm1d,
                skips=(
-                   DecorateInfo(skipIfMps, 'TestModule', dtypes=[torch.float64]),
+                   # test fails on MPS backend and is being investigated.
+                   # See https://github.com/pytorch/pytorch/issues/100914
+                   DecorateInfo(skipMPS),
                    # tracking here rather than in the list in test_aotdispatch.py as eval mode passes
                    # RuntimeError: tried to get Double out of SymInt
                    DecorateInfo(
@@ -2313,7 +2321,9 @@
                train_and_eval_differ=True,
                module_inputs_func=module_inputs_torch_nn_BatchNorm2d,
                skips=(
-                   DecorateInfo(skipIfMps, 'TestModule', dtypes=[torch.float64]),
+                   # test fails on MPS backend and is being investigated.
+                   # See https://github.com/pytorch/pytorch/issues/100914
+                   DecorateInfo(skipMPS),
                    # tracking here rather than in the list in test_aotdispatch.py as eval mode passes
                    # RuntimeError: tried to get Double out of SymInt
                    DecorateInfo(
@@ -2332,7 +2342,8 @@
                train_and_eval_differ=True,
                module_inputs_func=module_inputs_torch_nn_BatchNorm3d,
                skips=(
-                   DecorateInfo(skipIfMps, 'TestModule', dtypes=[torch.float64]),
+                   # not supported on MPS backend
+                   DecorateInfo(skipMPS),
                    # tracking here rather than in the list in test_aotdispatch.py as eval mode passes
                    # RuntimeError: tried to get Double out of SymInt
                    DecorateInfo(
@@ -2380,6 +2391,9 @@
                    # See https://github.com/pytorch/pytorch/issues/80247
                    DecorateInfo(unittest.expectedFailure, "TestModule", "test_memory_format",
                                 device_type='cuda', dtypes=[torch.float64]),
+                   # Fails with channels last test on MPS backend
+                   DecorateInfo(unittest.expectedFailure, "TestModule", "test_memory_format",
+                                device_type='mps', dtypes=[torch.float32]),
                ),
                decorators=(
                    DecorateInfo(precisionOverride({torch.float32: 1e-04}), 'TestModule', 'test_memory_format'),
@@ -2393,7 +2407,8 @@
                    DecorateInfo(skipCUDAIfCudnnVersionLessThan(version=8005), 'TestModule', 'test_memory_format'),
                    # Failure on ROCM for float32 issue #70125
                    DecorateInfo(skipCUDAIfRocm, 'TestModule', 'test_memory_format', dtypes=[torch.float32]),
-                   DecorateInfo(skipIfMps, 'TestModule', dtypes=[torch.float64]),
+                   # Conv3d is not supported on MPS backend
+                   DecorateInfo(skipMPS),
                    # This was wrongly being skipped before and needs investigation.
                    # See https://github.com/pytorch/pytorch/issues/80247
                    DecorateInfo(unittest.expectedFailure, "TestModule", "test_memory_format"),
@@ -2411,7 +2426,8 @@
                    DecorateInfo(skipCUDAIfCudnnVersionLessThan(version=7603), 'TestModule', 'test_memory_format'),
                    # Failure on ROCM for float32 issue #70125
                    DecorateInfo(skipCUDAIfRocm, 'TestModule', 'test_memory_format', dtypes=[torch.float32]),
-                   DecorateInfo(skipIfMps, 'TestModule', dtypes=[torch.float64]),
+                   DecorateInfo(skipIfMps, 'TestModule',
+                                dtypes=complex_types_and(torch.chalf, torch.float64, torch.complex128)),
                    # Not implmented for chalf on CPU
                    DecorateInfo(unittest.expectedFailure, 'TestModule', 'test_forward',
                                 dtypes=(torch.chalf,), device_type='cpu'),
@@ -2441,7 +2457,8 @@
                    DecorateInfo(skipCUDAIfCudnnVersionLessThan(version=7603), 'TestModule', 'test_memory_format'),
                    # Failure on ROCM for float32 issue #70125
                    DecorateInfo(skipCUDAIfRocm, 'TestModule', 'test_memory_format', dtypes=[torch.float32]),
-                   DecorateInfo(skipIfMps, 'TestModule', dtypes=[torch.float64]),
+                   DecorateInfo(skipIfMps, 'TestModule',
+                                dtypes=complex_types_and(torch.chalf, torch.float64, torch.complex128)),
                    # This was wrongly being skipped before and needs investigation.
                    # See https://github.com/pytorch/pytorch/issues/80247
                    DecorateInfo(unittest.expectedFailure, "TestModule", "test_memory_format", device_type='cuda',
@@ -2449,7 +2466,10 @@
                    # These fail only on ROCm
                    DecorateInfo(unittest.expectedFailure, "TestModule", "test_memory_format", device_type='cuda',
                                 dtypes=[torch.complex32], active_if=TEST_WITH_ROCM),
-                   # Not implmented for chalf on CPU
+                   # Fails with channels last test on MPS backend
+                   DecorateInfo(unittest.expectedFailure, "TestModule", "test_memory_format",
+                                device_type='mps', dtypes=[torch.float32]),
+                   # Not implemented for chalf on CPU
                    DecorateInfo(unittest.expectedFailure, 'TestModule', 'test_forward',
                                 dtypes=(torch.chalf,), device_type='cpu'),
                    DecorateInfo(unittest.expectedFailure, 'TestModule', 'test_memory_format',
@@ -2478,7 +2498,8 @@
                    DecorateInfo(skipCUDAIfCudnnVersionLessThan(version=8005), 'TestModule', 'test_memory_format'),
                    # Failure on ROCM for float32 issue #70125
                    DecorateInfo(skipCUDAIfRocm, 'TestModule', 'test_memory_format', dtypes=[torch.float32]),
-                   DecorateInfo(skipIfMps, 'TestModule', dtypes=[torch.float64]),
+                   # ConvTranspose3d is not supported on MPS backend
+                   DecorateInfo(skipMPS),
                    # This was wrongly being skipped before and needs investigation.
                    # See https://github.com/pytorch/pytorch/issues/80247
                    DecorateInfo(unittest.expectedFailure, "TestModule", "test_memory_format"),
@@ -2514,14 +2535,16 @@
                module_inputs_func=module_inputs_torch_nn_FractionalMaxPool2d,
                gradcheck_nondet_tol=GRADCHECK_NONDET_TOL,
                skips=(
-                   DecorateInfo(skipIfMps, 'TestModule', dtypes=[torch.float64]),
+                   # not supported on MPS backend
+                   DecorateInfo(skipMPS),
                    DecorateInfo(unittest.skip("Skipped!"), 'TestModule', 'test_memory_format'),)
                ),
     ModuleInfo(torch.nn.FractionalMaxPool3d,
                module_inputs_func=module_inputs_torch_nn_FractionalMaxPool3d,
                gradcheck_nondet_tol=GRADCHECK_NONDET_TOL,
                skips=(
-                   DecorateInfo(skipIfMps, 'TestModule', dtypes=[torch.float64]),
+                   # not supported on MPS backend
+                   DecorateInfo(skipMPS),
                    DecorateInfo(unittest.skip("Skipped!"), 'TestModule', 'test_memory_format'),)
                ),
     ModuleInfo(torch.nn.L1Loss,
@@ -2565,6 +2588,9 @@
                    # See https://github.com/pytorch/pytorch/issues/80247
                    DecorateInfo(unittest.expectedFailure, "TestModule", "test_memory_format",
                                 device_type='cuda', dtypes=[torch.float64]),
+                   # Fails with channels last test on MPS backend
+                   DecorateInfo(unittest.expectedFailure, "TestModule", "test_memory_format",
+                                device_type='mps', dtypes=[torch.float32]),
                ),
                decorators=(
                    DecorateInfo(precisionOverride({torch.float32: 1e-04}), 'TestModule', 'test_memory_format'),
@@ -2581,7 +2607,8 @@
                    # Lazy modules don't currently play well with ModuleInfo tests on the meta device.
                    # See https://github.com/pytorch/pytorch/issues/70505 for more info.
                    DecorateInfo(skipMeta),
-                   DecorateInfo(skipIfMps, 'TestModule', dtypes=[torch.float64]),
+                   # LazyConv3d is not supported on MPS backend
+                   DecorateInfo(skipMPS),
                    # This was wrongly being skipped before and needs investigation.
                    # See https://github.com/pytorch/pytorch/issues/80247
                    DecorateInfo(unittest.expectedFailure, "TestModule", "test_memory_format"),
@@ -2623,6 +2650,9 @@
                    # See https://github.com/pytorch/pytorch/issues/80247
                    DecorateInfo(unittest.expectedFailure, "TestModule", "test_memory_format", device_type='cuda',
                                 dtypes=[torch.float64]),
+                   # Fails with channels last test on MPS backend
+                   DecorateInfo(unittest.expectedFailure, "TestModule", "test_memory_format",
+                                device_type='mps', dtypes=[torch.float32]),
                ),
                decorators=(
                    DecorateInfo(precisionOverride({torch.float32: 1e-04}), 'TestModule', 'test_memory_format'),
@@ -2639,7 +2669,8 @@
                    # Lazy modules don't currently play well with ModuleInfo tests on the meta device.
                    # See https://github.com/pytorch/pytorch/issues/70505 for more info.
                    DecorateInfo(skipMeta),
-                   DecorateInfo(skipIfMps, 'TestModule', dtypes=[torch.float64]),
+                   # LazyConvTranspose3d is not supported on MPS backend
+                   DecorateInfo(skipMPS),
                    # This was wrongly being skipped before and needs investigation.
                    # See https://github.com/pytorch/pytorch/issues/80247
                    DecorateInfo(unittest.expectedFailure, "TestModule", "test_memory_format"),
@@ -2697,7 +2728,8 @@
                gradcheck_nondet_tol=GRADCHECK_NONDET_TOL,
                skips=(
                    DecorateInfo(unittest.skip("Skipped!"), 'TestModule', 'test_memory_format'),
-                   DecorateInfo(skipIfMps, 'TestModule', dtypes=[torch.float64]),)
+                   # not supported on MPS backend
+                   DecorateInfo(skipMPS),)
                ),
     ModuleInfo(torch.nn.NLLLoss,
                module_inputs_func=module_inputs_torch_nn_NLLLoss,
@@ -2731,7 +2763,7 @@
                module_inputs_func=module_inputs_torch_nn_GroupNorm,
                dtypes=get_all_fp_dtypes(include_bfloat16=True, include_half=False),
                skips=(
-                   DecorateInfo(skipIfMps, 'TestModule', dtypes=[torch.float64]),
+                   DecorateInfo(skipIfMps, 'TestModule', dtypes=[torch.float64, torch.bfloat16]),
                    # Tracking at https://github.com/pytorch/pytorch/issues/98089
                    DecorateInfo(unittest.skip("Skipped!"), 'TestModule', 'test_cpu_gpu_parity'),
                    # No channels_last support for GroupNorm currently.
@@ -2742,7 +2774,8 @@
     ModuleInfo(torch.nn.Hardshrink,
                module_inputs_func=module_inputs_torch_nn_Hardshrink,
                skips=(
-                   DecorateInfo(skipIfMps, 'TestModule', dtypes=[torch.float64]),),
+                   # not supported on MPS backend
+                   DecorateInfo(skipMPS),),
                ),
     ModuleInfo(torch.nn.Hardswish,
                module_inputs_func=module_inputs_torch_nn_Hardswish,
@@ -2774,14 +2807,16 @@
                module_inputs_func=partial(module_inputs_torch_nn_InstanceNormNd, N=3),
                train_and_eval_differ=True,
                skips=(
-                   DecorateInfo(skipIfMps, 'TestModule', dtypes=[torch.float64]),
+                   # not supported on MPS backend
+                   DecorateInfo(skipMPS),
                    # No channels_last support for InstanceNorm3d currently.
                    DecorateInfo(unittest.skip("Skipped!"), 'TestModule', 'test_memory_format'),)
                ),
     ModuleInfo(torch.nn.LocalResponseNorm,
                module_inputs_func=module_inputs_torch_nn_LocalResponseNorm,
                skips=(
-                   DecorateInfo(skipIfMps, 'TestModule', dtypes=[torch.float64]),)
+                   # uses avg_pool3d which is not supported on MPS backend
+                   DecorateInfo(skipMPS),)
                ),
     ModuleInfo(torch.nn.LayerNorm,
                module_inputs_func=module_inputs_torch_nn_LayerNorm,
@@ -2856,12 +2891,16 @@
     ModuleInfo(torch.nn.ReLU6,
                module_inputs_func=module_inputs_torch_nn_ReLU6,
                skips=(
-                   DecorateInfo(skipIfMps, 'TestModule', dtypes=[torch.float64]),)
+                   # test fails on MPS backend and is being investigated.
+                   # See https://github.com/pytorch/pytorch/issues/100914
+                   DecorateInfo(skipMPS),)
                ),
     ModuleInfo(torch.nn.PReLU,
                module_inputs_func=module_inputs_torch_nn_PReLU,
                skips=(
-                   DecorateInfo(skipIfMps, 'TestModule', dtypes=[torch.float64]),)
+                   # test fails on MPS backend and is being investigated.
+                   # See https://github.com/pytorch/pytorch/issues/100914
+                   DecorateInfo(skipMPS),)
                ),
     ModuleInfo(torch.nn.RNNCell,
                module_inputs_func=partial(module_inputs_torch_nn_RNN_GRU_Cell, is_rnn=True),
@@ -2922,12 +2961,15 @@
     ModuleInfo(torch.nn.Softplus,
                module_inputs_func=module_inputs_torch_nn_Softplus,
                skips=(
-                   DecorateInfo(skipIfMps, 'TestModule', dtypes=[torch.float64]),)
+                   # test fails on MPS backend and is being investigated.
+                   # See https://github.com/pytorch/pytorch/issues/100914
+                   DecorateInfo(skipMPS),)
                ),
     ModuleInfo(torch.nn.Softshrink,
                module_inputs_func=module_inputs_torch_nn_Softshrink,
                skips=(
-                   DecorateInfo(skipIfMps, 'TestModule', dtypes=[torch.float64]),)
+                   # not supported on MPS backend
+                   DecorateInfo(skipMPS),)
                ),
     ModuleInfo(torch.nn.Softsign,
                module_inputs_func=module_inputs_torch_nn_Softsign,
@@ -2947,12 +2989,15 @@
     ModuleInfo(torch.nn.Threshold,
                module_inputs_func=module_inputs_torch_nn_Threshold,
                skips=(
-                   DecorateInfo(skipIfMps, 'TestModule', dtypes=[torch.float64]),)
+                   # test fails on MPS backend and is being investigated.
+                   # See https://github.com/pytorch/pytorch/issues/100914
+                   DecorateInfo(skipMPS),)
                ),
     ModuleInfo(torch.nn.Mish,
                module_inputs_func=module_inputs_torch_nn_Mish,
                skips=(
-                   DecorateInfo(skipIfMps, 'TestModule', dtypes=[torch.float64]),)
+                   # not supported on MPS backend
+                   DecorateInfo(skipMPS),)
                ),
     ModuleInfo(torch.nn.RNN,
                train_and_eval_differ=True,
@@ -2971,7 +3016,8 @@
                train_and_eval_differ=True,
                module_inputs_func=module_inputs_torch_nn_LSTM,
                skips=(
-                   DecorateInfo(skipIfMps, 'TestModule', dtypes=[torch.float64]),),
+                   # LSTM with projections is not currently supported with MPS
+                   DecorateInfo(skipMPS),),
                decorators=rnn_gru_lstm_module_info_decorators),
     ModuleInfo(torch.nn.ReflectionPad1d,
                module_inputs_func=module_inputs_torch_nn_ReflectionPad1d,
@@ -3014,7 +3060,9 @@
     ModuleInfo(torch.nn.SELU,
                module_inputs_func=module_inputs_torch_nn_SELU,
                skips=(
-                   DecorateInfo(skipIfMps, 'TestModule', dtypes=[torch.float64]),)
+                   # test fails on MPS backend and is being investigated.
+                   # See https://github.com/pytorch/pytorch/issues/100914
+                   DecorateInfo(skipMPS),)
                ),
     ModuleInfo(torch.nn.ZeroPad1d,
                module_inputs_func=module_inputs_torch_nn_ZeroPad1d,
@@ -3024,12 +3072,16 @@
     ModuleInfo(torch.nn.ZeroPad2d,
                module_inputs_func=module_inputs_torch_nn_ZeroPad2d,
                skips=(
-                   DecorateInfo(skipIfMps, 'TestModule', dtypes=[torch.float64]),)
+                   DecorateInfo(skipIfMps, 'TestModule', dtypes=[torch.float64]),
+                   # Fails with channels last test on MPS backend
+                   DecorateInfo(unittest.expectedFailure, "TestModule", "test_memory_format", device_type='mps'),)
                ),
     ModuleInfo(torch.nn.ZeroPad3d,
                module_inputs_func=module_inputs_torch_nn_ZeroPad3d,
                skips=(
-                   DecorateInfo(skipIfMps, 'TestModule', dtypes=[torch.float64]),)
+                   DecorateInfo(skipIfMps, 'TestModule', dtypes=[torch.float64]),
+                   # Fails with channels last test on MPS backend
+                   DecorateInfo(unittest.expectedFailure, "TestModule", "test_memory_format", device_type='mps'),)
                ),
     ModuleInfo(torch.nn.ConstantPad1d,
                module_inputs_func=module_inputs_torch_nn_ConstantPad1d,
@@ -3039,11 +3091,15 @@
     ModuleInfo(torch.nn.ConstantPad2d,
                module_inputs_func=module_inputs_torch_nn_ConstantPad2d,
                skips=(
-                   DecorateInfo(skipIfMps, 'TestModule', dtypes=[torch.float64]),)
+                   DecorateInfo(skipIfMps, 'TestModule', dtypes=[torch.float64]),
+                   # Fails with channels last test on MPS backend
+                   DecorateInfo(unittest.expectedFailure, "TestModule", "test_memory_format", device_type='mps'),)
                ),
     ModuleInfo(torch.nn.ConstantPad3d,
                module_inputs_func=module_inputs_torch_nn_ConstantPad3d,
                skips=(
-                   DecorateInfo(skipIfMps, 'TestModule', dtypes=[torch.float64]),)
+                   DecorateInfo(skipIfMps, 'TestModule', dtypes=[torch.float64]),
+                   # Fails with channels last test on MPS backend
+                   DecorateInfo(unittest.expectedFailure, "TestModule", "test_memory_format", device_type='mps'),)
                )
 ]
diff --git a/torch/testing/_internal/common_utils.py b/torch/testing/_internal/common_utils.py
index 88eb9c3..927115f 100644
--- a/torch/testing/_internal/common_utils.py
+++ b/torch/testing/_internal/common_utils.py
@@ -63,6 +63,7 @@
 import torch
 import torch.backends.cudnn
 import torch.backends.mkl
+import torch.backends.mps
 import torch.backends.xnnpack
 import torch.cuda
 from torch import Tensor
@@ -955,6 +956,7 @@
 TEST_FAIRSEQ = _check_module_exists('fairseq')
 TEST_SCIPY = _check_module_exists('scipy')
 TEST_MKL = torch.backends.mkl.is_available()
+TEST_MPS = torch.backends.mps.is_available()
 TEST_CUDA = torch.cuda.is_available()
 TEST_NUMBA = _check_module_exists('numba')
 
@@ -1143,7 +1145,7 @@
 def skipIfMps(fn):
     @wraps(fn)
     def wrapper(*args, **kwargs):
-        if torch.backends.mps.is_available():
+        if TEST_MPS:
             raise unittest.SkipTest("test doesn't currently work with MPS")
         else:
             fn(*args, **kwargs)