Make proxy tensor support item() calls on torch.tensor constants (#81192)

This PR is doing a few interrelated things, all of which are necessary to get correctness. Read the comment in torch/fx/experimental/proxy_tensor.py for the high level overview.

Let's break down the parts of this PR:

* Bug fix where `enable_torch_dispatch_mode` with `None` doesn't work. This make `enable_torch_dispatch_mode(current_mode.inner)` work which is the basis for how we temporarily disable fake tensor mode.
* Bug fix for when fake tensor mode is combined with a non-mode tensor subclass. This actually could be ablated from this PR but it affects where the logic for allowing non fake tensor inputs with lift goes, so it's all in here in one go. There are some relevant tests for the fix in fake tensor, but it turns out I didn't need this because I'm always using proxy tensors as a mode (which ensures the ordering is right.)
* New `lift_fresh` view operator.  Note that like lift, we have to manually write the functionalize kernel for these functions.
* The actual change, which is to save constants when we see them in the proxy tensor mode, and then propagate them as we go (because otherwise you'll handle mutations on constants incorrectly--see test.)

This is mildly BC-breaking if anyone was previously interposing on
at::lift, but this operator was relatively new and I checked
functorch which has no explicit reference to lift.  So I think it
should not be too disruptive.

Signed-off-by: Edward Z. Yang <[email protected]>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/81192
Approved by: https://github.com/samdow, https://github.com/bdhirsh
diff --git a/test/test_proxy_tensor.py b/test/test_proxy_tensor.py
index 65bab0f..e2d3782 100644
--- a/test/test_proxy_tensor.py
+++ b/test/test_proxy_tensor.py
@@ -164,6 +164,43 @@
 
         self._test(f, [])
 
+    def test_constant_proxy_tensor(self):
+        from torch.fx.experimental.proxy_tensor import make_fx
+
+        def f():
+            val = torch.tensor(float('inf'))
+            return torch.full((100, 100), val)
+
+        g = make_fx(f)()
+        self.assertEqual(g(), f())
+
+    def test_constant_proxy_tensor_mut(self):
+        from torch.fx.experimental.proxy_tensor import make_fx
+
+        def f():
+            val = torch.tensor(float(1))
+            val.add_(2)
+            return torch.full((100, 100), val)
+
+        g = make_fx(f)()
+        self.assertEqual(g(), f())
+        # In case we mutated shared state in the g graph!
+        self.assertEqual(g(), f())
+
+        g = make_fx(f, use_fake=True)()
+        self.assertEqual(g(), f())
+        # In case we mutated shared state in the g graph!
+        self.assertEqual(g(), f())
+
+    def test_use_fake_and_tensor(self):
+        def f(x, y):
+            z = torch.tensor([2.0, 3.0])
+            return x + y + z
+
+        g = make_fx(f, use_fake=True)(torch.randn(2), torch.randn(2))
+        x, y = torch.randn(2), torch.randn(2)
+        self.assertEqual(g(x, y), f(x, y))
+
     def test_decomposition_interpreter(self):
         def fn(x):
             return torch.nn.functional.silu(x)
@@ -247,28 +284,6 @@
     xfail('cholesky_inverse'),
     # ASAN failures due to divide by 0
     skip('nn.functional.nll_loss'),
-    # Masked failures (creating a scalar tensor just to call `.item` on it)
-    xfail('_masked.amax'),
-    xfail('_masked.amax'),
-    xfail('_masked.amin'),
-    xfail('_masked.argmax'),
-    xfail('_masked.argmin'),
-    xfail('_masked.cumprod'),
-    xfail('_masked.cumsum'),
-    xfail('_masked.log_softmax'),
-    xfail('_masked.logaddexp'),
-    xfail('_masked.logsumexp'),
-    xfail('_masked.mean'),
-    xfail('_masked.median'),
-    xfail('_masked.norm'),
-    xfail('_masked.prod'),
-    xfail('_masked.softmax'),
-    xfail('_masked.softmin'),
-    xfail('_masked.std'),
-    xfail('_masked.sum'),
-    xfail('_masked.var'),
-    # Same as masked failures - preventing torch.tensor constants from turning into proxytensors causes issues with faketensors
-    xfail('__getitem__'),
 }