| import functools |
| import logging |
| import operator |
| from collections import defaultdict |
| from functools import partial |
| from importlib import import_module |
| from typing import Set |
| |
| import torch |
| from torch.fx import GraphModule |
| from torch.fx.passes.backends.cudagraphs import partition_cudagraphs |
| from torch.multiprocessing.reductions import StorageWeakRef |
| from torch.nn import Module |
| from torch.utils._pytree import tree_map |
| |
| from .. import config |
| from ..debug_utils import wrap_compiler_debug |
| from ..utils import clone_inputs, count_calls, counters |
| from .analysis import has_mutation |
| from .backends import BACKENDS |
| from .normalize import normalize_ir |
| |
| log = logging.getLogger(__name__) |
| |
| |
| def is_aot_autograd_safe_to_run(gm, example_inputs): |
| """ |
| There are some known issues with Aot Autograd. This is a workaround to catch |
| such cases, and fallback to eager. We should fix these quickly. |
| |
| Issues |
| 1) LSTM - https://github.com/pytorch/torchdynamo/issues/1147 |
| 2) LSTM - https://github.com/pytorch/functorch/issues/586 |
| 3) Input mutation - https://github.com/pytorch/torchdynamo/issues/1301 |
| """ |
| |
| def raise_or_warn(reason): |
| msg = f"Unable to use Aot Autograd because of presence of {reason}" |
| if config.raise_on_unsafe_aot_autograd: |
| raise NotImplementedError(msg) |
| else: |
| log.warning(msg) |
| return False |
| |
| import functorch.compile |
| |
| # 1) LSTM module (tts_angular) - https://github.com/pytorch/functorch/issues/586 |
| for submod in gm.modules(): |
| if submod.__class__.__name__ == "LSTM": |
| return raise_or_warn("LSTM") |
| |
| # 2) Mutation in the graph |
| mutated = False |
| try: |
| if functorch.compile.config.use_functionalize: |
| # There are two problematic classes we still exclude for now with |
| # functionalization: |
| # - data mutation of inputs (fixed when we stop recording the |
| # copy_ directly into the graph) |
| # - metadata mutation of inputs (fixed if we do an extra partition |
| # to avoid AotAutograd on the mutated inputs, or if we some how |
| # get custom autograd function to reflect metadata changes to the |
| # original tensor) |
| mutated = has_mutation(gm, example_inputs, inputs_only=True) |
| else: |
| mutated = has_mutation(gm, example_inputs) |
| except NotImplementedError as e: |
| if "SparseTensorImpl" not in str(e): |
| # TODO - TorchDynamo mutation analysis cannot handle sparse tensors. |
| # So, there is a chance that we could call Aot Autograd when it is |
| # unsafe. |
| # The exception is fairly guarded with string check, so any other |
| # mutation analysis bugs will raise exceptions and will be caught. |
| raise e |
| pass |
| |
| if mutated: |
| return raise_or_warn("mutation") |
| |
| return True |
| |
| |
| class AotAutogradStrategy(object): |
| """Base class for backend strategies that use AOT Autograd""" |
| |
| @classmethod |
| def compile_fn(cls, gm: torch.fx.GraphModule, example_inputs): |
| if count_calls(gm.graph) < 2: |
| return gm # no point for tiny graphs |
| return cls(gm, example_inputs).verified_candidate() |
| |
| def __init__(self, gm: torch.fx.GraphModule, example_inputs): |
| import functorch.compile |
| |
| functorch.compile.config.use_functionalize = True |
| functorch.compile.config.use_fake_tensor = True |
| |
| super(AotAutogradStrategy, self).__init__() |
| counters["aot_autograd"]["total"] += 1 |
| self.use_fallback = False |
| self.original_example_inputs = example_inputs |
| self.gm = gm |
| |
| if not functorch.compile.config.use_functionalize and config.normalize_ir: |
| try: |
| self.gm = normalize_ir(gm, self.example_inputs) |
| except Exception: |
| log.debug("TorchDynamo unable to remove mutation") |
| self.use_fallback = True |
| pass |
| |
| if not is_aot_autograd_safe_to_run(gm, example_inputs): |
| self.use_fallback = True |
| |
| @property |
| def example_inputs(self): |
| return clone_inputs(self.original_example_inputs) |
| |
| def verified_candidate(self): |
| if self.use_fallback: |
| log.debug("Unable to use AOT Autograd because graph has mutation") |
| counters["aot_autograd"]["not_ok"] += 1 |
| return self.gm |
| cg = self.candidate() |
| if cg is None: |
| counters["aot_autograd"]["not_ok"] += 1 |
| raise RuntimeError("AOT Autograd failed to compile") |
| counters["aot_autograd"]["ok"] += 1 |
| return cg |
| |
| def candidate(self): |
| raise NotImplementedError() |
| |
| |
| class AotNop(AotAutogradStrategy): |
| """Useful for debugging purpose""" |
| |
| def candidate(self): |
| from functorch.compile import nop |
| |
| return BACKENDS["aot_autograd"](self.gm, self.example_inputs, fw_compiler=nop) |
| |
| |
| aot_eager = AotNop.compile_fn |
| |
| |
| class AotTorchscript(AotAutogradStrategy): |
| """ |
| AOT Autograd with torchscript backend. Default partitioner. |
| """ |
| |
| def candidate(self): |
| from functorch.compile import ts_compile |
| |
| return BACKENDS["aot_autograd"]( |
| self.gm, self.example_inputs, fw_compiler=ts_compile |
| ) |
| |
| |
| aot_ts = AotTorchscript.compile_fn |
| |
| # Global counter to differentiate between different graphs. |
| graph_idx = 0 |
| |
| |
| class AotPrint(AotNop): |
| """Saves all the gm models so that we can run them separately""" |
| |
| def candidate(self): |
| global graph_idx |
| module_idx = "module_" + str(graph_idx) |
| self.gm.to_folder(module_idx, "Bar") |
| for idx, x in enumerate(self.example_inputs): |
| torch.save(x, module_idx + "_tensor" + str(idx) + ".pt") |
| graph_idx += 1 |
| return super(AotPrint, self).candidate() |
| |
| |
| aot_print = AotPrint.compile_fn |
| |
| |
| def mem_efficient_fusion_kwargs(use_decomps): |
| from functorch.compile import ( |
| default_decompositions, |
| min_cut_rematerialization_partition, |
| ts_compile, |
| ) |
| |
| kwargs = { |
| # these are taken from memory_efficient_fusion() |
| "fw_compiler": ts_compile, |
| "bw_compiler": ts_compile, |
| "partition_fn": min_cut_rematerialization_partition, |
| } |
| |
| if use_decomps: |
| kwargs["decompositions"] = default_decompositions |
| |
| return kwargs |
| |
| |
| class AotMemEfficientFusion(AotAutogradStrategy): |
| """Use Min cut rematerilization and NVFuser with AOT Autograd""" |
| |
| def candidate(self): |
| kwargs = mem_efficient_fusion_kwargs(use_decomps=True) |
| return BACKENDS["aot_autograd"](self.gm, self.example_inputs, **kwargs) |
| |
| |
| class AotMemEfficientFusionNoDecomps(AotAutogradStrategy): |
| """Use Min cut rematerilization and NVFuser with AOT Autograd""" |
| |
| def candidate(self): |
| kwargs = mem_efficient_fusion_kwargs(use_decomps=False) |
| return BACKENDS["aot_autograd"](self.gm, self.example_inputs, **kwargs) |
| |
| |
| class AotInductorDebug(AotAutogradStrategy): |
| """ |
| Uses TorchInductor Aot Autograd decopms and partitioner to isolate aot vs |
| inductor problems. |
| """ |
| |
| def candidate(self): |
| from functorch.compile import min_cut_rematerialization_partition, nop |
| |
| decompositions = import_module( |
| f"{config.inductor_import}.compile_fx" |
| ).select_decomp_table() |
| |
| kwargs = { |
| # these are taken from memory_efficient_fusion() |
| "fw_compiler": nop, |
| "bw_compiler": nop, |
| "decompositions": decompositions, |
| "partition_fn": functools.partial( |
| min_cut_rematerialization_partition, compiler="inductor" |
| ), |
| } |
| return BACKENDS["aot_autograd"](self.gm, self.example_inputs, **kwargs) |
| |
| |
| aot_inductor_debug = AotInductorDebug.compile_fn |
| |
| |
| class AOTMemEfficientFusionWithContext: |
| """Pass nvfuser context to TorchDynamo""" |
| |
| def __init__(self, use_decomps=True): |
| self.backend_ctx_ctor = lambda: torch.jit.fuser("fuser2") |
| self.use_decomps = use_decomps |
| |
| def __call__(self, gm: torch.fx.GraphModule, example_inputs): |
| if self.use_decomps: |
| return AotMemEfficientFusion.compile_fn(gm, example_inputs) |
| else: |
| return AotMemEfficientFusionNoDecomps.compile_fn(gm, example_inputs) |
| |
| |
| aot_mem_efficient_fusion = AOTMemEfficientFusionWithContext(True) |
| aot_mem_efficient_fusion_no_decomp = AOTMemEfficientFusionWithContext(False) |
| |
| |
| class AotPrimsNvfuser(AotAutogradStrategy): |
| """ |
| Use FX graph partitioner + Aten2Prims ref + trace executor + nvFuser |
| """ |
| |
| def __init__(self, gm: torch.fx.GraphModule, example_inputs): |
| super(AotPrimsNvfuser, self).__init__(gm, example_inputs) |
| |
| from functorch.compile import min_cut_rematerialization_partition |
| |
| from torch.fx.passes.backends.nvfuser import NvFuserBackend |
| |
| self.nvfuser = NvFuserBackend() |
| self.min_cut_rematerialization_partition = min_cut_rematerialization_partition |
| self.populate_aten2aten_decomps() |
| |
| def populate_aten2aten_decomps(self): |
| from torch._decomp import get_decompositions |
| |
| aten = torch.ops.aten |
| default_decompositions = { |
| aten.detach, |
| aten.gelu_backward, |
| aten.leaky_relu_backward, |
| aten.sigmoid_backward, |
| aten.threshold_backward, |
| aten.hardtanh_backward, |
| aten.hardsigmoid_backward, |
| aten.hardswish_backward, |
| aten.tanh_backward, |
| aten.silu_backward, |
| aten.elu_backward, |
| aten.cudnn_batch_norm, |
| aten.cudnn_batch_norm_backward, |
| aten.masked_fill.Scalar, |
| aten.masked_fill.Tensor, |
| aten.elu, |
| aten.leaky_relu, |
| aten.hardtanh, |
| aten.hardswish, |
| aten.hardsigmoid, |
| aten.rsub, |
| aten.native_batch_norm_backward, |
| } |
| |
| self.aten2aten_decompositions = get_decompositions(default_decompositions) |
| |
| def candidate(self): |
| return BACKENDS["aot_autograd"]( |
| self.gm, |
| self.example_inputs, |
| fw_compiler=wrap_compiler_debug(self.nvfuser, "nvfuser"), |
| partition_fn=self.min_cut_rematerialization_partition, |
| decompositions=self.aten2aten_decompositions, |
| ) |
| |
| |
| aot_prims_nvfuser = AotPrimsNvfuser.compile_fn |
| |
| |
| def prims_executor(gm, inputs, *, executor): |
| # This function is called once per forward/backward pass of a graph in AOT |
| # Autograd. We use it to set up the nvFuser-specific FX graph and return |
| # execute function. |
| from torch._prims.context import TorchRefsNvfuserCapabilityMode |
| from torch._prims.executor import execute |
| from torch.fx.experimental.proxy_tensor import make_fx |
| |
| # First we trace the graph conditionally decomposing nodes |
| # that can be sent to the nvfuser executor |
| with TorchRefsNvfuserCapabilityMode(): |
| prim_gm = make_fx(gm)(*inputs) |
| |
| # Then we return a callable that executes the "prim_gm" graph |
| return partial(execute, prim_gm, executor=executor) |
| |
| |
| def create_nvprims_backend(*, executor): |
| class NvPrims(AotAutogradStrategy): |
| def __init__(self, gm: torch.fx.GraphModule, example_inputs): |
| super(NvPrims, self).__init__(gm, example_inputs) |
| self.executor = executor |
| |
| def candidate(self): |
| return BACKENDS["aot_autograd"]( |
| self.gm, |
| self.example_inputs, |
| fw_compiler=partial(prims_executor, executor=self.executor), |
| bw_compiler=partial(prims_executor, executor=self.executor), |
| ) |
| |
| return NvPrims |
| |
| |
| aot_nvprims_nvfuser = create_nvprims_backend(executor="nvfuser").compile_fn |
| aot_nvprims_aten = create_nvprims_backend(executor="aten").compile_fn |
| |
| |
| def cloner(t): |
| if isinstance(t, torch.Tensor): |
| return t.clone() |
| else: |
| return t |
| |
| |
| class CudaGraphModule(Module): |
| gm: GraphModule |
| mutated_inputs: Set[int] |
| |
| def __init__(self, gm, mutated_inputs): |
| super().__init__() |
| self.gm = gm |
| self.mutated_inputs = mutated_inputs |
| |
| warmed_up = False |
| |
| # these are all None or all filled |
| graph = None |
| static_inputs = None |
| static_outputs = None |
| |
| # NB: we override __call__ as we don't need any nn.Module machinery |
| # and to reduce overhead |
| def __call__(self, *args): |
| # TODO: once we've recorded here, we'd like to replace the __call__ |
| # implementation with compiled bytecode that copies into static, replays |
| # the cuda graph, then copies out. First condition is the hotpath, |
| # needs optimizing |
| if self.graph is not None: |
| assert len(args) == len(self.static_inputs) |
| for dst, src in zip(self.static_inputs, args): |
| dst.copy_(src) |
| self.graph.replay() |
| for i in self.mutated_inputs: |
| args[i].copy_(self.static_inputs[i]) |
| return tree_map(cloner, self.static_outputs) |
| |
| elif self.warmed_up: |
| # record |
| self.static_inputs = [x.clone() for x in args] |
| self.graph = torch.cuda.CUDAGraph() |
| with torch.cuda.graph(self.graph): |
| self.static_outputs = self.gm(*self.static_inputs) |
| # NB: recording doesn't actually run the operations, so |
| # now we immediately replay the graph to serve up the result |
| self.graph.replay() |
| for i in self.mutated_inputs: |
| args[i].copy_(self.static_inputs[i]) |
| return tree_map(cloner, self.static_outputs) |
| |
| else: |
| # warmup |
| stream = torch.cuda.Stream() |
| stream.wait_stream(torch.cuda.current_stream()) |
| with torch.cuda.stream(stream): |
| r = self.gm(*args) |
| torch.cuda.current_stream().wait_stream(stream) |
| self.warmed_up = True |
| return r |
| |
| |
| # Interpreter versions of these passes can be found at |
| # https://gist.github.com/ezyang/df2d746cac3b2c7d55c181e37c57ef23 |
| |
| |
| def find_input_mutations(g): |
| def meta_fk(meta): |
| return meta["val"] if "val" in meta else meta["fake_result"] |
| |
| inputs = defaultdict(set) |
| input_idx = 0 |
| mutated_inputs = set() |
| for n in g.nodes: |
| if n.op == "placeholder": |
| inputs[StorageWeakRef(meta_fk(n.meta).storage())].add(input_idx) |
| input_idx += 1 |
| elif n.op == "call_function": |
| if n.target is operator.getitem: |
| continue |
| schema = n.target._schema |
| for i, arg in enumerate(schema.arguments): |
| if i < len(n.args): |
| argument = n.args[i] |
| else: |
| if arg.name not in n.kwargs: |
| continue |
| argument = n.kwargs[arg.name] |
| mut_arg = False |
| if arg.alias_info: |
| if arg.alias_info.is_write: |
| mut_arg = True |
| if mut_arg: |
| # TODO: not correct for args that contain tensors in a struct |
| # like list |
| mutated_inputs |= inputs[ |
| StorageWeakRef(meta_fk(argument.meta).storage()) |
| ] |
| # TODO: error on unrecognized nodes |
| return mutated_inputs |
| |
| |
| # Mutates input graph |
| def apply_cuda_graphs(gm): |
| for n in gm.graph.nodes: |
| if n.op == "call_module": |
| assert not n.kwargs |
| submod = gm.get_submodule(n.target) |
| gm.delete_submodule(n.target) |
| mutated_inputs = find_input_mutations(submod.graph) |
| gm.add_submodule(n.target, CudaGraphModule(submod, mutated_inputs)) |
| # NB: we didn't actually change the graph, no need for recompile |
| |
| |
| def cudagraphs(model, inputs): |
| model = partition_cudagraphs(model, inputs) |
| apply_cuda_graphs(model) |
| return model |
| |
| |
| def raw_aot_autograd_cudagraphs(model, inputs): |
| kwargs = { |
| # these are taken from memory_efficient_fusion() |
| "fw_compiler": cudagraphs, |
| "bw_compiler": cudagraphs, |
| } |
| |
| def _wrapped_bw_compiler(*args, **kwargs): |
| # stop TorchDynamo from trying to compile our generated backwards pass |
| return disable(bw_compiler(*args, **kwargs)) # type: ignore[operator] |
| |
| bw_compiler = kwargs.get("bw_compiler") or kwargs["fw_compiler"] |
| kwargs["bw_compiler"] = _wrapped_bw_compiler |
| |
| from functorch.compile import aot_module_simplified # type: ignore[import] |
| |
| from .. import disable |
| |
| return aot_module_simplified(model, **kwargs) |
| |
| |
| class AotAutogradCudaGraphs(AotAutogradStrategy): |
| def candidate(self): |
| return raw_aot_autograd_cudagraphs(self.gm, self.example_inputs) |
| |
| |
| aot_cudagraphs = AotAutogradCudaGraphs.compile_fn |
| |
| |
| def create_aot_backends(): |
| """ |
| Register aliases for the AOT backends |
| """ |
| # aot_eager uses AOT Autograd backend with nop compiler. It is helpful in debugging. |
| BACKENDS["aot_eager"] = aot_eager |
| |
| # aot_eager uses AOT Autograd backend with print compiler. It prints the |
| # graphs and also saves the graph modules that are sent to AOT Autograd. |
| # This is helpful for debugging. |
| BACKENDS["aot_print"] = aot_print |
| |
| # aot_ts uses torchscript backend. We can use this with both nnc and nvfuser |
| # by using the relevant fuser with torch.jit.fuser(...) |
| BACKENDS["aot_ts"] = aot_ts |
| |
| # prims_nvfuser uses the prims and AOT-Autograd to get FX-aten IR. And then |
| # directly lowers to NVFuser without relying no Torchscript. |
| BACKENDS["prims_nvfuser"] = aot_prims_nvfuser |
| |
| # "nvprims" is a subset of PrimTorch primitives that are guaranteed to be |
| # supported by nvFuser. This is the preferred backend for nvFuser+PrimTorch. |
| BACKENDS["nvprims_nvfuser"] = aot_nvprims_nvfuser |
| # This is useful for debugging. Can be removed later. |
| BACKENDS["nvprims_aten"] = aot_nvprims_aten |
| |
| # aot_nvfuser uses the memory efficient fusion algorithm from AOT Autograd. |
| # It uses min cut rematerialization algorithm, and uses nvfuser as the |
| # compiler backend. This is the most optimized setting with nvfuser for |
| # training. |
| BACKENDS["aot_nvfuser"] = aot_mem_efficient_fusion |
| |
| # Similar to aot_nvfuser, but disables the decompositions. Decompositions |
| # can cause accuracy deviations. This setting allows us to compare accuracy |
| # without worrying about the impact of decomposisitons. More details at |
| # https://github.com/pytorch/torchdynamo/issues/611 |
| BACKENDS["aot_nvfuser_nodecomps"] = aot_mem_efficient_fusion_no_decomp |
| |
| # aot_cudagraphs only applies CUDA graphs to the graph. It is also helpful |
| # for debugging and can serve as a perf baseline. |
| BACKENDS["aot_cudagraphs"] = aot_cudagraphs |
| |
| # aot_inductor_debug just replaces the inductor compiler with nop to help |
| # isolate inductor vs aot_eager errors |
| BACKENDS["aot_inductor_debug"] = aot_inductor_debug |