| import logging |
| import operator |
| from collections import defaultdict |
| from typing import Set |
| |
| import torch |
| |
| from torch.fx import GraphModule |
| from torch.fx.passes.backends.cudagraphs import partition_cudagraphs |
| from torch.multiprocessing.reductions import StorageWeakRef |
| from torch.nn import Module |
| from torch.utils._pytree import tree_map |
| from .common import aot_autograd |
| from .registry import register_backend |
| |
| log = logging.getLogger(__name__) |
| |
| |
| def cloner(t): |
| if isinstance(t, torch.Tensor): |
| return t.clone() |
| else: |
| return t |
| |
| |
| class CudaGraphModule(Module): |
| gm: GraphModule |
| mutated_inputs: Set[int] |
| |
| def __init__(self, gm, mutated_inputs): |
| super().__init__() |
| self.gm = gm |
| self.mutated_inputs = mutated_inputs |
| |
| warmed_up = False |
| |
| # these are all None or all filled |
| graph = None |
| static_inputs = None |
| static_outputs = None |
| |
| # NB: we override __call__ as we don't need any nn.Module machinery |
| # and to reduce overhead |
| def __call__(self, *args): |
| # TODO: once we've recorded here, we'd like to replace the __call__ |
| # implementation with compiled bytecode that copies into static, replays |
| # the cuda graph, then copies out. First condition is the hotpath, |
| # needs optimizing |
| if self.graph is not None: |
| assert len(args) == len(self.static_inputs) |
| for dst, src in zip(self.static_inputs, args): |
| dst.copy_(src) |
| self.graph.replay() |
| for i in self.mutated_inputs: |
| args[i].copy_(self.static_inputs[i]) |
| return tree_map(cloner, self.static_outputs) |
| |
| elif self.warmed_up: |
| # record |
| self.static_inputs = [x.clone() for x in args] |
| self.graph = torch.cuda.CUDAGraph() |
| with torch.cuda.graph(self.graph): |
| self.static_outputs = self.gm(*self.static_inputs) |
| # NB: recording doesn't actually run the operations, so |
| # now we immediately replay the graph to serve up the result |
| self.graph.replay() |
| for i in self.mutated_inputs: |
| args[i].copy_(self.static_inputs[i]) |
| return tree_map(cloner, self.static_outputs) |
| |
| else: |
| # warmup |
| stream = torch.cuda.Stream() |
| stream.wait_stream(torch.cuda.current_stream()) |
| with torch.cuda.stream(stream): |
| r = self.gm(*args) |
| torch.cuda.current_stream().wait_stream(stream) |
| self.warmed_up = True |
| return r |
| |
| |
| # Interpreter versions of these passes can be found at |
| # https://gist.github.com/ezyang/df2d746cac3b2c7d55c181e37c57ef23 |
| |
| |
| def find_input_mutations(g): |
| def meta_fk(meta): |
| return meta["val"] if "val" in meta else meta["fake_result"] |
| |
| inputs = defaultdict(set) |
| input_idx = 0 |
| mutated_inputs = set() |
| for n in g.nodes: |
| if n.op == "placeholder": |
| inputs[StorageWeakRef(meta_fk(n.meta)._typed_storage())].add(input_idx) |
| input_idx += 1 |
| elif n.op == "call_function": |
| if n.target is operator.getitem: |
| continue |
| schema = n.target._schema |
| for i, arg in enumerate(schema.arguments): |
| if i < len(n.args): |
| argument = n.args[i] |
| else: |
| if arg.name not in n.kwargs: |
| continue |
| argument = n.kwargs[arg.name] |
| mut_arg = False |
| if arg.alias_info: |
| if arg.alias_info.is_write: |
| mut_arg = True |
| if mut_arg: |
| # TODO: not correct for args that contain tensors in a struct |
| # like list |
| mutated_inputs |= inputs[ |
| StorageWeakRef(meta_fk(argument.meta)._typed_storage()) |
| ] |
| # TODO: error on unrecognized nodes |
| return mutated_inputs |
| |
| |
| # Mutates input graph |
| def apply_cuda_graphs(gm): |
| for n in gm.graph.nodes: |
| if n.op == "call_module": |
| assert not n.kwargs |
| submod = gm.get_submodule(n.target) |
| gm.delete_submodule(n.target) |
| mutated_inputs = find_input_mutations(submod.graph) |
| gm.add_submodule(n.target, CudaGraphModule(submod, mutated_inputs)) |
| # NB: we didn't actually change the graph, no need for recompile |
| |
| |
| def cudagraphs(model, inputs): |
| model = partition_cudagraphs(model, inputs) |
| apply_cuda_graphs(model) |
| return model |
| |
| |
| aot_cudagraphs = aot_autograd(fw_compiler=cudagraphs, bw_compiler=cudagraphs) |
| |
| # aot_cudagraphs only applies CUDA graphs to the graph. It is also helpful |
| # for debugging and can serve as a perf baseline. |
| # TODO(jansel): rename to just "cudagraphs"? |
| register_backend(name="cudagraphs", compiler_fn=aot_cudagraphs) |
| |
| |
| def cudagraphs_inner(model, inputs, copy_outputs=True): |
| """This isn't registered as a backend, but is used in some benchmarks""" |
| assert isinstance(inputs, (list, tuple)) |
| static_inputs = [torch.zeros_like(x) for x in inputs] |
| |
| # warmup |
| torch.cuda.synchronize() |
| stream = torch.cuda.Stream() |
| stream.wait_stream(torch.cuda.current_stream()) |
| with torch.cuda.stream(stream): |
| model(*inputs) |
| stream.synchronize() |
| torch.cuda.current_stream().wait_stream(stream) |
| torch.cuda.synchronize() |
| |
| # record |
| graph = torch.cuda.CUDAGraph() |
| with torch.cuda.graph(graph, stream=stream): |
| static_outputs = model(*static_inputs) |
| if not isinstance(static_outputs, (list, tuple)): |
| static_outputs = (static_outputs,) |
| |
| def run(*new_inputs): |
| assert len(static_inputs) == len(new_inputs) |
| for dst, src in zip(static_inputs, new_inputs): |
| dst.copy_(src) |
| graph.replay() |
| if copy_outputs: |
| return [x.clone() for x in static_outputs] |
| else: |
| return static_outputs |
| |
| return run |