Correctly error if you pass in tensors where size arguments expected (#86126)
This also makes symintlist track intlist exception handling,
which eellison fixed.
Signed-off-by: Edward Z. Yang <[email protected]>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/86126
Approved by: https://github.com/eellison
diff --git a/test/test_proxy_tensor.py b/test/test_proxy_tensor.py
index f1e42c7..5d4c94b 100644
--- a/test/test_proxy_tensor.py
+++ b/test/test_proxy_tensor.py
@@ -875,6 +875,17 @@
self._test_dynamic(f, [(2, 4), (4, 5)], [[(2, 3), (5, 7)], [(3, 7), (9, 3)]], assert_eq=False)
+ def test_size_with_tensor(self):
+ def f(tensor):
+ max_size = torch.tensor([800, 1216], dtype=torch.int64)
+ batch_shape = [2] + list(tensor.shape[:-2]) + list(max_size)
+ return tensor.new_empty(batch_shape)
+
+ a = torch.randn(3, 800, 1199)
+ self.assertRaisesRegex(
+ RuntimeError, "data-dependent", lambda: make_fx(f, tracing_mode="symbolic")(a)
+ )
+
def test_expand(self):
def f(a):
b = torch.mul(a, a)