blob: c23d4f6dd993492f61a99410dd48a9196fc608d7 [file] [log] [blame]
import collections
import functools
import itertools
import logging
import operator
import re
import traceback
from dataclasses import dataclass
from typing import Any, Callable, Dict, List, Optional
import torch.nn
from torch import fx
from torch.fx.experimental.symbolic_shapes import ShapeEnv
from . import config, logging as torchdynamo_logging, variables
from .bytecode_transformation import create_instruction, Instruction, unique_id
from .codegen import PyCodegen
from .exc import BackendCompilerFailed, unimplemented
from .guards import GuardBuilder
from .mutation_guard import is_dynamic_nn_module
from .side_effects import SideEffects
from .source import ConstantSource, LocalSource, Source
from .utils import (
CleanupHook,
count_calls,
counters,
fake_tensors_available,
format_graph_tabular,
)
from .variables.builder import VariableBuilder
from .variables.nn_module import NNModuleVariable
from .variables.tensor import (
TensorVariable,
UnspecializedNumpyVariable,
UnspecializedPythonVariable,
)
log = logging.getLogger(__name__)
@functools.lru_cache(None)
def _step_logger():
return torchdynamo_logging.get_step_logger(log)
@dataclass
class GraphCompileReason:
"""Stores why a given output graph was compiled; i.e. what caused the graph break."""
reason: str
user_stack: List[traceback.FrameSummary]
def _get_gen_rand_values_fn(random_calls):
def _gen_rand_values():
return [fn(*args, **kwargs) for fn, args, kwargs in random_calls]
return _gen_rand_values
class FakeRootModule(torch.nn.Module):
"""Trick the constructor of fx.GraphModule"""
def __init__(self, nn_modules: dict):
super(FakeRootModule, self).__init__()
for k, v in nn_modules.items():
setattr(self, k, v)
def __repr__(self):
return "FakeRootModule(...)"
class OutputGraph(fx.Tracer):
"""
Wrapper class to hold outputs of InstructionTranslator. Mainly the
generated fx.Graph.
"""
def __init__(
self,
f_globals: Dict[str, Any],
code_options: Dict[str, Any],
compiler_fn: Callable,
root_tx,
):
super(OutputGraph, self).__init__()
# Mutable state checkpointed by copy_graphstate()
self.graph = torch.fx.Graph()
self.graphargs = []
self.guards = set()
self.nn_modules = dict()
self.side_effects = SideEffects()
self.code_options = dict(code_options)
self.output_instructions = []
# Node => computed real value (see TensorVariable.get_real_value)
self.real_value_cache = {}
# Not checkpointed
self.compiler_fn = compiler_fn
self.root_globals = f_globals
self.root_tx = root_tx
self.cleanups = []
self.should_exit = False
self.random_values_var = None
self.initial_random_state = ()
self.unspec_variable_map = {}
self.shape_env = ShapeEnv() if config.dynamic_shapes else None
self.tensor_id_to_sym_shape_ref = {}
@property
def output(self):
return self
@property
def fake_mode(self):
return self.root_tx.fake_mode
def copy_graphstate(self):
"""Create a checkpoint of the current state by copying everything"""
graph_nodes = set(self.graph.nodes)
return (
graph_nodes,
list(self.graphargs),
set(self.guards),
dict(self.nn_modules),
self.side_effects.clone(),
)
def restore_graphstate(self, state):
"""Restore a checkpoint created by self.copy_graphstate()"""
(
graph_nodes,
self.graphargs,
self.guards,
self.nn_modules,
self.side_effects,
) = state
# FX deepcopy doesn't work for a partially created graph, so just remove new nodes
for node in reversed(list(self.graph.nodes)):
if node not in graph_nodes:
# Erasing node alone does not remove the meta information
# So, remove the help tensor explicitly
if "example_value" in node.meta:
del node.meta["example_value"]
self.graph.erase_node(node)
self.real_value_cache.pop(node, None)
def count_calls(self):
return count_calls(self.graph)
def get_submodule(self, keys):
assert keys
obj = self.nn_modules
for k in keys.split("."):
if isinstance(obj, dict):
obj = obj[k]
else:
obj = getattr(obj, k)
return obj
def create_graph_input(self, name, type_expr=None):
placeholders = [n for n in self.graph.nodes if n.op == "placeholder"]
# unique
used_names = {n.target for n in placeholders}
if name in used_names:
for i in itertools.count():
if f"{name}_{i}" not in used_names:
name = f"{name}_{i}"
break
if placeholders:
ctx = self.graph.inserting_after(placeholders[-1])
else:
ctx = self.graph.inserting_before(None)
with ctx:
return self.create_proxy("placeholder", name, (), {}, type_expr=type_expr)
def new_var(self, name="tmp"):
existing = set(self.code_options["co_varnames"])
for i in itertools.count():
var = f"___{name}_{i}"
if var not in existing:
self.code_options["co_varnames"] = self.code_options["co_varnames"] + (
var,
)
return var
def update_co_names(self, name):
"""Ensure self.code_options.co_names contains name"""
if name not in self.code_options["co_names"]:
self.code_options["co_names"] = tuple(self.code_options["co_names"]) + (
name,
)
def register_attr_or_module(self, mod: torch.nn.Module, *names, **options):
if is_dynamic_nn_module(mod):
return variables.UnspecializedNNModuleVariable(mod, **options)
options = dict(options)
options["guards"] = set(options.get("guards", []))
source: Source = options.get("source", None)
if isinstance(mod, torch.Tensor):
if source:
options["guards"].add(source.make_guard(GuardBuilder.TENSOR_MATCH))
def wrap_name(module_key):
return TensorVariable.create(
self,
self.create_proxy("get_attr", module_key, tuple(), {}),
example_value=mod,
**options,
)
elif isinstance(mod, torch.nn.Module):
assert isinstance(mod, torch.nn.Module)
options["guards"].add(source.make_guard(GuardBuilder.NN_MODULE))
def wrap_name(module_key):
return NNModuleVariable(type(mod), module_key, **options)
else:
def wrap_name(module_key):
self.output.update_co_names(module_key)
self.root_globals[module_key] = mod
return VariableBuilder(self, ConstantSource(source_name=module_key))(
mod
)
for k, v in self.nn_modules.items():
if v is mod:
# it already exists
return wrap_name(k)
# create a new unique name
name = "_".join(map(str, names))
# e.g. repalce abc.xyz[123].qkv with abc.xyz_123.qkv
name = re.sub(r"\[(\d+)\]", r"_\g<1>", name)
# e.g. replace abc.xyz_123.qkv with abc_xyz_123_qkv
name = re.sub(r"[^a-zA-Z0-9]", "_", name)
if not name or not name[0].isalpha():
name = "sub" + name
base = name
for i in itertools.count():
if name not in self.nn_modules:
self.nn_modules[name] = mod
return wrap_name(name)
name = f"{base}_{i}"
raise AssertionError("unreachable")
def compile_subgraph(
self, tx, partial_convert=False, reason: Optional[GraphCompileReason] = None
):
"""
Generate a subgraph to continue execution on user code.
Automatically restore live variables.
"""
from .eval_frame import disable
self.partial_convert = partial_convert
self.compile_subgraph_reason = reason
if not all(block.can_restore() for block in tx.block_stack):
unimplemented("compile_subgraph with block_depth != 0")
for block in reversed(tx.block_stack):
block.exit(tx)
tx.prune_dead_locals()
stack_values = list(tx.stack)
root = FakeRootModule(self.nn_modules)
# Add all the local vars to the "stack" so restore at the end
restore_vars = []
val_to_names = collections.OrderedDict()
if stack_values:
val_to_names[stack_values[-1]] = list()
for k, v in tx.symbolic_locals.items():
if isinstance(v.source, LocalSource) and v.source.name() == k:
continue # no need to restore initial state
if v not in val_to_names:
val_to_names[v] = list()
val_to_names[v].append(k)
for v in val_to_names.keys():
restore_vars.extend(val_to_names[v])
stack_values.extend([v] * len(val_to_names[v]))
# to handle random calls
if len(tx.random_calls) > 0:
random_calls_instructions = []
self.random_values_var = self.new_var("random_values")
rand_fn_name = unique_id("__gen_rand_values")
rand_fn = disable(_get_gen_rand_values_fn(tx.random_calls))
self.install_global(rand_fn_name, rand_fn)
codegen = PyCodegen(tx, root)
random_calls_instructions.extend(
[
codegen.create_load_global("random", add=True),
codegen.create_load_attr("setstate"),
codegen.create_load_const(tx.output.initial_random_state),
create_instruction("CALL_FUNCTION", 1),
]
)
random_calls_instructions.extend(codegen.load_function_name(rand_fn_name))
random_calls_instructions.extend(
[
create_instruction("CALL_FUNCTION", 0),
codegen.create_store(tx.output.random_values_var),
]
)
self.add_output_instructions(random_calls_instructions)
if (
stack_values
and all(
not isinstance(
v, (UnspecializedNumpyVariable, UnspecializedPythonVariable)
)
for v in stack_values
)
and all(isinstance(x, TensorVariable) for x in stack_values)
and len(set(stack_values)) == len(stack_values)
and self.side_effects.is_empty()
):
# optimization to generate better code in a common case
self.add_output_instructions(
self.compile_and_call_fx_graph(tx, list(reversed(stack_values)), root)
+ [create_instruction("UNPACK_SEQUENCE", len(stack_values))]
)
else:
graph_output_var = self.new_var("graph_out")
pass1 = PyCodegen(tx, root, graph_output_var)
self.side_effects.codegen_save_tempvars(pass1)
pass1.foreach(stack_values)
self.side_effects.codegen_update_mutated(pass1)
# one more time now that we have established tempvars
pass2 = PyCodegen(
tx,
root,
graph_output_var,
tempvars={val: None for val, count in pass1.uses.items() if count > 1},
)
self.side_effects.codegen_save_tempvars(pass2)
pass2.foreach(stack_values)
self.side_effects.codegen_update_mutated(pass2)
output = []
if count_calls(self.graph) != 0 or len(pass2.graph_outputs) != 0:
output.extend(
self.compile_and_call_fx_graph(tx, pass2.graph_output_vars(), root)
)
if len(pass2.graph_outputs) != 0:
output.append(pass2.create_store(graph_output_var))
else:
output.append(create_instruction("POP_TOP"))
self.add_output_instructions(output + pass2.get_instructions())
# restore all the live local vars
self.add_output_instructions(
[PyCodegen(tx).create_store(var) for var in reversed(restore_vars)]
)
def compile_and_call_fx_graph(self, tx, rv, root):
"""
Generate code from self.graph and return the Instruction()s to
call that generated code.
"""
from .eval_frame import disable
assert isinstance(rv, list)
assert isinstance(root, FakeRootModule)
for output in rv:
self.guards.update(output.guards)
self.create_node(
"output", "output", (self.create_arg(tuple(x.as_proxy() for x in rv)),), {}
)
self.remove_unused_graphargs()
ncalls = count_calls(self.graph)
counters["stats"]["calls_captured"] += ncalls
counters["stats"]["fusions_possible"] += ncalls - 1
if config.dynamic_propagation:
# free a bit of memory
for node in self.graph.nodes:
if "example_value" in node.meta:
del node.meta["example_value"]
self.real_value_cache.clear()
gm = fx.GraphModule(root, self.graph)
gm.recompile()
gm.compile_subgraph_reason = self.compile_subgraph_reason
name = unique_id("__compiled_fn")
compiled_fn = self.call_user_compiler(gm)
compiled_fn = disable(compiled_fn)
counters["stats"]["unique_graphs"] += 1
self.install_global(name, compiled_fn)
try:
# the call to tabulate can cause a lot of memory to be allocated
if config.log_level <= logging.INFO:
log.log(
torchdynamo_logging.CODE,
f"TRACED GRAPH\n {name} {gm.forward.__code__.co_filename} {format_graph_tabular(gm.graph)}\n",
)
except ImportError:
log.warning(
"Unable to print graph: `format_graph_tabular` relies on the library `tabulate`, "
"which could not be found on this machine. Run `pip "
"install tabulate` to install the library."
)
cg = PyCodegen(tx)
cg.make_call_generated_code(name)
return cg.get_instructions()
def call_user_compiler(self, gm):
try:
name = (
self.compiler_fn.__name__
if hasattr(self.compiler_fn, "__name__")
else ""
)
_step_logger()(logging.INFO, f"calling compiler function {name}")
compiled_fn = self.compiler_fn(gm, self.example_inputs())
_step_logger()(logging.INFO, f"done compiler function {name}")
assert callable(compiled_fn), "compiler_fn did not return callable"
except Exception as e:
log.warning("-" * 40 + "\n")
log.warning("TORCHDYNAMO: backend compiler failed\n")
log.warning(e, exc_info=True)
log.warning("-" * 40 + "\n")
compiled_fn = gm.forward
raise BackendCompilerFailed(self.compiler_fn, e) from e
return compiled_fn
def example_inputs(self):
result = []
for arg in self.graphargs:
result.extend(arg.get_examples())
return result
def remove_unused_graphargs(self):
for node in reversed(list(self.graph.nodes)):
if len(list(node.users)) == 0:
if node.op == "get_attr":
self.graph.erase_node(node)
elif node.op == "call_function" and node.target is operator.getitem:
self.graph.erase_node(node)
expanded_graphargs = []
for arg in self.graphargs:
expanded_graphargs.extend([arg] * len(arg))
arg.uses = 0
for node, arg in zip(self.graph.nodes, expanded_graphargs):
assert node.op == "placeholder"
arg.uses += len(node.users)
for node, arg in list(zip(self.graph.nodes, expanded_graphargs)):
if arg.uses == 0:
if "example_value" in node.meta:
del node.meta["example_value"]
self.graph.erase_node(node)
self.real_value_cache.pop(node, None)
self.graphargs = [arg for arg in self.graphargs if arg.uses > 0]
def add_output_instructions(self, prefix: List[Instruction]):
"""
We call this on the creation of a new compiled subgraph that is inserted
before user code.
"""
self.output_instructions.extend(prefix)
self.should_exit = True
def install_global(self, name, value):
self.cleanups.append(CleanupHook.create(self.root_globals, name, value))
def cleanup(self):
# There is a reference cycle between tracer and OutputGraph, causing
# some of the tensor objects to be held alive for longer than necessary.
# Clear cache for conversion of real -> fake tensors
if fake_tensors_available:
self.root_tx.fake_mode.fake_tensor_converter = None
self.root_tx = None
# Note: generated fx graph will hold a reference to the nn_module,
# So depending on the backend they may not be released
self.nn_modules = None
# Cleanup graphargs
for graph_arg in self.graphargs:
graph_arg.erase()
for node in self.graph.nodes:
if "example_value" in node.meta:
del node.meta["example_value"]
self.real_value_cache.clear()
def create_proxy(
self,
kind,
target,
args,
kwargs,
name=None,
type_expr=None,
proxy_factory_fn=None,
current_tx=None,
):
rv = super().create_proxy(
kind, target, args, kwargs, name, type_expr, proxy_factory_fn
)
# append stack trace to fx node
tx = current_tx if current_tx else self.root_tx
nn_module_stack = tx.nn_module_stack
if nn_module_stack:
rv.node.meta["nn_module_stack"] = nn_module_stack.copy()
frame_summaries: List[traceback.FrameSummary] = []
while tx:
frame_summaries.append(tx.frame_summary())
tx = getattr(tx, "parent", None)
msgs = traceback.StackSummary.from_list(frame_summaries).format()
# Carry module_stack along with node.stack_trace for reusing stacktrace propagation infra
nn_module_stack_str = f"Module stack: {nn_module_stack}\n"
rv.node.stack_trace = nn_module_stack_str + " | ".join(msgs)
return rv