blob: 391022a580525660e8a7a69a5641146956cb88f9 [file] [log] [blame]
# mypy: allow-untyped-defs
import weakref
from typing import Any, cast, Dict, Iterable, List, NoReturn, Optional, Set, Tuple
import torch
import torch.nn as nn
from torch.distributed._composable_state import _State
from torch.nn.parallel import DistributedDataParallel
from .contract import _get_registry, contract
_ROOT_MODULE_PREFIX = ""
class _ReplicateState(_State):
def __init__(self) -> None:
super().__init__()
self.module: nn.Module = nn.ParameterList()
self.has_initialized: bool = False
self._param_list: nn.ParameterList = nn.ParameterList()
# TODO(@fegin): this variable is originally create for testing, we
# should remove this if possible.
self._orig_module = self.module
self._param_names: List[str] = []
self._no_sync: bool = False
self._init_args: Optional[Tuple[Any, ...]] = None
self._init_kwargs: Dict[str, Any] = {}
self._comm_hook_args: List[Any] = []
def _collect_params(
self,
module: nn.Module,
ignored_modules: Set[nn.Module],
ignored_params: Set[nn.Parameter],
prefix: str = _ROOT_MODULE_PREFIX,
) -> None:
# skip if managed by fully_sharded API
if _is_fully_sharded(module):
return
# if a module is ignored, all descendants of the module are ignored.
if module in ignored_modules:
return
recurse_prefix = (
f"{prefix}." if prefix != _ROOT_MODULE_PREFIX else _ROOT_MODULE_PREFIX
)
for n, p in module.named_parameters(recurse=False):
if p not in ignored_params:
self._param_list.append(p)
self._param_names.append(f"{recurse_prefix}{n}")
for name, child_module in module.named_children():
self._collect_params(
child_module,
ignored_modules,
ignored_params,
prefix=f"{recurse_prefix}{name}",
)
def lazy_init(self) -> None:
@torch._disable_dynamo(recursive=True)
def _lazy_init():
assert self._init_args is not None
self.init(*self._init_args, **self._init_kwargs)
self.register_comm_hook()
self._init_args = ()
self._init_kwargs = {}
_lazy_init()
def init(
self,
module: nn.Module,
ignored_modules: Set[nn.Module],
**kwargs,
) -> None:
if self.has_initialized:
return
self.has_initialized = True
device_mesh = kwargs.get("device_mesh", None)
self.module = module
ignored_params = {p for m in ignored_modules for p in m.parameters()}
from torch.distributed.tensor.parallel.ddp import _localize_dtensor
_localize_dtensor(module)
self._collect_params(module, ignored_modules, ignored_params)
if "device_id" in kwargs:
# replicate() supports a small usability enhancement where
# user can pass in device_id as a Union[int, torch.device] even for
# CPU devices so users don't have to change code for CPU/GPU runs.
# We derive the right device_ids to feed into DDP to support this.
if kwargs["device_id"] is not None:
device_id = kwargs["device_id"]
# Convert to device_ids that DDP expects.
if isinstance(device_id, torch.device) and device_id.type == "cpu":
# CPU modules receive device_ids None
kwargs["device_ids"] = None
else:
# GPU modules expect device_ids=[cuda_device]
kwargs["device_ids"] = [device_id]
else:
kwargs["device_ids"] = None
kwargs.pop("device_id")
self._ddp = DistributedDataParallel(self._param_list, **kwargs)
# Weakref to the DDP instance is currently only used for testing.
replicate.state(self.module)._ddp_weakref = weakref.ref(self._ddp)
def register_comm_hook(self) -> None:
for comm_args, comm_kwargs in self._comm_hook_args:
self._ddp.register_comm_hook(*comm_args, **comm_kwargs)
self._comm_hook_args.clear()
def record_init_args(self, *args, **kwargs) -> None:
self._init_args = args
self._init_kwargs = kwargs
def forward_pre_hook(
self, module: nn.Module, args: Tuple[Any, ...], kwargs: Dict[str, Any]
) -> Any:
if self._init_args or self._init_kwargs:
self.lazy_init()
self._ddp.require_backward_grad_sync = not self._no_sync
return self._ddp._pre_forward(*args, **kwargs)
def forward_post_hook(
self,
module: nn.Module,
input: Tuple[torch.Tensor],
output: torch.Tensor,
) -> torch.Tensor:
return self._ddp._post_forward(output)
def unimplemented_deepcopy(*args: Any, **kwargs: Any) -> NoReturn:
raise AssertionError(
"DDP does not support deepcopy. Please use state dict for serialization."
)
# Follow the same pattern as FSDP/fully_shard
class DDP:
def __new__(cls, *args, **kwargs):
"""
Override ``__new__`` to remove the DDP class and directly construct
the original class for cases like indexing into a container module.
"""
# Use index 2 since 0 is the dynamically constructed `DDP<...>` class
# and index 1 is the `DDP` class itself
orig_cls = cls.__mro__[2]
return orig_cls.__new__(orig_cls, *args, **kwargs)
def set_requires_gradient_sync(self, requires_gradient_sync: bool) -> None:
"""
Sets if the module should sync gradients. This can be used to implement
gradient accumulation without communication.
Args:
requires_gradient_sync (bool): Whether to reduce gradients for the
module's parameters.
"""
replicate.state(self)._no_sync = not requires_gradient_sync
def register_comm_hook(self, *args, **kwargs) -> None:
replicate.state(self)._comm_hook_args.append((args, kwargs))
@contract(state_cls=_ReplicateState)
def replicate(
module: nn.Module,
ignored_modules: Optional[Iterable[torch.nn.Module]] = None,
**kwargs,
) -> nn.Module:
r"""Replicates a module
Args:
module (torch.nn.Module): module to replicate
Example::
>>> # xdoctest: +REQUIRES(module:torch._C._distributed_c10d)
>>> module = nn.Linear(3, 3)
>>> replicate(module)
"""
torch._C._log_api_usage_once("torch.distributed.replicate")
# TODO(fegin): using kwargs is not a good idea if we would like to make
# replicate a formal API to replace DDP.
if "device_id" in kwargs:
if not isinstance(kwargs["device_id"], (int, torch.device)):
raise RuntimeError(
"Expected device_id to be int or torch.device, "
f"but got {type(kwargs['device_id'])}"
)
if _is_fully_sharded(module):
raise RuntimeError(
"Cannot apply `replicate()` on a Module already managed by `fully_shard`"
)
if ignored_modules is None:
ignored_modules = {}
else:
ignored_modules = set(ignored_modules)
state = cast(_ReplicateState, replicate.state(module))
module.register_forward_pre_hook(state.forward_pre_hook, with_kwargs=True)
device_mesh = kwargs.get("device_mesh", None)
if device_mesh is not None:
from torch.distributed.device_mesh import _mesh_resources
if _mesh_resources.get_parent_mesh(device_mesh) is not None:
# TODO: This is a temporary work around to enable DDP + TP.
# We should do the logic in DDP so that the 2D implementation is
# sound and the state_dict works out of the box.
#
# This won't conflict with what is done in DDP class as the module
# replicate is going to pass is NOT the original module.
from torch.distributed.tensor.parallel.ddp import (
_localize_dtensor,
_reconstruct_dtensor,
)
module.register_forward_pre_hook(_reconstruct_dtensor)
module.register_forward_hook(_localize_dtensor)
module.register_forward_hook(state.forward_post_hook) # type: ignore[arg-type]
state.record_init_args(module, ignored_modules, **kwargs)
# Place DDP leftmost for highest priority in the method resolution order
cls = module.__class__
dct = {"__deepcopy__": unimplemented_deepcopy}
new_cls = type(f"DDP{cls.__name__}", (DDP, cls), dct)
module.__class__ = new_cls
return module
def _is_fully_sharded(module: nn.Module) -> bool:
r"""Check if module is marked with fully_shard."""
registry = _get_registry(module)
if registry is None:
return False
return "fully_shard" in registry