Fix dumb make_fx issue (#84011) Pull Request resolved: https://github.com/pytorch/pytorch/pull/84011 Approved by: https://github.com/ezyang
diff --git a/test/test_proxy_tensor.py b/test/test_proxy_tensor.py index 09de59f..105b1d5 100644 --- a/test/test_proxy_tensor.py +++ b/test/test_proxy_tensor.py
@@ -525,6 +525,29 @@ torch.allclose(fx_f(input, params)[1], f(input, params)[1]) ) + def test_make_fx_model_double_param(self): + class Emformer(torch.nn.Module): + def __init__( + self, + input_dim: int = 256, + ) -> None: + super().__init__() + + self.layer_norm = torch.nn.LayerNorm(input_dim) + + def forward(mod_self, x): # noqa: B902 + self.assertTrue(isinstance(mod_self.layer_norm.weight, torch.Tensor)) + y = mod_self.layer_norm(x) + self.assertTrue(isinstance(mod_self.layer_norm.weight, torch.Tensor)) + z = mod_self.layer_norm(y) + return z + + + gm = make_fx(Emformer())(torch.randn(16, 1, 256)) + ops = set([n.target for n in gm.graph.nodes if n.op == 'call_function']) + self.assertEqual(len(ops), 2) + + def test_make_fx_model_fwd_bwd_wgtupdate(self): class Foo(torch.nn.Module): def __init__(self):