Make proxy tensor support item() calls on torch.tensor constants (#81192)
This PR is doing a few interrelated things, all of which are necessary to get correctness. Read the comment in torch/fx/experimental/proxy_tensor.py for the high level overview.
Let's break down the parts of this PR:
* Bug fix where `enable_torch_dispatch_mode` with `None` doesn't work. This make `enable_torch_dispatch_mode(current_mode.inner)` work which is the basis for how we temporarily disable fake tensor mode.
* Bug fix for when fake tensor mode is combined with a non-mode tensor subclass. This actually could be ablated from this PR but it affects where the logic for allowing non fake tensor inputs with lift goes, so it's all in here in one go. There are some relevant tests for the fix in fake tensor, but it turns out I didn't need this because I'm always using proxy tensors as a mode (which ensures the ordering is right.)
* New `lift_fresh` view operator. Note that like lift, we have to manually write the functionalize kernel for these functions.
* The actual change, which is to save constants when we see them in the proxy tensor mode, and then propagate them as we go (because otherwise you'll handle mutations on constants incorrectly--see test.)
This is mildly BC-breaking if anyone was previously interposing on
at::lift, but this operator was relatively new and I checked
functorch which has no explicit reference to lift. So I think it
should not be too disruptive.
Signed-off-by: Edward Z. Yang <[email protected]>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/81192
Approved by: https://github.com/samdow, https://github.com/bdhirsh
diff --git a/aten/src/ATen/FunctionalInverses.cpp b/aten/src/ATen/FunctionalInverses.cpp
index 1aa9481..a8e3c4d 100644
--- a/aten/src/ATen/FunctionalInverses.cpp
+++ b/aten/src/ATen/FunctionalInverses.cpp
@@ -172,6 +172,10 @@
return mutated_view;
}
+Tensor FunctionalInverses::lift_fresh_copy_inverse(const Tensor& base, const Tensor& mutated_view, bool reapply_views) {
+ return mutated_view;
+}
+
Tensor FunctionalInverses::slice_copy_Tensor_inverse(const Tensor& base, const Tensor& mutated_view, bool reapply_views, int64_t dim, c10::optional<int64_t> start, c10::optional<int64_t> end, int64_t step) {
// Pessimism: we can't reapply views for slice_scatter.
return base.slice_scatter(mutated_view, dim, start, end, step);
diff --git a/aten/src/ATen/FunctionalizeFallbackKernel.cpp b/aten/src/ATen/FunctionalizeFallbackKernel.cpp
index abe3743..8137e3b 100644
--- a/aten/src/ATen/FunctionalizeFallbackKernel.cpp
+++ b/aten/src/ATen/FunctionalizeFallbackKernel.cpp
@@ -14,6 +14,7 @@
#else
#include <ATen/ops/_to_copy.h>
#include <ATen/ops/to_native.h>
+#include <ATen/ops/lift_fresh_copy.h>
#include <ATen/ops/resize.h>
#include <ATen/ops/as_strided.h>
#include <ATen/ops/as_strided_copy.h>
@@ -176,6 +177,22 @@
return at::functionalization::impl::to_functional_tensor(self);
}
+at::Tensor lift_functionalize_copy(const at::Tensor & self) {
+ TORCH_INTERNAL_ASSERT(!at::functionalization::impl::isFunctionalTensor(self));
+ return at::functionalization::impl::to_functional_tensor(self.clone());
+}
+
+at::Tensor lift_fresh_functionalize(const at::Tensor & self) {
+ TORCH_INTERNAL_ASSERT(!at::functionalization::impl::isFunctionalTensor(self));
+ return at::functionalization::impl::to_functional_tensor(self);
+}
+
+at::Tensor lift_fresh_functionalize_copy(const at::Tensor & self) {
+ TORCH_INTERNAL_ASSERT(!at::functionalization::impl::isFunctionalTensor(self));
+ at::AutoDispatchSkipFunctionalize guard;
+ return at::functionalization::impl::to_functional_tensor(at::lift_fresh_copy(self));
+}
+
bool device_opted_into_functionalization(c10::Device self_device, c10::optional<c10::Device> tgt_device) {
// If the target device is empty, then the output tensor should be on the same device as the input
auto real_tgt_device = tgt_device.has_value() ? tgt_device.value() : self_device;
@@ -276,6 +293,8 @@
TORCH_LIBRARY_IMPL(aten, Functionalize, m) {
m.impl("resize_", TORCH_FN(resize__functionalization));
m.impl("lift", TORCH_FN(lift_functionalize));
+ m.impl("lift_fresh", TORCH_FN(lift_fresh_functionalize));
+ m.impl("lift_fresh_copy", TORCH_FN(lift_fresh_functionalize_copy));
m.impl("_to_copy", TORCH_FN(_to_copy_functionalize));
m.impl("_unsafe_view", TORCH_FN(_unsafe_view_functionalize));
}
diff --git a/aten/src/ATen/native/TensorShape.cpp b/aten/src/ATen/native/TensorShape.cpp
index ce40f04..c46d2dc 100644
--- a/aten/src/ATen/native/TensorShape.cpp
+++ b/aten/src/ATen/native/TensorShape.cpp
@@ -3407,6 +3407,11 @@
return self;
}
+// See notes in native_functions.yaml
+at::Tensor lift_fresh(const at::Tensor& self) {
+ return self;
+}
+
at::Tensor& _fw_primal_copy_out(const at::Tensor & self, int64_t level, at::Tensor & out) {
auto tmp = self._fw_primal(level);
out.copy_(tmp);
diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml
index afda967..879f77b 100644
--- a/aten/src/ATen/native/native_functions.yaml
+++ b/aten/src/ATen/native/native_functions.yaml
@@ -6333,13 +6333,31 @@
MPS: set_mps_
autogen: set, set.out
+# Not making it CompositeImplicitAutograd because lift
+# should be a primitive w.r.t. functorch
+
+# TODO: this should have a view annotation
+# TODO: shouldn't be a method
- func: lift(Tensor self) -> Tensor
variants: method
dispatch:
- # Not making it CompositeImplicitAutograd because lift
- # should be a primitive w.r.t. functorch
CompositeExplicitAutograd: lift
+# lift_fresh is called with an argument that is guaranteed to be
+# fresh (i.e., newly allocated). This is ONLY called from a
+# torch.tensor call; if you FX trace a lift_fresh, you are obligated
+# to convert this into a lift_fresh_copy (because FX will violate the
+# freshness invariant when tracing).
+- func: lift_fresh(Tensor(a) self) -> Tensor(a)
+ dispatch:
+ CompositeExplicitAutograd: lift_fresh
+
+# Like lift, but it clones the input.
+- func: lift_fresh_copy(Tensor self) -> Tensor
+ tags: view_copy
+ dispatch:
+ CompositeExplicitAutograd: lift_fresh_copy
+
- func: is_set_to(Tensor self, Tensor tensor) -> bool
variants: method
device_check: NoCheck
diff --git a/test/test_proxy_tensor.py b/test/test_proxy_tensor.py
index 65bab0f..e2d3782 100644
--- a/test/test_proxy_tensor.py
+++ b/test/test_proxy_tensor.py
@@ -164,6 +164,43 @@
self._test(f, [])
+ def test_constant_proxy_tensor(self):
+ from torch.fx.experimental.proxy_tensor import make_fx
+
+ def f():
+ val = torch.tensor(float('inf'))
+ return torch.full((100, 100), val)
+
+ g = make_fx(f)()
+ self.assertEqual(g(), f())
+
+ def test_constant_proxy_tensor_mut(self):
+ from torch.fx.experimental.proxy_tensor import make_fx
+
+ def f():
+ val = torch.tensor(float(1))
+ val.add_(2)
+ return torch.full((100, 100), val)
+
+ g = make_fx(f)()
+ self.assertEqual(g(), f())
+ # In case we mutated shared state in the g graph!
+ self.assertEqual(g(), f())
+
+ g = make_fx(f, use_fake=True)()
+ self.assertEqual(g(), f())
+ # In case we mutated shared state in the g graph!
+ self.assertEqual(g(), f())
+
+ def test_use_fake_and_tensor(self):
+ def f(x, y):
+ z = torch.tensor([2.0, 3.0])
+ return x + y + z
+
+ g = make_fx(f, use_fake=True)(torch.randn(2), torch.randn(2))
+ x, y = torch.randn(2), torch.randn(2)
+ self.assertEqual(g(x, y), f(x, y))
+
def test_decomposition_interpreter(self):
def fn(x):
return torch.nn.functional.silu(x)
@@ -247,28 +284,6 @@
xfail('cholesky_inverse'),
# ASAN failures due to divide by 0
skip('nn.functional.nll_loss'),
- # Masked failures (creating a scalar tensor just to call `.item` on it)
- xfail('_masked.amax'),
- xfail('_masked.amax'),
- xfail('_masked.amin'),
- xfail('_masked.argmax'),
- xfail('_masked.argmin'),
- xfail('_masked.cumprod'),
- xfail('_masked.cumsum'),
- xfail('_masked.log_softmax'),
- xfail('_masked.logaddexp'),
- xfail('_masked.logsumexp'),
- xfail('_masked.mean'),
- xfail('_masked.median'),
- xfail('_masked.norm'),
- xfail('_masked.prod'),
- xfail('_masked.softmax'),
- xfail('_masked.softmin'),
- xfail('_masked.std'),
- xfail('_masked.sum'),
- xfail('_masked.var'),
- # Same as masked failures - preventing torch.tensor constants from turning into proxytensors causes issues with faketensors
- xfail('__getitem__'),
}
diff --git a/test/test_python_dispatch.py b/test/test_python_dispatch.py
index 475f479..2233d50 100644
--- a/test/test_python_dispatch.py
+++ b/test/test_python_dispatch.py
@@ -1152,6 +1152,16 @@
with PoliteMode():
a.abs()
+ def test_disable_mode(self):
+ class FailEverythingMode(TorchDispatchMode):
+ def __torch_dispatch__(self, func, types, args=(), kwargs=None):
+ raise RuntimeError("arf")
+
+ with FailEverythingMode() as m:
+ self.assertRaises(RuntimeError, lambda: torch.ones([2, 3]))
+ with enable_torch_dispatch_mode(None, replace=m):
+ torch.ones([2, 3])
+
def test_make_wrapper_subclass_with_modes(self):
class ModeTensor(torch.Tensor):
def __new__(cls, elem, mode):
diff --git a/tools/autograd/gen_python_functions.py b/tools/autograd/gen_python_functions.py
index 77aeeb1..a17d45a 100644
--- a/tools/autograd/gen_python_functions.py
+++ b/tools/autograd/gen_python_functions.py
@@ -150,7 +150,7 @@
"copy", # only used by the functionalization pass
"fill.Tensor", # only used by the functionalization pass
"fill.Scalar", # only used by the functionalization pass
- "lift",
+ "lift.*",
"normal_functional", # only used by the functionalization pas
]
diff --git a/torch/_decomp/decompositions.py b/torch/_decomp/decompositions.py
index 359001c..39e07fe 100644
--- a/torch/_decomp/decompositions.py
+++ b/torch/_decomp/decompositions.py
@@ -1100,7 +1100,9 @@
# Questionable decompositions
# This is only valid if we're running the graph without autograd, such as if the backward pass has been traced.
# Note that this decomposition causes issues with in-place ops
-@register_decomposition([aten.detach, aten.lift, aten.alias], disable_meta=True)
+@register_decomposition(
+ [aten.detach, aten.lift, aten.lift_fresh, aten.alias], disable_meta=True
+)
def nop_decomposition(x):
return x
diff --git a/torch/_subclasses/fake_tensor.py b/torch/_subclasses/fake_tensor.py
index 7b7f867..5eb982c 100644
--- a/torch/_subclasses/fake_tensor.py
+++ b/torch/_subclasses/fake_tensor.py
@@ -290,7 +290,7 @@
def __init__(self, fake_mode, elem, device: Union[torch.device, str]):
# elem does not need to be recorded, because FakeTensor *is a* elem
- assert elem.device.type == "meta"
+ assert elem.device.type == "meta", elem
device = device if isinstance(device, torch.device) else torch.device(device)
assert device.type != "meta"
self.fake_device = device
@@ -334,6 +334,14 @@
else:
return args[0].fake_device
+ # Because fake mode can return NotImplemented (if it sees a subclass
+ # it doesn't know how to deal with), this test here is important
+ # because the next dispatch after a fake mode will attempt to use
+ # subclasses of tensors to dispatch, and any FakeTensor arguments
+ # will be considered eligible.
+ if any(not issubclass(t, FakeTensor) and t is not torch.Tensor for t in types):
+ return NotImplemented
+
fake_mode = None
for arg in itertools.chain(tree_flatten(args)[0], tree_flatten(kwargs)[0]):
if isinstance(arg, FakeTensor):
@@ -446,18 +454,6 @@
# TODO: apply as no_dispatch decorator
converter = self.fake_tensor_converter
- # this is generated from torch.tensor(), which does not use the
- # dispatcher, to allow wrapper subclasses to wrap the new tensor
- # we need to handle before error checking
- if func == torch.ops.aten.lift.default:
- assert (
- len(kwargs) == 0
- and len(args) == 1
- and type(args[0]) is torch.Tensor
- )
- with no_dispatch():
- return converter(self, args[0])
-
def wrap(e, device=None):
if isinstance(e, torch.Tensor) and not isinstance(e, FakeTensor):
return converter(self, e, device)
@@ -468,20 +464,54 @@
# are not FakeTensors. For now, throw if any non-Fake Tensor inputs
# and just support constructors. TODO: extend more broadly
conversion_made = False
+ subclass_seen = False
def check_non_fake_tensor(x):
- nonlocal conversion_made
+ nonlocal conversion_made, subclass_seen
conversion_made = conversion_made or (
isinstance(x, torch.Tensor) and not isinstance(x, FakeTensor)
)
+ subclass_seen = subclass_seen or (
+ isinstance(x, torch.Tensor) and not isinstance(x, FakeTensor)
+ and type(x) is not torch.Tensor
+ )
tree_map(check_non_fake_tensor, args)
tree_map(check_non_fake_tensor, kwargs)
+ # Suppose we enable fake tensor mode. This means that fake tensor
+ # mode will run first. But what if we do an operation that
+ # involves a tensor subclass that will desugar into normal tensor
+ # operations? Without this line, fake tensor mode will run first,
+ # decide that a conversion was made (since there was a non fake
+ # tensor argument), and report an error that converting non
+ # fake tensor is not supported. What we actually wanted to happen
+ # was to give the subclass a chance to figure out what it wants to
+ # before erroring out. Returning NotImplemented here allows this.
+ #
+ # NB: If you're seeing a mysterious infinite loop involving fake
+ # tensor, it might be related to this line. Though I'm not sure
+ # how you'll know to read this comment, as this line won't show up
+ # in the stack trace.
+ if subclass_seen:
+ return NotImplemented
+
+ # this is generated from torch.tensor(), which does not use the
+ # dispatcher, to allow wrapper subclasses to wrap the new tensor
+ # we need to handle before error checking
+ if func in [torch.ops.aten.lift_fresh.default, torch.ops.aten.lift_fresh_copy.default]:
+ assert (
+ len(kwargs) == 0
+ and len(args) == 1
+ and type(args[0]) is torch.Tensor
+ ), f"{args} {kwargs}"
+ with no_dispatch():
+ return converter(self, args[0])
+
if conversion_made:
raise Exception(
"Invoking operators with non-Fake Tensor inputs in FakeTensorMode is not yet supported. "
- f"Please convert all Tensors to FakeTensors first. Found in {func}"
+ f"Please convert all Tensors to FakeTensors first. Found in {func}(*{args}, **{kwargs})"
)
for run_impl_check, op_impl in op_implementations:
diff --git a/torch/csrc/utils/tensor_new.cpp b/torch/csrc/utils/tensor_new.cpp
index 2d65d65..60ea20c 100644
--- a/torch/csrc/utils/tensor_new.cpp
+++ b/torch/csrc/utils/tensor_new.cpp
@@ -414,8 +414,9 @@
at::tracer::impl::NoTracerDispatchMode tracer_guard;
// lift has no autograd implementation, so we need to make sure we don't try
// to dispatch to it.
+ // TODO: arguably it should have an autograd implementation that noops
at::AutoDispatchBelowADInplaceOrView guard;
- return tensor.lift();
+ return at::lift_fresh(tensor);
}
Tensor new_from_data_copy(
diff --git a/torch/fx/experimental/proxy_tensor.py b/torch/fx/experimental/proxy_tensor.py
index c739deb..31bf620 100644
--- a/torch/fx/experimental/proxy_tensor.py
+++ b/torch/fx/experimental/proxy_tensor.py
@@ -15,7 +15,7 @@
from torch.fx.passes.shape_prop import _extract_tensor_metadata
from contextlib import contextmanager, nullcontext
-from torch.utils._python_dispatch import TorchDispatchMode
+from torch.utils._python_dispatch import TorchDispatchMode, enable_torch_dispatch_mode
__all__ = ["ProxyTensor", "PythonKeyTracer", "dispatch_trace", "make_fx", "enable_strict", "DecompositionInterpreter"]
aten = torch.ops.aten
@@ -39,11 +39,11 @@
global IS_STRICT
IS_STRICT = val
-def wrap_output(inner_res, proxy_res):
+def wrap_output(inner_res, proxy_res, **kwargs):
def wrap_with_proxy(e, proxy):
if isinstance(e, torch.Tensor):
with no_dispatch():
- return ProxyTensor(e, proxy)
+ return ProxyTensor(e, proxy, **kwargs)
else:
return e
@@ -60,6 +60,16 @@
return inner_res
+def maybe_disable_fake_tensor_mode():
+ # TODO: figure out if this API generally makes sense and bake it into the
+ # library
+ mb_fake_mode = torch._C._get_torch_dispatch_mode()
+ if isinstance(mb_fake_mode, FakeTensorMode):
+ return enable_torch_dispatch_mode(mb_fake_mode.inner, replace=mb_fake_mode)
+ else:
+ return nullcontext()
+
+
def proxy_call(func_overload, args, kwargs=None):
if kwargs is None:
kwargs = {}
@@ -67,6 +77,11 @@
if func_overload in CURRENT_DECOMPOSITION_TABLE:
return CURRENT_DECOMPOSITION_TABLE[func_overload](*args, **kwargs)
if func_overload == aten._local_scalar_dense.default:
+ t, = args
+ assert not kwargs
+ if t.constant is not None:
+ with maybe_disable_fake_tensor_mode():
+ return t.constant.item()
raise RuntimeError("It appears that you're trying to get value out of a tracing tensor - erroring out! "
"It's likely that this is caused by data-dependent control flow or similar."
"Try torch.fx.experimental.proxy_tensor.enable_strict(False) to disable this check")
@@ -88,14 +103,66 @@
args[0].proxy = proxy_res
proxy_res.node.meta['tensor_meta'] = _extract_tensor_metadata(args[0])
inner_res = func_overload(*pytree.tree_map(unwrap_elem, args), **pytree.tree_map(unwrap_elem, kwargs))
+
# Needed to sync up metadata for in-place operators that modify metadata
+ # TODO: instead forward the metadata to the inner tensor so updating
+ # is not necessary
if torch.Tag.inplace_view in func_overload.tags: # type: ignore[attr-defined]
with no_dispatch():
func_overload(*args, **kwargs)
+ # In some circumstances, we will be tracing in a situation where a tensor
+ # is *statically* known to be a constant (currently, this only happens if
+ # you run torch.tensor; deterministic factory functions like torch.arange
+ # don't get this treatment). When the tensor in question is small, it's
+ # helpful to due constant propagation in case we call item() (in which
+ # case we can return the constant value that is known, rather than give
+ # an error.) The logic here tests if constant propagation is possible
+ # (because all of the inputs are constant). If so, we disable fake tensor
+ # mode (if it is on) and do true compute on the constant.
+ #
+ # It's worth highlighting that we're making a policy decision here.
+ # There is a potential that the tensor is actually quite large, and we
+ # don't actually want to run the compute. The tensor being quite large
+ # is one of the reasons why factory functions don't get this treatment
+ # (since they can be quite large; if a parameter is initialized to a
+ # constant value it will be!) Similarly, there is also a potential
+ # to run an operator that blows up the size of a small tensor; we don't
+ # protect against this case, but we could force, e.g., only single
+ # element constant computation by testing the numel of the result before
+ # propagating const-ness. Similarly, we don't require the constant to
+ # live on CPU, but we could.
+ all_constant = True
+ any_constant = False
+
+ def check_constant(e):
+ nonlocal all_constant, any_constant
+ if isinstance(e, ProxyTensor):
+ if e.constant is None:
+ all_constant = False
+ else:
+ any_constant = True
+
+ pytree.tree_map(check_constant, args)
+ pytree.tree_map(check_constant, kwargs)
+
+ def unwrap_constant(e):
+ if isinstance(e, ProxyTensor):
+ return e.constant
+ return e
+
+ constant = None
+ # NB: do NOT include factories as constants
+ if all_constant and any_constant:
+ with maybe_disable_fake_tensor_mode():
+ constant = func_overload(
+ *pytree.tree_map(unwrap_constant, args),
+ **pytree.tree_map(unwrap_constant, kwargs)
+ )
+
# TODO(chilli): Enable this after it's been refactored to work with wrapper tensor subclasses in general
# pytree.tree_map(lambda x: check_metadata_consistency(x, ProxyTensor), (inner_res, args, kwargs))
- return wrap_output(inner_res, proxy_res)
+ return wrap_output(inner_res, proxy_res, constant=constant)
class ProxyTensor(torch.Tensor):
@@ -104,7 +171,7 @@
@staticmethod
- def __new__(cls, elem, proxy, *, requires_grad=None):
+ def __new__(cls, elem, proxy, *, requires_grad=None, constant=None):
r = torch.Tensor._make_wrapper_subclass( # type: ignore[attr-defined]
cls,
elem.shape, dtype=elem.dtype, layout=elem.layout, device=elem.device,
@@ -113,7 +180,7 @@
)
return r
- def __init__(self, elem, proxy, *, requires_grad=None):
+ def __init__(self, elem, proxy, *, requires_grad=None, constant=None):
if elem.is_sparse:
proxy.node.meta['tensor_meta'] = {}
else:
@@ -123,6 +190,7 @@
assert not (isinstance(elem, ProxyTensor) and elem.proxy.tracer is proxy.tracer)
self.elem = elem
self.proxy = proxy
+ self.constant = constant
def __deepcopy__(self, memo):
return self.clone()
@@ -220,13 +288,58 @@
return args[0]
if any(tuple(isinstance(arg, ProxyTensor) for arg in pytree.tree_flatten(args)[0])):
return proxy_call(func_overload, args, kwargs)
+ # When we trace through a torch.tensor invocation, you never actually
+ # see a torch.ops.aten.tensor call. Instead, the way this function is
+ # implemented internally is that we allocate a plain tensor (this is
+ # *guaranteed* to be a plain tensor, we disable all modes when doing
+ # so), and then call at::lift_fresh on it (to give modes a chance to do
+ # their stuff). Furthermore, the tensor argument to lift_fresh is guaranteed
+ # to be freshly allocated, so we want lift_fresh to be a no-op (directly
+ # returning the input argument).
+ #
+ # Here is the basic problem: when we trace this sequence of executions
+ # into an FX graph, what happens to this call sequence? Traditionally,
+ # tensor constants get interned as buffers on the FX GraphModule. But
+ # this is dangerous. Consider:
+ #
+ # x = torch.tensor(1)
+ # x.add_(2)
+ #
+ # Naively, this traces into:
+ #
+ # t = self._tensor_constant0 # initialized to torch.tensor(1)
+ # x = torch.ops.aten.lift_fresh(t)
+ # x.add_(2)
+ #
+ # If lift_fresh returns t directly, the subsequent add_ call will
+ # modify the tensor constant. Really, the problem is we've violated
+ # the invariant the the argument to lift is fresh. So what we should
+ # preserve the invariant by replacing lift_fresh with lift_fresh_copy:
+ #
+ # t = self._tensor_constant0 # initialized to torch.tensor(1)
+ # x = torch.ops.aten.lift_fresh_copy(t)
+ # x.add_(2)
+ #
+ # This is what the overload modification does.
else:
+ if func_overload is torch.ops.aten.lift_fresh.default:
+ func_overload = torch.ops.aten.lift_fresh_copy.default
+
proxy_res = self.tracer.create_proxy('call_function', func_overload, args, kwargs,
name=self.tracer.graph._target_to_str(func.__name__))
inner_res = func_overload(*args, **kwargs)
- return wrap_output(inner_res, proxy_res)
+ # If this is a lift, the input tensor is guaranteed to be a
+ # constant, so we keep a copy of the original argument along so
+ # we can query it if we're asked to item() it at some later point
+ is_lift = func_overload is torch.ops.aten.lift_fresh_copy.default
+ if is_lift:
+ with maybe_disable_fake_tensor_mode():
+ constant = args[0].clone()
+ else:
+ constant = None
+ return wrap_output(inner_res, proxy_res, constant=constant)
class DecompositionInterpreter(torch.fx.Interpreter):
@@ -262,6 +375,15 @@
return super().run(*args, **kwargs)
def make_fx(f, decomposition_table=None, trace_factory_functions=True, use_fake=False):
+ if use_fake and not trace_factory_functions:
+ raise ValueError("""\
+use_fake and not trace_factory_functions is not currently supported; if
+proxy tensor is not executed as a mode, fake tensors must not be executed
+as a mode either (otherwise, we will incorrectly intern fake tensors into
+the traced graph module.) However, non-mode execution of fake tensors
+is not currently supported (although, in principle, it could be; file
+a bug if you need this)""")
+
if decomposition_table is None:
decomposition_table = {}
diff --git a/torch/utils/_mode_utils.py b/torch/utils/_mode_utils.py
index 21c4018..8f00e81 100644
--- a/torch/utils/_mode_utils.py
+++ b/torch/utils/_mode_utils.py
@@ -81,11 +81,12 @@
)
# NB: we don't require TorchFunctionMode/PythonMode since this is intended to also
# let you directly pass a Tensor subclass type to "mode-ify" it.
- required_fn = "__" + mode_info.mode_name + "__"
- if not hasattr(mode, required_fn):
- raise ValueError(
- f'The argument passed to enable_{mode_info.mode_name}_mode must implement {required_fn}'
- )
+ if mode is not None:
+ required_fn = "__" + mode_info.mode_name + "__"
+ if not hasattr(mode, required_fn):
+ raise ValueError(
+ f'The argument passed to enable_{mode_info.mode_name}_mode must implement {required_fn}'
+ )
mode_info.set_mode(mode)
try:
yield mode # type: ignore[misc]
diff --git a/torchgen/gen_functionalization_type.py b/torchgen/gen_functionalization_type.py
index 37ced1f..e0aa157 100644
--- a/torchgen/gen_functionalization_type.py
+++ b/torchgen/gen_functionalization_type.py
@@ -688,6 +688,9 @@
if isinstance(g, NativeFunctionsViewGroup):
# functionalization needs to register kernels for view + view_inplace ops
+ # See Note [Functionalization <> torch.Tensor constructor]
+ if str(g.view.func.name) == "lift_fresh":
+ return []
view_str = [emit_registration_helper(g.view)]
if g.view_inplace is not None:
assert g.view_inplace.is_view_op