blob: 8db8f94b1b4195cd1be7f5a8a502bcbfe4d3994d [file] [log] [blame]
import contextlib
import warnings
import weakref
from typing import ContextManager, List, Optional, Tuple, TYPE_CHECKING
import torch
from torch._C._functorch import (
_unwrap_functional_tensor,
_wrap_functional_tensor,
current_level,
peek_interpreter_stack,
TransformType,
)
from torch._guards import Source
from torch.multiprocessing.reductions import StorageWeakRef
from torch.utils._python_dispatch import (
is_traceable_wrapper_subclass,
transform_subclass,
)
from torch.utils.weak import WeakIdRef
if TYPE_CHECKING:
# Import the following modules during type checking to enable code intelligence features,
# Do not import unconditionally, as they import sympy and importing sympy is very slow
from torch.fx.experimental.symbolic_shapes import SymbolicContext
DimList = List
def safe_is_leaf(t):
try:
return t.is_leaf
except RuntimeError:
# inference mode can trigger this
return False
def safe_grad(t):
with warnings.catch_warnings():
warnings.filterwarnings("ignore", "The .grad attribute of a Tensor")
return t.grad
def assert_eq(a, b):
assert a == b, f"{a} != {b}"
def assert_metadata_eq(assert_eq, m1, m2, *, skip_symbolic=False):
def go(m1, m2):
assert_eq(m1.dtype, m2.dtype)
if not skip_symbolic:
assert_eq(m1.shape, m2.shape)
assert_eq(m1.requires_grad, m2.requires_grad)
assert_eq(m1.is_leaf, m2.is_leaf)
assert_eq(m1.grad_fn is None, m2.grad_fn is None)
assert_eq(m1.is_sparse, m2.is_sparse)
assert_eq(m1.is_inference(), m2.is_inference())
assert_eq(m1.is_conj(), m2.is_conj())
assert_eq(m1.is_neg(), m2.is_neg())
assert_eq(safe_grad(m1) is not None, safe_grad(m2) is not None)
if safe_grad(m1) is not None:
go(safe_grad(m1), safe_grad(m2))
if m1.is_sparse:
assert_eq(m1.dense_dim(), m2.dense_dim())
assert_eq(m1.sparse_dim(), m2.sparse_dim())
assert_eq(m1.is_coalesced(), m2.is_coalesced())
else:
if not skip_symbolic:
assert_eq(m1.stride(), m2.stride())
assert_eq(m1.storage_offset(), m2.storage_offset())
assert_eq(m1._is_view(), m2._is_view())
if m1._is_view():
go(m1._base, m2._base)
# TODO: test if is resizable (no direct query for this atm)
# TODO: audit AutogradMeta to see if it matches
# TODO: test forward AD
return go(m1, m2)
# This is a class for converting multiple tensors into meta tensors which
# share the same view/storage structure. The operation model is you allocate
# one of these, and then call it repeatedly on all the tensors you want to
# convert. It's important to use the same object for tensors you want to
# share storage because this is how we correlate shared storages to the same
# meta storages. This class will hold weak references to cached tenosrs
# and tensor storages.
class MetaConverter:
def __init__(self):
self.storage_memo = {}
self.tensor_memo: weakref.WeakValueDictionary = weakref.WeakValueDictionary()
self.maybe_storages_to_delete = []
self.check_expired_frequency = 128
self.check_expired_count = 0
self.hit = 0
self.miss = 0
self.del_hook = None
self.arg_cnt = 0
def successful(self):
return self.hit > 0 and self.miss == 0
def check_for_expired_weak_storages(self):
new_li = []
stor_to_delete = []
for obj in self.maybe_storages_to_delete:
if not obj.expired():
new_li.append(obj)
else:
stor_to_delete.append(obj)
for obj in stor_to_delete:
self.storage_memo.pop(obj, None)
self.maybe_storages_to_delete = new_li
# if for some reason we have aquired many storages which have not expired
# even though a tensor with their storage has expired (aliasing or otherwise)
# check for expired storages less often so as to bound the amount of work we
# do checking for expired storages
self.check_expired_frequency = max(
self.check_expired_frequency, len(self.maybe_storages_to_delete)
)
def get_tensor_memo(self, t):
return self.tensor_memo.get(WeakIdRef(t), None)
def set_tensor_memo(self, t, v):
# hold a weak ref to self, otherwise it will be kept alive
# by the del_ten closure
self_weak_ref = weakref.ref(self)
if t.is_sparse or t.is_mkldnn:
weak_st = None
else:
weak_st = StorageWeakRef(t._typed_storage())
tensor_ref_key = WeakIdRef(t)
def del_ten():
# tensor outlives the converter
self_ref = self_weak_ref()
if self_ref is None:
return
# on shutdown, tensor_ref_key may not be in memo
self_ref.tensor_memo.pop(tensor_ref_key, None)
if weak_st and weak_st.expired():
self_ref.storage_memo.pop(weak_st, None)
elif weak_st is not None:
# [expired-storages]
# NB: even though the tensor has died,
# the deallocation of its storage can take longer,
# even when the storage has no other uses/views.
# In this case, the StorageWeakRef object will be kept alive
# longer than it needs to be, however the storage itself
# will be deallocated. We retain the possibly dead storages
# and periodically check if any of them are expired and
# can be freed.
self_ref.maybe_storages_to_delete.append(weak_st)
weakref.finalize(t, del_ten)
self.tensor_memo[tensor_ref_key] = v
# NB: doesn't actually return a storage, because meta storage is
# not supported
def meta_storage(self, s, callback):
# NB: TypedStorage is freshly allocated and cannot be used as hash
# key index.
# Use a Weak Ref to s in order to not leak memory
swr = StorageWeakRef(s)
if swr not in self.storage_memo:
self.storage_memo[swr] = callback(
lambda: torch.empty(s.size(), dtype=torch.uint8, device="meta")
).untyped_storage()
return self.storage_memo[swr]
# This function assumes that it's possible to do the conversion
# NB: name here is used in a conventional way by Dynamo; it corresponds
# precisely to the Source.name() of the tensor we're fakeifying and
# corresponds to a valid Python expression. When we construct sub-names
# as part of this process, we will maintain this invariant! (Even though
# other users of this may not need it this property to be upheld.)
def meta_tensor(
self,
t,
shape_env=None,
callback=lambda t: t(),
source: Optional[Source] = None,
symbolic_context: Optional["SymbolicContext"] = None,
):
from torch._subclasses.fake_tensor import FakeTensor
if source is None:
from torch._dynamo.source import ConstantSource
# TODO: make a dedicated UnknownSource for this?
source = ConstantSource(
f"__meta_utils_unknown_tensor{len(self.tensor_memo)}"
)
# This indicates you set no_dispatch() before calling into this
# function. This is an error: we may be creating fake tensors and
# will perform operations on them which need fake tensor mode to
# be active. You will segfault if you are in a no_dispatch() block.
assert not torch._C._dispatch_tls_local_exclude_set().has(
torch._C.DispatchKey.Python
)
arg_cnt = self.arg_cnt
self.arg_cnt += 1
# When we make as_strided calls, we end up generating a guard
# that the new as_strided tensor is in bounds for the old storage
# for the base (since as_strided calls can "bust" out of their
# bounding box.) This guard is unnecessary: if a user is able
# to provide us a tensor with the view base setup this way, we
# don't need to produce a guard, because the fact that they
# were able to produce the view base means its in bounds.
#
# Now, ordinarily, this guard would be harmless. However, the
# generated guard refers to variables bound on the base variable.
# At the moment, Dynamo doesn't actually guard on x._base, because
# according to Voz this results in a lot of spurious invalidations,
# and also if the user doesn't directly make use of _base, its
# pointless anyway (because programs should be parametric over
# whether or not the input tensor is a view or not--unless you're
# mutating the input, but that's a whole 'nother ballgame). So
# for expediency, we suppress these guards so we don't have to
# deal with this (yet, anyway.)
#
# NB: An old version of this code suppressed guards for ALL operations
# happening during meta conversion, not just as_strided calls.
# This is too aggressive: we do duck sizing and 0/1 simplification
# as we allocate variables, and we do need to register guards for
# these cases.
maybe_suppress = contextlib.nullcontext
if shape_env is not None:
maybe_suppress = shape_env.suppress_guards
def sym_sizes_strides_storage_offset(
t, src
) -> Tuple[Tuple[int, ...], Tuple[int, ...], int]:
if shape_env is not None:
if isinstance(t, FakeTensor) and t.fake_mode.shape_env is shape_env:
# Don't reallocate the sizes; the shape envs are the same,
# so reuse the old sizes/strides/etc
return (t.size(), t.stride(), t.storage_offset())
else:
return shape_env.create_symbolic_sizes_strides_storage_offset(
t,
src,
# Assume that the set of dims that are dynamic are the same between
# the wrapper tensor and any inner tensors.
# We can revisit this if this assumption does not hold
# for any important subclasses later.
symbolic_context=symbolic_context,
)
else:
assert symbolic_context is None
return (t.size(), t.stride(), t.storage_offset())
# see expired-storages
self.check_expired_count += 1
if self.check_expired_count >= self.check_expired_frequency:
self.check_for_expired_weak_storages()
self.check_expired_count = 0
if self.get_tensor_memo(t) is None:
with torch.inference_mode(t.is_inference()):
if t.is_sparse:
is_leaf = safe_is_leaf(t)
r = callback(
lambda: torch.ops.aten._sparse_coo_tensor_with_dims(
t.sparse_dim(),
t.dense_dim(),
t.shape,
dtype=t.dtype,
layout=torch.sparse_coo,
device="meta",
)
)
assert safe_is_leaf(r), "the callback you passed in doesn't detach"
# Note [is_coalesced is dispatched]
# Strangely enough, is_coalesced() is a dispatched operator,
# which means that it will get caught by fake tensor mode.
# Ordinarily this would error, but there's some logic in
# fake tensor ensure this doesn't happen.
r._coalesced_(t.is_coalesced())
if t.requires_grad:
r.requires_grad = True
if t.requires_grad and not is_leaf:
with torch.enable_grad():
r = r.clone()
r._coalesced_(t.is_coalesced())
elif t.is_mkldnn:
is_leaf = safe_is_leaf(t)
sizes, strides, _storage_offset = sym_sizes_strides_storage_offset(
t, source
)
r = callback(
lambda: torch.empty_strided(
sizes, strides, dtype=t.dtype, device="meta"
)
)
assert safe_is_leaf(r), "the callback you passed in doesn't detach"
if t.requires_grad:
r.requires_grad = True
if t.requires_grad and not is_leaf:
with torch.enable_grad():
r = r.clone()
elif t._is_view():
# Construct views in two steps: recursively meta-fy their
# base, and then create view(s) off that. NB: doing it
# directly from storage is WRONG because this won't cause
# version counters to get shared.
assert t._is_view()
from torch._dynamo.source import AttrSource
from torch.fx.experimental.symbolic_shapes import (
DimDynamic,
StatelessSymbolicContext,
)
if shape_env and not t.is_nested and not t._base.is_nested:
base_symbolic_context = StatelessSymbolicContext(
dynamic_sizes=[DimDynamic.STATIC] * t._base.dim(),
constraint_sizes=[None] * t._base.dim(),
)
else:
base_symbolic_context = None
base = self.meta_tensor(
t._base,
shape_env,
callback,
source=AttrSource(source, "_base"),
symbolic_context=base_symbolic_context,
)
def is_c_of_r(complex_dtype, real_dtype):
return (
utils.is_complex_dtype(complex_dtype)
and utils.corresponding_real_dtype(complex_dtype)
== real_dtype
)
# In some situations, MetaConverter may be called in a
# context where autograd is disabled. For the _is_view
# assert to pass, we have to setup the autograd view
# metadata anyway. Do this by reenabling the
# ADInplaceOrView key. This is kind of a hack.
old_exclude = torch._C._dispatch_tls_is_dispatch_key_excluded(
torch._C.DispatchKey.ADInplaceOrView
)
torch._C._dispatch_tls_set_dispatch_key_excluded(
torch._C.DispatchKey.ADInplaceOrView, False
)
try:
if base.dtype == t.dtype:
pass
elif is_c_of_r(base.dtype, t.dtype):
base = torch.view_as_real(base)
elif is_c_of_r(t.dtype, base.dtype):
base = torch.view_as_complex(base)
else:
# This is not guaranteed to succeed. If it fails, it
# means there is another dtype-converting view function
# that hasn't been handled here
base = base.view(t.dtype)
# This is very tricky. Naively, you might expect this
# to hold:
#
# if t.requires_grad and not safe_is_leaf(t)
# assert t._base.requires_grad
#
# But it's not true! As you can see in the following
# program:
#
# x = torch.zeros(4)
# y = x.view(1, 4)
# y.requires_grad = True
# z = y.view(1, 1, 4)
# assert z._base is x
#
# So we may have to do *two* views out of the base to
# recreate this situation.
def _view_from_base(base, t):
if t.is_nested:
# Nested tensors do not support as_strided, and
# hence,always have _view_func available.
#
# The unsafe version of _view_func omits
# checking whether the base passed in has the same
# metadata as the original base the view_func
# was originally executed with. (1) It is OK here,
# because we're calling it on the meta-ified base,
# so the metadata is guaranteed to be the same.
# (2) It is necessary because we don't actually
# want to guard on the base's metadata here.
return t._view_func_unsafe(base)
else:
(
sizes,
strides,
storage_offset,
) = sym_sizes_strides_storage_offset(t, source)
return base.as_strided(sizes, strides, storage_offset)
if safe_is_leaf(t):
# Leaf views that track view metadata are created by
# creating a view inside a no_grad block
with torch.no_grad(), maybe_suppress():
r = _view_from_base(base, t)
# As it's a leaf, we can directly assign requires_grad
r.requires_grad = t.requires_grad
else:
if t._base.requires_grad == t.requires_grad:
# Easy case, just run the view op
with torch.enable_grad(), maybe_suppress():
r = _view_from_base(base, t)
# NB: We don't actaully faithfully replicate
# autograd connectivity, but that doesn't matter
# today. See following for more info:
# https://gist.github.com/soulitzer/e03f015b314c3f5fcf80888c69390913
else:
# Obscure case. Create a leaf view and give it the
# correct requires_grad, then do the final view.
# NB: Can't have a non-leaf without requiring grad!
assert t.requires_grad
with torch.no_grad():
mid = base.view(base.shape)
mid.requires_grad = t.requires_grad
with torch.enable_grad(), maybe_suppress():
r = _view_from_base(mid, t)
# The CreationMeta influences whether or not inplace
# mutation is an error or not. So we need to make
# sure we properly propagate this as well.
torch._C._autograd._set_creation_meta(
r, torch._C._autograd._get_creation_meta(t)
)
finally:
torch._C._dispatch_tls_set_dispatch_key_excluded(
torch._C.DispatchKey.ADInplaceOrView, old_exclude
)
else:
is_leaf = safe_is_leaf(t)
if not t.is_nested:
# Nested tensor subclasses have special logic for
# creating symbolic size/strides/storage_offset
(
sizes,
strides,
storage_offset,
) = sym_sizes_strides_storage_offset(t, source)
def empty_create(inner_t, inner_src):
(
inner_sizes,
inner_strides,
inner_storage_offset,
) = sym_sizes_strides_storage_offset(inner_t, inner_src)
return torch.empty_strided(
inner_sizes,
inner_strides,
dtype=inner_t.dtype,
device="meta",
)
# If we have a subclass that desugars into dense tensors,
# perform our callback on each inner tensor.
if is_traceable_wrapper_subclass(t):
# Note: transform_subclass will use __tensor_unflatten__ to generate
# a fresh subclass wrapper, which is why sizes/strides are not passed in
# to the creation function here.
# We assume that if the inner tensors of the subclass are given symbolic sizes,
# their sizes will be used to construct the (symbolic) sizes of the wrapper tensor.
from torch._dynamo.source import AttrSource
if t.is_nested:
# Avoid circular import
from torch._dynamo.source import (
TensorProperty,
TensorPropertySource,
)
# For nested tensors, manually do transform_subclass
# so we can insert some special processing on ctx
attrs, ctx = t.__tensor_flatten__()
transformed_tensors_dict = {}
orig_shape_env = None
for attr in attrs:
inner_t = getattr(t, attr)
if orig_shape_env is None:
orig_shape_env = (
inner_t.fake_mode.shape_env
if isinstance(inner_t, FakeTensor)
else None
)
transformed_tensors_dict[attr] = callback(
lambda: empty_create(
inner_t, AttrSource(source, attr)
)
)
# We expect JaggedTensor to have a 'ragged_size' in
# its context
assert isinstance(ctx, dict)
assert "ragged_size" in ctx
assert isinstance(t._size[1], torch.SymInt)
if orig_shape_env is shape_env:
# It's already fake and the shape envs line up, reuse the old size
# Do not assert singleton_int; it may already
# be a variable
ctx["ragged_size"] = t._size[1]
else:
assert t._size[1].node.singleton_int() is not None
# Replace the eager ragged size with our freshly
# allocated jagged size that has a source
ctx["ragged_size"] = shape_env.create_symintnode(
shape_env.create_symbol(
t._size[1],
TensorPropertySource(
source, TensorProperty.SIZE, 1
),
),
hint=t._size[1],
)
r = type(t).__tensor_unflatten__(
transformed_tensors_dict, ctx
)
else:
r = transform_subclass(
t,
lambda attr, inner_t: callback(
lambda: empty_create(
inner_t,
AttrSource(source, attr),
)
),
)
else:
r = callback(
lambda: torch.empty_strided(
sizes,
strides,
dtype=t.dtype,
device="meta",
)
)
assert safe_is_leaf(r), "the callback you passed in doesn't detach"
if t.requires_grad:
r.requires_grad = t.requires_grad
if not is_leaf:
# Fake up some autograd history.
with torch.enable_grad():
# preserve_format is the default, but we want to
# emphasize how important it is to preserve
# format here
r = r.clone(memory_format=torch.preserve_format)
# Graph-Break for wrapped tensors
if torch._C._functorch.is_functorch_wrapped_tensor(t):
return NotImplemented
s = t.untyped_storage()
swr = StorageWeakRef(s)
if swr not in self.storage_memo and (
r.is_nested
or (
r.stride() == strides
and r.storage_offset() == storage_offset
)
):
# You're normal and happy, install the fresh storage into the memo
self.storage_memo[swr] = r.untyped_storage()
else:
# You're in crazy town; somehow you gave us a tensor
# that wasn't a view, but had nonzero storage offset,
# nontrivial strides (such that clone() couldn't
# preserve them), or already aliases with another
# tensor's storage. The most typical way to end
# up here is with set_. So use set_ to bludgeon this
# in.
r_s = self.meta_storage(s, callback=callback)
# NB: In principle, this should always work, but there
# is some subtle difference in the autograd metadata
# that means we will backprop the set_ call, even if
# r is declared as an input to grad.
# See https://github.com/pytorch/pytorch/issues/87956
# for the reproducer.
# NB: The in_kernel_invocation_manager here is necessary
# for fake tensor. If we run the set_ call with fake
# tensor on, r will improperly report that it is NOT a
# meta tensor but a cpu tensor, and then the set_ call
# will fail due to device mismatch. no_dispatch() is
# not enough, because the fake tensor will still claim
# to be a CPU tensor and you'll end up in the CPU
# kernel. Arguably this is a hack; a cleaner way to
# solve this is to have a FakeStorage concept which
# would report it's CPU device--no problem now! But
# this is difficult to do because we don't have storage
# subclasses. Relevant test is
# DynamicShapesFunctionTests::test_add_dynamic_shapes in
# test/dynamo/test_dynamic_shapes.py
maybe_fake_mgr: ContextManager[None] = contextlib.nullcontext()
from torch._subclasses.fake_tensor import (
in_kernel_invocation_manager,
maybe_get_fake_mode,
)
mb_fake_mode = maybe_get_fake_mode(r)
if mb_fake_mode is not None:
maybe_fake_mgr = in_kernel_invocation_manager(mb_fake_mode)
with maybe_fake_mgr, torch.no_grad():
r.set_(r_s, storage_offset, sizes, strides)
if safe_grad(t) is not None:
from torch._dynamo.source import AttrSource
r.grad = self.meta_tensor(
safe_grad(t),
shape_env,
callback,
source=AttrSource(source, "grad"),
symbolic_context=symbolic_context,
)
torch._C._set_conj(r, t.is_conj())
torch._C._set_neg(r, t.is_neg())
# This can be skipped if necessary for performance reasons
assert_metadata_eq(assert_eq, t, r, skip_symbolic=True)
self.set_tensor_memo(t, r)
return self.get_tensor_memo(t)
def __call__(
self,
t,
shape_env=None,
*,
callback=lambda t: t(),
source=None,
symbolic_context=None,
):
# TODO: zero tensors? We appear to have eliminated them by
# excluding complex for now
if isinstance(t, torch.Tensor) or is_traceable_wrapper_subclass(t):
if t.device.type != "xla" and any(
[
t.is_sparse_csr,
t.layout in [torch.sparse_csc, torch.sparse_bsr, torch.sparse_bsc],
t.is_quantized,
t._is_view() and t._base is not None and t._base.is_sparse,
torch._is_functional_tensor(t),
t.device.type in ("lazy"),
# We need a way to test if a tensor is batched but there
# is no official APi to do it
# torch._C._is_batched(t),
]
):
# TODO: sparse should support meta
# NB technically to('meta') does work but our logging
# instrumentation will see the meta conversions and the
# tests all break so we just exclude this. In any case
# the to conversion isn't really right anyhow.
if torch._is_functional_tensor(t) and t.device.type != "lazy":
if t._is_view():
raise RuntimeError(
"Cannot safely fakify a view because this process drops the view information right now."
)
st = peek_interpreter_stack()
assert (
st is None or st.key() == TransformType.Functionalize
), "Expect st to be either None or have Functionalize transform key."
if st is None:
# the case of AOTAutograd
torch._sync(t)
unwrap_t = torch._from_functional_tensor(t)
with torch._dispatch.python.suspend_functionalization():
fake_t = self.meta_tensor(
unwrap_t,
shape_env=shape_env,
callback=callback,
source=source,
symbolic_context=symbolic_context,
)
out = torch._to_functional_tensor(fake_t)
torch._mirror_autograd_meta_to(fake_t, out)
return out
else:
# torch.func.functionalize
reapply_views = torch._C._functionalization_reapply_views_tls()
unwrap_t = _unwrap_functional_tensor(t, reapply_views)
pop_st_ctx = (
torch._functorch.pyfunctorch.temporarily_pop_interpreter_stack()
)
with pop_st_ctx:
fake_t = self.meta_tensor(
unwrap_t,
shape_env=shape_env,
callback=callback,
source=source,
symbolic_context=symbolic_context,
)
return _wrap_functional_tensor(fake_t, current_level())
self.miss += 1
return NotImplemented
else:
self.hit += 1
r = self.meta_tensor(
t,
shape_env=shape_env,
callback=callback,
source=source,
symbolic_context=symbolic_context,
)
if type(t) is torch.nn.Parameter:
# NB: Cannot directly use Parameter constructor
# because that would force a detach, not desirable
r._is_param = True
return r
elif torch.overrides.is_tensor_like(t):
self.miss += 1
return NotImplemented
else:
# non-Tensor types don't count as hit or miss
return t
import torch._prims_common as utils