Stopped ProxyTensor from turning aten::lift tensors into proxy objects (#81024)
```
def f():
val = torch.tensor(float('inf'))
return torch.full((100, 100), val)
```
today we turn `val` into a ProxyTensor, and then complain when we try to turn `val` into a scalar.
We call `aten::lift` when we call `torch.tensor(5)`, so this just prevents those from being turned into ProxyTensors unnecessarily.
cc: @ezyang
Pull Request resolved: https://github.com/pytorch/pytorch/pull/81024
Approved by: https://github.com/ezyang
diff --git a/test/test_proxy_tensor.py b/test/test_proxy_tensor.py
index d8d3150..65bab0f 100644
--- a/test/test_proxy_tensor.py
+++ b/test/test_proxy_tensor.py
@@ -12,6 +12,7 @@
from torch._decomp import decomposition_table
from torch.testing._internal.common_device_type import ops
from torch.fx.experimental.proxy_tensor import make_fx, DecompositionInterpreter
+from torch.utils._pytree import tree_map
# Copied from functorch
def xfail(op_name, variant_name='', *, device_type=None, dtypes=None):
@@ -60,40 +61,45 @@
UserWarning)
+def _create_new_input(x):
+ if not isinstance(x, torch.Tensor):
+ return x
+ if x.dtype != torch.float:
+ return x + 1
+ if x.is_leaf:
+ return torch.rand_like(x, requires_grad=True)
+ else:
+ return torch.rand_like(x)
+
class TestProxyTensor(TestCase):
+ def _test(self, f, inps):
+ fx_f = make_fx(f)(*inps)
+ new_inps = tree_map(_create_new_input, inps)
+ self.assertEqual(fx_f(*new_inps), f(*new_inps))
+
def test_make_fx_simple(self, device):
def f(x):
return torch.sin(x)
- inp = torch.randn(3)
- fx_f = make_fx(f)(inp)
-
- new_inp = torch.randn(3)
- self.assertEqual(fx_f(new_inp), f(new_inp))
+ self._test(f, (torch.randn(3),))
def test_scalar_device(self, device):
def f(a, b):
return a + b
- inps = [torch.randn(3, device=device), torch.tensor(5)]
- fx_f = make_fx(f)(*inps)
- self.assertEqual(fx_f(*inps), f(*inps))
-
+ self._test(f, [torch.randn(3, device=device), torch.tensor(5)])
@unittest.skipIf(not USE_TORCHVISION, "test requires torchvision")
def test_resnet18_backward_trace(self, device):
mod = torchvision.models.resnet18()
def f(x):
+ for a in mod.parameters():
+ a.grad = None
out = mod(x)
out.sum().backward()
return [a.grad for a in mod.parameters()]
inp = torch.randn(3, 3, 250, 250, requires_grad=True)
- grads = f(inp)
-
- mod.zero_grad()
- mod(inp).sum().backward()
- grads2 = [a.grad for a in mod.parameters()]
- self.assertEqual(grads, grads2)
+ self._test(f, [inp])
def test_proxy_tensor(self):
def f_grad(x):
@@ -106,11 +112,7 @@
return x.grad
for f in [f_grad, f_backward]:
- traced_graph = make_fx(f)(torch.randn(3, requires_grad=True))
- inp = torch.randn(3, requires_grad=True)
- traced_graph_out = traced_graph(inp)
- assert inp.grad is None
- torch.testing.assert_close(traced_graph_out, f(inp))
+ self._test(f, [torch.randn(3, requires_grad=True)])
def test_inplace_metadata(self):
def f(x):
@@ -119,9 +121,7 @@
assert x.shape[-1] == 1
return x
- inps = [torch.randn(5)]
- fx_f = make_fx(f)(*inps)
- self.assertEqual(fx_f(*inps), f(*inps))
+ self._test(f, [torch.randn(5)])
def test_mode_tracing_factory_function(self):
def f(x):
@@ -157,6 +157,13 @@
self.assertTrue(all([isinstance(node.target, torch._ops.OpOverload)
for node in traced.graph.nodes if node.op == 'call_function']))
+ def test_tensor_constants(self):
+ def f():
+ val = torch.tensor(float('inf'))
+ return torch.full((100, 100), val)
+
+ self._test(f, [])
+
def test_decomposition_interpreter(self):
def fn(x):
return torch.nn.functional.silu(x)
@@ -212,26 +219,6 @@
xfail('quantile'),
xfail('tensor_split'),
xfail('corrcoef'),
- # Masked failures (creating a scalar tensor just to call `.item` on it)
- xfail('_masked.amax'),
- xfail('_masked.amax'),
- xfail('_masked.amin'),
- xfail('_masked.argmax'),
- xfail('_masked.argmin'),
- xfail('_masked.cumprod'),
- xfail('_masked.cumsum'),
- xfail('_masked.log_softmax'),
- xfail('_masked.logaddexp'),
- xfail('_masked.logsumexp'),
- xfail('_masked.mean'),
- xfail('_masked.median'),
- xfail('_masked.norm'),
- xfail('_masked.prod'),
- xfail('_masked.softmax'),
- xfail('_masked.softmin'),
- xfail('_masked.std'),
- xfail('_masked.sum'),
- xfail('_masked.var'),
# Seems like it's creating a sparse tensor that isn't captured by tensor.is_sparse
xfail('sparse.sampled_addmm'),
@@ -260,6 +247,28 @@
xfail('cholesky_inverse'),
# ASAN failures due to divide by 0
skip('nn.functional.nll_loss'),
+ # Masked failures (creating a scalar tensor just to call `.item` on it)
+ xfail('_masked.amax'),
+ xfail('_masked.amax'),
+ xfail('_masked.amin'),
+ xfail('_masked.argmax'),
+ xfail('_masked.argmin'),
+ xfail('_masked.cumprod'),
+ xfail('_masked.cumsum'),
+ xfail('_masked.log_softmax'),
+ xfail('_masked.logaddexp'),
+ xfail('_masked.logsumexp'),
+ xfail('_masked.mean'),
+ xfail('_masked.median'),
+ xfail('_masked.norm'),
+ xfail('_masked.prod'),
+ xfail('_masked.softmax'),
+ xfail('_masked.softmin'),
+ xfail('_masked.std'),
+ xfail('_masked.sum'),
+ xfail('_masked.var'),
+ # Same as masked failures - preventing torch.tensor constants from turning into proxytensors causes issues with faketensors
+ xfail('__getitem__'),
}