| import collections |
| import dataclasses |
| import enum |
| from typing import Any, Optional, Union |
| |
| from torch._guards import ChainedSource, GuardSource, Source |
| |
| from . import utils |
| from .bytecode_transformation import create_call_function, create_instruction |
| from .utils import enum_repr |
| |
| _GUARD_SOURCE_NN_MODULE = { |
| GuardSource.LOCAL: GuardSource.LOCAL_NN_MODULE, |
| GuardSource.GLOBAL: GuardSource.GLOBAL_NN_MODULE, |
| GuardSource.LOCAL_NN_MODULE: GuardSource.LOCAL_NN_MODULE, |
| GuardSource.GLOBAL_NN_MODULE: GuardSource.GLOBAL_NN_MODULE, |
| GuardSource.GLOBAL_FSDP_MODULE: GuardSource.GLOBAL_FSDP_MODULE, |
| GuardSource.LOCAL_FSDP_MODULE: GuardSource.LOCAL_FSDP_MODULE, |
| } |
| |
| _GUARD_SOURCE_NOT_NN_MODULE = { |
| GuardSource.LOCAL: GuardSource.LOCAL, |
| GuardSource.GLOBAL: GuardSource.GLOBAL, |
| GuardSource.LOCAL_NN_MODULE: GuardSource.LOCAL, |
| GuardSource.GLOBAL_NN_MODULE: GuardSource.GLOBAL, |
| GuardSource.LOCAL_FSDP_MODULE: GuardSource.LOCAL, |
| GuardSource.GLOBAL_FSDP_MODULE: GuardSource.GLOBAL, |
| } |
| |
| |
| def is_constant_source(source): |
| if isinstance(source, ConstantSource): |
| return True |
| try: |
| if source.guard_source() == GuardSource.CONSTANT: |
| return True |
| except NotImplementedError: |
| pass |
| |
| return False |
| |
| |
| def is_input_source(source): |
| return source.guard_source() in [ |
| GuardSource.LOCAL, |
| GuardSource.GLOBAL, |
| GuardSource.LOCAL_NN_MODULE, |
| GuardSource.GLOBAL_NN_MODULE, |
| GuardSource.LOCAL_FSDP_MODULE, |
| GuardSource.GLOBAL_FSDP_MODULE, |
| ] |
| |
| |
| def reconstruct_getitem( |
| source: Union["GetItemSource", "ODictGetItemSource"], codegen, index_is_slice |
| ): |
| instrs = source.base.reconstruct(codegen) |
| |
| if isinstance(source.index, Source): |
| instrs.extend(source.index.reconstruct(codegen)) |
| else: |
| if index_is_slice: |
| assert isinstance(source, GetItemSource) |
| instrs.append(codegen.create_load_const(source.unpack_slice())) |
| else: |
| instrs.append(codegen.create_load_const(source.index)) |
| |
| return instrs |
| |
| |
| @dataclasses.dataclass(frozen=True) |
| class LocalSource(Source): |
| local_name: str |
| cell_or_freevar: bool = False |
| |
| def reconstruct(self, codegen): |
| return [codegen.create_load(self.local_name)] |
| |
| def guard_source(self): |
| return GuardSource.LOCAL |
| |
| def name(self): |
| return f"L[{repr(self.local_name)}]" |
| |
| |
| @dataclasses.dataclass(frozen=True) |
| class RandomValueSource(Source): |
| random_call_index: int |
| |
| def guard_source(self): |
| return GuardSource.RANDOM_VALUE |
| |
| def reconstruct(self, codegen): |
| return [ |
| codegen.create_load(codegen.tx.output.random_values_var), |
| codegen.create_load_const(self.random_call_index), |
| create_instruction("BINARY_SUBSCR"), |
| ] |
| |
| def name(self): |
| return f"random_value_{self.random_call_index}" |
| |
| |
| @dataclasses.dataclass(frozen=True) |
| class GeneratorStateSource(Source): |
| device: str |
| initial_seed: int |
| |
| def guard_source(self): |
| return GuardSource.RANDOM_VALUE |
| |
| def reconstruct(self, codegen): |
| # generator state is a torch.ByteTensor, so we reuse TensorVariable reconstruction in codegen.py |
| raise NotImplementedError() |
| |
| def name(self): |
| name = f"generator_state_{self.device}_{self.initial_seed}" |
| return f"L[{name}]" |
| |
| |
| @dataclasses.dataclass(frozen=True) |
| class GlobalSource(Source): |
| global_name: str |
| |
| def reconstruct(self, codegen): |
| return [codegen.create_load_global(self.global_name, False, add=True)] |
| |
| def guard_source(self): |
| return GuardSource.GLOBAL |
| |
| def name(self): |
| return f"G[{repr(self.global_name)}]" |
| |
| |
| @dataclasses.dataclass(frozen=True) |
| class DummyGlobalSource(Source): |
| def reconstruct(self, codegen): |
| raise NotImplementedError() |
| |
| def guard_source(self): |
| return GuardSource.GLOBAL |
| |
| def name(self): |
| return "" |
| |
| |
| @dataclasses.dataclass(frozen=True) |
| class GlobalWeakRefSource(Source): |
| global_name: str |
| |
| def reconstruct(self, codegen): |
| return [ |
| codegen.create_load_global(self.global_name, True, add=True), |
| *create_call_function(0, False), |
| ] |
| |
| def guard_source(self): |
| return GuardSource.GLOBAL |
| |
| def name(self): |
| return f"G[{repr(self.global_name)}]()" |
| |
| |
| @dataclasses.dataclass(frozen=True) |
| class AttrSource(ChainedSource): |
| member: str |
| |
| def __post_init__(self): |
| assert self.base, "Can't construct an AttrSource without a valid base source" |
| if "." in self.member: |
| member_parts = self.member.split(".") |
| object.__setattr__( |
| self, "base", AttrSource(self.base, ".".join(member_parts[:-1])) |
| ) |
| object.__setattr__(self, "member", member_parts[-1]) |
| |
| def reconstruct(self, codegen): |
| return self.base.reconstruct(codegen) + codegen.create_load_attrs(self.member) |
| |
| def guard_source(self): |
| return self.base.guard_source() |
| |
| def name(self): |
| if not self.member.isidentifier(): |
| return f"getattr({self.base.name()}, {self.member!r})" |
| return f"{self.base.name()}.{self.member}" |
| |
| |
| @dataclasses.dataclass(frozen=True) |
| class ParamBufferSource(AttrSource): |
| def guard_source(self): |
| return _GUARD_SOURCE_NN_MODULE[self.base.guard_source()] |
| |
| |
| class TensorProperty(enum.Enum): |
| SIZE = 0 |
| STRIDE = 1 |
| STORAGE_OFFSET = 2 |
| |
| def method_name(self): |
| if self is TensorProperty.SIZE: |
| return "size" |
| elif self is TensorProperty.STRIDE: |
| return "stride" |
| elif self is TensorProperty.STORAGE_OFFSET: |
| return "storage_offset" |
| |
| |
| @dataclasses.dataclass(frozen=True) |
| class TensorPropertySource(ChainedSource): |
| prop: TensorProperty |
| idx: Optional[int] = None # None for STORAGE_OFFSET |
| |
| def __post_init__(self): |
| assert self.base is not None |
| if self.prop is TensorProperty.STORAGE_OFFSET: |
| assert self.idx is None |
| else: |
| assert self.idx is not None |
| |
| def reconstruct(self, codegen): |
| instructions = [ |
| *self.base.reconstruct(codegen), |
| codegen.create_load_attr(self.prop.method_name()), |
| ] |
| if self.idx is not None: |
| instructions.append(codegen.create_load_const(self.idx)) |
| instructions.extend( |
| create_call_function(1 if self.idx is not None else 0, True) |
| ) |
| return instructions |
| |
| def guard_source(self): |
| return self.base.guard_source() |
| |
| def name(self): |
| if self.prop is TensorProperty.SIZE: |
| return f"{self.base.name()}.size()[{self.idx}]" |
| elif self.prop is TensorProperty.STRIDE: |
| return f"{self.base.name()}.stride()[{self.idx}]" |
| elif self.prop is TensorProperty.STORAGE_OFFSET: |
| assert self.idx is None |
| return f"{self.base.name()}.storage_offset()" |
| else: |
| raise AssertionError(f"unhandled {self.prop}") |
| |
| |
| @dataclasses.dataclass(frozen=True) |
| class NegateSource(ChainedSource): |
| def __post_init__(self): |
| assert self.base is not None |
| |
| def reconstruct(self, codegen): |
| raise NotImplementedError() |
| |
| def guard_source(self): |
| return self.base.guard_source() |
| |
| def name(self): |
| # NB: use method call so that function stripping regexes work |
| return f"{self.base.name()}.__neg__()" |
| |
| |
| @dataclasses.dataclass(frozen=True) |
| class DefaultsSource(ChainedSource): |
| idx_key: Union[int, str] |
| is_kw: bool = False |
| field: str = dataclasses.field(init=False, repr=False, compare=False) |
| _name: str = dataclasses.field(init=False, repr=False, compare=False) |
| |
| def __post_init__(self): |
| assert ( |
| self.base |
| ), "Base must be a valid source in order to properly track and guard this Defaults to its origin." |
| if self.is_kw: |
| assert isinstance(self.idx_key, str) |
| object.__setattr__(self, "field", "__kwdefaults__") |
| object.__setattr__( |
| self, "_name", f"{self.base.name()}.{self.field}['{self.idx_key}']" |
| ) |
| else: |
| assert isinstance(self.idx_key, int) |
| object.__setattr__(self, "field", "__defaults__") |
| object.__setattr__( |
| self, "_name", f"{self.base.name()}.{self.field}[{self.idx_key}]" |
| ) |
| |
| def reconstruct(self, codegen): |
| instrs = self.base.reconstruct(codegen) |
| instrs.extend(codegen.create_load_attrs(self.field)) |
| instrs.extend( |
| [ |
| codegen.create_load_const(self.idx_key), |
| create_instruction("BINARY_SUBSCR"), |
| ] |
| ) |
| return instrs |
| |
| def guard_source(self): |
| return self.base.guard_source() |
| |
| def name(self): |
| return self._name |
| |
| |
| @dataclasses.dataclass(frozen=True) |
| class GetItemSource(ChainedSource): |
| index: Any |
| index_is_slice: bool = False |
| |
| def __post_init__(self): |
| assert self.base is not None |
| if isinstance(self.index, slice): |
| # store the hashable version of the slice so the whole GetItemSource is hashable |
| super().__setattr__("index", self.index.__reduce__()) |
| super().__setattr__("index_is_slice", True) |
| |
| def reconstruct(self, codegen): |
| return [ |
| *reconstruct_getitem(self, codegen, index_is_slice=self.index_is_slice), |
| create_instruction("BINARY_SUBSCR"), |
| ] |
| |
| def guard_source(self): |
| return self.base.guard_source() |
| |
| def unpack_slice(self): |
| assert self.index_is_slice |
| slice_class, slice_args = self.index |
| return slice_class(*slice_args) |
| |
| def name(self): |
| if isinstance(self.index, Source): |
| return f"{self.base.name()}[{self.index.name()}]" |
| else: |
| if self.index_is_slice: |
| return f"{self.base.name()}[{self.unpack_slice()!r}]" |
| elif isinstance(self.index, enum.Enum): |
| return f"{self.base.name()}[{enum_repr(self.index, self.guard_source().is_local())}]" |
| else: |
| return f"{self.base.name()}[{self.index!r}]" |
| |
| |
| @dataclasses.dataclass(frozen=True) |
| class TupleIteratorGetItemSource(GetItemSource): |
| def reconstruct(self, codegen): |
| codegen.load_import_from(utils.__name__, "tuple_iterator_getitem") |
| return [ |
| *self.base.reconstruct(codegen), |
| codegen.create_load_const(self.index), |
| *create_call_function(2, True), |
| ] |
| |
| def name(self): |
| return f"___tuple_iterator_getitem({self.base.name()}, {self.index!r})" |
| |
| |
| @dataclasses.dataclass(frozen=True) |
| class TypeSource(ChainedSource): |
| def __post_init__(self): |
| assert self.base is not None |
| |
| def reconstruct(self, codegen): |
| codegen.load_import_from("builtins", "type") |
| return self.base.reconstruct(codegen) + create_call_function(1, True) |
| |
| def guard_source(self): |
| return self.base.guard_source() |
| |
| def name(self): |
| return f"type({self.base.name()})" |
| |
| |
| # NB - SuperSource is a weird one. |
| # it is our only source with 2 bases, so we use the objec |
| # as the base, rather than the type, since an invocation |
| # like super(Foo, foo) is represented here, the source object base is more spiritually |
| # aligned with the instance, rather than the type. |
| # This whole construction is questionable tho, and we should probably find a way to |
| # avoid this exception to our otherwise nice source parentage invariant. |
| @dataclasses.dataclass(frozen=True) |
| class SuperSource(ChainedSource): |
| type: Source |
| |
| def __post_init__(self): |
| assert self.type is not None |
| assert self.base is not None |
| |
| def reconstruct(self, codegen): |
| codegen.load_import_from("builtins", "super") |
| return ( |
| self.type.reconstruct(codegen) |
| + self.base.reconstruct(codegen) |
| + create_call_function(2, True) |
| ) |
| |
| def guard_source(self): |
| return self.base.guard_source() |
| |
| def name(self): |
| return f"super({self.type.name()}, {self.base.name()})" |
| |
| |
| @dataclasses.dataclass(frozen=True) |
| class ODictGetItemSource(ChainedSource): |
| index: Any |
| |
| def __post_init__(self): |
| assert self.base is not None |
| |
| def reconstruct(self, codegen): |
| return [ |
| codegen._create_load_const(collections.OrderedDict.__getitem__), |
| *reconstruct_getitem(self, codegen, index_is_slice=False), |
| *create_call_function(2, True), |
| ] |
| |
| def guard_source(self): |
| return self.base.guard_source() |
| |
| def name(self): |
| if isinstance(self.index, type): |
| rep = f'__load_module("{self.index.__module__}").{self.index.__qualname__}' |
| return f"___odict_getitem({self.base.name()}, {rep})" |
| elif isinstance(self.index, Source): |
| return f"___odict_getitem({self.base.name()}, {self.index.name()})" |
| else: |
| return f"___odict_getitem({self.base.name()}, {self.index!r})" |
| |
| |
| @dataclasses.dataclass(frozen=True) |
| class NNModuleSource(ChainedSource): |
| def reconstruct(self, codegen): |
| return self.base.reconstruct(codegen) |
| |
| def guard_source(self): |
| return _GUARD_SOURCE_NN_MODULE[self.base.guard_source()] |
| |
| def name(self): |
| return self.base.name() |
| |
| |
| @dataclasses.dataclass(frozen=True) |
| class NotNNModuleSource(NNModuleSource): |
| def guard_source(self): |
| return _GUARD_SOURCE_NOT_NN_MODULE[self.base.guard_source()] |
| |
| |
| @dataclasses.dataclass(frozen=True) |
| class FSDPNNModuleSource(NNModuleSource): |
| def guard_source(self): |
| return _GUARD_SOURCE_NN_MODULE[self.base.guard_source()] |
| |
| |
| @dataclasses.dataclass(frozen=True) |
| class GlobalStateSource(Source): |
| def name(self): |
| return "" |
| |
| def guard_source(self): |
| return GuardSource.GLOBAL |
| |
| |
| @dataclasses.dataclass(frozen=True) |
| class ConstantSource(Source): |
| source_name: str |
| |
| def reconstruct(self, codegen): |
| return [codegen.create_load_global(self.source_name, False, add=False)] |
| |
| def guard_source(self): |
| return GuardSource.CONSTANT |
| |
| def name(self): |
| return self.source_name |
| |
| def make_guard(self, fn, is_volatile=False): |
| raise NotImplementedError() |
| |
| |
| @dataclasses.dataclass(frozen=True) |
| class NumpyTensorSource(ChainedSource): |
| def name(self) -> str: |
| return f"__as_tensor({self.base.name()})" |
| |
| def guard_source(self): |
| return self.base.guard_source() |
| |
| def reconstruct(self, codegen): |
| codegen.load_import_from("torch", "as_tensor") |
| return self.base.reconstruct(codegen) + create_call_function(1, True) |
| |
| |
| # This is a synthetic source that is associated with the singleton |
| # shape env guard we always register for all frames. We get the actual |
| # guard contents from the ambient ShapeEnv |
| @dataclasses.dataclass(frozen=True) |
| class ShapeEnvSource(Source): |
| def name(self): |
| return "" |
| |
| def guard_source(self): |
| return GuardSource.SHAPE_ENV |
| |
| |
| def is_from_local_source(source: Source, *, allow_cell_or_freevar=True): |
| if isinstance(source, ChainedSource): |
| return is_from_local_source( |
| source.base, allow_cell_or_freevar=allow_cell_or_freevar |
| ) |
| if not isinstance(source, LocalSource): |
| return False |
| if not allow_cell_or_freevar and source.cell_or_freevar: |
| return False |
| return True |