[nn] zero_grad() set_to_none default True (#92731)
Attempts to fix #92656
BC-breaking! This changes the default of zero_grad in optim and in nn to default set grads to None instead of zero tensors. We are changing the default because there are proven perf wins and existing code has typically not regressed due to this change. (will probably have to flesh out this note more).
Pull Request resolved: https://github.com/pytorch/pytorch/pull/92731
Approved by: https://github.com/ngimel
diff --git a/test/test_mps.py b/test/test_mps.py
index d7e560e..423f3ba 100644
--- a/test/test_mps.py
+++ b/test/test_mps.py
@@ -5957,24 +5957,27 @@
self.assertIsNotNone(module.weight.grad)
self.assertGreater(module.weight.grad.data.abs().sum(), 0)
module.zero_grad()
- self.assertEqual(module.weight.grad.data, module.weight.data.clone().zero_())
+ self.assertIsNone(module.weight.grad)
module.bias.requires_grad = True
module.zero_grad()
- self.assertIsNotNone(module.weight.grad)
+ self.assertIsNone(module.weight.grad)
self.assertIsNone(module.bias.grad)
module(i).sum().backward()
self.assertIsNotNone(module.weight.grad)
self.assertIsNotNone(module.bias.grad)
self.assertGreater(module.weight.grad.data.abs().sum(), 0)
self.assertGreater(module.bias.grad.data.abs().sum(), 0)
- module.zero_grad()
+
+ # Force set to zeros.
+ module.zero_grad(set_to_none=False)
self.assertEqual(module.weight.grad.data, module.weight.data.clone().zero_())
self.assertEqual(module.bias.grad.data, module.bias.data.clone().zero_())
- # Force set to None.
- module.zero_grad(set_to_none=True)
+ module.zero_grad()
self.assertIsNone(module.weight.grad)
+ self.assertIsNone(module.bias.grad)
+
def test_no_grad(self):
for dtype in [torch.bfloat16, torch.float, torch.double]: