Fix activation checkpoint for mps (#104787)
Fixes https://github.com/pytorch/pytorch/issues/104478
Pull Request resolved: https://github.com/pytorch/pytorch/pull/104787
Approved by: https://github.com/albanD
diff --git a/test/test_mps.py b/test/test_mps.py
index 0a16fb6..a9dcbd9 100644
--- a/test/test_mps.py
+++ b/test/test_mps.py
@@ -3695,6 +3695,18 @@
helper((2, 8, 4, 5), torch.int16)
+ def test_activation_checkpoint_does_not_error(self):
+ from torch.utils.checkpoint import checkpoint
+
+ for use_reentrant in (True, False):
+ a = torch.tensor(1., device="mps", requires_grad=True)
+
+ def fn(x):
+ return x.sin().cos().exp()
+
+ out = checkpoint(fn, a, use_reentrant=use_reentrant)
+ out.backward()
+
class TestLogical(TestCaseMPS):
def _wrap_tensor(self, x, device="cpu", dtype=None, requires_grad=False):
return torch.tensor(x, device=device, dtype=dtype, requires_grad=requires_grad)