blob: 47a032cedfd988874c1f5ed598b45ab2cdfce2fd [file] [log] [blame]
import collections
import copy
import functools
import itertools
import logging
import operator
import re
import traceback
from dataclasses import dataclass
from typing import Any, Dict, List, NamedTuple, Optional, OrderedDict, Set, Union
import torch.nn
from torch import fx
from torch._guards import Checkpointable, Guard, GuardsCheckpointState, TracingContext
from torch.fx.experimental.symbolic_shapes import ShapeEnv
from . import config, logging as torchdynamo_logging, variables
from .backends.registry import CompiledFn, CompilerFn
from .bytecode_transformation import (
create_call_function,
create_instruction,
Instruction,
unique_id,
)
from .codegen import PyCodegen
from .exc import BackendCompilerFailed, unimplemented
from .guards import GuardBuilder
from .mutation_guard import is_dynamic_nn_module
from .side_effects import SideEffects
from .source import (
ConstantSource,
is_constant_source,
LocalInputSource,
LocalSource,
ShapeEnvSource,
)
from .utils import (
assert_no_fake_params_or_buffers,
checkpoint_params,
CleanupHook,
clone_inputs,
count_calls,
counters,
dynamo_timed,
format_graph_tabular,
same,
)
from .variables.base import VariableTracker
from .variables.builder import GraphArg, TrackedFake, VariableBuilder, wrap_fx_proxy
from .variables.nn_module import NNModuleVariable
from .variables.tensor import (
SymNodeVariable,
TensorVariable,
UnspecializedPythonVariable,
)
log = logging.getLogger(__name__)
class OutputGraphState(NamedTuple):
graphargs: List[GraphArg]
tracked_fakes: List[TrackedFake]
guard_state: GuardsCheckpointState
nn_modules: Optional[Dict[str, torch.nn.Module]]
side_effects: SideEffects
timestamp: int
def diff(self, other: "OutputGraphState", *, prefix: str = "") -> Optional[str]:
for k in self._fields:
if k == "guard_state":
r = self.guard_state.diff(other.guard_state)
if r is not None:
return r
continue
elif k == "side_effects":
r = self.side_effects.diff(other.side_effects)
if r is not None:
return r
continue
sv = getattr(self, k)
ov = getattr(other, k)
if sv != ov:
return f"{prefix}{k} mismatch: {sv} != {ov}"
return None
# Back compat .guards api
@property
def guards(self):
return self.guard_state.dynamo_guards
@functools.lru_cache(None)
def _step_logger():
return torchdynamo_logging.get_step_logger(log)
@dataclass
class GraphCompileReason:
"""Stores why a given output graph was compiled; i.e. what caused the graph break."""
reason: str
user_stack: List[traceback.FrameSummary]
def _get_gen_rand_values_fn(random_calls):
def _gen_rand_values():
return [fn(*args, **kwargs) for fn, args, kwargs in random_calls]
return _gen_rand_values
class FakeRootModule(torch.nn.Module):
"""Trick the constructor of fx.GraphModule"""
def __init__(self, nn_modules: Dict[str, torch.nn.Module]):
super().__init__()
for k, v in nn_modules.items():
setattr(self, k, v)
def __repr__(self):
return "FakeRootModule(...)"
class WrapperBackend:
def __init__(self, backend: CompilerFn, original_example_inputs):
self.backend: CompilerFn = backend
self.original_example_inputs = original_example_inputs
@property
def example_inputs(self):
return clone_inputs(self.original_example_inputs)
def __call__(self, gm: torch.fx.GraphModule, example_inputs: List[torch.Tensor]):
self.restore = checkpoint_params(gm)
self.gm = gm
copy_gm = copy.deepcopy(self.gm)
self.candidate = self.backend(copy_gm, self.original_example_inputs)
if self.candidate is None or self.candidate is self.gm.forward:
return self.gm.forward
if not config.verify_correctness:
return self.candidate
# if verify_correctness=True
try:
correct = self.gm.forward(*self.example_inputs)
result = self.candidate(*self.example_inputs)
# TODO: replace `same` function with the one in testing
if same(correct, result):
return self.candidate
raise RuntimeError(f"incorrect results of backend {self}")
return self.gm.forward
except Exception:
log.exception("error in verify_correctness")
raise
finally:
self.restore()
class OutputGraph(fx.Tracer, Checkpointable[OutputGraphState]):
"""
Wrapper class to hold outputs of InstructionTranslator. Mainly the
generated fx.Graph.
"""
def __init__(
self,
f_globals: Dict[str, Any],
code_options: Dict[str, Any],
compiler_fn: CompilerFn,
root_tx,
export: bool,
):
super().__init__()
self.graph = torch.fx.Graph()
self.graphargs: List[GraphArg] = []
self.export = export
# In export mode, we force the shape_env to strictly disallow any constraining
# of the user marked dynamic dims
fake_mode = torch._subclasses.FakeTensorMode(
shape_env=ShapeEnv(
allow_scalar_outputs=config.capture_scalar_outputs,
allow_dynamic_output_shape_ops=config.capture_dynamic_output_shape_ops,
strict_mark_dyn=export,
assume_static_by_default=config.assume_static_by_default,
)
if config.dynamic_shapes
else None,
)
self.tracing_context: TracingContext = TracingContext(fake_mode)
if config.dynamic_shapes:
# Register a SHAPE_ENV guard to make sure we setup shape guards
# that show up in ShapeEnv
self.guards.add(ShapeEnvSource().make_guard(GuardBuilder.SHAPE_ENV))
# tracked_fakes says where any tensor that was wrapped to fake came
# from. It is similar to GraphArg, in that all GraphArgs will get
# will get added to TrackedFakes, but TrackedFakes also contains
# GraphArgs that got pruned, and things like Tensor attributes which
# aren't explicit graph inputs. Used by shape guard
self.tracked_fakes: List[TrackedFake] = []
# Although we prune unused graphargs before sending graphs to
# compilers, we may have legitimately triggered shape guards
# on "unused" inputs that we must keep track of. So after
# remove_unused_graphargs is called, orig_graphargs and
# graphargs no longer alias; orig_graphargs is the original
# graphargs, and graphargs is the pruned list. Guard creation
# should use original graphargs.
self.orig_graphargs: List[GraphArg] = self.graphargs
self.nn_modules: Optional[Dict[str, torch.nn.Module]] = dict()
self.side_effects = SideEffects()
self.code_options = dict(code_options)
self.output_instructions: List[Instruction] = []
# used to track nodes that are added between calls of copy_graphstate
# and restore_graphstate
self.timestamp = 0
# Node => computed real value (see utils.get_real_value)
self.real_value_cache: Dict[fx.Node, torch.Tensor] = {}
# Not checkpointed
self.compiler_fn: CompilerFn = compiler_fn
self.root_globals = f_globals
self.root_tx = root_tx
from torch._dynamo.symbolic_convert import InstructionTranslatorBase
self._current_tx: List[InstructionTranslatorBase] = []
self.cleanups: List[CleanupHook] = []
self.should_exit = False
self.random_values_var = None
self.initial_random_state = ()
self.unspec_variable_map: Dict[str, UnspecializedPythonVariable] = {}
# Maps the source arg position to the grapharg position
self.pos_to_arg: Dict[int, int] = {}
# Enables creating unique node names by tracking
# all current placeholder node names
self.name_to_input: OrderedDict[
str, Optional[fx.Proxy]
] = collections.OrderedDict()
@property
def output(self):
return self
@property
def fake_mode(self):
return self.root_tx.fake_mode
@property
def shape_env(self):
return self.tracing_context.fake_mode.shape_env
@property
def guards(self) -> Set[Guard]:
return self.tracing_context.guards_context.dynamo_guards
def push_tx(self, tx):
self._current_tx.append(tx)
def pop_tx(self):
return self._current_tx.pop()
@property
def current_tx(self):
return self.root_tx if not self._current_tx else self._current_tx[-1]
def copy_graphstate(self) -> OutputGraphState:
"""Create a checkpoint of the current state by copying everything"""
assert self.nn_modules is not None
guards_graph_state = self.tracing_context.guards_context.copy_graphstate()
state = OutputGraphState(
list(self.graphargs),
list(self.tracked_fakes),
guards_graph_state,
dict(self.nn_modules),
self.side_effects.clone(),
self.timestamp,
)
self.timestamp += 1
return state
def restore_graphstate(self, state: OutputGraphState):
"""Restore a checkpoint created by self.copy_graphstate()"""
(
self.graphargs,
self.tracked_fakes,
guards_state,
self.nn_modules,
self.side_effects,
self.timestamp,
) = state
self.tracing_context.guards_context.restore_graphstate(guards_state)
# FX deepcopy doesn't work for a partially created graph, so just remove new nodes
removed_nodes = 0
for node in reversed(list(self.graph.nodes)):
if node.meta["creation_timestamp"] > self.timestamp:
# Erasing node alone does not remove the meta information
# So, remove the help tensor explicitly
if "example_value" in node.meta:
del node.meta["example_value"]
self.remove_node(node)
self.real_value_cache.pop(node, None)
removed_nodes += 1
log.debug(f"restore_graphstate: removed {removed_nodes} nodes")
def add_grapharg(self, arg: GraphArg):
curr_pos = len(self.graphargs)
self.graphargs.append(arg)
if isinstance(arg.source, LocalInputSource):
self.pos_to_arg[arg.source.pos] = curr_pos
def count_calls(self):
return count_calls(self.graph)
def get_submodule(self, keys):
assert keys
obj = self.nn_modules
for k in keys.split("."):
if isinstance(obj, dict):
obj = obj[k]
else:
obj = getattr(obj, k)
return obj
def create_graph_input(self, name, type_expr=None):
# unique
if name in self.name_to_input:
for i in itertools.count():
if f"{name}_{i}" not in self.name_to_input:
name = f"{name}_{i}"
break
if self.name_to_input:
prev_name = next(reversed(self.name_to_input))
ctx = self.graph.inserting_after(self.name_to_input[prev_name])
else:
ctx = self.graph.inserting_before(None)
with ctx:
proxy = self.create_proxy("placeholder", name, (), {}, type_expr=type_expr)
self.name_to_input[name] = proxy.node
return proxy
def new_var(self, name="tmp"):
existing = set(self.code_options["co_varnames"])
for i in itertools.count():
var = f"___{name}_{i}"
if var not in existing:
self.code_options["co_varnames"] += (var,)
return var
def update_co_names(self, name):
"""Ensure self.code_options.co_names contains name"""
if name not in self.code_options["co_names"]:
self.code_options["co_names"] += (name,)
@staticmethod
def module_has_hooks(mod, only_check_unsupported=False):
supported_hooks = [
"_forward_pre_hooks",
"_forward_hooks",
]
unsupported_hooks = [
"_backward_pre_hooks",
"_backward_hooks",
"_state_dict_pre_hooks",
"_state_dict_hooks",
"_load_state_dict_pre_hooks",
"_load_state_dict_post_hooks",
]
check_hooks = unsupported_hooks
if not only_check_unsupported:
check_hooks += supported_hooks
return any(len(getattr(mod, x)) > 0 for x in check_hooks if hasattr(mod, x))
def register_attr_or_module(
self,
target: Union[torch.nn.Module, torch.Tensor, Any],
*names,
**options,
):
if is_dynamic_nn_module(target):
return variables.UnspecializedNNModuleVariable(target, **options)
options = dict(options)
options["guards"] = set(options.get("guards", []))
assert "source" in options
source = options["source"]
if isinstance(target, torch.Tensor):
if not is_constant_source(source):
options["guards"].add(source.make_guard(GuardBuilder.TENSOR_MATCH))
def wrap_name(module_key):
return wrap_fx_proxy(
self.root_tx,
self.create_proxy("get_attr", module_key, tuple(), {}),
example_value=target,
**options,
)
elif isinstance(target, torch.nn.Module):
assert isinstance(target, torch.nn.Module)
if self.module_has_hooks(target, only_check_unsupported=True):
log.warning(
"nn.Module hooks are not fully supported, they may be ignored"
)
options["guards"].add(source.make_guard(GuardBuilder.NN_MODULE))
def wrap_name(module_key):
return NNModuleVariable(type(target), module_key, **options)
elif isinstance(target, (torch.SymInt, torch.SymFloat)):
# HACKY CODE REGION BEGIN
# WE ARE PIGGYBACKING ON EXISTING INFRA TO REGISTER ATTRS
# This ultimately gets written to self.nn_modules, which is unfortunate
# Attrs that are tenors and symints and such need to be migrated to have their
# own storage
# alas, this is like this for now
def wrap_name(module_key):
return SymNodeVariable.create(
self,
self.create_proxy("get_attr", module_key, tuple(), {}),
sym_num=target,
**options,
)
# HACKY CODE REGION END
else:
def wrap_name(module_key):
self.output.update_co_names(module_key)
self.root_globals[module_key] = target
return VariableBuilder(self, ConstantSource(source_name=module_key))(
target
)
assert self.nn_modules is not None
for k, v in self.nn_modules.items():
if v is target:
# it already exists
return wrap_name(k)
# create a new unique name
name = "_".join(map(str, names))
# e.g. replace abc.xyz[123].qkv with abc.xyz_123.qkv
name = re.sub(r"\[(\d+)\]", r"_\g<1>", name)
# e.g. replace abc.xyz_123.qkv with abc_xyz_123_qkv
name = re.sub(r"[^a-zA-Z0-9]", "_", name)
if not name or not name[0].isalpha():
name = "sub" + name
base = name
for i in itertools.count():
if name not in self.nn_modules:
self.nn_modules[name] = target
return wrap_name(name)
name = f"{base}_{i}"
raise AssertionError("unreachable")
def compile_subgraph(
self, tx, partial_convert=False, reason: Optional[GraphCompileReason] = None
):
"""
Generate a subgraph to continue execution on user code.
Automatically restore live variables.
"""
from .eval_frame import disable
self.partial_convert = partial_convert
self.compile_subgraph_reason = reason
log.debug(f"COMPILING GRAPH due to {reason}")
if not all(block.can_restore() for block in tx.block_stack):
unimplemented("compile_subgraph with block_depth != 0")
for block in reversed(tx.block_stack):
block.exit(tx)
tx.prune_dead_locals()
stack_values = list(tx.stack)
assert self.nn_modules is not None
root = FakeRootModule(self.nn_modules)
# Add all the local vars to the "stack" so restore at the end
restore_vars = []
val_to_names: OrderedDict[
VariableTracker, List[str]
] = collections.OrderedDict()
if stack_values:
val_to_names[stack_values[-1]] = list()
for k, v in tx.symbolic_locals.items():
if isinstance(v.source, LocalSource) and v.source.name() == k:
continue # no need to restore initial state
if v not in val_to_names:
val_to_names[v] = list()
val_to_names[v].append(k)
for v in val_to_names.keys():
restore_vars.extend(val_to_names[v])
stack_values.extend([v] * len(val_to_names[v]))
# to handle random calls
if len(tx.random_calls) > 0:
random_calls_instructions = []
self.random_values_var = self.new_var("random_values")
rand_fn_name = unique_id("__gen_rand_values")
rand_fn = disable(_get_gen_rand_values_fn(tx.random_calls))
self.install_global(rand_fn_name, rand_fn)
codegen = PyCodegen(tx, root)
random_calls_instructions.extend(
[
codegen.create_load_global("random", True, add=True),
codegen.create_load_attr("setstate"),
codegen.create_load_const(tx.output.initial_random_state),
]
+ create_call_function(1, False),
)
random_calls_instructions.extend(
codegen.load_function_name(rand_fn_name, True)
)
random_calls_instructions.extend(create_call_function(0, False))
random_calls_instructions.append(
codegen.create_store(tx.output.random_values_var),
)
self.add_output_instructions(random_calls_instructions)
if (
stack_values
and all(
not isinstance(v, UnspecializedPythonVariable) for v in stack_values
)
and all(isinstance(x, TensorVariable) for x in stack_values)
and len(set(stack_values)) == len(stack_values)
and self.side_effects.is_empty()
):
# optimization to generate better code in a common case
self.add_output_instructions(
self.compile_and_call_fx_graph(tx, list(reversed(stack_values)), root)
+ [create_instruction("UNPACK_SEQUENCE", len(stack_values))]
)
else:
graph_output_var = self.new_var("graph_out")
pass1 = PyCodegen(tx, root, graph_output_var)
self.side_effects.codegen_save_tempvars(pass1)
pass1.foreach(stack_values)
self.side_effects.codegen_update_mutated(pass1)
# one more time now that we have established tempvars
pass2 = PyCodegen(
tx,
root,
graph_output_var,
tempvars={val: None for val, count in pass1.uses.items() if count > 1},
)
self.side_effects.codegen_save_tempvars(pass2)
pass2.foreach(stack_values)
self.side_effects.codegen_update_mutated(pass2)
output = []
if count_calls(self.graph) != 0 or len(pass2.graph_outputs) != 0:
output.extend(
self.compile_and_call_fx_graph(tx, pass2.graph_output_vars(), root)
)
if len(pass2.graph_outputs) != 0:
output.append(pass2.create_store(graph_output_var))
else:
output.append(create_instruction("POP_TOP"))
self.add_output_instructions(output + pass2.get_instructions())
# restore all the live local vars
self.add_output_instructions(
[PyCodegen(tx).create_store(var) for var in reversed(restore_vars)]
)
def compile_and_call_fx_graph(self, tx, rv, root):
"""
Generate code from self.graph and return the Instruction()s to
call that generated code.
"""
from .eval_frame import disable
assert isinstance(rv, list)
assert isinstance(root, FakeRootModule)
for output in rv:
self.guards.update(output.guards)
self.create_node(
"output", "output", (self.create_arg(tuple(x.as_proxy() for x in rv)),), {}
)
self.remove_unused_graphargs()
ncalls = count_calls(self.graph)
counters["stats"]["calls_captured"] += ncalls
counters["stats"]["fusions_possible"] += ncalls - 1
# free a bit of memory
for node in self.graph.nodes:
if "example_value" in node.meta:
del node.meta["example_value"]
self.real_value_cache.clear()
gm = fx.GraphModule(root, self.graph)
gm.recompile()
gm.compile_subgraph_reason = self.compile_subgraph_reason
name = unique_id("__compiled_fn")
assert_no_fake_params_or_buffers(gm)
compiled_fn = self.call_user_compiler(gm)
compiled_fn = disable(compiled_fn)
counters["stats"]["unique_graphs"] += 1
self.install_global(name, compiled_fn)
try:
# the call to tabulate can cause a lot of memory to be allocated
if config.log_level <= logging.INFO and config.output_code:
graph_str = (
gm.print_readable()
if config.output_graph_code
else format_graph_tabular(gm.graph)
)
log.log(
logging.INFO,
f"TRACED GRAPH\n {name} {gm.forward.__code__.co_filename} {graph_str}\n",
)
except ImportError:
log.warning(
"Unable to print graph: `format_graph_tabular` relies on the library `tabulate`, "
"which could not be found on this machine. Run `pip "
"install tabulate` to install the library."
)
cg = PyCodegen(tx)
cg.make_call_generated_code(name)
return cg.get_instructions()
@dynamo_timed(phase_name="backend_compile")
def call_user_compiler(self, gm: fx.GraphModule) -> CompiledFn:
tot = 0
for node in gm.graph.nodes:
if node.op in ("call_function", "call_method", "call_module"):
tot += 1
torch._dynamo.utils.increment_op_count(tot)
try:
name = (
self.compiler_fn.__name__
if hasattr(self.compiler_fn, "__name__")
else ""
)
_step_logger()(logging.INFO, f"calling compiler function {name}")
compiler_fn = self.compiler_fn
# WrapperBackend needs real inputs, for now, to verify correctness
if config.verify_correctness:
compiler_fn = WrapperBackend(compiler_fn, self.example_inputs())
# NOTE: [Real Tensors in Accuracy Evaluation]
#
# Today, tensors are passed to backends as fake at compile time. See the .fake_example_inputs()
# call to compiler_fn below. At runtime, backends use real tensors.
#
# This should be a strong invariant we hold across all backends,
# and generally, it is. However, for accuracy evaluation, we need real tensors at compile time,
# for now, due to the unfortunate setup described below.
#
# Due to the nature of how we invoke comparison as a backend in two different ways:
#
# (1) Less bad, but still worth rewriting, WrapperBackend above, which takes
# real inputs for its ctor. see the config.verify_correctnes above.
#
# (2) More bad, and very worth rewriting, the minifier installs accuracy comparison as
# a true backend, and therefore needs to be compiled with real inputs. This is made trickier
# by the fact that the minifier will spawn new processes during minification. As such, we have
# created a global flag, MINIFIER_SPAWNED, that should be set IF AND ONLY IF this run was spawned
# as part of accuracy minification. This flag is not a contract, and ideally will not be here long.
#
# The longer term PoR is to:
# (A) Rewrite the minifier accuracy evaluation and verify_correctness code to share the same
# correctness and accuracy logic, so as not to have two different ways of doing the same thing.
#
# (B) Refactor minifier accuracy backend to do its comparison fully at runtime, so as not to need to
# pass real tensors to it at compile time.
is_top_level_minifying = (
config.repro_after is not None and config.repro_level == 4
)
if torch._dynamo.debug_utils.MINIFIER_SPAWNED or is_top_level_minifying:
compiled_fn = compiler_fn(gm, self.example_inputs())
elif config.DO_NOT_USE_legacy_non_fake_example_inputs:
compiled_fn = compiler_fn(gm, self.example_inputs())
else:
compiled_fn = compiler_fn(gm, self.fake_example_inputs())
_step_logger()(logging.INFO, f"done compiler function {name}")
assert callable(compiled_fn), "compiler_fn did not return callable"
except Exception as e:
raise BackendCompilerFailed(self.compiler_fn, e).with_traceback(
e.__traceback__
) from None
return compiled_fn
def fake_example_inputs(self) -> List[torch.Tensor]:
result = []
for arg in self.graphargs:
example = arg.get_fake_examples()
if example is not None:
result.extend(example)
else:
# Fallback, in case fake_tensor was not set
# Particularly for graph args that are not tensors
result.extend(arg.get_examples())
return result
def example_inputs(self) -> List[torch.Tensor]:
result = []
for arg in self.graphargs:
result.extend(arg.get_examples())
return result
def remove_unused_graphargs(self) -> None:
for node in reversed(list(self.graph.nodes)):
if len(list(node.users)) == 0:
if node.op == "get_attr":
self.remove_node(node)
elif node.op == "call_function" and node.target is operator.getitem:
self.remove_node(node)
expanded_graphargs = []
for arg in self.graphargs:
expanded_graphargs.extend([arg] * len(arg))
arg.uses = 0
for node, arg in zip(self.graph.nodes, expanded_graphargs):
assert node.op == "placeholder"
arg.uses += len(node.users)
for node, arg in list(zip(self.graph.nodes, expanded_graphargs)):
if arg.uses == 0:
log.debug(f"REMOVE UNUSED GRAPHARG {arg.source.name()}")
if "example_value" in node.meta:
del node.meta["example_value"]
self.remove_node(node)
self.real_value_cache.pop(node, None)
self.graphargs = [arg for arg in self.graphargs if arg.uses > 0]
def add_output_instructions(self, prefix: List[Instruction]) -> None:
"""
We call this on the creation of a new compiled subgraph that is inserted
before user code.
"""
self.output_instructions.extend(prefix)
self.should_exit = True
def install_global(self, name, value) -> None:
self.cleanups.append(CleanupHook.create(self.root_globals, name, value))
def cleanup(self) -> None:
# There is a reference cycle between tracer and OutputGraph, causing
# some of the tensor objects to be held alive for longer than necessary.
self.root_tx = None
# Note: generated fx graph will hold a reference to the nn_module,
# So depending on the backend they may not be released
self.nn_modules = None
# Cleanup graphargs
for graph_arg in self.graphargs:
graph_arg.erase()
for node in self.graph.nodes:
if "example_value" in node.meta:
del node.meta["example_value"]
self.real_value_cache.clear()
self.name_to_input.clear()
self.side_effects.keepalive = []
def create_proxy(
self,
kind,
target,
args,
kwargs,
name=None,
type_expr=None,
proxy_factory_fn=None,
):
rv = super().create_proxy(
kind, target, args, kwargs, name, type_expr, proxy_factory_fn
)
# append stack trace to fx node
tx = self.current_tx
nn_module_stack = tx.nn_module_stack
if nn_module_stack:
rv.node.meta["nn_module_stack"] = nn_module_stack.copy()
if kind in {"call_function", "call_method"}:
rv.node.meta["source_fn"] = target
elif kind == "call_module":
# For modules we store the class
rv.node.meta["source_fn"] = rv.node.meta["nn_module_stack"][target][1]
frame_summaries: List[traceback.FrameSummary] = []
while tx:
frame_summaries.append(tx.frame_summary())
tx = getattr(tx, "parent", None)
# Reverse the frame_summaries, such that the innermost frame is at the last
frame_summaries.reverse()
# official from_list stub doesn't have new-style type
msgs = traceback.StackSummary.from_list(frame_summaries).format() # type: ignore[arg-type]
rv.node.stack_trace = "".join(msgs)
return rv
def create_node(self, *args, **kwargs):
node = super().create_node(*args, **kwargs)
node.meta["creation_timestamp"] = self.timestamp
return node
# Note: we did not override erase_node since
# we call self.graph.erase_node elsewhere
def remove_node(self, node):
self.graph.erase_node(node)
self.name_to_input.pop(node.name, None)