blob: f3f9fbfe5ca5ea27af42fcff522a7cf7059df5b0 [file] [log] [blame]
# Owner(s): ["oncall: distributed"]
import os
import sys
import weakref
from functools import wraps, partial
import torch
import torch.distributed as dist
import torch.distributed._functional_collectives as ft_c
import torch.distributed._functional_collectives_impl as ft_c_impl
import torch.distributed.distributed_c10d as c10d
import torch.distributed._tensor as dt
from torch.testing import FileCheck
from functorch import make_fx
if not dist.is_available():
print("Distributed not available, skipping tests", file=sys.stderr)
sys.exit(0)
from torch.testing._internal.common_distributed import (
MultiThreadedTestCase,
MultiProcessTestCase,
requires_nccl,
skip_if_lt_x_gpu,
TEST_SKIPS
)
from torch.testing._internal.common_utils import (
instantiate_parametrized_tests,
parametrize,
run_tests,
TestCase,
)
def new_subgroups(group_size: int, pg_tag=None):
world_size = dist.get_world_size()
subgroups = []
cur_subgroup = None
for subgroup_id in range(world_size // group_size):
start_rank = subgroup_id * group_size
end_rank = start_rank + group_size
ranks_in_subgroup = list(range(start_rank, end_rank))
subgroup = c10d._new_group_with_tag(
ranks=ranks_in_subgroup,
pg_tag=pg_tag,
)
subgroups.append(subgroup)
rank = dist.get_rank()
if rank in ranks_in_subgroup:
cur_subgroup = subgroup
return cur_subgroup, subgroups
class TestExpand(MultiThreadedTestCase):
@property
def world_size(self):
return 4
def setUp(self):
super().setUp()
self._spawn_threads()
def test_expand_1d_rank_list(self):
tag, rankset, group_size = ft_c._expand_group([0, 1, 2, 3])
self.assertEqual("", tag)
self.assertEqual([0, 1, 2, 3], rankset)
self.assertEqual(4, group_size)
tag, rankset, group_size = ft_c._expand_group([0, 1, 2, 3], "bla")
self.assertEqual("bla", tag)
def test_expand_2d_rank_list(self):
tag, rankset, group_size = ft_c._expand_group([[0, 1], [2, 3]])
self.assertEqual("", tag)
self.assertEqual([0, 1, 2, 3], rankset)
self.assertEqual(2, group_size)
tag, rankset, group_size = ft_c._expand_group([[0, 1], [2, 3]], "blu")
self.assertEqual("blu", tag)
with self.assertRaisesRegex(ValueError, "group sizes must be identical"):
ft_c._expand_group([[0], [1, 2, 3]])
def test_expand_process_group(self):
tag, rankset, group_size = ft_c._expand_group(dist.group.WORLD)
self.assertEqual(c10d._get_group_tag(dist.group.WORLD), tag)
self.assertEqual([0, 1, 2, 3], rankset)
self.assertEqual(4, group_size)
tag, rankset, group_size = ft_c._expand_group(dist.group.WORLD, "bla")
self.assertEqual("bla", tag)
my_pg, others = new_subgroups(group_size=2)
tag, rankset, group_size = ft_c._expand_group(my_pg)
self.assertEqual(c10d._get_group_tag(my_pg), tag)
self.assertEqual(dist.get_process_group_ranks(my_pg), rankset)
self.assertEqual(2, group_size)
my_pg = None
for i in range(dist.get_world_size()):
group = c10d._new_group_with_tag([i], pg_tag="my_pg")
if i == dist.get_rank():
my_pg = group
tag, rankset, group_size = ft_c._expand_group(my_pg)
self.assertEqual("my_pg", tag)
self.assertEqual([dist.get_rank()], rankset)
self.assertEqual(1, group_size)
tag, rankset, group_size = ft_c._expand_group(my_pg, "bla")
self.assertEqual("bla", tag)
def test_expand_device_mesh(self):
mesh = dt.DeviceMesh("cpu", torch.arange(4))
tag, rankset, group_size = ft_c._expand_group(mesh)
self.assertEqual(c10d._get_group_tag(mesh.get_dim_groups()[0]), tag)
self.assertEqual([0, 1, 2, 3], rankset)
self.assertEqual(4, group_size)
mesh = dt.DeviceMesh("cpu", torch.arange(4))
tag, rankset, group_size = ft_c._expand_group(mesh)
self.assertEqual(c10d._get_group_tag(mesh.get_dim_groups()[0]), tag)
self.assertEqual([0, 1, 2, 3], rankset)
self.assertEqual(4, group_size)
def test_expand_device_mesh_tuple(self):
mesh = dt.DeviceMesh("cpu", torch.arange(4).view(2, 2))
with self.assertRaisesRegex(AssertionError, "Only 1D mesh"):
tag, rankset, group_size = ft_c._expand_group(mesh)
tag, rankset, group_size = ft_c._expand_group((mesh, 0))
self.assertEqual(c10d._get_group_tag(mesh.get_dim_groups()[0]), tag)
expected_rankset = [0, 2] if dist.get_rank() in [0, 2] else [1, 3]
self.assertEqual(expected_rankset, rankset)
self.assertEqual(2, group_size)
tag, rankset, group_size = ft_c._expand_group((mesh, 1))
expected_rankset = [0, 1] if dist.get_rank() in [0, 1] else [2, 3]
self.assertEqual(c10d._get_group_tag(mesh.get_dim_groups()[1]), tag)
self.assertEqual(expected_rankset, rankset)
self.assertEqual(2, group_size)
class TestPgTag(MultiThreadedTestCase):
@property
def world_size(self):
return 4
def setUp(self):
super().setUp()
self._spawn_threads()
"""
The behavior we want is as follow:
- rankset+tag will always result in the same PG.
Do we enforce this by failing creation of new PGs or returning existing ones?
Return existing one.
- default tag gives existing behavior.
This means we should create duplicates.
- _expand_group on _default-tagged pg should always resolve to it
This mean we can't depend on empty tag + rankset.
"""
def test_pg_creation_with_tag(self):
my_group, _ = new_subgroups(group_size=2, pg_tag="blu")
my_group2, _ = new_subgroups(group_size=2, pg_tag="blu")
self.assertEqual(my_group, my_group2)
my_group3, _ = new_subgroups(group_size=2, pg_tag="blu2")
self.assertNotEqual(my_group, my_group3)
my_group4, _ = new_subgroups(group_size=2)
self.assertNotEqual(my_group, my_group4)
my_group5, _ = new_subgroups(group_size=2)
self.assertNotEqual(my_group4, my_group5)
def test_pg_lookup_roundtrip(self):
pg_tag0, _ = new_subgroups(group_size=2, pg_tag="blu")
pg_tag1, _ = new_subgroups(group_size=2, pg_tag="blu2")
pg_notag0, _ = new_subgroups(group_size=2)
pg_notag1, _ = new_subgroups(group_size=2)
def roundtrip(pg):
tag, rankset, _ = ft_c._expand_group(pg)
return c10d._find_pg_by_ranks_and_tag(tag, rankset)
self.assertEqual(pg_tag0, roundtrip(pg_tag0))
self.assertEqual(pg_tag1, roundtrip(pg_tag1))
self.assertEqual(pg_notag0, roundtrip(pg_notag0))
self.assertEqual(pg_notag1, roundtrip(pg_notag1))
def test_pg_lookup_with_tag(self):
pg_tag0, _ = new_subgroups(group_size=2, pg_tag="blu")
pg_tag1, _ = new_subgroups(group_size=2, pg_tag="bla")
pg_notag0, _ = new_subgroups(group_size=2)
def roundtrip(pg, pg_tag):
tag, rankset, _ = ft_c._expand_group(pg, pg_tag)
return c10d._find_pg_by_ranks_and_tag(tag, rankset)
self.assertEqual(pg_tag0, roundtrip(pg_tag1, "blu"))
self.assertEqual(pg_tag0, roundtrip(pg_notag0, "blu"))
# Cannot erase the tag of a PG
self.assertEqual(pg_tag0, roundtrip(pg_tag0, ""))
def test_find_or_create_pg(self):
pg = c10d._find_or_create_pg_by_ranks_and_tag("blu", [0, 1, 2, 3], 2)
pg_tag0, _ = new_subgroups(group_size=2, pg_tag="blu")
self.assertEqual(pg, pg_tag0)
def test_find_root_pg(self):
pg = c10d._find_pg_by_ranks_and_tag("", [0, 1, 2, 3])
self.assertEqual(dist.group.WORLD, pg)
class TestTraceableCollectives(MultiThreadedTestCase):
@property
def world_size(self):
return 4
def setUp(self):
super().setUp()
self._spawn_threads()
@parametrize("device", ["cpu", "cuda"])
def test_all_reduce_eager(self, device):
if device == "cuda":
if torch.cuda.device_count() < self.world_size:
self.skipTest("Not enough CUDA devices")
torch.cuda.set_device(dist.get_rank())
tensor = torch.ones([4], device=device)
mesh = dt.DeviceMesh(device, torch.arange(4))
res = ft_c.all_reduce(tensor, "sum", mesh)
self.assertEqual(res, torch.tensor([4, 4, 4, 4], dtype=torch.float))
mesh = dt.DeviceMesh(device, torch.arange(4).view(2, 2))
res2 = ft_c.all_reduce(tensor, "sum", (mesh, 1))
self.assertEqual(res2, torch.tensor([2, 2, 2, 2], dtype=torch.float))
@parametrize("device", ["cpu", "cuda"])
def test_all_reduce_coalesced_eager(self, device):
if device == "cuda":
if torch.cuda.device_count() < self.world_size:
self.skipTest("Not enough CUDA devices")
torch.cuda.set_device(dist.get_rank())
t0 = torch.ones([4], device=device)
t1 = torch.ones([6], device=device) + 2
mesh = dt.DeviceMesh(device, torch.arange(4))
res = ft_c.all_reduce_coalesced([t0, t1], "sum", mesh)
self.assertEqual(res[0], t0 * 4)
self.assertEqual(res[1], t1 * 4)
@parametrize("device", ["cpu", "cuda"])
def test_all_gather_tensor(self, device):
if device == "cuda":
if torch.cuda.device_count() < self.world_size:
self.skipTest("Not enough CUDA devices")
torch.cuda.set_device(dist.get_rank())
# testing 1d/2d mesh
mesh_1d = dt.DeviceMesh(device, torch.arange(self.world_size))
mesh_2d = dt.DeviceMesh(device, torch.arange(self.world_size).view(2, 2))
for mesh in [mesh_1d, mesh_2d]:
dims_to_gather = [0, 1, 2]
for dim in dims_to_gather:
output_size = [3, 3, 3]
output_size[dim] *= mesh.size(0)
# each rank have its own tensor, all_gather gives a list
local_tensor = torch.ones([3, 3, 3], device=device)
gathered_tensor = ft_c.all_gather_tensor(local_tensor, gather_dim=dim, group=(mesh, 0))
self.assertEqual(gathered_tensor, torch.ones(output_size))
@parametrize("device", ["cpu", "cuda"])
def test_all_gather_into_tensor_coalesced(self, device):
if device == "cuda":
if torch.cuda.device_count() < self.world_size:
self.skipTest("Not enough CUDA devices")
torch.cuda.set_device(dist.get_rank())
tensors = [torch.ones([4], device=device), torch.ones([4], device=device) + 1]
mesh = dt.DeviceMesh(device, torch.arange(4))
res = ft_c.all_gather_into_tensor_coalesced(tensors, mesh)
self.assertEqual(2, len(res))
self.assertEqual(torch.ones([4 * dist.get_world_size()], device=device), res[0])
self.assertEqual(torch.ones([4 * dist.get_world_size()], device=device) + 1, res[1])
@parametrize("device", ["cpu", "cuda"])
def test_reduce_scatter_tensor(self, device):
if device == "cuda":
if torch.cuda.device_count() < self.world_size:
self.skipTest("Not enough CUDA devices")
torch.cuda.set_device(dist.get_rank())
# testing 1d/2d mesh
mesh_1d = dt.DeviceMesh(device, torch.arange(self.world_size))
mesh_2d = dt.DeviceMesh(device, torch.arange(self.world_size).view(2, 2))
for mesh in [mesh_1d, mesh_2d]:
dims_to_scatter = [0, 1]
for dim in dims_to_scatter:
group_size = mesh.size(0)
input_size = [3, 3]
output_size = [3, 3]
output_size[dim] *= group_size
input_tensor = torch.ones(output_size, device=device)
res_num = 1 * group_size
rs_tensor = ft_c.reduce_scatter_tensor(input_tensor, "sum", scatter_dim=dim, group=(mesh, 0))
self.assertEqual(rs_tensor, torch.ones(input_size) * res_num)
@parametrize("device", ["cpu", "cuda"])
def test_reduce_scatter_into_tensor_coalesced(self, device):
if device == "cuda":
if torch.cuda.device_count() < self.world_size:
self.skipTest("Not enough CUDA devices")
torch.cuda.set_device(dist.get_rank())
tensors = [torch.ones([4], dtype=torch.int64, device=device), torch.ones([4], dtype=torch.int64, device=device) + 1]
mesh = dt.DeviceMesh(device, torch.arange(4))
res = ft_c.reduce_scatter_tensor_coalesced(tensors, "sum", [0, 0], mesh)
self.assertEqual(2, len(res))
self.assertEqual(torch.tensor([4], device=device), res[0])
self.assertEqual(torch.tensor([8], device=device), res[1])
class TestMetaCollectives(TestCase):
def test_all_reduce(self):
x = torch.rand((2, 3, 4), device="meta")
out = ft_c.all_reduce(x, "sum", [1])
self.assertEqual(x.size(), out.size())
class TestGradCollectives(MultiThreadedTestCase):
@property
def world_size(self):
return 2
def setUp(self):
super().setUp()
self._spawn_threads()
def test_all_reduce(self):
x = torch.rand([4], requires_grad=True)
y = torch.rand([4], requires_grad=True)
out = ft_c.all_reduce(x, "sum", [0, 1])
(out + y).sum().backward()
self.assertIsNone(x.grad)
class TestMakeFx(MultiThreadedTestCase):
@property
def world_size(self):
return 2
def setUp(self):
super().setUp()
self._spawn_threads()
def test_all_reduce_tracing(self):
def allred(input):
return ft_c.all_reduce(input, "sum", group=[0, 1]) + 1
graph = make_fx(allred)(torch.rand(4))
FileCheck() \
.check("all_reduce") \
.check("wait_tensor").run(str(graph.graph))
mesh = dt.DeviceMesh("cpu", torch.arange(self.world_size))
def allred_mesh(input):
return ft_c.all_reduce(input, "sum", mesh) + 1
mesh_graph = make_fx(allred_mesh)(torch.rand(4))
FileCheck() \
.check_not("get_attr") \
.check("wait_tensor").run(str(mesh_graph.graph))
def allred_mesh_dim(input):
return ft_c.all_reduce(input, "sum", (mesh, 0)) + 1
mesh_dim_graph = make_fx(allred_mesh_dim)(torch.rand(4))
FileCheck() \
.check_not("get_attr") \
.check("wait_tensor").run(str(mesh_dim_graph.graph))
instantiate_parametrized_tests(TestTraceableCollectives)
BACKEND = dist.Backend.NCCL if torch.cuda.is_available() else dist.Backend.GLOO
WORLD_SIZE = 2
def with_comms(func=None):
if func is None:
return partial(
with_comms,
)
@wraps(func)
def wrapper(self, *args, **kwargs):
if BACKEND == dist.Backend.NCCL and torch.cuda.device_count() < self.world_size:
sys.exit(TEST_SKIPS[f"multi-gpu-{self.world_size}"].exit_code)
self.dist_init()
func(self)
self.destroy_comms()
return wrapper
class TestCollectivesWithNCCL(MultiProcessTestCase):
def setUp(self):
super().setUp()
os.environ["WORLD_SIZE"] = str(self.world_size)
os.environ["BACKEND"] = dist.Backend.NCCL
self._spawn_processes()
@property
def device(self):
return torch.device(self.rank)
@property
def world_size(self):
return WORLD_SIZE
@property
def process_group(self):
return dist.group.WORLD
def dist_init(self):
dist.init_process_group(
backend=BACKEND,
world_size=self.world_size,
rank=self.rank,
init_method=f"file://{self.file_name}",
)
# set device for nccl pg for collectives
if BACKEND == "nccl":
torch.cuda.set_device(self.rank)
def destroy_comms(self):
# Wait for all ranks to reach here before starting shutdown.
dist.barrier()
dist.destroy_process_group()
@skip_if_lt_x_gpu(WORLD_SIZE)
@requires_nccl()
@with_comms()
def test_all_gather_into_tensor_coalesced(self):
tensors = [torch.ones([4], device=f"cuda:{self.rank}"), torch.ones([4], device=f"cuda:{self.rank}") + 1]
mesh = dt.DeviceMesh(f"cuda:{self.rank}", torch.arange(self.world_size))
res = ft_c.all_gather_into_tensor_coalesced(tensors, mesh)
self.assertEqual(2, len(res))
self.assertEqual(torch.ones([4 * dist.get_world_size()]), res[0])
self.assertEqual(torch.ones([4 * dist.get_world_size()]) + 1, res[1])
class TestOpWaitiness(MultiThreadedTestCase):
@property
def world_size(self):
return 1
def setUp(self):
super().setUp()
self._spawn_threads()
def tearDown(self):
super().tearDown()
ft_c_impl._wait_all()
def test_wait_reduce_outstanding_work_count(self):
self.assertEqual(0, ft_c_impl._outstanding_wait_count())
tensor = torch.ones([4])
res = ft_c.all_reduce(tensor, "sum", [0])
self.assertEqual(1, ft_c_impl._outstanding_wait_count())
self.assertTrue(ft_c_impl._tensor_needs_wait(res))
res.trigger_wait()
self.assertEqual(0, ft_c_impl._outstanding_wait_count())
self.assertFalse(ft_c_impl._tensor_needs_wait(res))
def test_add_triggers_wait(self):
self.assertEqual(0, ft_c_impl._outstanding_wait_count())
tensor = torch.ones([4])
res = ft_c.all_reduce(tensor, "sum", [0])
self.assertEqual(1, ft_c_impl._outstanding_wait_count())
self.assertTrue(ft_c_impl._tensor_needs_wait(res))
foo = res + torch.ones([4])
self.assertEqual(0, ft_c_impl._outstanding_wait_count())
self.assertFalse(ft_c_impl._tensor_needs_wait(res))
self.assertFalse(isinstance(foo, ft_c.AsyncCollectiveTensor))
def test_view_does_not_trigger_wait(self):
self.assertEqual(0, ft_c_impl._outstanding_wait_count())
tensor = torch.ones([4])
res = ft_c.all_reduce(tensor, "sum", [0])
self.assertEqual(1, ft_c_impl._outstanding_wait_count())
self.assertTrue(ft_c_impl._tensor_needs_wait(res))
foo = res.view([2, 2])
self.assertEqual(1, ft_c_impl._outstanding_wait_count())
self.assertTrue(ft_c_impl._tensor_needs_wait(res))
self.assertTrue(ft_c_impl._tensor_needs_wait(foo))
self.assertTrue(isinstance(foo, ft_c.AsyncCollectiveTensor))
foo.trigger_wait()
self.assertEqual(0, ft_c_impl._outstanding_wait_count())
def test_dead_wrapper_triggers_wait(self):
self.assertEqual(0, ft_c_impl._outstanding_wait_count())
tensor = torch.ones([4])
res = ft_c.all_reduce(tensor, "sum", [0])
wr = weakref.ref(res)
self.assertTrue(wr() is not None)
res = None
self.assertTrue(wr() is None)
self.assertEqual(0, ft_c_impl._outstanding_wait_count())
def test_dead_wrapper_plus_view(self):
self.assertEqual(0, ft_c_impl._outstanding_wait_count())
tensor = torch.ones([4])
res = ft_c.all_reduce(tensor, "sum", [0])
res = res.view([2, 2])
self.assertEqual(1, ft_c_impl._outstanding_wait_count())
res = None
self.assertEqual(0, ft_c_impl._outstanding_wait_count())
if __name__ == "__main__":
run_tests()