| # Owner(s): ["oncall: distributed"] |
| |
| import contextlib |
| import os |
| import sys |
| from typing import Any, Optional |
| |
| import torch |
| import torch.distributed as dist |
| |
| if not dist.is_available(): |
| print("Distributed not available, skipping tests", file=sys.stderr) |
| sys.exit(0) |
| |
| from torch.distributed.algorithms.join import Join, Joinable, JoinHook |
| from torch.testing._internal.common_distributed import ( |
| MultiProcessTestCase, |
| require_n_gpus_for_nccl_backend, |
| ) |
| from torch.testing._internal.common_utils import run_tests, TEST_WITH_DEV_DBG_ASAN |
| |
| if TEST_WITH_DEV_DBG_ASAN: |
| print("Skip dev-asan as torch + multiprocessing spawn have known issues", file=sys.stderr) |
| sys.exit(0) |
| |
| BACKEND = dist.Backend.NCCL if torch.cuda.is_available() else dist.Backend.GLOO |
| WORLD_SIZE = min(4, max(2, torch.cuda.device_count())) |
| |
| # Constants used for testing post-hooks |
| BEFORE_CONSTANT = 41 |
| AFTER_CONSTANT = 42 |
| |
| |
| class AllReducerJoinHook(JoinHook): |
| r""" |
| Join hook for :class:`AllReducer`. |
| |
| Arguments: |
| allreducer (AllReducer): the :class:`AllReducer` object using this |
| hook. |
| num_allreduces (int): the number of all-reduces to shadow per |
| iteration. |
| run_post_hook (bool): a flag enabling the post-hook logic. |
| """ |
| def __init__( |
| self, |
| allreducer, |
| num_allreduces, |
| run_post_hook |
| ): |
| self.allreducer = allreducer |
| self.num_allreduces = num_allreduces |
| self.run_post_hook = run_post_hook |
| |
| def main_hook(self): |
| r""" |
| Shadows each all-reduce; the number of all-reduces is passed into the |
| constructor as ``num_allreduces``. |
| """ |
| device = self.allreducer.device |
| for _ in range(self.num_allreduces): |
| t = torch.zeros(1, device=device) |
| dist.all_reduce(t) |
| |
| def post_hook(self, is_last_joiner: bool): |
| r""" |
| Broadcasts a tensor containing a magic constant ``AFTER_CONSTANT`` from |
| the last joiner to all other processes. |
| """ |
| if not self.run_post_hook: |
| return |
| rank = dist.get_rank(self.allreducer.process_group) |
| common_rank = self.allreducer.find_common_rank(rank, is_last_joiner) |
| device = self.allreducer.device |
| if rank == common_rank: |
| self.allreducer.post_hook_tensor = torch.tensor([AFTER_CONSTANT], device=device) |
| dist.broadcast(self.allreducer.post_hook_tensor, src=common_rank) |
| |
| |
| class AllReducer(Joinable): |
| r""" |
| Example :class:`Joinable` that performs some number of all-reduces as its |
| per-iteration collective communication. |
| """ |
| def __init__(self, device, process_group): |
| super(AllReducer, self).__init__() |
| self.device = device |
| self.process_group = process_group |
| self.post_hook_tensor = torch.tensor([BEFORE_CONSTANT], device=self.device) |
| |
| def __call__(self, num_allreduces=1): |
| r""" |
| All-reduces a dim-1 one tensor ``num_allreduces``-many times, and |
| returns the total result. |
| """ |
| Join.notify_join_context(self) |
| device = self.device |
| total = 0 |
| for _ in range(num_allreduces): |
| t = torch.ones(1, device=device) |
| dist.all_reduce(t) |
| total += t.item() |
| return total |
| |
| def join_hook(self, **kwargs) -> JoinHook: |
| r""" |
| Returns a join hook that shadows some number of all-reduces; by default, |
| this number is 1. |
| """ |
| num_allreduces = kwargs.get("num_allreduces", 1) |
| run_post_hook = kwargs.get("run_post_hooks", False) |
| return AllReducerJoinHook( |
| self, |
| num_allreduces, |
| run_post_hook |
| ) |
| |
| @property |
| def join_device(self) -> torch.device: |
| return self.device |
| |
| @property |
| def join_process_group(self) -> Any: |
| return self.process_group |
| |
| def find_common_rank(self, rank, to_consider): |
| r""" |
| Returns the max rank of the ones to consider over the process group. |
| """ |
| common_rank = torch.tensor( |
| [rank if to_consider else -1], |
| device=self.device |
| ) |
| dist.all_reduce(common_rank, op=dist.ReduceOp.MAX, group=self.process_group) |
| common_rank = common_rank.item() |
| assert common_rank >= 0 |
| return common_rank |
| |
| class TestJoin(MultiProcessTestCase): |
| r"""Test cases for the generic join context.""" |
| def setUp(self): |
| super(TestJoin, self).setUp() |
| os.environ["WORLD_SIZE"] = str(self.world_size) |
| os.environ["BACKEND"] = BACKEND |
| self._spawn_processes() |
| |
| @property |
| def device(self): |
| return torch.device(self.rank) if BACKEND == dist.Backend.NCCL \ |
| else torch.device("cpu") |
| |
| @property |
| def world_size(self): |
| return WORLD_SIZE |
| |
| @property |
| def process_group(self): |
| return dist.group.WORLD |
| |
| def tearDown(self): |
| try: |
| dist.destroy_process_group() |
| except AssertionError: |
| pass |
| try: |
| os.remove(self.file_name) |
| except OSError: |
| pass |
| |
| def dist_init(self, rank, world_size, backend=BACKEND): |
| store = dist.FileStore(self.file_name, world_size) |
| return dist.init_process_group( |
| backend=backend, |
| store=store, |
| rank=rank, |
| world_size=world_size |
| ) |
| |
| def construct_uneven_inputs(self, base, offset, device=None): |
| r""" |
| Returns uneven inputs: rank i gets ``base`` + i * ``offset`` inputs. |
| """ |
| if device is None: |
| device = self.device |
| return [torch.zeros(1, device=device) for _ in range(base + self.rank * offset)] |
| |
| def construct_even_inputs(self, base, device=None): |
| r"""Returns even inputs: each rank gets ``base`` inputs.""" |
| if device is None: |
| device = self.device |
| return [torch.zeros(1, device=device) for _ in range(base)] |
| |
| @property |
| def base_num_inputs(self): |
| r"""Base number of inputs to be used by all ranks.""" |
| return 3 |
| |
| @property |
| def offset(self): |
| r"""Rank i gets i * ``offset`` additional inputs.""" |
| return 1 |
| |
| def _test_join_base( |
| self, |
| uneven_inputs: bool, |
| num_joinables: int, |
| enable: bool, |
| throw_on_early_termination: bool, |
| num_allreduces: int, |
| run_post_hooks: bool, |
| expected_total: Optional[int] = None, |
| ): |
| r""" |
| Skeleton for all :class:`Join` tests. |
| |
| Arguments: |
| uneven_inputs (bool): ``True`` to use uneven inputs; ``False`` |
| otherwise. |
| num_joinables (int): number of :class:`AllReducer` s to construct. |
| enable (bool): ``True`` to enable the join context manager; |
| ``False`` otherwise. |
| throw_on_early_termination (bool): ``True`` to raise an exception |
| upon detecting uneven inputs; ``False`` otherwise. |
| num_allreduces (int): number of all-reduces to perform per input. |
| run_post_hooks (bool): ``True`` to run post-hooks; ``False`` |
| otherwise. |
| expected_total (Optional[int]): ``None`` to not check the expected |
| all-reduce total; otherwise, the expected total; default is |
| ``None``. |
| """ |
| self.dist_init(self.rank, self.world_size) |
| |
| allreducers = [ |
| AllReducer(self.device, self.process_group) |
| for _ in range(num_joinables) |
| ] |
| for allreducer in allreducers: |
| self.assertEqual(allreducer.post_hook_tensor.item(), BEFORE_CONSTANT) |
| |
| inputs = self.construct_uneven_inputs(self.base_num_inputs, self.offset) \ |
| if uneven_inputs \ |
| else self.construct_even_inputs(self.base_num_inputs) |
| allreduce_total = 0 |
| |
| # Expect a `RuntimeError` if `throw_on_early_termination=True` |
| # Rank 0 exhausts its inputs first |
| expected_msg = "Rank 0 exhausted all inputs." if self.rank == 0 \ |
| else "Detected at least one rank that exhausted inputs. " \ |
| "Throwing across all ranks." |
| with self.assertRaisesRegex( |
| RuntimeError, |
| expected_msg |
| ) if throw_on_early_termination else contextlib.suppress(): |
| with Join( |
| allreducers, |
| enable=enable, |
| throw_on_early_termination=throw_on_early_termination, |
| num_allreduces=num_allreduces, |
| run_post_hooks=run_post_hooks |
| ): |
| for _ in inputs: |
| for allreducer in allreducers: |
| allreduce_total += allreducer(num_allreduces) |
| |
| if throw_on_early_termination: |
| return |
| |
| # Check `expected_total` if not `None` |
| if expected_total: |
| self.assertEqual(allreduce_total, expected_total) |
| |
| # All `AllReduce` instances should receive the updated |
| # `post_hook_tensor` from the last-joined process |
| if run_post_hooks: |
| for allreducer in allreducers: |
| self.assertEqual(allreducer.post_hook_tensor.item(), AFTER_CONSTANT) |
| |
| @require_n_gpus_for_nccl_backend( |
| WORLD_SIZE, BACKEND |
| ) |
| def test_single_joinable_main_hooks(self): |
| r"""Tests the main hooks of a single :class:`Joinable`.""" |
| num_joinables = 1 |
| num_allreduces = 1 |
| run_post_hooks = False |
| # Non-joined processes all-reduce a 1, so this rank's all-reduce total |
| # should be precisely equal to the total number of inputs processed |
| # before it joined |
| expected_total = self.world_size * self.base_num_inputs |
| # Rank i runs for i additional iterations |
| for num_joined in range(1, self.rank + 1): |
| expected_total += (self.world_size - num_joined) * self.offset |
| |
| self._test_join_base( |
| uneven_inputs=True, |
| num_joinables=num_joinables, |
| enable=True, |
| throw_on_early_termination=False, |
| num_allreduces=num_allreduces, |
| run_post_hooks=run_post_hooks, |
| expected_total=expected_total |
| ) |
| |
| @require_n_gpus_for_nccl_backend( |
| WORLD_SIZE, BACKEND |
| ) |
| def test_single_joinable_post_hooks(self): |
| r"""Tests the post-hooks of a single :class:`Joinable`.""" |
| num_joinables = 1 |
| num_allreduces = 0 # set to 0 to skip the main hooks |
| run_post_hooks = False |
| |
| self._test_join_base( |
| uneven_inputs=True, |
| num_joinables=num_joinables, |
| enable=True, |
| throw_on_early_termination=False, |
| num_allreduces=num_allreduces, |
| run_post_hooks=run_post_hooks, |
| expected_total=None |
| ) |
| |
| @require_n_gpus_for_nccl_backend( |
| WORLD_SIZE, BACKEND |
| ) |
| def test_single_joinable(self): |
| r""" |
| Tests the main hooks and post-hooks of a single :class:`Joinable` |
| together. |
| |
| This combines ``test_single_joinable_main_hooks()`` and |
| ``test_single_joinable_post_hooks()`` into a single test to ensure that |
| main hooks and post-hooks operate correctly together. |
| """ |
| num_joinables = 1 |
| num_allreduces = 1 |
| run_post_hooks = True |
| |
| expected_total = self.world_size * self.base_num_inputs |
| for num_joined in range(1, self.rank + 1): |
| expected_total += (self.world_size - num_joined) * self.offset |
| |
| self._test_join_base( |
| uneven_inputs=True, |
| num_joinables=num_joinables, |
| enable=True, |
| throw_on_early_termination=False, |
| num_allreduces=num_allreduces, |
| run_post_hooks=run_post_hooks, |
| expected_total=expected_total |
| ) |
| |
| @require_n_gpus_for_nccl_backend( |
| WORLD_SIZE, BACKEND |
| ) |
| def test_multiple_joinables(self): |
| r""" |
| Tests the main hooks and post-hooks of multiple :class:`Joinable` s |
| together. |
| |
| This generalizes ``test_single_joinable()`` to multiple |
| :class:`Joinable` s. |
| """ |
| num_joinables = 3 |
| num_allreduces = 1 |
| run_post_hooks = True |
| |
| expected_total = self.world_size * self.base_num_inputs |
| for num_joined in range(1, self.rank + 1): |
| expected_total += (self.world_size - num_joined) * self.offset |
| # The expected total is now multiplied by a factor of `NUM_JOINABLES` |
| expected_total *= num_joinables |
| |
| self._test_join_base( |
| uneven_inputs=True, |
| num_joinables=num_joinables, |
| enable=True, |
| throw_on_early_termination=False, |
| num_allreduces=num_allreduces, |
| run_post_hooks=run_post_hooks, |
| expected_total=expected_total |
| ) |
| |
| @require_n_gpus_for_nccl_backend( |
| WORLD_SIZE, BACKEND |
| ) |
| def test_single_joinable_disable(self): |
| r"""Tests ``enable=False`` for a single :class:`Joinable`.""" |
| num_joinables = 1 |
| num_allreduces = 1 |
| uneven_inputs = False |
| enable = False |
| run_post_hooks = False |
| |
| expected_total = self.world_size * self.base_num_inputs |
| |
| self._test_join_base( |
| uneven_inputs=uneven_inputs, |
| num_joinables=num_joinables, |
| enable=enable, |
| throw_on_early_termination=False, |
| num_allreduces=num_allreduces, |
| run_post_hooks=run_post_hooks, |
| expected_total=expected_total |
| ) |
| |
| @require_n_gpus_for_nccl_backend( |
| WORLD_SIZE, BACKEND |
| ) |
| def test_multiple_joinable_disable(self): |
| r""" |
| Tests ``enable=False`` for multiple :class:`Joinable` s. |
| |
| This generalizes ``test_single_joinable_disable`` to multiple |
| :class:`Joinable` s. |
| """ |
| num_joinables = 3 |
| num_allreduces = 1 |
| uneven_inputs = False |
| enable = False |
| run_post_hooks = False |
| |
| expected_total = self.world_size * self.base_num_inputs * num_joinables |
| |
| self._test_join_base( |
| uneven_inputs=uneven_inputs, |
| num_joinables=num_joinables, |
| enable=enable, |
| throw_on_early_termination=False, |
| num_allreduces=num_allreduces, |
| run_post_hooks=run_post_hooks, |
| expected_total=expected_total |
| ) |
| |
| @require_n_gpus_for_nccl_backend( |
| WORLD_SIZE, BACKEND |
| ) |
| def test_single_joinable_throw(self): |
| r""" |
| Tests ``throw_on_early_termination=True`` for a single |
| :class:`Joinable`. |
| """ |
| num_joinables = 1 |
| num_allreduces = 1 |
| throw_on_early_termination = True |
| run_post_hooks = False |
| |
| self._test_join_base( |
| uneven_inputs=True, |
| num_joinables=num_joinables, |
| enable=True, |
| throw_on_early_termination=throw_on_early_termination, |
| num_allreduces=num_allreduces, |
| run_post_hooks=run_post_hooks, |
| expected_total=None |
| ) |
| |
| @require_n_gpus_for_nccl_backend( |
| WORLD_SIZE, BACKEND |
| ) |
| def test_multiple_joinables_throw(self): |
| r""" |
| Tests ``throw_on_early_termination=True`` for multiple |
| :class:`Joinable` s together. |
| |
| This generalizes ``test_single_joinable_throw`` to multiple |
| :class:`Joinable` s. |
| """ |
| num_joinables = 3 |
| num_allreduces = 1 |
| throw_on_early_termination = True |
| run_post_hooks = False |
| |
| self._test_join_base( |
| uneven_inputs=True, |
| num_joinables=num_joinables, |
| enable=True, |
| throw_on_early_termination=throw_on_early_termination, |
| num_allreduces=num_allreduces, |
| run_post_hooks=run_post_hooks, |
| expected_total=None |
| ) |
| |
| @require_n_gpus_for_nccl_backend( |
| WORLD_SIZE, BACKEND |
| ) |
| def test_join_kwargs(self): |
| r""" |
| Tests passing keyword arguments to the context manager. |
| """ |
| num_joinables = 1 |
| num_allreduces = 2 |
| run_post_hooks = False |
| |
| expected_total = self.world_size * self.base_num_inputs |
| for num_joined in range(1, self.rank + 1): |
| expected_total += (self.world_size - num_joined) * self.offset |
| # The expected total is now multiplied by a factor of `NUM_ALLREDUCES` |
| expected_total *= num_allreduces |
| |
| self._test_join_base( |
| uneven_inputs=True, |
| num_joinables=num_joinables, |
| enable=True, |
| throw_on_early_termination=False, |
| num_allreduces=num_allreduces, |
| run_post_hooks=run_post_hooks, |
| expected_total=expected_total |
| ) |
| |
| if __name__ == "__main__": |
| run_tests() |