[nn] zero_grad() set_to_none default True (#92731)
Attempts to fix #92656
BC-breaking! This changes the default of zero_grad in optim and in nn to default set grads to None instead of zero tensors. We are changing the default because there are proven perf wins and existing code has typically not regressed due to this change. (will probably have to flesh out this note more).
Pull Request resolved: https://github.com/pytorch/pytorch/pull/92731
Approved by: https://github.com/ngimel
diff --git a/test/cpp/api/module.cpp b/test/cpp/api/module.cpp
index dd16d9c..28f17f1 100644
--- a/test/cpp/api/module.cpp
+++ b/test/cpp/api/module.cpp
@@ -45,8 +45,7 @@
for (auto& parameter : module->parameters()) {
// NOLINTNEXTLINE(performance-unnecessary-copy-initialization)
auto grad = parameter.grad();
- ASSERT_TRUE(grad.defined());
- ASSERT_EQ(grad.sum().item<float>(), 0);
+ ASSERT_FALSE(grad.defined());
}
}
@@ -66,14 +65,14 @@
ASSERT_TRUE(module.x.grad().defined());
ASSERT_FALSE(module.y.grad().defined());
- module.zero_grad();
+ module.zero_grad(false); // set_to_none = false
ASSERT_TRUE(module.x.grad().defined());
ASSERT_FALSE(module.y.grad().defined());
ASSERT_EQ(module.x.grad().sum().item<float>(), 0);
- module.zero_grad(true); // set_to_none = true
+ module.zero_grad();
ASSERT_FALSE(module.x.grad().defined());
ASSERT_FALSE(module.y.grad().defined());
diff --git a/test/distributed/optim/test_zero_redundancy_optimizer.py b/test/distributed/optim/test_zero_redundancy_optimizer.py
index 3e0474c..e67ba92 100644
--- a/test/distributed/optim/test_zero_redundancy_optimizer.py
+++ b/test/distributed/optim/test_zero_redundancy_optimizer.py
@@ -268,8 +268,8 @@
self.assertNotEqual(m.weight.grad, torch.zeros_like(m.weight))
self.assertNotEqual(m.weight.grad, torch.zeros_like(m.weight))
o.zero_grad()
- self.assertFalse(m.weight.grad)
- self.assertFalse(m.bias.grad)
+ self.assertIsNone(m.weight.grad)
+ self.assertIsNone(m.bias.grad)
def test_constructor(self):
"""Check the robustness of the ZeroRedundancyOptimizer constructor by
diff --git a/test/profiler/test_memory_profiler.py b/test/profiler/test_memory_profiler.py
index 8444272..70b21b6 100644
--- a/test/profiler/test_memory_profiler.py
+++ b/test/profiler/test_memory_profiler.py
@@ -844,14 +844,17 @@
if key.storage.allocation_id == max(ids | {-1})
}
- def _run_and_check_parameters_and_gradients(self, inner_fn, model):
+ def _run_and_check_parameters_and_gradients(self, inner_fn, model, grads_none: bool = False):
with profile() as prof:
inner_fn()
memory_profile = prof._memory_profile()
- def assert_category(t: torch.Tensor, category: _memory_profiler.Category):
+ def assert_category(t: torch.Tensor, category: _memory_profiler.Category, should_be_none: bool = False):
+ if should_be_none:
+ assert t is None, "tensor should be None but is not."
+ return
self.assertIsNotNone(t)
categories = self._lookup_tensor_categories(t, memory_profile)
self.assertGreater(len(categories), 0)
@@ -859,7 +862,7 @@
for p in model.parameters():
assert_category(p, _memory_profiler.Category.PARAMETER)
- assert_category(p.grad, _memory_profiler.Category.GRADIENT)
+ assert_category(p.grad, _memory_profiler.Category.GRADIENT, grads_none)
# Rely on internal asserts
_ = memory_profile.timeline
@@ -929,16 +932,15 @@
_ = model(torch.ones((2, 2)))
def fwd_bwd_step():
+ optimizer.zero_grad()
y = model(torch.ones((2, 2)))
torch.nn.functional.mse_loss(y, torch.rand((2, 1))).backward()
optimizer.step()
- optimizer.zero_grad()
# If we profile the first step then gradients will not have been
# created when we call `model.forward`, so if we don't call `.backward`
# then gradients are never created.
- with self.assertRaises(AssertionError):
- self._run_and_check_parameters_and_gradients(inner_fn=fwd_only, model=model)
+ self._run_and_check_parameters_and_gradients(inner_fn=fwd_only, model=model, grads_none=True)
# On the first step we must rely on `AccumulateGrad`, since gradients
# did not exist when `model.forward` was called.
@@ -1078,10 +1080,10 @@
def inner_fn():
y = model(torch.ones((2, 2)))
- torch.nn.functional.mse_loss(y, torch.rand((2, 1))).backward()
optimizer = torch.optim.SGD(model.parameters(), lr=0.1)
- optimizer.step()
optimizer.zero_grad()
+ torch.nn.functional.mse_loss(y, torch.rand((2, 1))).backward()
+ optimizer.step()
self._run_and_check_parameters_and_gradients(inner_fn=inner_fn, model=model)
self.assertEqual(len(list(model.parameters())), 6)
@@ -1220,9 +1222,7 @@
-- Optimizer --------------------------------------------------------------------------------------------
aten::add_.Tensor 3 (PARAMETER), 25 (GRADIENT) -> 3 (PARAMETER)
- aten::add_.Tensor 5 (PARAMETER), 23 (GRADIENT) -> 5 (PARAMETER)
- aten::zero_ 25 (GRADIENT) -> 25 (GRADIENT)
- aten::zero_ 23 (GRADIENT) -> 23 (GRADIENT)""",
+ aten::add_.Tensor 5 (PARAMETER), 23 (GRADIENT) -> 5 (PARAMETER)""",
)
def test_categories_e2e_simple_module_fwd(self) -> None:
@@ -1317,9 +1317,7 @@
aten::clone 9 (GRADIENT) -> 11 (OPTIMIZER_STATE)
aten::detach 11 (OPTIMIZER_STATE) -> 11 (OPTIMIZER_STATE)
aten::detach 11 (OPTIMIZER_STATE) -> 11 (OPTIMIZER_STATE)
- aten::add_.Tensor 3 (PARAMETER), 11 (OPTIMIZER_STATE) -> 3 (PARAMETER)
- aten::zero_ 7 (GRADIENT) -> 7 (GRADIENT)
- aten::zero_ 9 (GRADIENT) -> 9 (GRADIENT)""",
+ aten::add_.Tensor 3 (PARAMETER), 11 (OPTIMIZER_STATE) -> 3 (PARAMETER)""",
)
def test_categories_e2e_sequential_fwd(self) -> None:
@@ -1550,9 +1548,9 @@
destroy ??? 27(v1) 2 kB
increment_version PARAMETER 2(v0) 1024 kB
destroy ??? 29(v1) 1024 kB
- increment_version GRADIENT 16(v0) 128 kB
- increment_version GRADIENT 17(v0) 2 kB
- increment_version GRADIENT 13(v0) 1024 kB""")
+ destroy GRADIENT 16(v0) 128 kB
+ destroy GRADIENT 17(v0) 2 kB
+ destroy GRADIENT 13(v0) 1024 kB""")
if __name__ == "__main__":
diff --git a/test/profiler/test_profiler.py b/test/profiler/test_profiler.py
index c31b1ea..c0497da 100644
--- a/test/profiler/test_profiler.py
+++ b/test/profiler/test_profiler.py
@@ -2669,10 +2669,10 @@
)
optimizer = torch.optim.Adam(model.parameters())
cases = (
- (1, lambda: optimizer.zero_grad()),
- (1, lambda: model.zero_grad()),
- (0, lambda: optimizer.zero_grad(set_to_none=True)),
- (0, lambda: model.zero_grad(set_to_none=True))
+ (0, lambda: optimizer.zero_grad()),
+ (0, lambda: model.zero_grad()),
+ (1, lambda: optimizer.zero_grad(set_to_none=False)),
+ (1, lambda: model.zero_grad(set_to_none=False))
)
num_matched = []
for _, fn in cases:
diff --git a/test/test_cpp_extensions_jit.py b/test/test_cpp_extensions_jit.py
index 3b6d7ee..26116c6 100644
--- a/test/test_cpp_extensions_jit.py
+++ b/test/test_cpp_extensions_jit.py
@@ -565,7 +565,7 @@
# Try calling zero_grad()
net.zero_grad()
for p in net.parameters():
- self.assertEqual(p.grad, torch.zeros_like(p))
+ assert p.grad is None, "zero_grad defaults to setting grads to None"
# Test train(), eval(), training (a property)
self.assertTrue(net.training)
diff --git a/test/test_mps.py b/test/test_mps.py
index d7e560e..423f3ba 100644
--- a/test/test_mps.py
+++ b/test/test_mps.py
@@ -5957,24 +5957,27 @@
self.assertIsNotNone(module.weight.grad)
self.assertGreater(module.weight.grad.data.abs().sum(), 0)
module.zero_grad()
- self.assertEqual(module.weight.grad.data, module.weight.data.clone().zero_())
+ self.assertIsNone(module.weight.grad)
module.bias.requires_grad = True
module.zero_grad()
- self.assertIsNotNone(module.weight.grad)
+ self.assertIsNone(module.weight.grad)
self.assertIsNone(module.bias.grad)
module(i).sum().backward()
self.assertIsNotNone(module.weight.grad)
self.assertIsNotNone(module.bias.grad)
self.assertGreater(module.weight.grad.data.abs().sum(), 0)
self.assertGreater(module.bias.grad.data.abs().sum(), 0)
- module.zero_grad()
+
+ # Force set to zeros.
+ module.zero_grad(set_to_none=False)
self.assertEqual(module.weight.grad.data, module.weight.data.clone().zero_())
self.assertEqual(module.bias.grad.data, module.bias.data.clone().zero_())
- # Force set to None.
- module.zero_grad(set_to_none=True)
+ module.zero_grad()
self.assertIsNone(module.weight.grad)
+ self.assertIsNone(module.bias.grad)
+
def test_no_grad(self):
for dtype in [torch.bfloat16, torch.float, torch.double]:
diff --git a/test/test_nn.py b/test/test_nn.py
index 90bafbb..e76737b 100644
--- a/test/test_nn.py
+++ b/test/test_nn.py
@@ -219,25 +219,24 @@
self.assertIsNotNone(module.weight.grad)
self.assertGreater(module.weight.grad.data.abs().sum(), 0)
module.zero_grad()
- self.assertEqual(module.weight.grad.data, module.weight.data.clone().zero_())
+ self.assertIsNone(module.weight.grad)
module.bias.requires_grad = True
module.zero_grad()
- self.assertIsNotNone(module.weight.grad)
+ self.assertIsNone(module.weight.grad)
self.assertIsNone(module.bias.grad)
module(i).sum().backward()
self.assertIsNotNone(module.weight.grad)
self.assertIsNotNone(module.bias.grad)
self.assertGreater(module.weight.grad.data.abs().sum(), 0)
self.assertGreater(module.bias.grad.data.abs().sum(), 0)
- module.zero_grad()
+ module.zero_grad(set_to_none=False) # Force set to zeros.
self.assertEqual(module.weight.grad.data, module.weight.data.clone().zero_())
self.assertEqual(module.bias.grad.data, module.bias.data.clone().zero_())
- # Force set to None.
- module.zero_grad(set_to_none=True)
+ module.zero_grad()
self.assertIsNone(module.weight.grad)
-
+ self.assertIsNone(module.bias.grad)
def test_no_grad(self):
for dtype in [torch.bfloat16, torch.float, torch.double]: