Made tracing of proxy symints lazy (#85185)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/85185
Approved by: https://github.com/ezyang
diff --git a/test/test_proxy_tensor.py b/test/test_proxy_tensor.py
index ddf0cde..bc83773 100644
--- a/test/test_proxy_tensor.py
+++ b/test/test_proxy_tensor.py
@@ -846,9 +846,7 @@
self.assertExpectedInline(r, """\
def forward(self, a_1):
sym_size = torch.ops.aten.sym_size(a_1, 0)
- sym_float = torch.fx.experimental.symbolic_shapes.sym_float(sym_size)
- lt = sym_size < 0
- eq = sym_size == sym_size; sym_size = None
+ sym_float = torch.fx.experimental.symbolic_shapes.sym_float(sym_size); sym_size = None
div = torch.ops.prims.div.default(a_1, sym_float); a_1 = sym_float = None
return div""")
@@ -919,6 +917,16 @@
fx_g = _trace(f, 7, 7, 4, 3)
self._assert_no_guards(fx_g, 2)
+ def f(a, b, c, d, e):
+ vals = [a, b, c, d, e]
+ x = a
+ for idx in range(len(vals) - 1):
+ x = torch.cat([x, vals[idx]]) + vals[idx + 1]
+ return x
+
+ fx_g = _trace(f, 2, 4, 8, 16, 32)
+ self._assert_no_guards(fx_g, 1)
+
def f(a, b):
a = a.view(b.shape[0])
return a + b.sum()
@@ -948,6 +956,7 @@
+
make_fx_failures = {
# unknown
xfail('allclose'),