Ensure that symbolic variables incorporate fresh constraints before they're used (#87254)

cc @jansel @lezcano @fdrocha
Pull Request resolved: https://github.com/pytorch/pytorch/pull/87254
Approved by: https://github.com/jansel
diff --git a/test/test_proxy_tensor.py b/test/test_proxy_tensor.py
index c697c8c..7f46b47 100644
--- a/test/test_proxy_tensor.py
+++ b/test/test_proxy_tensor.py
@@ -942,16 +942,30 @@
         self._test_dynamic(f, [(3,)], [[(3,)], [(4,)], [(2,)]])
         self._test_dynamic(f, [(5, 1)], [[(4, 1)], [(3, 1)], [(6, 1)]])
 
-    def test_symbolic_meta(self):
+    def test_metadata(self):
         def f(a, b):
             d = a.new_empty(a.shape[0] + b.shape[0])
             return d
         fx_g = make_fx(f, tracing_mode="symbolic")(torch.randn(5), torch.randn(4))
-        fx_g.graph.eliminate_dead_code()
-        fx_g.recompile()
         meta_c = _get_node(fx_g, lambda x: x.target == aten.new_empty.default)
         meta_d = _get_node(fx_g, lambda x: x.target == operator.add)
-        self.assertTrue(meta_c.meta['val'].shape[0].get_pyobj() == meta_d.meta['val'].expr)
+        self.assertTrue(meta_c.meta['val'].shape[0].get_pyobj().expr == meta_d.meta['val'].expr)
+
+    def test_metadata_fresh(self):
+        def f(x):
+            assert x.shape[0] == 3
+            return x.cos()
+
+        fx_g = make_fx(f, tracing_mode="symbolic")(torch.randn(3))
+        meta_cos = _get_node(fx_g, lambda x: x.target == aten.cos.default)
+        meta_inp = _get_node(fx_g, lambda x: x.op == 'placeholder')
+        self.assertTrue(meta_cos.meta['val'].shape[0].get_pyobj().expr == 3)
+        # Checks if the input expr has been updated even though the constraint
+        # happened afterwards
+        self.assertTrue(meta_inp.meta['val'].shape[0].get_pyobj().expr == 3)
+
+
+
 
     def test_return_symint(self):
         def f(x):