Delete ProxyTensor wrapper subclass (#83330)
I was working on https://github.com/pytorch/torchdynamo/issues/80 and my
working hypothesis for what was causing the error was that proxy tensor
was not advertising correct dispatch keys, causing AMP to operate
differently when you traced. I could have fixed this directly by
replicating fake tensor's fix for setting dispatch keys to also apply to
proxy tensor, but I was like, "Why must I repeat myself."
This PR is the result. It completely deletes the ProxyTensor wrapper
subclass, so that when we are tracing, the tensors flowing through the
program are the *original* real or fake tensors, depending on what the
user requested in the top-level API. There is no more wrapping. To
store the Proxy objects necessary for actually doing tracing, I store
the property directly on the tensors. (Note: I never
clean up old entries from the map at the moment, this is easily fixed
by using a weak map)
Benefits of doing this:
* No more tip-toeing around no_dispatch() creation of new ProxyTensors;
we never create new tensors (except when we call the underlying func),
so you don't have to worry about accidentally tracing them.
* No more syncing up metadata from in place operators. In particular
https://github.com/pytorch/pytorch/issues/81526 is mooted
* This fixes https://github.com/pytorch/torchdynamo/issues/519 as we no longer need to teach proxy tensor to support sparse tensor.
* No more schlepping symbolic integers from the inner fake tensor to the
outer proxy tensor. If you can make a fake tensor with symbolic ints,
you're done, nothing else to do.
To avoid having to rewrite all of the guts, when I get to the actual
proxy tensor handler, I first "fetch" the stored ProxyTensor data from
the weakmap via a tree_map, and then operate on the consequent data as
before. A more optimized implementation is possible.
Signed-off-by: Edward Z. Yang <[email protected]>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/83330
Approved by: https://github.com/Chillee
diff --git a/aten/src/ATen/TensorSubclassLikeUtils.h b/aten/src/ATen/TensorSubclassLikeUtils.h
index 5c01ce9..2dc55bf 100644
--- a/aten/src/ATen/TensorSubclassLikeUtils.h
+++ b/aten/src/ATen/TensorSubclassLikeUtils.h
@@ -1,5 +1,6 @@
#pragma once
#include <ATen/ATen.h>
+#include <c10/core/impl/TorchDispatchModeTLS.h>
namespace at {
@@ -39,16 +40,22 @@
DispatchKeySet(BackendComponent::MetaBit);
inline bool isTensorSubclassLike(const Tensor& tensor) {
+ if (c10::impl::dispatch_mode_enabled())
+ return true;
auto key_set = tensor.unsafeGetTensorImpl()->key_set();
return !(key_set & kTensorSubclassLike).empty();
}
inline bool areAnyTensorSubclassLike(TensorList tensors) {
+ if (c10::impl::dispatch_mode_enabled())
+ return true;
return std::any_of(tensors.begin(), tensors.end(), isTensorSubclassLike);
}
inline bool areAnyOptionalTensorSubclassLike(
const c10::List<c10::optional<Tensor>>& tensors) {
+ if (c10::impl::dispatch_mode_enabled())
+ return true;
return std::any_of(
tensors.begin(), tensors.end(), [](const optional<Tensor>& opt_tensor) {
return (
diff --git a/functorch/functorch/_src/python_key.py b/functorch/functorch/_src/python_key.py
index 5fe0aff..e7c8058 100644
--- a/functorch/functorch/_src/python_key.py
+++ b/functorch/functorch/_src/python_key.py
@@ -3,8 +3,7 @@
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
-__all__ = ["make_fx", "ProxyTensor", "dispatch_trace", "PythonKeyTracer", "pythonkey_decompose"]
-from torch.fx.experimental.proxy_tensor import make_fx, ProxyTensor, dispatch_trace, PythonKeyTracer, decompose
+__all__ = ["make_fx", "dispatch_trace", "PythonKeyTracer", "pythonkey_decompose"]
+from torch.fx.experimental.proxy_tensor import make_fx, dispatch_trace, PythonKeyTracer, decompose
pythonkey_decompose = decompose
-PythonTensor = ProxyTensor
diff --git a/test/test_proxy_tensor.py b/test/test_proxy_tensor.py
index 82997ca..64b9d94 100644
--- a/test/test_proxy_tensor.py
+++ b/test/test_proxy_tensor.py
@@ -175,7 +175,9 @@
def _test(self, f, inps):
fx_f = make_fx(f, tracing_mode=self.tracing_mode)(*inps)
new_inps = tree_map(_create_new_input, inps)
- self.assertEqual(fx_f(*new_inps), f(*new_inps))
+ r1 = fx_f(*new_inps)
+ r2 = f(*new_inps)
+ self.assertEqual(r1, r2)
def test_make_fx_simple(self):
def f(x):
@@ -284,11 +286,10 @@
self.assertTrue(is_any_sigmoid(gm))
return torch.digamma(x)
- with self.assertRaisesRegex(AssertionError, "ProxyTensor is wrapped with another Tensor subclass"):
- traced = make_fx(f2_logging_tensor)(torch.randn(3))
- self.assertFalse(is_any_sum(traced))
- self.assertFalse(is_any_sigmoid(traced)) # this fails, sigmoid is traced with LoggingTensor
- self.assertTrue(is_any_digamma(traced))
+ traced = make_fx(f2_logging_tensor)(torch.randn(3))
+ self.assertFalse(is_any_sum(traced))
+ self.assertFalse(is_any_sigmoid(traced)) # this fails, sigmoid is traced with LoggingTensor
+ self.assertTrue(is_any_digamma(traced))
def test_proxy_tensor_mode_with_decomp_table_preserves_proxy(self):
def f(x):
@@ -514,6 +515,8 @@
model = Foo()
def f(args, params, buffers):
+ for p in params.values():
+ p.grad = None
if not isinstance(args, Iterable):
args = [args]
params_and_buffers = {**params, **buffers}
diff --git a/torch/_decomp/decompositions.py b/torch/_decomp/decompositions.py
index b791486..938ef48 100644
--- a/torch/_decomp/decompositions.py
+++ b/torch/_decomp/decompositions.py
@@ -1238,10 +1238,10 @@
dim0, dim1 = utils.canonicalize_dims(self.dim(), (dim0, dim1)) # type: ignore[misc]
if self.dim() <= 1:
- return self
+ return self.view(self.shape)
if dim0 == dim1:
- return self
+ return self.view(self.shape)
perm = list(range(self.dim()))
perm[dim0], perm[dim1] = perm[dim1], perm[dim0]
return torch.permute(self, perm)
diff --git a/torch/_prims_common/wrappers.py b/torch/_prims_common/wrappers.py
index b4db75a..3c72841 100644
--- a/torch/_prims_common/wrappers.py
+++ b/torch/_prims_common/wrappers.py
@@ -269,15 +269,18 @@
def backwards_not_supported(prim):
+ def redispatch_prim(args, kwargs):
+ g = torch._C._AutoDispatchBelowAutograd()
+ try:
+ return prim(*args, **kwargs)
+ finally:
+ del g
+
class BackwardsNotSupported(torch.autograd.Function):
@staticmethod
def forward(ctx, args_spec, *flat_args):
args, kwargs = tree_unflatten(flat_args, args_spec) # type: ignore[arg-type]
- g = torch._C._AutoDispatchBelowAutograd()
- try:
- return prim(*args, **kwargs)
- finally:
- del g
+ return redispatch_prim(args, kwargs)
@staticmethod
def backward(ctx, *args):
@@ -286,7 +289,20 @@
@wraps(prim)
def _autograd_impl(*args, **kwargs):
flat_args, args_spec = tree_flatten((args, kwargs))
- return BackwardsNotSupported.apply(args_spec, *flat_args)
+ if torch.is_grad_enabled() and any(a.requires_grad for a in flat_args if isinstance(a, torch.Tensor)):
+ # TODO: There is a subtle bug here: prims like copy_to
+ # return their input argument after mutating it; and custom
+ # autograd function will incorrectly turn the result into
+ # a view which will fail test_python_ref_executor tests.
+ # At the moment, we sidestep this by observing that the
+ # unit tests don't ever try to run the executor with
+ # autograd, so we don't exercise the buggy case, but if
+ # you ever want to feed autograd through this, be aware
+ # of it! We need a way of properly implementing autograd
+ # for mutating operations in Python to do this.
+ return BackwardsNotSupported.apply(args_spec, *flat_args)
+ else:
+ return redispatch_prim(args, kwargs)
return _autograd_impl
diff --git a/torch/csrc/autograd/variable.cpp b/torch/csrc/autograd/variable.cpp
index aa3f17f..a48777d 100644
--- a/torch/csrc/autograd/variable.cpp
+++ b/torch/csrc/autograd/variable.cpp
@@ -45,6 +45,8 @@
self_impl->set_version_counter(
impl::version_counter(backward_info_.value().base_));
attr_version_ = self_impl->version_counter().current_version();
+ TORCH_INTERNAL_ASSERT(
+ backward_info_.value().base_.unsafeGetTensorImpl() != self_impl);
}
if (shared_view_info_) {
TORCH_INTERNAL_ASSERT(
diff --git a/torch/fx/experimental/proxy_tensor.py b/torch/fx/experimental/proxy_tensor.py
index 225035a..554c142 100644
--- a/torch/fx/experimental/proxy_tensor.py
+++ b/torch/fx/experimental/proxy_tensor.py
@@ -7,23 +7,24 @@
import functools
from typing import Any, Dict, Optional, Tuple, Callable, Union
import torch
-from torch._C import _disabled_torch_function_impl
import torch.utils._pytree as pytree
from torch.fx import Tracer, GraphModule
from torch._subclasses.fake_tensor import FakeTensorMode
import torch.fx as fx
-from torch.utils._mode_utils import no_dispatch
from torch.fx.passes.shape_prop import _extract_tensor_metadata
from contextlib import contextmanager, nullcontext
import inspect
+from dataclasses import dataclass
from torch.utils._python_dispatch import TorchDispatchMode, enable_torch_dispatch_mode
from torch._subclasses import FakeTensor
from .symbolic_shapes import ShapeEnv, SymDispatchMode, PySymInt
import torch.fx.experimental.symbolic_shapes as symbolic_shapes
+from torch.fx import Proxy
-__all__ = ["ProxyTensor", "PythonKeyTracer", "dispatch_trace", "make_fx", "DecompositionInterpreter"]
+__all__ = ["PythonKeyTracer", "dispatch_trace", "make_fx", "DecompositionInterpreter"]
aten = torch.ops.aten
+prim = torch.ops.prim
CURRENT_DECOMPOSITION_TABLE: Dict[torch._ops.OpOverload, Callable] = {}
@@ -33,7 +34,6 @@
argnames = ",".join(f"arg{i}" for i in range(nargs))
return eval(f"lambda {argnames}: fn({argnames})", {"fn": fn})
-
@contextmanager
def decompose(decomposition_table):
global CURRENT_DECOMPOSITION_TABLE
@@ -44,7 +44,12 @@
finally:
CURRENT_DECOMPOSITION_TABLE = old_decomposition_table
-def track_metadata(tensor, proxy, tracer):
+def track_tensor(tensor, proxy, *, constant, tracer):
+ # The basic idea is that we need to associate each tensor/SymInt
+ # with a Proxy. How do we setup this association? We just store
+ # the proxy on the __dict__ of the object, keyed on the tracer
+ # (so that if we have multiple tracers at the same time, they
+ # don't clobber each other.)
for i, s in enumerate(tensor.shape):
if isinstance(s, SymInt):
inner_s = s.get_pyobj()
@@ -54,15 +59,13 @@
# use? Maybe complicated and DCE is a better idea
inner_s.__dict__[tracer] = proxy.size(i)
# TODO: also do stride/numel
+ tensor.__dict__[tracer] = _ProxyTensor(proxy, constant)
-def wrap_output(inner_res, proxy_res, *, constant, proxy_mode):
+def track_tensor_tree(inner_res, proxy_res, *, constant, tracer):
def wrap_with_proxy(e, proxy, constant):
if isinstance(e, torch.Tensor):
- track_metadata(e, proxy, proxy_mode.tracer)
- with no_dispatch():
- return ProxyTensor(e, proxy, constant=constant, proxy_mode=proxy_mode)
- else:
- return e
+ track_tensor(e, proxy, tracer=tracer, constant=constant)
+ proxy.node.meta['tensor_meta'] = _extract_tensor_metadata(e)
def get_constant(idx):
if constant is None:
@@ -73,14 +76,13 @@
# Unfortunately, tree_map cannot directly be used here. As the resulting
# object may be a proxy that represents a tuple, we may need to
# explicitly unwrap the proxy by simulating the flattening operations.
- if isinstance(inner_res, tuple):
- return tuple(wrap_with_proxy(e, proxy_res[idx], get_constant(idx)) for idx, e in enumerate(inner_res))
- elif isinstance(inner_res, list):
- return list([wrap_with_proxy(e, proxy_res[idx], get_constant(idx)) for idx, e in enumerate(inner_res)])
+ if isinstance(inner_res, tuple) or isinstance(inner_res, list):
+ for idx, e in enumerate(inner_res):
+ wrap_with_proxy(e, proxy_res[idx], get_constant(idx))
elif isinstance(inner_res, torch.Tensor):
- return wrap_with_proxy(inner_res, proxy_res, constant)
- else:
- return inner_res
+ wrap_with_proxy(inner_res, proxy_res, constant)
+
+ return inner_res
def maybe_disable_fake_tensor_mode():
@@ -93,10 +95,10 @@
return nullcontext()
-def unwrap_elem(e):
- if isinstance(e, ProxyTensor):
- return e.elem
- return e
+@dataclass
+class _ProxyTensor:
+ proxy: Proxy
+ constant: Optional[torch.Tensor]
def fetch_symint_proxy(tracer):
@@ -117,6 +119,7 @@
if func_overload in CURRENT_DECOMPOSITION_TABLE:
with proxy_mode.restore():
return CURRENT_DECOMPOSITION_TABLE[func_overload](*args, **kwargs)
+
# Some of these are not "real" aten ops and will fail if we
# call _dispatch_has_kernel_for_dispatch_key on them.
# This list is probably incomplete
@@ -126,11 +129,21 @@
if r is not NotImplemented:
return r
+ tracer = proxy_mode.tracer
+
+ def fetch(t):
+ if isinstance(t, torch.Tensor) and tracer in t.__dict__:
+ return t.__dict__[tracer]
+ else:
+ return t
+
+ f_args, f_kwargs = pytree.tree_map(fetch, (args, kwargs))
+
# If there are SymInts, we also should not consider this constant.
# However, fake tensor handling of SymInts is sufficiently broken that
# I couldn't write a test for this case
all_constant = (
- pytree.tree_all_only(ProxyTensor, lambda t: t.constant is not None, (args, kwargs))
+ pytree.tree_all_only(_ProxyTensor, lambda t: t.constant is not None, (f_args, f_kwargs))
# TODO: maybe constant SymInts should also be allowed? Not sure if
# this can happen
and pytree.tree_all_only(SymInt, lambda _: False, (args, kwargs))
@@ -140,7 +153,7 @@
# Check if all of the Tensor inputs are constants
if all_constant:
const_args, const_kwargs = pytree.tree_map_only(
- ProxyTensor, lambda t: t.constant, (args, kwargs)
+ _ProxyTensor, lambda t: t.constant, (f_args, f_kwargs)
)
with maybe_disable_fake_tensor_mode():
return func_overload(*const_args, **const_kwargs)
@@ -152,26 +165,18 @@
proxy_args, proxy_kwargs = pytree.tree_map_only(
SymInt,
fetch_symint_proxy(proxy_mode.tracer),
- pytree.tree_map_only(ProxyTensor, lambda e: e.proxy, (args, kwargs))
+ pytree.tree_map_only(_ProxyTensor, lambda e: e.proxy, (f_args, f_kwargs))
)
- proxy_res = func_overload(*proxy_args, **proxy_kwargs)
+ proxy_out = func_overload(*proxy_args, **proxy_kwargs)
# Kind of a hacky way to test if an op is in-place or not
if func.__name__[-1] == "_" and func.__name__[0] != "_":
# This makes DCE marginally less likely to DCE inplace operations.
# It is not strictly necessary
- args[0].proxy = proxy_res
- proxy_res.node.meta['tensor_meta'] = _extract_tensor_metadata(args[0])
+ args[0].proxy = proxy_out
+ proxy_out.node.meta['tensor_meta'] = _extract_tensor_metadata(args[0])
- elem_args, elem_kwargs = pytree.tree_map(unwrap_elem, (args, kwargs))
- inner_res = func_overload(*elem_args, **elem_kwargs)
-
- # Needed to sync up metadata for in-place operators that modify metadata
- # TODO: instead forward the metadata to the inner tensor so updating
- # is not necessary
- if torch.Tag.inplace_view in func_overload.tags: # type: ignore[attr-defined]
- with no_dispatch():
- func_overload(*args, **kwargs)
+ out = func_overload(*args, **kwargs)
# In some circumstances, we will be tracing in a situation where a tensor
# is *statically* known to be a constant (currently, this only happens if
@@ -194,69 +199,19 @@
# element constant computation by testing the numel of the result before
# propagating const-ness. Similarly, we don't require the constant to
# live on CPU, but we could.
- any_constant = pytree.tree_any_only(ProxyTensor, lambda t: t.constant is not None, (args, kwargs))
+ any_constant = pytree.tree_any_only(_ProxyTensor, lambda t: t.constant is not None, (f_args, f_kwargs))
constant = None
# NB: do NOT include factories as constants
if all_constant and any_constant:
with maybe_disable_fake_tensor_mode():
const_args, const_kwargs = pytree.tree_map_only(
- ProxyTensor, lambda t: t.constant, (args, kwargs)
+ _ProxyTensor, lambda t: t.constant, (f_args, f_kwargs)
)
constant = func_overload(*const_args, **const_kwargs)
- # TODO(chilli): Enable this after it's been refactored to work with wrapper tensor subclasses in general
- # pytree.tree_map(lambda x: check_metadata_consistency(x, ProxyTensor), (inner_res, args, kwargs))
- return wrap_output(inner_res, proxy_res, constant=constant, proxy_mode=proxy_mode)
-
-
-class ProxyTensor(torch.Tensor):
- proxy: fx.Proxy
- elem: torch.Tensor
- proxy_mode: "ProxyTorchDispatchMode"
-
- @staticmethod
- def __new__(cls, elem, proxy, *, requires_grad=None, constant=None, proxy_mode):
- new_shape = elem.shape
- new_strides = elem.stride()
-
- return torch.Tensor._make_wrapper_subclass( # type: ignore[attr-defined]
- cls,
- new_shape, dtype=elem.dtype, layout=elem.layout, device=elem.device,
- requires_grad=requires_grad if requires_grad is not None else False, strides=new_strides,
- storage_offset=elem.storage_offset()
- )
-
- def __init__(self, elem, proxy, *, requires_grad=None, constant=None, proxy_mode):
- # TODO: hack since _extract_tensor_metadata currently tries to access stride
- if elem.is_sparse or symbolic_shapes.has_symbolic_sizes_strides(elem): # TODO: handle has_sym_ints
- proxy.node.meta['tensor_meta'] = {}
- else:
- proxy.node.meta['tensor_meta'] = _extract_tensor_metadata(self)
- # This detects situations where you accidentally put a ProxyTensor
- # inside a ProxyTensor for the same trace; this is a layering violation
- assert not (isinstance(elem, ProxyTensor) and elem.proxy.tracer is proxy.tracer)
- self.elem = elem
- self.proxy = proxy
- self.constant = constant
- self.proxy_mode = proxy_mode
-
-
- def __deepcopy__(self, memo):
- return self.clone()
-
- def __repr__(self):
- with no_dispatch():
- return f"ProxyTensor({self.elem}, proxy={self.proxy})"
-
- __torch_function__ = _disabled_torch_function_impl
-
- @classmethod
- def __torch_dispatch__(cls, func_overload, types, args=(), kwargs=None):
- raise RuntimeError(
- "Should not be needed as we always trace with modes. May have entered this due to redispatching from"
- "__torch_dispatch__ into another op without restoring dispatch mode"
- )
+ track_tensor_tree(out, proxy_out, constant=constant, tracer=tracer)
+ return out
class PythonKeyTracer(Tracer):
@@ -304,33 +259,21 @@
return GraphModule(tracer.root, graph, name)
-def wrap_key(f, inps, proxy_mode, tracer):
- flat_inps, _ = pytree.tree_flatten(inps)
+def wrap_key(f, tensors, tracer):
+ flat_tensors, tensors_spec = pytree.tree_flatten(tensors)
@functools.wraps(f)
- def wrapped(*args):
- flat_args, args_spec = pytree.tree_flatten(args)
- assert (len(flat_args) == len(flat_inps))
- for idx, arg in enumerate(flat_args):
- if isinstance(flat_inps[idx], torch.Tensor):
- with no_dispatch():
- track_metadata(flat_inps[idx], arg, tracer)
- flat_args[idx] = ProxyTensor(
- flat_inps[idx],
- arg,
- requires_grad=(flat_inps[idx].is_leaf and flat_inps[idx].requires_grad),
- proxy_mode=proxy_mode,
- )
- else:
- flat_args[idx] = flat_inps[idx]
+ def wrapped(*proxies):
+ flat_proxies, proxies_spec = pytree.tree_flatten(proxies)
+ assert len(flat_proxies) == len(flat_tensors)
+ track_tensor_tree(flat_tensors, flat_proxies, constant=None, tracer=tracer)
- tree_args = pytree.tree_unflatten(flat_args, args_spec)
- out = f(*tree_args)
- flat_outs, out_spec = pytree.tree_flatten(out)
- for idx in range(len(flat_outs)):
- if isinstance(flat_outs[idx], torch.Tensor) and isinstance(flat_outs[idx], ProxyTensor):
- flat_outs[idx] = flat_outs[idx].proxy
- return pytree.tree_unflatten(flat_outs, out_spec)
+ out = f(*tensors)
+ return pytree.tree_map_only(
+ torch.Tensor,
+ lambda t: t.__dict__[tracer].proxy if tracer in t.__dict__ else t,
+ out
+ )
return wrapped
@@ -340,6 +283,7 @@
self.tracer = tracer
self.enable_tracing = True
self.sym_mode = ProxySymDispatchMode(tracer)
+ self.trace_state = {}
def __torch_dispatch__(self, func_overload, types, args=(), kwargs=None):
with self.sym_mode.enable(False):
@@ -364,7 +308,13 @@
if func_overload == aten.lift.default:
return args[0]
- if any(tuple(isinstance(arg, ProxyTensor) for arg in pytree.tree_flatten(args)[0])):
+ if func in [prim.device]:
+ return func_overload(*args, **kwargs)
+
+ if any(
+ isinstance(arg, torch.Tensor) and self.tracer in arg.__dict__
+ for arg in pytree.tree_flatten(args)[0]
+ ):
out = proxy_call(self, func_overload, args, kwargs)
# When we trace through a torch.tensor invocation, you never actually
# see a torch.ops.aten.tensor call. Instead, the way this function is
@@ -401,10 +351,11 @@
# This is what the overload modification does.
else:
flat_args = pytree.tree_flatten((args, kwargs))[0]
- handled_types = [torch.Tensor, ProxyTensor, torch.nn.Parameter]
+ handled_types = [torch.Tensor, _ProxyTensor, torch.nn.Parameter]
# If there are any tensor subclasses, we need to handle those tensor subclasses first
- if any([isinstance(arg, torch.Tensor) and type(arg) not in handled_types for arg in flat_args]):
+ # TODO: we could use types to test this
+ if any(isinstance(arg, torch.Tensor) and type(arg) not in handled_types for arg in flat_args):
return NotImplemented
if func_overload is torch.ops.aten.lift_fresh.default:
@@ -412,10 +363,10 @@
n_args, n_kwargs = pytree.tree_map_only(SymInt, fetch_symint_proxy(self.tracer), (args, kwargs))
- proxy_res = self.tracer.create_proxy('call_function', func_overload, n_args, n_kwargs,
+ proxy_out = self.tracer.create_proxy('call_function', func_overload, n_args, n_kwargs,
name=self.tracer.graph._target_to_str(func.__name__))
- inner_res = func_overload(*args, **kwargs)
+ out = func_overload(*args, **kwargs)
# If this is a lift, the input tensor is guaranteed to be a
# constant, so we keep a copy of the original argument along so
@@ -426,14 +377,14 @@
constant = args[0].clone()
else:
constant = None
- out = wrap_output(inner_res, proxy_res, constant=constant, proxy_mode=self)
+ track_tensor_tree(out, proxy_out, constant=constant, tracer=self.tracer)
def assert_proxy_tensor(e):
if isinstance(e, torch.Tensor):
- assert isinstance(e, ProxyTensor), \
- f"Internal Error: ProxyTensor is incorrectly baking a tensor constant into the graph: {str(e)}"
+ assert self.tracer in e.__dict__, \
+ f"Internal Error: make_fx is incorrectly baking a tensor constant into the graph: {str(e)}"
- # When we trace factory functions, we expect that tensor outputs are *always* ProxyTensors.
+ # When we trace factory functions, we expect that tensor outputs are *always* tracked.
# (Except for torch.tensor() constants handled through lift(), which is handled
# specially further up).
pytree.tree_map(assert_proxy_tensor, out)
@@ -447,6 +398,9 @@
def __init__(self, tracer):
super().__init__()
self.tracer = tracer
+ # When false, we don't trace operations. If you do this, you MUST
+ # call track_tensor/track_tensor_tree on all results of the operation
+ # to ensure we can adeduately track the results
self.enable_tracing = True
@contextmanager
@@ -477,6 +431,8 @@
return out
+# TODO: I'm not sure what the point of this class is; you can just
+# make_fx through a regular Interpreter
class DecompositionInterpreter(torch.fx.Interpreter):
def __init__(self, module: torch.fx.GraphModule, new_graph: torch.fx.Graph, decomposition_table=None, **kwargs):
super().__init__(module, **kwargs)
@@ -489,20 +445,24 @@
def placeholder(self, target, args, kwargs):
out = super().placeholder(target, args, kwargs)
+ proxy = torch.fx.Proxy(self.new_graph.placeholder(target), self.tracer)
+ track_tensor_tree(out, proxy, constant=None, tracer=self.tracer)
# TODO handle case where the first character of target is '*'
- return ProxyTensor(out, torch.fx.Proxy(self.new_graph.placeholder(target), self.tracer), proxy_mode=self.mode)
+ return out
def get_attr(self, target, args, kwargs):
out = super().get_attr(target, args, kwargs)
- return ProxyTensor(out, torch.fx.Proxy(self.new_graph.get_attr(target), self.tracer), proxy_mode=self.mode)
+ proxy = torch.fx.Proxy(self.new_graph.get_attr(target), self.tracer)
+ track_tensor_tree(out, proxy, constant=None, tracer=self.tracer)
+ return out
- # call_function, call_method, call_module get traced automatically by the ProxyTensors.
+ # call_function, call_method, call_module get traced automatically by the outer mode.
def output(self, target, args, kwargs):
out = super().output(target, args, kwargs)
def unwrap(e):
- return e.proxy.node if isinstance(e, ProxyTensor) else e
+ return e.__dict__[self.tracer].proxy.node if self.tracer in e.__dict__ else e
self.new_graph.output(pytree.tree_map(unwrap, out))
return out
@@ -581,7 +541,7 @@
func = f
with decompose(decomposition_table), fake_tensor_mode, sym_mode, proxy_mode: # type: ignore[attr-defined]
- t = dispatch_trace(wrap_key(func, args, proxy_mode, fx_tracer), tracer=fx_tracer, concrete_args=tuple(phs))
+ t = dispatch_trace(wrap_key(func, args, fx_tracer), tracer=fx_tracer, concrete_args=tuple(phs))
# TODO: kind of a bad way to do it, should maybe figure out a better way
t.shape_env = shape_env # type: ignore[assignment]
@@ -601,6 +561,7 @@
@contextlib.contextmanager
def disable_proxy_modes_tracing():
+ # TODO: This probably doesn't correctly also disable ProxySymDispatchMode
modes = get_torch_dispatch_modes()
proxy_tensor_modes = [m for m in modes if isinstance(m, ProxyTorchDispatchMode)]
olds = [m.enable_tracing for m in proxy_tensor_modes]
@@ -622,19 +583,6 @@
"""
wrapped, all_args = wrapper_and_args_for_make_fx(func, args, kwargs)
- unwrapped_all_args = [unwrap_elem(a) for a in all_args]
-
- # Current implementation doesn't support the case when ProxyTensor is
- # wrapped with another Tensor subclass
- # See: https://github.com/pytorch/pytorch/pull/81764#issuecomment-1200472068
- # TODO: Once https://github.com/pytorch/pytorch/pull/82549 is merged, we can
- # remove this
- assert all(
- getattr(a, "elem", None) is None
- for a in unwrapped_all_args
- if isinstance(a, torch.Tensor)
- ), "ProxyTensor is wrapped with another Tensor subclass"
-
with disable_proxy_modes_tracing():
- gm = make_fx(wrapped)(unwrapped_all_args)
+ gm = make_fx(wrapped)(all_args)
return gm
diff --git a/torch/testing/_internal/common_methods_invocations.py b/torch/testing/_internal/common_methods_invocations.py
index 1ef2c9a..33d8eef 100644
--- a/torch/testing/_internal/common_methods_invocations.py
+++ b/torch/testing/_internal/common_methods_invocations.py
@@ -19549,9 +19549,8 @@
DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref_meta'),
# RuntimeError: no _refs support for torch.Tensor.tolist
DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref'),
- # RuntimeError: .tolist() is not supported for tensor subclasses.
- DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref_executor'),
- )
+ ),
+ supports_nvfuser=False,
),
PythonRefInfo(
"_refs.hsplit",