| # mypy: ignore-errors |
| |
| import dataclasses |
| import functools |
| from importlib import import_module |
| from typing import Any, List, Optional |
| |
| from functorch.compile import min_cut_rematerialization_partition |
| |
| import torch |
| from torch import _guards |
| from torch._functorch.compilers import ts_compile |
| from .common import aot_autograd |
| from .registry import register_debug_backend as register_backend |
| |
| """ |
| This file contains TorchDynamo backends intended for debugging uses. |
| """ |
| |
| |
| @register_backend |
| def eager(gm, fake_tensor_inputs): |
| return gm |
| |
| |
| @register_backend |
| def pre_dispatch_eager(gm, fake_tensor_inputs): |
| from torch.fx.experimental.proxy_tensor import make_fx |
| |
| def runnable_gm(*args): |
| return torch.fx.Interpreter(gm).run(*args) |
| |
| pre_dispatch_gm = make_fx(runnable_gm, pre_dispatch=True)(*fake_tensor_inputs) |
| pre_dispatch_gm.print_readable() |
| |
| return pre_dispatch_gm |
| |
| |
| @register_backend |
| def eager_debug(gm, fake_tensor_inputs): |
| from torch._subclasses.schema_check_mode import SchemaCheckMode |
| |
| # We could add more debugging bits here. |
| # Right now, this backend can be used to check for and error on |
| # custom dispatcher ops that have incorrect schemas. |
| def inner(*args): |
| with SchemaCheckMode(): |
| return torch.fx.Interpreter(gm).run(*args) |
| |
| return inner |
| |
| |
| @register_backend(name="ts") |
| def torchscript(gm, fake_tensor_inputs): |
| return torch.jit.script(gm) |
| |
| |
| # used boxed call to discard inputs when they are no longer needed |
| def boxed_nop(fx_g, example_inputs): |
| def run(args): |
| return torch.fx.Interpreter(fx_g).boxed_run(args) |
| |
| run._boxed_call = True |
| return run |
| |
| |
| # Useful for debugging purpose |
| # aot_eager uses AOT Autograd backend with nop compiler. It is helpful in debugging. |
| aot_eager = aot_autograd( |
| fw_compiler=boxed_nop, partition_fn=min_cut_rematerialization_partition |
| ) |
| register_backend(name="aot_eager", compiler_fn=aot_eager) |
| |
| aot_eager_default_partitioner = aot_autograd(fw_compiler=boxed_nop) |
| register_backend( |
| name="aot_eager_default_partitioner", compiler_fn=aot_eager_default_partitioner |
| ) |
| |
| # Uses TorchInductor AOT Autograd decomps and partitioner to isolate aot vs |
| # inductor problems. |
| # aot_eager_decomp_partition just replaces the inductor compiler with nop to help |
| # isolate inductor vs aot_eager errors |
| aot_eager_decomp_partition = aot_autograd( |
| # these are taken from memory_efficient_fusion() |
| fw_compiler=boxed_nop, |
| bw_compiler=boxed_nop, |
| # NB: lambda here is to delay import of inductor |
| decompositions=lambda: import_module( |
| "torch._inductor.compile_fx" |
| ).select_decomp_table(), |
| partition_fn=functools.partial( |
| min_cut_rematerialization_partition, compiler="inductor" |
| ), |
| ) |
| register_backend( |
| name="aot_eager_decomp_partition", compiler_fn=aot_eager_decomp_partition |
| ) |
| |
| # AOT Autograd with torchscript backend. Default partitioner. |
| # aot_ts uses torchscript backend. We can use this with both nnc and nvfuser |
| # by using the relevant fuser with torch.jit.fuser(...) |
| aot_ts = aot_autograd(fw_compiler=ts_compile) |
| register_backend(name="aot_ts", compiler_fn=aot_ts) |
| |
| # These buggy backends are used for inducing bugs so that we can test |
| # our repro extraction / minifier scripts |
| |
| |
| class ReluCompileError(Exception): |
| pass |
| |
| |
| class TestingOnlyCompileError(Exception): |
| pass |
| |
| |
| @register_backend |
| def relu_compile_error_TESTING_ONLY(gm: torch.fx.GraphModule, example_inputs): |
| for node in gm.graph.nodes: |
| if node.target == torch.relu: |
| raise ReluCompileError() |
| return gm |
| |
| |
| @register_backend |
| def relu_runtime_error_TESTING_ONLY(gm: torch.fx.GraphModule, example_inputs): |
| for node in gm.graph.nodes: |
| if node.target == torch.relu: |
| node.target = torch._assert |
| node.args = (False, "ReluRuntimeError") |
| gm.recompile() |
| return gm |
| |
| |
| @register_backend |
| def relu_accuracy_error_TESTING_ONLY(gm: torch.fx.GraphModule, example_inputs): |
| for node in gm.graph.nodes: |
| if node.target == torch.relu: |
| node.target = torch.add |
| node.args = (node.args[0], 1) |
| gm.recompile() |
| |
| return gm |
| |
| |
| @register_backend |
| def non_leaf_compile_error_TESTING_ONLY(gm: torch.fx.GraphModule, example_inputs): |
| # Require at least one non-trivial thing in the graph, |
| # see https://github.com/pytorch/pytorch/issues/102898 |
| for node in gm.graph.nodes: |
| if node.op == "call_function": |
| break |
| else: |
| return gm |
| for t in example_inputs: |
| if not t.is_leaf: |
| raise TestingOnlyCompileError() |
| return gm |
| |
| |
| @dataclasses.dataclass |
| class ExplainOutput: |
| """ |
| This is the output of :func:`torch._dynamo.explain()` |
| There is no reason to create this class directly. |
| """ |
| |
| graphs: List[torch.fx.GraphModule] |
| graph_count: int |
| graph_break_count: int |
| break_reasons: List[ |
| Any |
| ] # Type is GraphCompileReason but doesn't matter for this purpose |
| op_count: int |
| ops_per_graph: Optional[List[torch.fx.Node]] = None |
| out_guards: Optional[List[_guards.Guard]] = None |
| compile_times: Optional[str] = None |
| |
| def __str__(self): |
| output = f"Graph Count: {self.graph_count}\n" |
| output += f"Graph Break Count: {self.graph_break_count}\n" |
| output += f"Op Count: {self.op_count}\n" |
| |
| output += "Break Reasons:\n" |
| for idx, break_reason in enumerate(self.break_reasons): |
| output += f" Break Reason {idx+1}:\n" |
| output += f" Reason: {break_reason.reason}\n" |
| output += " User Stack:\n" |
| for frame_summary in break_reason.user_stack: |
| output += f" {frame_summary}\n" |
| |
| if self.ops_per_graph is not None: |
| output += "Ops per Graph:\n" |
| for idx, ops in enumerate(self.ops_per_graph): |
| output += f" Ops {idx+1}:\n" |
| for op in ops: |
| output += f" {op}\n" |
| |
| if self.out_guards is not None: |
| output += "Out Guards:\n" |
| for i, guard in enumerate(self.out_guards): |
| output += f" Guard {i+1}:\n" |
| output += f" {str(guard)}" |
| |
| if self.compile_times is not None: |
| output += f"Compile Times: {self.compile_times}\n" |
| return output |
| |
| |
| def _explain_graph_detail( |
| gm: torch.fx.GraphModule, graphs, op_count, ops_per_graph, break_reasons |
| ): |
| """ |
| This function is a utility which processes a torch.fx.GraphModule and |
| accumulates information about its ops, graph breaks, and other details. It |
| is intended to be used by the ExplainWithBackend class and |
| `torch._dynamo.explain()` to provide details from Dynamo's graph capture. |
| |
| Parameters: |
| gm (torch.fx.GraphModule): The GraphModule to be processed. |
| graphs (list): A list that accumulates all the GraphModules processed. |
| op_count (int): The total count of operations in all GraphModules processed so far. |
| ops_per_graph (list): A list that accumulates the operations of each GraphModule. |
| break_reasons (list): A list that accumulates the reasons for breaks in each GraphModule. |
| |
| Returns: |
| tuple: A tuple containing the processed GraphModule, the updated lists of graphs, |
| operations per graph, and break reasons, and the updated operation count. |
| """ |
| graphs.append(gm) |
| ops = [node.target for node in gm.graph.nodes if node.op == "call_function"] |
| op_count += len(ops) |
| ops_per_graph.append(ops) |
| if gm.compile_subgraph_reason.graph_break: |
| break_reasons.append(gm.compile_subgraph_reason) |
| |
| return gm, graphs, op_count, ops_per_graph, break_reasons |
| |
| |
| class ExplainWithBackend: |
| """ |
| This class is intended to be used as a backend for `torch.compile`. It is |
| composable with other backends. When used in this way, it accumulates |
| information about graph breaks, ops, and other info and provides a string |
| representation summarizing this information. |
| |
| Attributes: |
| backend (str): The name of the backend to use for optimization. |
| graphs (list): A list of the graphs captured by TorchDynamo. |
| op_count (int): The total number of operations in all optimized graphs. |
| break_reasons (list): A list of graph break reasons with stack traces. |
| |
| Example Usage: |
| def fn(x): |
| x = torch.sigmoid(x) |
| return x |
| |
| torch._dynamo.reset() |
| eb = ExplainWithBackend("inductor") |
| optimized_fn = torch.compile(fn, backend=eb) |
| result = optimized_fn(torch.randn(5)) |
| print(eb.output()) |
| """ |
| |
| def __init__(self, backend): |
| from .registry import lookup_backend |
| |
| self.backend = lookup_backend(backend) |
| self.graphs = [] |
| self.op_count = 0 |
| self.break_reasons = [] |
| |
| def __call__(self, gm: torch.fx.GraphModule, example_inputs): |
| gm, self.graphs, self.op_count, _, self.break_reasons = _explain_graph_detail( |
| gm, self.graphs, self.op_count, [], self.break_reasons |
| ) |
| return self.backend(gm, example_inputs) |
| |
| def output(self) -> ExplainOutput: |
| graph_count = len(self.graphs) |
| output = ExplainOutput( |
| self.graphs, |
| graph_count, |
| graph_count - 1, |
| self.break_reasons, |
| self.op_count, |
| ) |
| |
| return output |