| import collections |
| import contextlib |
| import cProfile |
| import functools |
| import itertools |
| import logging |
| import os.path |
| import pstats |
| import shutil |
| import subprocess |
| import sys |
| from typing import Any, List |
| from unittest.mock import patch |
| |
| from functorch.compile import ( |
| config as functorch_config, |
| draw_graph, |
| get_aot_graph_name, |
| get_graph_being_compiled, |
| ) |
| |
| import torch |
| from torch import fx as fx |
| |
| from torch._dynamo import config as dynamo_config |
| from torch._dynamo.debug_utils import save_graph_repro, wrap_compiler_debug |
| from torch._dynamo.utils import get_debug_dir, init_logging |
| from torch.fx.graph_module import GraphModule |
| from torch.fx.passes.shape_prop import TensorMetadata |
| from torch.fx.passes.tools_common import legalize_graph |
| |
| from . import config, ir # noqa: F811, this is needed |
| from .scheduler import ( |
| BaseSchedulerNode, |
| FusedSchedulerNode, |
| NopKernelSchedulerNode, |
| OutputNode, |
| SchedulerNode, |
| ) |
| from .virtualized import V |
| |
| log = logging.getLogger(__name__) |
| |
| |
| @functools.lru_cache(None) |
| def has_dot(): |
| try: |
| subprocess.check_output(["which", "dot"], stderr=subprocess.PIPE) |
| return True |
| except subprocess.SubprocessError: |
| return False |
| |
| |
| def draw_buffers(nodes, print_graph=False, fname=None): |
| """ |
| Draw a graph in fname.svg. |
| nodes is a list of SchedulerNode objects. |
| """ |
| if not has_dot(): |
| log.warning("draw_buffers() requires `graphviz` package") |
| return |
| |
| if fname is None: |
| fname = get_graph_being_compiled() |
| |
| graph = create_fx_from_snodes(nodes) |
| |
| for node in graph.nodes: |
| if "fusion_meta" not in node.meta: |
| continue |
| group = node.meta["fusion_meta"].group |
| if isinstance(group, tuple): |
| group = group[1] |
| |
| # gather meta data |
| dtype = None |
| if isinstance(node, ir.ComputedBuffer): |
| dtype = node.data.dtype |
| |
| metadata = TensorMetadata(group, dtype, None, None, None, None, None) |
| node.meta["tensor_meta"] = metadata |
| |
| if print_graph: |
| print(graph) |
| |
| gm = GraphModule({}, graph) |
| legalize_graph(gm) |
| gm.graph.lint() |
| draw_graph(gm, fname, clear_meta=False) |
| |
| |
| def create_fx_from_snodes(snodes: List[BaseSchedulerNode]) -> fx.Graph: |
| """ |
| Creates a FX Graph from a list of SchedulerNode objects. |
| """ |
| |
| def get_fake_func(name): |
| def func1(*args): |
| return 0 |
| |
| func1.__name__ = name |
| return func1 |
| |
| FusionMeta = collections.namedtuple("FusionMeta", ["group", "snodes", "type"]) |
| |
| func_dict = {s: get_fake_func(s) for s in ["extern", "nop", "compute", "fused"]} |
| buf_to_fx_node = {} |
| graph = torch.fx.Graph() |
| first_node = None |
| |
| outputs = [] |
| group: Any = None |
| # create call_function node for each Buffer and Kernel |
| for snode in snodes: |
| if snode.is_extern(): |
| node_type = "extern" |
| group = node_type |
| elif snode.is_template(): |
| node_type = "template" |
| group = node_type |
| elif isinstance(snode, NopKernelSchedulerNode): |
| node_type = "nop" |
| group = node_type |
| elif isinstance(snode, SchedulerNode): |
| node_type = "compute" |
| group = snode.group |
| elif isinstance(snode, FusedSchedulerNode): |
| node_type = "fused" |
| group = snode.group |
| else: |
| raise RuntimeError("Unknown node type") |
| node_func = func_dict[node_type] |
| fx_node = graph.call_function(node_func, args=(), kwargs=None) |
| |
| def in_output(snode): |
| if isinstance(snode, FusedSchedulerNode): |
| return any([in_output(x) for x in snode.snodes]) |
| return any([isinstance(user.node, OutputNode) for user in snode.users]) |
| |
| if in_output(snode): |
| outputs.append(fx_node) |
| name = snode.get_name() |
| fx_node.name = name |
| |
| fx_node.meta["fusion_meta"] = FusionMeta(group, [snode], node_type) |
| |
| if isinstance(snode, FusedSchedulerNode): |
| for x in snode.snodes: |
| buf_to_fx_node[x.get_name()] = fx_node |
| buf_to_fx_node[name] = fx_node |
| |
| if first_node is None: |
| first_node = fx_node |
| |
| # create edges between nodes |
| for snode in snodes: |
| name = snode.get_name() |
| deps = snode.read_writes.reads |
| |
| fx_node = buf_to_fx_node[name] |
| new_args = [] |
| for dep in deps: |
| if dep.name in buf_to_fx_node: |
| dep_node = buf_to_fx_node[dep.name] |
| else: |
| with graph.inserting_before(first_node): |
| dep_node = graph.placeholder(dep.name) |
| buf_to_fx_node[dep.name] = dep_node |
| new_args.append(dep_node) |
| |
| fx_node.args = tuple(new_args) |
| |
| graph.output(outputs[0] if len(outputs) == 1 else tuple(outputs)) |
| return graph |
| |
| |
| @contextlib.contextmanager |
| def enable_aot_logging(): |
| compile_debug = bool(os.environ.get("TORCH_COMPILE_DEBUG", False)) |
| debug_graphs = functorch_config.debug_graphs |
| debug_joint_graphs = functorch_config.debug_joint |
| |
| import torch._functorch.aot_autograd |
| |
| log = logging.getLogger(torch._functorch.aot_autograd.__name__) |
| |
| stack = contextlib.ExitStack() |
| stack.enter_context(patch("functorch.compile.config.log_level", logging.DEBUG)) |
| # if user has specified they want to see graphs via either env var |
| # add stream to std out |
| if debug_graphs or debug_joint_graphs: |
| stdout_handler = logging.StreamHandler(sys.stdout) |
| log.addHandler(stdout_handler) |
| stack.callback(lambda: log.removeHandler(stdout_handler)) |
| |
| if not compile_debug: |
| try: |
| yield |
| finally: |
| stack.close() |
| return |
| |
| # Enable all graphs to be logged to a file by setting the flags to True |
| # and the log level of the file logger to DEBUG |
| stack.enter_context(patch("functorch.compile.config.debug_partitioner", True)) |
| stack.enter_context(patch("functorch.compile.config.debug_graphs", True)) |
| stack.enter_context(patch("functorch.compile.config.debug_joint", True)) |
| |
| path = os.path.join(get_debug_dir(), "aot_torchinductor") |
| if not os.path.exists(path): |
| os.makedirs(path) |
| |
| fh = logging.FileHandler( |
| os.path.join( |
| path, |
| f"aot_{get_aot_graph_name()}_debug.log", |
| ) |
| ) |
| fh.setLevel(logging.DEBUG) |
| fh.setFormatter( |
| logging.Formatter("[%(filename)s:%(lineno)d %(levelname)s] %(message)s") |
| ) |
| log.addHandler(fh) |
| try: |
| yield |
| finally: |
| log.removeHandler(fh) |
| stack.close() |
| |
| |
| class DebugContext: |
| _counter = itertools.count() |
| |
| @staticmethod |
| def wrap(fn): |
| @functools.wraps(fn) |
| def inner(*args, **kwargs): |
| with DebugContext(): |
| return fn(*args, **kwargs) |
| |
| return wrap_compiler_debug(inner, compiler_name="inductor") |
| |
| @staticmethod |
| def create_debug_dir(folder_name): |
| for n in DebugContext._counter: |
| dirname = os.path.join( |
| get_debug_dir(), |
| "aot_torchinductor", |
| f"{folder_name}.{n}", |
| ) |
| if not os.path.exists(dirname): |
| os.makedirs(dirname) |
| return dirname |
| |
| def __init__(self): |
| self._prof = None |
| self._path = None |
| self._stack = contextlib.ExitStack() |
| |
| def rename(self, new_path: str): |
| if not self._path: |
| return |
| assert new_path.endswith(".debug"), new_path |
| if os.path.exists(new_path): |
| shutil.rmtree(new_path) |
| try: |
| os.rename(self._path, new_path) |
| self._path = new_path |
| except OSError: |
| # other OS might have troubling renaming dir with open files |
| pass |
| |
| def fopen(self, filename): |
| assert self._path |
| return open(os.path.join(self._path, filename), "w") |
| |
| def filename(self, suffix): |
| return os.path.join(self._path, suffix) |
| |
| def upload_tar(self): |
| if config.trace.upload_tar is not None: |
| import tarfile |
| |
| assert self._path |
| tar_file = os.path.join( |
| self._path, f"{os.path.basename(self._path)}.tar.gz" |
| ) |
| with tarfile.open(tar_file, "w:gz") as tar: |
| tar.add(self._path, arcname=os.path.basename(self._path)) |
| config.trace.upload_tar(tar_file) |
| |
| def __enter__(self): |
| log = logging.getLogger("torch._inductor") |
| if not log.handlers: |
| init_logging() |
| |
| if config.debug: |
| |
| def reset_log_level(level): |
| dynamo_config.log_level = level |
| |
| self._stack.callback(reset_log_level, dynamo_config.log_level) |
| dynamo_config.log_level = logging.DEBUG |
| |
| self._stack.enter_context(V.set_debug_handler(self)) |
| |
| if not config.trace.enabled: |
| return |
| |
| self._path = self.create_debug_dir(get_aot_graph_name()) |
| |
| if config.trace.debug_log: |
| self._setup_log_capture("debug.log", logging.DEBUG) |
| if config.trace.info_log: |
| self._setup_log_capture("info.log", logging.INFO) |
| if config.trace.compile_profile: |
| self._prof = cProfile.Profile() |
| self._prof.enable() |
| |
| def _setup_log_capture(self, filename, level): |
| log = logging.getLogger("torch._inductor") |
| fd = self._stack.enter_context(self.fopen(filename)) |
| ch = logging.StreamHandler(fd) |
| ch.setLevel(level) |
| ch.setFormatter( |
| logging.Formatter("[%(filename)s:%(lineno)d %(levelname)s] %(message)s") |
| ) |
| log.addHandler(ch) |
| log.setLevel(min(log.level, level)) |
| self._stack.callback(log.removeHandler, ch) |
| |
| def __exit__(self, exc_type, exc_val, exc_tb): |
| if self._prof: |
| self._prof.disable() |
| self._save_profile_data() |
| |
| if self._path: |
| self.upload_tar() |
| log.warning("%s debug trace: %s", get_graph_being_compiled(), self._path) |
| self._stack.close() |
| |
| def _save_profile_data(self): |
| self._prof.dump_stats(self.filename("compile.prof")) |
| with self.fopen("compile.stats") as fd: |
| stats = pstats.Stats(self._prof, stream=fd) |
| stats.strip_dirs() |
| stats.sort_stats("cumtime") |
| stats.print_stats(100) |
| stats.sort_stats("tottime") |
| stats.print_stats(100) |
| |
| def __getattr__(self, name): |
| if config.trace.enabled and getattr(config.trace, name): |
| try: |
| return getattr(DebugFormatter(self), name) |
| except Exception: |
| log.warning("Ignoring exception in debug code", exc_info=True) |
| else: |
| |
| def ignored(*args, **kwargs): |
| pass |
| |
| return ignored |
| |
| |
| SchedulerNodeList = List[Any] |
| |
| |
| class DebugFormatter: |
| def __init__(self, handler): |
| self.fopen = handler.fopen |
| self.filename = handler.filename |
| self.handler = handler |
| |
| def fx_graph(self, gm: torch.fx.GraphModule, inputs: List[torch.Tensor]): |
| with self.fopen("fx_graph_runnable.py") as fd: |
| save_graph_repro(fd, gm, inputs, "inductor") |
| |
| with self.fopen("fx_graph_readable.py") as fd: |
| fd.write(gm.print_readable(print_output=False)) |
| |
| def fx_graph_transformed( |
| self, gm: torch.fx.GraphModule, inputs: List[torch.Tensor] |
| ): |
| with self.fopen("fx_graph_transformed.py") as fd: |
| fd.write(gm.print_readable(print_output=False)) |
| |
| def ir_pre_fusion(self, nodes: SchedulerNodeList): |
| self._write_ir("ir_pre_fusion.txt", nodes) |
| |
| def ir_post_fusion(self, nodes: SchedulerNodeList): |
| self._write_ir("ir_post_fusion.txt", nodes) |
| |
| def _write_ir(self, filename: str, nodes: SchedulerNodeList): |
| with self.fopen(filename) as fd: |
| for node in nodes: |
| fd.write(node.debug_str()) |
| fd.write("\n\n\n") |
| |
| def graph_diagram(self, nodes: SchedulerNodeList): |
| draw_buffers(nodes, fname=self.filename("graph_diagram.svg")) |
| |
| def output_code(self, filename): |
| shutil.copy(filename, self.filename("output_code.py")) |