Add test_make_fx_model_train example (#980) (#82011)
Summary: Pull Request resolved: https://github.com/pytorch/functorch/pull/980
Test Plan: CI should pass
Differential Revision: D38078694
Pulled By: mostafaelhoushi
Pull Request resolved: https://github.com/pytorch/pytorch/pull/82011
Approved by: https://github.com/Chillee
diff --git a/test/test_proxy_tensor.py b/test/test_proxy_tensor.py
index e71395c..f3e194c 100644
--- a/test/test_proxy_tensor.py
+++ b/test/test_proxy_tensor.py
@@ -4,6 +4,8 @@
import torch
import unittest
import warnings
+import torch.nn.utils._stateless as stateless
+from collections.abc import Iterable
from torch.testing._internal.common_device_type import instantiate_device_type_tests
from torch.testing._internal.common_methods_invocations import DecorateInfo
from torch.testing._internal.common_methods_invocations import op_db, wrapper_set_seed
@@ -274,6 +276,72 @@
self.assertEqual(fx_module(x), decomposed_module(x))
+ def test_make_fx_model_fwd_bwd(self, device):
+ class Foo(torch.nn.Module):
+ def __init__(self):
+ super().__init__()
+ self.linear = torch.nn.Linear(5, 5)
+
+ def forward(self, x):
+ return self.linear(x).relu()
+
+ model = Foo()
+
+ def f(x, params):
+ out = stateless.functional_call(model, params, x).sum()
+ out.backward()
+ return list(params.values())
+ input = torch.randn(3, 5, requires_grad=True)
+ params = dict(model.named_parameters())
+ fx_f = make_fx(f)(input, params)
+ # fx may change the order of parameters in list, so using set() to compare
+ self.assertTrue(
+ torch.allclose(fx_f(input, params)[0], f(input, params)[0])
+ or
+ torch.allclose(fx_f(input, params)[0], f(input, params)[1])
+ )
+ self.assertTrue(
+ torch.allclose(fx_f(input, params)[1], f(input, params)[0])
+ or
+ torch.allclose(fx_f(input, params)[1], f(input, params)[1])
+ )
+
+ def test_make_fx_model_fwd_bwd_wgtupdate(self, device):
+ class Foo(torch.nn.Module):
+ def __init__(self):
+ super().__init__()
+ self.linear = torch.nn.Linear(5, 5)
+
+ def forward(self, x):
+ return self.linear(x).relu()
+
+ model = Foo()
+
+ def f(args, params, buffers):
+ if not isinstance(args, Iterable):
+ args = [args]
+ params_and_buffers = {**params, **buffers}
+ out = stateless.functional_call(model, params_and_buffers, args)
+ out.sum().backward()
+ return [p - 1e-4 * p.grad for p in params.values()]
+
+ input = torch.randn(3, 5, requires_grad=True)
+ params = dict(model.named_parameters())
+ buffers = dict(model.named_buffers())
+ fx_f = make_fx(f)(input, params, buffers)
+ # fx may change the order of parameters in list, so using set() to compare
+ # also there is a numerical difference in results so changing atol from 1e-08 to 1e-03
+ self.assertTrue(
+ torch.allclose(fx_f(input, params, buffers)[0], f(input, params, buffers)[0], atol=1e-03)
+ or
+ torch.allclose(fx_f(input, params, buffers)[0], f(input, params, buffers)[1], atol=1e-03)
+ )
+ self.assertTrue(
+ torch.allclose(fx_f(input, params, buffers)[1], f(input, params, buffers)[0], atol=1e-03)
+ or
+ torch.allclose(fx_f(input, params, buffers)[1], f(input, params, buffers)[1], atol=1e-03)
+ )
+
# TODO: Need to test the guards themselves specifically as well
@skipIfNoSympy
class TestSymbolicTracing(TestCase):