Fix dumb make_fx issue (#84011)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/84011
Approved by: https://github.com/ezyang
diff --git a/test/test_proxy_tensor.py b/test/test_proxy_tensor.py
index 09de59f..105b1d5 100644
--- a/test/test_proxy_tensor.py
+++ b/test/test_proxy_tensor.py
@@ -525,6 +525,29 @@
             torch.allclose(fx_f(input, params)[1], f(input, params)[1])
         )
 
+    def test_make_fx_model_double_param(self):
+        class Emformer(torch.nn.Module):
+            def __init__(
+                self,
+                input_dim: int = 256,
+            ) -> None:
+                super().__init__()
+
+                self.layer_norm = torch.nn.LayerNorm(input_dim)
+
+            def forward(mod_self, x):  # noqa: B902
+                self.assertTrue(isinstance(mod_self.layer_norm.weight, torch.Tensor))
+                y = mod_self.layer_norm(x)
+                self.assertTrue(isinstance(mod_self.layer_norm.weight, torch.Tensor))
+                z = mod_self.layer_norm(y)
+                return z
+
+
+        gm = make_fx(Emformer())(torch.randn(16, 1, 256))
+        ops = set([n.target for n in gm.graph.nodes if n.op == 'call_function'])
+        self.assertEqual(len(ops), 2)
+
+
     def test_make_fx_model_fwd_bwd_wgtupdate(self):
         class Foo(torch.nn.Module):
             def __init__(self):