| # mypy: allow-untyped-defs |
| import weakref |
| from typing import Any, cast, Dict, Iterable, List, NoReturn, Optional, Set, Tuple |
| |
| import torch |
| import torch.nn as nn |
| from torch.distributed._composable_state import _State |
| from torch.nn.parallel import DistributedDataParallel |
| |
| from .contract import _get_registry, contract |
| |
| |
| _ROOT_MODULE_PREFIX = "" |
| |
| |
| class _ReplicateState(_State): |
| def __init__(self) -> None: |
| super().__init__() |
| self.module: nn.Module = nn.ParameterList() |
| self.has_initialized: bool = False |
| self._param_list: nn.ParameterList = nn.ParameterList() |
| # TODO(@fegin): this variable is originally create for testing, we |
| # should remove this if possible. |
| self._orig_module = self.module |
| self._param_names: List[str] = [] |
| self._no_sync: bool = False |
| self._init_args: Optional[Tuple[Any, ...]] = None |
| self._init_kwargs: Dict[str, Any] = {} |
| self._comm_hook_args: List[Any] = [] |
| |
| def _collect_params( |
| self, |
| module: nn.Module, |
| ignored_modules: Set[nn.Module], |
| ignored_params: Set[nn.Parameter], |
| prefix: str = _ROOT_MODULE_PREFIX, |
| ) -> None: |
| # skip if managed by fully_sharded API |
| if _is_fully_sharded(module): |
| return |
| |
| # if a module is ignored, all descendants of the module are ignored. |
| if module in ignored_modules: |
| return |
| |
| recurse_prefix = ( |
| f"{prefix}." if prefix != _ROOT_MODULE_PREFIX else _ROOT_MODULE_PREFIX |
| ) |
| |
| for n, p in module.named_parameters(recurse=False): |
| if p not in ignored_params: |
| self._param_list.append(p) |
| self._param_names.append(f"{recurse_prefix}{n}") |
| |
| for name, child_module in module.named_children(): |
| self._collect_params( |
| child_module, |
| ignored_modules, |
| ignored_params, |
| prefix=f"{recurse_prefix}{name}", |
| ) |
| |
| def lazy_init(self) -> None: |
| @torch._disable_dynamo(recursive=True) |
| def _lazy_init(): |
| assert self._init_args is not None |
| self.init(*self._init_args, **self._init_kwargs) |
| self.register_comm_hook() |
| self._init_args = () |
| self._init_kwargs = {} |
| |
| _lazy_init() |
| |
| def init( |
| self, |
| module: nn.Module, |
| ignored_modules: Set[nn.Module], |
| **kwargs, |
| ) -> None: |
| if self.has_initialized: |
| return |
| |
| self.has_initialized = True |
| |
| device_mesh = kwargs.get("device_mesh", None) |
| self.module = module |
| ignored_params = {p for m in ignored_modules for p in m.parameters()} |
| from torch.distributed.tensor.parallel.ddp import _localize_dtensor |
| |
| _localize_dtensor(module) |
| self._collect_params(module, ignored_modules, ignored_params) |
| |
| if "device_id" in kwargs: |
| # replicate() supports a small usability enhancement where |
| # user can pass in device_id as a Union[int, torch.device] even for |
| # CPU devices so users don't have to change code for CPU/GPU runs. |
| # We derive the right device_ids to feed into DDP to support this. |
| if kwargs["device_id"] is not None: |
| device_id = kwargs["device_id"] |
| # Convert to device_ids that DDP expects. |
| if isinstance(device_id, torch.device) and device_id.type == "cpu": |
| # CPU modules receive device_ids None |
| kwargs["device_ids"] = None |
| else: |
| # GPU modules expect device_ids=[cuda_device] |
| kwargs["device_ids"] = [device_id] |
| else: |
| kwargs["device_ids"] = None |
| kwargs.pop("device_id") |
| |
| self._ddp = DistributedDataParallel(self._param_list, **kwargs) |
| # Weakref to the DDP instance is currently only used for testing. |
| replicate.state(self.module)._ddp_weakref = weakref.ref(self._ddp) |
| |
| def register_comm_hook(self) -> None: |
| for comm_args, comm_kwargs in self._comm_hook_args: |
| self._ddp.register_comm_hook(*comm_args, **comm_kwargs) |
| self._comm_hook_args.clear() |
| |
| def record_init_args(self, *args, **kwargs) -> None: |
| self._init_args = args |
| self._init_kwargs = kwargs |
| |
| def forward_pre_hook( |
| self, module: nn.Module, args: Tuple[Any, ...], kwargs: Dict[str, Any] |
| ) -> Any: |
| if self._init_args or self._init_kwargs: |
| self.lazy_init() |
| self._ddp.require_backward_grad_sync = not self._no_sync |
| return self._ddp._pre_forward(*args, **kwargs) |
| |
| def forward_post_hook( |
| self, |
| module: nn.Module, |
| input: Tuple[torch.Tensor], |
| output: torch.Tensor, |
| ) -> torch.Tensor: |
| return self._ddp._post_forward(output) |
| |
| |
| def unimplemented_deepcopy(*args: Any, **kwargs: Any) -> NoReturn: |
| raise AssertionError( |
| "DDP does not support deepcopy. Please use state dict for serialization." |
| ) |
| |
| |
| # Follow the same pattern as FSDP/fully_shard |
| class DDP: |
| def __new__(cls, *args, **kwargs): |
| """ |
| Override ``__new__`` to remove the DDP class and directly construct |
| the original class for cases like indexing into a container module. |
| """ |
| # Use index 2 since 0 is the dynamically constructed `DDP<...>` class |
| # and index 1 is the `DDP` class itself |
| orig_cls = cls.__mro__[2] |
| return orig_cls.__new__(orig_cls, *args, **kwargs) |
| |
| def set_requires_gradient_sync(self, requires_gradient_sync: bool) -> None: |
| """ |
| Sets if the module should sync gradients. This can be used to implement |
| gradient accumulation without communication. |
| |
| Args: |
| requires_gradient_sync (bool): Whether to reduce gradients for the |
| module's parameters. |
| """ |
| replicate.state(self)._no_sync = not requires_gradient_sync |
| |
| def register_comm_hook(self, *args, **kwargs) -> None: |
| replicate.state(self)._comm_hook_args.append((args, kwargs)) |
| |
| |
| @contract(state_cls=_ReplicateState) |
| def replicate( |
| module: nn.Module, |
| ignored_modules: Optional[Iterable[torch.nn.Module]] = None, |
| **kwargs, |
| ) -> nn.Module: |
| r"""Replicates a module |
| |
| Args: |
| module (torch.nn.Module): module to replicate |
| |
| Example:: |
| >>> # xdoctest: +REQUIRES(module:torch._C._distributed_c10d) |
| >>> module = nn.Linear(3, 3) |
| >>> replicate(module) |
| """ |
| torch._C._log_api_usage_once("torch.distributed.replicate") |
| |
| # TODO(fegin): using kwargs is not a good idea if we would like to make |
| # replicate a formal API to replace DDP. |
| if "device_id" in kwargs: |
| if not isinstance(kwargs["device_id"], (int, torch.device)): |
| raise RuntimeError( |
| "Expected device_id to be int or torch.device, " |
| f"but got {type(kwargs['device_id'])}" |
| ) |
| |
| if _is_fully_sharded(module): |
| raise RuntimeError( |
| "Cannot apply `replicate()` on a Module already managed by `fully_shard`" |
| ) |
| |
| if ignored_modules is None: |
| ignored_modules = {} |
| else: |
| ignored_modules = set(ignored_modules) |
| |
| state = cast(_ReplicateState, replicate.state(module)) |
| module.register_forward_pre_hook(state.forward_pre_hook, with_kwargs=True) |
| device_mesh = kwargs.get("device_mesh", None) |
| if device_mesh is not None: |
| from torch.distributed.device_mesh import _mesh_resources |
| |
| if _mesh_resources.get_parent_mesh(device_mesh) is not None: |
| # TODO: This is a temporary work around to enable DDP + TP. |
| # We should do the logic in DDP so that the 2D implementation is |
| # sound and the state_dict works out of the box. |
| # |
| # This won't conflict with what is done in DDP class as the module |
| # replicate is going to pass is NOT the original module. |
| from torch.distributed.tensor.parallel.ddp import ( |
| _localize_dtensor, |
| _reconstruct_dtensor, |
| ) |
| |
| module.register_forward_pre_hook(_reconstruct_dtensor) |
| module.register_forward_hook(_localize_dtensor) |
| |
| module.register_forward_hook(state.forward_post_hook) # type: ignore[arg-type] |
| |
| state.record_init_args(module, ignored_modules, **kwargs) |
| |
| # Place DDP leftmost for highest priority in the method resolution order |
| cls = module.__class__ |
| dct = {"__deepcopy__": unimplemented_deepcopy} |
| new_cls = type(f"DDP{cls.__name__}", (DDP, cls), dct) |
| module.__class__ = new_cls |
| return module |
| |
| |
| def _is_fully_sharded(module: nn.Module) -> bool: |
| r"""Check if module is marked with fully_shard.""" |
| registry = _get_registry(module) |
| if registry is None: |
| return False |
| return "fully_shard" in registry |