| # mypy: allow-untyped-defs |
| from __future__ import annotations |
| |
| import contextlib |
| |
| import dataclasses |
| import warnings |
| import weakref |
| from dataclasses import dataclass |
| from typing import ( |
| Any, |
| Callable, |
| ClassVar, |
| ContextManager, |
| Dict, |
| List, |
| Optional, |
| Tuple, |
| Type, |
| TYPE_CHECKING, |
| Union, |
| ) |
| from typing_extensions import TypeAlias |
| |
| import torch |
| from torch._C._autograd import CreationMeta |
| from torch._C._functorch import ( |
| _add_batch_dim, |
| _unwrap_functional_tensor, |
| _wrap_functional_tensor, |
| get_unwrapped, |
| is_batchedtensor, |
| is_functorch_wrapped_tensor, |
| is_gradtrackingtensor, |
| is_legacy_batchedtensor, |
| maybe_get_bdim, |
| maybe_get_level, |
| peek_interpreter_stack, |
| ) |
| from torch._logging import trace_structured |
| from torch.utils._mode_utils import no_dispatch |
| |
| from torch.utils._python_dispatch import is_traceable_wrapper_subclass |
| from torch.utils.weak import WeakIdKeyDictionary |
| |
| if TYPE_CHECKING: |
| from torch._C._functorch import CInterpreter |
| from torch._guards import Source |
| |
| # Import here to avoid cycle |
| from torch._subclasses.fake_tensor import FakeTensorMode |
| |
| # Import the following modules during type checking to enable code intelligence features, |
| # Do not import unconditionally, as they import sympy and importing sympy is very slow |
| from torch.fx.experimental.symbolic_shapes import ShapeEnv, SymbolicContext |
| |
| DimList = List |
| |
| |
| def safe_is_leaf(t): |
| try: |
| return t.is_leaf |
| except RuntimeError: |
| # inference mode can trigger this |
| return False |
| |
| |
| def safe_grad(t): |
| with warnings.catch_warnings(): |
| warnings.filterwarnings("ignore", "The .grad attribute of a Tensor") |
| return t.grad |
| |
| |
| def assert_eq(a, b): |
| assert a == b, f"{a} != {b}" |
| |
| |
| def assert_metadata_eq( |
| assert_eq, |
| m1: Union[MetaTensorDesc, torch.Tensor], |
| m2: torch.Tensor, |
| *, |
| skip_symbolic=False, |
| skip_leaf=False, |
| ): |
| if isinstance(m1, torch.Tensor): |
| m1 = MetaTensorDescriber().describe_tensor(m1) |
| |
| def go(m1, m2): |
| assert_eq(m1.dtype, m2.dtype) |
| if not skip_symbolic: |
| assert_eq(m1.shape, m2.shape) |
| assert_eq(m1.requires_grad, m2.requires_grad) |
| if not skip_leaf: |
| assert_eq(m1.is_leaf, m2.is_leaf) |
| # MetaTensorDesc doesn't store grad_fn; inferred from leaf |
| # assert_eq(m1.grad_fn is None, m2.grad_fn is None) |
| assert_eq(m1.is_sparse, m2.is_sparse) |
| assert_eq(m1.is_inference, m2.is_inference()) |
| assert_eq(m1.is_conj, m2.is_conj()) |
| assert_eq(m1.is_neg, m2.is_neg()) |
| assert_eq(m1.grad is not None, safe_grad(m2) is not None) |
| if m1.grad is not None: |
| go(m1.grad, safe_grad(m2)) |
| if m1.is_sparse: |
| assert_eq(m1.dense_dim, m2.dense_dim()) |
| assert_eq(m1.sparse_dim, m2.sparse_dim()) |
| assert_eq(m1.is_coalesced, m2.is_coalesced()) |
| else: |
| if not skip_symbolic: |
| assert_eq(m1.stride, m2.stride()) |
| assert_eq(m1.storage_offset, m2.storage_offset()) |
| assert_eq(m1.is_view, m2._is_view()) |
| if m1.is_view: |
| go(m1.base, m2._base) |
| # TODO: test if is resizable (no direct query for this atm) |
| # TODO: audit AutogradMeta to see if it matches |
| # TODO: test forward AD |
| |
| return go(m1, m2) |
| |
| |
| def is_sparse_coo(t): |
| return isinstance(t, torch.Tensor) and t.layout is torch.sparse_coo |
| |
| |
| def is_sparse_compressed_layout(layout): |
| return layout in { |
| torch.sparse_csr, |
| torch.sparse_csc, |
| torch.sparse_bsr, |
| torch.sparse_bsc, |
| } |
| |
| |
| def is_sparse_compressed(t): |
| return isinstance(t, torch.Tensor) and is_sparse_compressed_layout(t.layout) |
| |
| |
| def is_sparse_any(t): |
| return is_sparse_coo(t) or is_sparse_compressed(t) |
| |
| |
| # Don't use id() directly, because those can get reallocated over time. |
| MetaStorageId: TypeAlias = int |
| MetaTensorId: TypeAlias = int |
| |
| |
| DESCRIBER_NEXT_ID = 0 |
| |
| |
| class MetaTensorDescriber: |
| """ |
| Given a Tensor/Storage, generate a MetaTensorDesc/MetaStorageDesc |
| for it, which is enough information to reconstruct a meta tensor/fake tensor |
| corresponding to a Tensor as faithfully as possible. |
| |
| This is a stateful conversion object because we keep track of the IDs |
| of the tensors/storages passed to us, so we can consistently give |
| the same ID when we see the same tensor/storage. |
| """ |
| |
| def __init__(self, *, copy_data=False): |
| global DESCRIBER_NEXT_ID |
| self.id = DESCRIBER_NEXT_ID |
| DESCRIBER_NEXT_ID += 1 |
| self.next_tensor_id: MetaTensorId = 0 |
| self.next_storage_id: MetaStorageId = 0 |
| # Tensor -> int |
| self.lookup_tensor = WeakIdKeyDictionary() |
| # Storage -> int |
| self.lookup_storage = WeakIdKeyDictionary() |
| self.copy_data = copy_data |
| self.traced_tensors = set() |
| self.traced_storages = set() |
| |
| def get_tensor_id(self, t: torch.Tensor): |
| if t not in self.lookup_tensor: |
| self.lookup_tensor[t] = self.next_tensor_id |
| self.next_tensor_id += 1 |
| return self.lookup_tensor[t] |
| |
| def get_storage_id(self, s: torch.UntypedStorage): |
| if s not in self.lookup_storage: |
| self.lookup_storage[s] = self.next_storage_id |
| self.next_storage_id += 1 |
| return self.lookup_storage[s] |
| |
| def describe_storage(self, s: torch.UntypedStorage, *, trace: bool = False): |
| r = MetaStorageDesc( |
| id=self.get_storage_id(s), |
| size=s.size(), |
| # NB: We don't do the copy yet; copy happens when we start |
| # creating the new storages |
| data=s if self.copy_data else None, |
| ) |
| if trace and r.id not in self.traced_storages: |
| trace_structured( |
| "describe_storage", |
| metadata_fn=lambda: r.as_json(self.id), |
| ) |
| self.traced_storages.add(r.id) |
| return r |
| |
| def describe_tensor( |
| self, t: torch.Tensor, *, recurse: bool = True, trace: bool = False |
| ): |
| is_leaf = safe_is_leaf(t) |
| is_view = t._is_view() |
| is_sparse = t.is_sparse |
| layout = t.layout |
| is_nested = t.is_nested |
| is_traceable_wrapper_subclass_v = is_traceable_wrapper_subclass(t) |
| is_functorch_wrapped = is_functorch_wrapped_tensor(t) |
| is_mkldnn = t.is_mkldnn |
| is_batchedtensor_v = is_batchedtensor(t) |
| is_legacy_batchedtensor_v = is_legacy_batchedtensor(t) |
| is_gradtrackingtensor_v = is_gradtrackingtensor(t) |
| is_functorch_batched_or_grad = is_batchedtensor_v or is_gradtrackingtensor_v |
| is_functional = torch._is_functional_tensor(t) |
| |
| storage = None |
| # NB: For compatibility, I default this to zero, as sometimes people |
| # still have stuffed zero into storage offset even though the tensor |
| # doesn't meaningfully have an offset |
| storage_offset = 0 |
| if not ( |
| is_sparse |
| or is_sparse_compressed_layout(layout) |
| or (is_nested and not is_traceable_wrapper_subclass_v) |
| or is_mkldnn |
| # TODO: TBH, functorch wrapped tensors probably should have |
| # storage associated with them |
| or is_functorch_wrapped |
| or is_legacy_batchedtensor_v |
| ): |
| # NB: We actually don't use storage to do views, but might as well |
| # put it in for accuracy |
| storage = self.describe_storage(t.untyped_storage(), trace=trace) |
| storage_offset = t.storage_offset() |
| |
| stride = None |
| if not ( |
| is_sparse |
| or is_sparse_compressed_layout(layout) |
| or (is_nested and not is_traceable_wrapper_subclass_v) |
| ): |
| # stride/storage_offset are called from is_functorch_wrapped, |
| # view_from_base, empty_create_subclass, |
| # sym_sizes_strides_storage_offset (empty_create) |
| stride = t.stride() |
| |
| # NB: this technically should refer to functorch unwrapped tensor, but |
| # I am (perhaps abusively) using it to store both the functorch and |
| # non-functorch functional tensor |
| unwrapped = None |
| autograd_meta_from = None |
| current_level = None |
| if is_batchedtensor_v or is_gradtrackingtensor_v: |
| unwrapped = self.describe_tensor(get_unwrapped(t), trace=trace) |
| # xla and lazy tensors present as functional tensors, but we want them |
| # to be handled specially |
| elif is_functional and t.device.type not in ("xla", "lazy"): |
| if t._is_view(): |
| raise RuntimeError( |
| "Cannot safely fakify a view because this process drops the view information right now." |
| ) |
| if not is_functorch_wrapped: |
| torch._sync(t) |
| unwrapped = self.describe_tensor( |
| torch._from_functional_tensor(t), trace=trace |
| ) |
| autograd_meta_from = t |
| else: |
| reapply_views = torch._C._functionalization_reapply_views_tls() |
| # NB: has side effects! |
| unwrapped = self.describe_tensor( |
| _unwrap_functional_tensor(t, reapply_views), trace=trace |
| ) |
| # TODO: It's pretty suspicious that functional tensors don't have |
| # valid level and thus we just grab whatever the current level |
| # is |
| current_level = torch._C._functorch.current_level() |
| |
| maybe_functorch_stack = None |
| if is_functorch_wrapped: |
| with torch._functorch.pyfunctorch.temporarily_clear_interpreter_stack() as maybe_functorch_stack: |
| pass |
| |
| attrs = None |
| ctx = None |
| type_v = None |
| if is_traceable_wrapper_subclass_v: |
| assert hasattr(t, "__tensor_flatten__") |
| raw_attrs, ctx = t.__tensor_flatten__() |
| attrs = { |
| attr: self.describe_tensor(getattr(t, attr), trace=trace) |
| for attr in raw_attrs |
| } |
| type_v = type(t) |
| |
| # TODO: Is it important to enable torch.inference_mode before querying |
| # these values? |
| r = MetaTensorDesc( |
| id=self.get_tensor_id(t), |
| storage=storage, |
| is_inference=t.is_inference(), |
| is_leaf=is_leaf, |
| requires_grad=t.requires_grad, |
| # NB: ndim should be OK too but there is a disaster at |
| # python test/dynamo/test_subclasses.py -k test_user_overidden_property_unsupported |
| # Actually, this means that we have a little bit of a problem |
| # here, which is that there is some sensitivity to how exactly an |
| # access is done if you have a __torch_function__ subclass. Maybe |
| # should disable torch function before doing accesses? |
| ndim=t.dim(), |
| dtype=t.dtype, |
| is_sparse=is_sparse, |
| is_mkldnn=is_mkldnn, |
| is_functorch_wrapped=is_functorch_wrapped, |
| is_batchedtensor=is_batchedtensor_v, |
| is_legacy_batchedtensor=is_legacy_batchedtensor_v, |
| is_gradtrackingtensor=is_gradtrackingtensor_v, |
| is_view=is_view, |
| is_conj=t.is_conj(), |
| is_neg=t.is_neg(), |
| is_parameter=isinstance(t, torch.nn.Parameter), |
| is_traceable_wrapper_subclass=is_traceable_wrapper_subclass_v, |
| is_nested=is_nested, |
| is_functional=is_functional, |
| layout=layout, |
| device=t.device, |
| size=t.size(), |
| stride=stride, |
| storage_offset=storage_offset, |
| dynamo_dynamic_indices=list(getattr(t, "_dynamo_dynamic_indices", set())), |
| sparse_dim=t.sparse_dim() |
| if t.is_sparse or is_sparse_compressed(t) |
| else None, |
| dense_dim=t.dense_dim() if t.is_sparse or is_sparse_compressed(t) else None, |
| is_coalesced=t.is_coalesced() if t.is_sparse else None, |
| # TODO: I actually think recursing here is correct, but we have at |
| # least an infinite cycle from base -> values -> base |
| # https://github.com/pytorch/pytorch/issues/122089 |
| crow_indices=self.describe_tensor( |
| t.crow_indices(), recurse=False, trace=trace |
| ) |
| if recurse and t.layout in {torch.sparse_csr, torch.sparse_bsr} |
| else None, |
| col_indices=self.describe_tensor( |
| t.col_indices(), recurse=False, trace=trace |
| ) |
| if recurse and t.layout in {torch.sparse_csr, torch.sparse_bsr} |
| else None, |
| ccol_indices=self.describe_tensor( |
| t.ccol_indices(), recurse=False, trace=trace |
| ) |
| if recurse and t.layout in {torch.sparse_csc, torch.sparse_bsc} |
| else None, |
| row_indices=self.describe_tensor( |
| t.row_indices(), recurse=False, trace=trace |
| ) |
| if recurse and t.layout in {torch.sparse_csc, torch.sparse_bsc} |
| else None, |
| values=self.describe_tensor(t.values(), recurse=False, trace=trace) |
| if recurse and is_sparse_compressed(t) |
| else None, |
| grad=self.describe_tensor(safe_grad(t), trace=trace) |
| if safe_grad(t) is not None |
| else None, |
| creation_meta=torch._C._autograd._get_creation_meta(t) |
| if t._is_view() |
| else None, |
| unwrapped=unwrapped, |
| level=maybe_get_level(t) |
| if is_batchedtensor_v or is_gradtrackingtensor_v |
| else None, |
| bdim=maybe_get_bdim(t) if is_batchedtensor_v else None, |
| base=self.describe_tensor(t._base, trace=trace) |
| if recurse and t._is_view() and t._base is not None |
| else None, |
| fake_mode=torch._subclasses.fake_tensor.maybe_get_fake_mode(t), |
| view_func=t._view_func_unsafe, |
| attrs=attrs, |
| ctx=ctx, |
| type=type_v, |
| # NB: even if functorch is enabled, don't actually save the |
| # interpreter stack here unless we are actually functorch wrapped; |
| # it's irrelevant for non-functorch stuff |
| functorch_stack=maybe_functorch_stack, |
| autograd_meta_from=autograd_meta_from, |
| current_level=current_level, |
| data=t if self.copy_data else None, |
| ) |
| if trace and r.id not in self.traced_tensors: |
| trace_structured( |
| "describe_tensor", |
| metadata_fn=lambda: r.as_json(self.id), |
| ) |
| self.traced_tensors.add(r.id) |
| return r |
| |
| |
| @dataclass(frozen=True) |
| class MetaStorageDesc: |
| id: MetaStorageId |
| size: int |
| # NB: this is only populated with copy_data True, it is not directly |
| # serializable in JSON, you want to do something special here anyway |
| data: Optional[torch.UntypedStorage] |
| |
| def as_json(self, describer_id): |
| return { |
| "id": self.id, |
| "describer_id": describer_id, |
| "size": self.size if isinstance(self.size, int) else repr(self.size), |
| } |
| |
| |
| @dataclass(frozen=True) |
| class MetaTensorDesc: |
| id: MetaTensorId |
| ndim: int |
| dtype: torch.dtype |
| device: torch.device |
| |
| # NB: Sometimes, size, stride and storage_offset contain SymInt, in which |
| # case this is NOT serializable. That only happens when you're |
| # re-fakeifying a fake tensor with an existing ShapeEnv... maybe we |
| # can get rid of this use case entirely. Notably, even if we are |
| # fakeifying a real tensor into a fake tensor with symbolic shapes, the |
| # size here is NOT dynamic |
| # NB: These also contain SymInt because wrap_meta_outputs_with_default_device_logic |
| # goes through this codepath. But it really should not LOL. |
| # NB: size could potentially be None as you can override it and make it |
| # throw an error, but we don't currently have any subclasses that do this |
| # except C++ nested tensor but we're going to have nested int to make this |
| # defined on NJT |
| size: Tuple[int, ...] |
| dynamo_dynamic_indices: List[int] |
| |
| layout: torch.layout = torch.strided |
| is_inference: bool = False |
| is_leaf: bool = False |
| requires_grad: bool = False |
| is_sparse: bool = False |
| is_mkldnn: bool = False |
| is_functorch_wrapped: bool = False |
| is_batchedtensor: bool = False |
| is_legacy_batchedtensor: bool = False |
| is_gradtrackingtensor: bool = False |
| is_view: bool = False |
| is_nested: bool = False |
| is_traceable_wrapper_subclass: bool = False |
| is_functional: bool = False |
| is_conj: bool = False |
| is_neg: bool = False |
| is_parameter: bool = False |
| stride: Optional[Tuple[int, ...]] = None |
| storage_offset: int = 0 |
| # NB: We have a choice whether or not to store the id or a direct pointer |
| # to the data structure. For ease of use, we store the data structure, |
| # but this means that when we serialize, we have to swizzle these pointers |
| # back into ids (so we have accurate aliasing relationships) |
| storage: Optional[MetaStorageDesc] = None |
| sparse_dim: Optional[int] = None # is_sparse, is_sparse_compressed |
| dense_dim: Optional[int] = None # is_sparse, is_sparse_compressed |
| is_coalesced: Optional[bool] = None # is_sparse |
| crow_indices: Optional[MetaTensorDesc] = None # is_sparse_compressed |
| col_indices: Optional[MetaTensorDesc] = None # is_sparse_compressed |
| ccol_indices: Optional[MetaTensorDesc] = None # is_sparse_compressed |
| row_indices: Optional[MetaTensorDesc] = None # is_sparse_compressed |
| values: Optional[MetaTensorDesc] = None # is_sparse_compressed |
| unwrapped: Optional[MetaTensorDesc] = None # is_functorch_wrapped |
| bdim: Optional[int] = None # is_functorch_wrapped |
| base: Optional[MetaTensorDesc] = None # is_view |
| attrs: Optional[Dict[str, MetaTensorDesc]] = None # is_traceable_wrapper_subclass |
| creation_meta: Optional[CreationMeta] = None |
| grad: Optional[MetaTensorDesc] = None |
| |
| # Everything below is NOT serializable, need some more work |
| |
| _UNSERIALIZABLE: ClassVar[List[str]] = [ |
| "ctx", |
| "type", |
| "fake_mode", |
| "view_func", |
| "level", |
| "current_level", |
| "functorch_stack", |
| "autograd_meta_from", |
| "data", |
| ] |
| |
| ctx: Optional[object] = None # is_traceable_wrapper_subclass |
| type: Optional[Type] = None # is_traceable_wrapper_subclass |
| fake_mode: Optional[FakeTensorMode] = None |
| view_func: Optional[ |
| Callable[ |
| [ |
| torch.Tensor, |
| Callable[[int], int], |
| Callable[[torch.Tensor], torch.Tensor], |
| ], |
| torch.Tensor, |
| ] |
| ] = None |
| # level looks serializable, but actually it is meaningless without |
| # the functorch_stack below |
| level: Optional[int] = None # is_functorch_wrapped |
| current_level: Optional[int] = None |
| functorch_stack: Optional[List[CInterpreter]] = None |
| autograd_meta_from: Optional[torch.Tensor] = None |
| |
| # This is only populated on copy_data, and typically is not used at all, |
| # except for some of our meta-ification paths that don't properly use |
| # storage (pro-tip: you should use storage) |
| data: Optional[torch.Tensor] = None |
| |
| # Faithfully serializing functorch tensors will not be too difficult. |
| # We only need to consider grad/vmap interpreters, and their internal |
| # state is only bools (mostly what the grad enabled/disabled state |
| # should be in the lower layer). Beyond that, tensors just need to |
| # precisely indicate which particular interpreter they correspond |
| # to (we then replace level with a pointer to the interpreter stack.) |
| # However, this use of functorch is very "non-lexical" so it's not |
| # entirely clear how to make it all lexical again, so we haven't done |
| # it for now. |
| |
| # NB: This will reference numeric IDs, and it is assumed that you've |
| # already serialized everything this recursively references |
| def as_json(self, describer_id): |
| def json(k, v): |
| # Some best-effort debugging serialization for unserializable |
| # fields (feel free to add other special cases as appropriate) |
| if k in ["data", "autograd_meta_from"]: |
| return None # never repr these |
| if k in set(MetaTensorDesc._UNSERIALIZABLE): |
| return repr(v) |
| if isinstance(v, (torch.device, torch.dtype, torch.layout)): |
| return repr(v) |
| if isinstance(v, torch.SymInt): |
| return repr(v) |
| if isinstance(v, (tuple, list)): |
| return [json(k, v1) for v1 in v] |
| if isinstance(v, (MetaStorageDesc, MetaTensorDesc)): |
| return v.id |
| if isinstance(v, CreationMeta): |
| return str(v) |
| if k == "attrs" and isinstance(v, dict): |
| return {k1: v1.id for k1, v1 in v.items()} |
| return v |
| |
| r = { |
| field.name: json(field.name, getattr(self, field.name)) |
| for field in dataclasses.fields(self) |
| if not ( |
| getattr(self, field.name) is field.default |
| or ( |
| field.name == "dynamo_dynamic_indices" |
| and not getattr(self, field.name) |
| ) |
| ) |
| } |
| r.update({"describer_id": describer_id}) |
| return r |
| |
| @property |
| def shape(self): |
| return self.size |
| |
| |
| # A more faithful reproduction would do a copy on the entire |
| # storage, but this needs to be done carefully because the |
| # underlying storage could have larger extent than is implied |
| # by size/stride. The real fix is to properly call |
| # meta_storage recursively here. |
| # |
| # These "safe" functions are intended to be used under no_dispatch() mode. |
| # The no_dispatch() here is intended to prevent ambient fake tensor mode from |
| # fakeifying the operation. But if we are given an honest to goodness |
| # FakeTensor as src, we MUST NOT run the copy/clone operation. A better way |
| # to do this would be to not use no_dispatch and instead just disable fake |
| # tensor mode only (allowing for subclass dispatch to occur) |
| def _safe_copy(dst, src): |
| if type(src) is not torch.Tensor: |
| return |
| dst.copy_(src) |
| |
| |
| def _safe_clone(src): |
| if type(src) is not torch.Tensor: |
| return None |
| return src.clone() |
| |
| |
| # This is a class for converting multiple tensors into meta tensors which |
| # share the same view/storage structure. The operation model is you allocate |
| # one of these, and then call it repeatedly on all the tensors you want to |
| # convert. It's important to use the same object for tensors you want to |
| # share storage because this is how we correlate shared storages to the same |
| # meta storages. This class will hold weak references to cached tenosrs |
| # and tensor storages. |
| class MetaConverter: |
| def __init__(self, *, copy_data: bool = False): |
| # Maps MetaStorageId to UntypedStorage |
| self.storage_memo: weakref.WeakValueDictionary = weakref.WeakValueDictionary() |
| # Maps MetaTensorId to torch.Tensor (typically a meta tensor or |
| # FakeTensor) |
| self.tensor_memo: weakref.WeakValueDictionary = weakref.WeakValueDictionary() |
| self.hit = 0 |
| self.miss = 0 |
| self.del_hook = None |
| self.arg_cnt = 0 |
| # Ensures real_storage/real_tensor are populated on the resulting |
| # metaified storage/tensor. The naming of this attribute is load |
| # bearing: FakeTensor relies on real tensor being set to exactly this |
| # value |
| self.copy_data = copy_data |
| self.describer = MetaTensorDescriber(copy_data=copy_data) |
| |
| def successful(self): |
| return self.hit > 0 and self.miss == 0 |
| |
| def get_tensor_memo(self, t: MetaTensorDesc): |
| return self.tensor_memo.get(t.id, None) |
| |
| def set_tensor_memo(self, t: MetaTensorDesc, v): |
| self.tensor_memo[t.id] = v |
| |
| def get_storage_memo(self, s: MetaStorageDesc): |
| return self.storage_memo.get(s.id, None) |
| |
| def set_storage_memo(self, s: MetaStorageDesc, v): |
| self.storage_memo[s.id] = v |
| |
| def meta_storage(self, s: MetaStorageDesc, callback): |
| # If we are fakeifying a tensor that has a secretly-zero-sized storage, |
| # Need to make sure to resize the meta storage too. |
| if self.get_storage_memo(s) is None: |
| r_s = callback( |
| lambda: torch.empty(s.size, dtype=torch.uint8, device="meta"), |
| ).untyped_storage() |
| if self.copy_data: |
| # NB: no_dispatch is needed because internally storage copy is |
| # implemented as Tensor operations |
| with torch.no_grad(), no_dispatch(): |
| assert s.data is not None |
| r_s.real_storage = s.data.clone() |
| self.set_storage_memo(s, r_s) |
| return r_s |
| else: |
| return self.get_storage_memo(s) |
| |
| # This function assumes that it's possible to do the conversion |
| # NB: name here is used in a conventional way by Dynamo; it corresponds |
| # precisely to the Source.name() of the tensor we're fakeifying and |
| # corresponds to a valid Python expression. When we construct sub-names |
| # as part of this process, we will maintain this invariant! (Even though |
| # other users of this may not need it this property to be upheld.) |
| def meta_tensor( |
| self, |
| t: MetaTensorDesc, |
| shape_env: Optional[ShapeEnv] = None, |
| callback=lambda t: t(), |
| source: Optional[Source] = None, |
| symbolic_context: Optional[SymbolicContext] = None, |
| ): |
| if source is None: |
| from torch._dynamo.source import ConstantSource |
| |
| # TODO: make a dedicated UnknownSource for this? |
| source = ConstantSource( |
| f"__meta_utils_unknown_tensor{len(self.tensor_memo)}" |
| ) |
| |
| # This indicates you set no_dispatch() before calling into this |
| # function. This is an error: we may be creating fake tensors and |
| # will perform operations on them which need fake tensor mode to |
| # be active. You will segfault if you are in a no_dispatch() block. |
| assert not torch._C._dispatch_tls_local_exclude_set().has( |
| torch._C.DispatchKey.Python |
| ) |
| arg_cnt = self.arg_cnt |
| self.arg_cnt += 1 |
| |
| # When we make as_strided calls, we end up generating a guard |
| # that the new as_strided tensor is in bounds for the old storage |
| # for the base (since as_strided calls can "bust" out of their |
| # bounding box.) This guard is unnecessary: if a user is able |
| # to provide us a tensor with the view base setup this way, we |
| # don't need to produce a guard, because the fact that they |
| # were able to produce the view base means its in bounds. |
| # |
| # Now, ordinarily, this guard would be harmless. However, the |
| # generated guard refers to variables bound on the base variable. |
| # At the moment, Dynamo doesn't actually guard on x._base, because |
| # according to Voz this results in a lot of spurious invalidations, |
| # and also if the user doesn't directly make use of _base, its |
| # pointless anyway (because programs should be parametric over |
| # whether or not the input tensor is a view or not--unless you're |
| # mutating the input, but that's a whole 'nother ballgame). So |
| # for expediency, we suppress these guards so we don't have to |
| # deal with this (yet, anyway.) |
| # |
| # NB: An old version of this code suppressed guards for ALL operations |
| # happening during meta conversion, not just as_strided calls. |
| # This is too aggressive: we do duck sizing and 0/1 simplification |
| # as we allocate variables, and we do need to register guards for |
| # these cases. |
| maybe_suppress: Callable[[], Any] = contextlib.nullcontext |
| if shape_env is not None: |
| maybe_suppress = shape_env.suppress_guards |
| |
| def sym_sizes_strides_storage_offset( |
| t: MetaTensorDesc, src, symbolic_context=symbolic_context |
| ) -> Tuple[Tuple[int, ...], Tuple[int, ...], int]: |
| assert t.stride is not None |
| if shape_env is not None: |
| fake_mode = t.fake_mode |
| if fake_mode is not None and fake_mode.shape_env is shape_env: |
| # Don't reallocate the sizes; the shape envs are the same, |
| # so reuse the old sizes/strides/etc |
| return (t.size, t.stride, t.storage_offset) |
| else: |
| # TODO: deduplicate this |
| t_size = tuple( |
| shape_env._maybe_specialize_sym_int_with_hint(sz) |
| for sz in t.size |
| ) |
| t_stride = tuple( |
| shape_env._maybe_specialize_sym_int_with_hint(sd) |
| for sd in t.stride |
| ) |
| t_storage_offset = shape_env._maybe_specialize_sym_int_with_hint( |
| t.storage_offset |
| ) |
| return shape_env._create_symbolic_sizes_strides_storage_offset( |
| t_size, |
| t_stride, |
| t_storage_offset, |
| [d in t.dynamo_dynamic_indices for d in range(t.ndim)], |
| src, |
| symbolic_context=symbolic_context, |
| ) |
| else: |
| return (t.size, t.stride, t.storage_offset) |
| |
| def empty_create( |
| inner_t: MetaTensorDesc, inner_src, symbolic_context=symbolic_context |
| ): |
| ( |
| inner_sizes, |
| inner_strides, |
| inner_storage_offset, |
| ) = sym_sizes_strides_storage_offset(inner_t, inner_src, symbolic_context) |
| return torch.empty_strided( |
| inner_sizes, |
| inner_strides, |
| dtype=inner_t.dtype, |
| device="meta", |
| ) |
| |
| # Creates a subclass instance with empty inner tensors according to the specified |
| # symbolic context. |
| def empty_create_subclass( |
| t: MetaTensorDesc, |
| outer_size, |
| outer_stride, |
| symbolic_context=symbolic_context, |
| callback=callback, |
| source=source, |
| ): |
| from torch._dynamo.source import AttrSource |
| from torch.fx.experimental.symbolic_shapes import SubclassSymbolicContext |
| |
| assert t.attrs is not None |
| assert t.type is not None |
| # NB: t.ctx could be None if the subclass in question has no |
| # meaningful context |
| |
| assert symbolic_context is None or isinstance( |
| symbolic_context, SubclassSymbolicContext |
| ) |
| |
| # Note: transform_subclass will use __tensor_unflatten__ to generate |
| # a fresh subclass wrapper with outer sizes / strides according to the |
| # outer symbolic context (passed in to this function). Inner size / stride |
| # / storage offset symbols are allocated according to the appropriate inner |
| # symbolic contexts, after which the checks in transform_subclass() will |
| # relate them to the outer metadata as possible. |
| # |
| # Morally, the code here is same as transform_subclass, but we've |
| # written it from scratch to read EmptyCreateSubclass |
| |
| outer_size = outer_size if outer_size is not None else t.size |
| outer_stride = outer_stride if outer_stride is not None else t.stride |
| |
| def transform(attr, inner_t): |
| r = callback( |
| lambda: empty_create( |
| inner_t, |
| AttrSource(source, attr), |
| symbolic_context=( |
| None |
| if symbolic_context is None |
| else symbolic_context.inner_contexts[attr] |
| ), |
| ) |
| ) |
| if self.copy_data: |
| with torch.no_grad(), no_dispatch(): |
| r.real_tensor = torch.empty_strided( |
| inner_t.size, |
| inner_t.stride, |
| dtype=inner_t.dtype, |
| device=inner_t.device, |
| ) |
| assert inner_t.data is not None |
| _safe_copy(r.real_tensor, inner_t.data) |
| return r |
| |
| transformed_tensors_dict = { |
| attr: transform(attr, inner_t) for attr, inner_t in t.attrs.items() |
| } |
| |
| sub = t.type.__tensor_unflatten__( |
| transformed_tensors_dict, t.ctx, outer_size, outer_stride |
| ) |
| |
| # NB: Purposefully guard here to simplify the inner / outer symbols. |
| # Using sym_eq() for symbolic comparison can result in an expression that's too |
| # difficult to guard on, so we use == here. |
| assert sub.shape == outer_size, ( |
| f"Expected return value from {t.type}__tensor_unflatten__() to have " |
| f"shape equal to {outer_size}, but got: {sub.shape}" |
| ) |
| assert sub.stride() == outer_stride, ( |
| f"Expected return value from {t.type}__tensor_unflatten__() to have " |
| f"stride equal to {outer_stride}, but got: {sub.stride()}" |
| ) |
| |
| return sub |
| |
| # Returns an all-dynamic symbolic context used for metafying the given tensor with |
| # fully dynamic dims. This is useful when fake-ifying intermediate tensors in |
| # closed-over ViewFunc state, as we don't have symbolic contexts for them, but we |
| # don't want to over-specialize during view replay. |
| def all_dynamic_symbolic_context( |
| t: MetaTensorDesc, source, shape_env, callback |
| ): |
| from torch._dynamo.source import AttrSource |
| from torch.fx.experimental.symbolic_shapes import ( |
| DimDynamic, |
| StatelessSymbolicContext, |
| SubclassSymbolicContext, |
| ) |
| |
| view_base_context: Optional[SymbolicContext] = None |
| if t.is_view: |
| assert t.base is not None |
| view_base_context = all_dynamic_symbolic_context( |
| t.base, AttrSource(source, "_base"), shape_env, callback |
| ) |
| |
| t_symbolic_context: SymbolicContext |
| t_dynamic_sizes = [DimDynamic.DYNAMIC] * t.ndim |
| if t.is_traceable_wrapper_subclass: |
| assert t.attrs is not None |
| inner_contexts: Dict[str, SymbolicContext] = {} |
| for attr, inner in t.attrs.items(): |
| assert isinstance(attr, str) |
| inner_contexts[attr] = all_dynamic_symbolic_context( |
| inner, AttrSource(source, attr), shape_env, callback |
| ) |
| t_symbolic_context = SubclassSymbolicContext( |
| dynamic_sizes=t_dynamic_sizes, |
| constraint_sizes=[None] * t.ndim, |
| inner_contexts=inner_contexts, |
| tensor_source=source, |
| view_base_context=view_base_context, |
| ) |
| else: |
| t_symbolic_context = StatelessSymbolicContext( |
| dynamic_sizes=t_dynamic_sizes, |
| constraint_sizes=[None] * t.ndim, |
| view_base_context=view_base_context, |
| ) |
| |
| return t_symbolic_context |
| |
| # Returns a fake-ified version of an input view tensor t, given an already fake-ified |
| # base. At a high level, we want two things: |
| # 1. fake_t should have the same view relationship to the given fake base as the |
| # input t has to its _base. |
| # 2. fake_t should have symbolic sizes / strides / storage offset according to the |
| # appropriate symbolic context (i.e. from the automatic dynamic algorithm). |
| # |
| # We currently take different strategies across view types: |
| # * For dense -> dense views, accomplish both (1) and (2) simultaneously via an |
| # as_strided() call on the fake-ified base, passing symbolic metadata. |
| # * For views involving subclasses, perform view replay using view funcs to |
| # achieve (1). It's necessary for (2) to swap out any closed-over state in |
| # the view funcs with symbolicized SymInts and fake-ified tensors. Doing this |
| # avoids specialization (and thus over-eager simplification of symbols) that |
| # could occur during view replay on the fake-ified base. |
| # |
| # Examples: |
| # * t.unsqueeze(-1) with dense t is a dense -> dense view. It can be modeled |
| # with an as_strided() call on the fake base passing symbolic metadata. |
| # * sub.select(dim=0, index=3) is a subclass -> subclass view. The index arg |
| # is made symbolic to avoid invalid specialization and view replay is then |
| # done to reconstruct the view. |
| # * _nested_from_jagged(values, offsets) is a dense -> subclass view |
| # that returns a subclass instance from a dense values tensor. The offsets |
| # tensor is closed over in the view func, as it can be considered view metadata. |
| # First, the offsets tensor is fake-ified according to the inner symbolic |
| # context and with the correct relationship to the outer size / stride metadata. |
| # Then view replay is done, swapping in the fake offsets so the view replay output |
| # is fully fake with no invalid specialization. |
| def view_from_base( |
| base: torch.Tensor, t: MetaTensorDesc, source=source, shape_env=shape_env |
| ): |
| # fake-ify t's metadata according to the outer symbolic context |
| (sizes, strides, storage_offset) = sym_sizes_strides_storage_offset( |
| t, source |
| ) |
| if ( |
| not t.is_traceable_wrapper_subclass |
| and not is_traceable_wrapper_subclass(base) |
| ): |
| # Dense -> Dense view case uses as_strided() to construct view relationship. |
| # TODO: Change this logic to use view replay for consistency? |
| # It's likely there is no view func available. |
| with maybe_suppress(): |
| return base.as_strided(sizes, strides, storage_offset) |
| |
| from torch._dynamo.source import EphemeralSource |
| from torch.fx.experimental.symbolic_shapes import ( |
| StatelessSymbolicContext, |
| sym_eq, |
| ) |
| |
| def symint_visitor_fn(s): |
| nonlocal symbolic_context |
| from torch.fx.experimental.symbolic_shapes import DimDynamic |
| |
| all_static_sizes = ( |
| symbolic_context is not None |
| and isinstance(symbolic_context, StatelessSymbolicContext) |
| and all( |
| x is DimDynamic.STATIC for x in symbolic_context.dynamic_sizes |
| ) |
| ) |
| # Can't just rely on shape env being None - dynamo always initializes it |
| if all_static_sizes or shape_env is None: |
| return s |
| |
| # NB: The symbol here is expected to be simplified out because we a priori |
| # allocate inner and outer symbols according to the appropriate symbolic |
| # contexts and prefer those over this symbol during symbol simplification |
| # (via usage of EphemeralSource below). This -shouldn't- happen, but if |
| # this symbol somehow leaks out beyond the view tensor's shape metadata, our |
| # assumption of it being simplified out will fail and it may be guarded on, |
| # which will hard error. |
| sym_source = EphemeralSource("symint_visitor_fn") |
| symbol = shape_env.create_symbol(s, sym_source) |
| return shape_env.create_symintnode(symbol, hint=s, source=sym_source) |
| |
| real_to_fake_mapping = {} |
| if t.is_traceable_wrapper_subclass: |
| assert t.attrs is not None |
| # NB: t.ctx could be None if the subclass in question has no |
| # meaningful context |
| assert t.type is not None |
| |
| # Fake-ify t naively here; this is only done so we can get fake-ified inner |
| # tensors with the correct relationships to the outer sizes / strides for use |
| # in view replay. It's done beforehand here because it's not easy to do when |
| # visiting tensors one-by-one during view replay. |
| # |
| # Example: |
| # Consider a Dense -> NJT view. NJT has (values, offsets) components and we |
| # want a view of values with the offsets closed over. As the offsets component |
| # is needed to describe the output view, it's important that it's fakeified |
| # correctly. |
| fake_t = empty_create_subclass( |
| t, outer_size=sizes, outer_stride=strides |
| ) |
| attrs, _ = fake_t.__tensor_flatten__() |
| for attr in attrs: |
| real_to_fake_mapping[t.attrs[attr].id] = getattr(fake_t, attr) |
| |
| def tensor_visitor_fn( |
| visited_t: torch.Tensor, |
| # These arguments are never passed, we just use them to close |
| # over these relevant values |
| shape_env=shape_env, |
| callback=callback, |
| ): |
| # It's possible to close over an undefined tensor (e.g. NJT's lengths). |
| if visited_t is None: |
| return None |
| |
| # NB: visited_t being a Tensor here is very naughty! Should |
| # have already been described |
| |
| # Fake inner tensors of view subclasses will come from the mapping built above. |
| visited_id = self.describer.get_tensor_id(visited_t) |
| fake_visited_t = real_to_fake_mapping.get(visited_id, None) |
| if fake_visited_t is not None: |
| return fake_visited_t |
| |
| visited_desc = self.describer.describe_tensor(visited_t) |
| |
| # For other closed-over tensor state, fake-ify it as all dynamic with an |
| # ephemeral source. This avoids invalid specialization during view replay. |
| # If we find that in practice the usage of ephemeral sources isn't enough |
| # to guarantee that we don't have guards on these symbols, we may need to |
| # explicitly suppress guards (as is done for _base in the dense -> dense |
| # view case). |
| temp_source = EphemeralSource("tensor_visitor_fn") |
| return self.meta_tensor( |
| visited_desc, |
| shape_env, |
| callback, |
| source=temp_source, |
| symbolic_context=all_dynamic_symbolic_context( |
| visited_desc, temp_source, shape_env, callback |
| ), |
| ) |
| |
| # Replay the view, swapping out any non-symbolic SymInts or real tensors |
| # for symbolic SymInts or fake tensors. |
| assert t.view_func is not None |
| # NB: we do NOT suppress guards here, we need to remove ephemeral |
| # sources |
| fake_t = t.view_func(base, symint_visitor_fn, tensor_visitor_fn) |
| |
| # Ensure the output has symbolic shapes according to the outer symbolic context. |
| # These checks should simplify out any symbols created for closed-over view func |
| # SymInts. |
| torch._check(sym_eq(fake_t.size(), sizes)) |
| torch._check(sym_eq(fake_t.stride(), strides)) |
| torch._check(sym_eq(fake_t.storage_offset(), storage_offset)) |
| return fake_t |
| |
| if self.get_tensor_memo(t) is None: |
| GRAD_TENSOR_SENTINEL_VALUE = -2 |
| |
| with torch.inference_mode(t.is_inference): |
| if t.is_sparse: |
| is_leaf = t.is_leaf |
| |
| # The lambda function below is similar to |
| # `t.to(device='meta')` except the latter |
| # preserves nnz value |
| r = callback( |
| lambda: torch.ops.aten._sparse_coo_tensor_with_dims( |
| t.sparse_dim, |
| t.dense_dim, |
| t.size, |
| dtype=t.dtype, |
| layout=torch.sparse_coo, |
| device="meta", |
| ) |
| ) |
| if self.copy_data: |
| # Pray that sparse clone doesn't lose information |
| assert t.data is not None |
| with torch.no_grad(), no_dispatch(): |
| r.real_tensor = _safe_clone(t.data) |
| assert safe_is_leaf(r), "the callback you passed in doesn't detach" |
| # Note [is_coalesced is dispatched] |
| # Strangely enough, is_coalesced() is a dispatched operator, |
| # which means that it will get caught by fake tensor mode. |
| # Ordinarily this would error, but there's some logic in |
| # fake tensor ensure this doesn't happen. |
| r._coalesced_(t.is_coalesced) |
| if t.requires_grad: |
| r.requires_grad = True |
| if t.requires_grad and not is_leaf: |
| # This should probably use DelayedError, |
| # but clone is fine for now for sparse tensors. |
| # (DelayedError does not work for sparse because it causes |
| # the Fake sparse tensor to "lose" its fakeness) |
| r = r.clone() |
| with torch.enable_grad(): |
| r._coalesced_(t.is_coalesced) |
| elif is_sparse_compressed_layout(t.layout): |
| is_leaf = t.is_leaf |
| |
| if t.layout in {torch.sparse_bsr, torch.sparse_bsc}: |
| assert t.sparse_dim is not None |
| assert t.dense_dim is not None |
| assert t.values is not None |
| batch_dim = t.ndim - t.sparse_dim - t.dense_dim |
| blocksize = t.values.shape[batch_dim + 1 : batch_dim + 3] |
| else: |
| blocksize = () |
| if t.layout in {torch.sparse_csr, torch.sparse_bsr}: |
| assert t.crow_indices is not None |
| index_dtype = t.crow_indices.dtype |
| else: |
| assert t.ccol_indices is not None |
| index_dtype = t.ccol_indices.dtype |
| |
| r = callback( |
| lambda: torch.ops.aten._sparse_compressed_tensor_with_dims( |
| 0, |
| t.dense_dim, |
| t.shape, |
| blocksize, |
| index_dtype, |
| layout=t.layout, |
| dtype=t.dtype, |
| device="meta", |
| ) |
| ) |
| if self.copy_data: |
| # Pray sparse clone doesn't lose information |
| assert t.data is not None |
| with torch.no_grad(), no_dispatch(): |
| r.real_tensor = _safe_clone(t.data) |
| assert safe_is_leaf(r), "the callback you passed in doesn't detach" |
| if t.requires_grad: |
| r.requires_grad = True |
| if t.requires_grad and not is_leaf: |
| r = torch._C._functions.DelayedError( |
| "Internal error: Tried to backward() through example input", |
| 1, |
| )(r) |
| elif t.is_nested and not t.is_traceable_wrapper_subclass: |
| # TODO: Handle this better in Dynamo? |
| # There are checks there now, but this can still be triggered by a dense |
| # tensor graph input that is a view of a strided NT. |
| from torch._dynamo.exc import unimplemented |
| |
| unimplemented( |
| "strided nested tensors are not supported by meta conversion" |
| ) |
| elif t.is_mkldnn: |
| is_leaf = t.is_leaf |
| sizes, strides, _storage_offset = sym_sizes_strides_storage_offset( |
| t, source |
| ) |
| # TODO: This doesn't seem right, where's the MKLDNN'ness |
| # lol |
| r = callback( |
| lambda: torch.empty_strided( |
| sizes, strides, dtype=t.dtype, device="meta" |
| ) |
| ) |
| if self.copy_data: |
| with torch.no_grad(), no_dispatch(): |
| assert t.size is not None |
| assert t.stride is not None |
| r.real_tensor = torch.empty_strided( |
| t.size, t.stride, dtype=t.dtype, device=t.device |
| ) |
| assert t.data is not None |
| _safe_copy(r.real_tensor, t.data) |
| assert safe_is_leaf(r), "the callback you passed in doesn't detach" |
| if t.requires_grad: |
| r.requires_grad = True |
| if t.requires_grad and not is_leaf: |
| r = torch._C._functions.DelayedError( |
| "Internal error: Tried to backward() through example input", |
| 1, |
| )(r) |
| elif t.is_functorch_wrapped: |
| if t.is_view: |
| from torch._dynamo.exc import unimplemented |
| |
| unimplemented( |
| "view functorch tensors are not supported by meta conversion" |
| ) |
| |
| # Wraps a functorch tensor class (BatchedTensor, GradTrackingTensor) |
| # in a FakeTensor |
| def _to_fake_tensor(t: MetaTensorDesc): |
| # TODO: why aren't the recursive calls going to |
| # meta_tensor |
| if t.is_batchedtensor: |
| assert t.unwrapped is not None |
| assert t.level is not None |
| assert t.bdim is not None |
| ft = _to_fake_tensor(t.unwrapped) |
| lvl = t.level |
| bdim = t.bdim |
| # You cannot create functorch tensors without |
| # having the ambient funtorch interpreter stack |
| # available, as the level refers to things in the |
| # stack |
| with torch._functorch.pyfunctorch.temporarily_restore_interpreter_stack( |
| t.functorch_stack |
| ): |
| r = _add_batch_dim(ft, bdim, lvl) |
| elif t.is_gradtrackingtensor: |
| assert t.unwrapped is not None |
| assert t.level is not None |
| disable_functorch = torch._C._DisableFuncTorch |
| with disable_functorch(): |
| ft = _to_fake_tensor(t.unwrapped) |
| lvl = t.level |
| if lvl == GRAD_TENSOR_SENTINEL_VALUE: |
| r = ft |
| else: |
| with torch._functorch.pyfunctorch.temporarily_restore_interpreter_stack( |
| t.functorch_stack |
| ): |
| r = torch._C._functorch._wrap_for_grad(ft, lvl) |
| |
| is_leaf = t.is_leaf |
| if t.requires_grad and safe_is_leaf(r): |
| r.requires_grad = True |
| elif t.requires_grad and not is_leaf: |
| r = torch._C._functions.DelayedError( # type: ignore[assignment] |
| "Internal error: Tried to backward() through example input", |
| 1, |
| )( |
| r # type: ignore[arg-type] |
| ) |
| elif t.is_functional: |
| assert t.unwrapped is not None |
| assert t.current_level is not None |
| ft = self.meta_tensor( |
| t.unwrapped, |
| shape_env=shape_env, |
| callback=callback, |
| # NB: reuse these exactly, we treat the |
| # functional tensor as "invisible". |
| # TODO: Actually this all probably doesn't |
| # work, take a closer look. |
| source=source, |
| symbolic_context=symbolic_context, |
| ) |
| r = _wrap_functional_tensor(ft, t.current_level) |
| # TODO: is_leaf/requires_grad? |
| else: |
| assert t.stride is not None |
| |
| sizes = t.size |
| strides = t.stride |
| r = callback( |
| lambda: torch.empty_strided( |
| sizes, |
| strides, |
| dtype=t.dtype, |
| device="meta", |
| ) |
| ) |
| if self.copy_data: |
| with torch.no_grad(), no_dispatch(): |
| r.real_tensor = torch.empty_strided( # type: ignore[attr-defined] |
| t.size, |
| t.stride, |
| dtype=t.dtype, |
| device=t.device, |
| ) |
| assert t.data is not None |
| _safe_copy(r.real_tensor, t.data) # type: ignore[attr-defined] |
| return r |
| |
| r = _to_fake_tensor(t) |
| |
| elif t.is_functional and t.device.type not in ["xla", "lazy"]: |
| assert t.unwrapped is not None |
| assert not t.is_functorch_wrapped # handled above |
| unwrapped = self.meta_tensor( |
| t.unwrapped, |
| shape_env=shape_env, |
| callback=callback, |
| source=source, |
| symbolic_context=symbolic_context, |
| ) |
| r = torch._to_functional_tensor(unwrapped) |
| torch._mirror_autograd_meta_to(t.autograd_meta_from, r) # type: ignore[attr-defined] |
| |
| elif t.is_view: |
| # Construct views in two steps: recursively meta-fy their |
| # base, and then create view(s) off that. NB: doing it |
| # directly from storage is WRONG because this won't cause |
| # version counters to get shared. |
| |
| assert t.base is not None |
| |
| base_symbolic_context = None |
| if shape_env and symbolic_context is not None: |
| from torch.fx.experimental.symbolic_shapes import ( |
| StatelessSymbolicContext, |
| ) |
| |
| assert isinstance(symbolic_context, StatelessSymbolicContext) |
| # NB: This should generally be set when the input is a view, |
| # but the exception right now is for fake-ifying grads, which is |
| # a work in progress. |
| if symbolic_context.view_base_context is not None: |
| base_symbolic_context = symbolic_context.view_base_context |
| |
| base = self.meta_tensor( |
| t.base, |
| shape_env, |
| callback, |
| source=torch._dynamo.source.AttrSource(source, "_base"), |
| symbolic_context=base_symbolic_context, |
| ) |
| |
| def is_c_of_r(complex_dtype, real_dtype): |
| return ( |
| utils.is_complex_dtype(complex_dtype) |
| and utils.corresponding_real_dtype(complex_dtype) |
| == real_dtype |
| ) |
| |
| # In some situations, MetaConverter may be called in a |
| # context where autograd is disabled. For the _is_view |
| # assert to pass, we have to setup the autograd view |
| # metadata anyway. Do this by reenabling the |
| # ADInplaceOrView key. This is kind of a hack. |
| old_exclude = torch._C._dispatch_tls_is_dispatch_key_excluded( |
| torch._C.DispatchKey.ADInplaceOrView |
| ) |
| torch._C._dispatch_tls_set_dispatch_key_excluded( |
| torch._C.DispatchKey.ADInplaceOrView, False |
| ) |
| try: |
| if base.dtype == t.dtype: |
| pass |
| elif is_c_of_r(base.dtype, t.dtype): |
| base = torch.view_as_real(base) |
| elif is_c_of_r(t.dtype, base.dtype): |
| base = torch.view_as_complex(base) |
| else: |
| # This is not guaranteed to succeed. If it fails, it |
| # means there is another dtype-converting view function |
| # that hasn't been handled here |
| base = base.view(t.dtype) |
| |
| # This is very tricky. Naively, you might expect this |
| # to hold: |
| # |
| # if t.requires_grad and not safe_is_leaf(t) |
| # assert t._base.requires_grad |
| # |
| # But it's not true! As you can see in the following |
| # program: |
| # |
| # x = torch.zeros(4) |
| # y = x.view(1, 4) |
| # y.requires_grad = True |
| # z = y.view(1, 1, 4) |
| # assert z._base is x |
| # |
| # So we may have to do *two* views out of the base to |
| # recreate this situation. |
| if t.is_leaf: |
| # Leaf views that track view metadata are created by |
| # creating a view inside a no_grad block |
| with torch.no_grad(): |
| r = view_from_base(base, t) |
| # As it's a leaf, we can directly assign requires_grad |
| r.requires_grad = t.requires_grad |
| else: |
| if t.base.requires_grad == t.requires_grad: |
| # Easy case, just run the view op |
| with torch.enable_grad(): |
| r = view_from_base(base, t) |
| |
| # NB: We don't actaully faithfully replicate |
| # autograd connectivity, but that doesn't matter |
| # today. See following for more info: |
| # https://gist.github.com/soulitzer/e03f015b314c3f5fcf80888c69390913 |
| else: |
| # Obscure case. Create a leaf view and give it the |
| # correct requires_grad, then do the final view. |
| # NB: Can't have a non-leaf without requiring grad! |
| assert t.requires_grad |
| with torch.no_grad(): |
| mid = base.view(base.shape) |
| mid.requires_grad = t.requires_grad |
| with torch.enable_grad(): |
| r = view_from_base(mid, t) |
| # The CreationMeta influences whether or not inplace |
| # mutation is an error or not. So we need to make |
| # sure we properly propagate this as well. |
| assert t.creation_meta is not None |
| torch._C._autograd._set_creation_meta(r, t.creation_meta) |
| finally: |
| torch._C._dispatch_tls_set_dispatch_key_excluded( |
| torch._C.DispatchKey.ADInplaceOrView, old_exclude |
| ) |
| |
| else: |
| is_leaf = t.is_leaf |
| |
| # Graph-Break for wrapped tensors |
| if ( |
| not (t.is_batchedtensor or t.is_gradtrackingtensor) |
| and t.is_functorch_wrapped |
| ) or t.is_legacy_batchedtensor: |
| return NotImplemented |
| |
| ( |
| sizes, |
| strides, |
| storage_offset, |
| ) = sym_sizes_strides_storage_offset(t, source, symbolic_context) |
| |
| # If we have a subclass that desugars into dense tensors, |
| # perform our callback on each inner tensor. |
| if t.is_traceable_wrapper_subclass: |
| r = empty_create_subclass( |
| t, outer_size=sizes, outer_stride=strides |
| ) |
| else: |
| r = callback( |
| lambda: torch.empty_strided( |
| sizes, |
| strides, |
| dtype=t.dtype, |
| device="meta", |
| ) |
| ) |
| if self.copy_data: |
| with torch.no_grad(), no_dispatch(): |
| assert t.size is not None |
| assert t.stride is not None |
| r.real_tensor = torch.empty_strided( |
| t.size, t.stride, dtype=t.dtype, device=t.device |
| ) |
| _safe_copy(r.real_tensor, t.data) |
| |
| assert safe_is_leaf(r), "the callback you passed in doesn't detach" |
| if t.requires_grad: |
| r.requires_grad = t.requires_grad |
| if not is_leaf: |
| # Fake up some autograd history. |
| # Note: we *used* to call .clone() here to mock up some autograd history. |
| # This is bad for subclasses. |
| # Consider the case where you have a wrapper subclass that is contiguous, |
| # but its inner tensor is noncontiguous(). |
| # .clone() (or other ops) will have the side effect of changing |
| # the metadata of the inner tensor. |
| # So instead, we now have a dedicated fn to set autograd history, |
| # without inadvertently changing other metadata. |
| r = torch._C._functions.DelayedError( |
| "Internal error: Tried to backward() through example input", |
| 1, |
| )(r) |
| |
| s = t.storage |
| assert s is not None |
| if s.id not in self.storage_memo and ( |
| r.is_nested |
| or ( |
| r.stride() == strides |
| and r.storage_offset() == storage_offset |
| ) |
| ): |
| # You're normal and happy, install the fresh storage into the memo |
| self.set_storage_memo(s, r.untyped_storage()) |
| if self.copy_data: |
| r.untyped_storage().real_storage = ( |
| r.real_tensor.untyped_storage() |
| ) |
| else: |
| # You're in crazy town; somehow you gave us a tensor |
| # that wasn't a view, but had nonzero storage offset, |
| # nontrivial strides (such that clone() couldn't |
| # preserve them), or already aliases with another |
| # tensor's storage. The most typical way to end |
| # up here is with set_. So use set_ to bludgeon this |
| # in. |
| r_s = self.meta_storage(s, callback=callback) |
| # NB: In principle, this should always work, but there |
| # is some subtle difference in the autograd metadata |
| # that means we will backprop the set_ call, even if |
| # r is declared as an input to grad. |
| # See https://github.com/pytorch/pytorch/issues/87956 |
| # for the reproducer. |
| # NB: The in_kernel_invocation_manager here is necessary |
| # for fake tensor. If we run the set_ call with fake |
| # tensor on, r will improperly report that it is NOT a |
| # meta tensor but a cpu tensor, and then the set_ call |
| # will fail due to device mismatch. no_dispatch() is |
| # not enough, because the fake tensor will still claim |
| # to be a CPU tensor and you'll end up in the CPU |
| # kernel. Arguably this is a hack; a cleaner way to |
| # solve this is to have a FakeStorage concept which |
| # would report it's CPU device--no problem now! But |
| # this is difficult to do because we don't have storage |
| # subclasses. Relevant test is |
| # DynamicShapesFunctionTests::test_add_dynamic_shapes in |
| # test/dynamo/test_dynamic_shapes.py |
| maybe_fake_mgr: ContextManager[None] = contextlib.nullcontext() |
| from torch._subclasses.fake_tensor import ( |
| in_kernel_invocation_manager, |
| maybe_get_fake_mode, |
| ) |
| |
| mb_fake_mode = maybe_get_fake_mode(r) |
| if mb_fake_mode is not None: |
| maybe_fake_mgr = in_kernel_invocation_manager(mb_fake_mode) |
| with torch.no_grad(), maybe_suppress(): |
| with maybe_fake_mgr: |
| r.set_(r_s, storage_offset, sizes, strides) |
| if self.copy_data: |
| with torch.no_grad(), no_dispatch(): |
| r.real_tensor.set_( |
| r_s.real_storage, |
| t.storage_offset, |
| t.size, |
| t.stride, |
| ) |
| |
| if t.grad is not None: |
| from torch._dynamo.source import AttrSource |
| |
| # TODO: Use a valid grad-specific symbolic context instead of recycling |
| # the one from t. This isn't correct if e.g. t._is_view() != t.grad._is_view(). |
| r.grad = self.meta_tensor( |
| t.grad, |
| shape_env, |
| callback, |
| source=AttrSource(source, "grad"), |
| symbolic_context=symbolic_context, |
| ) |
| torch._C._set_conj(r, t.is_conj) |
| torch._C._set_neg(r, t.is_neg) |
| # This can be skipped if necessary for performance reasons |
| skip_leaf = ( |
| t.is_gradtrackingtensor and t.level == GRAD_TENSOR_SENTINEL_VALUE |
| ) |
| assert_metadata_eq(assert_eq, t, r, skip_symbolic=True, skip_leaf=skip_leaf) |
| # Thanks to storage resizing, it's possible to end up with a tensor |
| # that advertises a real size, but has a storage that actually has zero bytes. |
| # Need to reflect this in the generated FakeTensor. |
| if t.storage is not None and t.storage.size == 0: |
| r.untyped_storage().resize_(0) |
| |
| if t.is_parameter: |
| r._is_param = True |
| |
| self.set_tensor_memo(t, r) |
| |
| return self.get_tensor_memo(t) |
| |
| def __call__( |
| self, |
| t, |
| shape_env=None, |
| *, |
| callback=lambda t: t(), |
| source=None, |
| symbolic_context=None, |
| # Controls whether or not we should dump the tensor metadata to structured logs |
| # when source is not None. Because we refakify after Dynamo is done, |
| # we don't want to dump info again from AOTAutograd, it is redundant. |
| trace=True, |
| ): |
| # TODO: zero tensors? We appear to have eliminated them by |
| # excluding complex for now |
| |
| # Filter out cases we don't support |
| # TODO: This can probably be simplified quite a bit |
| if isinstance(t, torch.Tensor) or is_traceable_wrapper_subclass(t): |
| if ( |
| # Lazy tensors are not supported. Note that XLA is |
| # implemented on top of lazy tensor, not excluded here; we |
| # have some special handling for it; this is for XLA Dynamo |
| # integration |
| t.device.type == "lazy" |
| or |
| # Quantization is not supported |
| t.is_quantized |
| or |
| # Views out of sparse tensors not currently supported (plain |
| # sparse is supported htough) |
| (t._is_view() and t._base is not None and t._base.is_sparse) |
| ): |
| self.miss += 1 |
| return NotImplemented |
| else: |
| self.hit += 1 |
| elif torch.overrides.is_tensor_like(t): |
| self.miss += 1 |
| return NotImplemented |
| else: |
| # non-Tensor types don't count as hit or miss |
| return t |
| |
| if source is None: |
| trace = False |
| |
| # Describe the tensor. NB: do NOT disable ambient modes, we may need |
| # to query them when figuring out what to put in here |
| t_desc = self.describer.describe_tensor(t, trace=trace) |
| |
| if trace: |
| trace_structured( |
| "describe_source", |
| metadata_fn=lambda: { |
| "describer_id": self.describer.id, |
| "id": t_desc.id, |
| "source": source.name(), |
| }, |
| ) |
| |
| # Do the meta-fication. Here, we disable all the ambient modes, to |
| # better simulate what would be like to re-fakeify from a fresh |
| # process |
| with contextlib.ExitStack() as exit_stack: |
| exit_stack.enter_context(torch._dispatch.python.suspend_functionalization()) |
| st = peek_interpreter_stack() |
| if st is not None: |
| exit_stack.enter_context( |
| torch._functorch.pyfunctorch.temporarily_clear_interpreter_stack() |
| ) |
| |
| r = self.meta_tensor( |
| t_desc, |
| shape_env=shape_env, |
| callback=callback, |
| source=source, |
| symbolic_context=symbolic_context, |
| ) |
| |
| if type(t) is torch.nn.Parameter: |
| # NB: Cannot directly use Parameter constructor |
| # because that would force a detach, not desirable |
| r._is_param = True |
| |
| # TODO: return the description for later |
| return r |
| |
| |
| import torch._prims_common as utils |