blob: c9bafe0dd8622560b532294d43ad2a5ba9dc42d9 [file] [log] [blame]
# Owner(s): ["oncall: distributed"]
import os
import sys
from datetime import timedelta
import torch
import torch.distributed as c10d
if not c10d.is_available():
print("c10d not available, skipping tests", file=sys.stderr)
sys.exit(0)
from test_c10d_common import LOOPBACK
from torch.testing._internal.common_distributed import (
MultiProcessTestCase,
requires_nccl,
requires_gloo,
skip_if_lt_x_gpu,
with_dist_debug_levels,
create_device,
)
from torch.testing._internal.common_utils import (
run_tests,
TEST_WITH_DEV_DBG_ASAN,
)
class AbstractProcessGroupWrapperTest(MultiProcessTestCase):
def setUp(self):
super(AbstractProcessGroupWrapperTest, self).setUp()
self._spawn_processes()
def _validate_error(self, exception, op_type, rank, tensor):
err = str(exception)
self.assertTrue(
op_type in err, f"Got {err} but expected {op_type} to be in error."
)
# User doesn't call barrier with tensor.
if op_type != "BARRIER":
self.assertTrue(
f"{list(tensor.shape)}" in err,
f"Did not find shapes {list(tensor.shape)} in error {err}",
)
# For CUDA, only assert on device type, not index
if "cuda" in str(tensor.device):
self.assertTrue(
"cuda" in err, f"Did not find cuda device in error {err}"
)
else:
self.assertTrue(
str(tensor.device) in err,
f"Did not find tensor device {str(tensor.device)} in error {err}",
)
# C++ and python type strings are not exactly the same.
if "float" in str(tensor.dtype):
self.assertTrue("Float" in err, "Expected Float type")
elif "int" in str(tensor.dtype):
self.assertTrue("Long" in err, "Expected Long type")
else:
self.fail(f"Unexpected dtype {str(tensor.dtype)} for error {err}")
def _test_collective_hang(self, wrapper_pg, use_cuda=False):
# All ranks besides 1 call allreduce and wrapper_pg should detect a hang
# and report an issue with rank 1.
faulty_rank = 1
if self.rank != faulty_rank:
tensor = torch.randn(20, 10)
if use_cuda:
tensor = tensor.to(self.rank)
if self.rank == 0:
# Rank 0 reports faulty ranks
err = f"Ranks {faulty_rank} failed to pass monitoredBarrier"
else:
err = "Please check rank 0 logs for faulty rank"
# Gloo can sometimes throw the following error if a rank exits early
# before rank 0 calls into the allreduce.
err += "|Connection closed by peer|Connection reset by peer"
with self.assertRaisesRegex(RuntimeError, err):
wrapper_pg.allreduce([tensor])
def _test_collectives_op_mismatch(self, wrapper_pg, use_cuda=False):
tensor = torch.randn(20, 10)
if use_cuda:
tensor = tensor.to(self.rank)
works = []
# Run a few successful collectives
for _ in range(10):
work = wrapper_pg.allreduce([tensor])
works.append(work)
for w in works:
w.wait()
# Simulate mismatch: allreduce vs reduce.
# Error including info about inconsistent collective, rank, tensor
# shape, device, and dtype should be raised.
with self.assertRaisesRegex(RuntimeError, ".*") as cm:
if self.rank == 0:
wrapper_pg.allreduce([tensor])
else:
wrapper_pg.reduce([tensor])
self._validate_error(
exception=cm.exception,
op_type="ALLREDUCE" if self.rank == 0 else "REDUCE",
rank=self.rank,
tensor=tensor,
)
with self.assertRaisesRegex(RuntimeError, ".*") as cm:
if self.rank == 0:
wrapper_pg.reduce([tensor])
else:
wrapper_pg.barrier()
self._validate_error(
exception=cm.exception,
op_type="REDUCE" if self.rank == 0 else "BARRIER",
rank=self.rank,
tensor=tensor,
)
with self.assertRaisesRegex(RuntimeError, ".*") as cm:
scatter_result = [torch.ones(4) * i for i in range(self.world_size)]
scattered_tensor = torch.empty(4)
if self.rank == 0:
wrapper_pg.scatter(scattered_tensor, scatter_result, 0)
else:
wrapper_pg.reduce_scatter(scattered_tensor, scatter_result)
self._validate_error(
exception=cm.exception,
op_type="SCATTER" if self.rank == 0 else "REDUCE_SCATTER",
rank=self.rank,
tensor=scattered_tensor,
)
with self.assertRaisesRegex(RuntimeError, ".*") as cm:
if self.rank == 0:
wrapper_pg.broadcast(tensor, 0)
else:
output_tensors = [
torch.zeros_like(tensor) for _ in range(self.world_size)
]
wrapper_pg.allgather([output_tensors], [tensor])
self._validate_error(
exception=cm.exception,
op_type="BROADCAST" if self.rank == 0 else "ALLGATHER",
rank=self.rank,
tensor=tensor,
)
def _test_collective_shape_mismatch(self, wrapper_pg, use_cuda=False):
wrapper_pg.barrier()
dim = 2 if self.rank == 0 else 10
tensor = torch.randn(20, dim)
if use_cuda:
tensor = tensor.to(self.rank)
with self.assertRaisesRegex(RuntimeError, ".*") as cm:
wrapper_pg.allreduce([tensor])
self._validate_error(
exception=cm.exception,
op_type="ALLREDUCE",
rank=self.rank,
tensor=tensor,
)
# Check errors are raised when dimensionality of shapes is different
tensor = torch.randn(20, 10, 2) if self.rank == 0 else torch.randn(20, 10)
if use_cuda:
tensor = tensor.to(self.rank)
with self.assertRaisesRegex(RuntimeError, ".*") as cm:
wrapper_pg.allreduce([tensor])
self._validate_error(
exception=cm.exception,
op_type="ALLREDUCE",
rank=self.rank,
tensor=tensor,
)
# Check shape errors with scatter
input = [
torch.tensor(
[self.rank] if self.rank == 0 else [self.rank, self.rank],
device=self.rank if use_cuda else "cpu",
)
for _ in range(self.world_size)
]
outputs = [
torch.tensor(
[-1] if self.rank == 0 else [-1, -1],
device=self.rank if use_cuda else "cpu",
)
for _ in range(self.world_size)
]
root_rank = 0
opts = c10d.ScatterOptions()
opts.rootRank = root_rank
with self.assertRaisesRegex(RuntimeError, ".*") as cm:
if self.rank == root_rank:
wrapper_pg.scatter([outputs[self.rank]], [input], opts).wait()
else:
wrapper_pg.scatter([outputs[self.rank]], [], opts).wait()
self._validate_error(
exception=cm.exception,
op_type="SCATTER",
rank=self.rank,
tensor=outputs[self.rank],
)
# ASAN is not safe since we are spawning processes.
if not TEST_WITH_DEV_DBG_ASAN:
@requires_gloo()
@requires_nccl()
class ProcessGroupNCCLWrapperTest(AbstractProcessGroupWrapperTest):
def setUp(self):
super(AbstractProcessGroupWrapperTest, self).setUp()
self._spawn_processes()
# NCCL_BLOCKING_WAIT overrides NCCL_ASYNC_ERROR_HANDLING hence tests
# that use NCCL_BLOCKING_WAIT will test it as expected.
os.environ["NCCL_ASYNC_ERROR_HANDLING"] = "1"
@property
def world_size(self) -> int:
return 2
def _create_wrapper_pg(self, with_new_group=False, timeout=10.0):
store = c10d.FileStore(self.file_name, self.world_size)
c10d.init_process_group(
backend="nccl",
rank=self.rank,
world_size=self.world_size,
store=store,
timeout=timedelta(seconds=timeout),
)
if with_new_group:
pg = c10d.new_group(backend="nccl", timeout=timedelta(seconds=timeout))
else:
_pg = c10d.ProcessGroupNCCL(
store, self.rank, self.world_size, timeout=timedelta(seconds=timeout)
)
pg = c10d._create_process_group_wrapper(
_pg,
"unused",
store,
self.rank,
self.world_size,
timeout=timeout,
)
return pg
@requires_nccl()
@skip_if_lt_x_gpu(2)
def test_collective_hang(self):
pg = self._create_wrapper_pg(timeout=2.0)
self._test_collective_hang(pg)
# NOTE: these tests are separated by debug level instead of combined into
# one due to https://github.com/pytorch/pytorch/issues/55967, they can be
# combined after that is resolved.
@requires_nccl()
@skip_if_lt_x_gpu(2)
@with_dist_debug_levels(levels=["DETAIL"])
def test_collectives_op_mismatch_debug_mode(self):
pg = self._create_wrapper_pg(with_new_group=True)
self._test_collectives_op_mismatch(pg, use_cuda=True)
self._test_nccl_only_op_mismatch(pg)
@requires_nccl()
@skip_if_lt_x_gpu(2)
@with_dist_debug_levels(levels=["OFF"])
def test_collectives_op_mismatch(self):
pg = self._create_wrapper_pg(with_new_group=False)
self._test_collectives_op_mismatch(pg, use_cuda=True)
self._test_nccl_only_op_mismatch(pg)
@requires_nccl()
@skip_if_lt_x_gpu(2)
@with_dist_debug_levels(levels=["DETAIL"])
def test_collective_shape_mismatch_debug_mode(self):
pg = self._create_wrapper_pg(with_new_group=True)
self._test_collective_shape_mismatch(pg, use_cuda=True)
self._test_nccl_only_shape_mismatch(pg)
@requires_nccl()
@skip_if_lt_x_gpu(2)
@with_dist_debug_levels(levels=["OFF"])
def test_collective_shape_mismatch(self):
pg = self._create_wrapper_pg(with_new_group=False)
self._test_collective_shape_mismatch(pg, use_cuda=True)
self._test_nccl_only_shape_mismatch(pg)
def _test_nccl_only_op_mismatch(self, wrapper_pg):
device = f"cuda:{self.rank}"
with self.assertRaisesRegex(RuntimeError, ".*") as cm:
output = torch.zeros(4 + self.rank, device=device)
input = torch.ones(4 * self.world_size, device=device)
if self.rank == 0:
wrapper_pg._allgather_base(output, input).wait()
else:
wrapper_pg._reduce_scatter_base(output, input).wait()
self._validate_error(
exception=cm.exception,
op_type="ALLGATHER_BASE" if self.rank == 0 else "REDUCE_SCATTER_BASE",
rank=self.rank,
tensor=input,
)
def _test_nccl_only_shape_mismatch(self, wrapper_pg):
device = f"cuda:{self.rank}"
with self.assertRaisesRegex(RuntimeError, ".*") as cm:
output = torch.zeros(4 + self.rank, device=device)
input = torch.ones(4 * self.world_size, device=device)
wrapper_pg._reduce_scatter_base(output, input).wait()
self._validate_error(
exception=cm.exception,
op_type="REDUCE_SCATTER_BASE",
rank=self.rank,
tensor=input,
)
with self.assertRaisesRegex(RuntimeError, ".*") as cm:
output = torch.zeros(4, device=device)
input = torch.ones((4 + self.rank) * self.world_size, device=device)
wrapper_pg._reduce_scatter_base(output, input).wait()
self._validate_error(
exception=cm.exception,
op_type="REDUCE_SCATTER_BASE",
rank=self.rank,
tensor=input,
)
@requires_gloo()
class ProcessGroupGlooWrapperTest(AbstractProcessGroupWrapperTest):
def setUp(self):
super(ProcessGroupGlooWrapperTest, self).setUp()
def opts(self, threads=2, timeout=10.0):
opts = c10d.ProcessGroupGloo._Options()
opts._timeout = timeout
opts._devices = [create_device(interface=LOOPBACK)]
opts._threads = threads
return opts
def _create_wrapper_pg(self, with_new_group=False, timeout=10.0):
store = c10d.FileStore(self.file_name, self.world_size)
c10d.init_process_group(
backend="gloo", rank=self.rank, world_size=self.world_size, store=store
)
if with_new_group:
pg = c10d.new_group(backend="gloo")
else:
_pg = c10d.ProcessGroupGloo(
store, self.rank, self.world_size, self.opts(timeout=timeout)
)
pg = c10d._create_process_group_wrapper(
_pg,
"unused",
store,
self.rank,
self.world_size,
timeout=timeout,
)
return pg
def test_collective_hang(self):
pg = self._create_wrapper_pg(timeout=2.0)
self._test_collective_hang(pg)
# NOTE: these tests are separated by debug level instead of combined into
# one due to https://github.com/pytorch/pytorch/issues/55967, they can be
# combined after that is resolved.
@with_dist_debug_levels(levels=["DETAIL"])
def test_collectives_op_mismatch_debug_mode(self):
pg = self._create_wrapper_pg(with_new_group=True)
self._test_collectives_op_mismatch(pg)
@with_dist_debug_levels(levels=["OFF"])
def test_collectives_op_mismatch(self):
pg = self._create_wrapper_pg(with_new_group=False)
self._test_collectives_op_mismatch(pg)
@with_dist_debug_levels(levels=["DETAIL"])
def test_collective_shape_mismatch_debug_mode(self):
pg = self._create_wrapper_pg(with_new_group=True)
self._test_collective_shape_mismatch(pg)
@with_dist_debug_levels(levels=["OFF"])
def test_collective_shape_mismatch(self):
pg = self._create_wrapper_pg(with_new_group=False)
self._test_collective_shape_mismatch(pg)
@skip_if_lt_x_gpu(4)
@with_dist_debug_levels(levels=["DETAIL"])
def test_collectives_op_mismatch_cuda_debug_mode(self):
pg = self._create_wrapper_pg(with_new_group=True)
self._test_collectives_op_mismatch(pg, use_cuda=True)
@skip_if_lt_x_gpu(4)
@with_dist_debug_levels(levels=["OFF"])
def test_collectives_op_mismatch_cuda(self):
pg = self._create_wrapper_pg(with_new_group=False)
self._test_collectives_op_mismatch(pg, use_cuda=True)
@skip_if_lt_x_gpu(4)
@with_dist_debug_levels(levels=["DETAIL"])
def test_collective_shape_mismatch_cuda_debug_mode(self):
pg = self._create_wrapper_pg(with_new_group=True)
self._test_collective_shape_mismatch(pg, use_cuda=True)
@skip_if_lt_x_gpu(4)
@with_dist_debug_levels(levels=["OFF"])
def test_collective_shape_mismatch_cuda(self):
pg = self._create_wrapper_pg(with_new_group=False)
self._test_collective_shape_mismatch(pg, use_cuda=True)
if __name__ == "__main__":
assert (
not torch.cuda._initialized
), "test_pg_wrapper must not have initialized CUDA context on main process"
run_tests()