Ensure that symbolic variables incorporate fresh constraints before they're used (#87254)

cc @jansel @lezcano @fdrocha
Pull Request resolved: https://github.com/pytorch/pytorch/pull/87254
Approved by: https://github.com/jansel
diff --git a/test/inductor/test_torchinductor.py b/test/inductor/test_torchinductor.py
index c850da7..787f0ec 100644
--- a/test/inductor/test_torchinductor.py
+++ b/test/inductor/test_torchinductor.py
@@ -3801,25 +3801,6 @@
         ]
         self.common(forward, args)
 
-    @unittest.skip("https://github.com/pytorch/torchdynamo/issues/1297")
-    @patch.object(torch._inductor.config.triton, "cudagraphs", False)
-    def test_symbolic(self):
-        def f(x):
-            x = x.cos()
-            x = x.view(x.shape[0] * 2, -1)
-            return (x,)
-
-        traced = make_fx(f, tracing_mode="symbolic")(
-            torch.randn(8, 4, device=self.device)
-        )
-        compiled = compile_fx_inner(traced, [torch.randn(8, 4, device=self.device)])
-
-        out = compiled([torch.randn(8, 4, device=self.device)])
-        self.assertEqual(out[0].shape, (16, 2))
-
-        out = compiled([torch.randn(12, 4, device=self.device)])
-        self.assertEqual(out[0].shape, (24, 2))
-
     @requires_cuda()
     @patch.object(config.triton, "cudagraphs", False)
     def test_unspec_inputs(self):
diff --git a/test/test_proxy_tensor.py b/test/test_proxy_tensor.py
index c697c8c..7f46b47 100644
--- a/test/test_proxy_tensor.py
+++ b/test/test_proxy_tensor.py
@@ -942,16 +942,30 @@
         self._test_dynamic(f, [(3,)], [[(3,)], [(4,)], [(2,)]])
         self._test_dynamic(f, [(5, 1)], [[(4, 1)], [(3, 1)], [(6, 1)]])
 
-    def test_symbolic_meta(self):
+    def test_metadata(self):
         def f(a, b):
             d = a.new_empty(a.shape[0] + b.shape[0])
             return d
         fx_g = make_fx(f, tracing_mode="symbolic")(torch.randn(5), torch.randn(4))
-        fx_g.graph.eliminate_dead_code()
-        fx_g.recompile()
         meta_c = _get_node(fx_g, lambda x: x.target == aten.new_empty.default)
         meta_d = _get_node(fx_g, lambda x: x.target == operator.add)
-        self.assertTrue(meta_c.meta['val'].shape[0].get_pyobj() == meta_d.meta['val'].expr)
+        self.assertTrue(meta_c.meta['val'].shape[0].get_pyobj().expr == meta_d.meta['val'].expr)
+
+    def test_metadata_fresh(self):
+        def f(x):
+            assert x.shape[0] == 3
+            return x.cos()
+
+        fx_g = make_fx(f, tracing_mode="symbolic")(torch.randn(3))
+        meta_cos = _get_node(fx_g, lambda x: x.target == aten.cos.default)
+        meta_inp = _get_node(fx_g, lambda x: x.op == 'placeholder')
+        self.assertTrue(meta_cos.meta['val'].shape[0].get_pyobj().expr == 3)
+        # Checks if the input expr has been updated even though the constraint
+        # happened afterwards
+        self.assertTrue(meta_inp.meta['val'].shape[0].get_pyobj().expr == 3)
+
+
+
 
     def test_return_symint(self):
         def f(x):