| # mypy: allow-untyped-defs |
| import copy |
| import functools |
| import heapq |
| import itertools |
| import logging |
| import math |
| import operator |
| import os |
| from collections import defaultdict |
| from dataclasses import dataclass, replace |
| from typing import Callable, Dict, List, Optional, Set, Tuple, TYPE_CHECKING, Union |
| |
| import torch |
| import torch._inductor.inductor_prims |
| import torch.fx as fx |
| import torch.utils._pytree as pytree |
| from torch.fx.experimental._backward_state import BackwardState |
| from torch.fx.experimental.proxy_tensor import is_sym_node, py_sym_types |
| from torch.fx.experimental.sym_node import magic_methods, method_to_operator |
| from torch.fx.experimental.symbolic_shapes import ( |
| find_symbol_binding_fx_nodes, |
| free_symbols, |
| hint_int, |
| is_symbol_binding_fx_node, |
| ) |
| from torch.fx.passes import graph_drawer |
| from torch.utils.checkpoint import CheckpointPolicy |
| from . import config |
| from ._aot_autograd.logging_utils import get_aot_graph_name |
| from .compile_utils import fx_graph_cse, get_aten_target |
| |
| if TYPE_CHECKING: |
| import sympy |
| |
| |
| AOT_PARTITIONER_DEBUG = config.debug_partitioner |
| log = logging.getLogger(__name__) |
| |
| aten = torch.ops.aten |
| prims = torch.ops.prims |
| |
| |
| @dataclass |
| class OpTypes: |
| """Class for keeping track of different operator categories""" |
| |
| fusible_ops: Set[Callable] |
| compute_intensive_ops: Set[Callable] |
| random_ops: Set[Callable] |
| view_ops: Set[Callable] |
| recomputable_ops: Set[Callable] |
| |
| def is_fusible(self, node: fx.Node): |
| return get_aten_target(node) in self.fusible_ops |
| |
| def is_compute_intensive(self, node: fx.Node): |
| return get_aten_target(node) in self.compute_intensive_ops |
| |
| def is_random(self, node: fx.Node): |
| return get_aten_target(node) in self.random_ops |
| |
| def is_view(self, node: fx.Node): |
| return get_aten_target(node) in self.view_ops |
| |
| def is_recomputable(self, node: fx.Node): |
| return get_aten_target(node) in self.recomputable_ops |
| |
| |
| @dataclass |
| class NodeInfo: |
| # Be careful about iterating over these explicitly, as their order may not |
| # be deterministic |
| inputs: List[fx.Node] |
| _required_fw_nodes: Set[fx.Node] |
| required_bw_nodes: Set[fx.Node] |
| unclaimed_nodes: Set[fx.Node] |
| fw_order: Dict[fx.Node, int] |
| |
| @functools.cached_property |
| def required_fw_nodes(self) -> List[fx.Node]: |
| return sorted( |
| (n for n in self._required_fw_nodes), key=lambda n: self.fw_order[n] |
| ) |
| |
| def is_required_fw(self, n: fx.Node) -> bool: |
| return n in self._required_fw_nodes |
| |
| def is_required_bw(self, n: fx.Node) -> bool: |
| return n in self.required_bw_nodes |
| |
| def is_unclaimed(self, n: fx.Node) -> bool: |
| return n in self.unclaimed_nodes |
| |
| def get_fw_order(self, n: fx.Node) -> int: |
| assert n in self._required_fw_nodes, f"Node {n} not in fw nodes!" |
| return self.fw_order[n] |
| |
| |
| @dataclass |
| class MinCutOptions: |
| ban_if_used_far_apart: bool |
| ban_if_long_fusible_chains: bool |
| ban_if_materialized_backward: bool |
| ban_if_not_in_allowlist: bool |
| ban_if_reduction: bool |
| |
| |
| def must_recompute(node: fx.Node) -> bool: |
| return node.meta.get("recompute", None) in [ |
| CheckpointPolicy.MUST_RECOMPUTE, |
| CheckpointPolicy.PREFER_RECOMPUTE, |
| ] |
| |
| |
| def has_recomputable_ops(fx_g: fx.GraphModule) -> bool: |
| found = False |
| for node in fx_g.graph.nodes: |
| if must_recompute(node): |
| return True |
| return False |
| |
| |
| def has_recomputable_rng_ops(fx_g: fx.GraphModule) -> bool: |
| for node in fx_g.graph.nodes: |
| if ( |
| must_recompute(node) |
| and hasattr(node.target, "tags") |
| and torch.Tag.nondeterministic_seeded in node.target.tags |
| ): |
| return True |
| return False |
| |
| |
| def sym_node_size(node: fx.Node) -> int: |
| if isinstance(node.meta["val"], (torch.SymInt, torch.SymBool)): |
| return 1 |
| assert isinstance(node.meta["val"], torch.SymFloat) |
| return 4 |
| |
| |
| class InvalidNodeBase: |
| def __repr__(self): |
| return "Invalid Node" |
| |
| |
| InvalidNode = InvalidNodeBase() |
| |
| |
| def _extract_graph_with_inputs_outputs( |
| joint_graph: fx.Graph, inputs: List[fx.Node], outputs: List[fx.Node] |
| ) -> fx.Graph: |
| """ |
| Given a graph, extracts out a subgraph that takes the specified nodes as |
| inputs and returns the specified outputs. |
| |
| This includes specifying non-placeholder nodes as inputs. |
| |
| The general strategy is to initialize all inputs with proxies as we |
| encounter them, and trace through the graph, only keeping values which take |
| in valid proxies. Then, all dead code is eliminated. |
| """ |
| new_graph = fx.Graph() |
| env = {} |
| |
| # Add new placeholder nodes in the order specified by the inputs |
| for node in inputs: |
| new_node = new_graph.placeholder(node.name) |
| # Can't use node_copy here as we may be turning previous call_function into placeholders |
| new_node.meta = node.meta |
| env[node] = new_node |
| |
| for node in joint_graph.nodes: |
| if node in env: |
| # Node must be one of our inputs. (Any member of env which wasn't an |
| # input to start must have been created by this loop and won't be in |
| # joint_graph.nodes). |
| continue |
| elif node.op == "placeholder": |
| env[node] = InvalidNode |
| elif node.op == "call_function": |
| all_args = pytree.arg_tree_leaves(*node.args, **node.kwargs) |
| all_args = [ |
| isinstance(env[x], InvalidNodeBase) |
| for x in all_args |
| if isinstance(x, fx.Node) |
| ] |
| if any(all_args): |
| env[node] = InvalidNode |
| continue |
| env[node] = new_graph.node_copy(node, lambda x: env[x]) |
| elif node.op == "get_attr": |
| env[node] = new_graph.node_copy(node, lambda x: env[x]) |
| elif node.op == "output": |
| pass |
| output_values = [] |
| for x in outputs: |
| if isinstance(x, fx.Node): |
| if x not in env: |
| raise RuntimeError(f"Node {x} couldn't be found in env") |
| assert not isinstance( |
| env[x], InvalidNodeBase |
| ), f"Node {x} was invalid, but is output" |
| output_values.append(env[x]) |
| else: |
| output_values.append(x) |
| new_graph.output(output_values) |
| |
| new_graph.eliminate_dead_code() |
| new_graph.lint() |
| return new_graph |
| |
| |
| def _is_primal(node: fx.Node) -> bool: |
| return ( |
| node.op == "placeholder" |
| and "tangents" not in str(node.target) |
| and not _is_bwd_seed_offset(node) |
| and not _is_fwd_seed_offset(node) |
| ) |
| |
| |
| def _is_tangent(node: fx.Node) -> bool: |
| return node.op == "placeholder" and "tangents" in str(node.target) |
| |
| |
| def _is_bwd_seed_offset(node: fx.Node) -> bool: |
| return node.op == "placeholder" and ( |
| "bwd_seed" in str(node.target) or "bwd_base_offset" in str(node.target) |
| ) |
| |
| |
| def _is_fwd_seed_offset(node: fx.Node) -> bool: |
| return node.op == "placeholder" and ( |
| "fwd_seed" in str(node.target) or "fwd_base_offset" in str(node.target) |
| ) |
| |
| |
| def _is_backward_state(node: fx.Node) -> bool: |
| return node.op == "placeholder" and isinstance(node.meta.get("val"), BackwardState) |
| |
| |
| def _extract_fwd_bwd_outputs( |
| joint_module: fx.GraphModule, *, num_fwd_outputs |
| ) -> Tuple[List[fx.Node], List[fx.Node]]: |
| outputs = pytree.arg_tree_leaves( |
| *(node.args for node in joint_module.graph.find_nodes(op="output")) |
| ) |
| fwd_outputs = outputs[:num_fwd_outputs] |
| bwd_outputs = outputs[num_fwd_outputs:] |
| return fwd_outputs, bwd_outputs |
| |
| |
| def _remove_by_name(saved_values: List[fx.Node], name: str): |
| for saved_value in saved_values: |
| if saved_value.name == name: |
| saved_values.remove(saved_value) |
| break |
| |
| |
| def _extract_fwd_bwd_modules( |
| joint_module: fx.GraphModule, |
| saved_values: List[fx.Node], |
| saved_sym_nodes: List[fx.Node], |
| *, |
| num_fwd_outputs: int, |
| ) -> Tuple[fx.GraphModule, fx.GraphModule]: |
| fwd_outputs, bwd_outputs = _extract_fwd_bwd_outputs( |
| joint_module, num_fwd_outputs=num_fwd_outputs |
| ) |
| placeholders = joint_module.graph.find_nodes(op="placeholder") |
| primal_inputs = [*filter(_is_primal, placeholders)] |
| tangent_inputs = [*filter(_is_tangent, placeholders)] |
| fwd_seed_offset_inputs = [*filter(_is_fwd_seed_offset, placeholders)] |
| bwd_seed_offset_inputs = [*filter(_is_bwd_seed_offset, placeholders)] |
| backward_state_inputs = [*filter(_is_backward_state, placeholders)] |
| |
| bwd_graph = _extract_graph_with_inputs_outputs( |
| joint_module.graph, |
| saved_sym_nodes + saved_values + tangent_inputs + bwd_seed_offset_inputs, |
| bwd_outputs, |
| ) |
| |
| for node in bwd_graph.find_nodes(op="placeholder"): |
| # This is to filter out saved values that don't actually end up being used by the backwards pass |
| if not node.users: |
| _remove_by_name(saved_values, node.name) |
| _remove_by_name(saved_sym_nodes, node.name) |
| elif _is_backward_state(node): |
| # BackwardState is saved directly |
| _remove_by_name(saved_values, node.name) |
| assert backward_state_inputs |
| |
| # Now that we have the finalized list of saved values, we need to ensure |
| # we propagate all symbols which are referenced by backwards inputs. |
| # These are not directly used in the graph but are required for downstream |
| # sizevar assignment |
| saved_symbols: Set[sympy.Symbol] = set() |
| saved_sym_nodes_binding = [] |
| saved_sym_nodes_derived = [] |
| |
| # Some symbols may already be bound in the directly saved_sym_nodes, |
| # keep track of them so we don't re-bind them |
| for node in saved_sym_nodes: |
| symbol = is_symbol_binding_fx_node(node) |
| if symbol: |
| saved_symbols.add(symbol) |
| saved_sym_nodes_binding.append(node) |
| else: |
| saved_sym_nodes_derived.append(node) |
| |
| # Now go through all of the prospective backward inputs and track any |
| # other symbols we need to bind |
| symbol_bindings = find_symbol_binding_fx_nodes(joint_module.graph) |
| for node in itertools.chain(saved_sym_nodes_derived, saved_values, tangent_inputs): |
| if "val" not in node.meta: |
| continue |
| new_symbols = free_symbols(node.meta["val"]) - saved_symbols |
| # NB: Deterministic order please! |
| for s in sorted(new_symbols, key=lambda s: s.name): |
| # NB: For well formed graphs, the symbol should always be present, |
| # but we also have ways to produce ill-formed graphs, e.g., direct |
| # make_fx usages, so don't choke in this case |
| if s not in symbol_bindings: |
| continue |
| saved_sym_nodes_binding.append(symbol_bindings[s]) |
| saved_symbols |= new_symbols |
| |
| # Update saved_sym_nodes that are now reordered to have all bindings at |
| # front. This can also be used later on to figure out the position of saved |
| # sym nodes in the output of fwd graph. |
| saved_sym_nodes.clear() |
| saved_sym_nodes.extend(saved_sym_nodes_binding + saved_sym_nodes_derived) |
| |
| # Now, we re-generate the fwd/bwd graphs. |
| # NB: This might increase compilation time, but I doubt it matters |
| fwd_graph = _extract_graph_with_inputs_outputs( |
| joint_module.graph, |
| primal_inputs + fwd_seed_offset_inputs, |
| fwd_outputs + saved_values + saved_sym_nodes, |
| ) |
| bwd_graph = _extract_graph_with_inputs_outputs( |
| joint_module.graph, |
| saved_sym_nodes |
| + saved_values |
| + tangent_inputs |
| + bwd_seed_offset_inputs |
| + backward_state_inputs, |
| bwd_outputs, |
| ) |
| |
| fwd_module = fx._lazy_graph_module._make_graph_module(joint_module, fwd_graph) |
| bwd_module = fx._lazy_graph_module._make_graph_module(joint_module, bwd_graph) |
| return fwd_module, bwd_module |
| |
| |
| def default_partition( |
| joint_module: fx.GraphModule, _joint_inputs, *, num_fwd_outputs |
| ) -> Tuple[fx.GraphModule, fx.GraphModule]: |
| """ |
| Partitions the :attr:`joint_module` in a manner that closely resembles the |
| behavior observed in the original ``.forward()`` and ``.backward()`` of the |
| callable, i.e., the resulting forward graph contains those operators that |
| are executed in the original ``.forward()`` callable passed to |
| :func:`aot_function`. |
| |
| The default partitioner collects the operators that are between the forward |
| inputs and the forward outputs. This helps in finding the tensors which have |
| to be stashed for the backward pass. These stashed tensors become the output |
| of the generated forward graph. The remaining operators are then placed in |
| the backward graph. |
| |
| .. warning:: |
| This API is experimental and likely to change. |
| |
| Args: |
| joint_module(fx.GraphModule): The joint forward and backward graph. This |
| is the result of AOT Autograd tracing. |
| |
| Returns: |
| Returns the generated forward and backward Fx graph modules. |
| """ |
| if has_recomputable_ops(joint_module): |
| return min_cut_rematerialization_partition( |
| joint_module, _joint_inputs, num_fwd_outputs=num_fwd_outputs |
| ) |
| primal_inputs = list(filter(_is_primal, joint_module.graph.nodes)) |
| fwd_seed_offset_inputs = list(filter(_is_fwd_seed_offset, joint_module.graph.nodes)) |
| inputs = primal_inputs + fwd_seed_offset_inputs |
| fwd_outputs, bwd_outputs = _extract_fwd_bwd_outputs( |
| joint_module, num_fwd_outputs=num_fwd_outputs |
| ) |
| forward_only_graph = _extract_graph_with_inputs_outputs( |
| joint_module.graph, inputs, fwd_outputs |
| ) |
| forward_node_names = { |
| node.name for node in forward_only_graph.nodes if node.op != "output" |
| } |
| saved_values = [] |
| saved_sym_nodes = [] |
| |
| for node in joint_module.graph.nodes: |
| if node.name not in forward_node_names: |
| continue |
| if is_sym_node(node): |
| # Symints must be kept separate from tensors so that PythonFunction only calls |
| # save_for_backward on tensors and stashes symints in autograd .ctx |
| saved_sym_nodes.append(node) |
| elif "tensor_meta" not in node.meta and node.op == "call_function": |
| # Since we can't save tuple of tensor values, we need to flatten out what we're saving |
| users = node.users |
| assert all(user.target == operator.getitem for user in users) |
| saved_values.extend(users) |
| else: |
| backward_usages = [ |
| n for n in node.users if n.name not in forward_node_names |
| ] |
| if "tensor_meta" in node.meta and all( |
| is_sym_node(n) for n in backward_usages |
| ): |
| # If we have a tensor in the forward, where only its sizes/strides are needed in the backward, |
| # and not the actual tensor data, |
| # then it will be a lot cheaper to save only the sizes/strides, and not the actual tensor. |
| # |
| # Note that saving the tensor could also cause compilation problems: |
| # If the user mutated an input in the forward and uses its sizes/strides in the backward, |
| # then we would be obligated to clone the input before saving it to appease autograd. |
| # (This is how we originally found this bug). |
| saved_sym_nodes.extend(backward_usages) |
| else: |
| saved_values.append(node) |
| saved_values = list(dict.fromkeys(saved_values).keys()) |
| saved_sym_nodes = list(dict.fromkeys(saved_sym_nodes).keys()) |
| |
| return _extract_fwd_bwd_modules( |
| joint_module, |
| saved_values, |
| saved_sym_nodes=saved_sym_nodes, |
| num_fwd_outputs=num_fwd_outputs, |
| ) |
| |
| |
| INT_INF = int(1e6) |
| |
| |
| def _tensor_nbytes(numel: int, dtype) -> int: |
| return numel * dtype.itemsize |
| |
| |
| def _size_of(node: fx.Node) -> int: |
| if "val" in node.meta: |
| val = node.meta["val"] |
| if isinstance(val, py_sym_types): |
| return 1 |
| # NB: The fallback values here are meaningless, maybe we should respect |
| # torch._inductor.config.unbacked_symint_fallback (but this is a |
| # layering violation) |
| elif isinstance(val, (list, tuple)): |
| return sum( |
| _tensor_nbytes(hint_int(n.numel(), fallback=4096), n.dtype) |
| for n in val |
| if isinstance(n, torch.Tensor) |
| ) |
| elif isinstance(val, torch.Tensor): |
| return _tensor_nbytes(hint_int(val.numel(), fallback=4096), val.dtype) |
| |
| raise RuntimeError(f"Unknown metadata type {type(val)}") |
| if node.op == "get_attr": |
| return 0 |
| raise RuntimeError("We should always have `val` metadata on the nodes") |
| |
| |
| # Used for some investigative purposes |
| def _count_ops(graph: fx.Graph): |
| from collections import defaultdict |
| |
| cnt: Dict[str, int] = defaultdict(int) |
| for node in graph.nodes: |
| if node.op == "call_function": |
| cnt[node.target.__name__] += 1 |
| print(sorted(cnt.items(), key=lambda x: x[1], reverse=True)) |
| |
| |
| @functools.lru_cache(None) |
| def pointwise_ops(): |
| ops = [] |
| for attr_name in dir(torch.ops.aten): |
| opoverloadpacket = getattr(torch.ops.aten, attr_name) |
| if not isinstance(opoverloadpacket, torch._ops.OpOverloadPacket): |
| continue |
| |
| for overload in opoverloadpacket.overloads(): |
| op_overload = getattr(opoverloadpacket, overload) |
| if torch.Tag.pointwise in op_overload.tags: |
| # currently aot autograd uses packet not overload |
| ops.append(opoverloadpacket) |
| break |
| |
| return ops |
| |
| |
| def sort_depths(args, depth_map: Dict[fx.Node, int]) -> List[Tuple[fx.Node, int]]: |
| arg_depths = { |
| arg: depth_map[arg] for arg in args if isinstance(arg, torch.fx.node.Node) |
| } |
| return sorted(arg_depths.items(), key=lambda x: x[1], reverse=True) |
| |
| |
| def reordering_to_mimic_autograd_engine(gm: fx.GraphModule) -> fx.GraphModule: |
| """ |
| This pass finds the first bwd node in the graph (by looking at users of |
| tangents) and then reorders the graph by walking from this node to all the |
| way to the end of the graph. At each op in this traveral, we insert this op |
| in a new graph and try to bring only the relevant subgraph from the other |
| non-bwd edges relevant for this op. This closely mimics the behavior of |
| autograd engine. |
| |
| Why is this pass required in the first place? |
| |
| This is an artifact of how partitioners work today. The starting point of |
| partitioner is a joint graph, which is fwd and then bwd graph. In the case |
| of checkpointing, we keep portions of fwd graph in their original place in |
| the joint graph, while obtaining a bwd graph. As a result, the resulting bwd |
| graph has copies of recomputed fwd subgraphs followed by the original bwd |
| graph. If we run this naively, this leads to bad memory footprint, because |
| the fwd subgraphs are live for way longer duration than necessary. This pass |
| reorders the operations such that we prioritize the ops for the original bwd |
| graph while only realizing those ops from the fwd graph that are necessary |
| at any given point in the graph. |
| """ |
| |
| new_graph = fx.Graph() |
| env: Dict[fx.Node, fx.Node] = {} |
| |
| # Add new placeholder nodes in the order specified by the inputs |
| for node in gm.graph.find_nodes(op="placeholder"): |
| env[node] = new_graph.node_copy(node, lambda x: env[x]) |
| |
| order = {} |
| for idx, node in enumerate(gm.graph.nodes): |
| order[node] = idx |
| |
| def insert_node_in_graph(node): |
| cur_nodes = [node] |
| insertable_nodes = set() |
| while len(cur_nodes) > 0: |
| node = cur_nodes.pop() |
| if node in insertable_nodes or node in env: |
| continue |
| insertable_nodes.add(node) |
| |
| # Bias traversal towards the nodes that have higher depth - prioritizes |
| # critical path first. |
| cur_nodes += node.all_input_nodes |
| |
| insertable_nodes = sorted(insertable_nodes, key=lambda n: order[n]) |
| for node in insertable_nodes: |
| env[node] = new_graph.node_copy(node, lambda x: env[x]) |
| |
| # Find first bwd node in the graph |
| tangent_inputs = list(filter(_is_tangent, gm.graph.nodes)) |
| first_node_in_bwd = None |
| minimum_order = math.inf |
| for tangent in tangent_inputs: |
| for user in tangent.users: |
| if order[user] < minimum_order: |
| minimum_order = order[user] |
| first_node_in_bwd = user |
| |
| # If gradInp does not depend upon gradOut, we may not find any nodes in the "backwards pass" |
| if first_node_in_bwd is None: |
| return gm |
| |
| # Build the graph op-by-op by starting from the node all the way to the end |
| for node in list(gm.graph.nodes)[order[first_node_in_bwd] :]: |
| insert_node_in_graph(node) |
| |
| # The output node is already built by the traversal. |
| new_gm = torch.fx.GraphModule(gm, new_graph) |
| return new_gm |
| |
| |
| def functionalize_rng_ops( |
| joint_module: fx.GraphModule, |
| fw_module: fx.GraphModule, |
| bw_module: fx.GraphModule, |
| num_sym_nodes: int, |
| ) -> Tuple[fx.GraphModule, fx.GraphModule]: |
| # During user-driven activation checkpointing, we have to ensure that a rng |
| # op in fwd yields the same output as the recomputed rng op in the bwd. To |
| # do this, we use functionalize wrappers to wrap the random ops and share |
| # rng state between the fwd and bwd graphs. |
| |
| # There are 3 main steps to do this |
| # Step 1 - Construct a mapping of rng node between the fwd and its counterpart in bwd. |
| # Step 2 - Modify the fwd pass such that |
| # 1) Replace rand with run_and_save_rng_state wrapper |
| # 2) Replace the users of the original op with the output[1] of this op. |
| # 3) Collect all the rng_state - output[0] of each op, and make them |
| # output nodes. Special care needs to be taken here because fwd outputs |
| # has symints at the very end. |
| # Step 3 - Modify the bwd pass such that |
| # 1) Add the input nodes just before the tangents for the stashed rng states |
| # 2) Replace rand with run_with_save_rng_state wrappers |
| # 3) Use the stashed states as inputs to these ops |
| |
| # Unique id to generate name |
| uid = itertools.count() |
| |
| def get_rng_ops(gmod): |
| random_nodes = {} |
| for node in gmod.graph.nodes: |
| if ( |
| node.op == "call_function" |
| and hasattr(node.target, "tags") |
| and torch.Tag.nondeterministic_seeded in node.target.tags |
| ): |
| random_nodes[node.name] = node |
| return random_nodes |
| |
| def get_device(node): |
| """ |
| Check the example value of the node outputs to find the device type. |
| """ |
| if "val" not in node.meta: |
| return None |
| |
| candidates = node.meta["val"] |
| if not isinstance(candidates, tuple): |
| candidates = (candidates,) |
| |
| for candidate in candidates: |
| if isinstance(candidate, torch.Tensor): |
| if candidate.device.type == "cuda": |
| return "cuda" |
| |
| return "cpu" |
| |
| def get_sample_rng_state(device): |
| if device == "cuda": |
| return torch.cuda.get_rng_state() |
| return torch.get_rng_state() |
| |
| # Step 1 - Construct a mapping of rng node between the fwd and its counterpart in bwd. |
| joint_graph_rng_ops = get_rng_ops(joint_module) |
| fw_graph_rng_ops = get_rng_ops(fw_module) |
| bw_graph_rng_ops = get_rng_ops(bw_module) |
| recomputable_rng_ops_map = dict() |
| for node in joint_module.graph.nodes: |
| if ( |
| must_recompute(node) |
| and hasattr(node.target, "tags") |
| and torch.Tag.nondeterministic_seeded in node.target.tags |
| ): |
| base_node = joint_graph_rng_ops[node.name] |
| fw_node = fw_graph_rng_ops[node.name] |
| bw_node = bw_graph_rng_ops[node.name] |
| recomputable_rng_ops_map[base_node] = {"fwd": fw_node, "bwd": bw_node} |
| |
| run_and_save_rng = torch._prims.rng_prims.run_and_save_rng_state |
| run_with_rng_state = torch._prims.rng_prims.run_with_rng_state |
| bw_tangent_start_node = None |
| for node in bw_module.graph.find_nodes(op="placeholder"): |
| if "tangent" in node.name: |
| bw_tangent_start_node = node |
| break |
| if bw_tangent_start_node is None: |
| raise RuntimeError( |
| "Couldn't find tangent node in graph inputs. This is unexpected, please file a bug if you see this" |
| ) |
| |
| fw_rng_state_outputs = [] |
| for base_node, node_pair in recomputable_rng_ops_map.items(): |
| # Step 2 - Modify the fwd pass such that |
| fw_node = node_pair["fwd"] |
| bw_node = node_pair["bwd"] |
| fw_graph = fw_module.graph |
| with fw_graph.inserting_before(fw_node): |
| functional_fw_node = fw_graph.create_node( |
| "call_function", |
| run_and_save_rng, |
| args=(fw_node.target, *fw_node.args), |
| kwargs=fw_node.kwargs, |
| ) |
| state = fw_graph.create_node( |
| "call_function", |
| operator.getitem, |
| args=(functional_fw_node, 0), |
| kwargs={}, |
| ) |
| rng_output = fw_graph.create_node( |
| "call_function", |
| operator.getitem, |
| args=( |
| functional_fw_node, |
| 1, |
| ), |
| kwargs={}, |
| ) |
| fw_node.replace_all_uses_with(rng_output) |
| fw_graph.erase_node(fw_node) |
| fw_rng_state_outputs.append(state) |
| |
| # Step 3 - Modify the bwd pass such that |
| bw_graph = bw_module.graph |
| with bw_graph.inserting_before(bw_tangent_start_node): |
| state_name = f"rng_state_output_{next(uid)}" |
| bw_rng_state_node = bw_graph.placeholder(state_name) |
| bw_rng_state_node.meta["val"] = get_sample_rng_state(get_device(fw_node)) |
| |
| with bw_graph.inserting_before(bw_node): |
| rng_output = bw_graph.create_node( |
| "call_function", |
| run_with_rng_state, |
| args=(bw_rng_state_node, bw_node.target, *bw_node.args), |
| kwargs=bw_node.kwargs, |
| ) |
| |
| bw_node.replace_all_uses_with(rng_output) |
| bw_graph.erase_node(bw_node) |
| |
| # Add the rng states in the output of the fwd graph. AOT Autograd assumes |
| # that symints are at the end of forward graph outputs. So, insert the new |
| # rng states accordingly. |
| fw_output_node = next(iter(fw_module.graph.find_nodes(op="output"))) |
| fw_outputs = fw_output_node.args[0] |
| sym_node_start_idx = len(fw_outputs) - num_sym_nodes |
| outputs = ( |
| fw_outputs[:sym_node_start_idx] |
| + fw_rng_state_outputs |
| + fw_outputs[sym_node_start_idx:] |
| ) |
| fw_module.graph.output(outputs) |
| fw_module.graph.erase_node(fw_output_node) |
| fw_module.recompile() |
| bw_module.recompile() |
| return fw_module, bw_module |
| |
| |
| def cleanup_recompute_tags(joint_module: fx.GraphModule) -> fx.GraphModule: |
| """ |
| If there are two consecutive checkpointed blocks with no operator in |
| between, we would still want to stash the tensor at the boundary of |
| checkpointed blocks. The following pass makes the last output node |
| non-recomputable to allow for that. |
| """ |
| for node in joint_module.graph.nodes: |
| if must_recompute(node): |
| for user in node.users: |
| if ( |
| must_recompute(user) |
| and user.meta["ac_graph_id"] > node.meta["ac_graph_id"] |
| ): |
| node.meta["recompute"] = CheckpointPolicy.MUST_SAVE |
| return joint_module |
| |
| |
| def solve_min_cut( |
| joint_graph: fx.Graph, |
| node_info: NodeInfo, |
| min_cut_options: MinCutOptions, |
| dont_ban=None, |
| ): |
| if dont_ban is None: |
| dont_ban = set() |
| op_types = get_default_op_list() |
| |
| if AOT_PARTITIONER_DEBUG: |
| joint_module_ops = { |
| str(node.target._overloadpacket) |
| for node in joint_graph.nodes |
| if node.op == "call_function" and hasattr(node.target, "_overloadpacket") |
| } |
| ops_ignored = joint_module_ops - {str(i) for i in op_types.recomputable_ops} |
| print("Ops banned from rematerialization: ", ops_ignored) |
| print() |
| |
| def is_fusible(a, b): |
| # We can perform "memory fusion" into a cat, but cat cannot be a |
| # producer to a fusion |
| if get_aten_target(b) == aten.cat: |
| return True |
| return op_types.is_fusible(a) and op_types.is_fusible(b) |
| |
| try: |
| import networkx as nx |
| except ImportError as e: |
| raise RuntimeError( |
| "Need networkx installed to perform smart recomputation " "heuristics" |
| ) from e |
| |
| def is_materialized_backwards(node): |
| if op_types.is_view(node): |
| return False |
| cur_nodes = {node} |
| while len(cur_nodes) > 0: |
| cur = cur_nodes.pop() |
| for user in cur.users: |
| if not node_info.is_required_fw(user) and not is_fusible(cur, user): |
| return True |
| if op_types.is_view(user): |
| cur_nodes.add(user) |
| |
| return False |
| |
| def should_ban_recomputation(node): |
| if node.op != "call_function": |
| return False |
| if node.target == operator.getitem: |
| return False |
| if node.meta.get("recompute", None) == CheckpointPolicy.MUST_SAVE: |
| return True |
| if config.recompute_views and op_types.is_view(node): |
| return False |
| if node.target in [aten.lift_fresh_copy.default, aten.lift_fresh.default]: |
| return False |
| |
| if min_cut_options.ban_if_not_in_allowlist: |
| if not op_types.is_recomputable(node): |
| return True |
| else: |
| if op_types.is_random(node) or op_types.is_compute_intensive(node): |
| return True |
| |
| # If a node *must* be materialized in the backwards pass, then we |
| # should never recompute it. This is a pretty subtle point. In |
| # general, the assumption we make is that recomputing a node in the |
| # backwards pass is "free". However, if a node must be materialized |
| # in the backwards pass, then recomputing it is never free. |
| if min_cut_options.ban_if_materialized_backward and is_materialized_backwards( |
| node |
| ): |
| log.info("materialized backwards: %s %s", node, tuple(node.users)) |
| return True |
| |
| # Arbitrary hack that sometimes seems to help things. The above |
| # modification appears to have made this heuristic a lot less critical |
| # for performance. |
| # NB: As of PR #121692, this hack no longer seems necessary. |
| if node.dist_from_bw < 1000 and node.dist_from_bw > config.max_dist_from_bw: |
| return True |
| |
| # If the output of an op is 4x smaller (arbitrary choice), |
| # then we don't allow recomputation. The idea here is that for |
| # things like reductions, saving the output of the reduction is very |
| # cheap/small, and it makes sure we don't do things like recompute |
| # normalizations in the backwards. |
| if min_cut_options.ban_if_reduction: |
| input_tensors_size = sum( |
| _size_of(i) for i in node.args if isinstance(i, fx.Node) |
| ) |
| output_size = _size_of(node) |
| return output_size * 4 < input_tensors_size |
| return False |
| |
| def is_materialized(node): |
| if node.op == "placeholder": |
| return True |
| |
| return not all(is_fusible(node, user) for user in node.users) |
| |
| def get_node_weight(node) -> float: |
| mem_sz = _size_of(node) |
| if config.recompute_views and op_types.is_view(node): |
| # If `config.recompute_views=True`, we don't save views. This is generally |
| # a good idea since views are free to recompute, and it makes it a bit simpler |
| # to analyze. |
| # NB: If they're not free to recompute (e.g. nested tensors)... I |
| # think we should modify checks for view_ops to `is_view` and check |
| # that. Basically, with nested tensors, `aten.view` is not a "view |
| # op". |
| return math.inf |
| |
| if isinstance(node.meta["val"], py_sym_types): |
| # We never want to save symfloats |
| if not isinstance(node.meta["val"], torch.SymInt): |
| return INT_INF |
| |
| # Heuristic to bias towards nodes closer to the backwards pass |
| # Complete guess about current value |
| mem_sz = int(mem_sz * (1.1 ** max(min(node.dist_from_bw, 100), 1))) |
| if is_materialized(node): |
| return mem_sz |
| else: |
| return mem_sz * 2 |
| |
| nx_graph = nx.DiGraph() |
| banned_nodes = set() |
| |
| def ban_recomputation_if_allowed(node): |
| if op_types.is_view(node): |
| return False |
| if node in dont_ban: |
| return False |
| # This bans recomputation of the node unless we've been forced not to by |
| # user annotation |
| if must_recompute(node): |
| return False |
| |
| if "val" in node.meta and isinstance(node.meta["val"], torch.SymFloat): |
| return False |
| |
| banned_nodes.add(node) |
| # A node will only ever be recomputed if there is a path from an |
| # ancestor of this node to the backwards path through this node that |
| # doesn't go through any saved value. If this node is saved, then that |
| # condition is not possible. |
| nx_graph.add_edge("source", node.name + "_in", capacity=math.inf) |
| return True |
| |
| for node in joint_graph.nodes: |
| if node.op == "output": |
| continue |
| |
| if node in node_info.required_bw_nodes: |
| if node not in node_info.inputs: |
| nx_graph.add_edge(node.name + "_in", "sink", capacity=math.inf) |
| continue |
| # If someone saves a input for backward as-is and backward |
| # returns that tensor as-is as a grad input, then the node x would |
| # be both a required_bw_node and an input. In this case we |
| # (1) connect x_in to to the source, (2) x_out to the sink, and |
| # (3) assign the proper weight to the x_in-x_out edge, so that |
| # x would be part of cut nodes. A case where this happens is if |
| # NestedTensor saves a offset tensor as part of the singleton int |
| # in sizes. |
| nx_graph.add_edge(node.name + "_out", "sink", capacity=math.inf) |
| |
| if must_recompute(node): |
| # If user explicitly says they want to recompute a node, we honor it |
| # by adding an inf-capacity edge from X_in to the sink. |
| # This way, X_in node is guaranteed to be part of the subgraph that contains "sink" |
| # after the cut, thus guaranteeing that X op will be recomputed. |
| nx_graph.add_edge(node.name + "_in", "sink", capacity=math.inf) |
| continue |
| |
| if _is_primal(node) or _is_fwd_seed_offset(node): |
| ban_recomputation_if_allowed(node) |
| |
| # If a node can't be recomputed (too expensive or involves randomness), |
| # we prevent it from being recomputed by adding an inf edge to the source |
| # We only need to ban nodes in the fw pass, as those are the only ones that would be recomputed. |
| if node_info.is_required_fw(node) and should_ban_recomputation(node): |
| ban_recomputation_if_allowed(node) |
| |
| # Checks if a node is actually a tuple. Can be simplified to just an isinstance check if we always use faketensors. |
| is_non_tensor_node = ( |
| "val" not in node.meta and "tensor_meta" not in node.meta |
| ) or ("val" in node.meta and not isinstance(node.meta["val"], torch.Tensor)) |
| |
| if is_sym_node(node): |
| weight = float(sym_node_size(node)) |
| elif is_non_tensor_node: |
| weight = ( |
| 0.0 if isinstance(node.meta.get("val"), BackwardState) else math.inf |
| ) |
| else: |
| weight = get_node_weight(node) |
| # Creates the weights on the "node" edge |
| nx_graph.add_edge(node.name + "_in", node.name + "_out", capacity=weight) |
| for user in node.users: |
| nx_graph.add_edge(node.name + "_out", user.name + "_in", capacity=math.inf) |
| |
| # todo(chilli): This is the most questionable of the 3 heuristics for banning recompute. |
| # Some example models to look at where this helps perf: poolformer_m36, |
| # mixer_b16_224, cait_m36_384 |
| |
| # The "rough" idea here is that if you have some node that is used by both a |
| # node nearby downstream as well as a node far downstream, if we recompute |
| # both of the downstream nodes, we're unlikely to be able to fuse both |
| # downstream nodes together. |
| |
| # Thus, we shouldn't aim to recompute far downstream nodes that depend on |
| # this node. That intuition of "far downstream" is captured by whether |
| # there's an unfusible op along the chain somewhere |
| |
| # It could probably be improved by properly analyzing what's going on in the |
| # backwards pass instead of only relying on whether it's unfusible in the |
| # forwards. |
| |
| def find_first_unfusible(start_nodes: List[fx.Node], max_range: int) -> int: |
| """ |
| Finds the first unfusible node in the chain of nodes starting from |
| `start_nodes` and returns its position. |
| """ |
| sorted_nodes: List[Tuple[int, fx.Node, bool]] = [] |
| for n in start_nodes: |
| heapq.heappush(sorted_nodes, (node_info.get_fw_order(n), n, True)) |
| |
| while len(sorted_nodes) > 0: |
| _, node, node_is_fusible = heapq.heappop(sorted_nodes) |
| if not node_is_fusible: |
| return node_info.get_fw_order(node) |
| for user in node.users: |
| if node_info.is_required_fw(user): |
| if node_info.get_fw_order(user) > max_range: |
| continue |
| heapq.heappush( |
| sorted_nodes, |
| (node_info.get_fw_order(user), user, is_fusible(node, user)), |
| ) |
| return max_range |
| |
| if min_cut_options.ban_if_used_far_apart: |
| for used_node in node_info.required_fw_nodes: |
| orders = [ |
| node_info.get_fw_order(user) |
| for user in used_node.users |
| if node_info.is_required_fw(user) |
| ] |
| fw_users = [ |
| user for user in used_node.users if node_info.is_required_fw(user) |
| ] |
| if len(orders) > 0: |
| first_unfusible_use = find_first_unfusible(fw_users, max(orders)) |
| for user in tuple(used_node.users): |
| if ( |
| node_info.is_required_fw(user) |
| and node_info.get_fw_order(user) > first_unfusible_use |
| and is_fusible(used_node, user) |
| ): |
| if user in banned_nodes: |
| continue |
| log.info( |
| "used above/below fusible %s:(%s) -> %s -> %s:(%s)", |
| used_node, |
| node_info.get_fw_order(used_node), |
| first_unfusible_use, |
| user, |
| node_info.get_fw_order(user), |
| ) |
| ban_recomputation_if_allowed(user) |
| |
| # This heuristic is fairly straightforward. The idea is that although it is |
| # cheap to recompute bandwidth-bound ops, we don't want to end up in a situation |
| # where we have a long chain of pointwise ops from the beginning to the end |
| # of the model (like say, residual connections) |
| |
| # todo: I'm not totally sure why this heuristic matters. It's possible that this is |
| # working around Inductor fusion decisions, or that it's a patch over |
| # suboptimal partitioning decisions |
| |
| # Some models it improves perf on are cait_m36_384, mixer_b16_224, poolformer_m36 |
| |
| if min_cut_options.ban_if_long_fusible_chains: |
| visited = set() |
| for start_node in joint_graph.nodes: |
| if not node_info.is_required_fw(start_node): |
| continue |
| fusible = [(node_info.get_fw_order(start_node), start_node)] |
| start_order = node_info.get_fw_order(start_node) |
| while len(fusible) > 0: |
| _, cur = heapq.heappop(fusible) |
| if cur in visited: |
| continue |
| visited.add(cur) |
| # 100 is arbitrary choice to try and prevent degenerate cases |
| if ( |
| node_info.get_fw_order(cur) > start_order + 100 |
| and len(fusible) == 0 |
| ): |
| log.info( |
| "too long %s %s %s %s", |
| cur, |
| start_node, |
| node_info.get_fw_order(cur), |
| node_info.get_fw_order(start_node), |
| ) |
| ban_recomputation_if_allowed(cur) |
| break |
| |
| for user in cur.users: |
| if ( |
| node_info.is_required_fw(user) |
| and is_fusible(cur, user) |
| and user not in banned_nodes |
| ): |
| heapq.heappush(fusible, (node_info.get_fw_order(user), user)) |
| |
| try: |
| cut_value, partition = nx.minimum_cut(nx_graph, "source", "sink") |
| except Exception: |
| print("Failed to compute min-cut on following graph:") |
| print("\n".join(nx.readwrite.edgelist.generate_edgelist(nx_graph))) |
| visualize_min_cut_graph(nx_graph) |
| raise |
| |
| reachable, non_reachable = partition |
| cutset: Set[Tuple[str, str]] = set() |
| for u, nbrs in ((n, nx_graph[n]) for n in reachable): |
| cutset.update((u, v) for v in nbrs if v in non_reachable) |
| |
| cut_nodes = set() |
| for node_in, node_out in cutset: |
| assert node_in[:-3] == node_out[:-4] |
| node_name = node_in[:-3] |
| cut_nodes.add(node_name) |
| |
| name_to_node = get_name_to_node(joint_graph) |
| # To make this stuff deterministic |
| node_idx = {node: idx for idx, node in enumerate(joint_graph.nodes)} |
| saved_values = sorted( |
| (name_to_node[node] for node in cut_nodes), key=lambda x: node_idx[x] |
| ) |
| return saved_values, banned_nodes |
| |
| |
| def visualize_min_cut_graph(nx_graph): |
| import networkx as nx |
| import pydot |
| |
| dot_format = nx.nx_pydot.to_pydot(nx_graph).to_string() |
| dot_graph = pydot.graph_from_dot_data(dot_format)[0] |
| for edge in dot_graph.get_edges(): |
| weight = nx_graph[edge.get_source()][edge.get_destination()]["capacity"] |
| # Set edge label to weight |
| edge.set_label(str(weight)) |
| # Color edges with weight 'inf' as red |
| if weight == float("inf"): |
| edge.set_color("red") |
| print("Visualizing the failed graph to min_cut_failed.svg") |
| dot_graph.write_svg("min_cut_failed.svg") |
| |
| |
| def get_default_op_list() -> OpTypes: |
| default_recomputable_ops: List[Callable] = [ |
| aten.add, |
| aten.sub, |
| aten.div, |
| aten.atan2, |
| aten.mul, |
| aten.max, |
| aten.min, |
| aten.pow, |
| aten.remainder, |
| aten.fmod, |
| aten.__and__, |
| aten.__or__, |
| aten.__xor__, |
| aten.__lshift__, |
| aten.__rshift__, |
| aten.eq, |
| aten.ne, |
| aten.ge, |
| aten.gt, |
| aten.le, |
| aten.lt, |
| aten.abs, |
| aten.bitwise_not, |
| aten.ceil, |
| aten.floor, |
| aten.frac, |
| aten.neg, |
| aten.relu, |
| aten.round, |
| aten.silu, |
| aten.trunc, |
| aten.log, |
| aten.log10, |
| aten.log1p, |
| aten.log2, |
| aten.lgamma, |
| aten.exp, |
| aten.expm1, |
| aten.erf, |
| aten.erfc, |
| aten.cos, |
| aten.acos, |
| aten.cosh, |
| aten.sin, |
| aten.asin, |
| aten.sinh, |
| aten.tan, |
| aten.atan, |
| aten.tanh, |
| aten.atanh, |
| aten.sqrt, |
| aten.rsqrt, |
| aten.reciprocal, |
| aten.sigmoid, |
| aten.softplus, |
| aten.threshold, |
| aten.threshold_backward, |
| aten.clamp, |
| aten.where, |
| aten.lerp, |
| aten.addcmul, |
| aten.gelu, |
| aten.gelu_backward, |
| aten.sum, |
| aten.mean, |
| aten._grad_sum_to_size, |
| aten.sum_to_size, |
| aten.amax, |
| aten.to, |
| aten.type_as, |
| operator.getitem, |
| aten.squeeze, |
| aten.unsqueeze, |
| aten.rsub, |
| aten._to_copy, |
| ] # noqa: E501,B950 |
| recomputable_view_ops = [aten.squeeze, aten.unsqueeze, aten.alias] |
| recomputable_view_ops += [ |
| aten.view, |
| aten.slice, |
| aten.t, |
| prims.broadcast_in_dim, |
| aten.expand, |
| aten.as_strided, |
| aten.permute, |
| ] |
| view_ops = recomputable_view_ops |
| default_recomputable_ops += [ |
| prims.div, |
| prims.convert_element_type, |
| aten.clone, |
| aten._to_copy, |
| aten.full_like, |
| prims.var, |
| prims.sum, |
| aten.var, |
| aten.std, |
| prims.broadcast_in_dim, |
| aten.select, |
| aten._unsafe_view, |
| aten.view, |
| aten.expand, |
| aten.slice, |
| aten.reshape, |
| aten.broadcast_tensors, |
| aten.scalar_tensor, |
| aten.ones, |
| aten.new_zeros, |
| aten.lift_fresh_copy, |
| aten.arange, |
| aten.triu, |
| aten.var_mean, |
| aten.isinf, |
| aten.any, |
| aten.full, |
| aten.as_strided, |
| aten.zeros, |
| aten.argmax, |
| aten.maximum, |
| prims.iota, |
| prims._low_memory_max_pool2d_offsets_to_indices, |
| ] # noqa: E501,B950 |
| # Natalia said that we should allow recomputing indexing :) |
| default_recomputable_ops += [aten.index, aten.gather] |
| default_recomputable_ops += view_ops |
| |
| default_recomputable_ops += pointwise_ops() |
| |
| default_recomputable_ops += [ |
| aten.zeros_like, |
| ] |
| |
| default_recomputable_ops += [method_to_operator(m) for m in magic_methods] |
| recomputable_ops = set(default_recomputable_ops) |
| |
| random_ops = [aten.native_dropout, aten.rand_like, aten.randn_like] |
| compute_intensive_ops = [ |
| aten.mm, |
| aten.convolution, |
| aten.convolution_backward, |
| aten.bmm, |
| aten.addmm, |
| aten._scaled_dot_product_flash_attention, |
| aten._scaled_dot_product_efficient_attention, |
| aten._flash_attention_forward, |
| aten._efficient_attention_forward, |
| aten.upsample_bilinear2d, |
| ] # noqa: E501,B950 |
| |
| fusible_ops = recomputable_ops | set(random_ops) |
| return OpTypes( |
| set(fusible_ops), |
| set(compute_intensive_ops), |
| set(random_ops), |
| set(view_ops), |
| set(recomputable_ops), |
| ) |
| |
| |
| def get_name_to_node(graph: fx.Graph): |
| name_to_node = {} |
| for node in graph.nodes: |
| name_to_node[node.name] = node |
| return name_to_node |
| |
| |
| def greedy_knapsack( |
| memory: List[float], runtimes: List[float], max_memory: float |
| ) -> Tuple[float, List[int], List[int]]: |
| n = len(runtimes) |
| items = list(range(n)) |
| |
| # Sort items based on the ratio of runtime to memory in descending order |
| items = sorted(items, key=lambda i: runtimes[i] / memory[i], reverse=True) |
| |
| total_memory = 0.0 |
| total_runtime = 0.0 |
| items_to_save = [] |
| items_to_allow_recomputing = [] |
| |
| for i in items: |
| if total_memory + memory[i] <= max_memory: |
| total_memory += memory[i] |
| total_runtime += runtimes[i] |
| items_to_save.append(i) |
| else: |
| items_to_allow_recomputing.append(i) |
| return total_runtime, items_to_save, items_to_allow_recomputing |
| |
| |
| def ilp_knapsack( |
| memory: List[float], runtimes: List[float], max_memory: float |
| ) -> Tuple[float, List[int], List[int]]: |
| import numpy as np |
| |
| try: |
| from scipy.optimize import Bounds, LinearConstraint, milp |
| except ImportError: |
| raise RuntimeError( |
| "To use the ILP for memory budget checkpointing you need to install scipy" |
| ) from None |
| |
| np_memory = np.array(memory) |
| np_runtimes = np.array(runtimes) |
| c = -np_runtimes # type: ignore[operator] |
| |
| memory_constraint = LinearConstraint(A=np_memory, ub=np.array(max_memory)) |
| constraints = [memory_constraint] |
| |
| integrality = np.ones_like(c) |
| res = milp( |
| c=c, constraints=constraints, integrality=integrality, bounds=Bounds(0, 1) |
| ) |
| if not res.success: |
| raise RuntimeError("Somehow scipy solving failed") |
| |
| items_to_save = [] |
| items_to_allow_recomputing = [] |
| for idx, i in enumerate(res.x): |
| if i == 1: |
| items_to_save.append(idx) |
| else: |
| items_to_allow_recomputing.append(idx) |
| return -res.fun, items_to_save, items_to_allow_recomputing |
| |
| |
| def dp_knapsack( |
| memory: List[float], runtimes: List[float], max_memory: float |
| ) -> Tuple[float, List[int], List[int]]: |
| # Scaling factor to convert floating point weights to integers |
| S = 10000 |
| |
| # Quantize the memory weights |
| quantized_memory = torch.tensor( |
| [int(round(m * S)) for m in memory], dtype=torch.long, device="cpu" |
| ) |
| runtimes = torch.tensor(runtimes, dtype=torch.float32, device="cpu") |
| |
| # Quantized pseudopolynomial DP for 0-1 Knapsack |
| quantized_max_memory = int(round(max_memory * S)) |
| |
| n = len(memory) |
| |
| # Initialize the DP table |
| # TODO(chilli): I think if needed, this memory can be optimized with sliding |
| # window trick + Hirschberg trick: |
| # https://codeforces.com/blog/entry/47247?#comment-316200 |
| dp = torch.zeros( |
| (n + 1, quantized_max_memory + 1), dtype=torch.float32, device="cpu" |
| ) |
| |
| for i in range(1, n + 1): |
| current_memory = quantized_memory[i - 1] |
| current_runtime = runtimes[i - 1] |
| |
| # Copy the previous row |
| dp[i, :] = dp[i - 1, :] |
| |
| # Update dp[i, j] for all j >= current_memory |
| if current_memory == 0: |
| dp[i, :] = dp[i - 1, :] + current_runtime |
| else: |
| dp[i, current_memory:] = torch.maximum( |
| dp[i - 1, current_memory:], |
| dp[i - 1, :-current_memory] + current_runtime, |
| ) |
| |
| # Backtrack to find the items included in the knapsack |
| saved_items = [] |
| recomputable_items = [] |
| j: int = quantized_max_memory |
| for i in range(n, 0, -1): |
| if dp[i][j] != dp[i - 1][j]: |
| saved_items.append(i - 1) # Include this item (indexing from 0) |
| j -= int(quantized_memory[i - 1].item()) |
| else: |
| recomputable_items.append(i - 1) |
| |
| saved_items.reverse() # To get items in the order they were added |
| |
| # The maximum runtime that can be achieved within the max_memory constraint |
| max_runtime = dp[n][quantized_max_memory].item() |
| |
| return max_runtime, saved_items, recomputable_items |
| |
| |
| def _optimize_runtime_with_given_memory( |
| memory: List[float], |
| runtimes: List[float], |
| max_memory: float, |
| ) -> Tuple[float, List[int], List[int]]: |
| SOLVER = config.activation_memory_budget_solver |
| if SOLVER == "greedy": |
| return greedy_knapsack(memory, runtimes, max_memory) |
| elif SOLVER == "ilp": |
| return ilp_knapsack(memory, runtimes, max_memory) |
| elif SOLVER == "dp": |
| return dp_knapsack(memory, runtimes, max_memory) |
| else: |
| raise RuntimeError(f"Not aware of memory budget knapsack solver: {SOLVER}") |
| |
| |
| from torch.utils._mode_utils import no_dispatch |
| |
| |
| def estimate_runtime(node): |
| RUNTIME_MODE = config.activation_memory_budget_runtime_estimator |
| |
| def materialize_arg(x): |
| if isinstance(x, fx.Node) and isinstance(x.meta["val"], torch.Tensor): |
| shape = list(x.meta["val"].shape) |
| |
| def realize_symbol(d): |
| return hint_int(d, fallback=4096) |
| |
| shape = [realize_symbol(s) for s in shape] |
| return x.meta["val"].new_zeros(shape) |
| elif isinstance(x, fx.Node) and isinstance(x.meta["val"], torch.SymInt): |
| return hint_int(x.meta["val"], fallback=4096) |
| elif isinstance(x, fx.Node) and isinstance(x.meta["val"], torch.SymFloat): |
| return 1.0 |
| elif isinstance(x, fx.Node) and isinstance(x.meta["val"], torch.SymBool): |
| return True |
| else: |
| return x |
| |
| if RUNTIME_MODE == "testing": |
| return 1 |
| |
| elif RUNTIME_MODE == "profile": |
| from triton.testing import do_bench |
| |
| with no_dispatch(): |
| args, kwargs = pytree.tree_map(materialize_arg, (node.args, node.kwargs)) |
| ms = do_bench(lambda: node.target(*args, **kwargs)) |
| return ms |
| |
| elif RUNTIME_MODE == "flops": |
| # todo(chilli): Normalize this to also return ms |
| from torch.utils.flop_counter import FlopCounterMode |
| |
| args, kwargs = pytree.tree_map(materialize_arg, (node.args, node.kwargs)) |
| with FlopCounterMode(display=False) as mode: |
| node.target(*args, **kwargs) |
| counted_flops = mode.get_total_flops() |
| return max(counted_flops, 1) |
| else: |
| raise RuntimeError(f"Not aware of runtime estimator: {RUNTIME_MODE}") |
| |
| |
| def choose_saved_values_set( |
| joint_graph: fx.Graph, node_info: NodeInfo, memory_budget=1 |
| ) -> List[fx.Node]: |
| if memory_budget > 1 or memory_budget < 0: |
| raise RuntimeError( |
| f"The valid ranges for memory budget are 0 <= m <= 1. The provided value is {memory_budget}" |
| ) |
| min_cut_options = MinCutOptions( |
| ban_if_used_far_apart=config.ban_recompute_used_far_apart, |
| ban_if_long_fusible_chains=config.ban_recompute_long_fusible_chains, |
| ban_if_materialized_backward=config.ban_recompute_materialized_backward, |
| ban_if_not_in_allowlist=config.ban_recompute_not_in_allowlist, |
| ban_if_reduction=config.ban_recompute_reductions, |
| ) |
| |
| if config.aggressive_recomputation: |
| min_cut_options = replace( |
| min_cut_options, |
| ban_if_used_far_apart=False, |
| ban_if_long_fusible_chains=False, |
| ban_if_materialized_backward=False, |
| ban_if_not_in_allowlist=False, |
| ) |
| if memory_budget == 0: |
| return node_info.inputs |
| |
| runtime_optimized_saved_values, _ = solve_min_cut( |
| joint_graph, |
| node_info, |
| min_cut_options, |
| ) |
| # return runtime_optimized_saved_values |
| if memory_budget == 1: |
| return runtime_optimized_saved_values |
| |
| def estimate_activations_size(saved_values: List[fx.Node]) -> float: |
| return sum([_size_of(i) for i in saved_values]) / 1e9 |
| |
| min_act_size = estimate_activations_size(node_info.inputs) |
| max_act_size = estimate_activations_size(runtime_optimized_saved_values) |
| # The optimized choice is smaller than the inputs anyways |
| if max_act_size <= min_act_size: |
| return runtime_optimized_saved_values |
| |
| def get_normalized_size(sz): |
| return (sz / 1e9) / (max_act_size - min_act_size) |
| |
| def get_mem_ratio(activations: List[fx.Node]): |
| return (estimate_activations_size(activations) - min_act_size) / ( |
| max_act_size - min_act_size |
| ) |
| |
| more_aggressive_options = replace( |
| min_cut_options, |
| ban_if_used_far_apart=False, |
| ban_if_long_fusible_chains=False, |
| ban_if_materialized_backward=False, |
| ) |
| more_aggressive_saved_values, _ = solve_min_cut( |
| joint_graph, node_info, more_aggressive_options |
| ) |
| if get_mem_ratio(more_aggressive_saved_values) < memory_budget: |
| return more_aggressive_saved_values |
| |
| aggressive_options = replace( |
| more_aggressive_options, |
| ban_if_not_in_allowlist=False, |
| ) |
| aggressive_recomputation_saved_values, banned_nodes = solve_min_cut( |
| joint_graph, node_info, aggressive_options |
| ) |
| |
| if get_mem_ratio(aggressive_recomputation_saved_values) < memory_budget: |
| return aggressive_recomputation_saved_values |
| |
| from torch._inductor.fx_utils import get_node_storage |
| |
| input_storages = {get_node_storage(node) for node in node_info.inputs} |
| |
| def get_recomputable_banned_nodes(banned_nodes: List[fx.Node]) -> List[fx.Node]: |
| return [ |
| i |
| for i in banned_nodes |
| if ( |
| # Only allow recomputing nodes that are actually required for BW |
| i.dist_from_bw < int(1e9) # type: ignore[attr-defined] |
| and get_node_storage(i) not in input_storages |
| ) |
| ] |
| |
| recomputable_banned_nodes = get_recomputable_banned_nodes(banned_nodes) |
| |
| # default: runtime_optimized_saved_values |
| # more aggressive: more_aggressive_saved_values |
| # full aggressive: aggressive_recomputation_saved_values |
| |
| all_recomputable_banned_nodes = sorted( |
| recomputable_banned_nodes, key=_size_of, reverse=True |
| ) |
| if len(all_recomputable_banned_nodes) == 0: |
| return node_info.inputs |
| memories_banned_nodes = [ |
| get_normalized_size(_size_of(i)) for i in all_recomputable_banned_nodes |
| ] |
| runtimes_banned_nodes = [ |
| estimate_runtime(node) for node in all_recomputable_banned_nodes |
| ] |
| from torch.utils._mode_utils import no_dispatch |
| |
| def get_saved_values_knapsack(memory_budget): |
| with no_dispatch(): |
| ( |
| expected_runtime, |
| saved_node_idxs, |
| recomputable_node_idxs, |
| ) = _optimize_runtime_with_given_memory( |
| memories_banned_nodes, runtimes_banned_nodes, max(memory_budget, 0) |
| ) |
| dont_ban = set() |
| for idx in recomputable_node_idxs: |
| dont_ban.add(all_recomputable_banned_nodes[idx]) |
| assert dont_ban.issubset(all_recomputable_banned_nodes) |
| |
| saved_values, _ = solve_min_cut( |
| joint_graph, |
| node_info, |
| aggressive_options, |
| dont_ban, |
| ) |
| return saved_values, expected_runtime |
| |
| if config.visualize_memory_budget_pareto: |
| options = [] |
| for sweep_memory_budget in range(100, -1, -5): |
| saved_values, expected_runtime = get_saved_values_knapsack( |
| sweep_memory_budget / 100 |
| ) |
| options.append( |
| ( |
| sweep_memory_budget, |
| sum(runtimes_banned_nodes) - expected_runtime, |
| get_mem_ratio(saved_values), |
| ) |
| ) |
| |
| import matplotlib.pyplot as plt |
| |
| x_values = [item[2] for item in options] |
| y_values = [item[1] for item in options] |
| |
| # Plotting the values with updated axis labels and chart title |
| plt.figure(figsize=(10, 6)) |
| plt.plot(x_values, y_values, marker="o") |
| |
| # Adding labels for each point |
| for i, txt in enumerate(x_values): |
| plt.annotate( |
| f"{txt:.2f}", |
| (txt, y_values[i]), |
| textcoords="offset points", |
| xytext=(0, 10), |
| ha="center", |
| ) |
| |
| plt.xlabel("Memory Budget") |
| plt.ylabel("Runtime of Recomputed Components") |
| plt.title("Pareto Frontier of Memory Budget vs. Recomputation Runtime") |
| plt.grid(True) |
| fig = plt.gcf() |
| plt.show() |
| fig_name = f"memory_budget_pareto_{get_aot_graph_name()}.png" |
| fig.savefig(fig_name) |
| log.warning("Generated Pareto frontier curve at %s", fig_name) |
| |
| # todo(chilli): Estimated doesn't align exactly with actual - actual is |
| # usually less memory than estimated. i'm guessing (actually quite |
| # unsure about this) that's because estimated is just only including |
| # tensors we actually banned from recompute, but there may be other |
| # tensors that we choose to save. |
| |
| return get_saved_values_knapsack(memory_budget=memory_budget)[0] |
| |
| |
| def min_cut_rematerialization_partition( |
| joint_module: fx.GraphModule, |
| _joint_inputs, |
| compiler="inductor", |
| *, |
| num_fwd_outputs, |
| ) -> Tuple[fx.GraphModule, fx.GraphModule]: |
| """ |
| Partitions the joint graph such that the backward recomputes the forward. |
| Recomputing helps in trading off memory bandwidth with computation. |
| |
| To create the fwd and bwd graph, we copy the joint graph, manually set the |
| outputs to just original forward or backward outputs. And then we run the |
| resulting graphs through dead code elimination. |
| |
| .. warning:: |
| This API is experimental and likely to change. |
| |
| Args: |
| joint_module(fx.GraphModule): The joint forward and backward graph. This |
| is the result of AOT Autograd tracing. |
| _joint_inputs: The inputs to the joint graph. This is unused. |
| compiler: This option determines the default set of recomputable ops. |
| Currently, there are two options: ``nvfuser`` and ``inductor``. |
| recomputable_ops: This is an optional set of recomputable ops. If this |
| is not None, then this set of ops will be used instead of the |
| default set of ops. |
| num_fwd_outputs: The number of outputs from the forward graph. |
| |
| Returns: |
| Returns the generated forward and backward Fx graph modules. |
| """ |
| |
| joint_module.graph.eliminate_dead_code() |
| joint_module.recompile() |
| |
| fx_g = joint_module.graph |
| |
| # add the CSE pass |
| if config.cse: |
| cse_graph = fx_graph_cse(fx_g) |
| joint_module.graph = cse_graph |
| joint_graph = joint_module.graph |
| |
| graph_has_recomputable_ops = has_recomputable_ops(joint_module) |
| graph_has_recomputable_rng_ops = has_recomputable_rng_ops(joint_module) |
| if graph_has_recomputable_ops: |
| joint_module = cleanup_recompute_tags(joint_module) |
| |
| def classify_nodes(joint_module): |
| name_to_node = get_name_to_node(joint_module.graph) |
| required_bw_nodes = set() |
| for node in joint_module.graph.nodes: |
| if node.op == "placeholder" and "tangents" in node.target: |
| required_bw_nodes.add(node) |
| if node in required_bw_nodes: |
| for user in node.users: |
| required_bw_nodes.add(user) |
| |
| primal_inputs = list(filter(_is_primal, joint_module.graph.nodes)) |
| fwd_seed_offset_inputs = list( |
| filter(_is_fwd_seed_offset, joint_module.graph.nodes) |
| ) |
| inputs = primal_inputs + fwd_seed_offset_inputs |
| fwd_outputs, bwd_outputs = _extract_fwd_bwd_outputs( |
| joint_module, num_fwd_outputs=num_fwd_outputs |
| ) |
| required_bw_nodes.update( |
| o for o in bwd_outputs if o is not None and o.op != "output" |
| ) |
| forward_only_graph = _extract_graph_with_inputs_outputs( |
| joint_module.graph, inputs, fwd_outputs |
| ) |
| required_fw_nodes: Set[fx.Node] = { |
| name_to_node[node.name] |
| for node in forward_only_graph.nodes |
| if node.op != "output" |
| } |
| unclaimed_nodes = { |
| node |
| for node in joint_module.graph.nodes |
| if node not in required_fw_nodes and node not in required_bw_nodes |
| } |
| fw_cnt = 0 |
| fw_order = {} |
| for node in joint_module.graph.nodes: |
| if node in required_fw_nodes: |
| fw_order[node] = fw_cnt |
| fw_cnt += 1 |
| return NodeInfo( |
| inputs, required_fw_nodes, required_bw_nodes, unclaimed_nodes, fw_order |
| ) |
| |
| node_info = classify_nodes(joint_module) |
| |
| # networkx blows up on graphs with no required backward nodes |
| # Since there's nothing to partition anyway, and the default partitioner can "handle" |
| # this case, send our graph over to the default partitioner. |
| if len(node_info.required_bw_nodes) == 0: |
| return default_partition( |
| joint_module, _joint_inputs, num_fwd_outputs=num_fwd_outputs |
| ) |
| |
| for node in reversed(joint_module.graph.nodes): |
| if node.op == "output": |
| node.dist_from_bw = int(1e9) |
| elif not node_info.is_required_fw(node): |
| node.dist_from_bw = 0 |
| else: |
| node.dist_from_bw = int(1e9) |
| for user in node.users: |
| node.dist_from_bw = min(node.dist_from_bw, user.dist_from_bw + 1) |
| |
| memory_budget = config.activation_memory_budget |
| for node in joint_graph.nodes: |
| if isinstance(node.meta.get("memory_budget", None), float): |
| memory_budget = node.meta["memory_budget"] |
| break |
| # print("Memory Budget: ", memory_budget) |
| saved_values = choose_saved_values_set( |
| joint_graph, node_info, memory_budget=memory_budget |
| ) |
| # save_for_backward on tensors and stashes symints in autograd .ctx |
| saved_sym_nodes = list(filter(is_sym_node, saved_values)) |
| saved_values = list(filter(lambda n: not is_sym_node(n), saved_values)) |
| |
| # NB: saved_sym_nodes will be mutated to reflect the actual saved symbols |
| fw_module, bw_module = _extract_fwd_bwd_modules( |
| joint_module, |
| saved_values, |
| saved_sym_nodes=saved_sym_nodes, |
| num_fwd_outputs=num_fwd_outputs, |
| ) |
| |
| if graph_has_recomputable_ops: |
| if graph_has_recomputable_rng_ops: |
| fw_module, bw_module = functionalize_rng_ops( |
| joint_module, fw_module, bw_module, len(saved_sym_nodes) |
| ) |
| bw_module = reordering_to_mimic_autograd_engine(bw_module) |
| |
| if AOT_PARTITIONER_DEBUG: |
| from torch._inductor.fx_utils import get_node_storage |
| |
| storages = {get_node_storage(node) for node in saved_values} |
| print( |
| "Theoretical Activations Stored: ", |
| sum(_size_of(i) for i in saved_values) / 1e9, |
| ) |
| sorted_sizes = sorted([(_size_of(i), str(i)) for i in saved_values]) |
| fw_module_nodes = { |
| node.name for node in fw_module.graph.nodes if node.op == "call_function" |
| } |
| bw_module_nodes = { |
| node.name for node in bw_module.graph.nodes if node.op == "call_function" |
| } |
| remat_nodes = fw_module_nodes & bw_module_nodes |
| |
| counts: Dict[str, int] = defaultdict(int) |
| for node in fw_module.graph.nodes: |
| if node.name in remat_nodes and hasattr(node.target, "_overloadpacket"): |
| counts[str(node.target._overloadpacket)] += 1 |
| print( |
| f"# remat/fw/bw: {len(remat_nodes)}/{len(fw_module_nodes)}/{len(bw_module_nodes)}" |
| ) |
| print( |
| "Count of Ops Rematerialized: ", |
| sorted(counts.items(), key=lambda x: x[1], reverse=True), |
| ) |
| return fw_module, bw_module |
| |
| |
| def draw_graph( |
| traced: torch.fx.GraphModule, |
| fname: str, |
| figname: str = "fx_graph", |
| clear_meta: bool = True, |
| prog: Optional[Union[str, List[str]]] = None, |
| parse_stack_trace: bool = False, |
| dot_graph_shape: Optional[str] = None, |
| ) -> None: |
| if clear_meta: |
| new_graph = copy.deepcopy(traced.graph) |
| traced = fx.GraphModule(traced, new_graph) |
| for node in traced.graph.nodes: |
| node.meta = {} |
| base, ext = os.path.splitext(fname) |
| if not ext: |
| ext = "." + config.torch_compile_graph_format |
| print(f"Writing FX graph to file: {base}{ext}") |
| g = graph_drawer.FxGraphDrawer( |
| traced, |
| figname, |
| parse_stack_trace=parse_stack_trace, |
| dot_graph_shape=dot_graph_shape, |
| ) |
| x = g.get_main_dot_graph() |
| write_method = getattr(x, "write_" + ext.lstrip(".")) |
| fname = f"{base}{ext}" |
| if prog is None: |
| write_method(fname) |
| else: |
| write_method(fname, prog=prog) |