Delete ProxyTensor wrapper subclass (#83330)

I was working on https://github.com/pytorch/torchdynamo/issues/80 and my
working hypothesis for what was causing the error was that proxy tensor
was not advertising correct dispatch keys, causing AMP to operate
differently when you traced.  I could have fixed this directly by
replicating fake tensor's fix for setting dispatch keys to also apply to
proxy tensor, but I was like, "Why must I repeat myself."

This PR is the result.  It completely deletes the ProxyTensor wrapper
subclass, so that when we are tracing, the tensors flowing through the
program are the *original* real or fake tensors, depending on what the
user requested in the top-level API.  There is no more wrapping.  To
store the Proxy objects necessary for actually doing tracing, I store
the property directly on the tensors.  (Note: I never
clean up old entries from the map at the moment, this is easily fixed
by using a weak map)

Benefits of doing this:

* No more tip-toeing around no_dispatch() creation of new ProxyTensors;
  we never create new tensors (except when we call the underlying func),
  so you don't have to worry about accidentally tracing them.

* No more syncing up metadata from in place operators.  In particular
  https://github.com/pytorch/pytorch/issues/81526 is mooted

* This fixes https://github.com/pytorch/torchdynamo/issues/519 as we no longer need to teach proxy tensor to support sparse tensor.

* No more schlepping symbolic integers from the inner fake tensor to the
  outer proxy tensor.  If you can make a fake tensor with symbolic ints,
  you're done, nothing else to do.

To avoid having to rewrite all of the guts, when I get to the actual
proxy tensor handler, I first "fetch" the stored ProxyTensor data from
the weakmap via a tree_map, and then operate on the consequent data as
before.  A more optimized implementation is possible.

Signed-off-by: Edward Z. Yang <[email protected]>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/83330
Approved by: https://github.com/Chillee
diff --git a/aten/src/ATen/TensorSubclassLikeUtils.h b/aten/src/ATen/TensorSubclassLikeUtils.h
index 5c01ce9..2dc55bf 100644
--- a/aten/src/ATen/TensorSubclassLikeUtils.h
+++ b/aten/src/ATen/TensorSubclassLikeUtils.h
@@ -1,5 +1,6 @@
 #pragma once
 #include <ATen/ATen.h>
+#include <c10/core/impl/TorchDispatchModeTLS.h>
 
 namespace at {
 
@@ -39,16 +40,22 @@
     DispatchKeySet(BackendComponent::MetaBit);
 
 inline bool isTensorSubclassLike(const Tensor& tensor) {
+  if (c10::impl::dispatch_mode_enabled())
+    return true;
   auto key_set = tensor.unsafeGetTensorImpl()->key_set();
   return !(key_set & kTensorSubclassLike).empty();
 }
 
 inline bool areAnyTensorSubclassLike(TensorList tensors) {
+  if (c10::impl::dispatch_mode_enabled())
+    return true;
   return std::any_of(tensors.begin(), tensors.end(), isTensorSubclassLike);
 }
 
 inline bool areAnyOptionalTensorSubclassLike(
     const c10::List<c10::optional<Tensor>>& tensors) {
+  if (c10::impl::dispatch_mode_enabled())
+    return true;
   return std::any_of(
       tensors.begin(), tensors.end(), [](const optional<Tensor>& opt_tensor) {
         return (
diff --git a/functorch/functorch/_src/python_key.py b/functorch/functorch/_src/python_key.py
index 5fe0aff..e7c8058 100644
--- a/functorch/functorch/_src/python_key.py
+++ b/functorch/functorch/_src/python_key.py
@@ -3,8 +3,7 @@
 #
 # This source code is licensed under the BSD-style license found in the
 # LICENSE file in the root directory of this source tree.
-__all__ = ["make_fx", "ProxyTensor", "dispatch_trace", "PythonKeyTracer", "pythonkey_decompose"]
-from torch.fx.experimental.proxy_tensor import make_fx, ProxyTensor, dispatch_trace, PythonKeyTracer, decompose
+__all__ = ["make_fx", "dispatch_trace", "PythonKeyTracer", "pythonkey_decompose"]
+from torch.fx.experimental.proxy_tensor import make_fx, dispatch_trace, PythonKeyTracer, decompose
 
 pythonkey_decompose = decompose
-PythonTensor = ProxyTensor
diff --git a/test/test_proxy_tensor.py b/test/test_proxy_tensor.py
index 82997ca..64b9d94 100644
--- a/test/test_proxy_tensor.py
+++ b/test/test_proxy_tensor.py
@@ -175,7 +175,9 @@
     def _test(self, f, inps):
         fx_f = make_fx(f, tracing_mode=self.tracing_mode)(*inps)
         new_inps = tree_map(_create_new_input, inps)
-        self.assertEqual(fx_f(*new_inps), f(*new_inps))
+        r1 = fx_f(*new_inps)
+        r2 = f(*new_inps)
+        self.assertEqual(r1, r2)
 
     def test_make_fx_simple(self):
         def f(x):
@@ -284,11 +286,10 @@
             self.assertTrue(is_any_sigmoid(gm))
             return torch.digamma(x)
 
-        with self.assertRaisesRegex(AssertionError, "ProxyTensor is wrapped with another Tensor subclass"):
-            traced = make_fx(f2_logging_tensor)(torch.randn(3))
-            self.assertFalse(is_any_sum(traced))
-            self.assertFalse(is_any_sigmoid(traced))  # this fails, sigmoid is traced with LoggingTensor
-            self.assertTrue(is_any_digamma(traced))
+        traced = make_fx(f2_logging_tensor)(torch.randn(3))
+        self.assertFalse(is_any_sum(traced))
+        self.assertFalse(is_any_sigmoid(traced))  # this fails, sigmoid is traced with LoggingTensor
+        self.assertTrue(is_any_digamma(traced))
 
     def test_proxy_tensor_mode_with_decomp_table_preserves_proxy(self):
         def f(x):
@@ -514,6 +515,8 @@
         model = Foo()
 
         def f(args, params, buffers):
+            for p in params.values():
+                p.grad = None
             if not isinstance(args, Iterable):
                 args = [args]
             params_and_buffers = {**params, **buffers}
diff --git a/torch/_decomp/decompositions.py b/torch/_decomp/decompositions.py
index b791486..938ef48 100644
--- a/torch/_decomp/decompositions.py
+++ b/torch/_decomp/decompositions.py
@@ -1238,10 +1238,10 @@
     dim0, dim1 = utils.canonicalize_dims(self.dim(), (dim0, dim1))  # type: ignore[misc]
 
     if self.dim() <= 1:
-        return self
+        return self.view(self.shape)
 
     if dim0 == dim1:
-        return self
+        return self.view(self.shape)
     perm = list(range(self.dim()))
     perm[dim0], perm[dim1] = perm[dim1], perm[dim0]
     return torch.permute(self, perm)
diff --git a/torch/_prims_common/wrappers.py b/torch/_prims_common/wrappers.py
index b4db75a..3c72841 100644
--- a/torch/_prims_common/wrappers.py
+++ b/torch/_prims_common/wrappers.py
@@ -269,15 +269,18 @@
 
 
 def backwards_not_supported(prim):
+    def redispatch_prim(args, kwargs):
+        g = torch._C._AutoDispatchBelowAutograd()
+        try:
+            return prim(*args, **kwargs)
+        finally:
+            del g
+
     class BackwardsNotSupported(torch.autograd.Function):
         @staticmethod
         def forward(ctx, args_spec, *flat_args):
             args, kwargs = tree_unflatten(flat_args, args_spec)  # type: ignore[arg-type]
-            g = torch._C._AutoDispatchBelowAutograd()
-            try:
-                return prim(*args, **kwargs)
-            finally:
-                del g
+            return redispatch_prim(args, kwargs)
 
         @staticmethod
         def backward(ctx, *args):
@@ -286,7 +289,20 @@
     @wraps(prim)
     def _autograd_impl(*args, **kwargs):
         flat_args, args_spec = tree_flatten((args, kwargs))
-        return BackwardsNotSupported.apply(args_spec, *flat_args)
+        if torch.is_grad_enabled() and any(a.requires_grad for a in flat_args if isinstance(a, torch.Tensor)):
+            # TODO: There is a subtle bug here: prims like copy_to
+            # return their input argument after mutating it; and custom
+            # autograd function will incorrectly turn the result into
+            # a view which will fail test_python_ref_executor tests.
+            # At the moment, we sidestep this by observing that the
+            # unit tests don't ever try to run the executor with
+            # autograd, so we don't exercise the buggy case, but if
+            # you ever want to feed autograd through this, be aware
+            # of it!  We need a way of properly implementing autograd
+            # for mutating operations in Python to do this.
+            return BackwardsNotSupported.apply(args_spec, *flat_args)
+        else:
+            return redispatch_prim(args, kwargs)
 
     return _autograd_impl
 
diff --git a/torch/csrc/autograd/variable.cpp b/torch/csrc/autograd/variable.cpp
index aa3f17f..a48777d 100644
--- a/torch/csrc/autograd/variable.cpp
+++ b/torch/csrc/autograd/variable.cpp
@@ -45,6 +45,8 @@
     self_impl->set_version_counter(
         impl::version_counter(backward_info_.value().base_));
     attr_version_ = self_impl->version_counter().current_version();
+    TORCH_INTERNAL_ASSERT(
+        backward_info_.value().base_.unsafeGetTensorImpl() != self_impl);
   }
   if (shared_view_info_) {
     TORCH_INTERNAL_ASSERT(
diff --git a/torch/fx/experimental/proxy_tensor.py b/torch/fx/experimental/proxy_tensor.py
index 225035a..554c142 100644
--- a/torch/fx/experimental/proxy_tensor.py
+++ b/torch/fx/experimental/proxy_tensor.py
@@ -7,23 +7,24 @@
 import functools
 from typing import Any, Dict, Optional, Tuple, Callable, Union
 import torch
-from torch._C import _disabled_torch_function_impl
 import torch.utils._pytree as pytree
 from torch.fx import Tracer, GraphModule
 from torch._subclasses.fake_tensor import FakeTensorMode
 import torch.fx as fx
-from torch.utils._mode_utils import no_dispatch
 from torch.fx.passes.shape_prop import _extract_tensor_metadata
 from contextlib import contextmanager, nullcontext
 import inspect
+from dataclasses import dataclass
 
 from torch.utils._python_dispatch import TorchDispatchMode, enable_torch_dispatch_mode
 from torch._subclasses import FakeTensor
 from .symbolic_shapes import ShapeEnv, SymDispatchMode, PySymInt
 import torch.fx.experimental.symbolic_shapes as symbolic_shapes
+from torch.fx import Proxy
 
-__all__ = ["ProxyTensor", "PythonKeyTracer", "dispatch_trace", "make_fx", "DecompositionInterpreter"]
+__all__ = ["PythonKeyTracer", "dispatch_trace", "make_fx", "DecompositionInterpreter"]
 aten = torch.ops.aten
+prim = torch.ops.prim
 
 CURRENT_DECOMPOSITION_TABLE: Dict[torch._ops.OpOverload, Callable] = {}
 
@@ -33,7 +34,6 @@
     argnames = ",".join(f"arg{i}" for i in range(nargs))
     return eval(f"lambda {argnames}: fn({argnames})", {"fn": fn})
 
-
 @contextmanager
 def decompose(decomposition_table):
     global CURRENT_DECOMPOSITION_TABLE
@@ -44,7 +44,12 @@
     finally:
         CURRENT_DECOMPOSITION_TABLE = old_decomposition_table
 
-def track_metadata(tensor, proxy, tracer):
+def track_tensor(tensor, proxy, *, constant, tracer):
+    # The basic idea is that we need to associate each tensor/SymInt
+    # with a Proxy.  How do we setup this association?  We just store
+    # the proxy on the __dict__ of the object, keyed on the tracer
+    # (so that if we have multiple tracers at the same time, they
+    # don't clobber each other.)
     for i, s in enumerate(tensor.shape):
         if isinstance(s, SymInt):
             inner_s = s.get_pyobj()
@@ -54,15 +59,13 @@
             # use?  Maybe complicated and DCE is a better idea
             inner_s.__dict__[tracer] = proxy.size(i)
         # TODO: also do stride/numel
+    tensor.__dict__[tracer] = _ProxyTensor(proxy, constant)
 
-def wrap_output(inner_res, proxy_res, *, constant, proxy_mode):
+def track_tensor_tree(inner_res, proxy_res, *, constant, tracer):
     def wrap_with_proxy(e, proxy, constant):
         if isinstance(e, torch.Tensor):
-            track_metadata(e, proxy, proxy_mode.tracer)
-            with no_dispatch():
-                return ProxyTensor(e, proxy, constant=constant, proxy_mode=proxy_mode)
-        else:
-            return e
+            track_tensor(e, proxy, tracer=tracer, constant=constant)
+            proxy.node.meta['tensor_meta'] = _extract_tensor_metadata(e)
 
     def get_constant(idx):
         if constant is None:
@@ -73,14 +76,13 @@
     # Unfortunately, tree_map cannot directly be used here. As the resulting
     # object may be a proxy that represents a tuple, we may need to
     # explicitly unwrap the proxy by simulating the flattening operations.
-    if isinstance(inner_res, tuple):
-        return tuple(wrap_with_proxy(e, proxy_res[idx], get_constant(idx)) for idx, e in enumerate(inner_res))
-    elif isinstance(inner_res, list):
-        return list([wrap_with_proxy(e, proxy_res[idx], get_constant(idx)) for idx, e in enumerate(inner_res)])
+    if isinstance(inner_res, tuple) or isinstance(inner_res, list):
+        for idx, e in enumerate(inner_res):
+            wrap_with_proxy(e, proxy_res[idx], get_constant(idx))
     elif isinstance(inner_res, torch.Tensor):
-        return wrap_with_proxy(inner_res, proxy_res, constant)
-    else:
-        return inner_res
+        wrap_with_proxy(inner_res, proxy_res, constant)
+
+    return inner_res
 
 
 def maybe_disable_fake_tensor_mode():
@@ -93,10 +95,10 @@
         return nullcontext()
 
 
-def unwrap_elem(e):
-    if isinstance(e, ProxyTensor):
-        return e.elem
-    return e
+@dataclass
+class _ProxyTensor:
+    proxy: Proxy
+    constant: Optional[torch.Tensor]
 
 
 def fetch_symint_proxy(tracer):
@@ -117,6 +119,7 @@
     if func_overload in CURRENT_DECOMPOSITION_TABLE:
         with proxy_mode.restore():
             return CURRENT_DECOMPOSITION_TABLE[func_overload](*args, **kwargs)
+
     # Some of these are not "real" aten ops and will fail if we
     # call _dispatch_has_kernel_for_dispatch_key on them.
     # This list is probably incomplete
@@ -126,11 +129,21 @@
             if r is not NotImplemented:
                 return r
 
+    tracer = proxy_mode.tracer
+
+    def fetch(t):
+        if isinstance(t, torch.Tensor) and tracer in t.__dict__:
+            return t.__dict__[tracer]
+        else:
+            return t
+
+    f_args, f_kwargs = pytree.tree_map(fetch, (args, kwargs))
+
     # If there are SymInts, we also should not consider this constant.
     # However, fake tensor handling of SymInts is sufficiently broken that
     # I couldn't write a test for this case
     all_constant = (
-        pytree.tree_all_only(ProxyTensor, lambda t: t.constant is not None, (args, kwargs))
+        pytree.tree_all_only(_ProxyTensor, lambda t: t.constant is not None, (f_args, f_kwargs))
         # TODO: maybe constant SymInts should also be allowed?  Not sure if
         # this can happen
         and pytree.tree_all_only(SymInt, lambda _: False, (args, kwargs))
@@ -140,7 +153,7 @@
         # Check if all of the Tensor inputs are constants
         if all_constant:
             const_args, const_kwargs = pytree.tree_map_only(
-                ProxyTensor, lambda t: t.constant, (args, kwargs)
+                _ProxyTensor, lambda t: t.constant, (f_args, f_kwargs)
             )
             with maybe_disable_fake_tensor_mode():
                 return func_overload(*const_args, **const_kwargs)
@@ -152,26 +165,18 @@
     proxy_args, proxy_kwargs = pytree.tree_map_only(
         SymInt,
         fetch_symint_proxy(proxy_mode.tracer),
-        pytree.tree_map_only(ProxyTensor, lambda e: e.proxy, (args, kwargs))
+        pytree.tree_map_only(_ProxyTensor, lambda e: e.proxy, (f_args, f_kwargs))
     )
-    proxy_res = func_overload(*proxy_args, **proxy_kwargs)
+    proxy_out = func_overload(*proxy_args, **proxy_kwargs)
 
     # Kind of a hacky way to test if an op is in-place or not
     if func.__name__[-1] == "_" and func.__name__[0] != "_":
         # This makes DCE marginally less likely to DCE inplace operations.
         # It is not strictly necessary
-        args[0].proxy = proxy_res
-        proxy_res.node.meta['tensor_meta'] = _extract_tensor_metadata(args[0])
+        args[0].proxy = proxy_out
+        proxy_out.node.meta['tensor_meta'] = _extract_tensor_metadata(args[0])
 
-    elem_args, elem_kwargs = pytree.tree_map(unwrap_elem, (args, kwargs))
-    inner_res = func_overload(*elem_args, **elem_kwargs)
-
-    # Needed to sync up metadata for in-place operators that modify metadata
-    # TODO: instead forward the metadata to the inner tensor so updating
-    # is not necessary
-    if torch.Tag.inplace_view in func_overload.tags:  # type: ignore[attr-defined]
-        with no_dispatch():
-            func_overload(*args, **kwargs)
+    out = func_overload(*args, **kwargs)
 
     # In some circumstances, we will be tracing in a situation where a tensor
     # is *statically* known to be a constant (currently, this only happens if
@@ -194,69 +199,19 @@
     # element constant computation by testing the numel of the result before
     # propagating const-ness.  Similarly, we don't require the constant to
     # live on CPU, but we could.
-    any_constant = pytree.tree_any_only(ProxyTensor, lambda t: t.constant is not None, (args, kwargs))
+    any_constant = pytree.tree_any_only(_ProxyTensor, lambda t: t.constant is not None, (f_args, f_kwargs))
 
     constant = None
     # NB: do NOT include factories as constants
     if all_constant and any_constant:
         with maybe_disable_fake_tensor_mode():
             const_args, const_kwargs = pytree.tree_map_only(
-                ProxyTensor, lambda t: t.constant, (args, kwargs)
+                _ProxyTensor, lambda t: t.constant, (f_args, f_kwargs)
             )
             constant = func_overload(*const_args, **const_kwargs)
 
-    # TODO(chilli): Enable this after it's been refactored to work with wrapper tensor subclasses in general
-    # pytree.tree_map(lambda x: check_metadata_consistency(x, ProxyTensor), (inner_res, args, kwargs))
-    return wrap_output(inner_res, proxy_res, constant=constant, proxy_mode=proxy_mode)
-
-
-class ProxyTensor(torch.Tensor):
-    proxy: fx.Proxy
-    elem: torch.Tensor
-    proxy_mode: "ProxyTorchDispatchMode"
-
-    @staticmethod
-    def __new__(cls, elem, proxy, *, requires_grad=None, constant=None, proxy_mode):
-        new_shape = elem.shape
-        new_strides = elem.stride()
-
-        return torch.Tensor._make_wrapper_subclass(  # type: ignore[attr-defined]
-            cls,
-            new_shape, dtype=elem.dtype, layout=elem.layout, device=elem.device,
-            requires_grad=requires_grad if requires_grad is not None else False, strides=new_strides,
-            storage_offset=elem.storage_offset()
-        )
-
-    def __init__(self, elem, proxy, *, requires_grad=None, constant=None, proxy_mode):
-        # TODO: hack since _extract_tensor_metadata currently tries to access stride
-        if elem.is_sparse or symbolic_shapes.has_symbolic_sizes_strides(elem):  # TODO: handle has_sym_ints
-            proxy.node.meta['tensor_meta'] = {}
-        else:
-            proxy.node.meta['tensor_meta'] = _extract_tensor_metadata(self)
-        # This detects situations where you accidentally put a ProxyTensor
-        # inside a ProxyTensor for the same trace; this is a layering violation
-        assert not (isinstance(elem, ProxyTensor) and elem.proxy.tracer is proxy.tracer)
-        self.elem = elem
-        self.proxy = proxy
-        self.constant = constant
-        self.proxy_mode = proxy_mode
-
-
-    def __deepcopy__(self, memo):
-        return self.clone()
-
-    def __repr__(self):
-        with no_dispatch():
-            return f"ProxyTensor({self.elem}, proxy={self.proxy})"
-
-    __torch_function__ = _disabled_torch_function_impl
-
-    @classmethod
-    def __torch_dispatch__(cls, func_overload, types, args=(), kwargs=None):
-        raise RuntimeError(
-            "Should not be needed as we always trace with modes. May have entered this due to redispatching from"
-            "__torch_dispatch__ into another op without restoring dispatch mode"
-        )
+    track_tensor_tree(out, proxy_out, constant=constant, tracer=tracer)
+    return out
 
 
 class PythonKeyTracer(Tracer):
@@ -304,33 +259,21 @@
     return GraphModule(tracer.root, graph, name)
 
 
-def wrap_key(f, inps, proxy_mode, tracer):
-    flat_inps, _ = pytree.tree_flatten(inps)
+def wrap_key(f, tensors, tracer):
+    flat_tensors, tensors_spec = pytree.tree_flatten(tensors)
 
     @functools.wraps(f)
-    def wrapped(*args):
-        flat_args, args_spec = pytree.tree_flatten(args)
-        assert (len(flat_args) == len(flat_inps))
-        for idx, arg in enumerate(flat_args):
-            if isinstance(flat_inps[idx], torch.Tensor):
-                with no_dispatch():
-                    track_metadata(flat_inps[idx], arg, tracer)
-                    flat_args[idx] = ProxyTensor(
-                        flat_inps[idx],
-                        arg,
-                        requires_grad=(flat_inps[idx].is_leaf and flat_inps[idx].requires_grad),
-                        proxy_mode=proxy_mode,
-                    )
-            else:
-                flat_args[idx] = flat_inps[idx]
+    def wrapped(*proxies):
+        flat_proxies, proxies_spec = pytree.tree_flatten(proxies)
+        assert len(flat_proxies) == len(flat_tensors)
+        track_tensor_tree(flat_tensors, flat_proxies, constant=None, tracer=tracer)
 
-        tree_args = pytree.tree_unflatten(flat_args, args_spec)
-        out = f(*tree_args)
-        flat_outs, out_spec = pytree.tree_flatten(out)
-        for idx in range(len(flat_outs)):
-            if isinstance(flat_outs[idx], torch.Tensor) and isinstance(flat_outs[idx], ProxyTensor):
-                flat_outs[idx] = flat_outs[idx].proxy
-        return pytree.tree_unflatten(flat_outs, out_spec)
+        out = f(*tensors)
+        return pytree.tree_map_only(
+            torch.Tensor,
+            lambda t: t.__dict__[tracer].proxy if tracer in t.__dict__ else t,
+            out
+        )
 
     return wrapped
 
@@ -340,6 +283,7 @@
         self.tracer = tracer
         self.enable_tracing = True
         self.sym_mode = ProxySymDispatchMode(tracer)
+        self.trace_state = {}
 
     def __torch_dispatch__(self, func_overload, types, args=(), kwargs=None):
         with self.sym_mode.enable(False):
@@ -364,7 +308,13 @@
         if func_overload == aten.lift.default:
             return args[0]
 
-        if any(tuple(isinstance(arg, ProxyTensor) for arg in pytree.tree_flatten(args)[0])):
+        if func in [prim.device]:
+            return func_overload(*args, **kwargs)
+
+        if any(
+            isinstance(arg, torch.Tensor) and self.tracer in arg.__dict__
+            for arg in pytree.tree_flatten(args)[0]
+        ):
             out = proxy_call(self, func_overload, args, kwargs)
         # When we trace through a torch.tensor invocation, you never actually
         # see a torch.ops.aten.tensor call. Instead, the way this function is
@@ -401,10 +351,11 @@
         # This is what the overload modification does.
         else:
             flat_args = pytree.tree_flatten((args, kwargs))[0]
-            handled_types = [torch.Tensor, ProxyTensor, torch.nn.Parameter]
+            handled_types = [torch.Tensor, _ProxyTensor, torch.nn.Parameter]
 
             # If there are any tensor subclasses, we need to handle those tensor subclasses first
-            if any([isinstance(arg, torch.Tensor) and type(arg) not in handled_types for arg in flat_args]):
+            # TODO: we could use types to test this
+            if any(isinstance(arg, torch.Tensor) and type(arg) not in handled_types for arg in flat_args):
                 return NotImplemented
 
             if func_overload is torch.ops.aten.lift_fresh.default:
@@ -412,10 +363,10 @@
 
             n_args, n_kwargs = pytree.tree_map_only(SymInt, fetch_symint_proxy(self.tracer), (args, kwargs))
 
-            proxy_res = self.tracer.create_proxy('call_function', func_overload, n_args, n_kwargs,
+            proxy_out = self.tracer.create_proxy('call_function', func_overload, n_args, n_kwargs,
                                                  name=self.tracer.graph._target_to_str(func.__name__))
 
-            inner_res = func_overload(*args, **kwargs)
+            out = func_overload(*args, **kwargs)
 
             # If this is a lift, the input tensor is guaranteed to be a
             # constant, so we keep a copy of the original argument along so
@@ -426,14 +377,14 @@
                     constant = args[0].clone()
             else:
                 constant = None
-            out = wrap_output(inner_res, proxy_res, constant=constant, proxy_mode=self)
+            track_tensor_tree(out, proxy_out, constant=constant, tracer=self.tracer)
 
         def assert_proxy_tensor(e):
             if isinstance(e, torch.Tensor):
-                assert isinstance(e, ProxyTensor), \
-                    f"Internal Error: ProxyTensor is incorrectly baking a tensor constant into the graph: {str(e)}"
+                assert self.tracer in e.__dict__, \
+                    f"Internal Error: make_fx is incorrectly baking a tensor constant into the graph: {str(e)}"
 
-        # When we trace factory functions, we expect that tensor outputs are *always* ProxyTensors.
+        # When we trace factory functions, we expect that tensor outputs are *always* tracked.
         # (Except for torch.tensor() constants handled through lift(), which is handled
         # specially further up).
         pytree.tree_map(assert_proxy_tensor, out)
@@ -447,6 +398,9 @@
     def __init__(self, tracer):
         super().__init__()
         self.tracer = tracer
+        # When false, we don't trace operations.  If you do this, you MUST
+        # call track_tensor/track_tensor_tree on all results of the operation
+        # to ensure we can adeduately track the results
         self.enable_tracing = True
 
     @contextmanager
@@ -477,6 +431,8 @@
         return out
 
 
+# TODO: I'm not sure what the point of this class is; you can just
+# make_fx through a regular Interpreter
 class DecompositionInterpreter(torch.fx.Interpreter):
     def __init__(self, module: torch.fx.GraphModule, new_graph: torch.fx.Graph, decomposition_table=None, **kwargs):
         super().__init__(module, **kwargs)
@@ -489,20 +445,24 @@
 
     def placeholder(self, target, args, kwargs):
         out = super().placeholder(target, args, kwargs)
+        proxy = torch.fx.Proxy(self.new_graph.placeholder(target), self.tracer)
+        track_tensor_tree(out, proxy, constant=None, tracer=self.tracer)
         # TODO handle case where the first character of target is '*'
-        return ProxyTensor(out, torch.fx.Proxy(self.new_graph.placeholder(target), self.tracer), proxy_mode=self.mode)
+        return out
 
     def get_attr(self, target, args, kwargs):
         out = super().get_attr(target, args, kwargs)
-        return ProxyTensor(out, torch.fx.Proxy(self.new_graph.get_attr(target), self.tracer), proxy_mode=self.mode)
+        proxy = torch.fx.Proxy(self.new_graph.get_attr(target), self.tracer)
+        track_tensor_tree(out, proxy, constant=None, tracer=self.tracer)
+        return out
 
-    # call_function, call_method, call_module get traced automatically by the ProxyTensors.
+    # call_function, call_method, call_module get traced automatically by the outer mode.
 
     def output(self, target, args, kwargs):
         out = super().output(target, args, kwargs)
 
         def unwrap(e):
-            return e.proxy.node if isinstance(e, ProxyTensor) else e
+            return e.__dict__[self.tracer].proxy.node if self.tracer in e.__dict__ else e
         self.new_graph.output(pytree.tree_map(unwrap, out))
         return out
 
@@ -581,7 +541,7 @@
             func = f
 
         with decompose(decomposition_table), fake_tensor_mode, sym_mode, proxy_mode:  # type: ignore[attr-defined]
-            t = dispatch_trace(wrap_key(func, args, proxy_mode, fx_tracer), tracer=fx_tracer, concrete_args=tuple(phs))
+            t = dispatch_trace(wrap_key(func, args, fx_tracer), tracer=fx_tracer, concrete_args=tuple(phs))
 
         # TODO: kind of a bad way to do it, should maybe figure out a better way
         t.shape_env = shape_env  # type: ignore[assignment]
@@ -601,6 +561,7 @@
 
 @contextlib.contextmanager
 def disable_proxy_modes_tracing():
+    # TODO: This probably doesn't correctly also disable ProxySymDispatchMode
     modes = get_torch_dispatch_modes()
     proxy_tensor_modes = [m for m in modes if isinstance(m, ProxyTorchDispatchMode)]
     olds = [m.enable_tracing for m in proxy_tensor_modes]
@@ -622,19 +583,6 @@
     """
     wrapped, all_args = wrapper_and_args_for_make_fx(func, args, kwargs)
 
-    unwrapped_all_args = [unwrap_elem(a) for a in all_args]
-
-    # Current implementation doesn't support the case when ProxyTensor is
-    # wrapped with another Tensor subclass
-    # See: https://github.com/pytorch/pytorch/pull/81764#issuecomment-1200472068
-    # TODO: Once https://github.com/pytorch/pytorch/pull/82549 is merged, we can
-    # remove this
-    assert all(
-        getattr(a, "elem", None) is None
-        for a in unwrapped_all_args
-        if isinstance(a, torch.Tensor)
-    ), "ProxyTensor is wrapped with another Tensor subclass"
-
     with disable_proxy_modes_tracing():
-        gm = make_fx(wrapped)(unwrapped_all_args)
+        gm = make_fx(wrapped)(all_args)
     return gm
diff --git a/torch/testing/_internal/common_methods_invocations.py b/torch/testing/_internal/common_methods_invocations.py
index 1ef2c9a..33d8eef 100644
--- a/torch/testing/_internal/common_methods_invocations.py
+++ b/torch/testing/_internal/common_methods_invocations.py
@@ -19549,9 +19549,8 @@
             DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref_meta'),
             # RuntimeError: no _refs support for torch.Tensor.tolist
             DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref'),
-            # RuntimeError: .tolist() is not supported for tensor subclasses.
-            DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref_executor'),
-        )
+        ),
+        supports_nvfuser=False,
     ),
     PythonRefInfo(
         "_refs.hsplit",