blob: 0b5cfae69363c1e678e8e48b8b793858a61c891b [file] [log] [blame]
import collections
import dataclasses
import dis
import functools
import importlib
import inspect
import itertools
import logging
import operator
import sys
import traceback
import types
import typing
import weakref
from typing import Any, Dict, Iterable, List
from unittest.mock import patch
import torch
from . import (
allowed_functions,
config,
exc,
logging as torchdynamo_logging,
side_effects,
skipfiles,
variables,
)
from .allowed_functions import is_allowed, is_builtin_callable, is_builtin_constant
from .bytecode_analysis import livevars_analysis
from .bytecode_transformation import (
cleaned_instructions,
create_instruction,
Instruction,
is_generator,
unique_id,
)
from .codegen import PyCodegen
from .exc import unimplemented, Unsupported
from .guards import GuardBuilder
from .output_graph import GraphCompileReason, OutputGraph
from .replay_record import DummyModule, ExecutionRecorder
from .resume_execution import ContinueExecutionCache, ReenterWith
from .source import (
AttrSource,
GetItemSource,
GlobalSource,
GlobalWeakRefSource,
LocalSource,
)
from .utils import (
counters,
fake_tensors_available,
graph_break_dup_warning_checker,
istype,
)
from .variables.base import MutableLocal, typestr, VariableTracker
from .variables.builder import VariableBuilder
from .variables.builtin import BuiltinVariable
from .variables.constant import ConstantVariable
from .variables.dicts import ConstDictVariable
from .variables.functions import (
BaseUserFunctionVariable,
NestedUserFunctionVariable,
UserFunctionVariable,
)
from .variables.lists import (
BaseListVariable,
ListIteratorVariable,
ListVariable,
SliceVariable,
TupleVariable,
)
from .variables.misc import (
ClosureVariable,
ContextWrappingVariable,
GetAttrVariable,
GradModeVariable,
PythonModuleVariable,
UnknownVariable,
WithExitFunctionVariable,
)
from .variables.nn_module import NNModuleVariable
from .variables.tensor import TensorVariable
from .variables.torch import TorchVariable
from .variables.user_defined import UserDefinedVariable
log = logging.getLogger(__name__)
@functools.lru_cache(None)
def _step_logger():
return torchdynamo_logging.get_step_logger(log)
@dataclasses.dataclass
class BlockStackEntry:
target: Instruction
stack_index: int = None
with_context: ContextWrappingVariable = None
def can_restore(self):
return self.with_context is not None
def resume_fn(self):
assert self.stack_index is not None
return ReenterWith(self.stack_index)
def exit(self, tx):
return self.with_context.exit(tx)
def stack_op(fn: typing.Callable):
nargs = len(inspect.signature(fn).parameters)
fn_var = BuiltinVariable(fn)
@functools.wraps(fn)
def impl(self: "InstructionTranslatorBase", inst: Instruction):
self.push(fn_var.call_function(self, self.popn(nargs), {}))
return impl
def generic_jump(truth_fn: typing.Callable, push: bool):
def inner(self: "InstructionTranslatorBase", inst: Instruction):
value: VariableTracker = self.pop()
self.output.guards.update(value.guards)
if value.is_python_constant():
if truth_fn(value.as_python_constant()):
push and self.push(value)
self.jump(inst)
elif isinstance(value, TensorVariable) and self.should_compile_partial_graph():
# compile a partial subgraph prefix then jump into user code
self.push(value)
self.output.compile_subgraph(
self,
reason=GraphCompileReason(
f"generic_jump {typestr(value)}", [self.frame_summary()]
),
)
self.pop()
if_next = self.create_call_resume_at(self.next_instruction)
push and self.push(value)
if_jump = self.create_call_resume_at(inst.target)
self.output.add_output_instructions(
[(create_instruction(inst.opname, target=if_jump[0]))]
+ if_next
+ if_jump
)
elif not isinstance(value, TensorVariable) and value.has_unpack_var_sequence(
self
):
if truth_fn(len(value.unpack_var_sequence(self))):
push and self.push(value)
self.jump(inst)
else:
unimplemented(f"generic_jump {typestr(value)}")
return inner
explain = False
def break_graph_if_unsupported(*, push):
def decorator(inner_fn):
@functools.wraps(inner_fn)
def wrapper(self: "InstructionTranslatorBase", inst: Instruction):
state = self.copy_graphstate()
reason = None
try:
return inner_fn(self, inst)
except Unsupported as exc:
if not self.should_compile_partial_graph():
raise
user_stack = [self.frame_summary()] + list(reversed(exc.real_stack))
user_stack_formatted = "".join(traceback.format_list(user_stack))
frame_loc = (user_stack[-1].filename, user_stack[-1].lineno)
# torchdynamo.explain() formats this a little nicer, and presents a slightly
# more actionable user code pointer
if (
config.print_graph_breaks
and not explain
and graph_break_dup_warning_checker.add(frame_loc)
):
log.warning(
f"Graph break: {exc} from user code at {user_stack_formatted}"
)
exc.remove_from_stats()
exc.add_to_stats("graph_break")
reason = GraphCompileReason(exc.msg, user_stack)
self.restore_graphstate(state)
self.output.compile_subgraph(self, reason=reason)
self.popn(push - dis.stack_effect(inst.opcode, inst.arg))
for _ in range(push):
self.push(UnknownVariable())
resume_call_insts = self.create_call_resume_at(self.next_instruction)
# Check if there is a block stack entry with GradModeVariable. And
# wrap the instruction causing the graph break inside a try..finally
# block. See more details at
# https://github.com/pytorch/torchdynamo/issues/207
cleanup = []
if len(self.block_stack) == 1 and isinstance(
self.block_stack[0].with_context, GradModeVariable
):
ctx_variable = self.block_stack[0].with_context
cg = PyCodegen(self)
setup_finally, cleanup = ctx_variable.reconstruct(
cg, resume_call_insts[0]
)
self.output.add_output_instructions(setup_finally)
self.output.add_output_instructions([inst])
# Add the cleanup instructions from try..finally block
self.output.add_output_instructions(cleanup)
self.output.add_output_instructions(
resume_call_insts,
)
return wrapper
return decorator
class InstructionTranslatorBase(object):
def cell_and_freevars(self):
if not hasattr(self, "_cell_and_freevars"):
self._cell_and_freevars = tuple(
self.code_options["co_cellvars"] or []
) + tuple(self.code_options["co_freevars"] or [])
return self._cell_and_freevars
def prune_dead_locals(self):
reads = livevars_analysis(self.instructions, self.current_instruction)
# implicit use by super()
# reads = reads | {"__class__"}
# output variables?
reads = reads | set(self.cell_and_freevars())
self.symbolic_locals = collections.OrderedDict(
[(k, v) for k, v in self.symbolic_locals.items() if k in reads]
)
self.output.side_effects.prune_dead_object_new(self)
def call_function(
self,
fn: VariableTracker,
args: List[VariableTracker],
kwargs: Dict[str, VariableTracker],
):
assert isinstance(fn, VariableTracker)
assert isinstance(args, list)
assert isinstance(kwargs, dict)
assert all(
isinstance(x, VariableTracker)
for x in itertools.chain(args, kwargs.values())
)
self.push(fn.call_function(self, args, kwargs))
def update_locals_and_stack(self, oldvar: VariableTracker, newvar: VariableTracker):
def repl(v: VariableTracker):
if v.mutable_local is oldvar.mutable_local:
return newvar
return v
cache = dict()
self.output.side_effects.apply(repl, cache)
self.stack = [VariableTracker.apply(repl, x, cache) for x in self.stack]
for k, x in self.symbolic_locals.items():
self.symbolic_locals[k] = VariableTracker.apply(repl, x, cache)
def replace_all(self, oldvar: VariableTracker, newvar: VariableTracker):
if isinstance(oldvar.mutable_local, side_effects.MutableSideEffects):
newvar = self.output.side_effects.mutation(oldvar, newvar)
else:
assert isinstance(oldvar.mutable_local, variables.base.MutableLocal)
newvar = newvar.clone(mutable_local=variables.base.MutableLocal())
self.update_locals_and_stack(oldvar, newvar)
return newvar
def inline_user_function_return(self, fn, args, kwargs):
"""
A call to some user defined function by inlining it.
"""
state = self.copy_graphstate()
try:
result = InliningInstructionTranslator.inline_call(self, fn, args, kwargs)
self.output.guards.update(fn.guards)
return result
except Exception:
self.restore_graphstate(state)
raise
def step(self):
"""Process exactly one instruction, return False we should exit"""
inst = self.instructions[self.instruction_pointer]
self.current_instruction = inst
self.instruction_pointer += 1
if self.instruction_pointer < len(self.instructions):
self.next_instruction = self.instructions[self.instruction_pointer]
else:
self.instruction_pointer = None
self.next_instruction = None
if inst.starts_line and self.lineno != inst.starts_line:
self.lineno = inst.starts_line
log.debug(f"TRACE starts_line {self.f_code.co_filename}:{self.lineno}")
if len(self.stack) == 0 and self.should_compile_partial_graph():
self.checkpoint = inst, self.copy_graphstate()
log.debug(f"TRACE {inst.opname} {inst.argval} {self.stack}")
try:
if not hasattr(self, inst.opname):
unimplemented(f"missing: {inst.opname}")
getattr(self, inst.opname)(inst)
return inst.opname != "RETURN_VALUE"
except Unsupported as exc:
exc.real_stack.append(self.frame_summary())
if self.empty_checkpoint():
raise
except Exception as exc:
real_stack = getattr(exc, "real_stack", [])
real_stack.append(self.frame_summary())
exc.real_stack = real_stack
raise
# generate code from checkpoint
assert not self.output.output_instructions
continue_inst, state = self.checkpoint
self.restore_graphstate(state)
self.output.compile_subgraph(self, partial_convert=True)
self.output.add_output_instructions(
[create_instruction("JUMP_ABSOLUTE", target=continue_inst)]
+ self.instructions
)
def run(self):
try:
while (
self.instruction_pointer is not None
and not self.output.should_exit
and self.step()
):
pass
except Exception as e:
if config.replay_record_enabled:
e.exec_record = self.exec_recorder.get_record()
raise
finally:
# Cleanup the outputGraph to delete the held tensors. We perform the
# cleanup only for InstructionTranslator and not
# InliningInstructionTranslator. The InliningInstructionTranslator
# mutates the output object and is restored to original state if
# there was an exception.
if isinstance(self, InstructionTranslator):
self.output.cleanup()
def push(self, val):
assert val is None or isinstance(
val, VariableTracker
), f"push expects VariableTracker, got {typestr(val)}"
self.stack.append(val)
def push_many(self, vals: List[TensorVariable]):
for val in vals:
self.push(val)
def pop(self) -> TensorVariable:
return self.stack.pop()
def popn(self, n: int) -> List[TensorVariable]:
assert n >= 0
return list(reversed([self.pop() for _ in range(n)]))
def LOAD_FAST(self, inst):
name = inst.argval
if name in self.f_locals and config.replay_record_enabled:
self.exec_recorder.add_local_var(name, self.f_locals[name])
if name.startswith(".") and name not in self.symbolic_locals:
# This happens in dict/list comprehensions
name = name.replace(".", "implicit")
assert name not in self.cell_and_freevars()
if name not in self.symbolic_locals:
unimplemented("undefined LOAD_FAST")
self.push(self.symbolic_locals[name])
if name.startswith("___stack"):
self.symbolic_locals.pop(name)
def LOAD_DEREF(self, inst):
assert inst.argval in self.cell_and_freevars()
if inst.argval in self.f_locals and config.replay_record_enabled:
self.exec_recorder.add_local_var(inst.argval, self.f_locals[inst.argval])
if inst.argval not in self.symbolic_locals:
unimplemented(f"undefined LOAD_DEREF {inst.argval}")
self.push(self.symbolic_locals[inst.argval])
def STORE_FAST(self, inst):
self.symbolic_locals[inst.argval] = self.pop()
def DELETE_FAST(self, inst):
del self.symbolic_locals[inst.argval]
STORE_DEREF = STORE_FAST
def LOAD_CLOSURE(self, inst):
self.push(ClosureVariable(name=inst.argval))
def LOAD_CONST(self, inst):
self.push(ConstantVariable(value=inst.argval))
def get_global_source(self, name):
if self.output.root_globals is self.f_globals:
source = GlobalSource(name)
else:
if "__name__" in self.f_globals:
source = AttrSource(
self.import_source(self.f_globals["__name__"]), name
)
else:
mangled_name = f"___unnamed_scope_{id(self.f_globals)}"
if mangled_name not in self.output.root_globals:
self.output.install_global(mangled_name, self.f_globals)
source = GetItemSource(GlobalSource(mangled_name), name)
return source
def LOAD_GLOBAL(self, inst):
name = inst.argval
if config.replay_record_enabled:
if name in self.f_globals:
self.exec_recorder.add_global_var(name, self.f_globals[name])
else:
assert name in self.f_builtins
self.exec_recorder.builtins[name] = self.f_builtins[name]
if name in self.symbolic_globals:
variable = self.output.side_effects[self.symbolic_globals[name]]
self.push(self.output.side_effects.load_global(variable, name))
return
try:
value = self.f_globals[name]
except KeyError:
return self.load_builtin(inst)
source = self.get_global_source(name)
self.push(VariableBuilder(self, source)(value))
def STORE_GLOBAL(self, inst):
value = self.pop()
name = inst.argval
source = self.get_global_source(name)
if name not in self.symbolic_globals:
self.symbolic_globals[name] = object() # sentinel object
variable = self.output.side_effects.track_global_existing(
source, self.symbolic_globals[name]
)
self.output.side_effects.store_global(variable, name, value)
def import_source(self, module_name):
"""Create an alias to a module for use in guards"""
value = importlib.import_module(module_name)
alias = f"__import_{module_name.replace('.', '_dot_')}"
f_globals = self.output.root_globals
assert alias not in f_globals or f_globals[alias] is value
f_globals[alias] = value
self.output.update_co_names(alias)
return GlobalSource(alias)
def resolve_name(self, name, package, level):
"""
Copied from the Cpython implementation of __import__
Resolve a relative module name to an absolute one.
https://github.com/python/cpython/blob/5a094f0255eea1db58fb2cf14c200971e64ec36e/Lib/importlib/_bootstrap.py#L902
"""
bits = package.rsplit(".", level - 1)
if len(bits) < level:
raise ImportError("attempted relative import beyond top-level package")
base = bits[0]
return "{}.{}".format(base, name) if name else base
def calc_package(self):
"""
Copied from the Cpython implementation of __import__
https://github.com/python/cpython/blob/5a094f0255eea1db58fb2cf14c200971e64ec36e/Lib/importlib/_bootstrap.py#L1090
"""
package = self.f_globals.get("__package__")
spec = self.f_globals.get("__spec__")
if package is not None:
if spec is not None and package != spec.parent:
log.warning(
"__package__ != __spec__.parent "
f"({package!r} != {spec.parent!r})",
ImportWarning,
stacklevel=3,
)
return package
elif spec is not None:
return spec.parent
else:
log.warning(
"can't resolve package from __spec__ or __package__, "
"falling back on __name__ and __path__",
ImportWarning,
stacklevel=3,
)
package = self.f_globals["__name__"]
if "__path__" not in self.f_globals:
package = package.rpartition(".")[0]
return package
def IMPORT_NAME(self, inst):
level, fromlist = self.popn(2)
level = level.as_python_constant()
fromlist = fromlist.as_python_constant()
module_name = inst.argval
# Are we replaying? if so, load recorded module
recorded_name = (
f"{ExecutionRecorder.LOCAL_MOD_PREFIX}_{level}_{fromlist}_{module_name}"
)
if recorded_name in self.f_globals:
value = self.f_globals[recorded_name]
source = GlobalSource(recorded_name)
else:
value = __import__(
module_name,
fromlist=fromlist,
level=level,
globals=self.f_globals,
)
if level != 0:
pkg = self.calc_package()
module_name = self.resolve_name(module_name, pkg, level)
# For __import__, when the name variable is of the form package.module,
# normally, the top-level package (the name up till the first dot) is
# returned, not the module named by module_name. However, when a
# non-empty fromlist argument is given, the module named by name is
# returned. Therefore, we set the source correctly here.
if not fromlist:
top_level_module_name = module_name.partition(".")[0]
source = self.import_source(top_level_module_name)
else:
source = self.import_source(module_name)
if config.replay_record_enabled:
self.exec_recorder.add_local_mod(recorded_name, value)
if is_allowed(value):
self.push(TorchVariable(value, source=source))
elif istype(value, (types.ModuleType, DummyModule)):
self.push(PythonModuleVariable(value, source=source))
else:
unimplemented(f"IMPORT_NAME {typestr(value)}")
def IMPORT_FROM(self, inst):
self.DUP_TOP(inst)
self.LOAD_ATTR(inst)
def load_builtin(self, inst):
assert inst.argval in self.f_builtins
val = self.f_builtins[inst.argval]
if callable(val):
assert is_builtin_callable(val)
self.push(VariableBuilder(self, GlobalSource(inst.argval))(val))
else:
assert is_builtin_constant(val)
self.push(ConstantVariable(value=val))
def jump(self, inst):
self.instruction_pointer = self.indexof[id(inst.target)]
JUMP_FORWARD = jump
JUMP_ABSOLUTE = jump
POP_JUMP_IF_FALSE = generic_jump(operator.not_, False)
POP_JUMP_IF_TRUE = generic_jump(operator.truth, False)
JUMP_IF_FALSE_OR_POP = generic_jump(operator.not_, True)
JUMP_IF_TRUE_OR_POP = generic_jump(operator.truth, True)
def SETUP_LOOP(self, inst):
# only exists in python<=3.7
self.block_stack.append(BlockStackEntry(inst.target))
def SETUP_EXCEPT(self, inst):
# only exists in python<=3.7
self.block_stack.append(BlockStackEntry(inst.target))
def POP_BLOCK(self, inst):
self.block_stack.pop()
def SETUP_WITH(self, inst):
ctx = self.pop()
if not isinstance(ctx, ContextWrappingVariable):
unimplemented(f"SETUP_WITH {ctx}")
self.output.guards.update(ctx.guards)
if isinstance(self, InstructionTranslator):
self.block_stack.append(BlockStackEntry(inst.target, len(self.stack), ctx))
else:
# can't restore this while inlining
self.block_stack.append(BlockStackEntry(inst.target))
self.push(
WithExitFunctionVariable(
ctx,
inst.target,
**VariableTracker.propagate(ctx),
)
)
self.push(ctx.enter(self))
def SETUP_FINALLY(self, inst):
self.block_stack.append(BlockStackEntry(inst.target))
def BEGIN_FINALLY(self, inst):
self.push(None)
def WITH_CLEANUP_START(self, inst):
exit, exc = self.popn(2)
if sys.version_info < (3, 8):
assert exc.is_python_constant()
assert exc.as_python_constant() is None
else:
assert exc is None
self.push(exc)
self.push(exit.call_function(self, [ConstantVariable(None)] * 3, {}))
def WITH_CLEANUP_FINISH(self, inst):
self.popn(2)
self.push(None)
def END_FINALLY(self, inst):
tos = self.pop()
if sys.version_info < (3, 8):
# python3.7 and 3.8 can have END_FINALLY without BEGIN_FINALLY
assert tos is None or (
tos.is_python_constant() and tos.as_python_constant() is None
)
else:
assert tos is None
def FOR_ITER(self, inst):
it = self.pop()
if isinstance(it, ListIteratorVariable):
self.output.guards.update(it.guards)
try:
val, next_iter = it.next_variables()
self.replace_all(it, next_iter)
self.push(next_iter)
self.push(val)
except StopIteration:
self.jump(inst)
else:
unimplemented(f"FOR_ITER {typestr(it)}")
def COMPARE_OP(self, inst):
left, right = self.popn(2)
left = left.as_specialized(self)
right = right.as_specialized(self)
options = VariableTracker.propagate([left, right])
op = inst.argval
supported_is_const = {
"is": operator.is_,
"is not": operator.is_not,
"==": operator.eq,
"!=": operator.ne,
}
supported_tensors = {
">": operator.gt,
"<": operator.lt,
">=": operator.ge,
"<=": operator.le,
"==": operator.eq,
"!=": operator.ne,
}
supported_any = dict(
itertools.chain(supported_tensors.items(), supported_is_const.items())
)
if (
isinstance(
left,
(
TensorVariable,
NNModuleVariable,
BaseListVariable,
UserDefinedVariable,
BaseUserFunctionVariable,
ConstDictVariable,
),
)
and isinstance(right, ConstantVariable)
and right.value is None
and op in supported_is_const
):
# <non-None> is None
self.push(
ConstantVariable(
supported_is_const[op](object(), right.value), **options
)
)
elif (
isinstance(left, TensorVariable) or isinstance(right, TensorVariable)
) and op in supported_tensors:
self.push(
TensorVariable.create(
self,
supported_tensors[op](left.as_proxy(), right.as_proxy()),
**options,
)
)
elif (
left.is_python_constant()
and right.is_python_constant()
and op in supported_any
):
# constant fold
self.push(
ConstantVariable(
supported_any[op](
left.as_python_constant(), right.as_python_constant()
),
**options,
)
)
elif op in ("in", "not in"):
self.push(right.call_method(self, "__contains__", [left], {}))
if op == "not in":
self.UNARY_NOT(inst)
else:
unimplemented(f"COMPARE_OP {typestr(left)} {op} {typestr(right)}")
def GET_ITER(self, inst):
self.call_function(BuiltinVariable(iter), [self.pop()], {})
@break_graph_if_unsupported(push=1)
def CALL_FUNCTION(self, inst):
args = self.popn(inst.argval)
fn = self.pop()
self.call_function(fn, args, {})
@break_graph_if_unsupported(push=1)
def CALL_FUNCTION_EX(self, inst):
if inst.argval == 0:
kwargsvars = ConstDictVariable({}, dict)
argsvars = self.pop()
elif inst.argval == 1:
kwargsvars = self.pop()
argsvars = self.pop()
else:
unimplemented("CALL_FUNCTION_EX")
fn = self.pop()
self.output.guards.update(argsvars.guards)
self.output.guards.update(kwargsvars.guards)
if (
isinstance(fn, GetAttrVariable)
and isinstance(fn.obj, TensorVariable)
and fn.name == "view"
and isinstance(argsvars, (ConstantVariable, TensorVariable))
):
# Hack to handle special case in some bert models. Converts
# x.view(*shape) into x.view(shape), which is correct for view()
# but not generally. See test_transpose_for_scores().
argsvars = TupleVariable([argsvars])
if not isinstance(
argsvars, BaseListVariable
) and argsvars.has_unpack_var_sequence(self):
argsvars = TupleVariable(argsvars.unpack_var_sequence(self))
if not isinstance(argsvars, BaseListVariable) or not isinstance(
kwargsvars, ConstDictVariable
):
unimplemented(f"non-static call {typestr(argsvars)} {typestr(kwargsvars)}")
self.call_function(fn, argsvars.items, kwargsvars.items)
@break_graph_if_unsupported(push=1)
def CALL_FUNCTION_KW(self, inst):
argnames = self.pop()
args = self.popn(inst.argval)
fn = self.pop()
assert isinstance(argnames, ConstantVariable)
argnames = argnames.value
args, kwargs = args[: -len(argnames)], args[-len(argnames) :]
kwargs = dict(zip(argnames, kwargs))
assert len(kwargs) == len(argnames)
self.call_function(fn, args, kwargs)
def LOAD_METHOD(self, inst):
self.LOAD_ATTR(inst)
self.push(self.pop())
self.push(None)
def CALL_METHOD(self, inst):
args = self.popn(inst.argval)
dummy = self.pop()
assert dummy is None
fn = self.pop()
self.call_function(fn, args, {})
def LOAD_ATTR(self, inst):
obj = self.pop()
result = BuiltinVariable(getattr).call_function(
self, [obj, ConstantVariable(inst.argval)], {}
)
self.push(result)
def STORE_ATTR(self, inst):
prior = self.copy_graphstate()
val, obj = self.popn(2)
try:
self.output.guards.update(
BuiltinVariable(setattr)
.call_function(self, [obj, ConstantVariable(inst.argval), val], {})
.guards
)
return
except Unsupported as e:
if not self.should_compile_partial_graph():
raise
e.remove_from_stats()
e.add_to_stats("graph_break")
self.restore_graphstate(prior)
# break the graph
self.output.compile_subgraph(
self, reason=GraphCompileReason("store_attr", [self.frame_summary()])
)
self.output.add_output_instructions([inst])
self.popn(2)
self.output.add_output_instructions(
self.create_call_resume_at(self.next_instruction)
)
@break_graph_if_unsupported(push=0)
def STORE_SUBSCR(self, inst):
val, obj, key = self.popn(3)
result = obj.call_method(self, "__setitem__", [key, val], {})
# no result is pushed, so need to lift the guards to global
self.output.guards.update(result.guards)
def BUILD_TUPLE(self, inst):
items = self.popn(inst.argval)
options = VariableTracker.propagate(items)
self.push(TupleVariable(items, **options))
def BUILD_SLICE(self, inst):
items = self.popn(inst.argval)
options = VariableTracker.propagate(items)
self.push(
SliceVariable(
[x.as_specialized(self) for x in items],
**options,
)
)
def BUILD_LIST(self, inst):
items = self.popn(inst.argval)
options = VariableTracker.propagate(items)
self.push(ListVariable(items, mutable_local=MutableLocal(), **options))
def BUILD_LIST_UNPACK(self, inst, cls=ListVariable):
seqs = self.popn(inst.argval)
options = VariableTracker.propagate(seqs)
items = list()
for seq in seqs:
try:
items.extend(seq.unpack_var_sequence(self))
except NotImplementedError:
unimplemented(f"BUILD_LIST_UNPACK {seq}")
self.push(cls(items, mutable_local=MutableLocal(), **options))
def BUILD_TUPLE_UNPACK(self, inst):
self.BUILD_LIST_UNPACK(inst, cls=TupleVariable)
BUILD_TUPLE_UNPACK_WITH_CALL = BUILD_TUPLE_UNPACK
def BUILD_MAP(self, inst):
items = self.popn(inst.argval * 2)
options = VariableTracker.propagate(items)
result = dict()
for k, v in zip(items[::2], items[1::2]):
assert isinstance(k, ConstantVariable) or (
isinstance(k, TensorVariable) and k.specialized_value is not None
)
result[ConstDictVariable.get_key(k)] = v
assert len(result) == len(items) / 2
self.push(
ConstDictVariable(result, dict, mutable_local=MutableLocal(), **options)
)
def BUILD_CONST_KEY_MAP(self, inst):
keys = self.pop()
values = self.popn(inst.argval)
options = VariableTracker.propagate([keys] + values)
assert isinstance(keys, ConstantVariable)
keys = keys.value
assert istype(keys, tuple)
assert len(keys) == len(values)
self.push(
ConstDictVariable(
dict(zip(keys, values)),
dict,
mutable_local=MutableLocal(),
**options,
)
)
def MAP_ADD(self, inst):
if sys.version_info < (3, 8):
v, k = self.popn(2)
else:
k, v = self.popn(2)
assert inst.argval > 0
obj = self.stack[-inst.arg]
assert isinstance(obj, ConstDictVariable)
assert obj.mutable_local
items = dict(obj.items)
items[k.as_python_constant()] = v
self.replace_all(
obj,
ConstDictVariable(
items,
obj.user_cls,
**VariableTracker.propagate([obj, k, v]),
),
)
def LIST_APPEND(self, inst):
v = self.pop()
assert inst.argval > 0
obj = self.stack[-inst.arg]
assert isinstance(obj, ListVariable)
assert obj.mutable_local
self.replace_all(
obj,
ListVariable(
obj.items + [v],
**VariableTracker.propagate([obj, v]),
),
)
def MAKE_FUNCTION(self, inst):
flags = inst.arg
old_stack = list(self.stack)
fn_name = self.pop()
code = self.pop()
defaults = None
closure = None
annotations = None
kwdefaults = None
if flags & 0x08:
closure = self.pop()
if flags & 0x04:
annotations = self.pop()
if flags & 0x02:
kwdefaults = self.pop()
if flags & 0x01:
defaults = self.pop()
options = VariableTracker.propagate(old_stack[len(self.stack) :])
self.push(
NestedUserFunctionVariable(
fn_name,
code,
self.f_globals,
defaults,
kwdefaults,
annotations,
closure,
closure_scope=self,
**options,
)
)
def UNPACK_SEQUENCE(self, inst):
# TODO(jansel): rewrite this using unpack_var_sequence
seq = self.pop()
options = VariableTracker.propagate([seq])
if isinstance(seq, BaseListVariable):
assert len(seq.items) == inst.argval
self.output.guards.update(seq.guards)
for i in reversed(seq.items):
self.push(i)
elif seq.is_python_constant() and isinstance(seq, ConstantVariable):
val = seq.as_python_constant()
assert len(val) == inst.argval
for i in reversed(val):
self.push(ConstantVariable(i, **options))
elif isinstance(seq, TensorVariable):
proxy = seq.as_proxy()
for i in reversed(range(inst.argval)):
self.push(TensorVariable.create(self, proxy[i], **options))
elif isinstance(seq, GetAttrVariable) and isinstance(seq.obj, TensorVariable):
# x, y = a.shape
proxy = getattr(seq.obj.as_proxy(), seq.name)
for i in reversed(range(inst.argval)):
self.push(TensorVariable.create(self, proxy[i], **options))
else:
unimplemented(f"UNPACK_SEQUENCE {seq}")
def UNPACK_EX(self, inst):
assert 0 <= inst.argval <= 0xFFFF
prefix = inst.argval & 0xFF # low byte
suffix = inst.argval >> 8 # high byte
seq = self.pop()
options = VariableTracker.propagate(seq)
if seq.has_unpack_var_sequence(self):
vals = list(seq.unpack_var_sequence(self))
assert len(vals) >= prefix + suffix
vals_prefix = vals[:prefix]
vals_list = vals[prefix : len(vals) - suffix]
vals_suffix = vals[len(vals) - suffix :]
for item in reversed(vals_suffix):
self.push(item.add_options(options))
self.push(TupleVariable(vals_list, **options))
for item in reversed(vals_prefix):
self.push(item.add_options(options))
else:
unimplemented(f"UNPACK_EX {seq}")
def NOP(self, inst):
pass
def POP_TOP(self, inst):
self.pop()
def ROT_TWO(self, inst):
a = self.pop()
b = self.pop()
self.push(a)
self.push(b)
def ROT_THREE(self, inst):
a = self.pop()
b = self.pop()
c = self.pop()
self.push(a)
self.push(c)
self.push(b)
def ROT_FOUR(self, inst):
a = self.pop()
b = self.pop()
c = self.pop()
d = self.pop()
self.push(a)
self.push(d)
self.push(c)
self.push(b)
def DUP_TOP(self, inst):
a = self.pop()
self.push(a)
self.push(a)
def DUP_TOP_TWO(self, inst):
a = self.pop()
b = self.pop()
self.push(b)
self.push(a)
self.push(b)
self.push(a)
def FORMAT_VALUE(self, inst):
flags = inst.arg
if (flags & 0x04) == 0x04:
fmt_spec = self.pop()
else:
fmt_spec = ConstantVariable("")
value = self.pop()
if (flags & 0x03) == 0x01:
value = BuiltinVariable(str).call_function(self, [value], {})
elif (flags & 0x03) == 0x02:
value = BuiltinVariable(repr).call_function(self, [value], {})
elif (flags & 0x03) == 0x03:
value = BuiltinVariable(ascii).call_function(self, [value], {})
fmt_var = ConstantVariable(
"{:" + fmt_spec.as_python_constant() + "}"
).add_options(fmt_spec)
self.call_function(BuiltinVariable(str.format), [fmt_var, value], {})
def BUILD_STRING(self, inst):
result = ""
for _ in range(inst.arg):
str_var = self.pop()
assert isinstance(str_var, ConstantVariable)
result = str_var.value + result
self.push(ConstantVariable(value=result))
def IS_OP(self, inst):
assert inst.argval == 0 or inst.argval == 1
if inst.argval == 0:
new_argval = "is"
else:
new_argval = "is not"
new_inst = create_instruction("COMPARE_OP", argval=new_argval)
self.COMPARE_OP(new_inst)
def CONTAINS_OP(self, inst):
assert inst.argval == 0 or inst.argval == 1
left, right = self.popn(2)
op = inst.argval
self.push(right.call_method(self, "__contains__", [left], {}))
if op == 1:
self.UNARY_NOT(inst)
def LIST_EXTEND(self, inst):
v = self.pop()
assert inst.argval > 0
obj = self.stack[-inst.arg]
assert isinstance(obj, ListVariable)
assert obj.mutable_local
obj.call_method(self, "extend", [v], {})
def LIST_TO_TUPLE(self, inst):
self.push(BuiltinVariable(tuple).call_function(self, [self.pop()], {}))
def DICT_MERGE(self, inst):
v = self.pop()
assert inst.argval > 0
obj = self.stack[-inst.arg]
assert isinstance(obj, ConstDictVariable)
assert obj.mutable_local
obj.call_method(self, "update", [v], {})
def GEN_START(self, inst):
self.pop()
def GET_LEN(self, inst):
tos = self.stack[-1]
if tos.is_python_constant():
self.push(ConstantVariable(len(tos.as_python_constant())))
else:
self.push(tos.call_method(self, "__len__", [], {}))
def MATCH_MAPPING(self, inst):
tos = self.stack[-1]
assert isinstance(tos, ConstDictVariable)
if isinstance(tos.items, collections.abc.Mapping):
self.push(ConstantVariable(True))
else:
self.push(ConstantVariable(False))
def MATCH_SEQUENCE(self, inst):
tos = self.stack[-1]
assert tos.is_python_constant()
tos_value = tos.as_python_constant()
if isinstance(tos_value, collections.abc.Sequence) and not isinstance(
tos_value, (str, bytes, bytearray)
):
self.push(ConstantVariable(True))
else:
self.push(ConstantVariable(False))
def MATCH_KEYS(self, inst):
tos = self.stack[-1]
assert tos.is_python_constant()
keys = tos.as_python_constant()
tos1 = self.stack[-2]
assert isinstance(tos1, ConstDictVariable)
match_obj = tos1.items
if all(key in match_obj for key in keys):
self.push(TupleVariable(list(match_obj[key] for key in keys)))
self.push(ConstantVariable(True))
else:
self.push(ConstantVariable(None))
self.push(ConstantVariable(False))
UNARY_POSITIVE = stack_op(operator.pos)
UNARY_NEGATIVE = stack_op(operator.neg)
UNARY_NOT = stack_op(operator.not_)
UNARY_INVERT = stack_op(operator.invert)
BINARY_POWER = stack_op(operator.pow)
BINARY_MULTIPLY = stack_op(operator.mul)
BINARY_MATRIX_MULTIPLY = stack_op(operator.matmul)
BINARY_FLOOR_DIVIDE = stack_op(operator.floordiv)
BINARY_TRUE_DIVIDE = stack_op(operator.truediv)
BINARY_MODULO = stack_op(operator.mod)
BINARY_ADD = stack_op(operator.add)
BINARY_SUBTRACT = stack_op(operator.sub)
BINARY_SUBSCR = break_graph_if_unsupported(push=1)(stack_op(operator.getitem))
BINARY_LSHIFT = stack_op(operator.lshift)
BINARY_RSHIFT = stack_op(operator.rshift)
BINARY_AND = stack_op(operator.and_)
BINARY_OR = stack_op(operator.or_)
BINARY_XOR = stack_op(operator.xor)
INPLACE_POWER = stack_op(operator.ipow)
INPLACE_MULTIPLY = stack_op(operator.imul)
INPLACE_MATRIX_MULTIPLY = stack_op(operator.imatmul)
INPLACE_FLOOR_DIVIDE = stack_op(operator.ifloordiv)
INPLACE_TRUE_DIVIDE = stack_op(operator.itruediv)
INPLACE_MODULO = stack_op(operator.imod)
INPLACE_ADD = stack_op(operator.iadd)
INPLACE_SUBTRACT = stack_op(operator.isub)
INPLACE_LSHIFT = stack_op(operator.ilshift)
INPLACE_RSHIFT = stack_op(operator.irshift)
INPLACE_AND = stack_op(operator.iand)
INPLACE_XOR = stack_op(operator.ixor)
INPLACE_OR = stack_op(operator.ior)
def copy_graphstate(self):
"""Create a checkpoint of the current state by copying everything"""
return (
self.output.copy_graphstate(),
collections.OrderedDict(self.symbolic_locals),
list(self.stack),
list(self.block_stack),
self.instruction_pointer,
self.current_instruction,
self.next_instruction,
self.lineno,
)
def restore_graphstate(self, state):
"""Restore a checkpoint created by self.copy_graphstate()"""
(
output_state,
self.symbolic_locals,
self.stack,
self.block_stack,
self.instruction_pointer,
self.current_instruction,
self.next_instruction,
self.lineno,
) = state
self.output.restore_graphstate(output_state)
def empty_checkpoint(self):
if self.checkpoint is None:
return True
output_graphstate = self.checkpoint[1][0]
graphstate = self.checkpoint[1][1:]
state = (*output_graphstate, *graphstate)
for obj in state:
if isinstance(obj, Iterable):
if len(obj) != 0:
return False
return True
def format_frame_summary(self, additional_stack_frames=None):
if additional_stack_frames is None:
additional_stack_frames = []
return "".join(
traceback.format_list(
([self.frame_summary()] + list(reversed(additional_stack_frames)))
)
)
def frame_summary(self):
return traceback.FrameSummary(
getattr(self.f_code, "co_filename", "<unknown>"),
self.lineno,
getattr(self.f_code, "co_name", "<unknown>"),
lookup_line=False,
)
def store_dict_key(self, name, value):
self.output.guards.add(
GlobalWeakRefSource(name).make_guard(GuardBuilder.WEAKREF_ALIVE)
)
if name not in self.output.root_globals:
self.output.install_global(name, weakref.ref(value))
@property
def fake_mode(self):
return self._fake_mode
def find_symbolic_locals_name(self, tensor_variable):
for key, value in self.symbolic_locals.items():
if value is tensor_variable:
return key
return None
def __init__(
self,
output: OutputGraph,
instructions: List[Instruction],
f_locals: Dict[str, Any],
f_globals: Dict[str, Any],
f_builtins: Dict[str, Any],
code_options: Dict[str, Any],
symbolic_locals: Dict[str, VariableTracker],
symbolic_globals: Dict[str, VariableTracker],
f_code: types.CodeType,
):
super(InstructionTranslatorBase, self).__init__()
# Mutable state checkpointed by copy_graphstate()
self.output: OutputGraph = output
self.symbolic_locals: Dict[str, VariableTracker] = symbolic_locals
self.symbolic_globals: Dict[str, VariableTracker] = symbolic_globals
self.stack: List[VariableTracker] = []
self.instruction_pointer: int = 0
self.current_instruction: Instruction = create_instruction("NOP")
self.next_instruction: typing.Optional[Instruction] = None
self.block_stack: List[BlockStackEntry] = []
self.lineno: int = code_options.get("co_firstlineno")
# Properties of the input/output code
self.instructions: List[Instruction] = instructions
self.indexof: Dict[int, int] = {id(i): n for n, i in enumerate(instructions)}
self.f_locals: Dict[
str, Any
] = f_locals # needed for recording accessed locals for replay
self.f_globals: Dict[str, Any] = f_globals
self.f_builtins: Dict[str, Any] = f_builtins
self.code_options: Dict[str, Any] = code_options
self.f_code: types.CodeType = f_code
# Execution record for replaying errors
self.exec_recorder = ExecutionRecorder(code=f_code, code_options=code_options)
# Stack of module being parsed, current nn.module is at the end of ordered dict
self.nn_module_stack: Dict[str, str] = {}
if fake_tensors_available:
with torch._subclasses.FakeTensorMode(
throw_on_data_dependent_ops=True
) as fake_mode:
pass
self._fake_mode = fake_mode
self.checkpoint = None
self.random_calls: List[tuple] = []
if sys.version_info >= (3, 10):
from .resume_execution import (
CO_ASYNC_GENERATOR,
CO_COROUTINE,
CO_GENERATOR,
CO_ITERABLE_COROUTINE,
)
if f_code.co_flags & (
CO_GENERATOR | CO_COROUTINE | CO_ITERABLE_COROUTINE | CO_ASYNC_GENERATOR
):
self.push(BuiltinVariable(None))
class InstructionTranslator(InstructionTranslatorBase):
def __init__(
self,
instructions: List[Instruction],
f_code,
f_locals,
f_globals,
f_builtins,
code_options,
compiler_fn,
one_graph,
export,
):
super(InstructionTranslator, self).__init__(
output=OutputGraph(f_globals, code_options, compiler_fn, self),
instructions=instructions,
f_locals=f_locals,
f_globals=f_globals,
f_builtins=f_builtins,
code_options=code_options,
symbolic_locals=collections.OrderedDict(), # set below
# A global var is inserted only after a STORE_GLOBAL happens to it
symbolic_globals=collections.OrderedDict(),
f_code=f_code,
)
self.one_graph: bool = one_graph
self.export = export
if self.export:
assert (
self.one_graph
), "Export without one graph - something has gone wrong."
vars = list(code_options["co_varnames"])
vars.extend(x for x in self.cell_and_freevars() if x not in vars)
self.symbolic_locals = collections.OrderedDict(
(k, VariableBuilder(self, LocalSource(k))(f_locals[k]))
for k in vars
if k in f_locals
)
# symbolic_locals contains the mapping from original f_locals to the
# Variable objects. During the Variable building phase, each object also
# has its associated guards. At the end, we will accumulate these
# guards.
#
# One way of handling these guards is to just accumulate all of them
# right now. However, many f_locals might not be used in the frame and
# thus can unnecessarily increase guard execution overhead. Therefore,
# we selectively update output.guards as we run the Python Bytecode
# instruction by instruction.
#
# An exception here is list/dict variables. Guards related to these
# variables have indexed access, like Tensor_match on args[0], and if
# args is not used in this frame, we will miss a LIST_LENGTH check like
# len(args) == 2. Missing the LIST_LENGTH check causes problem for the
# next invocation when args is not a list, and args[0] is a runtime
# error. Therefore, we recursively add guards for list/dict variable here.
for val in self.symbolic_locals.values():
if isinstance(
val, (ListIteratorVariable, BaseListVariable, ConstDictVariable)
):
local_guards = VariableTracker.propagate(val)["guards"]
index_guards = [
guard
for guard in local_guards
if guard.create_fn
in (
GuardBuilder.LIST_LENGTH,
GuardBuilder.DICT_KEYS,
GuardBuilder.ODICT_KEYS,
GuardBuilder.TUPLE_ITERATOR_LEN,
)
]
self.output.guards.update(index_guards)
self._freevars_ids = dict()
for name in self.code_options["co_freevars"]:
if name in f_locals:
self._freevars_ids[name] = id(f_locals[name])
def run(self):
_step_logger()(logging.INFO, f"torchdynamo start tracing {self.f_code.co_name}")
super().run()
def match_nested_cell(self, name, cell):
"""Match a cell in this method to one in a function we are inlining"""
value = cell.cell_contents
# TODO(jansel): check the id of the cell rather than the contents
if id(value) != self._freevars_ids.get(name):
return None
return self.symbolic_locals[name]
def should_compile_partial_graph(self):
return all(b.can_restore() for b in self.block_stack) and not self.one_graph
def create_call_resume_at(self, inst):
self.instruction_pointer = None
if inst.opname == "RETURN_VALUE":
return [create_instruction("RETURN_VALUE")]
reads = livevars_analysis(self.instructions, inst)
argnames = tuple(
k
for k in self.symbolic_locals.keys()
if k in reads and k not in self.cell_and_freevars()
)
nargs = len(self.stack) + len(argnames)
name = unique_id(f"__resume_at_{inst.offset}")
new_code: types.CodeType = ContinueExecutionCache.lookup(
self.f_code,
self.lineno,
inst.offset,
len(self.stack),
argnames,
tuple(b.resume_fn() for b in self.block_stack),
)
cg = PyCodegen(self)
if new_code.co_freevars:
cg.make_function_with_closure(name, new_code, len(self.stack))
else:
self.output.install_global(
name, types.FunctionType(new_code, self.f_globals, name)
)
cg.extend_output(cg.load_function_name(name, len(self.stack)))
cg.extend_output([cg.create_load(k) for k in argnames])
cg.extend_output(
[
create_instruction("CALL_FUNCTION", nargs),
create_instruction("RETURN_VALUE"),
]
)
return cg.get_instructions()
def RETURN_VALUE(self, inst):
if self.output.count_calls() == 0 and not self.export:
raise exc.SkipFrame()
self.instruction_pointer = None
_step_logger()(logging.INFO, f"torchdynamo done tracing {self.f_code.co_name}")
self.output.compile_subgraph(self)
self.output.add_output_instructions([create_instruction("RETURN_VALUE")])
class InliningInstructionTranslator(InstructionTranslatorBase):
"""Trace and inline a called method"""
@classmethod
def inline_call(cls, parent, func, args, kwargs):
with patch.dict(counters, {"unimplemented": counters["inline_call"]}):
return cls.inline_call_(parent, func, args, kwargs)
@staticmethod
def inline_call_(parent, func, args, kwargs):
assert isinstance(
func,
(UserFunctionVariable, NestedUserFunctionVariable),
)
if func.has_self():
unimplemented("inline with __self__")
if func.get_name() == "patched_init":
unimplemented("Patched init cannot be inlined.")
try:
if id(func.get_function()) in allowed_functions._disallowed_function_ids:
unimplemented(f"inlining disallowed: {func.get_function()}")
except NotImplementedError:
pass # closures
if skipfiles.check(
func.get_filename()
) and not skipfiles.is_torch_inline_allowed(func.get_filename()):
unimplemented(
f"inline in skipfiles: {func.get_name()} {func.get_filename()}"
)
try:
sub_locals, closure_cells = func.bind_args(parent, args, kwargs)
except TypeError as exc:
log.warning(
f"{func.get_filename()} {func.get_function()} {args} {kwargs} {exc}"
)
unimplemented("arg mismatch inlining")
for v in itertools.chain(sub_locals.values(), closure_cells.values()):
if not isinstance(v, VariableTracker):
unimplemented(f"unconverted arg {v}")
code: types.CodeType = func.get_code()
if code.co_name in ("__setitem__", "__setattr__"):
unimplemented(f"inline {code.co_name}")
log.debug(f"INLINING {code} \n {dis.Bytecode(code).dis()} \n")
if is_generator(code):
tracer = InliningGeneratorInstructionTranslator(
parent, code, sub_locals, parent.symbolic_globals, closure_cells, func
)
else:
tracer = InliningInstructionTranslator(
parent, code, sub_locals, parent.symbolic_globals, closure_cells, func
)
tracer.run()
assert tracer.symbolic_result is not None
func.export_freevars(parent, tracer)
if tracer.f_globals is parent.f_globals:
# Merge symbolic_globals back if parent and child are in the same namespace
parent.symbolic_globals.update(tracer.symbolic_globals)
log.debug(f"DONE INLINING {code}")
if is_generator(code):
assert tracer.symbolic_result.as_python_constant() is None
return ListIteratorVariable(
tracer.generated_items,
mutable_local=MutableLocal(),
**VariableTracker.propagate(tracer.symbolic_result),
)
else:
return tracer.symbolic_result
def __init__(
self,
parent: InstructionTranslatorBase,
code: types.CodeType,
symbolic_locals: Dict[str, VariableTracker],
symbolic_globals: Dict[str, VariableTracker],
closure_cells: Dict[str, VariableTracker],
funcvar: BaseUserFunctionVariable,
):
f_globals = funcvar.get_globals()
f_builtins = f_globals["__builtins__"]
if not isinstance(f_builtins, dict):
f_builtins = f_builtins.__dict__
super(InliningInstructionTranslator, self).__init__(
output=parent.output,
f_locals={},
f_globals=f_globals,
f_builtins=f_builtins,
symbolic_locals=symbolic_locals,
symbolic_globals=symbolic_globals,
instructions=cleaned_instructions(code),
code_options={k: getattr(code, k) for k in dir(code)},
f_code=code,
)
self.parent = parent
self.symbolic_result = None
self.closure_cells = closure_cells
self.nn_module_stack = parent.nn_module_stack.copy()
@property
def fake_mode(self):
return self.parent.fake_mode
def STORE_DEREF(self, inst):
if inst.argval in self.closure_cells:
cell = self.closure_cells[inst.argval]
val = self.pop()
if isinstance(cell, ClosureVariable):
self.output.root_tx.symbolic_locals[cell.name] = val
else:
self.output.side_effects.store_cell(cell, val)
else:
if isinstance(
self.symbolic_locals.get(inst.argval),
variables.NewCellVariable,
):
self.output.side_effects.store_cell(
self.symbolic_locals[inst.argval], self.pop()
)
else:
unimplemented("write to __closure__ while inlining")
def LOAD_DEREF(self, inst):
if inst.argval in self.closure_cells:
cell = self.closure_cells[inst.argval]
if isinstance(cell, ClosureVariable):
self.push(self.output.root_tx.symbolic_locals[cell.name])
else:
self.push(self.output.side_effects.load_cell(cell))
else:
maybe_sym_local = self.symbolic_locals.get(inst.argval, None)
if isinstance(maybe_sym_local, variables.NewCellVariable):
self.push(self.output.side_effects.load_cell(maybe_sym_local))
else:
super().LOAD_DEREF(inst)
def LOAD_CLOSURE(self, inst):
assert inst.argval in self.cell_and_freevars()
self.push(self.closure_cells[inst.argval])
def replace_all(self, oldvar: VariableTracker, newvar: VariableTracker):
newvar = super().replace_all(oldvar, newvar)
# recursively check and update parent's locals and stack in case oldvar is from parent
translator = self
while hasattr(translator, "parent"):
translator = translator.parent
translator.update_locals_and_stack(oldvar, newvar)
return newvar
def should_compile_partial_graph(self):
return False # inlining functions is all-or-nothing
def create_call_resume_at(self, offset):
unimplemented("cant resume while inlining")
def RETURN_VALUE(self, inst):
self.symbolic_result = self.pop()
self.instruction_pointer = None
class InliningGeneratorInstructionTranslator(InliningInstructionTranslator):
def __init__(self, *args, **kwargs):
super(InliningGeneratorInstructionTranslator, self).__init__(*args, **kwargs)
self.generated_items = []
def YIELD_VALUE(self, inst: Instruction):
self.generated_items.append(self.pop())
# TODO(jansel): figure out why this is needed, it isn't in the docs for YIELD_VALUE
self.push(ConstantVariable(None))