Stopped ProxyTensor from turning aten::lift tensors into proxy objects (#81024)

```
def f():
    val = torch.tensor(float('inf'))
    return torch.full((100, 100), val)
```
today we turn `val` into a ProxyTensor, and then complain when we try to turn `val` into a scalar.

We call `aten::lift` when we call `torch.tensor(5)`, so this just prevents those from being turned into ProxyTensors unnecessarily.

cc: @ezyang
Pull Request resolved: https://github.com/pytorch/pytorch/pull/81024
Approved by: https://github.com/ezyang
diff --git a/test/test_proxy_tensor.py b/test/test_proxy_tensor.py
index d8d3150..65bab0f 100644
--- a/test/test_proxy_tensor.py
+++ b/test/test_proxy_tensor.py
@@ -12,6 +12,7 @@
 from torch._decomp import decomposition_table
 from torch.testing._internal.common_device_type import ops
 from torch.fx.experimental.proxy_tensor import make_fx, DecompositionInterpreter
+from torch.utils._pytree import tree_map
 
 # Copied from functorch
 def xfail(op_name, variant_name='', *, device_type=None, dtypes=None):
@@ -60,40 +61,45 @@
                   UserWarning)
 
 
+def _create_new_input(x):
+    if not isinstance(x, torch.Tensor):
+        return x
+    if x.dtype != torch.float:
+        return x + 1
+    if x.is_leaf:
+        return torch.rand_like(x, requires_grad=True)
+    else:
+        return torch.rand_like(x)
+
 class TestProxyTensor(TestCase):
+    def _test(self, f, inps):
+        fx_f = make_fx(f)(*inps)
+        new_inps = tree_map(_create_new_input, inps)
+        self.assertEqual(fx_f(*new_inps), f(*new_inps))
+
     def test_make_fx_simple(self, device):
         def f(x):
             return torch.sin(x)
-        inp = torch.randn(3)
-        fx_f = make_fx(f)(inp)
-
-        new_inp = torch.randn(3)
-        self.assertEqual(fx_f(new_inp), f(new_inp))
+        self._test(f, (torch.randn(3),))
 
     def test_scalar_device(self, device):
         def f(a, b):
             return a + b
-        inps = [torch.randn(3, device=device), torch.tensor(5)]
-        fx_f = make_fx(f)(*inps)
-        self.assertEqual(fx_f(*inps), f(*inps))
-
+        self._test(f, [torch.randn(3, device=device), torch.tensor(5)])
 
     @unittest.skipIf(not USE_TORCHVISION, "test requires torchvision")
     def test_resnet18_backward_trace(self, device):
         mod = torchvision.models.resnet18()
 
         def f(x):
+            for a in mod.parameters():
+                a.grad = None
             out = mod(x)
             out.sum().backward()
             return [a.grad for a in mod.parameters()]
 
         inp = torch.randn(3, 3, 250, 250, requires_grad=True)
-        grads = f(inp)
-
-        mod.zero_grad()
-        mod(inp).sum().backward()
-        grads2 = [a.grad for a in mod.parameters()]
-        self.assertEqual(grads, grads2)
+        self._test(f, [inp])
 
     def test_proxy_tensor(self):
         def f_grad(x):
@@ -106,11 +112,7 @@
             return x.grad
 
         for f in [f_grad, f_backward]:
-            traced_graph = make_fx(f)(torch.randn(3, requires_grad=True))
-            inp = torch.randn(3, requires_grad=True)
-            traced_graph_out = traced_graph(inp)
-            assert inp.grad is None
-            torch.testing.assert_close(traced_graph_out, f(inp))
+            self._test(f, [torch.randn(3, requires_grad=True)])
 
     def test_inplace_metadata(self):
         def f(x):
@@ -119,9 +121,7 @@
             assert x.shape[-1] == 1
             return x
 
-        inps = [torch.randn(5)]
-        fx_f = make_fx(f)(*inps)
-        self.assertEqual(fx_f(*inps), f(*inps))
+        self._test(f, [torch.randn(5)])
 
     def test_mode_tracing_factory_function(self):
         def f(x):
@@ -157,6 +157,13 @@
         self.assertTrue(all([isinstance(node.target, torch._ops.OpOverload)
                              for node in traced.graph.nodes if node.op == 'call_function']))
 
+    def test_tensor_constants(self):
+        def f():
+            val = torch.tensor(float('inf'))
+            return torch.full((100, 100), val)
+
+        self._test(f, [])
+
     def test_decomposition_interpreter(self):
         def fn(x):
             return torch.nn.functional.silu(x)
@@ -212,26 +219,6 @@
     xfail('quantile'),
     xfail('tensor_split'),
     xfail('corrcoef'),
-    # Masked failures (creating a scalar tensor just to call `.item` on it)
-    xfail('_masked.amax'),
-    xfail('_masked.amax'),
-    xfail('_masked.amin'),
-    xfail('_masked.argmax'),
-    xfail('_masked.argmin'),
-    xfail('_masked.cumprod'),
-    xfail('_masked.cumsum'),
-    xfail('_masked.log_softmax'),
-    xfail('_masked.logaddexp'),
-    xfail('_masked.logsumexp'),
-    xfail('_masked.mean'),
-    xfail('_masked.median'),
-    xfail('_masked.norm'),
-    xfail('_masked.prod'),
-    xfail('_masked.softmax'),
-    xfail('_masked.softmin'),
-    xfail('_masked.std'),
-    xfail('_masked.sum'),
-    xfail('_masked.var'),
 
     # Seems like it's creating a sparse tensor that isn't captured by tensor.is_sparse
     xfail('sparse.sampled_addmm'),
@@ -260,6 +247,28 @@
     xfail('cholesky_inverse'),
     # ASAN failures due to divide by 0
     skip('nn.functional.nll_loss'),
+    # Masked failures (creating a scalar tensor just to call `.item` on it)
+    xfail('_masked.amax'),
+    xfail('_masked.amax'),
+    xfail('_masked.amin'),
+    xfail('_masked.argmax'),
+    xfail('_masked.argmin'),
+    xfail('_masked.cumprod'),
+    xfail('_masked.cumsum'),
+    xfail('_masked.log_softmax'),
+    xfail('_masked.logaddexp'),
+    xfail('_masked.logsumexp'),
+    xfail('_masked.mean'),
+    xfail('_masked.median'),
+    xfail('_masked.norm'),
+    xfail('_masked.prod'),
+    xfail('_masked.softmax'),
+    xfail('_masked.softmin'),
+    xfail('_masked.std'),
+    xfail('_masked.sum'),
+    xfail('_masked.var'),
+    # Same as masked failures - preventing torch.tensor constants from turning into proxytensors causes issues with faketensors
+    xfail('__getitem__'),
 }