[nn] zero_grad() set_to_none default True (#92731)
Attempts to fix #92656
BC-breaking! This changes the default of zero_grad in optim and in nn to default set grads to None instead of zero tensors. We are changing the default because there are proven perf wins and existing code has typically not regressed due to this change. (will probably have to flesh out this note more).
Pull Request resolved: https://github.com/pytorch/pytorch/pull/92731
Approved by: https://github.com/ngimel
diff --git a/test/cpp/api/module.cpp b/test/cpp/api/module.cpp
index dd16d9c..28f17f1 100644
--- a/test/cpp/api/module.cpp
+++ b/test/cpp/api/module.cpp
@@ -45,8 +45,7 @@
for (auto& parameter : module->parameters()) {
// NOLINTNEXTLINE(performance-unnecessary-copy-initialization)
auto grad = parameter.grad();
- ASSERT_TRUE(grad.defined());
- ASSERT_EQ(grad.sum().item<float>(), 0);
+ ASSERT_FALSE(grad.defined());
}
}
@@ -66,14 +65,14 @@
ASSERT_TRUE(module.x.grad().defined());
ASSERT_FALSE(module.y.grad().defined());
- module.zero_grad();
+ module.zero_grad(false); // set_to_none = false
ASSERT_TRUE(module.x.grad().defined());
ASSERT_FALSE(module.y.grad().defined());
ASSERT_EQ(module.x.grad().sum().item<float>(), 0);
- module.zero_grad(true); // set_to_none = true
+ module.zero_grad();
ASSERT_FALSE(module.x.grad().defined());
ASSERT_FALSE(module.y.grad().defined());
diff --git a/test/distributed/optim/test_zero_redundancy_optimizer.py b/test/distributed/optim/test_zero_redundancy_optimizer.py
index 3e0474c..e67ba92 100644
--- a/test/distributed/optim/test_zero_redundancy_optimizer.py
+++ b/test/distributed/optim/test_zero_redundancy_optimizer.py
@@ -268,8 +268,8 @@
self.assertNotEqual(m.weight.grad, torch.zeros_like(m.weight))
self.assertNotEqual(m.weight.grad, torch.zeros_like(m.weight))
o.zero_grad()
- self.assertFalse(m.weight.grad)
- self.assertFalse(m.bias.grad)
+ self.assertIsNone(m.weight.grad)
+ self.assertIsNone(m.bias.grad)
def test_constructor(self):
"""Check the robustness of the ZeroRedundancyOptimizer constructor by
diff --git a/test/profiler/test_memory_profiler.py b/test/profiler/test_memory_profiler.py
index 8444272..70b21b6 100644
--- a/test/profiler/test_memory_profiler.py
+++ b/test/profiler/test_memory_profiler.py
@@ -844,14 +844,17 @@
if key.storage.allocation_id == max(ids | {-1})
}
- def _run_and_check_parameters_and_gradients(self, inner_fn, model):
+ def _run_and_check_parameters_and_gradients(self, inner_fn, model, grads_none: bool = False):
with profile() as prof:
inner_fn()
memory_profile = prof._memory_profile()
- def assert_category(t: torch.Tensor, category: _memory_profiler.Category):
+ def assert_category(t: torch.Tensor, category: _memory_profiler.Category, should_be_none: bool = False):
+ if should_be_none:
+ assert t is None, "tensor should be None but is not."
+ return
self.assertIsNotNone(t)
categories = self._lookup_tensor_categories(t, memory_profile)
self.assertGreater(len(categories), 0)
@@ -859,7 +862,7 @@
for p in model.parameters():
assert_category(p, _memory_profiler.Category.PARAMETER)
- assert_category(p.grad, _memory_profiler.Category.GRADIENT)
+ assert_category(p.grad, _memory_profiler.Category.GRADIENT, grads_none)
# Rely on internal asserts
_ = memory_profile.timeline
@@ -929,16 +932,15 @@
_ = model(torch.ones((2, 2)))
def fwd_bwd_step():
+ optimizer.zero_grad()
y = model(torch.ones((2, 2)))
torch.nn.functional.mse_loss(y, torch.rand((2, 1))).backward()
optimizer.step()
- optimizer.zero_grad()
# If we profile the first step then gradients will not have been
# created when we call `model.forward`, so if we don't call `.backward`
# then gradients are never created.
- with self.assertRaises(AssertionError):
- self._run_and_check_parameters_and_gradients(inner_fn=fwd_only, model=model)
+ self._run_and_check_parameters_and_gradients(inner_fn=fwd_only, model=model, grads_none=True)
# On the first step we must rely on `AccumulateGrad`, since gradients
# did not exist when `model.forward` was called.
@@ -1078,10 +1080,10 @@
def inner_fn():
y = model(torch.ones((2, 2)))
- torch.nn.functional.mse_loss(y, torch.rand((2, 1))).backward()
optimizer = torch.optim.SGD(model.parameters(), lr=0.1)
- optimizer.step()
optimizer.zero_grad()
+ torch.nn.functional.mse_loss(y, torch.rand((2, 1))).backward()
+ optimizer.step()
self._run_and_check_parameters_and_gradients(inner_fn=inner_fn, model=model)
self.assertEqual(len(list(model.parameters())), 6)
@@ -1220,9 +1222,7 @@
-- Optimizer --------------------------------------------------------------------------------------------
aten::add_.Tensor 3 (PARAMETER), 25 (GRADIENT) -> 3 (PARAMETER)
- aten::add_.Tensor 5 (PARAMETER), 23 (GRADIENT) -> 5 (PARAMETER)
- aten::zero_ 25 (GRADIENT) -> 25 (GRADIENT)
- aten::zero_ 23 (GRADIENT) -> 23 (GRADIENT)""",
+ aten::add_.Tensor 5 (PARAMETER), 23 (GRADIENT) -> 5 (PARAMETER)""",
)
def test_categories_e2e_simple_module_fwd(self) -> None:
@@ -1317,9 +1317,7 @@
aten::clone 9 (GRADIENT) -> 11 (OPTIMIZER_STATE)
aten::detach 11 (OPTIMIZER_STATE) -> 11 (OPTIMIZER_STATE)
aten::detach 11 (OPTIMIZER_STATE) -> 11 (OPTIMIZER_STATE)
- aten::add_.Tensor 3 (PARAMETER), 11 (OPTIMIZER_STATE) -> 3 (PARAMETER)
- aten::zero_ 7 (GRADIENT) -> 7 (GRADIENT)
- aten::zero_ 9 (GRADIENT) -> 9 (GRADIENT)""",
+ aten::add_.Tensor 3 (PARAMETER), 11 (OPTIMIZER_STATE) -> 3 (PARAMETER)""",
)
def test_categories_e2e_sequential_fwd(self) -> None:
@@ -1550,9 +1548,9 @@
destroy ??? 27(v1) 2 kB
increment_version PARAMETER 2(v0) 1024 kB
destroy ??? 29(v1) 1024 kB
- increment_version GRADIENT 16(v0) 128 kB
- increment_version GRADIENT 17(v0) 2 kB
- increment_version GRADIENT 13(v0) 1024 kB""")
+ destroy GRADIENT 16(v0) 128 kB
+ destroy GRADIENT 17(v0) 2 kB
+ destroy GRADIENT 13(v0) 1024 kB""")
if __name__ == "__main__":
diff --git a/test/profiler/test_profiler.py b/test/profiler/test_profiler.py
index c31b1ea..c0497da 100644
--- a/test/profiler/test_profiler.py
+++ b/test/profiler/test_profiler.py
@@ -2669,10 +2669,10 @@
)
optimizer = torch.optim.Adam(model.parameters())
cases = (
- (1, lambda: optimizer.zero_grad()),
- (1, lambda: model.zero_grad()),
- (0, lambda: optimizer.zero_grad(set_to_none=True)),
- (0, lambda: model.zero_grad(set_to_none=True))
+ (0, lambda: optimizer.zero_grad()),
+ (0, lambda: model.zero_grad()),
+ (1, lambda: optimizer.zero_grad(set_to_none=False)),
+ (1, lambda: model.zero_grad(set_to_none=False))
)
num_matched = []
for _, fn in cases:
diff --git a/test/test_cpp_extensions_jit.py b/test/test_cpp_extensions_jit.py
index 3b6d7ee..26116c6 100644
--- a/test/test_cpp_extensions_jit.py
+++ b/test/test_cpp_extensions_jit.py
@@ -565,7 +565,7 @@
# Try calling zero_grad()
net.zero_grad()
for p in net.parameters():
- self.assertEqual(p.grad, torch.zeros_like(p))
+ assert p.grad is None, "zero_grad defaults to setting grads to None"
# Test train(), eval(), training (a property)
self.assertTrue(net.training)
diff --git a/test/test_mps.py b/test/test_mps.py
index d7e560e..423f3ba 100644
--- a/test/test_mps.py
+++ b/test/test_mps.py
@@ -5957,24 +5957,27 @@
self.assertIsNotNone(module.weight.grad)
self.assertGreater(module.weight.grad.data.abs().sum(), 0)
module.zero_grad()
- self.assertEqual(module.weight.grad.data, module.weight.data.clone().zero_())
+ self.assertIsNone(module.weight.grad)
module.bias.requires_grad = True
module.zero_grad()
- self.assertIsNotNone(module.weight.grad)
+ self.assertIsNone(module.weight.grad)
self.assertIsNone(module.bias.grad)
module(i).sum().backward()
self.assertIsNotNone(module.weight.grad)
self.assertIsNotNone(module.bias.grad)
self.assertGreater(module.weight.grad.data.abs().sum(), 0)
self.assertGreater(module.bias.grad.data.abs().sum(), 0)
- module.zero_grad()
+
+ # Force set to zeros.
+ module.zero_grad(set_to_none=False)
self.assertEqual(module.weight.grad.data, module.weight.data.clone().zero_())
self.assertEqual(module.bias.grad.data, module.bias.data.clone().zero_())
- # Force set to None.
- module.zero_grad(set_to_none=True)
+ module.zero_grad()
self.assertIsNone(module.weight.grad)
+ self.assertIsNone(module.bias.grad)
+
def test_no_grad(self):
for dtype in [torch.bfloat16, torch.float, torch.double]:
diff --git a/test/test_nn.py b/test/test_nn.py
index 90bafbb..e76737b 100644
--- a/test/test_nn.py
+++ b/test/test_nn.py
@@ -219,25 +219,24 @@
self.assertIsNotNone(module.weight.grad)
self.assertGreater(module.weight.grad.data.abs().sum(), 0)
module.zero_grad()
- self.assertEqual(module.weight.grad.data, module.weight.data.clone().zero_())
+ self.assertIsNone(module.weight.grad)
module.bias.requires_grad = True
module.zero_grad()
- self.assertIsNotNone(module.weight.grad)
+ self.assertIsNone(module.weight.grad)
self.assertIsNone(module.bias.grad)
module(i).sum().backward()
self.assertIsNotNone(module.weight.grad)
self.assertIsNotNone(module.bias.grad)
self.assertGreater(module.weight.grad.data.abs().sum(), 0)
self.assertGreater(module.bias.grad.data.abs().sum(), 0)
- module.zero_grad()
+ module.zero_grad(set_to_none=False) # Force set to zeros.
self.assertEqual(module.weight.grad.data, module.weight.data.clone().zero_())
self.assertEqual(module.bias.grad.data, module.bias.data.clone().zero_())
- # Force set to None.
- module.zero_grad(set_to_none=True)
+ module.zero_grad()
self.assertIsNone(module.weight.grad)
-
+ self.assertIsNone(module.bias.grad)
def test_no_grad(self):
for dtype in [torch.bfloat16, torch.float, torch.double]:
diff --git a/torch/csrc/api/include/torch/nn/module.h b/torch/csrc/api/include/torch/nn/module.h
index ff0348e..20d1024 100644
--- a/torch/csrc/api/include/torch/nn/module.h
+++ b/torch/csrc/api/include/torch/nn/module.h
@@ -302,7 +302,7 @@
virtual void to(torch::Device device, bool non_blocking = false);
/// Recursively zeros out the `grad` value of each registered parameter.
- virtual void zero_grad(bool set_to_none = false);
+ virtual void zero_grad(bool set_to_none = true);
/// Attempts to cast this `Module` to the given `ModuleType`.
///
diff --git a/torch/distributed/_shard/sharded_optim/api.py b/torch/distributed/_shard/sharded_optim/api.py
index ec4f9e6..c2bfad6 100644
--- a/torch/distributed/_shard/sharded_optim/api.py
+++ b/torch/distributed/_shard/sharded_optim/api.py
@@ -40,7 +40,7 @@
self.param_groups = self._optim.param_groups
self.state = self._optim.state
- def zero_grad(self, set_to_none: bool = False): # type: ignore[override]
+ def zero_grad(self, set_to_none: bool = True): # type: ignore[override]
r"""Sets the gradients of all optimized :class:`torch.Tensor` s to zero.
Args:
diff --git a/torch/distributed/nn/api/remote_module.py b/torch/distributed/nn/api/remote_module.py
index d623014..6e0216d 100644
--- a/torch/distributed/nn/api/remote_module.py
+++ b/torch/distributed/nn/api/remote_module.py
@@ -447,7 +447,7 @@
def requires_grad_(self: T, requires_grad: bool = True) -> T: # type: ignore[return]
_raise_not_supported(self.requires_grad_.__name__)
- def zero_grad(self, set_to_none: bool = False) -> None:
+ def zero_grad(self, set_to_none: bool = True) -> None:
_raise_not_supported(self.zero_grad.__name__)
def share_memory(self: T) -> T: # type: ignore[return]
diff --git a/torch/distributed/optim/post_localSGD_optimizer.py b/torch/distributed/optim/post_localSGD_optimizer.py
index 4c60399..f171768 100644
--- a/torch/distributed/optim/post_localSGD_optimizer.py
+++ b/torch/distributed/optim/post_localSGD_optimizer.py
@@ -102,7 +102,7 @@
self.optim.step()
self.averager.average_parameters(params=self.param_groups)
- def zero_grad(self, set_to_none: bool = False): # type: ignore[override]
+ def zero_grad(self, set_to_none: bool = True): # type: ignore[override]
self.optim.zero_grad(set_to_none=set_to_none)
def add_param_group(self, param_group):
diff --git a/torch/nn/modules/module.py b/torch/nn/modules/module.py
index b1d5671..80884c8 100644
--- a/torch/nn/modules/module.py
+++ b/torch/nn/modules/module.py
@@ -2319,7 +2319,7 @@
p.requires_grad_(requires_grad)
return self
- def zero_grad(self, set_to_none: bool = False) -> None:
+ def zero_grad(self, set_to_none: bool = True) -> None:
r"""Sets gradients of all model parameters to zero. See similar function
under :class:`torch.optim.Optimizer` for more context.
diff --git a/torch/optim/optimizer.py b/torch/optim/optimizer.py
index aadb0ff..0d395a9 100644
--- a/torch/optim/optimizer.py
+++ b/torch/optim/optimizer.py
@@ -405,7 +405,7 @@
update_group(g, ng) for g, ng in zip(groups, saved_groups)]
self.__setstate__({'state': state, 'param_groups': param_groups})
- def zero_grad(self, set_to_none: bool = False):
+ def zero_grad(self, set_to_none: bool = True):
r"""Sets the gradients of all optimized :class:`torch.Tensor` s to zero.
Args:
diff --git a/torch/testing/_internal/distributed/rpc/examples/parameter_server_test.py b/torch/testing/_internal/distributed/rpc/examples/parameter_server_test.py
index 414e079..cd6c66c 100644
--- a/torch/testing/_internal/distributed/rpc/examples/parameter_server_test.py
+++ b/torch/testing/_internal/distributed/rpc/examples/parameter_server_test.py
@@ -47,7 +47,10 @@
def update_and_fetch_model(ps_rref, grads):
self = ps_rref.local_value()
for p, g in zip(self.model.parameters(), grads):
- p.grad += g
+ if p.grad is None:
+ p.grad = g
+ else:
+ p.grad += g
with self.lock:
timed_log(f"PS got {self.curr_update_size}/{self.batch_update_size} updates")
self.curr_update_size += 1