| import contextlib |
| import copy |
| import functools |
| import inspect |
| import logging |
| import os |
| import sys |
| import threading |
| import traceback |
| import types |
| import warnings |
| from importlib import import_module |
| from unittest.mock import patch |
| |
| import torch |
| import torch.utils._pytree as pytree |
| from torch.fx.experimental.proxy_tensor import make_fx |
| from torch.nn.parallel.distributed import DistributedDataParallel |
| |
| from . import config, convert_frame, skipfiles, utils |
| from .exc import ResetRequired |
| from .mutation_guard import install_generation_tagging_init |
| from .optimizations.distributed import DDPOptimizer |
| from .utils import checkpoint_params, clone_inputs, compile_times, same |
| |
| log = logging.getLogger(__name__) |
| |
| try: |
| from torch.fx.experimental import proxy_tensor |
| except ImportError: |
| proxy_tensor = None |
| |
| _eval_frame = torch._C._dynamo.eval_frame |
| set_eval_frame = _eval_frame.set_eval_frame |
| reset_code = _eval_frame.reset_code |
| unsupported = _eval_frame.unsupported |
| skip_code = _eval_frame.skip_code |
| set_guard_fail_hook = _eval_frame.set_guard_fail_hook |
| set_guard_error_hook = _eval_frame.set_guard_error_hook |
| always_optimize_code_objects = utils.ExactWeakKeyDictionary() |
| null_context = contextlib.nullcontext |
| unset = object() |
| compile_lock = threading.RLock() |
| most_recent_backend = None |
| |
| |
| def remove_from_cache(f): |
| """ |
| Make sure f.__code__ is not cached to force a recompile |
| """ |
| if isinstance(f, types.CodeType): |
| reset_code(f) |
| elif hasattr(f, "__code__"): |
| reset_code(f.__code__) |
| elif hasattr(getattr(f, "forward", None), "__code__"): |
| reset_code(f.forward.__code__) |
| else: |
| from . import reset |
| |
| reset() |
| log.warning("could not determine __code__ for %s", f) |
| |
| |
| def nothing(): |
| pass |
| |
| |
| def innermost_fn(fn): |
| """ |
| In case of nesting of _TorchDynamoContext calls, find the innermost |
| function. TorchDynamo caches on fn.__code__ object, so its necessary to find |
| the innermost function to pass on the optimize, run, disable etc. |
| """ |
| unaltered_fn = fn |
| while hasattr(unaltered_fn, "_torchdynamo_orig_callable"): |
| unaltered_fn = unaltered_fn._torchdynamo_orig_callable |
| assert callable(unaltered_fn) |
| return unaltered_fn |
| |
| |
| class _TorchDynamoContext: |
| def __init__( |
| self, |
| callback, |
| on_enter=nothing, |
| backend_ctx_ctor=null_context, |
| patch_fn=nothing, |
| first_ctx=False, |
| ): |
| super().__init__() |
| assert callable(callback) or callback is False or callback is None |
| self.callback = callback |
| self.prior = unset |
| self.on_enter = on_enter |
| self.extra_ctx_ctor = backend_ctx_ctor |
| self.first_ctx = first_ctx |
| patch_fn() |
| |
| def __enter__(self): |
| if config.raise_on_ctx_manager_usage: |
| raise RuntimeError( |
| "torchdynamo.optimize(...) is used with a context manager. " |
| "Please refer to https://github.com/pytorch/torchdynamo#usage-example " |
| "to use torchdynamo.optimize(...) as an annotation/decorator. " |
| ) |
| utils.debug_dir.setup() |
| self.on_enter() |
| self.prior = set_eval_frame(self.callback) |
| self.backend_ctx = self.extra_ctx_ctor() |
| self.backend_ctx.__enter__() |
| |
| def __exit__(self, exc_type, exc_val, exc_tb): |
| utils.debug_dir.clear() |
| set_eval_frame(self.prior) |
| self.prior = unset |
| self.backend_ctx.__exit__(exc_type, exc_val, exc_tb) |
| |
| def __call__(self, fn): |
| fn = innermost_fn(fn) |
| # Optimize the forward method of torch.nn.Module object |
| if isinstance(fn, torch.nn.Module): |
| mod = fn |
| optimized_forward = self(mod.forward) |
| |
| class TorchDynamoNNModuleWrapper: |
| """ |
| A wrapper that redirects the forward call to the optimized |
| forward, while for rest it redirects the calls to the original |
| module. |
| """ |
| |
| def __getattr__(self, name): |
| return getattr(mod, name) |
| |
| def forward(self, *args, **kwargs): |
| return optimized_forward(*args, **kwargs) |
| |
| def __call__(self, *args, **kwargs): |
| return self.forward(*args, **kwargs) |
| |
| new_mod = TorchDynamoNNModuleWrapper() |
| # Save the function pointer to find the original callable while nesting |
| # of decorators. |
| new_mod._torchdynamo_orig_callable = mod |
| return new_mod |
| |
| assert callable(fn) |
| callback = self.callback |
| on_enter = self.on_enter |
| backend_ctx_ctor = self.extra_ctx_ctor |
| |
| @functools.wraps(fn) |
| def _fn(*args, **kwargs): |
| on_enter() |
| utils.debug_dir.setup() |
| prior = set_eval_frame(callback) |
| backend_ctx = backend_ctx_ctor() |
| backend_ctx.__enter__() |
| try: |
| return fn(*args, **kwargs) |
| finally: |
| utils.debug_dir.clear() |
| set_eval_frame(prior) |
| backend_ctx.__exit__(None, None, None) |
| |
| # hooks to properly handle inlining |
| if isinstance(self, DisableContext): |
| _fn._torchdynamo_disable = True |
| else: |
| _fn._torchdynamo_inline = fn |
| |
| # Save the function pointer to find the original callable while nesting |
| # of decorators. |
| _fn._torchdynamo_orig_callable = fn |
| |
| # If the function is called using torchdynamo.optimize decorator, we |
| # should prevent any type of skipping. |
| if callback not in (None, False): |
| always_optimize_code_objects[fn.__code__] = True |
| |
| return _fn |
| |
| |
| class OptimizeContext(_TorchDynamoContext): |
| def __init__(self, callback, backend_ctx_ctor, first_ctx=False): |
| def on_enter(): |
| global most_recent_backend |
| if ( |
| most_recent_backend is not None |
| and most_recent_backend is not compiler_fn |
| ): |
| raise ResetRequired() |
| most_recent_backend = compiler_fn |
| install_generation_tagging_init() |
| |
| compiler_fn = innermost_fn(callback) |
| super().__init__( |
| callback=callback, |
| on_enter=on_enter, |
| backend_ctx_ctor=backend_ctx_ctor, |
| patch_fn=TorchPatcher.patch, |
| first_ctx=first_ctx, |
| ) |
| |
| |
| class RunOnlyContext(_TorchDynamoContext): |
| def __init__(self): |
| super().__init__(callback=False) |
| |
| |
| class DisableContext(_TorchDynamoContext): |
| def __init__(self): |
| super().__init__(callback=None) |
| |
| |
| def catch_errors_wrapper(callback): |
| @functools.wraps(callback) |
| def catch_errors(frame, cache_size): |
| try: |
| if frame.f_lasti >= 0 or skipfiles.check(frame.f_code.co_filename): |
| log.debug(f"skipping {frame.f_code.co_name} {frame.f_code.co_filename}") |
| return None |
| if ( |
| frame.f_code.co_filename == "<string>" |
| and frame.f_code.co_name == "__new__" |
| ): |
| # nametuple constructor |
| return None |
| if config.optimize_ddp: |
| ddp_module = DistributedDataParallel._get_active_ddp_module() |
| if ddp_module: |
| with compile_lock: |
| ddp_optimizer = DDPOptimizer( |
| bucket_bytes_cap=ddp_module.bucket_bytes_cap, |
| parameters_to_ignore=ddp_module.parameters_to_ignore, |
| backend_compile_fn=callback._torchdynamo_orig_callable, |
| ) |
| hijacked_callback = convert_frame.convert_frame( |
| ddp_optimizer.compile_fn, guard_export_fn=None |
| ) |
| return hijacked_callback(frame, cache_size) |
| |
| with compile_lock: |
| return callback(frame, cache_size) |
| except Exception: |
| log.exception("Error while processing frame") |
| raise |
| |
| catch_errors._torchdynamo_orig_callable = callback |
| return catch_errors |
| |
| |
| def _optimize_catch_errors(compile_fn, backend_ctx_ctor=null_context): |
| return OptimizeContext( |
| catch_errors_wrapper(compile_fn), |
| backend_ctx_ctor=backend_ctx_ctor, |
| first_ctx=True, |
| ) |
| |
| |
| class WrapperBackend: |
| def __init__(self, backend=None): |
| self.backend = backend |
| |
| @property |
| def example_inputs(self): |
| return clone_inputs(self.original_example_inputs) |
| |
| def __call__(self, gm: torch.fx.GraphModule, example_inputs): |
| |
| self.restore = checkpoint_params(gm) |
| self.original_example_inputs = clone_inputs(example_inputs) |
| self.gm = gm |
| copy_gm = copy.deepcopy(self.gm) |
| self.candidate = self.backend(copy_gm, self.original_example_inputs) |
| |
| if self.candidate is None or self.candidate is self.gm.forward: |
| return self.gm.forward |
| |
| if not config.verify_correctness: |
| return self.candidate |
| |
| # if verify_correctness=True |
| try: |
| correct = self.gm.forward(*self.example_inputs) |
| result = self.candidate(*self.example_inputs) |
| |
| # TODO: replace `same` function with the one in testing |
| if same(correct, result): |
| return self.candidate |
| |
| raise RuntimeError(f"incorrect results of backend {self}") |
| return self.gm.forward |
| |
| except Exception: |
| log.exception("error in verify_correctness") |
| raise |
| finally: |
| self.restore() |
| |
| |
| def get_compiler_fn(compiler_fn): |
| from .debug_utils import wrap_backend_debug |
| |
| compiler_str = compiler_fn if isinstance(compiler_fn, str) else None |
| compiler_fn = lookup_backend(compiler_fn) |
| return wrap_backend_debug(compiler_fn, compiler_str) |
| |
| |
| @functools.lru_cache(1) |
| def lookup_backend(compiler_fn): |
| """Expand backend strings to functions""" |
| if compiler_fn == "inductor": |
| compiler_fn = import_module(f"{config.inductor_import}.compile_fx").compile_fx |
| elif isinstance(compiler_fn, str): |
| from .optimizations import BACKENDS |
| |
| compiler_fn = BACKENDS[compiler_fn] |
| return compiler_fn |
| |
| |
| class _NullDecorator(contextlib.nullcontext): |
| def __call__(self, fn): |
| assert callable(fn) |
| return fn |
| |
| |
| def optimize( |
| backend="inductor", *, nopython=False, guard_export_fn=None, disable=False |
| ): |
| """ |
| The main entrypoint of TorchDynamo. Do graph capture and call |
| backend() to optimize extracted graphs. |
| |
| Args: |
| backend: One of the two things: |
| - Either, a function/callable taking a torch.fx.GraphModule and |
| example_inputs and returning a python callable that runs the |
| graph faster. |
| One can also provide additional context for the backend, like |
| torch.jit.fuser("fuser2"), by setting the backend_ctx_ctor attribute. |
| See AOTAutogradMemoryEfficientFusionWithContext for the usage. |
| - Or, a string backend name in `torchdynamo.list_backends()` |
| nopython: If True, graph breaks will be errors and there will |
| be a single whole-program graph. |
| disable: If True, turn this decorator into a no-op |
| |
| Example Usage: |
| |
| @torchdynamo.optimize() |
| def toy_example(a, b): |
| ... |
| """ |
| if disable or os.environ.get("TORCHDYNAMO_DISABLE", "") == "1": |
| return _NullDecorator() |
| if sys.platform == "win32": |
| warnings.warn( |
| "Windows is not currently supported, " |
| + f"{config.dynamo_import}.optimize() will do nothing" |
| ) |
| return _NullDecorator() |
| if sys.version_info >= (3, 11): |
| warnings.warn( |
| "Python 3.11+ not yet supported, " |
| f"{config.dynamo_import}.optimize() will do nothing" |
| ) |
| return _NullDecorator() |
| |
| backend = get_compiler_fn(backend) |
| |
| # Find if backend has any extra context manager |
| backend_ctx_ctor = getattr(backend, "backend_ctx_ctor", null_context) |
| |
| if nopython: |
| return optimize_assert(backend, guard_export_fn=guard_export_fn) |
| return _optimize_catch_errors( |
| convert_frame.convert_frame(backend, guard_export_fn=guard_export_fn), |
| backend_ctx_ctor, |
| ) |
| |
| |
| @patch("torch._dynamo.symbolic_convert.explain", True) |
| def explain(f, *args, **kwargs): |
| # TODO(voz): Do we want a decorator for this? |
| from . import reset |
| |
| reset() |
| |
| out_guards = [] |
| graphs = [] |
| ops_per_graph = [] |
| op_count = 0 |
| break_reasons = [] |
| |
| def dynamo_graph_accumulating_compiler(gm: torch.fx.GraphModule, example_inputs): |
| nonlocal graphs |
| nonlocal op_count |
| nonlocal ops_per_graph |
| |
| graphs.append(gm) |
| ops = [] |
| for node in gm.graph.nodes: |
| if node.op == "call_function": |
| ops.append(node.target) |
| |
| op_count += len(ops) |
| ops_per_graph.append(ops) |
| if gm.compile_subgraph_reason is not None: |
| break_reasons.append(gm.compile_subgraph_reason) |
| return gm.forward |
| |
| def guard_export_print(guards): |
| nonlocal out_guards |
| out_guards.append(guards) |
| |
| with patch(f"{__name__}.most_recent_backend", None): |
| opt_f = optimize( |
| dynamo_graph_accumulating_compiler, |
| nopython=False, |
| guard_export_fn=guard_export_print, |
| )(f) |
| # TODO(voz): We may have instances of `f` that mutate inputs, we should track sideffects and reject. |
| opt_f(*args, **kwargs) |
| |
| graph_count = len(graphs) |
| |
| # For the explanation summary, dedupe reasons by the innermost stack frame and dedupe by it. |
| deduped_reasons = {} |
| for reason in break_reasons: |
| innermost_frame = reason.user_stack[-1] |
| # __repr__ uniquely identifies a FrameSummary so we can use it for deduping |
| deduped_reasons[repr(innermost_frame)] = reason |
| |
| formatted_list = "" |
| for idx, break_reason in enumerate(deduped_reasons.values()): |
| formatted_stack = "".join(traceback.format_list(break_reason.user_stack)) |
| msg = f"{break_reason.reason}\n{formatted_stack}" |
| formatted_list += f"{idx + 1}. {msg} \n" |
| |
| explanation = f"Dynamo produced {graph_count} graphs" |
| explanation += f"with {graph_count - 1} graph break and {op_count} ops" |
| explanation += f"\n Break reasons: \n\n{formatted_list}" |
| |
| explanation += compile_times() |
| |
| # TODO(voz): Do we want a decorator for this? |
| reset() |
| return explanation, out_guards, graphs, ops_per_graph, break_reasons |
| |
| |
| def export( |
| f, *args, aten_graph=False, decomposition_table=None, tracing_mode="real", **kwargs |
| ): |
| if decomposition_table is not None or tracing_mode != "real": |
| assert ( |
| aten_graph |
| ), "Specifying a decomposition_table table or tracing mode is illegal without setting aten_graph=True" |
| f = innermost_fn(f) |
| |
| graph = None |
| out_guards = None |
| graph_captured_input = None |
| graph_captured_result = None |
| |
| def produce_matching(source_args, candidate_args): |
| matched_elements_positions = [] |
| dict_of_source_args = dict() |
| for i in range(0, len(source_args)): |
| element_id = id(source_args[i]) |
| dict_of_source_args[element_id] = i |
| |
| for i in range(0, len(candidate_args)): |
| arg = candidate_args[i] |
| # 1-element tensor arg can be unspec int/float |
| if isinstance(arg, torch.Tensor) and torch.numel(arg) == 1: |
| if id(arg) in dict_of_source_args: |
| matched_elements_positions.append(dict_of_source_args[id(arg)]) |
| elif id(arg.item()) in dict_of_source_args: |
| matched_elements_positions.append( |
| dict_of_source_args[id(arg.item())] |
| ) |
| else: |
| raise AssertionError( |
| "Dynamo input/output is not consistent with traced input/output" |
| ) |
| else: |
| assert ( |
| id(arg) in dict_of_source_args |
| ), "Dynamo input and output is a strict subset of traced input/output" |
| matched_elements_positions.append(dict_of_source_args[id(arg)]) |
| |
| return matched_elements_positions |
| |
| def guard_export_print(guards): |
| nonlocal out_guards |
| assert out_guards is None, "whole graph export entails exactly one guard export" |
| out_guards = guards |
| |
| def dynamo_normalization_capturing_compiler( |
| gm: torch.fx.GraphModule, example_inputs |
| ): |
| nonlocal graph |
| |
| assert graph is None, "whole graph export entails exactly one graph" |
| graph = gm |
| |
| def result_capturing_wrapper(*graph_inputs): |
| nonlocal graph_captured_result |
| nonlocal graph_captured_input |
| |
| graph_captured_input = graph_inputs |
| graph_captured_result = graph(*graph_inputs) |
| return graph_captured_result |
| |
| return result_capturing_wrapper |
| |
| # TODO(voz): Handle kwargs properly? |
| flat_args, in_spec = pytree.tree_flatten(args) |
| |
| remove_from_cache(f) |
| with patch(f"{__name__}.most_recent_backend", None): |
| opt_f = optimize_assert( |
| dynamo_normalization_capturing_compiler, |
| guard_export_fn=guard_export_print, |
| export=True, |
| )(f) |
| # TODO(voz): We may have instances of `f` that mutate inputs, we should track sideffects and reject. |
| result_traced = opt_f(*args, **kwargs) |
| remove_from_cache(f) |
| |
| assert graph is not None, "whole graph export entails exactly one call" |
| assert out_guards is not None, "whole graph export entails exactly one guard export" |
| |
| matched_input_elements_positions = produce_matching(flat_args, graph_captured_input) |
| |
| flat_results_traced, out_spec_traced = pytree.tree_flatten(result_traced) |
| |
| flat_both = list(graph_captured_result) + flat_args |
| matched_output_elements_positions = produce_matching(flat_both, flat_results_traced) |
| |
| class ChangeInputOutputSignature(torch.fx.interpreter.Transformer): |
| def __init__( |
| self, |
| m, |
| ): |
| super().__init__(m) |
| arg_len = len(flat_args) |
| self.new_args = [ |
| super(ChangeInputOutputSignature, self).placeholder(f"arg{i}", (), {}) |
| for i in range(0, arg_len) |
| ] |
| self.old_args_gen = ( |
| self.new_args[i] for i in matched_input_elements_positions |
| ) |
| |
| def placeholder(self, target, args, kwargs): |
| return next(self.old_args_gen) |
| |
| def output(self, target, args, kwargs): |
| dynamo_result_flat = args[0] |
| lookup = [*dynamo_result_flat, *self.new_args] |
| new_result_flat = [lookup[i] for i in matched_output_elements_positions] |
| new_result = pytree.tree_unflatten(new_result_flat, out_spec_traced) |
| |
| return super().output(target, (new_result,), {}) |
| |
| if aten_graph: |
| # Running graph with interpreter is needed for propagating the stack_trace |
| def graph_with_interpreter(*args): |
| with torch.fx.traceback.override_stack_trace(): |
| return torch.fx.Interpreter(graph).run(*args) |
| |
| graph = make_fx( |
| graph_with_interpreter, |
| decomposition_table=decomposition_table, |
| tracing_mode=tracing_mode, |
| )(*graph_captured_input) |
| |
| new_graph = ChangeInputOutputSignature( |
| graph, |
| ).transform() |
| |
| return (new_graph, out_guards) |
| |
| |
| def assume_constant_result(fn): |
| fn._dynamo_marked_constant = True |
| assert ( |
| not config.fake_tensor_propagation |
| ), "Constant result capture is not supported with fake tensors." |
| return fn |
| |
| |
| def optimize_assert(backend, *, guard_export_fn=None, export=False): |
| """ |
| The same as `torchdynamo.optimize(backend, nopython=True)` |
| """ |
| backend = get_compiler_fn(backend) |
| |
| # Find if backend has any extra context manager |
| backend_ctx_ctor = getattr(backend, "backend_ctx_ctor", null_context) |
| |
| return _optimize_catch_errors( |
| convert_frame.convert_frame_assert(backend, guard_export_fn, export=export), |
| backend_ctx_ctor, |
| ) |
| |
| |
| def run(fn=None): |
| """Don't do any dynamic compiles, just use prior optimizations""" |
| if fn is not None: |
| fn = innermost_fn(fn) |
| assert callable(fn) |
| return RunOnlyContext()(fn) |
| return RunOnlyContext() |
| |
| |
| def disable(fn=None): |
| """Decorator and context manager to disable TorchDynamo""" |
| if fn is not None: |
| fn = innermost_fn(fn) |
| assert callable(fn) |
| return DisableContext()(fn) |
| return DisableContext() |
| |
| |
| def skip(fn=None): |
| """ |
| Skip frames associated with the function code, but still process recursively |
| invoked frames |
| """ |
| if fn is None: |
| return skip |
| fn = innermost_fn(fn) |
| assert callable(fn) |
| skip_code(fn.__code__) |
| fn._torchdynamo_disable = True |
| return fn |
| |
| |
| class TorchPatcher: |
| @staticmethod |
| @functools.lru_cache(None) |
| def patch(): |
| # Disable TorchDynamo on some torch.* compilers generated frames |
| torch.jit.trace = disable(torch.jit.trace) |
| torch.jit.trace_module = disable(torch.jit.trace_module) |
| torch.jit._get_trace_graph = disable(torch.jit._get_trace_graph) |
| |
| # symbolic_trace creates new frames. We disable Dynamo on such frames |
| torch.fx._symbolic_trace.Tracer.trace = disable( |
| torch.fx._symbolic_trace.Tracer.trace |
| ) |
| |
| torch.onnx.export_to_pretty_string = disable(torch.onnx.export_to_pretty_string) |
| torch.distributions.Distribution.set_default_validate_args(False) |
| |
| if proxy_tensor is not None: |
| proxy_tensor.dispatch_trace = disable(proxy_tensor.dispatch_trace) |
| |
| optimizers = [ |
| opt |
| for opt in torch.optim.__dict__.values() |
| if inspect.isclass(opt) and issubclass(opt, torch.optim.Optimizer) |
| ] |
| |
| # disable dynamo for the wrapper that helps give dynamo hints about entering DDP |
| if hasattr(DistributedDataParallel, "_inside_ddp_forward"): |
| DistributedDataParallel._inside_ddp_forward = skip( |
| DistributedDataParallel._inside_ddp_forward |
| ) |
| |
| # disable profile hook |
| for opt in optimizers: |
| opt._cuda_graph_capture_health_check = disable( |
| opt._cuda_graph_capture_health_check |
| ) |
| # disable any currently set hooks |
| # Note: we only want to disable the profiling hook |
| # which is the *last* hook applied, we want to keep the no_grad hook |
| hooked = getattr(opt.step, "hooked", False) |
| if hooked: |
| unwrapped_step = getattr(opt.step, "__wrapped__", None) |
| if unwrapped_step: |
| opt.step = unwrapped_step |
| |
| # disable future hooking |
| opt.step.hooked = True |
| |
| @staticmethod |
| def suppress_torch_distributed_warnings(fn): |
| def inner_fn(*args, **kwargs): |
| warnings.filterwarnings( |
| "ignore", category=UserWarning, module="torch.distributed" |
| ) |
| return fn(*args, **kwargs) |
| |
| return inner_fn |