blob: 2d2919052152bfd931bc8570ea6043bd68839301 [file] [log] [blame]
import collections
import dataclasses
import re
import sys
import types
from typing import List
import torch.nn
from .bytecode_transformation import (
cell_and_freevars_offset,
create_call_function,
create_dup_top,
create_instruction,
create_load_global,
create_rot_n,
Instruction,
)
from .exc import unimplemented
from .source import AttrSource, Source
from .utils import is_safe_constant, istype, rot_n_helper
from .variables.base import VariableTracker
from .variables.nn_module import NNModuleVariable
from .variables.tensor import (
SymNodeVariable,
TensorVariable,
TensorWithTFOverrideVariable,
UnspecializedPythonVariable,
)
@dataclasses.dataclass
class GraphOutputEntry:
index: int
variable: VariableTracker
def merge(self, other: VariableTracker):
# merge in any extra guards
self.variable = self.variable.add_options(other)
class PyCodegen:
"""
Helper class uses for constructing Python bytecode
"""
def __init__(
self,
tx=None,
root: torch.nn.Module = None,
graph_output_var: str = None,
tempvars=None,
):
self.root = root
self.top_of_stack = None
self.uses = collections.Counter()
self.graph_outputs = collections.OrderedDict()
self._output: List[Instruction] = []
self.tempvars = tempvars or {}
self.tx = tx
self.graph_output_var = graph_output_var
self.code_options = self.tx.output.code_options
self.cell_and_freevars = self.tx.cell_and_freevars
self.new_var = self.tx.output.new_var
def cell_and_freevars_offset(self, i):
return cell_and_freevars_offset(self.code_options, i)
def graph_output_vars(self):
return [x.variable for x in self.graph_outputs.values()]
def __call__(self, value, allow_cache=True):
"""Generate code such that top-of-stack (TOS) is set to value"""
if isinstance(value, Source):
self._output.extend(value.reconstruct(self))
self.clear_tos()
return
self.tx.output.guards.update(value.guards)
assert isinstance(value, VariableTracker)
output = self._output
graph_outputs = self.graph_outputs
if self.top_of_stack is value:
output.append(create_dup_top())
return
if allow_cache:
if value.mutable_local and value.mutable_local in self.tempvars:
output.append(self.create_load(self.tempvars[value.mutable_local]))
self.top_of_stack = value
return
if self.tempvars.get(value) is not None:
output.append(self.create_load(self.tempvars[value]))
self.top_of_stack = value
return
if value.source is not None and allow_cache:
output.extend(value.source.reconstruct(self))
elif value.is_python_constant() and is_safe_constant(
value.as_python_constant()
):
output.append(self.create_load_const(value.as_python_constant()))
elif isinstance(
value,
(
TensorVariable,
SymNodeVariable,
TensorWithTFOverrideVariable,
UnspecializedPythonVariable,
),
):
if isinstance(value, TensorWithTFOverrideVariable):
# unwrap back to tensor
value = value.tensor_variable
graph_outputs_key = id(value.proxy)
if graph_outputs_key not in graph_outputs:
graph_outputs[graph_outputs_key] = GraphOutputEntry(
len(graph_outputs), value
)
else:
graph_outputs[graph_outputs_key].merge(value)
output.append(self.create_load(self.graph_output_var))
output.append(
self._create_load_const(graph_outputs[graph_outputs_key].index)
)
output.append(create_instruction("BINARY_SUBSCR"))
if isinstance(value, UnspecializedPythonVariable) and value.need_unwrap:
output.extend(
[self.create_load_attr("item")] + create_call_function(0, True)
)
elif isinstance(value, NNModuleVariable):
parts = value.module_key.split(".")
if parts[0] in self.code_options["co_varnames"]:
output.append(self.create_load(parts[0]))
parts = parts[1:]
else:
assert self.root is not None
output.append(self.create_load_output(self.root))
for part in parts:
output.append(self.create_load_attr(part))
else:
self.uses[value] += 1
try:
output.extend(value.reconstruct(self))
except NotImplementedError:
unimplemented(f"reconstruct: {value}")
if allow_cache and value in self.tempvars:
self._output.append(create_dup_top())
self.add_cache(value)
self.top_of_stack = value
def add_cache(self, value):
var = self.new_var()
self.tempvars[value] = var
if value.mutable_local:
self.tempvars[value.mutable_local] = var
self._output.append(self.create_store(var))
def foreach(self, items):
for i in items:
self(i)
def setup_globally_cached(self, name, value, push_null):
"""Store value in a new global"""
name = re.sub(r"[^a-zA-Z0-9_]+", "_", name)
f_globals = self.tx.f_globals
if name in f_globals:
assert id(f_globals[name]) == id(value)
else:
f_globals[name] = value
return [self.create_load_global(name, push_null, add=True)]
def clear_tos(self):
self.top_of_stack = None
def append_output(self, inst):
assert isinstance(inst, Instruction)
self._output.append(inst)
self.clear_tos()
def extend_output(self, insts):
assert all(isinstance(x, Instruction) for x in insts)
self._output.extend(insts)
self.clear_tos()
def get_instructions(self):
return self._output
def create_load(self, name):
if name in self.cell_and_freevars():
return create_instruction(
"LOAD_DEREF",
self.cell_and_freevars_offset(self.cell_and_freevars().index(name)),
name,
)
assert name in self.code_options["co_varnames"], f"{name} missing"
return create_instruction(
"LOAD_FAST", self.code_options["co_varnames"].index(name), name
)
def create_load_closure(self, name):
assert name in self.cell_and_freevars()
return create_instruction(
"LOAD_CLOSURE",
self.cell_and_freevars_offset(self.cell_and_freevars().index(name)),
name,
)
def create_store(self, name):
if name in self.cell_and_freevars():
return create_instruction(
"STORE_DEREF",
self.cell_and_freevars_offset(self.cell_and_freevars().index(name)),
name,
)
assert name in self.code_options["co_varnames"]
return create_instruction(
"STORE_FAST", self.code_options["co_varnames"].index(name), name
)
def create_load_global(self, name, push_null, add=False):
if add:
self.tx.output.update_co_names(name)
assert name in self.code_options["co_names"], f"{name} not in co_names"
return create_load_global(
name, self.code_options["co_names"].index(name), push_null
)
def create_load_const(self, value):
assert is_safe_constant(value), f"unsafe constant {value}"
return self._create_load_const(value)
@staticmethod
def get_const_index(code_options, value):
co_consts = code_options["co_consts"]
assert istype(co_consts, tuple)
index = None
for i, v in enumerate(co_consts):
if type(v) is type(value) and v == value:
index = i
break
if index is None:
index = len(co_consts)
co_consts = co_consts + (value,)
code_options["co_consts"] = co_consts
return index
def _create_load_const(self, value):
index = self.get_const_index(self.code_options, value)
return create_instruction("LOAD_CONST", index, value)
create_load_output = _create_load_const
def create_load_attr(self, name):
if name not in self.code_options["co_names"]:
self.code_options["co_names"] = self.code_options["co_names"] + (name,)
return create_instruction(
"LOAD_ATTR", self.code_options["co_names"].index(name), name
)
def create_load_attrs(self, names):
return [self.create_load_attr(name) for name in names.split(".")]
def load_function_name(self, fn_name, push_null, num_on_stack=0):
"""Load the global fn_name on the stack num_on_stack down"""
output = []
if push_null and sys.version_info >= (3, 11):
output.extend(
[create_instruction("PUSH_NULL")] + self.rot_n(num_on_stack + 1)
)
output.extend(
[self.create_load_global(fn_name, False, add=True)]
+ self.rot_n(num_on_stack + 1)
)
return output
def rot_n(self, n):
try:
return create_rot_n(n)
except AttributeError:
# desired rotate bytecode doesn't exist, generate equivalent bytecode
return (
[
create_instruction("BUILD_TUPLE", n),
self._create_load_const(rot_n_helper(n)),
]
+ create_rot_n(2)
+ [
create_instruction("CALL_FUNCTION_EX", 0),
create_instruction("UNPACK_SEQUENCE", n),
]
)
def pop_null(self):
# POP_TOP doesn't work for null, so we pop nulls by pushing in a
# nop function, calling it (which consumes the null), and popping the result.
assert sys.version_info >= (3, 11)
return (
[self._create_load_const(lambda: None)]
+ create_call_function(0, False)
+ [create_instruction("POP_TOP")]
)
def make_function_with_closure(
self, fn_name: str, code: types.CodeType, num_on_stack=0
):
freevars = code.co_freevars
assert freevars
output = self._output
for var in freevars:
assert var in self.cell_and_freevars()
output.append(
create_instruction(
"LOAD_CLOSURE",
self.cell_and_freevars_offset(self.cell_and_freevars().index(var)),
var,
)
)
output.append(create_instruction("BUILD_TUPLE", len(freevars)))
output.append(self.create_load_const(code))
if sys.version_info < (3, 11):
output.append(self.create_load_const(fn_name))
output.append(create_instruction("MAKE_FUNCTION", 0x08))
output.extend(self.rot_n(num_on_stack + 1))
self.clear_tos()
def create_load_python_module(self, mod, push_null):
"""
Generate a LOAD_GLOBAL instruction to fetch a given python module.
"""
root_globals = self.tx.output.root_globals
name = re.sub(r"^.*[.]", "", mod.__name__)
if root_globals.get(name, None) is mod:
return self.create_load_global(name, push_null, add=True)
mangled_name = f"___module_{name}_{id(mod)}"
if mangled_name not in root_globals:
self.tx.output.install_global(mangled_name, mod)
return self.create_load_global(mangled_name, push_null, add=True)
def make_call_generated_code(self, fn_name: str) -> List[Instruction]:
"""Call the generated code function stored in fn_name"""
self.extend_output(self.load_function_name(fn_name, True))
graphargs = self.tx.output.graphargs
for arg in graphargs:
if arg.is_unspecialized:
self.extend_output(
[
self.create_load_python_module(torch, True),
self.create_load_attr("tensor"),
]
)
self.extend_output(arg.load(self))
self.extend_output(create_call_function(1, False))
else:
self.extend_output(arg.load(self))
self.extend_output(create_call_function(len(graphargs), False))
def load_import_from(self, module_name, object_name):
self.extend_output(
AttrSource(self.tx.import_source(module_name), object_name).reconstruct(
self
)
)
def create_call_function_kw(self, nargs, kw_names, push_null):
if sys.version_info >= (3, 11):
output = create_call_function(nargs, push_null)
assert output[-2].opname == "PRECALL"
kw_names_inst = create_instruction(
"KW_NAMES", self.get_const_index(self.code_options, kw_names)
)
output.insert(-2, kw_names_inst)
return output
return [
self.create_load_const(kw_names),
create_instruction("CALL_FUNCTION_KW", nargs),
]