Add test_make_fx_model_train example (#980) (#82011)

Summary: Pull Request resolved: https://github.com/pytorch/functorch/pull/980

Test Plan: CI should pass

Differential Revision: D38078694

Pulled By: mostafaelhoushi

Pull Request resolved: https://github.com/pytorch/pytorch/pull/82011
Approved by: https://github.com/Chillee
diff --git a/test/test_proxy_tensor.py b/test/test_proxy_tensor.py
index e71395c..f3e194c 100644
--- a/test/test_proxy_tensor.py
+++ b/test/test_proxy_tensor.py
@@ -4,6 +4,8 @@
 import torch
 import unittest
 import warnings
+import torch.nn.utils._stateless as stateless
+from collections.abc import Iterable
 from torch.testing._internal.common_device_type import instantiate_device_type_tests
 from torch.testing._internal.common_methods_invocations import DecorateInfo
 from torch.testing._internal.common_methods_invocations import op_db, wrapper_set_seed
@@ -274,6 +276,72 @@
 
         self.assertEqual(fx_module(x), decomposed_module(x))
 
+    def test_make_fx_model_fwd_bwd(self, device):
+        class Foo(torch.nn.Module):
+            def __init__(self):
+                super().__init__()
+                self.linear = torch.nn.Linear(5, 5)
+
+            def forward(self, x):
+                return self.linear(x).relu()
+
+        model = Foo()
+
+        def f(x, params):
+            out = stateless.functional_call(model, params, x).sum()
+            out.backward()
+            return list(params.values())
+        input = torch.randn(3, 5, requires_grad=True)
+        params = dict(model.named_parameters())
+        fx_f = make_fx(f)(input, params)
+        # fx may change the order of parameters in list, so using set() to compare
+        self.assertTrue(
+            torch.allclose(fx_f(input, params)[0], f(input, params)[0])
+            or
+            torch.allclose(fx_f(input, params)[0], f(input, params)[1])
+        )
+        self.assertTrue(
+            torch.allclose(fx_f(input, params)[1], f(input, params)[0])
+            or
+            torch.allclose(fx_f(input, params)[1], f(input, params)[1])
+        )
+
+    def test_make_fx_model_fwd_bwd_wgtupdate(self, device):
+        class Foo(torch.nn.Module):
+            def __init__(self):
+                super().__init__()
+                self.linear = torch.nn.Linear(5, 5)
+
+            def forward(self, x):
+                return self.linear(x).relu()
+
+        model = Foo()
+
+        def f(args, params, buffers):
+            if not isinstance(args, Iterable):
+                args = [args]
+            params_and_buffers = {**params, **buffers}
+            out = stateless.functional_call(model, params_and_buffers, args)
+            out.sum().backward()
+            return [p - 1e-4 * p.grad for p in params.values()]
+
+        input = torch.randn(3, 5, requires_grad=True)
+        params = dict(model.named_parameters())
+        buffers = dict(model.named_buffers())
+        fx_f = make_fx(f)(input, params, buffers)
+        # fx may change the order of parameters in list, so using set() to compare
+        # also there is a numerical difference in results so changing atol from 1e-08 to 1e-03
+        self.assertTrue(
+            torch.allclose(fx_f(input, params, buffers)[0], f(input, params, buffers)[0], atol=1e-03)
+            or
+            torch.allclose(fx_f(input, params, buffers)[0], f(input, params, buffers)[1], atol=1e-03)
+        )
+        self.assertTrue(
+            torch.allclose(fx_f(input, params, buffers)[1], f(input, params, buffers)[0], atol=1e-03)
+            or
+            torch.allclose(fx_f(input, params, buffers)[1], f(input, params, buffers)[1], atol=1e-03)
+        )
+
 # TODO: Need to test the guards themselves specifically as well
 @skipIfNoSympy
 class TestSymbolicTracing(TestCase):