| import collections |
| from typing import Any, Callable, Dict, List, Optional, Set |
| |
| from .. import variables |
| from ..exc import unimplemented |
| from ..source import AttrSource, Source |
| from ..utils import dict_values, identity, istype, odict_values |
| |
| |
| class MutableLocal: |
| """ |
| Marker used to indicate this (list, iter, etc) was constructed in |
| local scope and can be mutated safely in analysis without leaking |
| state. |
| """ |
| |
| def __hash__(self): |
| return id(self) |
| |
| def __eq__(self, other): |
| return self is other |
| |
| |
| # metaclass to call post_init |
| class HasPostInit(type): |
| def __call__(cls, *args, **kwargs): |
| obj = type.__call__(cls, *args, **kwargs) |
| obj.__post_init__(*args, **kwargs) |
| return obj |
| |
| |
| class VariableTracker(object, metaclass=HasPostInit): |
| """ |
| Base class for tracked locals and stack values |
| |
| VariableTracker instances are immutable and should be copied in |
| order to change them. |
| """ |
| |
| # fields to leave unmodified in apply() |
| _nonvar_fields = ["value"] |
| |
| @staticmethod |
| def propagate(*vars: List[List["VariableTracker"]]): |
| """Combine the guards from many VariableTracker into **kwargs for a new instance""" |
| guards = set() |
| |
| def visit(var): |
| if type(var) in (list, tuple, dict_values, odict_values): |
| for i in var: |
| visit(i) |
| else: |
| assert isinstance(var, VariableTracker), typestr(var) |
| guards.update(var.guards) |
| |
| visit(vars) |
| return { |
| "guards": guards, |
| } |
| |
| def clone(self, **kwargs): |
| """Shallow copy with some (optional) changes""" |
| args = dict(self.__dict__) |
| args.update(kwargs) |
| return self.__class__(**args) |
| |
| @classmethod |
| def copy(cls, value): |
| """Deeper (but not full) copy, leaving FX and user objects alone""" |
| return cls.apply(identity, value) |
| |
| @classmethod |
| def apply( |
| cls, |
| fn: Callable[["VariableTracker"], "VariableTracker"], |
| value, |
| cache=None, |
| skip_fn=lambda _: False, # Whether we should skip applying to this var |
| ): |
| """ |
| Walk this object and call fn on all the VariableTracker |
| instances to produce a new VariableTracker with the results. |
| """ |
| if cache is None: |
| cache = dict() |
| |
| idx = id(value) |
| if idx in cache: |
| return cache[idx][0] |
| |
| if isinstance(value, VariableTracker): |
| if not skip_fn(value): |
| updated_dict = dict(value.__dict__) |
| for key in updated_dict.keys(): |
| if key not in value._nonvar_fields: |
| updated_dict[key] = cls.apply( |
| fn, updated_dict[key], cache, skip_fn |
| ) |
| result = fn(value.clone(**updated_dict)) |
| else: |
| result = fn(value) |
| |
| elif istype(value, list): |
| result = [cls.apply(fn, v, cache, skip_fn) for v in value] |
| elif istype(value, tuple): |
| result = tuple(cls.apply(fn, v, cache, skip_fn) for v in value) |
| elif istype(value, collections.OrderedDict): |
| result = collections.OrderedDict( |
| cls.apply(fn, v, cache, skip_fn) for v in value.items() |
| ) |
| elif istype(value, dict): |
| result = { |
| k: cls.apply(fn, v, cache, skip_fn) for k, v in list(value.items()) |
| } |
| else: |
| result = value |
| |
| # save `value` to keep it alive and ensure id() isn't reused |
| cache[idx] = (result, value) |
| return result |
| |
| def add_guard(self, guard): |
| return self.clone(guards=set.union(self.guards, {guard})) |
| |
| def add_guards(self, guards): |
| if guards is None: |
| return self |
| assert isinstance(guards, set) |
| return self.clone(guards=set.union(self.guards, guards)) |
| |
| def add_options(self, options, *more): |
| if more: |
| return self.add_options(options).add_options(*more) |
| if isinstance(options, VariableTracker): |
| return self.add_guards(options.guards) |
| assert isinstance(options, dict) |
| return self.add_guards(options.get("guards", set())) |
| |
| def __str__(self): |
| return f"{self.__class__.__name__}()" |
| |
| def __repr__(self): |
| return str(self) |
| |
| def python_type(self): |
| raise NotImplementedError(f"{self} has no type") |
| |
| def as_python_constant(self): |
| """For constants""" |
| raise NotImplementedError(f"{self} is not a constant") |
| |
| def is_python_constant(self): |
| try: |
| self.as_python_constant() |
| return True |
| except NotImplementedError: |
| return False |
| |
| def as_specialized(self, tx): |
| """ |
| For specialized variables, return itself, |
| For unspecialized variables, convert to constant variable and return. |
| """ |
| return self |
| |
| def can_make_guard(self): |
| try: |
| self.make_guard(None) |
| return True |
| except NotImplementedError: |
| return False |
| |
| def make_guard(self, fn): |
| if self.source: |
| return self.source.make_guard(fn) |
| raise NotImplementedError() |
| |
| def replace_guards(self, guards, *fns): |
| name = self.source.name() |
| new_guards = {g for g in (guards or []) if g.name != name} |
| new_guards.update(self.source.make_guard(fn) for fn in fns) |
| return new_guards |
| |
| def const_getattr(self, tx, name: str) -> Any: |
| """getattr(self, name) returning a python constant""" |
| raise NotImplementedError() |
| |
| def var_getattr(self, tx, name: str) -> "VariableTracker": |
| """getattr(self, name) returning a new variable""" |
| options = VariableTracker.propagate(self) |
| value = self.const_getattr(tx, name) |
| if not variables.ConstantVariable.is_literal(value): |
| raise NotImplementedError() |
| if self.source: |
| options["source"] = AttrSource(self.source, name) |
| return variables.ConstantVariable(value, **options) |
| |
| def is_proxy(self): |
| try: |
| self.as_proxy() |
| return True |
| except NotImplementedError: |
| return False |
| |
| def as_proxy(self): |
| raise NotImplementedError(str(self)) |
| |
| def reconstruct(self, codegen): |
| raise NotImplementedError() |
| |
| def unpack_var_sequence(self, tx): |
| raise NotImplementedError() |
| |
| def has_unpack_var_sequence(self, tx): |
| try: |
| self.unpack_var_sequence(tx) |
| return True |
| except NotImplementedError: |
| return False |
| |
| def num_parameters(self): |
| unimplemented(f"num_parameters: {self}") |
| |
| def call_hasattr(self, tx, name: str) -> "VariableTracker": |
| unimplemented(f"hasattr: {self}") |
| |
| def call_function( |
| self, tx, args: "List[VariableTracker]", kwargs: "Dict[str, VariableTracker]" |
| ) -> "VariableTracker": |
| unimplemented(f"call_function {self} {args} {kwargs}") |
| |
| def call_method( |
| self, |
| tx, |
| name, |
| args: "List[VariableTracker]", |
| kwargs: "Dict[str, VariableTracker]", |
| ) -> "VariableTracker": |
| if name == "__len__" and self.has_unpack_var_sequence(tx): |
| assert not (args or kwargs) |
| return variables.ConstantVariable( |
| len(self.unpack_var_sequence(tx)), **VariableTracker.propagate(self) |
| ) |
| elif ( |
| name == "__getattr__" |
| and len(args) == 1 |
| and args[0].is_python_constant() |
| and not kwargs |
| ): |
| return self.var_getattr(tx, args[0].as_python_constant()).add_options( |
| self, args[0] |
| ) |
| raise unimplemented(f"call_method {self} {name} {args} {kwargs}") |
| |
| def __init__( |
| self, |
| guards: Optional[Set] = None, |
| source: Source = None, |
| mutable_local: MutableLocal = None, |
| recursively_contains: Optional[Set] = None, |
| ): |
| super(VariableTracker, self).__init__() |
| self.guards = guards or set() |
| self.source = source |
| self.mutable_local = mutable_local |
| self.recursively_contains = ( |
| recursively_contains # provides hint to replace_all when replacing vars |
| ) |
| |
| def __post_init__(self, *args, **kwargs): |
| if self.recursively_contains is None: |
| self.recursively_contains = set() |
| |
| def aggregate_mutables(var): |
| self.recursively_contains.update(var.recursively_contains) |
| if var.mutable_local is not None: |
| self.recursively_contains.add(var.mutable_local) |
| |
| return var |
| |
| VariableTracker.apply( |
| aggregate_mutables, self, skip_fn=lambda var: var is not self |
| ) |
| |
| assert None not in self.recursively_contains |
| |
| |
| def typestr(*objs): |
| if len(objs) == 1: |
| (obj,) = objs |
| if isinstance(obj, VariableTracker): |
| return str(obj) |
| else: |
| return type(obj).__name__ |
| else: |
| return " ".join(map(typestr, objs)) |