Fix activation checkpoint for mps (#104787)

Fixes https://github.com/pytorch/pytorch/issues/104478

Pull Request resolved: https://github.com/pytorch/pytorch/pull/104787
Approved by: https://github.com/albanD
diff --git a/test/test_mps.py b/test/test_mps.py
index 0a16fb6..a9dcbd9 100644
--- a/test/test_mps.py
+++ b/test/test_mps.py
@@ -3695,6 +3695,18 @@
 
         helper((2, 8, 4, 5), torch.int16)
 
+    def test_activation_checkpoint_does_not_error(self):
+        from torch.utils.checkpoint import checkpoint
+
+        for use_reentrant in (True, False):
+            a = torch.tensor(1., device="mps", requires_grad=True)
+
+            def fn(x):
+                return x.sin().cos().exp()
+
+            out = checkpoint(fn, a, use_reentrant=use_reentrant)
+            out.backward()
+
 class TestLogical(TestCaseMPS):
     def _wrap_tensor(self, x, device="cpu", dtype=None, requires_grad=False):
         return torch.tensor(x, device=device, dtype=dtype, requires_grad=requires_grad)