| # mypy: allow-untyped-defs |
| import contextlib |
| import functools |
| from typing import Dict, List, Optional, TYPE_CHECKING |
| |
| import torch |
| from torch._dynamo.external_utils import ( |
| call_backward, |
| call_hook, |
| FakeCompiledAutogradEngine, |
| ) |
| from torch._dynamo.source import GetItemSource, LocalSource |
| from torch._dynamo.utils import counters, lazy_format_graph_code, set_locals_to_steal |
| from torch._logging import getArtifactLogger, trace_structured |
| from torch._prims_common import clone_preserve_strides |
| from torch._subclasses import FakeTensorMode |
| from torch.fx import GraphModule |
| from torch.fx.experimental._backward_state import BackwardState |
| from torch.fx.experimental.proxy_tensor import ( |
| decompose, |
| disable_autocast_cache, |
| disable_proxy_modes_tracing, |
| fetch_object_proxy, |
| ProxyTorchDispatchMode, |
| PythonKeyTracer, |
| track_tensor_tree, |
| ) |
| from torch.fx.experimental.symbolic_shapes import DimDynamic, ShapeEnv |
| from torch.fx.traceback import preserve_node_meta, set_stack_trace |
| from torch.utils._traceback import CapturedTraceback |
| |
| if TYPE_CHECKING: |
| from torch.fx.proxy import Proxy |
| |
| compiled_autograd_log = getArtifactLogger(__name__, "compiled_autograd") |
| verbose_log = getArtifactLogger(__name__, "compiled_autograd_verbose") |
| |
| |
| def snapshot_verbose_logging_enabled(): |
| return torch._logging._internal.log_state.is_artifact_enabled( |
| "compiled_autograd_verbose" |
| ) |
| |
| |
| def cpp_verbose_log_fn(msg: str) -> None: |
| verbose_log.debug(msg) |
| |
| |
| def snapshot_cudagraph_enabled(): |
| return torch._inductor.config.triton.cudagraphs |
| |
| |
| def maybe_clone(x): |
| if x is not None: |
| return clone_preserve_strides(x) |
| return x |
| |
| |
| class AutogradCompilerInstance: |
| def __init__(self, compiler_fn) -> None: |
| self.compiler_fn = compiler_fn |
| self.stack = contextlib.ExitStack() |
| self.close = self.stack.close |
| self.shape_env = ShapeEnv() |
| self.fake_tensor_mode = FakeTensorMode( |
| allow_fallback_kernels=True, |
| allow_non_fake_inputs=True, |
| shape_env=self.shape_env, |
| ) |
| self.fx_tracer = PythonKeyTracer() |
| self.proxy_mode = ProxyTorchDispatchMode(self.fx_tracer, "symbolic") |
| self.hooks_proxy: Optional[Proxy] = None |
| |
| def wrap_fake(self, x, source): |
| assert isinstance(x, torch.Tensor) |
| return self.fake_tensor_mode.from_tensor(x, source=source) |
| |
| @staticmethod |
| def source(name, idx) -> GetItemSource: |
| return GetItemSource(LocalSource(name), idx) |
| |
| def begin_capture(self, inputs: List[torch.Tensor], sizes: List[int]): |
| counters["compiled_autograd"]["captures"] += 1 |
| self.fx_tracer.root = torch.nn.Module() |
| self.fx_tracer.graph = torch.fx.Graph(tracer_cls=PythonKeyTracer) |
| self.fx_tracer.tensor_attrs = {} |
| args_proxy = self.fx_tracer.create_proxy("placeholder", "inputs", (), {}) |
| sizes_proxy = self.fx_tracer.create_proxy("placeholder", "sizes", (), {}) |
| self.hooks_proxy = self.fx_tracer.create_proxy("placeholder", "hooks", (), {}) |
| |
| # tensor inputs to fake tensors |
| inputs = [ |
| self.wrap_fake(x, self.source("inputs", idx)) |
| for idx, x in enumerate(inputs) |
| ] |
| proxies = [args_proxy[i] for i in range(len(inputs))] |
| self.bind_tensors_to_proxies(inputs, proxies) |
| |
| # size inputs to symints |
| sizes = [ |
| self.shape_env.create_unspecified_symint_and_symbol( |
| val, |
| self.source("sizes", idx), |
| DimDynamic.DYNAMIC, |
| ) |
| for idx, val in enumerate(sizes) |
| ] |
| self.bind_tensors_to_proxies(sizes, sizes_proxy) |
| |
| # TODO(jansel): are all these modes needed? |
| self.stack.enter_context(decompose({})) |
| self.stack.enter_context(self.fake_tensor_mode) |
| self.stack.enter_context(self.proxy_mode.sym_mode) |
| self.stack.enter_context(self.proxy_mode) |
| self.stack.enter_context(disable_autocast_cache()) |
| self.stack.enter_context(preserve_node_meta()) |
| return inputs, sizes |
| |
| def proxy_call_backward( |
| self, |
| inputs, |
| output_metadatas, |
| saved_tensors, |
| backward_idx: int, |
| ): |
| assert self.hooks_proxy is not None |
| backward_c_function = self.hooks_proxy[backward_idx] # type: ignore[index] |
| proxies = self.fx_tracer.create_proxy( |
| kind="call_function", |
| target=call_backward, |
| args=( |
| backward_c_function, |
| self.to_proxy(saved_tensors), |
| *self.to_proxy(inputs), |
| ), |
| kwargs={}, |
| ) |
| |
| with disable_proxy_modes_tracing(): |
| # create fake Tensors |
| grad_ins: List[Optional[torch.Tensor]] = [] |
| for output_metadata in output_metadatas: |
| if output_metadata is None: |
| grad_ins.append(None) |
| continue |
| |
| layout, device, dtype, size = output_metadata |
| grad_ins.append( |
| torch.empty(size=size, dtype=dtype, layout=layout, device=device) |
| ) |
| self.bind_tensors_to_proxies(grad_ins, proxies) |
| return tuple(grad_ins) |
| |
| def proxy_call_hook(self, hook, *args): |
| return self.fx_tracer.create_proxy( |
| "call_function", |
| call_hook, |
| ( |
| hook, |
| *[self.to_proxy(x) for x in args], |
| ), |
| {}, |
| ) |
| |
| def tensor_pre_hook(self, inputs, hook_id, i: int): |
| assert self.hooks_proxy is not None |
| hook = self.hooks_proxy[hook_id] # type: ignore[index] |
| proxy = self.proxy_call_hook( |
| hook, |
| inputs[i], |
| ) |
| with disable_proxy_modes_tracing(): |
| inputs[i] = maybe_clone(inputs[i]) |
| self.bind_tensors_to_proxies([inputs[i]], [proxy]) |
| return inputs |
| |
| def pre_hook(self, inputs, hook_id): |
| assert self.hooks_proxy is not None |
| hook = self.hooks_proxy[hook_id] # type: ignore[index] |
| proxies = self.proxy_call_hook( |
| hook, |
| inputs, |
| ) |
| with disable_proxy_modes_tracing(): |
| inputs = [maybe_clone(x) for x in inputs] |
| self.bind_tensors_to_proxies(inputs, proxies) |
| return inputs |
| |
| def post_hook(self, outputs, inputs, hook_id): |
| assert self.hooks_proxy is not None |
| hook = self.hooks_proxy[hook_id] # type: ignore[index] |
| proxies = self.proxy_call_hook( |
| hook, |
| outputs, |
| inputs, |
| ) |
| with disable_proxy_modes_tracing(): |
| outputs = [maybe_clone(x) for x in outputs] |
| self.bind_tensors_to_proxies(outputs, proxies) |
| return outputs |
| |
| def post_acc_grad_hook(self, input, hook_id): |
| assert isinstance(input, torch.Tensor) |
| assert self.hooks_proxy is not None |
| hook = self.hooks_proxy[hook_id] # type: ignore[index] |
| proxies = self.proxy_call_hook( |
| hook, |
| input, |
| ) |
| with disable_proxy_modes_tracing(): |
| input = [maybe_clone(input)] |
| self.bind_tensors_to_proxies(input, proxies) |
| return input |
| |
| # Note: [Compiled autograd and cudagraphs] |
| # Eager autograd backward implements scalars as 0-dim tensors, see DivBackward0::other_. |
| # When compiled autograd traces those nodes, it lifts the scalar tensors, resulting in a graph |
| # with some cpu 0-dim tensor inputs. To prevent the entire graph from skipping cudagraph, we move the |
| # scalars tensors to cuda. This works because ATen/prims ops will accept cuda 0-dim tensors too. |
| def move_graph_nodes_to_cuda(self, graph) -> List[int]: |
| to_move: Dict[int, torch.fx.Node] = {} |
| has_cuda_inputs = False |
| nodes = list(graph.nodes) |
| assert nodes[0].target == "inputs" |
| inputs = nodes[0] |
| inputs_users = list(inputs.users.keys()) |
| # the ordering of the nodes should always [inputs, sizes, hooks, getitem, getitem1, ...] |
| # where getitemi accesses inputs[i] |
| first_getitem_idx = 3 |
| assert nodes[first_getitem_idx] == inputs_users[0] |
| last_getitem_idx = first_getitem_idx + len(inputs_users) - 1 |
| assert nodes[last_getitem_idx] == inputs_users[-1] |
| for i, node in enumerate(inputs_users): |
| if not has_cuda_inputs and node.meta["val"].device.type == "cuda": |
| has_cuda_inputs = True |
| continue |
| |
| is_cpu = node.meta["val"].device.type == "cpu" |
| is_scalar = len(node.meta["val"].size()) == 0 |
| if is_cpu and is_scalar: |
| node_users = list(node.users.keys()) |
| if all( |
| isinstance(user.target, torch._ops.OpOverload) |
| and user.target.namespace in ("prims", "aten") |
| for user in node_users |
| ): |
| # all users are prims/aten, can move safely |
| to_move[i] = node |
| |
| # only move cpu scalars to cuda if there were cuda activations in this graph, |
| # this is to handle the case where cudagraphs is enabled on a cpu-only graph |
| if has_cuda_inputs: |
| for node in to_move.values(): |
| node.meta["val"] = node.meta["val"].cuda() |
| |
| # return runtime indices we need to move to cuda |
| return list(to_move.keys()) |
| |
| return [] |
| |
| def end_capture(self, outputs): |
| self.fx_tracer.create_proxy( |
| "call_function", |
| FakeCompiledAutogradEngine._exec_final_callbacks_stub, |
| (), |
| {}, |
| ) |
| self.stack.close() |
| self.fx_tracer.create_node( |
| "output", |
| "output", |
| (self.fx_tracer.create_arg(self.to_proxy(outputs)),), |
| {}, |
| ) |
| self.reorder_accumulate_grad_nodes() |
| runtime_inputs_to_move: List[int] = [] |
| if snapshot_cudagraph_enabled(): |
| runtime_inputs_to_move = self.move_graph_nodes_to_cuda(self.fx_tracer.graph) |
| |
| graph = GraphModule( |
| self.fx_tracer.root, self.fx_tracer.graph, "CompiledAutograd" |
| ) |
| set_locals_to_steal(graph, ["inputs"]) |
| compiled_autograd_log.info( |
| "%s", lazy_format_graph_code("Compiled autograd graph", graph, colored=True) |
| ) |
| verbose_log.debug( |
| "%s", |
| lazy_format_graph_code( |
| "Compiled autograd graph", graph, include_device=True, colored=True |
| ), |
| ) |
| trace_structured( |
| "compiled_autograd_graph", |
| payload_fn=lambda: graph.print_readable(print_output=False), |
| ) |
| |
| def runtime_wrapper(compiled_fn, inputs, sizes, hooks): |
| global in_compiled_autograd_region |
| try: |
| in_compiled_autograd_region = True |
| for i in runtime_inputs_to_move: |
| inputs[i] = inputs[i].pin_memory().cuda(non_blocking=True) |
| |
| return compiled_fn(inputs, sizes, hooks) |
| finally: |
| in_compiled_autograd_region = False |
| |
| return runtime_wrapper, self.compiler_fn(graph) |
| |
| def reorder_accumulate_grad_nodes(self): |
| """ |
| Usage of AOTAutograd causes all the accumulate_grad_ nodes to get pushed to the end of |
| the graph. This differs from eager mode, which schedules them as soon as possible. This |
| pass attempts to reorder the graph to mimic eager behavior. |
| """ |
| for node in self.fx_tracer.graph.find_nodes( |
| op="call_function", target=torch.ops.inductor.accumulate_grad_.default |
| ): |
| arg = max(node.args) # last arg |
| if arg is not node.prev and arg.op != "placeholder": |
| arg.append(node) |
| |
| def to_proxy(self, t): |
| if t is None: |
| return None |
| if isinstance(t, list): |
| return [self.to_proxy(x) for x in t] |
| if isinstance(t, tuple): |
| return tuple(self.to_proxy(x) for x in t) |
| assert isinstance(t, (torch.Tensor, torch.SymInt)) |
| return fetch_object_proxy(self.fx_tracer)(t).proxy |
| |
| def bind_tensors_to_proxies(self, tensors, proxies): |
| if isinstance(proxies, torch.fx.Proxy): |
| proxies = [proxies[i] for i in range(len(tensors))] |
| assert len(tensors) == len(proxies) |
| track_tensor_tree(tensors, proxies, constant=None, tracer=self.fx_tracer) |
| |
| def bind_backward_state(self, index: int): |
| assert self.hooks_proxy is not None |
| proxy = self.hooks_proxy[index] # type: ignore[index] |
| bw_state = BackwardState() |
| track_tensor_tree(bw_state, proxy, constant=None, tracer=self.fx_tracer) |
| return bw_state |
| |
| def set_node_origin(self, node_name, node_index): |
| raw_stack_trace = CapturedTraceback.extract().format()[-1] |
| new_code = f"{node_name} (NodeCall {node_index})" |
| new_stack_trace = raw_stack_trace.replace( |
| "raw_stack_trace = CapturedTraceback.extract().format()[-1]", new_code |
| ) |
| set_stack_trace(new_stack_trace) |
| |
| |
| # state of the autograd engine dispatch, kept in sync by enable/disable context managers |
| compiled_autograd_enabled = False |
| |
| # global flag to check if we are processing graphs produced from a compiled autograd graph |
| in_compiled_autograd_region = False |
| |
| |
| @contextlib.contextmanager |
| def enable(compiler_fn): |
| prior = torch._C._dynamo.compiled_autograd.set_autograd_compiler( |
| functools.partial(AutogradCompilerInstance, compiler_fn) |
| ) |
| if snapshot_verbose_logging_enabled(): |
| torch._C._dynamo.compiled_autograd.set_verbose_logger(cpp_verbose_log_fn) |
| global compiled_autograd_enabled |
| compiled_autograd_enabled = True |
| try: |
| with torch.autograd.set_multithreading_enabled(False): |
| yield |
| finally: |
| if not prior: |
| compiled_autograd_enabled = False |
| torch._C._dynamo.compiled_autograd.set_autograd_compiler(prior) |
| |
| |
| @contextlib.contextmanager |
| def disable(): |
| prior = torch._C._dynamo.compiled_autograd.set_autograd_compiler(None) |
| global compiled_autograd_enabled |
| compiled_autograd_enabled = False |
| try: |
| yield |
| finally: |
| if prior: |
| compiled_autograd_enabled = True |
| torch._C._dynamo.compiled_autograd.set_autograd_compiler(prior) |
| |
| |
| # return to starting state of a new process |
| def reset() -> None: |
| compiled_autograd_enable = False |
| assert not in_compiled_autograd_region |
| torch._C._dynamo.compiled_autograd.set_autograd_compiler(None) |
| torch._C._dynamo.compiled_autograd.set_verbose_logger(None) |