Apply new symbolic shape strategy to make_fx symbolic mode (#85260)
This results in some test wobbling, which looks legit. I also
added some debug helpers for stuff that I found useful while
working on this.
Signed-off-by: Edward Z. Yang <[email protected]>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/85260
Approved by: https://github.com/albanD
diff --git a/test/test_proxy_tensor.py b/test/test_proxy_tensor.py
index 42231fd..4904a09 100644
--- a/test/test_proxy_tensor.py
+++ b/test/test_proxy_tensor.py
@@ -806,7 +806,8 @@
shape_env = self._test_dynamic(f, [(3, 4)], test_inputs)
self.assertTrue(shape_env.evaluate_guards_for_args(torch.randn(4, 5)))
self.assertFalse(shape_env.evaluate_guards_for_args(torch.randn(25, 5)))
- assert len(shape_env.guards) == 1
+ # one guard for size/stride contiguity, and one substantive guard
+ assert len(shape_env.guards) == 2, "\n" + shape_env.format_guards()
def test_binary_broadcast(self):
def f(a, b):
@@ -894,8 +895,8 @@
self.assertTrue(meta_c.meta['val'].shape[0].get_pyobj() == meta_d.meta['val'].expr)
def _assert_no_guards(self, fx_g, free_symbols):
- self.assertEqual(_get_free_symbols(fx_g.shape_env), free_symbols)
- self.assertEqual(len(fx_g.shape_env.get_nontrivial_guards()), 0)
+ assert _get_free_symbols(fx_g.shape_env) == free_symbols, fx_g.shape_env.var_to_val
+ assert len(fx_g.shape_env.get_nontrivial_guards()) == 0, fx_g.shape_env.format_guards()
def test_guards_equal(self):
def f(a, b):
@@ -908,7 +909,7 @@
self._assert_no_guards(fx_g, 3)
fx_g = _trace(f, (5, 1), (1, 5))
- self._assert_no_guards(fx_g, 2)
+ self._assert_no_guards(fx_g, 3)
def f(a, b, c, d):
a = a + b
@@ -1309,7 +1310,6 @@
xfail('view_as_complex', ''), # aten.view_as_complex.default - couldn't find symbolic meta function/decomposition
xfail('view_as', ''), # aten.size.default - couldn't find symbolic meta function/decomposition
xfail('vsplit', ''), # aten.size.default - couldn't find symbolic meta function/decomposition
- xfail('where', ''), # expected predicate to be bool, got torch.float32
xfail('zero_', ''), # aten.clone.default - couldn't find symbolic meta function/decomposition
xfail('zeros_like', ''), # aten.zeros_like.default - couldn't find symbolic meta function/decomposition
xfail('unbind', ''), # aten.unbind.int - couldn't find symbolic meta function/decomposition