blob: 881e7f97edb1ad71127bc2ac329895c63ad1ea36 [file] [log] [blame]
DDP Communication Hooks
=======================
DDP communication hook is a generic interface to control how to communicate
gradients across workers by overriding the vanilla allreduce in
`DistributedDataParallel <https://pytorch.org/docs/stable/generated/torch.nn.parallel.DistributedDataParallel.html#torch.nn.parallel.DistributedDataParallel.>`_.
A few built-in communication hooks are provided,
and users can easily apply any of these hooks to optimize communication.
Besides, the hook interface can also support user-defined communication
strategies for more advanced use cases.
How to Use a Communication Hook?
--------------------------------
To use a communication hook, the user just needs to let the DDP model register
the hook before the training loop as below.
:func:`torch.nn.parallel.DistributedDataParallel.register_comm_hook`
What Does a Communication Hook Operate On?
------------------------------------------
A communication hook provides a flexible way to allreduce gradients.
Therefore, it mainly operates on the gradients on each replica before allreduce,
which are bucketized to increase the overlap between communication and computation.
Particularly, :class:`torch.distributed.GradBucket` represents a bucket of gradient tensors to be allreduced.
.. autoclass:: torch.distributed.GradBucket
.. autofunction:: torch.distributed.GradBucket.index
.. autofunction:: torch.distributed.GradBucket.buffer
.. autofunction:: torch.distributed.GradBucket.gradients
.. autofunction:: torch.distributed.GradBucket.is_last
.. autofunction:: torch.distributed.GradBucket.set_buffer
.. autofunction:: torch.distributed.GradBucket.parameters
Default Communication Hooks
---------------------------
Default communication hooks are simple **stateless** hooks, so the input state
in ``register_comm_hook`` is either a process group or ``None``.
The input ``bucket`` is a :class:`torch.distributed.GradBucket` object.
.. currentmodule:: torch.distributed.algorithms.ddp_comm_hooks.default_hooks
.. autofunction:: allreduce_hook
.. autofunction:: fp16_compress_hook
.. autofunction:: bf16_compress_hook
Additionally, a communication hook wrapper is provided to support :meth:`~fp16_compress_hook` or :meth:`~bf16_compress_hook` as a wrapper,
which can be combined with other communication hooks.
.. autofunction:: fp16_compress_wrapper
.. autofunction:: bf16_compress_wrapper
PowerSGD Communication Hook
---------------------------
PowerSGD (`Vogels et al., NeurIPS 2019 <https://arxiv.org/abs/1905.13727>`_)
is a gradient compression algorithm, which can provide very high compression
rates and accelerate bandwidth-bound distributed training.
This algorithm needs to maintain both some hyperparameters and the internal
state. Therefore, PowerSGD communication hook is a **stateful** hook,
and the user needs to provide a state object defined as below.
PowerSGD State
^^^^^^^^^^^^^^^^
.. currentmodule:: torch.distributed.algorithms.ddp_comm_hooks.powerSGD_hook
.. autoclass:: PowerSGDState
PowerSGD Hooks
^^^^^^^^^^^^^^^^
.. warning ::
PowerSGD typically requires extra memory of the same size as the model's
gradients to enable error feedback, which can compensate for biased
compressed communication and improve accuracy.
.. warning ::
PowerSGD hooks may conflict with `Apex automatic mixed precision package <https://github.com/NVIDIA/apex>`_.
Please use PyTorch `native automatic mixed precision package <https://pytorch.org/docs/stable/amp.html>`_
instead.
.. autofunction:: powerSGD_hook
.. autofunction:: batched_powerSGD_hook
Debugging Communication Hooks
-----------------------------
As the name implies, debugging communication hooks are **only** used for debugging and performance optimization purpose.
.. currentmodule:: torch.distributed.algorithms.ddp_comm_hooks.debugging_hooks
.. warning ::
Debugging communication hooks do not necessarily output the correct results.
.. autofunction:: noop_hook
Checkpointing of Communication Hooks
------------------------------------
.. currentmodule:: torch.distributed.algorithms.ddp_comm_hooks.powerSGD_hook
A stateful communication hook can be saved as a part of model checkpointing to enable trainer restarts.
To make a hook serializable, ``__setstate__`` and ``__getstate__`` should be defined.
.. warning ::
``__getstate__`` should exclude non-serializable attributes from a returned dictionary.
.. warning ::
``__setstate__`` should properly initialize non-serializable attributes, excluded from a provided ``state``.
:class:`PowerSGDState` has ``__setstate__`` and ``__getstate__`` implemented and can be used as a reference.
.. class:: PowerSGDState
:noindex:
.. automethod:: PowerSGDState.__getstate__
.. automethod:: PowerSGDState.__setstate__
Here is a simple, end-to-end example of saving and reloading PowerSGD state and hook.
::
import os
import sys
import tempfile
import torch
import torch.distributed as dist
import torch.nn as nn
import torch.optim as optim
from torch.distributed.algorithms.ddp_comm_hooks import powerSGD_hook as powerSGD
class SimpleModel(nn.Module):
def __init__(self):
super(SimpleModel, self).__init__()
self.fc1 = nn.Linear(24,24)
self.relu = nn.ReLU()
self.fc2 = nn.Linear(24,12)
def forward(self, x):
return self.fc2(self.relu(self.fc1(x)))
def setup(rank, world_size):
os.environ['MASTER_ADDR'] = 'localhost'
os.environ['MASTER_PORT'] = '12355'
# initialize the process group
dist.init_process_group("nccl", rank=rank, world_size=world_size)
def cleanup():
dist.destroy_process_group()
def run_demo(demo_fn, world_size):
mp.spawn(
demo_fn,
args=(world_size,),
nprocs=world_size,
join=True)
def demo_serialization(rank, world_size):
setup(rank, world_size)
CHECKPOINT = tempfile.gettempdir() + "/checkpoint.pt"
model = SimpleModel().to(rank)
ddp_model = DistributedDataParallel(model, device_ids=[rank])
powersgd_hook = powerSGD.powerSGD_hook
powersgd_state = powerSGD.PowerSGDState(process_group=None)
optimizer = optim.SGD(ddp_model.parameters(), lr=0.001)
ddp_model.register_comm_hook(powersgd_state, powersgd_hook)
state = {
'state_dict': ddp_model.state_dict(),
'comm_hook': hook,
'comm_hook_state': hook_state}
if rank == 0:
torch.save(state, CHECKPOINT)
dist.barrier()
map_location = {'cuda:%d' % 0: 'cuda:%d' % rank}
checkpoint = torch.load(CHECKPOINT, map_location=map_location)
ddp_model.load_state_dict(checkpoint['state_dict'])
powersgd_hook = checkpoint['comm_hook']
powersgd_state = checkpoint['comm_hook_state']
ddp_model.register_comm_hook(powersgd_state, powersgd_hook)
if rank == 0:
os.remove(CHECKPOINT)
cleanup()
if __name__ == "__main__":
n_gpus = torch.cuda.device_count()
assert n_gpus >= 2, f"Requires at least 2 GPUs to run, but got {n_gpus}"
world_size = n_gpus
run_demo(demo_serialization, world_size)
Acknowledgements
----------------
Many thanks to PowerSGD paper author **Thijs Vogels** for the code review on
PowerSGD communication hook, as well as the
`comparison experiments <https://observablehq.com/@tvogels/powersgd-benchmark>`_,
which show that the performance of PowerSGD communication hook is on par with
the implementation in the original `paper <https://arxiv.org/abs/1905.13727>`_.