| import collections |
| import dataclasses |
| from typing import Any |
| |
| from . import utils |
| from .bytecode_transformation import create_instruction |
| from .guards import Guard, GuardSource |
| from .utils import rename_implicit |
| |
| _GUARD_SOURCE_NN_MODULE = { |
| GuardSource.LOCAL: GuardSource.LOCAL_NN_MODULE, |
| GuardSource.GLOBAL: GuardSource.GLOBAL_NN_MODULE, |
| GuardSource.LOCAL_NN_MODULE: GuardSource.LOCAL_NN_MODULE, |
| GuardSource.GLOBAL_NN_MODULE: GuardSource.GLOBAL_NN_MODULE, |
| } |
| |
| _GUARD_SOURCE_NOT_NN_MODULE = { |
| GuardSource.LOCAL: GuardSource.LOCAL, |
| GuardSource.GLOBAL: GuardSource.GLOBAL, |
| GuardSource.LOCAL_NN_MODULE: GuardSource.LOCAL, |
| GuardSource.GLOBAL_NN_MODULE: GuardSource.GLOBAL, |
| } |
| |
| |
| def is_constant_source(source): |
| if isinstance(source, ConstantSource): |
| return True |
| try: |
| if source.guard_source() == GuardSource.CONSTANT: |
| return True |
| except NotImplementedError: |
| pass |
| |
| return False |
| |
| |
| @dataclasses.dataclass |
| class Source: |
| def reconstruct(self, codegen): |
| raise NotImplementedError() |
| |
| def guard_source(self): |
| raise NotImplementedError() |
| |
| def name(self): |
| raise NotImplementedError() |
| |
| def make_guard(self, fn, is_volatile=False): |
| if self.guard_source() is GuardSource.CONSTANT: |
| raise NotImplementedError() |
| return Guard(self.name(), self.guard_source(), fn, is_volatile) |
| |
| def is_nn_module(self): |
| return self.guard_source() in ( |
| GuardSource.LOCAL_NN_MODULE, |
| GuardSource.GLOBAL_NN_MODULE, |
| ) |
| |
| |
| @dataclasses.dataclass |
| class LocalSource(Source): |
| local_name: str |
| |
| def reconstruct(self, codegen): |
| return [codegen.create_load(self.local_name)] |
| |
| def guard_source(self): |
| return GuardSource.LOCAL |
| |
| def name(self): |
| return rename_implicit(self.local_name) |
| |
| |
| @dataclasses.dataclass |
| class RandomValueSource(Source): |
| random_call_index: int |
| |
| def reconstruct(self, codegen): |
| return [ |
| codegen.create_load(codegen.tx.output.random_values_var), |
| codegen.create_load_const(self.random_call_index), |
| create_instruction("BINARY_SUBSCR"), |
| ] |
| |
| def name(self): |
| return rename_implicit(f"random_value_{self.random_call_index}") |
| |
| |
| @dataclasses.dataclass |
| class GlobalSource(Source): |
| global_name: str |
| |
| def reconstruct(self, codegen): |
| return [codegen.create_load_global(self.global_name, add=True)] |
| |
| def guard_source(self): |
| return GuardSource.GLOBAL |
| |
| def name(self): |
| return self.global_name |
| |
| |
| @dataclasses.dataclass |
| class GlobalWeakRefSource(Source): |
| global_name: str |
| |
| def reconstruct(self, codegen): |
| return [ |
| codegen.create_load_global(self.global_name, add=True), |
| create_instruction("CALL_FUNCTION", 0), |
| ] |
| |
| def guard_source(self): |
| return GuardSource.GLOBAL |
| |
| def name(self): |
| return f"{self.global_name}()" |
| |
| |
| @dataclasses.dataclass |
| class AttrSource(Source): |
| base: Source |
| member: str |
| |
| def __init__(self, base, member): |
| super().__init__() |
| if "." in member: |
| member_parts = member.split(".") |
| self.base = AttrSource(base, ".".join(member_parts[:-1])) |
| self.member = member_parts[-1] |
| else: |
| self.base = base |
| self.member = member |
| |
| def reconstruct(self, codegen): |
| return self.base.reconstruct(codegen) + codegen.create_load_attrs(self.member) |
| |
| def guard_source(self): |
| return self.base.guard_source() |
| |
| def name(self): |
| if self.member.isnumeric(): |
| return f"getattr({self.base.name()}, {self.member!r})" |
| return f"{self.base.name()}.{self.member}" |
| |
| |
| @dataclasses.dataclass |
| class GetItemSource(Source): |
| base: Source |
| index: Any |
| |
| def reconstruct(self, codegen): |
| instrs = self.base.reconstruct(codegen) |
| |
| if isinstance(self.index, Source): |
| instrs.extend(self.index.reconstruct(codegen)) |
| else: |
| instrs.append(codegen.create_load_const(self.index)) |
| instrs.append(create_instruction("BINARY_SUBSCR")) |
| |
| return instrs |
| |
| def guard_source(self): |
| return self.base.guard_source() |
| |
| def name(self): |
| if isinstance(self.index, Source): |
| return f"{self.base.name()}[{self.index.name()}]" |
| else: |
| return f"{self.base.name()}[{self.index!r}]" |
| |
| |
| @dataclasses.dataclass |
| class TupleIteratorGetItemSource(GetItemSource): |
| def reconstruct(self, codegen): |
| codegen.load_import_from(utils.__name__, "tuple_iterator_getitem") |
| return self.base.reconstruct(codegen) + [ |
| codegen.create_load_const(self.index), |
| create_instruction("CALL_FUNCTION", 2), |
| ] |
| |
| def name(self): |
| return f"___tuple_iterator_getitem({self.base.name()}, {self.index!r})" |
| |
| |
| @dataclasses.dataclass |
| class TypeSource(Source): |
| base: Source |
| |
| def reconstruct(self, codegen): |
| codegen.load_import_from("builtins", "type") |
| return self.base.reconstruct(codegen) + [create_instruction("CALL_FUNCTION", 1)] |
| |
| def guard_source(self): |
| return self.base.guard_source() |
| |
| def name(self): |
| return f"type({self.base.name()})" |
| |
| |
| @dataclasses.dataclass |
| class ODictGetItemSource(Source): |
| base: Source |
| index: Any |
| |
| def reconstruct(self, codegen): |
| return ( |
| [codegen._create_load_const(collections.OrderedDict.__getitem__)] |
| + self.base.reconstruct(codegen) |
| + [ |
| codegen.create_load_const(self.index), |
| create_instruction("CALL_FUNCTION", 2), |
| ] |
| ) |
| |
| def guard_source(self): |
| return self.base.guard_source() |
| |
| def name(self): |
| return f"___odict_getitem({self.base.name()}, {self.index!r})" |
| |
| |
| @dataclasses.dataclass |
| class NNModuleSource(Source): |
| inner: Source |
| |
| def reconstruct(self, codegen): |
| return self.inner.reconstruct(codegen) |
| |
| def guard_source(self): |
| return _GUARD_SOURCE_NN_MODULE[self.inner.guard_source()] |
| |
| def name(self): |
| return self.inner.name() |
| |
| |
| class NotNNModuleSource(NNModuleSource): |
| def guard_source(self): |
| return _GUARD_SOURCE_NOT_NN_MODULE[self.inner.guard_source()] |
| |
| |
| @dataclasses.dataclass |
| class ConstantSource(Source): |
| source_name: str |
| |
| def reconstruct(self, codegen): |
| return [codegen.create_load_global(self.source_name, add=False)] |
| |
| def guard_source(self): |
| return GuardSource.CONSTANT |
| |
| def name(self): |
| return self.source_name |
| |
| def make_guard(self, fn, is_volatile=False): |
| raise NotImplementedError() |