Ensure that symbolic variables incorporate fresh constraints before they're used (#87254)
cc @jansel @lezcano @fdrocha
Pull Request resolved: https://github.com/pytorch/pytorch/pull/87254
Approved by: https://github.com/jansel
diff --git a/test/test_proxy_tensor.py b/test/test_proxy_tensor.py
index c697c8c..7f46b47 100644
--- a/test/test_proxy_tensor.py
+++ b/test/test_proxy_tensor.py
@@ -942,16 +942,30 @@
self._test_dynamic(f, [(3,)], [[(3,)], [(4,)], [(2,)]])
self._test_dynamic(f, [(5, 1)], [[(4, 1)], [(3, 1)], [(6, 1)]])
- def test_symbolic_meta(self):
+ def test_metadata(self):
def f(a, b):
d = a.new_empty(a.shape[0] + b.shape[0])
return d
fx_g = make_fx(f, tracing_mode="symbolic")(torch.randn(5), torch.randn(4))
- fx_g.graph.eliminate_dead_code()
- fx_g.recompile()
meta_c = _get_node(fx_g, lambda x: x.target == aten.new_empty.default)
meta_d = _get_node(fx_g, lambda x: x.target == operator.add)
- self.assertTrue(meta_c.meta['val'].shape[0].get_pyobj() == meta_d.meta['val'].expr)
+ self.assertTrue(meta_c.meta['val'].shape[0].get_pyobj().expr == meta_d.meta['val'].expr)
+
+ def test_metadata_fresh(self):
+ def f(x):
+ assert x.shape[0] == 3
+ return x.cos()
+
+ fx_g = make_fx(f, tracing_mode="symbolic")(torch.randn(3))
+ meta_cos = _get_node(fx_g, lambda x: x.target == aten.cos.default)
+ meta_inp = _get_node(fx_g, lambda x: x.op == 'placeholder')
+ self.assertTrue(meta_cos.meta['val'].shape[0].get_pyobj().expr == 3)
+ # Checks if the input expr has been updated even though the constraint
+ # happened afterwards
+ self.assertTrue(meta_inp.meta['val'].shape[0].get_pyobj().expr == 3)
+
+
+
def test_return_symint(self):
def f(x):