| import copy |
| import functools |
| import getpass |
| import logging |
| import os |
| import shutil |
| import subprocess |
| import textwrap |
| import uuid |
| from collections import Counter |
| from importlib import import_module |
| from tempfile import TemporaryFile |
| |
| import torch |
| import torch.fx as fx |
| from torch._prims_common import is_float_dtype |
| |
| from . import config |
| from .backends.registry import lookup_backend, register_debug_backend |
| from .utils import clone_inputs, get_debug_dir |
| |
| log = logging.getLogger(__name__) |
| |
| |
| inductor_config = import_module("torch._inductor.config") |
| use_buck = inductor_config.is_fbcode() |
| |
| |
| extra_deps = [] |
| extra_imports = "" |
| if use_buck: |
| extra_deps = [ |
| "//caffe2/fb/custom_ops/sparsenn:sparsenn-all_operators", |
| "//caffe2/torch/fb/sparsenn:sparsenn_operators_gpu", |
| "//caffe2/torch/fb/sparsenn:sparsenn_operators", |
| "//deeplearning/fbgemm/fbgemm_gpu:sparse_ops_cpu", |
| "//deeplearning/fbgemm/fbgemm_gpu:sparse_ops", |
| ] |
| extra_imports = "\n".join([f'torch.ops.load_library("{x}")' for x in extra_deps]) |
| |
| |
| BUCK_CMD_PREFIX = ["buck2", "run", "@mode/dev-nosan"] |
| |
| |
| class BuckTargetWriter: |
| def __init__(self, filename): |
| self.subdir, self.py_file = os.path.split(os.path.abspath(filename)) |
| self.target = self.py_file.replace(".py", "") |
| |
| # Get main_module path from fbcode |
| self.path = f'{self.subdir.replace("/", ".")}.{self.target}' |
| self.path = self.path[self.path.find("fbcode.") :] |
| self.path = self.path[7:] |
| |
| # Get cmd line path |
| tmp = self.subdir |
| tmp = tmp[tmp.find("fbcode/") :][7:] |
| self.cmd_line_path = f"//{tmp}:{self.target}" |
| |
| def build(self): |
| extra_cpp_deps = "\n".join([f' "{x}",' for x in extra_deps]) |
| return textwrap.dedent( |
| f""" |
| load("@fbcode_macros//build_defs:python_binary.bzl", "python_binary") |
| |
| python_binary( |
| name="{self.target}", |
| srcs = ["{self.py_file}"], |
| compile = False, |
| deps = [ |
| "//caffe2:torch", |
| "//caffe2/functorch:functorch", |
| "//triton:triton", |
| ], |
| cpp_deps = [ |
| {extra_cpp_deps} |
| ], |
| main_module = "{self.path}", |
| ) |
| """ |
| ) |
| |
| def write(self, print_msg=True): |
| target_file = os.path.join(self.subdir, "TARGETS") |
| with open(target_file, "w") as fd: |
| fd.write(self.build()) |
| # log.warning(f"Wrote isolation TARGETS file at {target_file}") |
| cmd_split = BUCK_CMD_PREFIX + [self.cmd_line_path] |
| if print_msg: |
| log.warning( |
| f"Found an example that reproduces the error. Run this cmd to repro - {' '.join(cmd_split)}" |
| ) |
| return cmd_split |
| |
| |
| def minifier_dir(): |
| path = os.path.join(get_debug_dir(), "minifier") |
| if path is None: |
| path = f"{tempfile.gettempdir()}/minifier_{getpass.getuser()}" |
| if not os.path.exists(path): |
| os.makedirs(path, exist_ok=True) |
| return path |
| |
| |
| class NNModuleToString: |
| safe_reprs = [ |
| torch.nn.Linear, |
| torch.nn.Conv1d, |
| torch.nn.Conv2d, |
| torch.nn.Conv3d, |
| torch.nn.BatchNorm1d, |
| torch.nn.BatchNorm2d, |
| torch.nn.BatchNorm3d, |
| torch.nn.LayerNorm, |
| torch.nn.Dropout, |
| torch.nn.Softmax, |
| torch.nn.ReLU, |
| torch.nn.GELU, |
| torch.nn.Identity, |
| torch.nn.MaxPool2d, |
| torch.nn.Embedding, |
| torch.nn.Tanh, |
| torch.nn.ConvTranspose1d, |
| torch.nn.GLU, |
| torch.nn.LSTM, |
| torch.nn.Flatten, |
| torch.nn.AdaptiveAvgPool2d, |
| ] |
| |
| @staticmethod |
| def can_convert_to_string(gm): |
| cant_convert = set() |
| for _, module in gm.named_children(): |
| if type(module) not in NNModuleToString.safe_reprs: |
| cant_convert.add(module) |
| |
| if len(cant_convert) > 0: |
| log.warning(f"We have not tested reprs of some modules - {cant_convert}") |
| # TODO - Assuming that all modules can be safely repr'd. Check if that assumption is correct. |
| return True |
| |
| @staticmethod |
| def convert(gm): |
| from torch.nn.modules.module import _addindent |
| |
| tab = " " * 4 |
| |
| model_str = textwrap.dedent( |
| """ |
| from torch.nn import * |
| class Repro(torch.nn.Module): |
| def __init__(self): |
| super().__init__() |
| """ |
| ) |
| |
| for module_name, module in gm.named_children(): |
| module_str = f"{module.__repr__()}" |
| # module should be a core torch.nn.Module, so all parameters |
| # should be on the same device. |
| example_param = next(module.parameters(), None) |
| if example_param is not None and example_param.is_cuda: |
| module_str = f"{module_str}.cuda()" |
| model_str += f"{tab*2}self.{module_name} = {module_str}\n" |
| |
| for buffer_name, buffer in gm._buffers.items(): |
| if buffer is None: |
| continue |
| if torch.is_floating_point(buffer): |
| tensor_str = f"torch.randn({list(buffer.shape)}, dtype={buffer.dtype})" |
| else: |
| tensor_str = ( |
| f"torch.randint(1, size={list(buffer.shape)}, dtype={buffer.dtype})" |
| ) |
| if buffer.is_cuda: |
| tensor_str = f"{tensor_str}.cuda()" |
| model_str += f"{tab*2}self.register_buffer('{buffer_name}', {tensor_str})\n" |
| |
| for param_name, param in gm._parameters.items(): |
| if param is None: |
| continue |
| tensor_str = f"torch.nn.Parameter(torch.randn({list(param.shape)}, dtype={param.dtype}))" |
| if param.is_cuda: |
| tensor_str = f"{tensor_str}.cuda()" |
| model_str += f"{tab*2}self.{param_name} = {tensor_str}\n" |
| |
| # TODO - Keep this code for now. But, I don't think we will need this. |
| # attrs = dir(gm) |
| # for attr in attrs: |
| # if "_tensor_constant" in attr: |
| # val = getattr(gm, attr) |
| # model_str += f" {attr} = {val!r}\n" |
| |
| model_str += f"{_addindent(gm.code, 4)}\n" |
| return model_str |
| |
| |
| @functools.lru_cache(None) # subprocess is expensive |
| def _cuda_system_info_comment(): |
| if not torch.cuda.is_available(): |
| return "# torch.cuda.is_available()==False, no GPU info collected\n" |
| |
| model_str = "# CUDA Info: \n" |
| try: |
| cuda_version_out = subprocess.run(["nvcc", "--version"], stdout=subprocess.PIPE) |
| cuda_version_lines = cuda_version_out.stdout.decode().split("\n") |
| cuda_version_out = "".join( |
| [f"# {s} \n" for s in cuda_version_lines if s not in [""]] |
| ) |
| model_str += f"{cuda_version_out}\n" |
| except FileNotFoundError: |
| model_str += "# nvcc not found\n" |
| |
| gpu_names = Counter( |
| torch.cuda.get_device_name(i) for i in range(torch.cuda.device_count()) |
| ) |
| |
| model_str += "# GPU Hardware Info: \n" |
| for name, count in gpu_names.items(): |
| model_str += f"# {name} : {count} \n" |
| model_str += "\n" |
| return model_str |
| |
| |
| def generate_config_string(): |
| import torch._functorch.config |
| import torch._inductor.config |
| |
| return textwrap.dedent( |
| f"""\ |
| import torch._dynamo.config |
| import torch._inductor.config |
| import torch._functorch.config |
| torch._dynamo.config.load_config({repr(torch._dynamo.config.save_config())}) |
| torch._inductor.config.load_config({repr(torch._inductor.config.save_config())}) |
| torch._functorch.config.load_config({repr(torch._functorch.config.save_config())}) |
| """ |
| ) |
| |
| |
| TEST_REPLACEABLE_COMMENT = "# REPLACEABLE COMMENT FOR TESTING PURPOSES" |
| |
| |
| def generate_compiler_repro_string(gm, args): |
| model_str = textwrap.dedent( |
| f""" |
| import torch |
| from torch import tensor, device |
| import torch.fx as fx |
| from torch._dynamo.testing import rand_strided |
| from math import inf |
| from torch.fx.experimental.proxy_tensor import make_fx |
| |
| {generate_config_string()} |
| |
| {TEST_REPLACEABLE_COMMENT} |
| {extra_imports} |
| |
| """ |
| ) |
| model_str += f"# torch version: {torch.version.__version__}\n" |
| if hasattr(torch.version, "cuda"): |
| model_str += f"# torch cuda version: {torch.version.cuda}\n" |
| if hasattr(torch.version, "git_version"): |
| model_str += f"# torch git version: {torch.version.git_version}\n\n\n" |
| model_str += _cuda_system_info_comment() |
| |
| model_str += NNModuleToString.convert(gm) |
| |
| model_str += f"args = {[(tuple(a.shape), tuple(a.stride()), a.dtype, a.device.type) for a in args]!r}\n" |
| model_str += ( |
| "args = [rand_strided(sh, st, dt, dev) for (sh, st, dt, dev) in args]\n" |
| ) |
| # TODO: fake may be better for performance here |
| tracing_mode = "real" |
| if config.dynamic_shapes: |
| tracing_mode = "symbolic" |
| model_str += f"mod = make_fx(Repro(), tracing_mode={repr(tracing_mode)})(*args)\n" |
| return model_str |
| |
| |
| INDUCTOR_IMPORT = """ |
| from torch._inductor.compile_fx import compile_fx_inner |
| from torch._dynamo.debug_utils import same_two_models |
| """ |
| |
| COMPILER_REPRO_OPTIONS = { |
| "inductor": (INDUCTOR_IMPORT, "compile_fx_inner", "inductor_fails"), |
| "inductor_accuracy": ( |
| INDUCTOR_IMPORT, |
| "compile_fx_inner", |
| "inductor_accuracy_fails", |
| ), |
| } |
| |
| |
| def dump_compiler_graph_state(gm, args, compiler_name): |
| subdir = os.path.join(minifier_dir(), "checkpoints") |
| if not os.path.exists(subdir): |
| os.makedirs(subdir, exist_ok=True) |
| file_name = os.path.join(subdir, f"{len(gm.graph.nodes)}.py") |
| log.warning(f"Writing checkpoint with {len(gm.graph.nodes)} nodes to {file_name}") |
| with open(file_name, "w") as fd: |
| save_graph_repro(fd, gm, args, compiler_name) |
| curdir = os.getcwd() |
| repro_path = os.path.join(curdir, "repro.py") |
| try: |
| shutil.copyfile(file_name, repro_path) |
| log.warning(f"Copying repro file for convenience to {repro_path}") |
| if use_buck: |
| BuckTargetWriter(file_name).write() |
| except OSError: |
| log.warning(f"No write permissions for {repro_path}") |
| pass |
| |
| |
| def save_graph_repro(fd, gm, args, compiler_name): |
| sync_line = "" |
| for arg in args: |
| if isinstance(arg, torch.Tensor) and arg.is_cuda: |
| sync_line = "torch.cuda.synchronize() # Ensures that segfaults are surfaced" |
| break |
| |
| if "inductor" in compiler_name: |
| fd.write("import torch._inductor.overrides\n") |
| fd.write(generate_compiler_repro_string(gm, args)) |
| fd.write(COMPILER_REPRO_OPTIONS[compiler_name][0]) |
| if "_accuracy" in compiler_name: |
| fd.write( |
| textwrap.dedent( |
| f""" |
| compiled = {COMPILER_REPRO_OPTIONS[compiler_name][1]}(mod, args) |
| class AccuracyError(Exception): |
| pass |
| if not same_two_models(mod, compiled, args, only_fwd=True): |
| raise AccuracyError("Bad accuracy detected") |
| """ |
| ) |
| ) |
| else: |
| fd.write( |
| textwrap.dedent( |
| f""" |
| compiled = {COMPILER_REPRO_OPTIONS[compiler_name][1]}(mod, args) |
| ref = compiled(args) |
| {sync_line} |
| """ |
| ) |
| ) |
| |
| |
| def isolate_fails(fx_g, args, compiler_name: str, env=None, patch_code=None): |
| if env is None: |
| env = {} |
| subdir = os.path.join(os.getcwd(), "isolate") |
| if not os.path.exists(subdir): |
| os.makedirs(subdir, exist_ok=True) |
| file_name = os.path.join(subdir, f"{str(uuid.uuid4())[:5]}.py") |
| with open(file_name, "w") as fd: |
| repro_code = generate_compiler_repro_string(fx_g, args) |
| if patch_code is not None: |
| repro_code = repro_code.replace(TEST_REPLACEABLE_COMMENT, patch_code) |
| fd.write(repro_code) |
| fail_fn = COMPILER_REPRO_OPTIONS[compiler_name][2] |
| fd.write( |
| textwrap.dedent( |
| f""" |
| from {__name__} import {fail_fn} |
| """ |
| ) |
| ) |
| fd.write( |
| textwrap.dedent( |
| f""" |
| if {fail_fn}(mod, args): |
| exit(1) |
| else: |
| exit(0) |
| """ |
| ) |
| ) |
| # with open(file_name, "r") as fd: |
| # print(fd.read()) |
| new_env = os.environ.copy() |
| new_env = {**new_env, **env} |
| stdout, stderr = TemporaryFile(), TemporaryFile() |
| |
| if use_buck: |
| cmd = BuckTargetWriter(file_name).write(print_msg=False) |
| else: |
| cmd = ["python", file_name] |
| |
| p = subprocess.Popen( |
| cmd, |
| cwd=subdir, |
| stdout=stdout, |
| stderr=stderr, |
| env=new_env, |
| ) |
| p.wait() |
| |
| if p.returncode != 0: |
| stdout.seek(0) |
| stderr.seek(0) |
| print(textwrap.indent(stdout.read().decode("utf-8"), prefix=">> ")) |
| print(textwrap.indent(stderr.read().decode("utf-8"), prefix=">> ")) |
| # print(f"Isolated test failed - {file_name}") |
| return True |
| return False |
| |
| |
| def inductor_fails(fx_g, args, check_str=None): |
| has_cuda = False |
| for arg in args: |
| if arg.is_cuda: |
| has_cuda = True |
| break |
| |
| def sync(): |
| if has_cuda: |
| # Ensures that segfaults are surfaced |
| torch.cuda.synchronize() |
| |
| from torch._inductor.compile_fx import compile_fx_inner |
| |
| try: |
| result = fx_g(*args) |
| assert isinstance(result, (tuple, list)) |
| assert not any([isinstance(x, (tuple, list)) for x in result]) |
| except Exception: |
| return False |
| |
| sync() |
| |
| try: |
| compile_mod = compile_fx_inner(fx_g, args) |
| compile_mod(args) |
| sync() |
| except Exception as e: |
| if check_str is not None and check_str not in repr(e): |
| return False |
| print(repr(e)) |
| return True |
| return False |
| |
| |
| def inductor_accuracy_fails(fx_g, args, check_str=None): |
| from torch._inductor.compile_fx import compile_fx_inner |
| |
| return backend_aot_accuracy_fails(fx_g, args, compile_fx_inner) |
| |
| |
| def get_minifier_repro_path(): |
| return os.path.join(minifier_dir(), "minifier_launcher.py") |
| |
| |
| def helper_for_dump_minify(contents): |
| minified_repro_path = get_minifier_repro_path() |
| log.warning(f"Writing minified repro to {minified_repro_path}") |
| |
| if use_buck: |
| BuckTargetWriter(minified_repro_path).write() |
| try: |
| with open(minified_repro_path, "w") as fd: |
| fd.write(contents) |
| |
| except OSError as e: |
| log.exception(e) |
| raise NotImplementedError("Could not write to {minified_repro_path}") from e |
| |
| |
| def dump_to_minify(gm, args, compiler_name: str): |
| favored_device = 1 if torch.cuda.device_count() >= 2 else 0 |
| |
| contents = textwrap.dedent( |
| f""" |
| isolate_fails_code_str = None |
| |
| {generate_compiler_repro_string(gm, args)} |
| |
| from functools import partial |
| from {__name__} import ( |
| isolate_fails, |
| dump_compiler_graph_state, |
| ) |
| from functorch.compile import minifier |
| |
| env_variables = {{"CUDA_VISIBLE_DEVICES": "{favored_device}"}} |
| |
| minifier( |
| mod, |
| args, |
| module_fails=partial(isolate_fails, env=env_variables, compiler_name="{compiler_name}", patch_code=isolate_fails_code_str), |
| dump_state=partial(dump_compiler_graph_state, compiler_name="{compiler_name}"), |
| ) |
| """ |
| ) |
| return helper_for_dump_minify(contents) |
| |
| |
| class AccuracyError(Exception): |
| pass |
| |
| |
| def wrap_compiler_debug(unconfigured_compiler_fn, compiler_name: str): |
| """ |
| Minifier for Fx Graph modules after Aot Autograd has finished. We wrap both |
| forward and backward call separately with the backend compiler_fn - like |
| inductor or nvfuser. Intercepting after Aot Autograd presents neat |
| abstraction, where all the params are lifted as graph inputs, making it easy |
| to save the graph as a string. |
| """ |
| |
| @functools.wraps(unconfigured_compiler_fn) |
| def debug_wrapper(gm, example_inputs, **kwargs): |
| from torch._subclasses import FakeTensorMode |
| |
| compiler_fn = functools.partial(unconfigured_compiler_fn, **kwargs) |
| |
| orig_graph = copy.deepcopy(gm.graph) |
| assert config.repro_after in ("dynamo", "aot", None) |
| inner_compiled_fn = None |
| |
| def deferred_for_real_inputs(real_inputs): |
| """ |
| Aot Autograd fw_compiler and bw_compiler can have fake tensors. So, |
| example_inputs can be fake tensors. We can call compiler_fn (which is |
| inductor or nvfuser) with fake tensors but the actually compiled_fn |
| should be called with real tensors. Therefore, the actual invocation |
| is deferred. |
| """ |
| # Avoid re-compiling when we call the compiled function twice. This happens |
| # when we run the model inference or training in a for loop like here |
| # https://github.com/pytorch/torchdynamo/issues/1687#issuecomment-1280040633 |
| nonlocal inner_compiled_fn |
| # Copy the tensor attrs like shape, stride etc by converting to Fake Tensor |
| # because inductor clears the tensor list in its codegen. And example_inputs |
| # are available only for the first invocation. |
| fake_mode = FakeTensorMode() |
| copy_tensor_attrs = [fake_mode.from_tensor(x) for x in real_inputs] |
| if config.repro_level == 3: |
| # Always dump the original module in case we have segfaults |
| dump_to_minify( |
| fx.GraphModule(gm, orig_graph), real_inputs, compiler_name |
| ) |
| |
| if config.repro_level == 4: |
| if compiler_name != "inductor": |
| raise NotImplementedError( |
| "Accuracy minification is supported for inductor only" |
| ) |
| if inner_compiled_fn is None: |
| inner_compiled_fn = compiler_fn(gm, example_inputs) |
| if backend_aot_accuracy_fails(gm, real_inputs, compiler_fn): |
| log.warning("Accuracy failed for the AOT Autograd graph") |
| dump_compiler_graph_state( |
| fx.GraphModule(gm, orig_graph), |
| copy_tensor_attrs, |
| f"{compiler_name}_accuracy", |
| ) |
| dump_to_minify( |
| fx.GraphModule(gm, orig_graph), |
| copy_tensor_attrs, |
| f"{compiler_name}_accuracy", |
| ) |
| raise AccuracyError("Bad accuracy detected") |
| else: |
| # Call the compiled function with real inputs |
| return inner_compiled_fn(real_inputs) |
| else: |
| try: |
| # Call the compiler_fn - which is either aot_autograd or inductor |
| # with fake inputs |
| if inner_compiled_fn is None: |
| inner_compiled_fn = compiler_fn(gm, example_inputs) |
| # Call the compiled function with real inputs |
| return inner_compiled_fn(real_inputs) |
| except Exception as e: |
| if config.repro_level == 1: |
| dump_compiler_graph_state( |
| fx.GraphModule(gm, orig_graph), |
| copy_tensor_attrs, |
| compiler_name, |
| ) |
| elif config.repro_level == 2: |
| dump_to_minify( |
| fx.GraphModule(gm, orig_graph), |
| copy_tensor_attrs, |
| compiler_name, |
| ) |
| log.error("CompilerError") |
| raise |
| |
| if config.repro_after == "aot": |
| compiled_fn = deferred_for_real_inputs |
| compiled_fn._boxed_call = True |
| else: |
| compiled_fn = compiler_fn(gm, example_inputs) |
| |
| return compiled_fn |
| |
| return debug_wrapper |
| |
| |
| def run_fwd_maybe_bwd(gm, args, only_fwd=False): |
| """ |
| Runs a forward and possibly backward iteration for a given mod and args. |
| """ |
| from torch._functorch.aot_autograd import make_boxed_func |
| |
| from .testing import collect_results, reduce_to_scalar_loss, requires_bwd_pass |
| |
| gm = copy.deepcopy(gm) |
| new_args = clone_inputs(args) |
| # Set the requires_grad field explicitly because clone_inputs only sets |
| # requires_grad for leaf tensors. |
| for narg, arg in zip(new_args, args): |
| narg.requires_grad_(arg.requires_grad) |
| args = new_args |
| |
| if hasattr(gm, "zero_grad"): |
| gm.zero_grad(True) |
| |
| # TorchInductor returned callable expects lists. So, boxing the call. |
| orig_named_parameters = getattr(gm, "named_parameters", None) |
| orig_named_buffers = getattr(gm, "named_buffers", None) |
| if not hasattr(gm, "_boxed_call") and ( |
| orig_named_parameters is not None or orig_named_buffers is not None |
| ): |
| gm = make_boxed_func(gm) |
| if orig_named_parameters is not None: |
| gm.named_parameters = orig_named_parameters |
| if orig_named_buffers is not None: |
| gm.named_buffers = orig_named_buffers |
| |
| out = gm(args) |
| if only_fwd: |
| return out |
| if requires_bwd_pass(out): |
| loss = reduce_to_scalar_loss(out) |
| loss.backward() |
| return collect_results(gm, out, None, args) |
| |
| |
| def same_two_models(gm, opt_gm, example_inputs, only_fwd=False): |
| """ |
| Check two models have same accuracy. |
| """ |
| from .eval_frame import OptimizedModule |
| from .testing import ( |
| named_buffers_for_optimized_module, |
| named_parameters_for_optimized_module, |
| ) |
| from .utils import same |
| |
| if isinstance(gm, OptimizedModule): |
| gm.named_parameters = named_parameters_for_optimized_module(gm) |
| gm.named_buffers = named_buffers_for_optimized_module(gm) |
| |
| if isinstance(opt_gm, OptimizedModule): |
| opt_gm.named_parameters = named_parameters_for_optimized_module(opt_gm) |
| opt_gm.named_buffers = named_buffers_for_optimized_module(opt_gm) |
| |
| ref = run_fwd_maybe_bwd(gm, example_inputs, only_fwd) |
| |
| try: |
| fp64_model, fp64_examples = cast_to_fp64( |
| copy.deepcopy(gm), clone_inputs(example_inputs) |
| ) |
| fp64_ref = run_fwd_maybe_bwd(fp64_model, fp64_examples, only_fwd) |
| except Exception: |
| log.warning("Could not generate fp64 outputs") |
| fp64_ref = None |
| |
| try: |
| res = run_fwd_maybe_bwd(opt_gm, example_inputs, only_fwd) |
| except Exception as e: |
| # This means that the the minified graph is bad/exposes a different problem. |
| # As we are checking accuracy here, lets log the exception and return True. |
| log.exception( |
| ( |
| "While minifying the program in accuracy minification mode, " |
| "ran into a runtime exception which is likely an unrelated issue." |
| " Skipping this graph." |
| ) |
| ) |
| return True |
| |
| passing = same(ref, res, fp64_ref, tol=config.repro_tolerance, equal_nan=True) |
| return passing |
| |
| |
| def cast_convert_element_type_to_fp64(model): |
| for node in model.graph.nodes: |
| if ( |
| node.op == "call_function" |
| and node.target == torch.ops.prims.convert_element_type.default |
| ): |
| assert len(node.args) == 2 |
| if is_float_dtype(node.args[1]) and node.args[1] != torch.float64: |
| node.args = (node.args[0], torch.float64) |
| model.graph.lint() |
| model.recompile() |
| return model |
| |
| |
| def cast_to(dtype, model, inputs): |
| from torch.utils._pytree import tree_map |
| |
| model = model.to(dtype) |
| if dtype == torch.float64: |
| # If casting to fp64 for accuracy comparison, we need to |
| # take care of convert_element_type explicitly |
| model = cast_convert_element_type_to_fp64(model) |
| |
| inputs = tree_map( |
| lambda x: x.to(dtype) |
| if isinstance(x, torch.Tensor) and x.is_floating_point() |
| else x, |
| inputs, |
| ) |
| return model, inputs |
| |
| |
| def cast_to_fp64(model, inputs): |
| return cast_to(torch.float64, model, inputs) |
| |
| |
| def generate_dynamo_fx_repro_string( |
| model_str, args, compiler_name, check_accuracy=False |
| ): |
| """ |
| Generate a repro string for backend-agnostic minified version. |
| """ |
| |
| run_code = textwrap.dedent( |
| f""" |
| with torch.cuda.amp.autocast(enabled={torch.is_autocast_enabled()}): |
| ref = run_fwd_maybe_bwd(mod, args) |
| res = run_fwd_maybe_bwd(opt_mod, args) |
| """ |
| ) |
| |
| if config.repro_level == 4 or check_accuracy: |
| run_code = textwrap.dedent( |
| f""" |
| mod.eval() |
| opt_mod.eval() |
| |
| class AccuracyError(Exception): |
| pass |
| |
| with torch.cuda.amp.autocast(enabled={torch.is_autocast_enabled()}): |
| assert same_two_models(mod, mod, args), "Eager itself failed" |
| if not same_two_models(mod, opt_mod, args): |
| raise AccuracyError("Dynamo failed") |
| """ |
| ) |
| |
| return textwrap.dedent( |
| f""" |
| from math import inf |
| import torch |
| from torch import tensor, device |
| import torch.fx as fx |
| import torch._dynamo |
| from torch._dynamo.testing import rand_strided |
| from torch._dynamo.debug_utils import run_fwd_maybe_bwd |
| from torch._dynamo.debug_utils import same_two_models |
| |
| {generate_config_string()} |
| |
| {TEST_REPLACEABLE_COMMENT} |
| {extra_imports} |
| |
| args = {[(tuple(a.shape), tuple(a.stride()), a.dtype, a.device.type, a.requires_grad) for a in args]} |
| args = [rand_strided(sh, st, dt, dev).requires_grad_(rg) for (sh, st, dt, dev, rg) in args] |
| |
| {model_str} |
| |
| mod = Repro() |
| opt_mod = torch._dynamo.optimize("{compiler_name}")(mod) |
| |
| {run_code} |
| """ |
| ) |
| |
| |
| def dump_backend_repro_as_file(gm, args, compiler_name, check_accuracy=False): |
| """ |
| Saves the repro to a repro.py file |
| """ |
| curdir = os.getcwd() |
| subdir = os.path.join(os.getcwd(), "checkpoints") |
| if not os.path.exists(subdir): |
| os.makedirs(subdir, exist_ok=True) |
| file_name = os.path.join(subdir, f"minified_{len(gm.graph.nodes)}_nodes.py") |
| log.warning(f"Writing checkpoint with {len(gm.graph.nodes)} nodes to {file_name}") |
| |
| model_str = NNModuleToString.convert(gm) |
| with open(file_name, "w") as fd: |
| fd.write( |
| generate_dynamo_fx_repro_string( |
| model_str, args, compiler_name, check_accuracy |
| ) |
| ) |
| latest_repro = os.path.join(curdir, "repro.py") |
| log.warning(f"Copying {file_name} to {latest_repro} for convenience") |
| |
| if use_buck: |
| BuckTargetWriter(latest_repro).write() |
| |
| shutil.copyfile(file_name, latest_repro) |
| |
| |
| # TODO - Commented because we are assuming that nn.Modules can be safely repr'd |
| # If that does not work, we might have to bring this code back. So, keeping it |
| # as it is for now. |
| # def dump_backend_repro_as_tarfile(gm, args, compiler_name): |
| # """ |
| # Saves the repro in repro.tar.gz, as opposed to a file. This is used for |
| # cases, where we can't convert a Fx GraphModule to a string, and therefore |
| # fallback to to_folder for serialization. We accompany this with a repro.py |
| # script that imports the saved module, sets it up and runs the model to repro |
| # the error. |
| # """ |
| # import tarfile |
| |
| # subdir = os.path.join(minifier_dir(), "checkpoints") |
| # if not os.path.exists(subdir): |
| # os.makedirs(subdir, exist_ok=True) |
| |
| # tmp_dir = os.path.join(subdir, f"{len(gm.graph.nodes)}") |
| # if os.path.exists(tmp_dir): |
| # shutil.rmtree(tmp_dir) |
| # os.makedirs(tmp_dir, exist_ok=True) |
| |
| # file_name = os.path.join(tmp_dir, "repro.py") |
| # gm_dir = os.path.join(tmp_dir, "module") |
| # if not os.path.exists(gm_dir): |
| # os.makedirs(gm_dir, exist_ok=True) |
| # for node in gm.graph.nodes: |
| # new_kwargs = {} |
| # for k, v in node.kwargs.items(): |
| # if isinstance(v, torch.device): |
| # v = v.type |
| # new_kwargs[k] = v |
| # node.kwargs = new_kwargs |
| # gm.recompile() |
| |
| # print(f"Writing checkpoint with {len(gm.graph.nodes)} nodes to {file_name}") |
| # with open(file_name, "w") as fd: |
| # # TODO - Add the readable version of to_folder when available |
| # gm.to_folder(gm_dir, "Repro") |
| # fd.write( |
| # generate_dynamo_fx_repro_string( |
| # "from module import Repro", args, compiler_name |
| # ) |
| # ) |
| |
| # local_dir = os.path.join(config.base_dir, "repro") |
| # if os.path.exists(local_dir): |
| # shutil.rmtree(local_dir) |
| # shutil.copytree(tmp_dir, local_dir) |
| # local_tar_file = os.path.join(config.base_dir, "repro.tar.gz") |
| # print(f"Writing checkpoint with {len(gm.graph.nodes)} locally to {local_tar_file}") |
| # with tarfile.open(local_tar_file, "w:gz") as tar: |
| # tar.add(local_dir, arcname=os.path.basename(local_dir)) |
| |
| |
| def dump_backend_state(gm, args, compiler_name, check_accuracy=False): |
| """ |
| Dumps the dynamo graph to repro the issue. |
| 1) It tries to convert Fx GraphModule to a string. If we can, it writes to a |
| repro.py file. |
| 2) If we can't convert Fx GraphModule to a string, we use to_folder to save |
| the module and save a tar file. |
| """ |
| assert NNModuleToString.can_convert_to_string(gm) |
| return dump_backend_repro_as_file(gm, args, compiler_name, check_accuracy) |
| # return dump_backend_repro_as_tarfile(gm, args, compiler_name) |
| |
| |
| def backend_accuracy_fails(gm, example_inputs, compiler_fn, only_fwd=False): |
| try: |
| compiled_gm = compiler_fn(copy.deepcopy(gm), clone_inputs(example_inputs)) |
| except Exception as e: |
| # This means that the the minified graph is bad/exposes a different problem. |
| # As we are checking accuracy here, lets log the exception and return False. |
| log.exception( |
| ( |
| "While minifying the program in accuracy minification mode, " |
| "ran into a runtime exception which is likely an unrelated issue." |
| " Skipping this graph" |
| ) |
| ) |
| return False |
| |
| return not same_two_models(gm, compiled_gm, example_inputs, only_fwd) |
| |
| |
| backend_aot_accuracy_fails = functools.partial(backend_accuracy_fails, only_fwd=True) |
| |
| # Please see NOTE: [Real Tensors in Accuracy Evaluation] |
| MINIFIER_SPAWNED = False |
| |
| |
| def backend_fails(gm, example_inputs, compiler_fn, orig_failure): |
| """ |
| Minifier uses this function to identify if the minified graph module fails |
| with the same error. |
| |
| One caveat is that minifier can potentially go into a wrong direction when |
| the resulting graph module fails for a different reason. To avoid this, we |
| save the string for the original exception and check similarity between new |
| and old exception. They can be somewhat different in some cases, when the |
| exception string depends on the failing node information. So, we have a |
| loose similarity metric to guide the minifier path. |
| """ |
| from difflib import SequenceMatcher |
| |
| try: |
| compiled_gm = compiler_fn(gm, example_inputs) |
| run_fwd_maybe_bwd(compiled_gm, clone_inputs(example_inputs)) |
| return False |
| except Exception as e: |
| new_failure = str(e) |
| if SequenceMatcher(None, orig_failure, new_failure).ratio() > 0.5: |
| return True |
| return False |
| |
| |
| def dump_to_minify_after_dynamo(gm, args, compiler_name): |
| model_str = NNModuleToString.convert(gm) |
| |
| minifier_backend = "dynamo_minifier_backend" |
| if config.repro_level == 4: |
| minifier_backend = "dynamo_accuracy_minifier_backend" |
| |
| custom_compiler_error = ( |
| textwrap.dedent( |
| """\ |
| raise RuntimeError( |
| 'Compiler name is None - this likely means that a custom compiler ' |
| 'was called by torchdynamo. Please remove this error, import your ' |
| 'custom compiler function, and replace the compiler_name="None" ' |
| 'line below to compiler_name=<my_imported_custom_function>' |
| ) |
| """ |
| ) |
| if compiler_name is None |
| else "" |
| ) |
| |
| contents = textwrap.dedent( |
| f""" |
| import os |
| from math import inf |
| import torch |
| from torch import tensor, device |
| import torch.fx as fx |
| import functools |
| import torch._dynamo |
| from torch._dynamo.debug_utils import run_fwd_maybe_bwd |
| from torch._dynamo.backends.registry import lookup_backend |
| from torch._dynamo.testing import rand_strided |
| |
| {generate_config_string()} |
| |
| {TEST_REPLACEABLE_COMMENT} |
| {extra_imports} |
| |
| args = {[(tuple(a.shape), tuple(a.stride()), a.dtype, a.device.type, a.requires_grad) for a in args]} |
| args = [rand_strided(sh, st, dt, dev).requires_grad_(rg) for (sh, st, dt, dev, rg) in args] |
| |
| {model_str} |
| mod = Repro() |
| |
| # Setup debug minifier compiler |
| torch._dynamo.debug_utils.MINIFIER_SPAWNED = True |
| compiler_fn = lookup_backend("{minifier_backend}") |
| {custom_compiler_error} |
| dynamo_minifier_backend = functools.partial( |
| compiler_fn, |
| compiler_name="{compiler_name}", |
| ) |
| opt_mod = torch._dynamo.optimize(dynamo_minifier_backend)(mod) |
| |
| with torch.cuda.amp.autocast(enabled={torch.is_autocast_enabled()}): |
| opt_mod(*args) |
| """ |
| ) |
| helper_for_dump_minify(contents) |
| |
| |
| def wrap_backend_debug(unconfigured_compiler_fn, compiler_name: str): |
| """ |
| A minifier decorator that wraps the TorchDynamo produced Fx graph modules. |
| As opposed to wrap_compiler_debug, this wrapper intercepts at the |
| TorchDynamo produced Fx Graph Module. This makes it backend-agnostic to some |
| level, e.g., it is useful for minifying issues related to Aot Autograd |
| tracing. If an error is found, we minify and save the minified repro in |
| repro.tar.gz. |
| """ |
| |
| @functools.wraps(unconfigured_compiler_fn) |
| def debug_wrapper(gm, example_inputs, **kwargs): |
| compiler_fn = functools.partial(unconfigured_compiler_fn, **kwargs) |
| assert config.repro_after in ("dynamo", "aot", None) |
| |
| if config.repro_after == "dynamo": |
| |
| def add_paths(exc): |
| exc.minifier_path = os.path.join(minifier_dir(), "minifier_launcher.py") |
| if use_buck: |
| exc.buck_command = " ".join( |
| BUCK_CMD_PREFIX |
| + [BuckTargetWriter(exc.minifier_path).cmd_line_path] |
| ) |
| |
| if config.repro_level == 3: |
| dump_to_minify_after_dynamo(gm, example_inputs, compiler_name) |
| |
| # Check for either accuracy (level 4) or other type of failures. |
| if config.repro_level == 4: |
| # Check Accuracy |
| compiled_gm = compiler_fn(copy.deepcopy(gm), example_inputs) |
| if backend_accuracy_fails(gm, example_inputs, compiler_fn): |
| log.warning( |
| "Accuracy failed for the TorchDynamo produced graph. Creating script to minify the error." |
| ) |
| dump_to_minify_after_dynamo( |
| fx.GraphModule(gm, copy.deepcopy(gm.graph)), |
| example_inputs, |
| compiler_name, |
| ) |
| exc = AccuracyError("Bad accuracy detected.") |
| add_paths(exc) |
| raise exc |
| else: |
| try: |
| compiled_gm = compiler_fn(copy.deepcopy(gm), example_inputs) |
| run_fwd_maybe_bwd(compiled_gm, example_inputs) |
| except Exception as exc: |
| log.warning( |
| "Compiled Fx GraphModule failed. Creating script to minify the error." |
| ) |
| if config.repro_level == 1: |
| dump_state_fn = functools.partial( |
| dump_backend_state, compiler_name=compiler_name |
| ) |
| dump_state_fn( |
| fx.GraphModule(gm, copy.deepcopy(gm.graph)), example_inputs |
| ) |
| elif config.repro_level == 2: |
| dump_to_minify_after_dynamo( |
| fx.GraphModule(gm, copy.deepcopy(gm.graph)), |
| example_inputs, |
| compiler_name, |
| ) |
| add_paths(exc) |
| raise |
| else: |
| compiled_gm = compiler_fn(gm, example_inputs) |
| |
| return compiled_gm |
| |
| debug_wrapper._torchdynamo_orig_callable = unconfigured_compiler_fn |
| |
| return debug_wrapper |
| |
| |
| @register_debug_backend |
| def dynamo_minifier_backend(gm, example_inputs, compiler_name): |
| from functorch.compile import minifier |
| |
| compiler_fn = lookup_backend(compiler_name) |
| |
| try: |
| compiled_gm = compiler_fn(gm, example_inputs) |
| run_fwd_maybe_bwd(compiled_gm, example_inputs) |
| raise ValueError("No issue was detected") |
| except Exception as exc: |
| orig_failure = str(exc) |
| log.warning( |
| "Compiled Fx GraphModule failed. Creating script to minify the error." |
| ) |
| dump_state_fn = functools.partial( |
| dump_backend_state, compiler_name=compiler_name |
| ) |
| dump_state_fn(fx.GraphModule(gm, copy.deepcopy(gm.graph)), example_inputs) |
| fails_fn = functools.partial( |
| backend_fails, |
| compiler_fn=compiler_fn, |
| orig_failure=orig_failure, |
| ) |
| minifier( |
| gm, |
| example_inputs, |
| module_fails=fails_fn, |
| dump_state=dump_state_fn, |
| ) |
| return gm |
| |
| |
| @register_debug_backend |
| def dynamo_accuracy_minifier_backend(gm, example_inputs, compiler_name): |
| from functorch.compile import minifier |
| |
| compiler_fn = lookup_backend(compiler_name) |
| |
| # Set the eval mode to remove randomness. |
| gm.eval() |
| |
| # Check Accuracy |
| if backend_accuracy_fails( |
| gm, example_inputs, compiler_fn, only_fwd=config.repro_forward_only |
| ): |
| log.warning("Accuracy failed for the TorchDynamo produced graph") |
| dump_state_fn = functools.partial( |
| dump_backend_state, compiler_name=compiler_name, check_accuracy=True |
| ) |
| fails_fn = functools.partial( |
| backend_accuracy_fails, |
| compiler_fn=compiler_fn, |
| only_fwd=config.repro_forward_only, |
| ) |
| dump_state_fn(fx.GraphModule(gm, copy.deepcopy(gm.graph)), example_inputs) |
| minifier( |
| gm, |
| example_inputs, |
| module_fails=fails_fn, |
| dump_state=dump_state_fn, |
| ) |
| else: |
| log.error("Input graph does not fail accuracy testing") |
| return gm |