| import copy |
| import logging |
| import os |
| import pickle |
| import random |
| from contextlib import contextmanager |
| from functools import partial |
| from typing import Callable, Union |
| import sympy |
| |
| import torch |
| from torch import SymInt |
| import torch.fx as fx |
| import torch.nn as nn |
| from torch._decomp import get_decompositions |
| from torch.fx.experimental.symbolic_shapes import bind_symbols |
| |
| from .aot_autograd import aot_function, aot_module, make_boxed_compiler |
| from .compile_utils import strip_overloads |
| from .partitioners import ( |
| default_partition, |
| draw_graph, |
| min_cut_rematerialization_partition, |
| ) |
| import torch.utils._pytree as pytree |
| |
| |
| log = logging.getLogger(__name__) |
| |
| |
| # These canonicalizations are needed here (and not decompositions), as the ops |
| # we're trying to canonicalize to CompositeImplicitAutograd. |
| def _canonicalize(fx_g): |
| for node in fx_g.graph.nodes: |
| if node.target == torch.ops.aten._to_copy: |
| node.target = torch.ops.aten.to |
| fx_g.recompile() |
| return fx_g |
| |
| |
| @contextmanager |
| def _disable_jit_autocast(): |
| old_jit_autocast_flag = torch._C._jit_set_autocast_mode(False) |
| try: |
| yield |
| finally: |
| torch._C._jit_set_autocast_mode(old_jit_autocast_flag) |
| |
| |
| @make_boxed_compiler |
| def ts_compile(fx_g: fx.GraphModule, inps) -> Callable: |
| """ |
| Compiles the :attr:`fx_g` with Torchscript compiler. |
| |
| .. warning:: |
| This API is experimental and likely to change. |
| |
| Args: |
| fx_g(fx.GraphModule): The input Fx graph module to be compiled. |
| |
| Returns: |
| Torch scripted model. |
| """ |
| |
| with _disable_jit_autocast(): |
| strip_overloads(fx_g) |
| |
| for node in fx_g.graph.nodes: |
| if ( |
| node.target == torch.ops.aten._to_copy |
| and len(node.args) == 1 |
| and len(node.kwargs) == 1 |
| and "dtype" in node.kwargs |
| ): |
| node.target = torch.ops.aten.to |
| |
| for node in fx_g.graph.nodes: |
| new_kwargs = {} |
| for k, v in node.kwargs.items(): |
| if isinstance(v, torch.device): |
| v = v.type |
| new_kwargs[k] = v |
| node.kwargs = new_kwargs |
| |
| fx_g.graph.lint() |
| |
| fx_g.recompile() |
| |
| f = torch.jit.script(fx_g) |
| |
| torch._C._jit_pass_remove_mutation(f.graph) |
| |
| f = torch.jit.freeze(f.eval()) |
| f = torch.jit.optimize_for_inference(f) |
| if not any(isinstance(t, torch._subclasses.FakeTensor) for t in inps): |
| f(*inps) |
| return f |
| |
| |
| def _draw_graph_compile(fx_g, _, name, clear_meta=True): |
| print(fx_g.code) |
| draw_graph(fx_g, name, clear_meta=clear_meta) |
| return fx_g |
| |
| |
| def draw_graph_compile(name): |
| return make_boxed_compiler( |
| partial(_draw_graph_compile, name=name) |
| ) |
| |
| |
| @make_boxed_compiler |
| def nop(fx_g: fx.GraphModule, _) -> Callable: |
| """ |
| Returns the :attr:`fx_g` Fx graph module as it is. This is a no-op compiler |
| and can be used to check accuracy. |
| |
| .. warning:: |
| This API is experimental and likely to change. |
| |
| """ |
| return fx_g |
| |
| class DebugInterpreter(fx.Interpreter): |
| def run(self, *args): |
| self.symbol_mapping = bind_symbols(self.module, *args) |
| super().run(*args) |
| |
| def run_node(self, n): |
| |
| def subst_symint(ni): |
| if not isinstance(ni, SymInt): |
| return ni |
| r = sympy.expand(ni.node.expr.xreplace(self.symbol_mapping)) |
| assert len(r.free_symbols) == 0, r |
| return int(r) |
| |
| def subst_symint_tuple(nis): |
| return tuple(subst_symint(ni) for ni in nis) |
| |
| def check_significant_strides(a, b): |
| if subst_symint(a.numel()) > 0: |
| for idx in range(a.ndim): |
| if subst_symint(a.stride(idx)) != b.stride(idx) and subst_symint(a.size(idx)) > 1: |
| return False |
| return True |
| |
| def check(nv, rv, desc): |
| assert callable(desc) |
| assert nv.dtype == rv.dtype, f"{desc()}: {nv.dtype} != {rv.dtype}" |
| assert subst_symint_tuple(nv.size()) == rv.size(), \ |
| f"{desc()}: {nv.size()} aka {subst_symint_tuple(nv.size())} != {rv.size()}" |
| same_strides = check_significant_strides(nv, rv) |
| assert same_strides, f"{desc()}: {nv.stride()} aka {subst_symint_tuple(nv.stride())} != {rv.stride()}" |
| |
| r = super().run_node(n) |
| if 'val' in n.meta: |
| n_vals, n_spec = pytree.tree_flatten(n.meta['val']) |
| r_vals, r_spec = pytree.tree_flatten(r) |
| # TODO: There is some sort of problem where we record that an |
| # operator returned a tuple/list, and then later it turns out the |
| # real version of the operator returned a list/tuple. Need to |
| # figure out what's actually going on here, the error itself is |
| # harmless enough as we only getitem out the outputs. |
| # assert n_spec == r_spec, f"{n_spec} != {r_spec}" |
| assert len(n_vals) == len(r_vals), f"{len(n_vals)} != {len(r_vals)}" |
| for i, nv, rv in zip(range(len(n_vals)), n_vals, r_vals): |
| if not isinstance(rv, torch.Tensor): |
| continue |
| check(nv, rv, lambda: f"output {i} where {self.symbol_mapping}") |
| return r |
| |
| |
| @make_boxed_compiler |
| def debug_nop(fx_g: fx.GraphModule, _) -> Callable: |
| """ |
| Returns a (slow) interpreter over the FX graph module that also checks |
| various debugging properties (e.g., that tracing strides matched real |
| strides.) |
| """ |
| return DebugInterpreter(fx_g).run |
| |
| @make_boxed_compiler |
| def simple_ts_compile(fx_g, _): |
| strip_overloads(fx_g) |
| f = torch.jit.script(fx_g) |
| f = torch.jit.freeze(f.eval()) |
| return f |
| |
| |
| def nnc_jit(f): |
| return aot_function(f, simple_ts_compile) |
| |
| |
| aten = torch.ops.aten |
| default_decompositions = { |
| aten.detach, |
| aten.gelu_backward, |
| aten.leaky_relu_backward, |
| aten.sigmoid_backward, |
| aten.threshold_backward, |
| aten.hardtanh_backward, |
| aten.hardsigmoid_backward, |
| aten.hardswish_backward, |
| aten.tanh_backward, |
| aten.silu_backward, |
| aten.elu_backward, |
| aten.cudnn_batch_norm, |
| aten.cudnn_batch_norm_backward, |
| aten.masked_fill.Scalar, |
| aten.masked_fill.Tensor, |
| aten.elu, |
| aten.leaky_relu, |
| aten.hardtanh, |
| aten.hardswish, |
| aten.hardsigmoid, |
| aten.conj_physical, |
| aten.is_same_size, |
| } |
| |
| default_decompositions = get_decompositions(default_decompositions) |
| |
| |
| @make_boxed_compiler |
| def print_compile(fx_g, _): |
| print(fx_g.code) |
| return fx_g |
| |
| |
| def memory_efficient_fusion( |
| fn: Union[Callable, nn.Module], |
| **kwargs, |
| ): |
| """ |
| Wrapper function over :func:`aot_function` and :func:`aot_module` to perform |
| memory efficient fusion. It uses the |
| :func:`min_cut_rematerialization_partition` partitioner to perform efficient |
| recomputation. It uses NVFuser to compile the generated forward and backward |
| graphs. |
| |
| .. warning:: |
| This API is experimental and likely to change. |
| |
| Args: |
| fn (Union[Callable, nn.Module]): A Python function or a ``nn.Module`` |
| that takes one ore more arguments. Must return one or more Tensors. |
| **kwargs: Any other overrides you want to make to the settings |
| |
| Returns: |
| Returns a ``Callable`` or ``nn.Module`` that retains the eager behavior |
| of the original :attr:`fn`, but whose forward and backward graphs have |
| gone through recomputation optimizations, and the graphs have been |
| compiled with nvfuser. |
| |
| """ |
| config = { |
| "fw_compiler": ts_compile, |
| "bw_compiler": ts_compile, |
| "partition_fn": min_cut_rematerialization_partition, |
| "decompositions": default_decompositions, |
| } |
| config.update(kwargs) |
| if isinstance(fn, torch.nn.Module): |
| return aot_module(fn, **config) |
| else: |
| return aot_function(fn, **config) |
| |
| |
| def debug_compile(fx_g, inps): |
| fx_g.to_folder("foo") |
| print( |
| f""" |
| ############################################################## |
| # To minimize FX graph, copy and paste the below and run it # |
| ############################################################## |
| |
| import torch |
| import torch.fx as fx |
| from functorch.compile import minifier, check_nvfuser_subprocess, check_nvfuser_correctness_subprocess |
| |
| inps = {[(i.shape, i.dtype) for i in inps]} |
| inps = [torch.ones(shape, dtype=dtype, device='cuda') for (shape, dtype) in inps] |
| from foo import FxModule |
| mod = FxModule().cuda() |
| |
| with torch.jit.fuser("fuser2"): |
| # check_nvfuser_subprocess can be replaced with check_nvfuser_correctness_subprocess |
| minifier(fx.symbolic_trace(mod), inps, check_nvfuser_subprocess) |
| """ |
| ) |
| from foo import FxModule |
| |
| FxModule().cuda()(*inps) |
| |
| return ts_compile(fx_g, inps) |
| |
| |
| graph_index = 0 |
| |
| |
| def get_inputs(input_data_path): |
| """ |
| Return a random input for the given inputs meta generated from _save_fx_default. |
| """ |
| inputs = [] |
| with (open(input_data_path, "rb")) as f: |
| inputs_meta = pickle.load(f) |
| inputs = [] |
| for meta in inputs_meta: |
| if len(meta) == 1: |
| type = meta |
| input = type(random.rand()) |
| else: |
| type, shape, stride, dtype, device = meta |
| if dtype in { |
| torch.int, |
| torch.int32, |
| torch.int64, |
| torch.bool, |
| torch.int, |
| torch.uint8, |
| int, |
| float, |
| }: |
| input = torch.randint(0, 1, shape, dtype=dtype, device=device) |
| else: |
| input = torch.rand(shape, dtype=dtype, device=device) |
| inputs.append(input) |
| return inputs |
| |
| |
| def _save_fx_default(current_name, folder_name, dump_example_input, gm, example_inputs): |
| """ |
| The forward, backward, and joint computation graph will be stored in |
| {folder_name}/{current_name}/{current_name}_forward_{graph_index}, |
| {folder_name}/{current_name}/{current_name}_backward_{graph_index}, and |
| {folder_name}/{current_name}/{current_name}_joint_{graph_index} respectively. |
| The input shape of the graphs will be stored in the .input files. |
| These files can be loaded with pickle, |
| and is a list of format (type, shape, stride, dtype, device). |
| In the case of type = int or float, it is just (type,). |
| For joint graph input, it is a nested list [[],[]] |
| where the two inner lists have the same format. |
| If dump_example_input is True, example_inputs will be stored in .pt file. |
| Since each function might produce multiple graphs, |
| the graph_index is used to distinguish difference graphs |
| """ |
| from functorch.compile import aot_module_simplified |
| |
| def get_input_meta(args): |
| input_meta = [] |
| if len(args) > 0 and isinstance(args[0], tuple): # joint input |
| input_meta += get_input_meta(args[0]) |
| input_meta += get_input_meta(args[1]) |
| return input_meta |
| for arg in args: |
| if type(arg) == int or type(arg) == float: |
| input_meta.append((type(arg),)) |
| else: |
| input_meta.append( |
| (type(arg), arg.shape, arg.stride(), arg.dtype, arg.device) |
| ) |
| return input_meta |
| |
| def graph_saver_helper(gm_to_save, args, type_name): |
| global graph_index |
| if len(gm_to_save.graph.nodes) == 0: |
| log.log( |
| logging.WARNING, |
| "No nodes in graph {%s}_{%s}_{%s}.", |
| current_name, |
| type_name, |
| graph_index, |
| ) |
| return |
| |
| gm = copy.deepcopy(gm_to_save) |
| gm.graph.set_codegen(torch.fx.graph.CodeGen()) # remove codegen |
| gm.recompile() |
| |
| input_meta = get_input_meta(args) |
| |
| isExist = os.path.exists(f"{folder_name}/{current_name}") |
| if not isExist: |
| os.makedirs(f"{folder_name}/{current_name}") |
| gm.to_folder( |
| f"{folder_name}/{current_name}/{current_name}_{type_name}_{graph_index}" |
| ) |
| pickle.dump( |
| input_meta, |
| open( |
| f"{folder_name}/{current_name}/{current_name}_{type_name}_{graph_index}/{current_name}_{type_name}_{graph_index}.input", # noqa: B950 |
| "wb", |
| ), |
| ) # noqa: E501 |
| if dump_example_input: |
| torch.save( |
| args, |
| f"{folder_name}/{current_name}/{current_name}_{type_name}_{graph_index}/{current_name}_{type_name}_{graph_index}.pt", # noqa: B950 |
| ) # noqa: E501 |
| |
| def graph_saver_forward(gm, fw_args): |
| graph_saver_helper(gm, fw_args, "forward") |
| return gm |
| |
| def graph_saver_backward(gm, bw_args): |
| graph_saver_helper(gm, bw_args, "backward") |
| global graph_index |
| graph_index += 1 |
| return gm |
| |
| def graph_saver_joint(gm, joint_args): |
| graph_saver_helper(gm, joint_args, "joint") |
| return default_partition(gm, joint_args) |
| |
| return aot_module_simplified( |
| gm, |
| example_inputs, |
| fw_compiler=graph_saver_forward, |
| bw_compiler=graph_saver_backward, |
| partition_fn=graph_saver_joint, |
| decompositions=default_decompositions, |
| ) |
| |
| |
| # WARNING: This isn't tested anywhere!! |
| def graph_dumper_aot(current_name, folder_name, dump_example_input=False): |
| """ |
| Dump the forward, backward, and joint computation graph. |
| Example Usage: |
| save_fx_func = graph_dumper_aot(current_name, folder_name, dump_example_input = False) |
| optimize_ctx = torchdynamo.optimize( |
| save_fx_func |
| ) |
| with torch.enable_grad(): |
| with optimize_ctx: |
| result = forward_and_backward_pass(model, example_inputs) |
| """ |
| global graph_index |
| graph_index = 0 |
| return partial(_save_fx_default, current_name, folder_name, dump_example_input) |