blob: 2b8a3764d21fb5e445692e5a491ae9892f037ac0 [file] [log] [blame]
# Owner(s): ["oncall: distributed"]
import contextlib
import os
import sys
from typing import Any, Optional
import torch
import torch.distributed as dist
if not dist.is_available():
print("Distributed not available, skipping tests", file=sys.stderr)
sys.exit(0)
from torch.distributed.algorithms.join import Join, Joinable, JoinHook
from torch.testing._internal.common_distributed import (
MultiProcessTestCase,
require_n_gpus_for_nccl_backend,
)
from torch.testing._internal.common_utils import run_tests, TEST_WITH_DEV_DBG_ASAN
if TEST_WITH_DEV_DBG_ASAN:
print("Skip dev-asan as torch + multiprocessing spawn have known issues", file=sys.stderr)
sys.exit(0)
BACKEND = dist.Backend.NCCL if torch.cuda.is_available() else dist.Backend.GLOO
WORLD_SIZE = min(4, max(2, torch.cuda.device_count()))
# Constants used for testing post-hooks
BEFORE_CONSTANT = 41
AFTER_CONSTANT = 42
class AllReducerJoinHook(JoinHook):
r"""
Join hook for :class:`AllReducer`.
Arguments:
allreducer (AllReducer): the :class:`AllReducer` object using this
hook.
num_allreduces (int): the number of all-reduces to shadow per
iteration.
run_post_hook (bool): a flag enabling the post-hook logic.
"""
def __init__(
self,
allreducer,
num_allreduces,
run_post_hook
):
self.allreducer = allreducer
self.num_allreduces = num_allreduces
self.run_post_hook = run_post_hook
def main_hook(self):
r"""
Shadows each all-reduce; the number of all-reduces is passed into the
constructor as ``num_allreduces``.
"""
device = self.allreducer.device
for _ in range(self.num_allreduces):
t = torch.zeros(1, device=device)
dist.all_reduce(t)
def post_hook(self, is_last_joiner: bool):
r"""
Broadcasts a tensor containing a magic constant ``AFTER_CONSTANT`` from
the last joiner to all other processes.
"""
if not self.run_post_hook:
return
rank = dist.get_rank(self.allreducer.process_group)
common_rank = self.allreducer.find_common_rank(rank, is_last_joiner)
device = self.allreducer.device
if rank == common_rank:
self.allreducer.post_hook_tensor = torch.tensor([AFTER_CONSTANT], device=device)
dist.broadcast(self.allreducer.post_hook_tensor, src=common_rank)
class AllReducer(Joinable):
r"""
Example :class:`Joinable` that performs some number of all-reduces as its
per-iteration collective communication.
"""
def __init__(self, device, process_group):
super(AllReducer, self).__init__()
self.device = device
self.process_group = process_group
self.post_hook_tensor = torch.tensor([BEFORE_CONSTANT], device=self.device)
def __call__(self, num_allreduces=1):
r"""
All-reduces a dim-1 one tensor ``num_allreduces``-many times, and
returns the total result.
"""
Join.notify_join_context(self)
device = self.device
total = 0
for _ in range(num_allreduces):
t = torch.ones(1, device=device)
dist.all_reduce(t)
total += t.item()
return total
def join_hook(self, **kwargs) -> JoinHook:
r"""
Returns a join hook that shadows some number of all-reduces; by default,
this number is 1.
"""
num_allreduces = kwargs.get("num_allreduces", 1)
run_post_hook = kwargs.get("run_post_hooks", False)
return AllReducerJoinHook(
self,
num_allreduces,
run_post_hook
)
@property
def join_device(self) -> torch.device:
return self.device
@property
def join_process_group(self) -> Any:
return self.process_group
def find_common_rank(self, rank, to_consider):
r"""
Returns the max rank of the ones to consider over the process group.
"""
common_rank = torch.tensor(
[rank if to_consider else -1],
device=self.device
)
dist.all_reduce(common_rank, op=dist.ReduceOp.MAX, group=self.process_group)
common_rank = common_rank.item()
assert common_rank >= 0
return common_rank
class TestJoin(MultiProcessTestCase):
r"""Test cases for the generic join context."""
def setUp(self):
super(TestJoin, self).setUp()
os.environ["WORLD_SIZE"] = str(self.world_size)
os.environ["BACKEND"] = BACKEND
self._spawn_processes()
@property
def device(self):
return torch.device(self.rank) if BACKEND == dist.Backend.NCCL \
else torch.device("cpu")
@property
def world_size(self):
return WORLD_SIZE
@property
def process_group(self):
return dist.group.WORLD
def tearDown(self):
try:
dist.destroy_process_group()
except AssertionError:
pass
try:
os.remove(self.file_name)
except OSError:
pass
def dist_init(self, rank, world_size, backend=BACKEND):
store = dist.FileStore(self.file_name, world_size)
return dist.init_process_group(
backend=backend,
store=store,
rank=rank,
world_size=world_size
)
def construct_uneven_inputs(self, base, offset, device=None):
r"""
Returns uneven inputs: rank i gets ``base`` + i * ``offset`` inputs.
"""
if device is None:
device = self.device
return [torch.zeros(1, device=device) for _ in range(base + self.rank * offset)]
def construct_even_inputs(self, base, device=None):
r"""Returns even inputs: each rank gets ``base`` inputs."""
if device is None:
device = self.device
return [torch.zeros(1, device=device) for _ in range(base)]
@property
def base_num_inputs(self):
r"""Base number of inputs to be used by all ranks."""
return 3
@property
def offset(self):
r"""Rank i gets i * ``offset`` additional inputs."""
return 1
def _test_join_base(
self,
uneven_inputs: bool,
num_joinables: int,
enable: bool,
throw_on_early_termination: bool,
num_allreduces: int,
run_post_hooks: bool,
expected_total: Optional[int] = None,
):
r"""
Skeleton for all :class:`Join` tests.
Arguments:
uneven_inputs (bool): ``True`` to use uneven inputs; ``False``
otherwise.
num_joinables (int): number of :class:`AllReducer` s to construct.
enable (bool): ``True`` to enable the join context manager;
``False`` otherwise.
throw_on_early_termination (bool): ``True`` to raise an exception
upon detecting uneven inputs; ``False`` otherwise.
num_allreduces (int): number of all-reduces to perform per input.
run_post_hooks (bool): ``True`` to run post-hooks; ``False``
otherwise.
expected_total (Optional[int]): ``None`` to not check the expected
all-reduce total; otherwise, the expected total; default is
``None``.
"""
self.dist_init(self.rank, self.world_size)
allreducers = [
AllReducer(self.device, self.process_group)
for _ in range(num_joinables)
]
for allreducer in allreducers:
self.assertEqual(allreducer.post_hook_tensor.item(), BEFORE_CONSTANT)
inputs = self.construct_uneven_inputs(self.base_num_inputs, self.offset) \
if uneven_inputs \
else self.construct_even_inputs(self.base_num_inputs)
allreduce_total = 0
# Expect a `RuntimeError` if `throw_on_early_termination=True`
# Rank 0 exhausts its inputs first
expected_msg = "Rank 0 exhausted all inputs." if self.rank == 0 \
else "Detected at least one rank that exhausted inputs. " \
"Throwing across all ranks."
with self.assertRaisesRegex(
RuntimeError,
expected_msg
) if throw_on_early_termination else contextlib.suppress():
with Join(
allreducers,
enable=enable,
throw_on_early_termination=throw_on_early_termination,
num_allreduces=num_allreduces,
run_post_hooks=run_post_hooks
):
for _ in inputs:
for allreducer in allreducers:
allreduce_total += allreducer(num_allreduces)
if throw_on_early_termination:
return
# Check `expected_total` if not `None`
if expected_total:
self.assertEqual(allreduce_total, expected_total)
# All `AllReduce` instances should receive the updated
# `post_hook_tensor` from the last-joined process
if run_post_hooks:
for allreducer in allreducers:
self.assertEqual(allreducer.post_hook_tensor.item(), AFTER_CONSTANT)
@require_n_gpus_for_nccl_backend(
WORLD_SIZE, BACKEND
)
def test_single_joinable_main_hooks(self):
r"""Tests the main hooks of a single :class:`Joinable`."""
num_joinables = 1
num_allreduces = 1
run_post_hooks = False
# Non-joined processes all-reduce a 1, so this rank's all-reduce total
# should be precisely equal to the total number of inputs processed
# before it joined
expected_total = self.world_size * self.base_num_inputs
# Rank i runs for i additional iterations
for num_joined in range(1, self.rank + 1):
expected_total += (self.world_size - num_joined) * self.offset
self._test_join_base(
uneven_inputs=True,
num_joinables=num_joinables,
enable=True,
throw_on_early_termination=False,
num_allreduces=num_allreduces,
run_post_hooks=run_post_hooks,
expected_total=expected_total
)
@require_n_gpus_for_nccl_backend(
WORLD_SIZE, BACKEND
)
def test_single_joinable_post_hooks(self):
r"""Tests the post-hooks of a single :class:`Joinable`."""
num_joinables = 1
num_allreduces = 0 # set to 0 to skip the main hooks
run_post_hooks = False
self._test_join_base(
uneven_inputs=True,
num_joinables=num_joinables,
enable=True,
throw_on_early_termination=False,
num_allreduces=num_allreduces,
run_post_hooks=run_post_hooks,
expected_total=None
)
@require_n_gpus_for_nccl_backend(
WORLD_SIZE, BACKEND
)
def test_single_joinable(self):
r"""
Tests the main hooks and post-hooks of a single :class:`Joinable`
together.
This combines ``test_single_joinable_main_hooks()`` and
``test_single_joinable_post_hooks()`` into a single test to ensure that
main hooks and post-hooks operate correctly together.
"""
num_joinables = 1
num_allreduces = 1
run_post_hooks = True
expected_total = self.world_size * self.base_num_inputs
for num_joined in range(1, self.rank + 1):
expected_total += (self.world_size - num_joined) * self.offset
self._test_join_base(
uneven_inputs=True,
num_joinables=num_joinables,
enable=True,
throw_on_early_termination=False,
num_allreduces=num_allreduces,
run_post_hooks=run_post_hooks,
expected_total=expected_total
)
@require_n_gpus_for_nccl_backend(
WORLD_SIZE, BACKEND
)
def test_multiple_joinables(self):
r"""
Tests the main hooks and post-hooks of multiple :class:`Joinable` s
together.
This generalizes ``test_single_joinable()`` to multiple
:class:`Joinable` s.
"""
num_joinables = 3
num_allreduces = 1
run_post_hooks = True
expected_total = self.world_size * self.base_num_inputs
for num_joined in range(1, self.rank + 1):
expected_total += (self.world_size - num_joined) * self.offset
# The expected total is now multiplied by a factor of `NUM_JOINABLES`
expected_total *= num_joinables
self._test_join_base(
uneven_inputs=True,
num_joinables=num_joinables,
enable=True,
throw_on_early_termination=False,
num_allreduces=num_allreduces,
run_post_hooks=run_post_hooks,
expected_total=expected_total
)
@require_n_gpus_for_nccl_backend(
WORLD_SIZE, BACKEND
)
def test_single_joinable_disable(self):
r"""Tests ``enable=False`` for a single :class:`Joinable`."""
num_joinables = 1
num_allreduces = 1
uneven_inputs = False
enable = False
run_post_hooks = False
expected_total = self.world_size * self.base_num_inputs
self._test_join_base(
uneven_inputs=uneven_inputs,
num_joinables=num_joinables,
enable=enable,
throw_on_early_termination=False,
num_allreduces=num_allreduces,
run_post_hooks=run_post_hooks,
expected_total=expected_total
)
@require_n_gpus_for_nccl_backend(
WORLD_SIZE, BACKEND
)
def test_multiple_joinable_disable(self):
r"""
Tests ``enable=False`` for multiple :class:`Joinable` s.
This generalizes ``test_single_joinable_disable`` to multiple
:class:`Joinable` s.
"""
num_joinables = 3
num_allreduces = 1
uneven_inputs = False
enable = False
run_post_hooks = False
expected_total = self.world_size * self.base_num_inputs * num_joinables
self._test_join_base(
uneven_inputs=uneven_inputs,
num_joinables=num_joinables,
enable=enable,
throw_on_early_termination=False,
num_allreduces=num_allreduces,
run_post_hooks=run_post_hooks,
expected_total=expected_total
)
@require_n_gpus_for_nccl_backend(
WORLD_SIZE, BACKEND
)
def test_single_joinable_throw(self):
r"""
Tests ``throw_on_early_termination=True`` for a single
:class:`Joinable`.
"""
num_joinables = 1
num_allreduces = 1
throw_on_early_termination = True
run_post_hooks = False
self._test_join_base(
uneven_inputs=True,
num_joinables=num_joinables,
enable=True,
throw_on_early_termination=throw_on_early_termination,
num_allreduces=num_allreduces,
run_post_hooks=run_post_hooks,
expected_total=None
)
@require_n_gpus_for_nccl_backend(
WORLD_SIZE, BACKEND
)
def test_multiple_joinables_throw(self):
r"""
Tests ``throw_on_early_termination=True`` for multiple
:class:`Joinable` s together.
This generalizes ``test_single_joinable_throw`` to multiple
:class:`Joinable` s.
"""
num_joinables = 3
num_allreduces = 1
throw_on_early_termination = True
run_post_hooks = False
self._test_join_base(
uneven_inputs=True,
num_joinables=num_joinables,
enable=True,
throw_on_early_termination=throw_on_early_termination,
num_allreduces=num_allreduces,
run_post_hooks=run_post_hooks,
expected_total=None
)
@require_n_gpus_for_nccl_backend(
WORLD_SIZE, BACKEND
)
def test_join_kwargs(self):
r"""
Tests passing keyword arguments to the context manager.
"""
num_joinables = 1
num_allreduces = 2
run_post_hooks = False
expected_total = self.world_size * self.base_num_inputs
for num_joined in range(1, self.rank + 1):
expected_total += (self.world_size - num_joined) * self.offset
# The expected total is now multiplied by a factor of `NUM_ALLREDUCES`
expected_total *= num_allreduces
self._test_join_base(
uneven_inputs=True,
num_joinables=num_joinables,
enable=True,
throw_on_early_termination=False,
num_allreduces=num_allreduces,
run_post_hooks=run_post_hooks,
expected_total=expected_total
)
if __name__ == "__main__":
run_tests()