[nn] zero_grad() set_to_none default True (#92731)

Attempts to fix #92656

BC-breaking! This changes the default of zero_grad in optim and in nn to default set grads to None instead of zero tensors. We are changing the default because there are proven perf wins and existing code has typically not regressed due to this change. (will probably have to flesh out this note more).

Pull Request resolved: https://github.com/pytorch/pytorch/pull/92731
Approved by: https://github.com/ngimel
diff --git a/test/cpp/api/module.cpp b/test/cpp/api/module.cpp
index dd16d9c..28f17f1 100644
--- a/test/cpp/api/module.cpp
+++ b/test/cpp/api/module.cpp
@@ -45,8 +45,7 @@
   for (auto& parameter : module->parameters()) {
     // NOLINTNEXTLINE(performance-unnecessary-copy-initialization)
     auto grad = parameter.grad();
-    ASSERT_TRUE(grad.defined());
-    ASSERT_EQ(grad.sum().item<float>(), 0);
+    ASSERT_FALSE(grad.defined());
   }
 }
 
@@ -66,14 +65,14 @@
   ASSERT_TRUE(module.x.grad().defined());
   ASSERT_FALSE(module.y.grad().defined());
 
-  module.zero_grad();
+  module.zero_grad(false); // set_to_none = false
 
   ASSERT_TRUE(module.x.grad().defined());
   ASSERT_FALSE(module.y.grad().defined());
 
   ASSERT_EQ(module.x.grad().sum().item<float>(), 0);
 
-  module.zero_grad(true); // set_to_none = true
+  module.zero_grad();
 
   ASSERT_FALSE(module.x.grad().defined());
   ASSERT_FALSE(module.y.grad().defined());
diff --git a/test/distributed/optim/test_zero_redundancy_optimizer.py b/test/distributed/optim/test_zero_redundancy_optimizer.py
index 3e0474c..e67ba92 100644
--- a/test/distributed/optim/test_zero_redundancy_optimizer.py
+++ b/test/distributed/optim/test_zero_redundancy_optimizer.py
@@ -268,8 +268,8 @@
         self.assertNotEqual(m.weight.grad, torch.zeros_like(m.weight))
         self.assertNotEqual(m.weight.grad, torch.zeros_like(m.weight))
         o.zero_grad()
-        self.assertFalse(m.weight.grad)
-        self.assertFalse(m.bias.grad)
+        self.assertIsNone(m.weight.grad)
+        self.assertIsNone(m.bias.grad)
 
     def test_constructor(self):
         """Check the robustness of the ZeroRedundancyOptimizer constructor by
diff --git a/test/profiler/test_memory_profiler.py b/test/profiler/test_memory_profiler.py
index 8444272..70b21b6 100644
--- a/test/profiler/test_memory_profiler.py
+++ b/test/profiler/test_memory_profiler.py
@@ -844,14 +844,17 @@
             if key.storage.allocation_id == max(ids | {-1})
         }
 
-    def _run_and_check_parameters_and_gradients(self, inner_fn, model):
+    def _run_and_check_parameters_and_gradients(self, inner_fn, model, grads_none: bool = False):
 
         with profile() as prof:
             inner_fn()
 
         memory_profile = prof._memory_profile()
 
-        def assert_category(t: torch.Tensor, category: _memory_profiler.Category):
+        def assert_category(t: torch.Tensor, category: _memory_profiler.Category, should_be_none: bool = False):
+            if should_be_none:
+                assert t is None, "tensor should be None but is not."
+                return
             self.assertIsNotNone(t)
             categories = self._lookup_tensor_categories(t, memory_profile)
             self.assertGreater(len(categories), 0)
@@ -859,7 +862,7 @@
 
         for p in model.parameters():
             assert_category(p, _memory_profiler.Category.PARAMETER)
-            assert_category(p.grad, _memory_profiler.Category.GRADIENT)
+            assert_category(p.grad, _memory_profiler.Category.GRADIENT, grads_none)
 
         # Rely on internal asserts
         _ = memory_profile.timeline
@@ -929,16 +932,15 @@
             _ = model(torch.ones((2, 2)))
 
         def fwd_bwd_step():
+            optimizer.zero_grad()
             y = model(torch.ones((2, 2)))
             torch.nn.functional.mse_loss(y, torch.rand((2, 1))).backward()
             optimizer.step()
-            optimizer.zero_grad()
 
         # If we profile the first step then gradients will not have been
         # created when we call `model.forward`, so if we don't call `.backward`
         # then gradients are never created.
-        with self.assertRaises(AssertionError):
-            self._run_and_check_parameters_and_gradients(inner_fn=fwd_only, model=model)
+        self._run_and_check_parameters_and_gradients(inner_fn=fwd_only, model=model, grads_none=True)
 
         # On the first step we must rely on `AccumulateGrad`, since gradients
         # did not exist when `model.forward` was called.
@@ -1078,10 +1080,10 @@
 
         def inner_fn():
             y = model(torch.ones((2, 2)))
-            torch.nn.functional.mse_loss(y, torch.rand((2, 1))).backward()
             optimizer = torch.optim.SGD(model.parameters(), lr=0.1)
-            optimizer.step()
             optimizer.zero_grad()
+            torch.nn.functional.mse_loss(y, torch.rand((2, 1))).backward()
+            optimizer.step()
 
         self._run_and_check_parameters_and_gradients(inner_fn=inner_fn, model=model)
         self.assertEqual(len(list(model.parameters())), 6)
@@ -1220,9 +1222,7 @@
 
             -- Optimizer --------------------------------------------------------------------------------------------
             aten::add_.Tensor                        3 (PARAMETER), 25 (GRADIENT)                  -> 3 (PARAMETER)
-            aten::add_.Tensor                        5 (PARAMETER), 23 (GRADIENT)                  -> 5 (PARAMETER)
-            aten::zero_                              25 (GRADIENT)                                 -> 25 (GRADIENT)
-            aten::zero_                              23 (GRADIENT)                                 -> 23 (GRADIENT)""",
+            aten::add_.Tensor                        5 (PARAMETER), 23 (GRADIENT)                  -> 5 (PARAMETER)""",
         )
 
     def test_categories_e2e_simple_module_fwd(self) -> None:
@@ -1317,9 +1317,7 @@
             aten::clone                              9 (GRADIENT)                                  -> 11 (OPTIMIZER_STATE)
             aten::detach                             11 (OPTIMIZER_STATE)                          -> 11 (OPTIMIZER_STATE)
             aten::detach                             11 (OPTIMIZER_STATE)                          -> 11 (OPTIMIZER_STATE)
-            aten::add_.Tensor                        3 (PARAMETER), 11 (OPTIMIZER_STATE)           -> 3 (PARAMETER)
-            aten::zero_                              7 (GRADIENT)                                  -> 7 (GRADIENT)
-            aten::zero_                              9 (GRADIENT)                                  -> 9 (GRADIENT)""",
+            aten::add_.Tensor                        3 (PARAMETER), 11 (OPTIMIZER_STATE)           -> 3 (PARAMETER)""",
         )
 
     def test_categories_e2e_sequential_fwd(self) -> None:
@@ -1550,9 +1548,9 @@
             destroy                    ???                         27(v1)            2 kB
             increment_version          PARAMETER                    2(v0)         1024 kB
             destroy                    ???                         29(v1)         1024 kB
-            increment_version          GRADIENT                    16(v0)          128 kB
-            increment_version          GRADIENT                    17(v0)            2 kB
-            increment_version          GRADIENT                    13(v0)         1024 kB""")
+            destroy                    GRADIENT                    16(v0)          128 kB
+            destroy                    GRADIENT                    17(v0)            2 kB
+            destroy                    GRADIENT                    13(v0)         1024 kB""")
 
 
 if __name__ == "__main__":
diff --git a/test/profiler/test_profiler.py b/test/profiler/test_profiler.py
index c31b1ea..c0497da 100644
--- a/test/profiler/test_profiler.py
+++ b/test/profiler/test_profiler.py
@@ -2669,10 +2669,10 @@
         )
         optimizer = torch.optim.Adam(model.parameters())
         cases = (
-            (1, lambda: optimizer.zero_grad()),
-            (1, lambda: model.zero_grad()),
-            (0, lambda: optimizer.zero_grad(set_to_none=True)),
-            (0, lambda: model.zero_grad(set_to_none=True))
+            (0, lambda: optimizer.zero_grad()),
+            (0, lambda: model.zero_grad()),
+            (1, lambda: optimizer.zero_grad(set_to_none=False)),
+            (1, lambda: model.zero_grad(set_to_none=False))
         )
         num_matched = []
         for _, fn in cases:
diff --git a/test/test_cpp_extensions_jit.py b/test/test_cpp_extensions_jit.py
index 3b6d7ee..26116c6 100644
--- a/test/test_cpp_extensions_jit.py
+++ b/test/test_cpp_extensions_jit.py
@@ -565,7 +565,7 @@
         # Try calling zero_grad()
         net.zero_grad()
         for p in net.parameters():
-            self.assertEqual(p.grad, torch.zeros_like(p))
+            assert p.grad is None, "zero_grad defaults to setting grads to None"
 
         # Test train(), eval(), training (a property)
         self.assertTrue(net.training)
diff --git a/test/test_mps.py b/test/test_mps.py
index d7e560e..423f3ba 100644
--- a/test/test_mps.py
+++ b/test/test_mps.py
@@ -5957,24 +5957,27 @@
         self.assertIsNotNone(module.weight.grad)
         self.assertGreater(module.weight.grad.data.abs().sum(), 0)
         module.zero_grad()
-        self.assertEqual(module.weight.grad.data, module.weight.data.clone().zero_())
+        self.assertIsNone(module.weight.grad)
 
         module.bias.requires_grad = True
         module.zero_grad()
-        self.assertIsNotNone(module.weight.grad)
+        self.assertIsNone(module.weight.grad)
         self.assertIsNone(module.bias.grad)
         module(i).sum().backward()
         self.assertIsNotNone(module.weight.grad)
         self.assertIsNotNone(module.bias.grad)
         self.assertGreater(module.weight.grad.data.abs().sum(), 0)
         self.assertGreater(module.bias.grad.data.abs().sum(), 0)
-        module.zero_grad()
+
+        # Force set to zeros.
+        module.zero_grad(set_to_none=False)
         self.assertEqual(module.weight.grad.data, module.weight.data.clone().zero_())
         self.assertEqual(module.bias.grad.data, module.bias.data.clone().zero_())
 
-        # Force set to None.
-        module.zero_grad(set_to_none=True)
+        module.zero_grad()
         self.assertIsNone(module.weight.grad)
+        self.assertIsNone(module.bias.grad)
+
 
     def test_no_grad(self):
         for dtype in [torch.bfloat16, torch.float, torch.double]:
diff --git a/test/test_nn.py b/test/test_nn.py
index 90bafbb..e76737b 100644
--- a/test/test_nn.py
+++ b/test/test_nn.py
@@ -219,25 +219,24 @@
         self.assertIsNotNone(module.weight.grad)
         self.assertGreater(module.weight.grad.data.abs().sum(), 0)
         module.zero_grad()
-        self.assertEqual(module.weight.grad.data, module.weight.data.clone().zero_())
+        self.assertIsNone(module.weight.grad)
 
         module.bias.requires_grad = True
         module.zero_grad()
-        self.assertIsNotNone(module.weight.grad)
+        self.assertIsNone(module.weight.grad)
         self.assertIsNone(module.bias.grad)
         module(i).sum().backward()
         self.assertIsNotNone(module.weight.grad)
         self.assertIsNotNone(module.bias.grad)
         self.assertGreater(module.weight.grad.data.abs().sum(), 0)
         self.assertGreater(module.bias.grad.data.abs().sum(), 0)
-        module.zero_grad()
+        module.zero_grad(set_to_none=False)   # Force set to zeros.
         self.assertEqual(module.weight.grad.data, module.weight.data.clone().zero_())
         self.assertEqual(module.bias.grad.data, module.bias.data.clone().zero_())
 
-        # Force set to None.
-        module.zero_grad(set_to_none=True)
+        module.zero_grad()
         self.assertIsNone(module.weight.grad)
-
+        self.assertIsNone(module.bias.grad)
 
     def test_no_grad(self):
         for dtype in [torch.bfloat16, torch.float, torch.double]: