blob: 81ddeb8bfe0ad8467e03eec7c903d31e43526c72 [file] [log] [blame]
# mypy: allow-untyped-defs
# Copyright (c) Meta Platforms, Inc. and affiliates
import copy
import logging
import operator
from collections import defaultdict
from enum import Enum
from inspect import Parameter, Signature, signature
from types import MethodType
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union
import torch
import torch.fx as fx
from torch.distributed import ProcessGroup
from torch.export import ExportedProgram
from torch.export.unflatten import (
_assign_attr,
_AttrKind,
_sink_params,
InterpreterModule,
)
from torch.fx.node import map_aggregate
from torch.fx.passes.split_module import split_module
from ._backward import _null_coalesce_accumulate, stage_backward
from ._unflatten import _outline_submodules
from ._utils import PipeInfo
from .stage import _PipelineStage
logger = logging.getLogger(__name__)
# TODO:
# 1. investigate gradient sync for shared parameters. how does DDP do it?
# 2. Add parameter movement to split_module
def _find_loss_from_output_and_spec(output_val, spec_val):
if spec_val is False:
return None
if spec_val is True:
if not isinstance(output_val, fx.Node):
raise RuntimeError(
f"Loss spec must specify a dynamic value but got {output_val}"
)
return output_val
if isinstance(spec_val, (tuple, list)):
if not isinstance(output_val, (tuple, list)):
raise RuntimeError(
f"Output value {output_val} must match type of loss specification "
f"{spec_val}"
)
if len(output_val) != len(spec_val):
raise RuntimeError(
f"Output value {output_val} must match length of loss specification "
f"{spec_val}"
)
for out, spec in zip(output_val, spec_val):
loss_val = _find_loss_from_output_and_spec(out, spec)
if loss_val is not None:
return loss_val
raise RuntimeError(f"Did not find loss value in specification {spec_val}")
if isinstance(spec_val, dict):
if not isinstance(output_val, dict):
raise RuntimeError(
f"Output value {output_val} must match type of loss specification "
f"{spec_val}"
)
if set(output_val.keys()) != set(spec_val.keys()):
raise RuntimeError(
f"Output value {output_val} must match keys of loss specification "
f"{spec_val}"
)
for k in spec_val:
loss_val = _find_loss_from_output_and_spec(output_val[k], spec_val[k])
if loss_val is not None:
return loss_val
raise RuntimeError(f"Did not find loss value in specification {spec_val}")
raise RuntimeError(f"Unsupported type {type(spec_val)} in loss specification")
def _find_loss_output(mod: torch.nn.Module, g: fx.Graph, output_loss_value_spec):
output_nodes = [n for n in g.nodes if n.op == "output"]
assert len(output_nodes) == 1
output_node = output_nodes[0]
output_val = output_node.args[0]
generated_spec: Any = None
if isinstance(mod, TrivialLossWrapper):
# TrivialLossWrapper is pre-defined by PiPPy.
# It has loss as the only output so we can safely assume the first output arg is the loss.
assert len(output_node.args) == 1
loss_node = output_val
generated_spec = TrivialLossWrapper.loss_spec
elif output_loss_value_spec is None:
# Use default spec, i.e. search for "loss" in output values
if isinstance(output_val, dict) and "loss" in output_val.keys():
loss_node = output_val["loss"]
generated_spec = {k: k == "loss" for k in output_val}
else:
loss_node = None
generated_spec = None
else:
loss_node = _find_loss_from_output_and_spec(output_val, output_loss_value_spec)
generated_spec = output_loss_value_spec
return loss_node, output_node, generated_spec
def _insert_stage_symbolic_backward(
g: fx.Graph,
loss_node: fx.Node,
output_node: fx.Node,
):
# Collect metadata about tuple output values. TODO: move this to split_module or FX IR
tuples: Dict[fx.Node, Tuple] = {}
for node in reversed(g.nodes):
if node.op == "call_function":
# In the forward pass, only emit placeholder, module calls, and
# getitem calls. If we have a target other than getitem in this
# (forward-only) code, there is a bug.
assert node.target == operator.getitem, (
"Found non-getitem call in forward pass. "
"Please report a bug to PiPPy"
)
assert (
len(node.args) == 2
), "Found malformed getitem call. Please report a bug to PiPPy"
indexed_value, node_idx = tuple(node.args)
# indexed_value is a collection that we are indexing into. It could
# exist in the tuples map if we've processed another `getitem`
# already.
existing_list_size = (
len(tuples[indexed_value]) if indexed_value in tuples else -1
)
new_list_size = max(node_idx + 1, existing_list_size)
reconstructed_list = [None for _ in range(new_list_size)]
# Copy over existing elements if present
if indexed_value in tuples:
for i, val in enumerate(tuples[indexed_value]):
reconstructed_list[i] = val
# Populate value represented by this node
reconstructed_list[node_idx] = node
tuples[indexed_value] = tuple(reconstructed_list)
# Keep track of nodes that dominate the loss node.
# We will only emit backward operations for nodes that can contribute
# to the specified loss value.
live_nodes = {loss_node: None}
val_to_grad: Dict[fx.Node, Optional[fx.Node]] = {loss_node: None}
def assign_or_accumulate_grad(forward_node, grad_value):
if forward_node in val_to_grad and forward_node.op != "placeholder":
grad_value = g.call_function(
_null_coalesce_accumulate,
(val_to_grad[forward_node], grad_value),
)
val_to_grad[forward_node] = grad_value
with g.inserting_before(output_node):
for node in reversed(g.nodes):
if node not in live_nodes:
continue
def add_to_live_nodes(n):
live_nodes.setdefault(n, None)
fx.node.map_arg(node.args, add_to_live_nodes)
fx.node.map_arg(node.kwargs, add_to_live_nodes)
if node.op == "call_module":
output_grads: Union[Tuple[Optional[fx.Node], ...], Optional[fx.Node]]
if node in tuples:
stage_output = tuples[node]
output_grads = tuple(val_to_grad.get(n, None) for n in tuples[node])
outputs_with_grads_idxs = [
i for i, n in enumerate(tuples[node]) if n in live_nodes
]
else:
stage_output = (node,)
output_grads = val_to_grad[node]
outputs_with_grads_idxs = [0]
output_grads = (
(output_grads,)
if not isinstance(output_grads, tuple)
else output_grads
)
grad_call = g.call_function(
stage_backward,
kwargs={
"stage_output": stage_output,
"output_grads": output_grads,
"input_values": list(node.all_input_nodes),
"outputs_with_grads_idxs": outputs_with_grads_idxs,
},
)
# Insert backward stage debug info
kwargs_copy = dict(grad_call.kwargs)
grad_call.kwargs = kwargs_copy
grad_call_proxy = fx.Proxy(grad_call)
grads = grad_call_proxy.node
input_nodes = list(node.all_input_nodes)
grads_proxy = fx.Proxy(grads)
for i, input_node in enumerate(input_nodes):
assign_or_accumulate_grad(input_node, grads_proxy[i].node)
return g
class PipeSequential(torch.nn.Sequential):
@staticmethod
def from_sequential(sequential_instance: torch.nn.Sequential):
return PipeSequential(*[copy.copy(m) for m in sequential_instance])
def forward(self, input):
for i, module in enumerate(self):
input = module(input)
if i != len(self) - 1:
pipe_split()
return input
class LossWrapper(torch.nn.Module):
"""
LossWrapper is a convenient abstract class that allows you to wrap up both
your model as well as its loss function and specify the connectivity between
the inputs, model, loss function, and output value. Example::
class MyModelWrapper(LossWrapper):
def forward(self, x, targets):
model_out = self.module(x)
loss_value = self.loss_fn(model_out, targets)
return loss_value
The above example defines a connectivity where we expect the forward/loss/backward
training procedure to take two arguments (x and targets), pass x into the module
to get the output of the feedforward computation, pass the model output and the
targets value into the loss function, and get and return the loss value, which will
be backpropagated by PiPPy. The above class would then be instantiated like::
model = ... # instantiate the model
loss_fn = torch.nn.MSELoss() # for the sake of demonstration
wrapper = MyModelWrapper(model, loss_fn)
pipe = Pipe.from_tracing(wrapper, ...)
"""
def __init__(self, module, loss_fn):
super().__init__()
self.module = module
self.loss_fn = loss_fn
def forward(self, *args, **kwargs):
raise NotImplementedError(
"This instance of LossWrapper does not have an overridden"
"forward(). Please implement forward() to specify the arguments, "
"connection between the module and loss, and loss output "
"value."
)
class TrivialLossWrapper(LossWrapper):
def forward(self, x, targets):
model_out = self.module(x)
return self.loss_fn(model_out, targets)
loss_spec = True
# Pipe model representation
#
# Pipe can be thought of as an `nn.Sequential++`. That is to say: it specifies
# a single topological ordering of pipeline "stages" that, when run in series,
# constitutes all of the operations of the program. However, unlike `nn.Sequential`,
# Pipe allows non-local usages of values, so long as those uses still respect
# topological ordering. In particular:
#
# 1. Non-local activations. This type of usage can appear in, for example, skip
# connections. These values will be directly transmitted from the "def" stage
# to all stages that use them skipping intermediate stages. During autograd,
# gradients will be propagated back through this skip connection reverse
# to how activations propagated in the forward pass.
# 2. Non-local parameter/module invocations. This occurs when a parameter is used
# in a stage downstream of where it is resident. These values can be carried
# forward similarly to (1), but in addition one might want to replicate the
# value on multiple stages. Gradients for these shared parameters will be
# accumulated separately on each stage, but there will be an additional
# gradient accumulation before the optimizer step.
# Register `_pipe_split()` as an ATen operator. This is required for Export to
# preserve this marker in the graph.
torch.library.define("pippy::_pipe_split", "() -> ()")
@torch.library.impl("pippy::_pipe_split", "BackendSelect")
def _pipe_split():
return None
@torch.library.register_fake("pippy::_pipe_split") # type: ignore[no-redef]
def _pipe_split(): # noqa: F811
return None
# Add an alias for convenience
aten_pipe_split_alias = torch.ops.pippy._pipe_split.default
# Ask Export to preserve the `_pipe_split` op.
# See examples in pytorch/torch/fx/node.py
fx.node._side_effectful_functions.add(aten_pipe_split_alias)
# User facing API
def pipe_split():
"""
pipe_split is a special operator that is used to mark the boundary between
stages in a module. It is used to split the module into stages. It is a
no-op if your annotated module is run eagerly.
Example:
>>> # xdoctest: +SKIP
>>> def forward(self, x):
>>> x = torch.mm(x, self.mm_param)
>>> x = torch.relu(x)
>>> pipe_split()
>>> x = self.lin(x)
>>> return x
The above example will be split into two stages.
"""
return torch.ops.pippy._pipe_split()
class MultiUseParameterConfig(Enum):
TRANSMIT = 1
REPLICATE = 2
MultiUseParamSpec = Union[MultiUseParameterConfig, Dict[str, MultiUseParameterConfig]]
class DetachExecutor(fx.Interpreter):
"""
Special interpreter to run the split_gm in testing that detaches all inputs to
a module invocation. This is needed so that the values at the boundary are
leaf modules in autograd execution.
"""
def __init__(self, module, garbage_collect_values=True):
garbage_collect_values = False
super().__init__(module, garbage_collect_values)
self.value_remap = {}
def run(self, *args, initial_env=None):
self.value_remap = {}
return super().run(*args, initial_env=initial_env)
def call_module(self, target, args, kwargs):
def detach_tensors(a):
if isinstance(a, torch.Tensor) and a.requires_grad:
if a not in self.value_remap:
new_val = a.detach().requires_grad_(True)
self.value_remap[a] = new_val
return self.value_remap[a]
else:
return a
"""
def dont_traverse_size(a):
return type(a) != torch.Size
"""
args = map_aggregate(
args,
detach_tensors, # dont_traverse_size
)
kwargs = map_aggregate(
kwargs,
detach_tensors, # dont_traverse_size
)
return super().call_module(target, args, kwargs)
def call_function(self, target, args, kwargs):
# HACK to reroute saved input tensors to point to the detach()ed version
if target == stage_backward:
kwargs = dict(kwargs)
kwargs["input_values"] = [
self.value_remap.get(v, v) for v in kwargs["input_values"]
]
return super().call_function(target, args, kwargs)
class _NodeReference:
def __init__(self, name):
self.name = name
name: str
class _LinearNodeList:
def __init__(self, node_list):
self.serialize_node_list = []
for node in node_list:
node_args = fx.node.map_arg(node.args, lambda n: _NodeReference(n.name))
node_kwargs = fx.node.map_arg(node.kwargs, lambda n: _NodeReference(n.name))
serialize_node = fx.Node(
graph=None,
name=node.name,
op=node.op,
target=node.target,
args=node_args,
kwargs=node_kwargs,
return_type=node.type,
)
serialize_node.meta = copy.copy(node.meta)
self.serialize_node_list.append(serialize_node)
def to_graph(self):
graph = fx.Graph()
ref_str_to_node: Dict[str, fx.Node] = {}
def ref_to_node(arg):
if isinstance(arg, _NodeReference):
return ref_str_to_node[arg.name]
else:
return arg
for node in self.serialize_node_list:
node_args = map_aggregate(node.args, ref_to_node)
node_kwargs = map_aggregate(node.kwargs, ref_to_node)
deser_node = graph.create_node(
op=node.op,
target=node.target,
args=node_args,
kwargs=node_kwargs,
name=node.name,
type_expr=node.type,
)
ref_str_to_node[node.name] = deser_node
return graph
def _direct_serialization_deserialize(body, nodes):
"""
Custom `__reduce__` method for serialization.
DO AS I SAY -- NOT AS I DO. This violates the principle that
GraphModules serialize via code export & re-tracing. We allow
for this here because **PIPE STAGES SHOULD NOT BE PERSISTED
TO DISK -- THIS IS ONLY FOR TRANSMISSION VIA RPC**. Persisting
these instances to disk will expose internal implementation
details of `fx.Graph` and related data structures and is
NOT advised.
"""
class DummyModule(torch.nn.Module):
def __init__(self, body):
super().__init__()
self.__dict__.update(body)
dummy = DummyModule(body)
return fx.GraphModule(dummy, nodes.to_graph())
def _direct_serialization_reduce(self):
serialization_dict = dict(self.__dict__)
serialization_dict.pop("_graph")
return (
_direct_serialization_deserialize,
(serialization_dict, _LinearNodeList(self.graph.nodes)),
)
def _modify_graph_op_device(
gm: torch.fx.GraphModule,
new_device: torch.device,
):
"""
Modify the device argument of all "call_function" nodes in the graph. This
is useful for moving the graph to a different device. In particular for
generator ops, like torch.ones.
"""
modified = False
for node in gm.graph.nodes:
if node.op == "call_function":
if "device" in node.kwargs and node.kwargs["device"] != new_device:
logger.debug(
f"Changing device of Node {node.name} from {node.kwargs['device']} to {new_device}" # noqa: G004
)
node.update_kwarg("device", new_device)
modified = True
elif node.op == "call_module":
# Recursively modify "device" in submodules
submod = gm.get_submodule(node.target)
if isinstance(submod, torch.fx.GraphModule):
_modify_graph_op_device(submod, new_device)
elif isinstance(submod, InterpreterModule):
# If unflattening has been performed, we need to access its graph module by `.graph_module`
_modify_graph_op_device(submod.graph_module, new_device)
else:
logger.warning(
f"Skipping device modification for submodule {node.target} because it is a {type(submod)}" # noqa: G004
)
if modified:
gm.recompile()
class Pipe(torch.nn.Module):
def __init__(
self,
split_gm: fx.GraphModule,
num_stages: int,
has_loss_and_backward: bool,
loss_spec,
):
# TODO: is there a way not to hard wire init?
torch.nn.Module.__init__(self)
self.split_gm: fx.GraphModule = split_gm
self.executor: DetachExecutor = DetachExecutor(self.split_gm)
self.num_stages: int = num_stages
self.has_loss_and_backward = has_loss_and_backward
self.loss_spec = loss_spec
for node in split_gm.graph.nodes:
assert (
node.op in {"call_module", "placeholder", "output"}
or (node.op, node.target) == ("call_function", operator.getitem)
or (node.op, node.target) == ("call_method", "backward")
or (node.op, node.target) == ("call_function", stage_backward)
or (node.op, node.target)
== ("call_function", _null_coalesce_accumulate)
), node
# Detect replicated parameters so we know that we have to do an additional allreduce
# before applying the optimizer
#
# Note that this also handles the case where there were multiple calls to a single
# module from different stages, regardless of whether that module invocation
# was handled by the logic above.
# Map parameter value to a dictionary that maps the user pipeline module
# to the local qualname within that module
params_to_users: Dict[torch.nn.Parameter, Dict[str, str]] = {}
for m_qualname, mod in self.split_gm.named_children():
for p_qualname, param in mod.named_parameters():
params_to_users.setdefault(param, {})
params_to_users[param][m_qualname] = p_qualname
self.replicated_params: List[Dict[str, str]] = [
use_mapping
for _, use_mapping in params_to_users.items()
if len(use_mapping) > 1
]
# We must break the aliasing relationship between the replicated parameters for correct
# numerics in reference runs. If we do not do this, the autograd tape in separate stages
# will have a reference to the same tensor value and will erroneously apply gradient
# updates multiple times. Therefore, for each replicated parameter set, we deepcopy the
# values so that we have separate instances.
for param_mapping in self.replicated_params:
for submod_name, param_qualname in param_mapping.items():
submod = getattr(self.split_gm, submod_name)
atoms = param_qualname.split(".")
for atom in atoms[:-1]:
submod = getattr(submod, atom)
setattr(submod, atoms[-1], copy.deepcopy(getattr(submod, atoms[-1])))
def throw(self, *args, **kwargs):
raise RuntimeError(
"To run pipeline locally, invoke the Pipe object directly, not `split_gm`"
)
self.split_gm.forward = throw
# Make submodules use custom direct-serialized GraphModule
i = 0
while True:
try:
name = f"submod_{i}"
submod = getattr(self.split_gm, name)
submod.__class__.__reduce__ = _direct_serialization_reduce
i += 1
except AttributeError:
break
def forward(self, *args, **kwargs):
executor_args = args
if len(kwargs) > 0:
parameters = []
for node in self.split_gm.graph.nodes:
if node.op == "placeholder":
if node.args and len(node.args) > 0:
parameters.append(
Parameter(
node.target,
Parameter.POSITIONAL_OR_KEYWORD,
default=node.args[0],
)
)
else:
parameter_kind = Parameter.POSITIONAL_OR_KEYWORD
param_name = node.target
if node.target.startswith("**"):
parameter_kind = Parameter.VAR_KEYWORD # type: ignore[assignment]
param_name = param_name[2:]
elif node.target.startswith("*"):
parameter_kind = Parameter.VAR_POSITIONAL # type: ignore[assignment]
param_name = param_name[1:]
parameters.append(Parameter(param_name, parameter_kind))
signature = Signature(parameters)
ba = signature.bind(*args, **kwargs)
ba.apply_defaults()
executor_args = ba.arguments.values() # type: ignore[assignment]
res = self.executor.run(*executor_args)
return res
def get_stage_module(self, stage_idx: int) -> torch.nn.Module:
"""
Return a stage module corresponding to `stage_idx` of the `pipe`.
"""
if stage_idx < 0 or stage_idx >= self.num_stages:
raise ValueError(f"Invalid stage index {stage_idx}!")
return getattr(self.split_gm, f"submod_{stage_idx}")
@staticmethod
def _number_and_count_forward_stages(gm: fx.GraphModule):
num_stages = 0
found_idxs: Dict[int, None] = {}
for node in gm.graph.nodes:
if node.op == "call_module" and node.target.startswith("submod_"):
node.meta["stage_idx"] = int(node.target[len("submod_") :])
found_idxs.setdefault(node.meta["stage_idx"])
num_stages += 1
# this assert will fail if a split point is inserted before the first layer, which creates empty first submodule
# Update: the following assert may fail against some torch versions >=
# 2.2.0, as:
# submod_0, submod_1, submod_2, ...
# may be named as
# submod_0, submod_2, submod_4, ...
# TODO: investigate
# assert all(i in found_idxs for i in range(num_stages))
return num_stages
@staticmethod
def _from_traced(
mod: torch.nn.Module,
exported_program: ExportedProgram,
multi_use_param_spec: Optional[MultiUseParamSpec] = None,
output_loss_value_spec=None,
split_policy: Optional[
Callable[[torch.fx.GraphModule], torch.fx.GraphModule]
] = None,
):
"""
Additionally, the ``output_loss_value_spec`` value can be specified to disambiguate
which value in the output of `forward` is the loss value on which PiPPy should apply
backpropagation. For example, if your ``forward`` returns a tuple ``(loss, model_out)``,
you can specify ``output_loss_value_spec=(True, False)``. Or, if your ``forward`` returns
a dict ``{'loss': loss_value, 'model_out': model_out}``, you can specify
``output_loss_value_spec={'loss': True, 'model_out': False}``
"""
traced = exported_program.module()
if split_policy is not None:
logger.info("Auto-splitting model")
traced = split_policy(traced) # type: ignore[arg-type]
logger.debug(traced.print_readable(print_output=False))
# Deduplicate `get_attr` nodes that refer to the same parameter . Downstream code for moving
# parameters relies on the invariant that parameter accesses happen once. This is not necessarily
# the case (especially with custom tracers), so fix that up here.
get_attr_nodes: Dict[str, fx.Node] = {}
for node in traced.graph.nodes:
if node.op == "get_attr":
get_attr_nodes.setdefault(node.target, node)
if get_attr_nodes[node.target] != node:
node.replace_all_uses_with(get_attr_nodes[node.target])
traced.graph.erase_node(node)
# avoid looking at next node by keeping track of previous pipe_split
prev_pipe_split_idx = -1
pipe_split_nodes_to_erase = set()
for i, node in enumerate(traced.graph.nodes):
if (node.op, node.target) == ("call_function", pipe_split):
if prev_pipe_split_idx == i - 1:
pipe_split_nodes_to_erase.add(node)
prev_pipe_split_idx = i
for node in pipe_split_nodes_to_erase:
traced.graph.erase_node(node)
traced.recompile()
part_idx = 0
def split_callback(n: fx.Node):
nonlocal part_idx
if (n.op, n.target) == (
"call_function",
aten_pipe_split_alias,
):
logger.debug(f"Found pipe_split {part_idx}") # noqa: G004
part_idx += 1
return part_idx
# TODO: what does split do with module invocations? does it move the modules
# into the submodules?
split = split_module(traced, mod, split_callback)
# a (custom) tracer can produce dead code like orphan get_attr nodes
split.graph.eliminate_dead_code()
# peephole to remove pipe_split
for submodule in split.modules():
if isinstance(submodule, fx.GraphModule):
for node in submodule.graph.nodes:
if (node.op, node.target) == (
"call_function",
aten_pipe_split_alias,
):
submodule.graph.erase_node(node)
submodule.recompile()
for name, submodule in split.named_children():
if isinstance(submodule, fx.GraphModule):
new_submod = _outline_submodules(submodule.graph)
# Replace old submod
split.register_module(name, new_submod)
# TODO: backport this into split_module
def delete_user_reference(node, user):
"""
Delete reference of `node` from `user`'s arg list.
Args:
- node: a `get_attr` node at root.
- user: a submodule node that uses `node`.
"""
assert len(user.kwargs) == 0
use_idxs = [i for i, arg in enumerate(user.args) if arg == node]
assert len(use_idxs) == 1
args_copy = list(user.args)
args_copy.pop(use_idxs[0])
user.args = tuple(args_copy)
logger.debug(
f"Deleted {node} from user {user}, arg index = {use_idxs[0]}" # noqa: G004
)
# A list of param referrals for deferred deletion.
# To be accumulated in `move_param_to_callee`.
to_delete = list()
def _recursive_getattr_with_parent(mod, fqn):
# Returns getattr call given a nested FQN, and the last parent
atoms = fqn.split(".")
for atom in atoms[:-1]:
if not hasattr(mod, atom):
return None, None
mod = getattr(mod, atom)
if not hasattr(mod, atoms[-1]):
return mod, None
attr = getattr(mod, atoms[-1])
return mod, attr
def move_param_to_callee(
root,
callee_name,
param_fqn,
):
"""
Move a parameter from the root module to a submodule.
Args:
root: The root module.
callee_name: The name of the submodule to move the parameter to.
param_fqn: The fully qualified name of the parameter to move.
"""
# `atoms` is a list of strings representing the path to the
# parameter in the original model
atoms = param_fqn.split(".")
mod_itr, param_val = _recursive_getattr_with_parent(split, param_fqn)
# Check whether the parameter is a buffer or a parameter
is_buffer = atoms[-1] in mod_itr._buffers
# Check whether the parameter is a tensor
assert isinstance(param_val, torch.Tensor), (
f"Expected '{param_fqn}' to be {torch.Tensor} but got {type(param_val)}."
+ (
f" It might happen if module '{param_fqn}' was passed to some 'leaf function'"
f"(see https://pytorch.org/docs/stable/fx.html#fx.wrap). Please inspect "
f"usages of '{param_fqn}' in the traced graph."
if isinstance(param_val, torch.nn.Module)
else ""
)
)
# Get submodule
callee = root.get_submodule(callee_name)
assert not hasattr(
callee, param_fqn
), f"Module {callee_name} already has a parameter named {param_fqn}"
# Assign the parameter to the submodule
if is_buffer:
_assign_attr(
param_val,
callee,
param_fqn,
attr_kind=_AttrKind.BUFFER,
persistent=True, # TODO: handle non-persistent buffer
)
else:
_assign_attr(
param_val,
callee,
param_fqn,
attr_kind=_AttrKind.PARAMETER,
)
logger.debug(f"Moved parameter {param_fqn} to {callee_name}") # noqa: G004
# Next step is to replace placeholder of submodule with a get_attr.
# Those placeholders are created by `split_module` inside each
# submodule.
# Update: this step is now moved to `_sink_params` because
# `_sink_params` can do it recursively (i.e. for modules inside
# submodule)
to_delete.append((mod_itr, atoms[-1]))
# Get the list of all parameters in the root module
attr_nodes = list(filter(lambda n: n.op == "get_attr", split.graph.nodes))
for node in attr_nodes:
# Check whether the parameter is used in only one submodule
if len(node.users) > 1:
logger.info(
f"Parameter {node.target} used in multiple stages: {node.users}." # noqa: G004
)
for user in node.users:
assert user.op == "call_module"
# Move parameter into submodule
move_param_to_callee(
split,
user.target,
node.target,
)
# [aliasing] store tensor id -> list of FQNs, built from state dict
# Also assign non-persistent buffers
id_to_fqns: Dict[int, Set[str]] = defaultdict(set)
for fqn, tensor in mod.state_dict(keep_vars=True).items():
id_to_fqns[id(tensor)].add(fqn)
for fqn, tensor in mod.named_buffers():
id_to_fqns[id(tensor)].add(fqn)
# After moving the params to their corresponding hierarchies, we also
# need to move the `get_attr` nodes from the root of the graph to those
# hierarchies.
# [aliasing] use id -> fqn mapping to list out all valid FQNs
inputs_to_state: Dict[str, List[str]] = {}
for attr in attr_nodes:
_, tensor = _recursive_getattr_with_parent(mod, attr.target)
fqns = list(id_to_fqns[id(tensor)])
if fqns:
inputs_to_state[attr.name] = fqns
elif attr.target in exported_program.constants: # lifted constants
inputs_to_state[attr.name] = [attr.target]
# [aliasing] for each submodule split, assign attributes on FQNs that may be used.
# We determine this based on whether or not the FQN attribute parent exists.
# i.e. if the last submodule exists, assign the attribute.
added_attributes: Dict[str, List[str]] = defaultdict(list)
for fqn, tensor in mod.state_dict(keep_vars=True).items():
for name, submod in split.named_children():
if isinstance(submod, fx.GraphModule):
parent, child = _recursive_getattr_with_parent(submod, fqn)
if (
parent and child is None
): # parent exists, attribute doesn't -> assign
added_attributes[name].append(fqn)
setattr(parent, fqn.split(".")[-1], tensor)
# Deferral deletion: Remove the original attributes (to params) from the
# root GraphModule
for mod_itr, last_atom in to_delete:
try:
delattr(mod_itr, last_atom)
except AttributeError:
# This is expected if the parameter is used in multiple stages
pass
# This is done by (1) `_sink_params` at each submodule;
for name, submod in split.named_children():
if isinstance(submod, fx.GraphModule):
_sink_params(submod, inputs_to_state, [])
submod.graph.lint()
submod.recompile()
# [aliasing] This step is not super necessary, but helps reduce parameter usage/memory.
# After _sink_params() routine has run, clean up unused attributes that we previously added.
# Determine this based on the get_attr nodes - if not used, remove it.
for name, attributes in added_attributes.items():
submod = getattr(split, name)
unused_attributes = set(attributes)
# track used attributes in the submodule, running DFS on subgraph hierarchy
stack = [("", submod)] # (scope, submodule)
while stack:
scope, _mod = stack.pop()
if isinstance(_mod, (fx.GraphModule, InterpreterModule)):
for node in _mod.graph.nodes:
if node.op == "get_attr":
# get_attr might get access deeper level attribute
fqn = scope + "." + node.target if scope else node.target
if fqn in unused_attributes: # used, remove it
unused_attributes.remove(fqn)
for _name, _submod in _mod.named_children():
stack.append((scope + "." + _name if scope else _name, _submod))
# delete unused attributes
for attr in unused_attributes:
mod_itr, atoms = submod, attr.split(".")
for atom in atoms[:-1]:
mod_itr = getattr(mod_itr, atom)
delattr(mod_itr, atoms[-1])
for node in attr_nodes:
# And (2): remove `get_attr` node from submod's arg list
for user in copy.copy(node.users):
assert user.op == "call_module"
delete_user_reference(node, user)
# And (3): remove the `get_attr` node from the root graph.
split.graph.erase_node(node)
split.delete_all_unused_submodules()
split.graph.lint()
split.recompile()
num_stages = Pipe._number_and_count_forward_stages(split)
has_loss_and_backward = False
generated_loss_spec = output_loss_value_spec
if output_loss_value_spec is not None:
loss_node, output_node, generated_loss_spec = _find_loss_output(
mod, split.graph, output_loss_value_spec
)
if loss_node is not None:
_insert_stage_symbolic_backward(
split.graph,
loss_node,
output_node,
)
split.recompile()
has_loss_and_backward = True
logger.debug("Pipeline is in training mode, backward pass generated")
else:
raise RuntimeError(
f"Did not find any loss value according to {output_loss_value_spec=}"
)
else:
logger.debug("Pipeline is in inference mode, backward pass not generated")
logger.debug("Full pipe model:\n" f"{split}") # noqa: G004
return Pipe(
split,
num_stages,
has_loss_and_backward,
generated_loss_spec,
)
def print_readable(self):
"""
Print the pipe in a human-readable format.
This will print both the root pipe and each stage module.
"""
self.split_gm.print_readable()
@staticmethod
def _trace_with_export(
mod: torch.nn.Module,
example_args: Tuple[Any, ...],
example_kwargs: Optional[Dict[str, Any]] = None,
) -> ExportedProgram:
logger.info("Tracing model ...")
try:
ep = torch.export.export(
mod,
example_args,
example_kwargs,
)
except Exception as e:
raise RuntimeError(
"It seems that we cannot capture your model as a full graph. "
"Typical reasons include graph breaks, data/shape-dependent "
"control flow, or missing meta kernels for custom operators. "
"You can use our manual pipeline interfaces, or try to fix the "
"graph breaks, see https://pytorch.org/docs/stable/export.html"
) from e
return ep
@staticmethod
def from_tracing(
mod: torch.nn.Module,
example_args: Tuple[Any, ...],
example_kwargs: Optional[Dict[str, Any]] = None,
split_policy: Optional[Callable[[fx.GraphModule], fx.GraphModule]] = None,
):
# If a param will be used in multiple pipeline stages, we default the strategy to REPLICATE'ing the param across
# stages instead of TRANSMIT'ting it
multi_use_param_spec = MultiUseParameterConfig.REPLICATE
# Figure out which output is loss from output_chunk_spec
output_loss_value_spec: Any = None
# Deprecated
"""
if output_chunk_spec is not None:
output_loss_value_spec = map_aggregate(
output_chunk_spec, lambda v: isinstance(v, _LossReducer)
)
"""
# Trace with export
exported_program = Pipe._trace_with_export(
mod,
example_args,
example_kwargs,
)
pipe = Pipe._from_traced(
mod,
exported_program,
multi_use_param_spec,
output_loss_value_spec=output_loss_value_spec,
split_policy=split_policy,
)
# Users want the first pipeline stage to accept kwargs if the original
# program does. This is controlled by the `_codegen` field of the graph,
# so we make a copy here. Note: we only want the input spec and not the
# output spec, because the output spec is for the last stage. Maybe a
# TODO? Not sure yet.
split = pipe.split_gm
traced = exported_program.module()
submod0 = next(iter(split.children()))
submod0_sign = signature(submod0.forward)
model_sign = signature(traced.forward)
if len(model_sign.parameters) != len(submod0_sign.parameters):
# We don't change the signature of the first stage if it takes
# different number of args than original model
logger.info(
f"Original model takes {len(model_sign.parameters)} args but the " # noqa: G004
f"first pipeline stage takes {len(submod0_sign.parameters)}. "
"Please provide args to respective pipeline stages."
)
else:
# Support kwargs for the first stage
submod0.graph._codegen = copy.deepcopy(traced.graph._codegen)
# `_replace` is actually not "private" or internal. based on this doc:
# To prevent conflicts with field names, the method and attribute names
# start with an underscore
submod0.graph._codegen.pytree_info = (
submod0.graph._codegen.pytree_info._replace(out_spec=None)
)
submod0.recompile()
return pipe
def __str__(self):
return self.split_gm.__str__()
def __repr__(self):
return self.split_gm.__repr__()
def info(self) -> PipeInfo:
"""
Get information about the pipe.
Returns
-------
PipeInfo
A dataclass containing information about the pipe.
"""
return PipeInfo(
graph=self.split_gm.graph,
num_stages=self.num_stages,
has_loss_and_backward=self.has_loss_and_backward,
)
def build_stage(
self,
stage_index: int,
device: torch.device,
group: Optional[ProcessGroup] = None,
) -> _PipelineStage:
"""
Create a `PipelineStage` given a stage index and distributed group.
The `PipelineStage` can run with `PipelineSchedule`s.
"""
# Find stage module
stage_module = self.get_stage_module(stage_index)
# Move ops argument to device
# Today PT2 tracer does not treat `x.device` as a symbolic device;
# instead, the device of tracing time got burned into the generated
# code. Here we provide a workaround for users to manually modify the
# "device" kwarg of operations. Such operation may include:
# `torch.ones`, `torch.zeros`, `torch.rand`, etc.
if isinstance(stage_module, torch.fx.GraphModule):
_modify_graph_op_device(stage_module, device)
else:
logger.warning(
f"Expected a `torch.fx.GraphModule` but got {type(stage_module)}" # noqa: G004
)
# Detach pipe info
# Note: be careful what's included in `pipe_info`. We don't want to keep
# a reference to `Pipe` or `Pipe.split_gm` which stops python from
# recycling them. When python recycles them, other stage modules (which
# are irrelevant to current rank) can be automatically freed.
pipe_info = self.info()
return _PipelineStage(stage_module, stage_index, pipe_info, device, group)
class SplitPoint(Enum):
BEGINNING = 1
END = 2
# For backward compatibility, we kept the PipeSplitWrapper class because `class
# SplitPoint` used to be defined in this class.
class PipeSplitWrapper:
# Create a class alias for BC
SplitPoint = SplitPoint
def _split_before_forward(self, *args, **kwargs):
pipe_split()
return self._orig_forward(*args, **kwargs)
def _split_after_forward(self, *args, **kwargs):
try:
return self._orig_forward(*args, **kwargs)
finally:
pipe_split()
def annotate_split_points(mod: torch.nn.Module, spec: Dict[str, SplitPoint]):
# TODO: make this implementation out-of-place?
for qualname, split_type in spec.items():
atoms = qualname.split(".")
predecessor_module = mod
for i, atom in enumerate(atoms[:-1]):
try:
predecessor_module = getattr(predecessor_module, atom)
except AttributeError as e:
raise AttributeError(
f"Specified target {qualname} referenced "
f'nonexistent module {".".join(atoms[: i + 1])}'
) from e
mod_to_wrap = getattr(predecessor_module, atoms[-1])
mod_to_wrap._orig_forward = mod_to_wrap.forward
if split_type == SplitPoint.BEGINNING:
mod_to_wrap.forward = MethodType(_split_before_forward, mod_to_wrap)
elif split_type == SplitPoint.END:
mod_to_wrap.forward = MethodType(_split_after_forward, mod_to_wrap)
else:
raise ValueError("Unknown split point type.")
def pipeline(
module: torch.nn.Module,
mb_args: Tuple[Any, ...],
mb_kwargs: Optional[Dict[str, Any]] = None,
split_spec: Optional[Dict[str, SplitPoint]] = None,
split_policy: Optional[Callable[[fx.GraphModule], fx.GraphModule]] = None,
) -> Pipe:
"""
Split a module based on a specification.
See `Pipe` for more details.
Arguments
---------
module:
The module to be splitted.
mb_args:
Example positional inputs, in micro-batch form.
mb_kwargs:
Example keyword inputs, in micro-batch form. (default: `None`)
split_spec:
A dictionary using submodule names as split marker. (default: `None`)
split_policy:
The policy to use for splitting the module. (default: `None`)
Returns
-------
A pipeline representation of class `Pipe`.
"""
if split_spec is not None and split_policy is not None:
raise ValueError(
"Cannot specify both `split_spec` and `split_policy`. Please use only one of them."
)
if split_spec is not None:
# Annotate split points in the module based on user spec
annotate_split_points(module, split_spec)
return Pipe.from_tracing(
mod=module,
example_args=mb_args,
example_kwargs=mb_kwargs,
)
else:
# Use split policy
return Pipe.from_tracing(
mod=module,
example_args=mb_args,
example_kwargs=mb_kwargs,
split_policy=split_policy,
)