blob: e2daa51c9aed9dd0e396f77934ff99139693ae82 [file] [log] [blame]
import contextlib
import functools
import logging
import os
import warnings
from enum import auto, Enum
from itertools import accumulate, chain
from typing import (
Any,
Callable,
Dict,
Generator,
Iterator,
List,
NamedTuple,
no_type_check,
Optional,
Sequence,
Set,
Tuple,
Union,
)
import torch
import torch.distributed as dist
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor
from torch.distributed._tensor import DTensor
from torch.distributed.fsdp._common_utils import (
_set_fsdp_flattened,
HandleTrainingState,
)
from torch.distributed.utils import _alloc_storage, _free_storage, _p_assert
from ._fsdp_extensions import _ext_post_unflatten_transform, _ext_pre_flatten_transform
from ._utils import _no_dispatch_record_stream, _same_storage_as_data_ptr
__all__ = [
"FlatParameter",
"FlatParamHandle",
"FlatParamShardMetadata",
"ParamInfo",
"SharedParamInfo",
"HandleShardingStrategy",
]
log = logging.getLogger(__name__)
"""
[Note: Fully Sharded Module]
We define the "fully sharded module" to be the original ``nn.Module`` that owns
a ``FlatParamHandle``. It is the *single* module logically responsible for the
*single* unshard/reshard pair for the handle's ``FlatParameter`` for a given
forward or backward pass. The fully sharded module should be passed to the
``FlatParamHandle`` constructor.
For the wrapper code path:
- The ``FullyShardedDataParallel`` module wrapping the fully sharded module
runs the unshard/reshard on behalf of the fully sharded module by overriding
``nn.Module.forward``.
- The fully sharded module is exactly the module passed to the
``FullyShardedDataParallel`` constructor's ``module`` argument.
For the non-wrapper code path:
- Hooks registered on the fully sharded module run the unshard/reshard.
- The fully sharded module may either be the direct argument to ``fully_shard``
or a submodule chosen by the provided wrapping policy.
"""
# Environment variable toggling whether to use unsafe `setattr()` for view
# setting in `_use_sharded_views()` and `_use_unsharded_views()`
# We should use 'safe' by default since it respects method overrides, but for
# special cases such as for high CPU overhead or for intentionally bypassing
# checks in the overrides, we may use 'unsafe'.
_FSDP_USE_UNSAFE_SETATTR = "FSDP_USE_UNSAFE_SETATTR"
# Some value to set padding in tensors to for debuggability
_FLAT_PARAM_PADDING_VALUE = 42
class ParamInfo(NamedTuple):
"""Information for an original parameter."""
param_name: str # unprefixed
module: nn.Module
module_name: str
class SharedParamInfo(NamedTuple):
"""
Additional information for a shared parameter.
For each shared parameter, we designate one module and its parameter
variable to be the primary owner, determined as the first one encountered
in the parameter walk. These are prefixed with "prim". The primary module
and parameter do not have their own :class:`SharedParamInfo` instance.
"""
param_name: str # unprefixed
module: nn.Module
module_name: str
prim_param_name: str # unprefixed
prim_module: nn.Module
prim_module_name: str
class _ShardParamInfo(NamedTuple):
"""Shard-related information for an original parameter."""
in_shard: bool
# Use to index into the sharded flat parameter, e.g.
# `flat_param[offset_in_shard : offset_in_shard + numel_in_shard]`
offset_in_shard: Optional[int]
numel_in_shard: Optional[int]
# Use to get part of the parameter in the local shard from a flattened
# version of the unsharded parameter, e.g.
# `param.flatten()[intra_param_start_idx : intra_param_end_idx + 1]`
intra_param_start_idx: Optional[int]
intra_param_end_idx: Optional[int] # inclusive
class FlatParamShardMetadata(NamedTuple):
"""
This holds metadata specific to this rank's shard of the flat parameter.
Attributes:
param_names (Tuple[str, ...]): Prefixed parameter names of this rank's
shard of the parameters; see :class:`FlatParameter`.
param_shapes (Tuple[torch.Size, ...]): Parameter shapes of this rank's
shard of the parameters; see :class:`FlatParameter`.
param_numels (Tuple[int, ...]): Parameter numels of this rank's shard
of the parameters; see :class:`FlatParameter`.
param_offsets (Tuple[Tuple[int, int], ...]): [start, end] offsets (in
units of numels) giving this rank's part of each flattened
original parameter.
"""
param_names: Tuple[str, ...]
param_shapes: Tuple[torch.Size, ...]
param_numels: Tuple[int, ...]
param_offsets: Tuple[Tuple[int, int], ...]
# TODO (awgu): Prefix these with "Handle" for now to avoid circular imports and
# inadvertent misuses; coalesce with those in fully_sharded_data_parallel.py
# later
class HandleShardingStrategy(Enum):
FULL_SHARD = auto()
SHARD_GRAD_OP = auto()
NO_SHARD = auto()
HYBRID_SHARD = auto()
_HYBRID_SHARD_ZERO2 = auto()
class FlatParameter(nn.Parameter):
"""
This is the flat parameter used by :class:`FullyShardedDataParallel`. It is
comprised of one or more original parameters, which are flattened and
concatenated to construct the flat parameter.
Under the current design, this parameter logically represents both the
unsharded and sharded flat parameter, and its data changes storages
dynamically.
- In the :class:`FullyShardedDataParallel` constructor, the parameter
is initialized as unsharded and then sharded in-place.
- At runtime, the parameter is lazily (re)-initialized. The sharded
parameter data is saved in ``self._local_shard``, and a new ``Tensor``
``self._full_param_padded`` is created, which is the all-gather
destination and owns the unsharded parameter storage thereafter. (See
:meth:`FlatParamHandle.init_flat_param_attributes`.)
- Throughout runtime, the parameter data changes storages as needed,
e.g. to the sharded flat parameter, low precision sharded flat
parameter, or the unsharded flat parameter.
NOTE: Since ``use_orig_params=True`` supports intra-``FlatParameter``
padding, we have two versions of the per-parameter numels, one that
includes the padding (``_numels_with_padding``) and one that does not
(``_numels``). The former may have length longer than the other data
structures, while the latter has the same length as the number of actual
original parameters like the other per-parameter data structures.
Attributes:
_unpadded_unsharded_size (torch.Size): Unsharded flat parameter's size
without right-hand-side padding for divisibility by the world size.
For ``use_orig_params=True``, this includes alignment padding.
_padded_unsharded_size (torch.Size): Unsharded flat parameter's size
with right-hand-side padding for divisibility by the world size.
For ``use_orig_params=True``, this includes alignment padding. This
is only set for sharded strategies since they require padding for
the all-gather.
_sharded_size (torch.Size): Sharded flat parameter's size with padding.
This is also set for ``NO_SHARD``, in which case it is the same as
the unsharded sizes. (We omit "padded" because there is no
analogous unpadded one.)
_num_params (int): Number of original parameters flattened into this
flat parameter. This is the length of the per-parameter data
structures.
_param_infos (Tuple[ParamInfo, ...]): Each parameter's parameter info
entry; see :class:`ParamInfo` for details.
_shapes (Tuple[torch.Size, ...]): Each parameter's original shape.
_fqns (Tuple[str, ...]): Each parameter's fully-qualified name (FQN)
prefixed from the ``_fully_sharded_module``. The names are
guaranteed to be unique in the subtree rooted at that module.
_param_extensions (Tuple[Optional[Any], ...]): Each parameter's
extension (i.e. some per-parameter state) used to customize
pre-flatten and post-unflatten behavior or ``None``. This is
experimental, and users should not depend on its existence in the
future.
_numels_with_padding (Tuple[int, ...]): Each parameter's numel
including entries for the padding. This is used to construct views
into the flat parameter via ``torch.split()``. This may have length
longer than ``_num_params``.
_numels (Tuple[int, ...]): Each parameter's numel excluding entries for
padding. This has length equal to ``_num_params``.
_shard_param_infos (Tuple[_ShardParamInfo, ...]): Each parameter's
shard parameter info; see :class:`_ShardParamInfo` for details.
_shared_param_infos (Tuple[SharedParamInfo, ...]): Shared parameter
info entries; see :class:`SharedParamInfo` for details.
_modules (Set[nn.Module]): Modules that contain some original parameter
that is flattened into the flat parameter.
_shard_numel_padded (int): Numel padded for this rank's sharded flat
parameter.
_local_shard (Tensor): Sharded flat parameter with padding if using a
sharded strategy. If using ``NO_SHARD``, then this is the unpadded
unsharded flat parameter, and there is no notion of a sharded flat
parameter or padded unsharded flat parameter.
_full_param_padded (Tensor): Unsharded flat parameter with padding.
This is not defined for ``NO_SHARD``. When using mixed precision
for parameters, this has the low precision.
_full_prec_full_param_padded (Tensor): Full precision unsharded flat
parameter with padding. This is used for unsharding outside of
computation when using mixed precision for parameters. This is
never defined for ``NO_SHARD``.
_post_backward_hook_state (Tuple[AccumulateGrad, RemovableHandle]):
Flat parameter's :class:`AccumulateGrad` object and post-backward
hook handle.
_mp_shard (Tensor): Low precision sharded flat parameter with padding.
This is only defined when parameter mixed precision is enabled. For
``NO_SHARD``, this is used for computation.
_cpu_grad (Tensor): Sharded gradient with padding stored on CPU.
This is only defined when offloading parameters is enabled.
_saved_grad_shard (Tensor): Sharded gradient with padding from previous
iterations for gradient accumulation without :meth:`no_sync`.
_params (Optional[List[nn.Parameter]]): If ``use_orig_params=True``,
then each original parameter variable; otherwise, ``None``. This
does not include any padding tensors.
_shared_params (Optional[List[nn.Parameter]]): The original shared
parameter variables if ``use_orig_params=True`` and ``None``
otherwise.
_tensors (Optional[List[Optional[Tensor]]]): This saves the ``Tensor``
views created in the forward and tracked by autograd when
``use_orig_params=True`` and is ``None`` otherwise. This is to
preserve those ``Tensor`` variables for the backward to ensure that
the ``FlatParameter`` 's ``AccumulateGrad`` object does not change
in which case the post-backward hook does not run. This is relevant
for cases like reentrant activation checkpointing.
_is_grad_none_mask (Optional[List[bool]]): If ``use_orig_params=True``,
a mask over the original parameters' gradients indicating if it is
logically ``None`` or not; otherwise, ``None``. This does not
include entries for padding. This mask is needed because only some
of the parameters may have ``None`` gradient, in which case the
flat gradient must be non-``None`` and must use zeros to
approximate those original ``None`` gradients. This mask informs
FSDP to set the original parameter gradients to ``None`` (instead
of zeros) as needed.
"""
def _init_metadata(
self,
param_infos: List[ParamInfo],
numels: List[int],
shapes: List[torch.Size],
fqns: List[str],
shared_param_infos: List[SharedParamInfo],
param_extensions: List[Optional[Any]],
params: Optional[List[nn.Parameter]],
shared_params: Optional[List[nn.Parameter]],
is_padding_mask: List[bool],
) -> None:
"""
Initializes attributes holding metadata about the original parameters
comprising the flat parameter.
We expose this method separate from the constructor to keep the
constructor only responsible for the flat parameter's tensor data. This
method should only be called once per model, while the constructor may
be called multiple times, e.g. when reloading from a checkpoint, in
which case only the tensor data needs to be passed to the constructor.
Since :meth:`load_state_dict` is implemented via :meth:`copy_`, the
metadata is correctly assumed to be unchanged.
Args:
See the Attributes in the class docstring.
"""
assert len(param_infos) == len(shapes)
assert len(param_infos) == len(fqns)
assert len(param_infos) == len(param_extensions)
self._num_params = len(param_infos)
self._param_infos = param_infos
self._shapes = shapes
self._fqns = fqns
self._param_extensions = param_extensions
self._is_padding_mask = is_padding_mask
numels_without_padding: List[int] = []
for numel, is_padding in zip(numels, is_padding_mask):
if not is_padding:
numels_without_padding.append(numel)
self._numels = tuple(numels_without_padding)
self._numels_with_padding = tuple(numels)
assert len(self._numels) == self._num_params
self._shared_param_infos = tuple(shared_param_infos)
self._modules = {pi.module for pi in self._param_infos}.union(
{spi.module for spi in self._shared_param_infos}
)
assert (params is None) == (shared_params is None)
if params is not None:
assert shared_params is not None and len(shared_params) == len(
shared_param_infos
)
self._params: Optional[List[nn.Parameter]] = []
for param, is_padding in zip(params, is_padding_mask):
if not is_padding:
self._params.append(param)
self._shared_params: Optional[List[nn.Parameter]] = shared_params
# Mark the original parameters to avoid flattening them into
# another `FlatParameter` during recursive construction
for param in chain(self._params, self._shared_params):
_set_fsdp_flattened(param)
self._is_grad_none_mask: Optional[List[bool]] = [
False for _ in range(self._num_params)
]
self._tensors: Optional[List[Optional[Tensor]]] = [
None for _ in range(self._num_params)
]
else:
self._params = None
self._shared_params = None
self._is_grad_none_mask = None
self._tensors = None
self._unpadded_unsharded_size = self.size()
_set_fsdp_flattened(self)
# Tracks whether the `FlatParameter`'s post-backward hook has been
# called to modify the behavior of the post-backward callback
self._post_backward_called = False
class FlatParamHandle:
"""
This handle manages a flat parameter (:class:`FlatParameter`). This
includes sharding and view management.
Args:
params (Sequence[nn.Parameter]): The parameters to flatten into the
flat parameter.
fully_sharded_module (nn.Module): See [Note: Fully Sharded Module].
device (torch.device): The compute and communication device, which
should be a non-CPU device. We refer to it as the compute device.
sharding_strategy (ShardingStrategy): Sharding strategy to apply to
this handle's ``FlatParameter``.
offload_params (bool): Whether to offload the handle's
``FlatParameter`` to CPU.
mp_param_dtype (Optional[torch.dtype]): Parameter mixed precision
setting passed to the FSDP constructor.
mp_reduce_dtype (Optional[torch.dtype]): Gradient reduction mixed
precision setting passed to the FSDP constructor.
keep_low_precision_grads (bool): Whether to keep gradients in low
precision.
use_orig_params (bool): If ``True``, then FSDP preserves the original
parameter variables and returns them from ``named_parameters()``
(e.g. to support different optimizer hyperparameters within one
:class:`FlatParameter`). If ``False``, then FSDP reconstructs the
parameters every iteration and returns the :class:`FlatParameter` s
from ``named_parameters()``.
"""
##################
# INITIALIZATION #
##################
def __init__(
self,
params: Sequence[Union[nn.Parameter, Tensor]],
fully_sharded_module: nn.Module,
device: torch.device,
sharding_strategy: HandleShardingStrategy,
offload_params: bool,
mp_param_dtype: Optional[torch.dtype],
mp_reduce_dtype: Optional[torch.dtype],
keep_low_precision_grads: bool,
process_group: dist.ProcessGroup,
use_orig_params: bool,
):
super().__init__()
params = list(params)
if len(params) == 0:
raise ValueError(
f"Cannot construct a {self.__class__.__name__} with an empty parameter list"
)
self._init_setattr_fns()
align_addresses = use_orig_params
self._init_get_unflat_views_fn(align_addresses)
self.device = device
self.process_group = process_group
self.rank = process_group.rank()
self.world_size = process_group.size()
self._sharding_strategy = sharding_strategy
self._offload_params = offload_params
self._use_orig_params = use_orig_params
self._keep_low_precision_grads = keep_low_precision_grads
self._training_state = HandleTrainingState.IDLE
self._debug_level = dist.get_debug_level()
self._fully_sharded_module = fully_sharded_module
# Optimistically assume a valid input `params` and set dtype attributes
# before `_init_flat_param()`, which performs the actual validation
self._orig_param_dtype = params[0].dtype
self._init_param_reduce_dtypes(mp_param_dtype, mp_reduce_dtype)
assert self._fwd_bwd_param_dtype is not None # mypy
self._aligned_numel = (
_get_aligned_numel(unsharded_dtype=self._fwd_bwd_param_dtype)
if align_addresses
else 0
)
self._init_flat_param_and_metadata(
params, fully_sharded_module, self._aligned_numel, use_orig_params # type: ignore[arg-type]
)
self._use_unsharded_views(as_params=False)
def _init_setattr_fns(self):
use_unsafe_setattr = os.environ.get(_FSDP_USE_UNSAFE_SETATTR, "") == "1"
self._setattr_tensor: Callable[[nn.Module, str, Tensor], None]
self._setattr_param: Callable[[nn.Module, str, nn.Parameter], None]
if use_unsafe_setattr:
self._setattr_tensor = _unsafe_setattr_tensor
self._setattr_param = _unsafe_setattr_param
else:
self._setattr_tensor = _safe_setattr_tensor_or_param
self._setattr_param = _safe_setattr_tensor_or_param
def _init_get_unflat_views_fn(self, align_addresses: bool):
self._get_unflat_views = (
self._get_unflat_views_aligned
if align_addresses
else self._get_unflat_views_unaligned
)
def _init_flat_param_and_metadata(
self,
params: List[Union[Tensor, nn.Parameter]],
module: nn.Module,
aligned_numel: int,
use_orig_params: bool,
) -> None:
"""
NOTE: This should only be called once at construction time, after which
the ``FlatParameter`` metadata is assumed to be static.
NOTE: The elements of ``params`` should only be ``Tensor`` s when
composing with ``DTensor`` -based tensor parallelism, in which case the
elements may be ``DTensor`` local shards.
"""
if len(params) == 0:
raise ValueError("Expects non-empty `params`")
if aligned_numel < 0:
raise ValueError(
f"Expects non-negative `aligned_numel` but got {aligned_numel}"
)
dtype, requires_grad, device = self._validate_tensors_to_flatten(params)
params_set = set(params)
# For alignment padding, only `numels` gets strictly non-`None`
# elements, and all other lists get `None` elements for padding.
param_infos: List[ParamInfo] = []
numels: List[int] = []
shapes: List[torch.Size] = []
fqns: List[str] = []
shared_param_infos: List[SharedParamInfo] = []
shared_param_memo: Dict[
Union[Tensor, nn.Parameter], Tuple[nn.Module, str, str]
] = {}
params_to_flatten: List[Union[Tensor, nn.Parameter]] = []
shared_params: List[Union[Tensor, nn.Parameter]] = []
param_extensions: List[Any] = []
is_padding_mask: List[bool] = []
total_numel = total_numel_without_padding = 0
for submodule_name, submodule in module.named_modules():
for param_name, param in submodule.named_parameters(recurse=False):
if param not in params_set:
continue
if param in shared_param_memo: # shared reference
prim_module, prim_module_name, prim_param_name = shared_param_memo[
param
]
shared_params.append(param)
shared_param_infos.append(
SharedParamInfo(
param_name,
submodule,
submodule_name,
prim_param_name,
prim_module,
prim_module_name,
)
)
else:
if aligned_numel > 0:
numel_to_pad = aligned_numel - (total_numel % aligned_numel)
if numel_to_pad > 0 and numel_to_pad < aligned_numel:
padding_tensor = _construct_padding_tensor(
numel_to_pad, dtype, requires_grad, device
)
params_to_flatten.append(padding_tensor)
is_padding_mask.append(True)
numels.append(numel_to_pad)
total_numel += numel_to_pad
param, extension = _ext_pre_flatten_transform(param)
param_extensions.append(extension)
shared_param_memo[param] = (submodule, submodule_name, param_name)
params_to_flatten.append(param)
is_padding_mask.append(False)
param_infos.append(ParamInfo(param_name, submodule, submodule_name))
numels.append(param.numel())
shapes.append(param.shape)
fqn = (
submodule_name + "." + param_name
if submodule_name
else param_name
)
fqns.append(fqn)
total_numel += param.numel()
total_numel_without_padding += param.numel()
if len(params_to_flatten) == 0:
raise ValueError(
f"`params` were not found in `module`'s tree"
f"params: {params}\nmodule: {module}"
)
if (
self.rank == 0
and aligned_numel > 0
and total_numel != total_numel_without_padding
):
log.info(
f"FSDP FlatParameter address alignment created "
f"{total_numel - total_numel_without_padding} "
f"numel of padding ({total_numel} vs. {total_numel_without_padding})"
)
# Pass `aligned_numel=0` since we already included padding tensors
self.flat_param: FlatParameter = self.flatten_tensors_into_flat_param(
params_to_flatten,
aligned_numel=0,
requires_grad=requires_grad,
)
self.flat_param._init_metadata(
param_infos,
numels,
shapes,
fqns,
shared_param_infos,
param_extensions,
_convert_to_params(params_to_flatten) if use_orig_params else None,
_convert_to_params(shared_params) if use_orig_params else None,
is_padding_mask,
)
def _validate_tensors_to_flatten(
self, tensors: List[Union[Tensor, nn.Parameter]]
) -> Tuple:
"""
Validates the tensors to flatten and returns any necessary metadata.
"""
dtype: Optional[torch.dtype] = None
requires_grad: Optional[bool] = None
device: Optional[torch.device] = None
for tensor in tensors:
if type(tensor) is FlatParameter:
raise ValueError("Cannot flatten a `FlatParameter`")
if dtype is None and not tensor.is_floating_point():
raise ValueError("Cannot flatten integer dtype tensors")
if dtype is not None and tensor.dtype != dtype:
raise ValueError(
f"Must flatten tensors with uniform dtype but got {dtype} "
f"and {tensor.dtype}"
)
# TODO: Relax the following for `use_orig_params=True`.
if requires_grad is not None and tensor.requires_grad != requires_grad:
raise ValueError("Must flatten tensors with uniform `requires_grad`")
if device is not None and tensor.device != device:
raise ValueError(
"Must flatten tensors on the same device but got both "
f"{device} and {tensor.device}"
)
dtype = tensor.dtype
requires_grad = tensor.requires_grad
device = tensor.device
assert requires_grad is not None
return dtype, requires_grad, device
def flatten_tensors(
self,
tensors: List[Tensor],
aligned_numel: int,
) -> Tensor:
"""
Flattens ``tensors`` into a single flat tensor optionally including
padding if ``aligned_numel`` is greater than 0, where ``aligned_numel``
gives the numel required to have address alignment.
NOTE: The padding alignment algorithm must be kept in sync with
:meth:`_init_flat_param_metadata`. We separate the two methods because
the initialization happens once, whereas this method may be called
multiple times throughout training (e.g. for checkpointing).
"""
if len(tensors) == 0:
raise ValueError("Expects non-empty `tensors`")
if aligned_numel < 0:
raise ValueError(
f"Expects non-negative `aligned_numel` but got {aligned_numel}"
)
dtype, requires_grad, device = self._validate_tensors_to_flatten(tensors)
flat_tensors: List[Tensor] = []
if aligned_numel > 0:
total_numel = 0
for tensor in tensors:
numel_to_pad = aligned_numel - (total_numel % aligned_numel)
if numel_to_pad > 0 and numel_to_pad < aligned_numel:
padding_tensor = _construct_padding_tensor(
numel_to_pad, dtype, requires_grad, device
)
flat_tensors.append(padding_tensor)
total_numel += numel_to_pad
flat_tensors.append(torch.flatten(_detach_if_needed(tensor)))
total_numel += tensor.numel()
else:
flat_tensors = [
torch.flatten(_detach_if_needed(tensor)) for tensor in tensors
]
return torch.cat(flat_tensors, dim=0)
def flatten_tensors_into_flat_param(
self,
tensors: List[Tensor],
aligned_numel: int,
requires_grad: bool,
) -> FlatParameter:
flat_param_data = self.flatten_tensors(tensors, aligned_numel)
return FlatParameter(flat_param_data, requires_grad=requires_grad)
def _init_param_reduce_dtypes(
self,
mp_param_dtype: Optional[torch.dtype],
mp_reduce_dtype: Optional[torch.dtype],
) -> None:
"""
Precondition: ``self.flat_param`` is set. This ensures that this
handle's parameters have a single dtype.
Postcondition: This sets ``self._fwd_bwd_param_dtype`` and
``self._reduce_dtype``. If ``mp_param_dtype`` or ``mp_reduce_dtype``
is ``None``, then we assume the original parameter dtype. One special
case is if ``mp_param_dtype`` is not ``None`` and ``mp_reduce_dtype``
is ``None``, in which case we assume the gradient reduction dtype
matches the forward/backward parameter dtype.
"""
# Save whether these dtypes were specified so that we permit the
# parameter dtype to change up until the lazy initialization
self._low_prec_param_dtype_specified = mp_param_dtype is not None
self._low_prec_reduce_dtype_specified = mp_reduce_dtype is not None
if (
self._low_prec_param_dtype_specified
and not self._low_prec_reduce_dtype_specified
):
# Special case: infer gradient reduction mixed precision
self._fwd_bwd_param_dtype = mp_param_dtype
self._reduce_dtype = self._fwd_bwd_param_dtype
else:
self._fwd_bwd_param_dtype = mp_param_dtype or self._orig_param_dtype
self._reduce_dtype = mp_reduce_dtype or self._orig_param_dtype
assert self._fwd_bwd_param_dtype is not None
assert self._reduce_dtype is not None
###################################
# SHARD INITIALIZATION & METADATA #
###################################
@torch.no_grad()
def shard(self):
"""
Shards the handle's ``FlatParameter``. This allocates new memory for
the sharded flat parameter and frees the unsharded flat parameter's
storage.
Postcondition: ``self.flat_param`` is the sharded flat parameter. Shard
metadata attributes are set for all sharding strategies.
"""
flat_param = self.flat_param
if not self.uses_sharded_strategy:
self._init_shard_metadata(0, 0, flat_param.numel() - 1)
else:
_p_assert(
flat_param.storage_offset() == 0,
"The `FlatParameter` is not the sole occupant of its storage",
)
orig_storage = flat_param._typed_storage()
sharded_flat_param, numel_padded = FlatParamHandle._get_shard(
flat_param, self.rank, self.world_size
)
flat_param.set_(sharded_flat_param) # type: ignore[call-overload]
start_idx = sharded_flat_param.numel() * self.rank
end_idx = sharded_flat_param.numel() * (self.rank + 1) - 1 # inclusive
self._init_shard_metadata(numel_padded, start_idx, end_idx)
if orig_storage._size() > 0:
orig_storage._resize_(0)
if self._use_orig_params:
self._use_sharded_views()
def _init_shard_metadata(
self,
numel_padded: int,
unsharded_start_idx: int,
unsharded_end_idx: int,
) -> None:
"""
Initializes shard-related metadata for this rank's shard of the flat
parameter: ``_sharded_size``, ``_shard_param_infos``, and
``_shard_numel_padded``.
Args:
numel_padded (int): Numel padded for this rank's sharded flat
parameter.
unsharded_start_idx (int): Start index in the unsharded flat
parameter assigned to this rank.
unsharded_end_idx (int): End index (inclusive) in the unsharded
flat parameter assigned to this rank.
Precondition: ``self.flat_param`` 's data is the sharded flat
parameter.
"""
flat_param = self.flat_param
flat_param._sharded_size = flat_param.size() # type: ignore[attr-defined]
sharded_flat_param_numel = flat_param.numel() # includes `numel_padded`
_p_assert(
unsharded_start_idx >= 0 and unsharded_start_idx <= unsharded_end_idx,
f"unsharded_start_idx: {unsharded_start_idx} unsharded_end_idx: {unsharded_end_idx}",
)
_p_assert(
numel_padded <= sharded_flat_param_numel,
f"numel_padded: {numel_padded} "
f"sharded_flat_param_numel: {sharded_flat_param_numel}",
)
shard_param_infos = self._get_shard_metadata(
unsharded_start_idx, unsharded_end_idx
)
assert (
len(shard_param_infos) == flat_param._num_params
), f"Expects length {flat_param._num_params} but got {len(shard_param_infos)}"
flat_param._shard_param_infos = shard_param_infos # type: ignore[attr-defined]
flat_param._shard_numel_padded = numel_padded # type: ignore[attr-defined]
def _get_shard_metadata(
self,
unsharded_start_idx: int,
unsharded_end_idx: int,
) -> Tuple[_ShardParamInfo, ...]:
"""
Computes the shard metadata based on ``unsharded_start_idx`` and
``unsharded_end_idx`` (inclusive), which give the interval of the
unsharded flat parameter specifying the shard.
"""
flat_param_offsets = self._get_flat_param_offsets()
assert len(flat_param_offsets) == len(
self.flat_param._numels_with_padding
), f"Expected {len(self.flat_param._numels_with_padding)} but got {len(flat_param_offsets)}"
shard_param_infos: List[_ShardParamInfo] = []
sharded_flat_param_numel = unsharded_end_idx - unsharded_start_idx + 1
# `unsharded_param_start_idx` and `unsharded_param_end_idx` are indices
# into the unsharded flat parameter (inclusive) of the given parameter
for i, (
(unsharded_param_start_idx, unsharded_param_end_idx),
is_padding,
) in enumerate(zip(flat_param_offsets, self.flat_param._is_padding_mask)):
if is_padding:
continue
in_sharded_flat_param = (
unsharded_start_idx <= unsharded_param_end_idx
and unsharded_end_idx >= unsharded_param_start_idx
)
if not in_sharded_flat_param:
shard_param_info = _ShardParamInfo(False, None, None, None, None)
else:
if unsharded_start_idx <= unsharded_param_start_idx:
# This branch can only happen once since the rank's
# unsharded start index can only intersect one parameter
intra_param_start_idx = 0
offset_in_shard = unsharded_param_start_idx - unsharded_start_idx
else:
intra_param_start_idx = (
unsharded_start_idx - unsharded_param_start_idx
)
offset_in_shard = 0
assert (
offset_in_shard >= 0 and offset_in_shard < sharded_flat_param_numel
), (
f"Invalid `offset_in_shard` of {offset_in_shard} for "
f"sharded flat parameter with {sharded_flat_param_numel} numel"
)
intra_param_end_idx = (
min(unsharded_param_end_idx, unsharded_end_idx)
- unsharded_param_start_idx
)
numel_in_shard = intra_param_end_idx - intra_param_start_idx + 1
shard_param_info = _ShardParamInfo(
True,
offset_in_shard,
numel_in_shard,
intra_param_start_idx,
intra_param_end_idx,
)
shard_param_infos.append(shard_param_info)
return tuple(shard_param_infos)
@staticmethod
def _get_unpadded_shard(
tensor: Tensor,
rank: int,
world_size: int,
) -> Tuple[Tensor, int]:
"""
Returns the shard of ``tensor`` without any padding for the given
``rank`` and ``world_size`` and the numel to pad for that shard.
If ``tensor`` is already flattened or may be viewed in the flattened
shape (which is true in the expected usage), then this method does not
allocate any new tensor memory.
"""
chunks = torch.flatten(tensor).chunk(world_size)
if len(chunks) < (rank + 1):
# This rank gets an empty chunk fully padded with zeros since there
# are not enough chunks across ranks
chunk = chunks[0].new_empty(0)
else:
chunk = chunks[rank]
numel_to_pad = chunks[0].numel() - chunk.numel()
assert (
numel_to_pad >= 0
), "Chunk's size should be at most the first chunk's size"
return chunk, numel_to_pad
@staticmethod
def _get_shard(
tensor: Tensor,
rank: int,
world_size: int,
) -> Tuple[Tensor, int]:
"""
Returns the shard of ``tensor`` with padding for the given ``rank`` and
``world_size`` and the numel padded for that shard.
This method allocates new memory (via :meth:`clone`) since the
unsharded ``tensor`` may be deallocated after this method returns.
"""
chunk, numel_to_pad = FlatParamHandle._get_unpadded_shard(
tensor, rank, world_size
)
shard = chunk.clone()
if numel_to_pad > 0:
shard = F.pad(shard, [0, numel_to_pad])
return shard, numel_to_pad
@staticmethod
def _get_sharded_size(tensor: Tensor, rank: int, world_size: int) -> torch.Size:
"""
Returns the shape of ``tensor`` after sharding including padding. This
requires ``tensor`` to have 1D shape and ensures that the returned
shape is 1D.
"""
assert len(tensor.shape) == 1, f"{tensor.shape}"
unpadded_sharded_tensor, numel_to_pad = FlatParamHandle._get_unpadded_shard(
tensor, rank, world_size
)
unpadded_sharded_size = unpadded_sharded_tensor.size()
assert len(unpadded_sharded_size) == 1, f"{unpadded_sharded_size}"
return torch.Size([unpadded_sharded_size[0] + numel_to_pad])
def _get_flat_param_offsets(self) -> List[Tuple[int, int]]:
"""
Returns [start, end] offsets of each original parameter's flattened
data in the unsharded flat parameter (without padding).
NOTE: The returned list includes elements for alignment padding.
"""
cumulative_sum = list(accumulate(self.flat_param._numels_with_padding))
starts = [0] + cumulative_sum[:-1]
ends = [end - 1 for end in cumulative_sum] # inclusive
param_offsets = list(zip(starts, ends))
return param_offsets
@no_type_check
def shard_metadata(
self,
) -> FlatParamShardMetadata:
"""
Returns shard-related metadata specific to this rank's shard of the
flat parameter.
NOTE: The returned tuple does not include elements for alignment
padding but does account for the padding.
"""
fqns_list = []
shapes_list = []
numels_list = []
shard_param_offsets = []
for fqn, shape, numel, shard_param_info in zip(
self.flat_param._fqns,
self.flat_param._shapes,
self.flat_param._numels,
self.flat_param._shard_param_infos,
):
if not shard_param_info.in_shard:
continue
fqns_list.append(fqn)
shapes_list.append(shape)
numels_list.append(numel)
shard_param_offsets.append(
(
shard_param_info.intra_param_start_idx,
shard_param_info.intra_param_end_idx,
)
)
return FlatParamShardMetadata(
tuple(fqns_list),
tuple(shapes_list),
tuple(numels_list),
shard_param_offsets,
)
@no_type_check
@torch.no_grad()
def init_flat_param_attributes(self) -> None:
"""
This initializes some attributes on the handle's ``FlatParameter``.
This should be called during lazy initialization since it requires the
parameter to be on the compute device if not offloading to CPU and we
want to give users the chance to move the parameter appropriately after
the FSDP constructor.
For each tensor attribute on the ``FlatParameter``, see the unshard and
reshard methods in this class for the allocation and free pattern.
"""
flat_param = self.flat_param
if flat_param.dtype != self._orig_param_dtype:
# Entering this branch means that the user changed the parameter
# dtype after FSDP initialization, in which case we may need to
# refresh some saved dtype attributes (dtypes specified as a part
# of mixed precision take precedence).
if not self._low_prec_param_dtype_specified:
self._fwd_bwd_param_dtype = flat_param.dtype
# For `reduce_dtype`, require `param_dtype` was not specified since
# then we infer the `reduce_dtype` from the specified `param_dtype`
if (
not self._low_prec_reduce_dtype_specified
and not self._low_prec_param_dtype_specified
):
self._reduce_dtype = flat_param.dtype
self._orig_param_dtype = flat_param.dtype
cpu_device = torch.device("cpu")
if self._offload_params:
_p_assert(
flat_param.device == cpu_device,
f"Expects the `FlatParameter` to be on CPU when parameter CPU "
f"offloading is enabled, not {flat_param.device}",
)
else:
self._check_on_compute_device(self.flat_param)
flat_param._local_shard = flat_param.data
if self._offload_params:
# Pin the memory for faster H2D transfer
flat_param._local_shard = flat_param._local_shard.pin_memory()
# Pre-allocate the sharded gradient on CPU to enable non-blocking
# D2H transfer during the backward pass
flat_param._cpu_grad = torch.zeros_like(
flat_param._local_shard, device=cpu_device
).pin_memory()
if self._uses_param_mixed_precision:
# For parameter mixed precision, we maintain a low precision
# sharded tensor on the compute device to be all-gathered (for
# sharded strategies) or directly used (for `NO_SHARD`) for
# computation.
flat_param._mp_shard = torch.zeros_like(
flat_param._local_shard,
device=self.device,
dtype=self._fwd_bwd_param_dtype,
)
_free_storage(flat_param._mp_shard)
if self.uses_sharded_strategy:
# We maintain a padded unsharded tensor that serves as the
# all-gather destination and owns the original parameter storages.
unsharded_param_dtype = (
self._fwd_bwd_param_dtype
if self._uses_param_mixed_precision
else flat_param.dtype
) # use low precision if parameter mixed precision is enabled
padded_unsharded_numel = flat_param.numel() * self.world_size
flat_param._full_param_padded = torch.zeros(
padded_unsharded_numel,
device=self.device,
dtype=unsharded_param_dtype,
)
flat_param._padded_unsharded_size = flat_param._full_param_padded.size()
_free_storage(flat_param._full_param_padded)
if self._uses_param_mixed_precision:
# For parameter mixed precision, we maintain a full precision
# padded unsharded tensor for when we force full precision.
flat_param._full_prec_full_param_padded = torch.zeros(
padded_unsharded_numel,
device=self.device,
dtype=flat_param.dtype, # full precision
)
_free_storage(flat_param._full_prec_full_param_padded)
###################
# UNSHARD/RESHARD #
###################
def pre_unshard(self) -> bool:
"""
Returns: ``False`` if this is a no-op and ``True`` otherwise.
Postcondition: ``self.flat_param`` 's data is on the device for
communication and is what should be all-gathered. This means that it
matches the dtype of the expected unsharded parameter.
"""
ret = False
if self._use_orig_params:
ret = self._writeback_orig_params()
if (
self.uses_sharded_strategy
and not self._offload_params
and not self.needs_unshard()
):
pass # no-op
elif self._uses_param_mixed_precision and not self._force_full_precision:
self._use_low_precision_shard()
ret = True
elif self._offload_params and self.flat_param.device != self.device:
# NOTE: This creates a new tensor distinct from any attributes.
self.flat_param_to(self.device, non_blocking=True)
ret = True
self._check_on_compute_device(self.flat_param)
return ret
def _use_low_precision_shard(self):
"""
Allocates the low precision shard directly on the compute device and
switches to using the low precision sharded flat parameter.
"""
self._check_low_precision_shard()
flat_param = self.flat_param
_alloc_storage(
flat_param._mp_shard, flat_param._local_shard.size() # type: ignore[attr-defined]
)
# `copy_()` implicitly casts to the low precision
flat_param._mp_shard.copy_( # type: ignore[attr-defined]
flat_param._local_shard.to( # type: ignore[attr-defined]
self.device, non_blocking=True
)
)
# Invariant: `_mp_shard` is always on the compute device.
flat_param.data = flat_param._mp_shard # type: ignore[attr-defined]
def unshard(self):
"""
Runs the unshard logic. This includes all-gathering the flat parameter
and switching to using the unsharded flat parameter. If the handle does
not need unsharding, then this only switches to using the unsharded
flat parameter. For ``NO_SHARD``, this is a no-op.
If FSDP is in :meth:`summon_full_params` and the handle uses parameter
mixed precision, then the parameter is forced to full precision.
"""
if not self.needs_unshard():
# Even when not needing an unshard, we should switch to using
# the unsharded flat parameter
unsharded_flat_param = (
self._get_padded_unsharded_flat_param()
if self.uses_sharded_strategy
else self.flat_param
)
self._use_unsharded_flat_param(unsharded_flat_param)
return
unsharded_flat_param = self._alloc_padded_unsharded_flat_param()
padded_unsharded_flat_param = self._all_gather_flat_param(unsharded_flat_param)
self._use_unsharded_flat_param(padded_unsharded_flat_param)
def needs_unshard(self) -> bool:
"""Returns if the handle's flat parameter needs to be unsharded."""
if not self.uses_sharded_strategy:
return False
unsharded_flat_param = self._get_padded_unsharded_flat_param()
already_unsharded = (
unsharded_flat_param._typed_storage()._size()
== unsharded_flat_param.numel()
)
return not already_unsharded
def _alloc_padded_unsharded_flat_param(self):
"""
Allocates the *padded* unsharded flat parameter. The unpadded unsharded
flat parameter is always a view into the padded one. This padded
parameter is saved to a different attribute on the ``FlatParameter``
depending on if we force full precision.
"""
self._check_sharded_strategy()
flat_param = self.flat_param
unsharded_flat_param = self._get_padded_unsharded_flat_param()
self._check_storage_freed(unsharded_flat_param)
_alloc_storage(unsharded_flat_param, flat_param._padded_unsharded_size) # type: ignore[attr-defined]
return unsharded_flat_param
def _get_padded_unsharded_flat_param(self) -> torch.Tensor:
"""
Returns a reference to the padded unsharded flat parameter depending on
the calling context. This should only be called if using a sharded
strategy.
"""
self._check_sharded_strategy()
flat_param = self.flat_param
if self._force_full_precision:
# When parameter mixed precision is enabled, we use a different
# tensor as the all-gather destination to preserve the invariant
# that `_full_param_padded` is in the low precision
unsharded_flat_param = flat_param._full_prec_full_param_padded # type: ignore[attr-defined]
_p_assert(
unsharded_flat_param.dtype != self._fwd_bwd_param_dtype,
f"Expects full precision but got {self._fwd_bwd_param_dtype}",
)
else:
unsharded_flat_param = flat_param._full_param_padded # type: ignore[attr-defined]
return unsharded_flat_param
def _all_gather_flat_param(
self,
padded_unsharded_flat_param: Tensor,
) -> Tensor:
"""
All-gathers the handle's flat parameter to the destination
``padded_unsharded_flat_param``, and switches to using the all-gathered
tensor.
"""
_p_assert(
hasattr(self, "process_group") and hasattr(self, "world_size"),
"Expects a process group and world size to have been set via `shard()`",
)
sharded_flat_param = self.flat_param.data
expected_numel = sharded_flat_param.numel() * self.world_size
_p_assert(
padded_unsharded_flat_param.numel() == expected_numel,
f"Expects {expected_numel} numel but got {padded_unsharded_flat_param.numel()}",
)
dist.all_gather_into_tensor(
padded_unsharded_flat_param,
sharded_flat_param,
self.process_group,
)
return padded_unsharded_flat_param
def _use_unsharded_flat_param(
self,
padded_unsharded_flat_param: torch.Tensor,
) -> None:
"""
Switches to using the *unpadded* unsharded flat parameter, which is a
view into the *padded* unsharded flat parameter.
"""
unsharded_size = self.flat_param._unpadded_unsharded_size
self.flat_param.data = padded_unsharded_flat_param[
: unsharded_size.numel()
].view(
unsharded_size
) # this `.view()` is not autograd visible
in_forward = self._training_state == HandleTrainingState.FORWARD
in_pre_backward = self._training_state == HandleTrainingState.BACKWARD_PRE
if self._use_orig_params:
# We use `Tensor` views in the forward so that they are tracked by
# autograd. We use them in the pre-backward as well to support
# reentrant activation checkpointing, which needs the views to be
# tracked by autograd in the backward pass's recomputed forward.
self._use_unsharded_views(
as_params=(not in_forward and not in_pre_backward)
)
elif in_forward:
self._use_unsharded_views(as_params=False)
def post_unshard(self):
"""
Runs the post-unshard logic. This includes freeing the low precision
shard if needed.
"""
if self._uses_param_mixed_precision and self.uses_sharded_strategy:
self._free_low_precision_sharded_param()
self._check_on_compute_device(self.flat_param)
def _free_low_precision_sharded_param(self):
"""Frees the low precision sharded flat parameter."""
self._check_low_precision_shard()
# `_mp_shard` is allocated in the pre-unshard stream, consumed in the
# unshard stream for sharded strategies, and consumed in both the
# unshard and default streams for `NO_SHARD`. For sharded strategies,
# the current stream here is the unshard stream, and for `NO_SHARD`,
# it is the default stream. For `NO_SHARD`, only recording for the
# default stream suffices since the default stream waits for the
# unshard stream.
_no_dispatch_record_stream(
self.flat_param._mp_shard, torch.cuda.current_stream() # type: ignore[attr-defined]
)
_free_storage(self.flat_param._mp_shard) # type: ignore[attr-defined]
@torch.no_grad()
def unshard_grad(self):
"""
Unshards the handle's ``FlatParameter`` 's gradient. If all ranks have
``None`` gradient, then all original parameters will as well. This
method performs an all-reduce and an all-gather. The additional
all-reduce is tolerable since this method is not meant to be used on
the computation critical path.
Postcondition: ``_saved_grad_shard`` is defined and contains the value
to set ``flat_param.grad`` after gradients are resharded.
"""
if not self.uses_sharded_strategy:
self._use_unsharded_grad_views()
return
flat_param = self.flat_param
self._check_unsharded(flat_param)
# Check if all ranks have a `None` gradient
num_grad_none = torch.zeros(1, dtype=torch.int32, device=self.device)
num_grad_none[0] = flat_param.grad is None
dist.all_reduce(num_grad_none, group=self.process_group)
if num_grad_none[0] == self.world_size:
flat_param._saved_grad_shard = None # type: ignore[attr-defined]
self._use_unsharded_grad_views()
return
padded_unsharded_grad = torch.empty(
flat_param._padded_unsharded_size, # type: ignore[attr-defined]
device=self.device,
)
if flat_param.grad is None:
# In the case that only some ranks have `None` gradient, we use
# zeros to approximate as a best effort attempt
if self._debug_level == dist.DebugLevel.DETAIL:
warnings.warn(
f"[Rank {self.rank}] Only some but not all ranks have a "
"`None` `FlatParameter` gradient, so FSDP is using zeros to "
"approximate those ranks' sharded gradients being `None`"
)
flat_param._saved_grad_shard = None # type: ignore[attr-defined]
sharded_grad = torch.zeros(flat_param._sharded_size, device=self.device) # type: ignore[attr-defined]
else:
self._check_sharded(flat_param.grad)
flat_param._saved_grad_shard = flat_param.grad # type: ignore[attr-defined]
sharded_grad = flat_param._saved_grad_shard # type: ignore[attr-defined]
dist.all_gather_into_tensor(
padded_unsharded_grad, sharded_grad, self.process_group
)
unsharded_size = self.flat_param._unpadded_unsharded_size
flat_param.grad = padded_unsharded_grad[: unsharded_size.numel()].view(
unsharded_size
)
self._use_unsharded_grad_views()
def reshard_grad(self):
if self._use_orig_params:
self._use_sharded_grad_views()
if not self.uses_sharded_strategy:
return
self.flat_param.grad = self.flat_param._saved_grad_shard # type: ignore[attr-defined]
delattr(self.flat_param, "_saved_grad_shard")
def prepare_gradient_for_backward(self):
"""
Prepares the gradient for the backward computation by saving and
clearing any existing sharded gradient in ``.grad`` to enable computing
a new unsharded gradient.
"""
_p_assert(
self._training_state
in (HandleTrainingState.BACKWARD_PRE, HandleTrainingState.IDLE),
"Expects to be in `BACKWARD_PRE` or `IDLE` (if prefetching)",
)
flat_param = self.flat_param
if flat_param.grad is not None and (
flat_param.grad.size() != flat_param._unpadded_unsharded_size
or flat_param.grad.device != flat_param.device # grad on CPU
):
self._check_on_compute_device(self.flat_param)
grad_offloaded = flat_param.grad.device != self.device
_p_assert(
not grad_offloaded or self._offload_params,
f"Expects the sharded gradient to be on {self.device} "
f"but got {flat_param.grad.device}",
)
prev_iter_synced_gradients = (
flat_param.grad.size()
== flat_param._local_shard.size() # type: ignore[attr-defined]
)
if prev_iter_synced_gradients:
# TODO (awgu): Gradient accumulation outside `no_sync()`
# does not work with CPU offloading. The issue should be
# that, in the post-backward hook, we cannot do an addition
# between a CPU tensor (the existing sharded gradient) and
# a GPU tensor (the new sharded gradient).
if not grad_offloaded:
flat_param._saved_grad_shard = flat_param.grad.data # type: ignore[attr-defined]
sharded_grad = flat_param._saved_grad_shard # type: ignore[attr-defined]
else:
_p_assert(
hasattr(flat_param, "_cpu_grad"),
"`_cpu_grad` should be defined if the gradient is on CPU",
)
sharded_grad = flat_param._cpu_grad # type: ignore[attr-defined]
# If user specified to keep the gradient in low precision, then
# the gradient may still be of the low precision dtype if the
# user did not set the gradient to `None` after the previous
# backward, in which case FSDP should cast back to the full
# precision dtype so that FSDP can accumulate in that dtype in
# the post-backward hook and assign to `.grad` in that dtype in
# the post-backward callback.
local_shard_dtype = flat_param._local_shard.dtype # type: ignore[attr-defined]
if (
self._keep_low_precision_grads
and sharded_grad.dtype != local_shard_dtype
):
sharded_grad.data = sharded_grad.to(local_shard_dtype)
else:
padded_unsharded_size = flat_param._padded_unsharded_size # type: ignore[attr-defined]
_p_assert(
flat_param.grad.size() == padded_unsharded_size,
"Expects `.grad` to be the unsharded gradient in "
f"`no_sync()` with size {padded_unsharded_size} "
f"but got size {flat_param.grad.size()}",
)
flat_param.grad = None
def prepare_gradient_for_optim(self):
"""
Prepares the gradient for optimizer computation by moving the sharded
gradient to the ``.grad`` attribute.
"""
def cast_grad_to_param_dtype_if_needed(flat_param):
if self._keep_low_precision_grads:
assert flat_param.grad is not None # mypy
if flat_param.grad.dtype != self._fwd_bwd_param_dtype:
flat_param.grad.data = flat_param.grad.to(self._fwd_bwd_param_dtype)
if self._use_orig_params:
self._use_sharded_grad_views()
flat_param = self.flat_param
# TODO (awgu): We should replace these conditional checks to encode
# the logical intention more directly.
if hasattr(flat_param, "_cpu_grad"):
# NOTE: This branch includes `NO_SHARD`.
self._check_sharded(flat_param)
self._check_on_cpu(flat_param)
flat_param.grad = flat_param._cpu_grad # type: ignore[attr-defined]
cast_grad_to_param_dtype_if_needed(flat_param)
elif hasattr(flat_param, "_saved_grad_shard"):
self._check_sharded(flat_param)
self._check_on_compute_device(flat_param)
self._check_on_compute_device(flat_param._saved_grad_shard) # type: ignore[attr-defined]
# If no sharded gradient was computed this iteration, then there is
# no need to forward `_saved_grad_shard` to `grad`
if flat_param._post_backward_called: # type: ignore[attr-defined]
flat_param.grad = flat_param._saved_grad_shard # type: ignore[attr-defined]
cast_grad_to_param_dtype_if_needed(flat_param)
else:
_p_assert(
not self.uses_sharded_strategy
or not flat_param._post_backward_called, # type: ignore[attr-defined]
"All sharded parameters that received a gradient in the "
"post-backward should use `_saved_grad_shard`",
)
# Delete `_saved_grad_shard` since its existence indicates a previous
# gradient to accumulate with in the post-backward hook
if hasattr(flat_param, "_saved_grad_shard"):
delattr(flat_param, "_saved_grad_shard")
@contextlib.contextmanager
def to_cpu(self):
"""
Moves the unpadded unsharded flat parameter to CPU while in the context
and moves it back to the previous device upon exit. For now, this
assumes the ``FlatParameter`` is the unpadded unsharded flat parameter
since (1) there is no reason to include the padding in the copy and (2)
there is no use case for the sharded flat parameter.
Precondition: ``self.flat_param`` 's data is the unpadded unsharded
flat parameter on the compute device, and the handle uses a sharded
strategy.
Postcondition: Same as the precondition.
"""
self._check_sharded_strategy()
_p_assert(
self.flat_param.size() == self.flat_param._unpadded_unsharded_size,
f"Expects size {self.flat_param._unpadded_unsharded_size} but got {self.flat_param.size()}",
)
self._check_on_compute_device(self.flat_param)
# Check that the unpadded unsharded flat parameter is a view into the
# padded unsharded flat parameter as expected
# NOTE: This check is not strictly needed for correctness but is a
# useful sanity check since the tensor should only be used internally.
unpadded_storage_ptr = self.flat_param._typed_storage()._data_ptr()
padded_storage_ptr = (
self._get_padded_unsharded_flat_param()._typed_storage()._data_ptr()
)
_p_assert(
unpadded_storage_ptr == padded_storage_ptr,
"Expects the unpadded parameter to be a view into the padded parameter",
)
self.flat_param_to(torch.device("cpu"))
self._free_unsharded_flat_param()
try:
yield
finally:
_p_assert(
self.flat_param.size() == self.flat_param._unpadded_unsharded_size,
f"Expects size {self.flat_param._unpadded_unsharded_size} but got {self.flat_param.size()}",
)
padded_unsharded_flat_param = self._alloc_padded_unsharded_flat_param()
# Copy from CPU to the compute device
padded_unsharded_flat_param[: self.flat_param.numel()].copy_(
self.flat_param
)
self._use_unsharded_flat_param(padded_unsharded_flat_param)
def reshard(self, free_unsharded_flat_param: bool):
"""
Runs the reshard logic. This includes freeing the unsharded flat
parameter if ``free_unsharded_flat_param`` and switching to using the
sharded flat parameter.
"""
# Switch to the sharded `FlatParameter` before freeing to prevent
# "use-after-free"-type bugs with external profiling tools, where for
# `use_orig_params=True`, the `param` does not point to valid memory
# when setting `param.data = ...` in `_use_sharded_views()`.
self._use_sharded_flat_param()
if free_unsharded_flat_param:
self._free_unsharded_flat_param()
def post_reshard(self):
"""
Runs the post-reshard logic. This includes freeing any memory that
can now be freed given that the ``FlatParameter`` points to the full
precision sharded flat parameter.
Precondition: ``self.flat_param`` 's data points to the full precision
sharded flat parameter.
"""
# For `NO_SHARD`, `_mp_shard` is not freed in the post-unshard since it
# is also the low precision *unsharded* flat parameter. Hence, we delay
# the free until the reshard.
if (
self._uses_param_mixed_precision
and not self.uses_sharded_strategy
and not self._force_full_precision # did not use the low precision shard
):
self._free_low_precision_sharded_param()
def _free_unsharded_flat_param(self):
"""
Frees the padded unsharded flat parameter. The tensor to free depends
on the calling context since the unshard may have forced full
precision, in which case a different tensor is used.
"""
self._check_sharded_strategy()
unsharded_flat_param = self._get_padded_unsharded_flat_param()
self._check_storage_allocated(unsharded_flat_param)
self._check_on_compute_device(unsharded_flat_param)
# Do not free the memory until all ops in the current stream finish
_no_dispatch_record_stream(unsharded_flat_param, torch.cuda.current_stream())
_free_storage(unsharded_flat_param)
def _use_sharded_flat_param(self) -> None:
"""Switches to using the sharded flat parameter."""
flat_param = self.flat_param
if self._offload_params:
device = flat_param._local_shard.device # type: ignore[attr-defined]
_p_assert(
device == torch.device("cpu"),
f"Expects the local shard to be on CPU but got {device}",
)
flat_param.data = flat_param._local_shard # type: ignore[attr-defined]
if self._use_orig_params:
self._use_sharded_views()
# For the post-forward reshard, we may try to use sharded gradient
# views (or unsharded gradient views if a gradient was accumulated
# in `no_sync()`), but for the post-backward reshard, we delay the
# call to after the reduce-scatter.
if self._training_state == HandleTrainingState.FORWARD:
# TODO: Change `_unpadded_unsharded_size` if we change the
# gradient to be computed directly with padding.
accumulated_grad_in_no_sync = (
flat_param.grad is not None
and self.uses_sharded_strategy
and flat_param.grad.shape == flat_param._unpadded_unsharded_size
)
if accumulated_grad_in_no_sync:
self._use_unsharded_grad_views()
else:
self._use_sharded_grad_views()
#########
# VIEWS #
#########
@no_type_check
def _get_unflat_views_unaligned(
self,
tensor: Optional[torch.Tensor] = None,
) -> Iterator[Tensor]:
"""
Returns unflattened ``Tensor`` views into ``tensor`` if it is not
``None`` or ``flat_param`` otherwise, where the unflattening is based
on ``flat_param`` 's metadata.
Examples for ``tensor`` include ``flat_param.grad`` or unsharded
tensor optimizer state.
"""
flat_param = self.flat_param
if tensor is None:
tensor = flat_param
views = (
_ext_post_unflatten_transform(subtensor.view(shape), param_extension)
for (subtensor, shape, param_extension) in zip(
torch.split(tensor, flat_param._numels, dim=0),
flat_param._shapes,
flat_param._param_extensions,
)
)
return views
@no_type_check
def _get_unflat_views_aligned(
self,
tensor: Optional[Tensor] = None,
) -> List[Tensor]:
"""
This has the same contract as :meth:`_get_unflat_views_unaligned`
except it checks for ``None`` placeholders representing padding for
alignment, which may incur slightly more CPU overhead.
"""
flat_param = self.flat_param
if tensor is None:
tensor = flat_param
splits: List[Tensor] = torch.split(
tensor, flat_param._numels_with_padding, dim=0
)
idx = 0
views: List[Tensor] = []
for split, is_padding in zip(splits, flat_param._is_padding_mask):
if is_padding:
continue
views.append(
_ext_post_unflatten_transform(
split.view(flat_param._shapes[idx]),
flat_param._param_extensions[idx],
)
)
idx += 1
return views
@no_type_check
def _use_unsharded_views(self, as_params: bool) -> None:
"""
Unflattens the unsharded flat parameter by setting the original
parameter variables to be views into it.
Args:
as_params (bool): If ``True``, then registers the original
parameters as ``nn.Parameter`` s; if ``False``, then registers
the original parameters only as ``Tensor`` s. ``False`` should
be used during forward/backward computation and when hiding the
original parameters from :meth:`nn.Module.named_parameters`.
"""
flat_param = self.flat_param
self._check_unsharded(flat_param)
views = self._get_unflat_views()
for i, (view, (param_name, module, _)) in enumerate(
zip(views, flat_param._param_infos)
):
if self._use_orig_params and as_params:
if type(view) is DTensor:
# A `DTensor` `view` is not compatible with assigning
# `param.data = view`, so we cannot preserve the parameter
# variable.
self._setattr_param(module, param_name, nn.Parameter(view))
continue
param = self.flat_param._params[i]
self._setattr_param(module, param_name, param)
param.data = view
elif as_params:
self._setattr_param(module, param_name, nn.Parameter(view))
else: # `as_params=False`
param_var: Tensor = view
if self._use_orig_params:
if self._training_state == HandleTrainingState.FORWARD:
# Save the `Tensor` for the pre-backward
self.flat_param._tensors[i] = view # save for pre-backward
elif self._training_state == HandleTrainingState.BACKWARD_PRE:
# Use the saved `Tensor` variable from the forward to
# preserve the autograd graph so that the post-backward
# hook fires (e.g. for reentrant AC)
tensor = self.flat_param._tensors[i]
tensor.data = view
param_var = tensor
self._setattr_tensor(module, param_name, param_var)
if (
self._use_orig_params
and self._training_state == HandleTrainingState.FORWARD
):
module._parameters[param_name] = param_var
for i, (
param_name,
module,
_,
prim_param_name,
prim_module,
_,
) in enumerate(self.flat_param._shared_param_infos):
prim_param: Union[Tensor, nn.Parameter] = getattr(
prim_module, prim_param_name
)
_p_assert(
not as_params or isinstance(prim_param, nn.Parameter),
f"as_params={as_params} type(prim_param)={type(prim_param)}",
)
if self._use_orig_params and as_params:
shared_param = self.flat_param._shared_params[i]
self._setattr_param(module, param_name, shared_param)
shared_param.data = prim_param
elif as_params:
self._setattr_param(module, param_name, prim_param)
else:
self._setattr_tensor(module, param_name, prim_param)
if (
self._use_orig_params
and self._training_state == HandleTrainingState.FORWARD
):
module._parameters[param_name] = prim_param
@no_type_check
def _use_unsharded_grad_views(self) -> None:
"""
Unflattens the unsharded flat parameter's gradient by setting the
original parameter variables' gradients to be views into it.
"""
# Expects the gradient to be in `flat_param.grad`
if self.flat_param.grad is None:
for param in chain(self.flat_param._params, self.flat_param._shared_params):
param.grad = None
return
self._check_unsharded(self.flat_param.grad)
views = self._get_unflat_views(self.flat_param.grad)
for i, (view, (param_name, module, _)) in enumerate(
zip(views, self.flat_param._param_infos)
):
_p_assert(
hasattr(module, param_name),
f"{self.flat_param._fqns[i]} is missing",
)
param = getattr(module, param_name)
if param.shape != view.shape or param.dtype != view.dtype:
# NOTE: This is a hack using `.data` to side step the
# check that parameter/gradient sizes and dtypes match. Here,
# `param` can have the sharded size, and `grad` can have the
# unsharded size. Orthogonally, `param` can have the full
# precision dtype from `reshard()`, and `grad` can have the
# parameter low precision dtype. Both of these mismatches
# happen when running in `no_sync()`.
if param.grad is None:
param.grad = torch.empty_like(param)
param.grad.data = view
else:
param.grad = view
for i, (
param_name,
module,
module_name,
prim_param_name,
prim_module,
_,
) in enumerate(self.flat_param._shared_param_infos):
_p_assert(
hasattr(module, param_name),
f"{module_name + '.' + param_name if module_name else param_name} is missing",
) # did not save FQN info in `_shared_param_infos`
param = getattr(module, param_name)
prim_param = getattr(prim_module, prim_param_name)
if (
param.shape != prim_param.grad.shape
or param.dtype != prim_param.grad.dtype
):
# NOTE: This is the same hack to use `.data` to side step the
# size check.
if param.grad is None:
param.grad = torch.empty_like(param)
param.grad.data = prim_param.grad
else:
param.grad = prim_param.grad
@contextlib.contextmanager
def unflatten_as_params(self) -> Generator:
"""
Assumes the flat parameter is unsharded. When in the context,
unflattens the original parameters as ``nn.Parameter`` views into the
flat parameter, and after the context, restores the original parameters
as ``Tensor`` views into the flat parameter.
"""
self._use_unsharded_views(as_params=True)
try:
yield
finally:
self._use_unsharded_views(as_params=False)
@no_type_check
@torch.no_grad()
def _use_sharded_views(self) -> None:
"""
Sets the original parameter variables' data to be flattened views into
the sharded flat parameter.
The views are kept as flattened to simplify the case where a parameter
is sharded across ranks. Parameters whose data is not present in the
sharded flat parameter have their data set to a size-0 empty tensor. We
do not delete them to ensure to preserve expected behaviors like model
printability. Parameters whose data is present must preserve their
variables to be passable to an optimizer.
"""
if not self.uses_sharded_strategy:
# For `NO_SHARD`, use the *unflattened* unsharded views since we
# have the unsharded parameter
self._use_unsharded_views(as_params=True)
return
flat_param = self.flat_param
self._check_sharded(flat_param)
# Construct once and reuse for all parameters not in the local shard
size_0_empty_tensor = torch.empty(
0,
dtype=self.flat_param.dtype, # in case `flat_param` changed dtype
device=self.flat_param.device,
requires_grad=False,
)
for param, shard_param_info, (param_name, module, _) in zip(
flat_param._params, flat_param._shard_param_infos, flat_param._param_infos
):
self._setattr_param(module, param_name, param)
if not shard_param_info.in_shard:
# Allow the original data to be freed via garbage collection
param.data = size_0_empty_tensor
else:
offset = shard_param_info.offset_in_shard
numel_in_shard = shard_param_info.numel_in_shard
param.data = flat_param[offset : offset + numel_in_shard]
assert self.flat_param._shared_params is not None
for i, (
param,
(param_name, module, _, prim_param_name, prim_module, _),
) in enumerate(
zip(self.flat_param._shared_params, self.flat_param._shared_param_infos)
):
self._setattr_param(module, param_name, param)
prim_param = getattr(prim_module, prim_param_name)
param.data = prim_param # could be both empty and non-empty
if self._training_state == HandleTrainingState.BACKWARD_POST:
# Clear the saved `Tensor`s since they are unneeded now
for i in range(len(self.flat_param._tensors)):
self.flat_param._tensors[i] = None
@no_type_check
@torch.no_grad()
def _use_sharded_grad_views(self) -> None:
"""
Sets the original parameter variables' gradients to be flattened
views into the sharded flat parameter's gradient. This is a no-op if
there is no gradient.
Parameters whose data is not present in the sharded flat parameter and
parameters with ``requires_grad=False`` have their gradients set to
``None``. Since the gradient variables do not need to be preserved,
this method does not manipulate existing ``Tensor`` data directly and
creates new ``Tensor`` variables instead.
"""
flat_param = self.flat_param
self._check_sharded(flat_param)
grad = self.sharded_grad
if grad is None:
for param in chain(flat_param._params, flat_param._shared_params):
param.grad = None
return
self._check_sharded(grad)
for param, shard_param_info, is_grad_none in zip(
flat_param._params,
flat_param._shard_param_infos,
flat_param._is_grad_none_mask,
):
if not shard_param_info.in_shard:
param.grad = None
else:
numel_in_shard = shard_param_info.numel_in_shard
if param.requires_grad and not is_grad_none:
offset = shard_param_info.offset_in_shard
if self._keep_low_precision_grads or param.dtype != grad.dtype:
# NOTE: This is a hack using `.data` to side step the
# check that parameter/gradient dtypes match. Here,
# `param` has full precision; `grad` has low precision.
if param.grad is None:
# `.grad` must have the same shape as `param`
param.grad = torch.empty_like(param)
param.grad.data = grad[
offset : offset + numel_in_shard
].reshape(param.shape)
else:
param.grad = grad[offset : offset + numel_in_shard].reshape(
param.shape
)
else:
param.grad = None
assert flat_param._shared_params is not None
for i, (param, (_, _, _, prim_param_name, prim_module, _)) in enumerate(
zip(flat_param._shared_params, flat_param._shared_param_infos)
):
in_sharded_flat_param = hasattr(prim_module, prim_param_name)
if in_sharded_flat_param and param.requires_grad:
prim_param = getattr(prim_module, prim_param_name)
param.grad = prim_param.grad # share the same reference
else:
param.grad = None
@no_type_check
@torch.no_grad()
def _writeback_orig_params(self) -> bool:
"""
Iterates over the original parameters and writes back any parameters
that changed storages (due to a non-inplace operator) to the handle's
``FlatParameter``. This method preserves the ``FlatParameter` 's
device even if an original parameter's device changes.
Raises:
RuntimeError: If an original parameter or gradient changes storages
but no longer has the expected flattened shape.
Returns: ``True`` if some writeback happened, and ``False`` otherwise.
"""
if self.uses_sharded_strategy and not self.is_sharded(self.flat_param):
# For `NO_SHARD`, we may still need to writeback
return False
flat_param = self.flat_param
wroteback = False
flat_param_data_ptr = flat_param.untyped_storage().data_ptr()
# NOTE: Since this method is called in the pre-unshard, which is only
# called during computation in the pre-forward or pre-backward, the
# sharded gradient should be guaranteed to be in `.grad`, not in
# `._saved_grad_shard`.
flat_param_grad = (
flat_param.grad
if self.uses_sharded_strategy or not self._offload_params
else flat_param._cpu_grad
)
flat_param_grad_data_ptr = (
None
if flat_param_grad is None
else flat_param_grad.untyped_storage().data_ptr()
)
for i, (
param,
(in_shard, offset_in_shard, numel_in_shard, _, _),
(param_name, module, _),
) in enumerate(
zip(
flat_param._params,
flat_param._shard_param_infos,
flat_param._param_infos,
)
):
if not in_shard:
continue
if not hasattr(module, param_name):
# Do not writeback if original parameters are deregistered
# (e.g. during model checkpointing)
continue
# Check for parameter writeback
param_changed = getattr(module, param_name) is not param
needs_param_writeback = (
param_changed # changed parameter variable itself
or not _same_storage_as_data_ptr(
param, flat_param_data_ptr
) # changed `.data`
)
if param_changed:
# NOTE: The gradient is not preserved after a parameter change.
param = getattr(module, param_name)
flat_param._params[i] = param
if needs_param_writeback:
expected_shape = torch.Size([numel_in_shard])
self._writeback_tensor(
param, flat_param, i, expected_shape, offset_in_shard, True
)
wroteback = True
# Check for gradient writeback
if param.grad is None and flat_param.grad is not None:
expected_shape = torch.Size([numel_in_shard])
self._writeback_tensor(
None, flat_param.grad, i, expected_shape, offset_in_shard, False
)
elif param.grad is not None:
# For `NO_SHARD` + CPU offloading, `_cpu_grad` is always in
# memory and owns the gradient storage, so it will never
# require gradient writeback.
needs_grad_writeback = (
flat_param_grad is None
or not _same_storage_as_data_ptr(
param.grad, flat_param_grad_data_ptr
)
)
if needs_grad_writeback:
if flat_param_grad is None:
flat_param_grad = torch.zeros_like(flat_param)
expected_shape = torch.Size([numel_in_shard])
self._writeback_tensor(
param.grad,
flat_param_grad,
i,
expected_shape,
offset_in_shard,
False,
)
flat_param.grad = flat_param_grad
flat_param_grad = flat_param.grad
flat_param_grad_data_ptr = (
flat_param_grad.untyped_storage().data_ptr()
)
# TODO: If we want to handle shared parameters, we need to re-generate
# the shared parameter data structures in case sharedness changed.
for i, (
param_name,
module,
_,
prim_param_name,
prim_module,
_,
) in enumerate(flat_param._shared_param_infos):
if getattr(module, param_name) is not getattr(prim_module, prim_param_name):
raise NotImplementedError(
"Changing shared parameters is not supported yet"
)
return wroteback
def _writeback_tensor(
self,
src_tensor: Optional[Tensor],
dst_tensor: Tensor,
tensor_index: int,
expected_shape: torch.Size,
offset: int,
is_param: bool, # else gradient
) -> None:
"""
Writes back ``src_tensor`` to ``dst_tensor`` at offset ``offset``,
where ``src_tensor`` should have shape ``expected_shape``. ``is_param``
indicates if the tensor is the parameter (if ``True``) or gradient (if
``False``). If ``src_tensor`` is ``None``, then the effect is zeroing
instead of copying. ``tensor_index`` gives the index of ``src_tensor``
in the metadata structures.
Raises:
RuntimeError: If the ``src_tensor`` does not have the expected
shape.
"""
_p_assert(
len(expected_shape) == 1,
f"Expects a 1D expected shape but got {expected_shape}",
)
if self._debug_level == dist.DebugLevel.DETAIL:
rank = self.rank if hasattr(self, "rank") else dist.get_rank()
src_shape = src_tensor.shape if src_tensor is not None else None
src_device = src_tensor.device if src_tensor is not None else None
warnings.warn(
f"[Rank {rank}] {'Parameter' if is_param else 'Gradient'} needs "
f"writeback in {self._training_state}\n"
f"expected shape={expected_shape} shape={src_shape} "
f"expected device={dst_tensor.device} device={src_device}"
)
if src_tensor is not None and src_tensor.shape != expected_shape:
# NOTE: Gradient shape mismatch is not possible in practice since
# the gradient shape is enforced to match that of the parameter and
# we already check for parameter shape mismatch.
raise RuntimeError(
f"Cannot writeback when the {'parameter' if is_param else 'gradient'} "
f"shape changes\nExpects {expected_shape} but got {src_tensor.shape}"
)
if src_tensor is not None:
dst_tensor[offset : offset + expected_shape.numel()].copy_(src_tensor)
else:
dst_tensor[offset : offset + expected_shape.numel()].zero_()
assert self.flat_param._is_grad_none_mask is not None
self.flat_param._is_grad_none_mask[tensor_index] = True
def _clear_grads_if_needed(self):
"""
When ``use_orig_params=True``, sets the underlying ``flat_param.grad``
to ``None`` if *all* of the original parameters' ``.grad`` are
``None``. This is targeting ``optim.zero_grad(set_to_none=True)``, in
which case we want to free the gradients as soon after the
``zero_grad()`` call as possible.
"""
if not self._use_orig_params:
return
flat_param = self.flat_param
assert flat_param._params is not None
if all(param.grad is None for param in flat_param._params):
flat_param.grad = None
def _deregister_orig_params(self):
for param_info in self.flat_param._param_infos:
param_name, module, _ = param_info
if hasattr(module, param_name):
delattr(module, param_name)
for param_name, module, _, _, _, _ in self.flat_param._shared_param_infos:
if hasattr(module, param_name):
delattr(module, param_name)
###########
# HELPERS #
###########
def flat_param_to(self, *args, **kwargs):
"""Wraps an in-place call to ``.to()`` for ``self.flat_param``."""
self.flat_param.data = self.flat_param.to(*args, **kwargs)
if self._use_orig_params:
# Refresh the views because their storage may have changed
if self.is_sharded(self.flat_param):
self._use_sharded_views()
else:
self._use_unsharded_views(as_params=True)
def _get_modules(self) -> Set[nn.Module]:
"""
Returns a :class:`set` of the modules whose parameters are included
in this handle's flat parameter.
"""
return {pi.module for pi in self.flat_param._param_infos}.union(
{spi.module for spi in self.flat_param._shared_param_infos}
)
def is_sharded(self, tensor: Tensor) -> bool:
"""
Returns if ``tensor`` is *currently* sharded. For ``NO_SHARD``, we
choose to have this always return ``False`` for clarity.
"""
if (
not hasattr(self.flat_param, "_sharded_size")
or not self.uses_sharded_strategy
):
# `_sharded_size` is defined iff `handle.shard()` has been called
return False
sharded_size = self.flat_param._sharded_size # type: ignore[attr-defined]
return tensor.size() == sharded_size
def param_module_names(self) -> Iterator[Tuple[str, str]]:
shared_param_infos = [
ParamInfo(param_name, module, module_name)
for (
param_name,
module,
module_name,
_,
_,
_,
) in self.flat_param._shared_param_infos
]
for param_info in chain(self.flat_param._param_infos, shared_param_infos):
param_name, _, module_name = param_info # type: ignore[misc]
yield (param_name, module_name)
def shared_param_module_names(self) -> Iterator[Tuple[str, str]]:
for param_name, _, module_name in [
ParamInfo(param_name, module, module_name)
for (
param_name,
module,
module_name,
_,
_,
_,
) in self.flat_param._shared_param_infos
]:
yield (param_name, module_name)
@property
def _fqns_in_shard(self) -> List[str]:
"""Returns the FQNs of the parameters present in this rank's shard."""
fqns_in_shard: List[str] = []
for fqn, shard_param_info in zip(
self.flat_param._fqns, self.flat_param._shard_param_infos # type: ignore[attr-defined]
):
if shard_param_info.in_shard:
fqns_in_shard.append(fqn)
return fqns_in_shard
@property
def sharded_grad(self) -> Optional[Tensor]:
"""Returns the handle's sharded gradient."""
flat_param = self.flat_param
# Priority for non-`None`: `_cpu_grad` > `_saved_grad_shard` > `grad`
# - CPU offloading: `_cpu_grad`
# - No CPU offloading + sharded strategies: `_saved_grad_shard`
# - No CPU offloading + `NO_SHARD`: `grad`
if hasattr(flat_param, "_cpu_grad"):
grad = flat_param._cpu_grad # type: ignore[attr-defined]
elif hasattr(flat_param, "_saved_grad_shard"):
# In the post-backward hook, the sharded gradient is still in
# `_saved_grad_shard`.
grad = flat_param._saved_grad_shard # type: ignore[attr-defined]
else:
# If in IDLE or in FORWARD states, then there may be an
# (accumulated) gradient. If accessed in IDLE, then this should
# be due to re-registering the original parameters (e.g. in state
# dict load).
_p_assert(
flat_param.grad is None
or not self.uses_sharded_strategy
or self._training_state
in (HandleTrainingState.FORWARD, HandleTrainingState.IDLE),
"Sharded strategies should use `_cpu_grad` or `_saved_grad_shard` "
"unless in IDLE or FORWARD",
)
grad = flat_param.grad
return grad
def _reset_is_grad_none(self) -> None:
"""
Resets ``_is_grad_none_mask`` as needed. This method should only be
called in the post-backward after gradient computation, in which case
if a parameter requires gradient, then it will surely receive a
gradient and we may reset its mask entry to ``False``.
"""
if not self._use_orig_params:
return
_p_assert(
self._training_state == HandleTrainingState.BACKWARD_POST,
"Expects to only be called in the post-backward after gradient computation",
)
flat_param = self.flat_param
assert flat_param._params is not None # mypy
for i, param in enumerate(flat_param._params): # type: ignore[arg-type]
# As long as the parameter requires gradient, it should receive a
# meaningful gradient (even if the gradient happens to be zeros)
if param.requires_grad:
assert flat_param._is_grad_none_mask is not None # mypy
flat_param._is_grad_none_mask[i] = False
#######################
# CHECKS & INVARIANTS #
#######################
def _check_sharded_strategy(self):
_p_assert(self.uses_sharded_strategy, "Expects sharded strategy")
def _check_on_compute_device(self, tensor: Tensor):
_p_assert(
tensor.device == self.device,
f"Expects tensor to be on the compute device {self.device}",
)
def _check_on_cpu(self, tensor: Tensor):
_p_assert(
tensor.device == torch.device("cpu"),
f"Expects tensor to be on CPU but got {tensor.device}",
)
@staticmethod
def _check_storage_freed(tensor: Tensor):
storage_size: int = tensor._typed_storage()._size()
_p_assert(
storage_size == 0,
f"Expects storage to be freed but got storage with size {storage_size}",
)
@staticmethod
def _check_storage_allocated(tensor: Tensor):
storage_size: int = tensor._typed_storage()._size()
_p_assert(storage_size > 0, "Expects storage to be allocated")
def _check_low_precision_shard(self):
_p_assert(
self._uses_param_mixed_precision,
"Not using low precision for parameters",
)
_p_assert(
getattr(self.flat_param, "_mp_shard", None) is not None,
"Expects `_mp_shard` to exist",
)
device = self.flat_param._mp_shard.device # type: ignore[attr-defined]
_p_assert(
device == self.device,
f"Expects the low precision shard to be on {self.device} but got {device}",
)
def _check_unsharded(self, tensor: Tensor):
msg_prefix = "Expects tensor to be unsharded "
_p_assert(tensor is not None, msg_prefix + "but got `None`")
unsharded_size = self.flat_param._unpadded_unsharded_size
_p_assert(
tensor.size() == unsharded_size,
msg_prefix + f"with size {unsharded_size} but got {tensor.size()}",
)
def _check_sharded(self, tensor: Tensor):
msg_prefix = "Expects tensor to be sharded "
_p_assert(tensor is not None, msg_prefix + "but got `None`")
sharded_size = self.flat_param._sharded_size # type: ignore[attr-defined]
_p_assert(
tensor.size() == sharded_size,
msg_prefix + f"with size {sharded_size} but got {tensor.size()}",
)
##############
# PROPERTIES #
##############
@property
def uses_sharded_strategy(self) -> bool:
return self._sharding_strategy != HandleShardingStrategy.NO_SHARD
@property
def _uses_param_mixed_precision(self) -> bool:
return self._fwd_bwd_param_dtype != self._orig_param_dtype
@property
def _uses_reduce_mixed_precision(self) -> bool:
return self._reduce_dtype != self._orig_param_dtype
@property
def _force_full_precision(self) -> bool:
return (
self._training_state == HandleTrainingState.SUMMON_FULL_PARAMS
and self._uses_param_mixed_precision
)
# NOTE: These are hacks to bypass `nn.Module.__setattr__` checks.
def _unsafe_setattr_param(
module: nn.Module, param_name: str, param: nn.Parameter
) -> None:
module._parameters[param_name] = param
# This bypasses any overrides in case `module` is an instance of an
# `nn.Module` subclass
super(nn.Module, module).__setattr__(param_name, param)
def _unsafe_setattr_tensor(module: nn.Module, param_name: str, tensor: Tensor) -> None:
module._parameters.pop(param_name, None)
# This bypasses any overrides in case `module` is an instance of an
# `nn.Module` subclass
super(nn.Module, module).__setattr__(param_name, tensor)
def _safe_setattr_tensor_or_param(
module: nn.Module, param_name: str, tensor_or_param: Union[Tensor, nn.Parameter]
):
# Call `delattr()` and `setattr()` to go through `nn.Module` checks
if hasattr(module, param_name):
delattr(module, param_name)
setattr(module, param_name, tensor_or_param)
def _convert_to_params(
tensors: List[Union[torch.Tensor, nn.Parameter]]
) -> List[nn.Parameter]:
return [t if isinstance(t, nn.Parameter) else nn.Parameter(t) for t in tensors]
def _detach_if_needed(param_or_tensor: Union[nn.Parameter, Tensor]) -> Tensor:
return (
param_or_tensor.detach()
if isinstance(param_or_tensor, nn.Parameter)
else param_or_tensor
)
def _get_aligned_numel(unsharded_dtype: torch.dtype):
# NOTE: This alignment constraint comes from TorchInductor.
ALIGNMENT = 16 # bytes
unsharded_dtype_size = _get_dtype_size(unsharded_dtype)
aligned_numel = ALIGNMENT // unsharded_dtype_size
return aligned_numel
@functools.lru_cache(8)
def _get_dtype_size(dtype):
return torch.empty((), dtype=dtype).element_size()
def _construct_padding_tensor(
padding_numel: int, dtype: torch.dtype, requires_grad: bool, device: torch.device
):
# NOTE: Set the padding value as a magic number for debuggability. The
# value itself should never be used in any user-facing computation.
return (
torch.ones(
(padding_numel,), dtype=dtype, requires_grad=requires_grad, device=device
)
* _FLAT_PARAM_PADDING_VALUE
)
# A handles key represents the group of `FlatParamHandle`s involved in a given
# module's forward. These will be all-gathered together in the pre-forward and
# pre-backward.
_HandlesKey = Tuple[FlatParamHandle, ...]