Reorganize test_proxy_tensor.py per tracing mode (#82739)

I intend to run all of the proxy tensor tests on each of our tracing
modes, but to do this I have to make the tests parametrized on
tracing mode first.  This does that refactor, without adding any
new tests.

Signed-off-by: Edward Z. Yang <[email protected]>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/82739
Approved by: https://github.com/eellison
diff --git a/test/test_proxy_tensor.py b/test/test_proxy_tensor.py
index 263a3b1..19fdc8c 100644
--- a/test/test_proxy_tensor.py
+++ b/test/test_proxy_tensor.py
@@ -120,8 +120,10 @@
         return torch.rand_like(x)
 
 class TestProxyTensor(TestCase):
+    tracing_mode = "real"
+
     def _test(self, f, inps):
-        fx_f = make_fx(f)(*inps)
+        fx_f = make_fx(f, tracing_mode=self.tracing_mode)(*inps)
         new_inps = tree_map(_create_new_input, inps)
         self.assertEqual(fx_f(*new_inps), f(*new_inps))
 
@@ -279,7 +281,7 @@
             return x + torch.randn(x.shape)
 
         # default behavior should trace factory functions
-        traced = make_fx(f)(torch.randn(3))
+        traced = make_fx(f, tracing_mode=self.tracing_mode)(torch.randn(3))
         self.assertTrue(
             any(
                 node.target == aten.randn.default
@@ -287,23 +289,11 @@
             )
         )
 
-    def test_mode_tracing_factory_function_no_factory_function(self):
-        def f(x):
-            return x + torch.randn(x.shape)
-        # setting the flag to false should not trace factory functions
-        traced = make_fx(f, trace_factory_functions=False)(torch.randn(3))
-        self.assertFalse(
-            any(
-                node.target == aten.randn.default
-                for node in traced.graph.nodes
-            )
-        )
-
     def test_make_fx_overloads(self):
         def f(x):
             return x.cos() + torch.randn(x.shape)
 
-        traced = make_fx(f)(torch.randn(3))
+        traced = make_fx(f, tracing_mode=self.tracing_mode)(torch.randn(3))
 
         self.assertTrue(all([isinstance(node.target, torch._ops.OpOverload)
                              for node in traced.graph.nodes if node.op == 'call_function']))
@@ -316,29 +306,20 @@
         self._test(f, [])
 
     def test_constant_proxy_tensor(self):
-        from torch.fx.experimental.proxy_tensor import make_fx
-
         def f():
             val = torch.tensor(float('inf'))
             return torch.full((100, 100), val)
 
-        g = make_fx(f)()
+        g = make_fx(f, tracing_mode=self.tracing_mode)()
         self.assertEqual(g(), f())
 
     def test_constant_proxy_tensor_mut(self):
-        from torch.fx.experimental.proxy_tensor import make_fx
-
         def f():
             val = torch.tensor(float(1))
             val.add_(2)
             return torch.full((100, 100), val)
 
-        g = make_fx(f)()
-        self.assertEqual(g(), f())
-        # In case we mutated shared state in the g graph!
-        self.assertEqual(g(), f())
-
-        g = make_fx(f, tracing_mode="fake")()
+        g = make_fx(f, tracing_mode=self.tracing_mode)()
         self.assertEqual(g(), f())
         # In case we mutated shared state in the g graph!
         self.assertEqual(g(), f())
@@ -349,37 +330,15 @@
             r, = torch.unbind(val, 0)
             return r.item()
 
-        g = make_fx(f)()
+        g = make_fx(f, tracing_mode=self.tracing_mode)()
         self.assertEqual(g(), f())
 
-    def test_issue82547(self):
-        x = nn.Parameter(torch.randn(3, 3))
-
-        def f():
-            return torch.ops.aten.t.default(x)
-        self.assertRaisesRegex(Exception, "non-Fake Tensor", lambda: make_fx(f, tracing_mode="fake")())
-
-        class A(torch.Tensor):
-            pass
-
-        x = A(torch.randn(3, 3))
-        self.assertRaisesRegex(TypeError, "no implementation found", lambda: make_fx(f, tracing_mode="fake")())
-
-    def test_use_fake_and_tensor(self):
-        def f(x, y):
-            z = torch.tensor([2.0, 3.0])
-            return x + y + z
-
-        g = make_fx(f, tracing_mode="fake")(torch.randn(2), torch.randn(2))
-        x, y = torch.randn(2), torch.randn(2)
-        self.assertEqual(g(x, y), f(x, y))
-
     def test_decomposition_interpreter(self):
         def fn(x):
             return torch.nn.functional.silu(x)
 
         x = torch.rand((4, 4))
-        fx_module = make_fx(fn, decomposition_table=None)(x)
+        fx_module = make_fx(fn, tracing_mode=self.tracing_mode, decomposition_table=None)(x)
 
         found_silu = False
         for n in fx_module.graph.nodes:
@@ -421,7 +380,7 @@
             return list(params.values())
         input = torch.randn(3, 5, requires_grad=True)
         params = dict(model.named_parameters())
-        fx_f = make_fx(f)(input, params)
+        fx_f = make_fx(f, tracing_mode=self.tracing_mode)(input, params)
         # fx may change the order of parameters in list, so using set() to compare
         self.assertTrue(
             torch.allclose(fx_f(input, params)[0], f(input, params)[0])
@@ -456,7 +415,7 @@
         input = torch.randn(3, 5, requires_grad=True)
         params = dict(model.named_parameters())
         buffers = dict(model.named_buffers())
-        fx_f = make_fx(f)(input, params, buffers)
+        fx_f = make_fx(f, tracing_mode=self.tracing_mode)(input, params, buffers)
         # fx may change the order of parameters in list, so using set() to compare
         # also there is a numerical difference in results so changing atol from 1e-08 to 1e-03
         self.assertTrue(
@@ -470,6 +429,43 @@
             torch.allclose(fx_f(input, params, buffers)[1], f(input, params, buffers)[1], atol=1e-03)
         )
 
+
+class TestRealProxyTensor(TestCase):
+    def test_mode_tracing_factory_function_no_factory_function(self):
+        def f(x):
+            return x + torch.randn(x.shape)
+        # setting the flag to false should not trace factory functions
+        traced = make_fx(f, trace_factory_functions=False)(torch.randn(3))
+        self.assertFalse(
+            any(
+                node.target == aten.randn.default
+                for node in traced.graph.nodes
+            )
+        )
+
+class TestFakeProxyTensor(TestCase):
+    def test_issue82547(self):
+        x = nn.Parameter(torch.randn(3, 3))
+
+        def f():
+            return torch.ops.aten.t.default(x)
+        self.assertRaisesRegex(Exception, "non-Fake Tensor", lambda: make_fx(f, tracing_mode="fake")())
+
+        class A(torch.Tensor):
+            pass
+
+        x = A(torch.randn(3, 3))
+        self.assertRaisesRegex(TypeError, "no implementation found", lambda: make_fx(f, tracing_mode="fake")())
+
+    def test_use_fake_and_tensor(self):
+        def f(x, y):
+            z = torch.tensor([2.0, 3.0])
+            return x + y + z
+
+        g = make_fx(f, tracing_mode="fake")(torch.randn(2), torch.randn(2))
+        x, y = torch.randn(2), torch.randn(2)
+        self.assertEqual(g(x, y), f(x, y))
+
 # TODO: Need to test the guards themselves specifically as well
 @skipIfNoSympy
 class TestSymbolicTracing(TestCase):