blob: f2255637a69128909acb4a170018b90140ce967f [file] [log] [blame]
# Owner(s): ["oncall: distributed"]
import os
import sys
import unittest
from functools import partial, wraps
import torch
import torch.distributed as dist
import torch.distributed._functional_collectives as ft_c
import torch.distributed._tensor as dt
import torch.distributed.distributed_c10d as c10d
from functorch import make_fx
from torch._inductor.utils import run_and_get_code
from torch.testing import FileCheck
from torch.testing._internal.distributed.fake_pg import FakeStore
from torch.utils._triton import has_triton
if not dist.is_available():
print("Distributed not available, skipping tests", file=sys.stderr)
sys.exit(0)
from torch.testing._internal.common_distributed import (
MultiProcessTestCase,
MultiThreadedTestCase,
requires_nccl,
TEST_SKIPS,
)
from torch.testing._internal.common_utils import (
instantiate_parametrized_tests,
parametrize,
run_tests,
TestCase,
)
def new_subgroups(group_size: int, pg_tag=None):
world_size = dist.get_world_size()
subgroups = []
cur_subgroup = None
for subgroup_id in range(world_size // group_size):
start_rank = subgroup_id * group_size
end_rank = start_rank + group_size
ranks_in_subgroup = list(range(start_rank, end_rank))
subgroup = c10d._new_group_with_tag(
ranks=ranks_in_subgroup,
pg_tag=pg_tag,
)
subgroups.append(subgroup)
rank = dist.get_rank()
if rank in ranks_in_subgroup:
cur_subgroup = subgroup
return cur_subgroup, subgroups
class TestExpand(MultiThreadedTestCase):
@property
def world_size(self):
return 4
def setUp(self):
super().setUp()
self._spawn_threads()
def test_expand_1d_rank_list(self):
tag, rankset, group_size = ft_c._expand_group([0, 1, 2, 3])
self.assertEqual("", tag)
self.assertEqual([0, 1, 2, 3], rankset)
self.assertEqual(4, group_size)
tag, rankset, group_size = ft_c._expand_group([0, 1, 2, 3], "bla")
self.assertEqual("bla", tag)
def test_expand_2d_rank_list(self):
tag, rankset, group_size = ft_c._expand_group([[0, 1], [2, 3]])
self.assertEqual("", tag)
self.assertEqual([0, 1, 2, 3], rankset)
self.assertEqual(2, group_size)
tag, rankset, group_size = ft_c._expand_group([[0, 1], [2, 3]], "blu")
self.assertEqual("blu", tag)
with self.assertRaisesRegex(ValueError, "group sizes must be identical"):
ft_c._expand_group([[0], [1, 2, 3]])
def test_expand_process_group(self):
tag, rankset, group_size = ft_c._expand_group(dist.group.WORLD)
self.assertEqual(c10d._get_group_tag(dist.group.WORLD), tag)
self.assertEqual([0, 1, 2, 3], rankset)
self.assertEqual(4, group_size)
tag, rankset, group_size = ft_c._expand_group(dist.group.WORLD, "bla")
self.assertEqual("bla", tag)
my_pg, others = new_subgroups(group_size=2)
tag, rankset, group_size = ft_c._expand_group(my_pg)
self.assertEqual(c10d._get_group_tag(my_pg), tag)
self.assertEqual(dist.get_process_group_ranks(my_pg), rankset)
self.assertEqual(2, group_size)
my_pg = None
for i in range(dist.get_world_size()):
group = c10d._new_group_with_tag([i], pg_tag="my_pg")
if i == dist.get_rank():
my_pg = group
tag, rankset, group_size = ft_c._expand_group(my_pg)
self.assertEqual("my_pg", tag)
self.assertEqual([dist.get_rank()], rankset)
self.assertEqual(1, group_size)
tag, rankset, group_size = ft_c._expand_group(my_pg, "bla")
self.assertEqual("bla", tag)
def test_expand_device_mesh(self):
mesh = dt.DeviceMesh("cpu", torch.arange(4))
tag, rankset, group_size = ft_c._expand_group(mesh)
self.assertEqual(c10d._get_group_tag(mesh.get_group(mesh_dim=0)), tag)
self.assertEqual([0, 1, 2, 3], rankset)
self.assertEqual(4, group_size)
mesh = dt.DeviceMesh("cpu", torch.arange(4))
tag, rankset, group_size = ft_c._expand_group(mesh)
self.assertEqual(c10d._get_group_tag(mesh.get_group(mesh_dim=0)), tag)
self.assertEqual([0, 1, 2, 3], rankset)
self.assertEqual(4, group_size)
def test_expand_device_mesh_tuple(self):
mesh = dt.DeviceMesh("cpu", torch.arange(4).view(2, 2))
with self.assertRaisesRegex(AssertionError, "Only 1D mesh"):
tag, rankset, group_size = ft_c._expand_group(mesh)
tag, rankset, group_size = ft_c._expand_group((mesh, 0))
self.assertEqual(c10d._get_group_tag(mesh.get_group(mesh_dim=0)), tag)
expected_rankset = [0, 2] if dist.get_rank() in [0, 2] else [1, 3]
self.assertEqual(expected_rankset, rankset)
self.assertEqual(2, group_size)
tag, rankset, group_size = ft_c._expand_group((mesh, 1))
expected_rankset = [0, 1] if dist.get_rank() in [0, 1] else [2, 3]
self.assertEqual(c10d._get_group_tag(mesh.get_group(mesh_dim=1)), tag)
self.assertEqual(expected_rankset, rankset)
self.assertEqual(2, group_size)
class TestPgTag(MultiThreadedTestCase):
@property
def world_size(self):
return 4
def setUp(self):
super().setUp()
self._spawn_threads()
"""
The behavior we want is as follow:
- rankset+tag will always result in the same PG.
Do we enforce this by failing creation of new PGs or returning existing ones?
Return existing one.
- default tag gives existing behavior.
This means we should create duplicates.
- _expand_group on _default-tagged pg should always resolve to it
This mean we can't depend on empty tag + rankset.
"""
def test_pg_creation_with_tag(self):
my_group, _ = new_subgroups(group_size=2, pg_tag="blu")
my_group2, _ = new_subgroups(group_size=2, pg_tag="blu")
self.assertEqual(my_group, my_group2)
my_group3, _ = new_subgroups(group_size=2, pg_tag="blu2")
self.assertNotEqual(my_group, my_group3)
my_group4, _ = new_subgroups(group_size=2)
self.assertNotEqual(my_group, my_group4)
my_group5, _ = new_subgroups(group_size=2)
self.assertNotEqual(my_group4, my_group5)
def test_pg_lookup_roundtrip(self):
pg_tag0, _ = new_subgroups(group_size=2, pg_tag="blu")
pg_tag1, _ = new_subgroups(group_size=2, pg_tag="blu2")
pg_notag0, _ = new_subgroups(group_size=2)
pg_notag1, _ = new_subgroups(group_size=2)
def roundtrip(pg):
tag, rankset, _ = ft_c._expand_group(pg)
return c10d._find_pg_by_ranks_and_tag(tag, rankset)
self.assertEqual(pg_tag0, roundtrip(pg_tag0))
self.assertEqual(pg_tag1, roundtrip(pg_tag1))
self.assertEqual(pg_notag0, roundtrip(pg_notag0))
self.assertEqual(pg_notag1, roundtrip(pg_notag1))
def test_pg_lookup_with_tag(self):
pg_tag0, _ = new_subgroups(group_size=2, pg_tag="blu")
pg_tag1, _ = new_subgroups(group_size=2, pg_tag="bla")
pg_notag0, _ = new_subgroups(group_size=2)
def roundtrip(pg, pg_tag):
tag, rankset, _ = ft_c._expand_group(pg, pg_tag)
return c10d._find_pg_by_ranks_and_tag(tag, rankset)
self.assertEqual(pg_tag0, roundtrip(pg_tag1, "blu"))
self.assertEqual(pg_tag0, roundtrip(pg_notag0, "blu"))
# Cannot erase the tag of a PG
self.assertEqual(pg_tag0, roundtrip(pg_tag0, ""))
def test_find_or_create_pg(self):
pg = c10d._find_or_create_pg_by_ranks_and_tag("blu", [0, 1, 2, 3], 2)
pg_tag0, _ = new_subgroups(group_size=2, pg_tag="blu")
self.assertEqual(pg, pg_tag0)
def test_find_root_pg(self):
pg = c10d._find_pg_by_ranks_and_tag("", [0, 1, 2, 3])
self.assertEqual(dist.group.WORLD, pg)
@instantiate_parametrized_tests
class TestTraceableCollectives(MultiThreadedTestCase):
@property
def world_size(self):
return 4
def setUp(self):
super().setUp()
self._spawn_threads()
@parametrize("device", ["cpu", "cuda"])
def test_broadcast(self, device):
if device == "cuda":
if torch.cuda.device_count() < self.world_size:
self.skipTest("Not enough CUDA devices")
torch.cuda.set_device(dist.get_rank())
if dist.get_rank() == 0:
tensor = torch.ones([4], device=device)
else:
tensor = torch.zeros([4], device=device)
mesh = dt.DeviceMesh(device, torch.arange(4))
res = ft_c.broadcast(tensor, 0, mesh)
self.assertEqual(res, torch.ones([4], device=device))
@parametrize("device", ["cpu", "cuda"])
def test_all_reduce_eager(self, device):
if device == "cuda":
if torch.cuda.device_count() < self.world_size:
self.skipTest("Not enough CUDA devices")
torch.cuda.set_device(dist.get_rank())
tensor = torch.ones([4], device=device)
mesh = dt.DeviceMesh(device, torch.arange(4))
res = ft_c.all_reduce(tensor, "sum", mesh)
self.assertEqual(res, torch.tensor([4, 4, 4, 4], dtype=torch.float))
mesh = dt.DeviceMesh(device, torch.arange(4).view(2, 2))
res2 = ft_c.all_reduce(tensor, "sum", (mesh, 1))
self.assertEqual(res2, torch.tensor([2, 2, 2, 2], dtype=torch.float))
@parametrize("device", ["cpu", "cuda"])
def test_all_reduce_coalesced_eager(self, device):
if device == "cuda":
if torch.cuda.device_count() < self.world_size:
self.skipTest("Not enough CUDA devices")
torch.cuda.set_device(dist.get_rank())
t0 = torch.ones([4], device=device)
t1 = torch.ones([6], device=device) + 2
mesh = dt.DeviceMesh(device, torch.arange(4))
res = ft_c.all_reduce_coalesced([t0, t1], "sum", mesh)
self.assertEqual(res[0], t0 * 4)
self.assertEqual(res[1], t1 * 4)
@parametrize("device", ["cpu", "cuda"])
def test_all_gather_tensor(self, device):
if device == "cuda":
if torch.cuda.device_count() < self.world_size:
self.skipTest("Not enough CUDA devices")
torch.cuda.set_device(dist.get_rank())
# testing 1d/2d mesh
mesh_1d = dt.DeviceMesh(device, torch.arange(self.world_size))
mesh_2d = dt.DeviceMesh(device, torch.arange(self.world_size).view(2, 2))
for mesh in [mesh_1d, mesh_2d]:
dims_to_gather = [0, 1, 2]
for dim in dims_to_gather:
output_size = [3, 3, 3]
output_size[dim] *= mesh.size(0)
# each rank have its own tensor, all_gather gives a bigger tensor
local_tensor = torch.ones([3, 3, 3], device=device)
gathered_tensor = ft_c.all_gather_tensor(
local_tensor, gather_dim=dim, group=(mesh, 0)
)
self.assertEqual(gathered_tensor, torch.ones(output_size))
@parametrize("device", ["cpu", "cuda"])
def test_all_gather_into_tensor_coalesced(self, device):
if device == "cuda":
if torch.cuda.device_count() < self.world_size:
self.skipTest("Not enough CUDA devices")
torch.cuda.set_device(dist.get_rank())
tensors = [torch.ones([4], device=device), torch.ones([4], device=device) + 1]
mesh = dt.DeviceMesh(device, torch.arange(4))
res = ft_c.all_gather_into_tensor_coalesced(tensors, mesh)
self.assertEqual(2, len(res))
self.assertEqual(torch.ones([4 * dist.get_world_size()], device=device), res[0])
self.assertEqual(
torch.ones([4 * dist.get_world_size()], device=device) + 1, res[1]
)
@parametrize("device", ["cpu", "cuda"])
def test_reduce_scatter_tensor(self, device):
if device == "cuda":
if torch.cuda.device_count() < self.world_size:
self.skipTest("Not enough CUDA devices")
torch.cuda.set_device(dist.get_rank())
# testing 1d/2d mesh
mesh_1d = dt.DeviceMesh(device, torch.arange(self.world_size))
mesh_2d = dt.DeviceMesh(device, torch.arange(self.world_size).view(2, 2))
for mesh in [mesh_1d, mesh_2d]:
dims_to_scatter = [0, 1]
for dim in dims_to_scatter:
group_size = mesh.size(0)
input_size = [3, 3]
output_size = [3, 3]
output_size[dim] *= group_size
input_tensor = torch.ones(output_size, device=device)
res_num = 1 * group_size
rs_tensor = ft_c.reduce_scatter_tensor(
input_tensor, "sum", scatter_dim=dim, group=(mesh, 0)
)
self.assertEqual(rs_tensor, torch.ones(input_size) * res_num)
@parametrize("device", ["cpu", "cuda"])
def test_reduce_scatter_into_tensor_coalesced(self, device):
if device == "cuda":
if torch.cuda.device_count() < self.world_size:
self.skipTest("Not enough CUDA devices")
torch.cuda.set_device(dist.get_rank())
tensors = [
torch.ones([4], dtype=torch.int64, device=device),
torch.ones([4], dtype=torch.int64, device=device) + 1,
]
mesh = dt.DeviceMesh(device, torch.arange(4))
res = ft_c.reduce_scatter_tensor_coalesced(tensors, "sum", [0, 0], mesh)
self.assertEqual(2, len(res))
self.assertEqual(torch.tensor([4], device=device), res[0])
self.assertEqual(torch.tensor([8], device=device), res[1])
class TestMetaCollectives(TestCase):
def test_all_reduce(self):
x = torch.rand((2, 3, 4), device="meta")
out = ft_c.all_reduce(x, "sum", "0")
self.assertEqual(x.size(), out.size())
class TestGradCollectives(MultiThreadedTestCase):
@property
def world_size(self):
return 2
def setUp(self):
super().setUp()
self._spawn_threads()
def test_all_reduce(self):
x = torch.rand([4], requires_grad=True)
y = torch.rand([4], requires_grad=True)
out = ft_c.all_reduce(x, "sum", dist.group.WORLD)
(out + y).sum().backward()
self.assertIsNone(x.grad)
class TestMakeFx(MultiThreadedTestCase):
@property
def world_size(self):
return 2
def setUp(self):
super().setUp()
self._spawn_threads()
def tearDown(self):
super().tearDown()
# race condition with threads causes is_fx_tracing flag to be set incorrectly.
torch.fx._symbolic_trace._is_fx_tracing_flag = False
self.assertFalse(torch.fx._symbolic_trace.is_fx_tracing())
def test_all_reduce_tracing(self):
def allred(input):
return ft_c.all_reduce(input, "sum", group=dist.group.WORLD) + 1
graph = make_fx(allred)(torch.rand(4))
FileCheck().check("all_reduce").check("wait_tensor").run(str(graph.graph))
mesh = dt.DeviceMesh("cpu", torch.arange(self.world_size))
def allred_mesh(input):
return ft_c.all_reduce(input, "sum", mesh) + 1
mesh_graph = make_fx(allred_mesh)(torch.rand(4))
FileCheck().check_not("get_attr").check("wait_tensor").run(
str(mesh_graph.graph)
)
def allred_mesh_dim(input):
return ft_c.all_reduce(input, "sum", (mesh, 0)) + 1
mesh_dim_graph = make_fx(allred_mesh_dim)(torch.rand(4))
FileCheck().check_not("get_attr").check("wait_tensor").run(
str(mesh_dim_graph.graph)
)
BACKEND = dist.Backend.NCCL if torch.cuda.is_available() else dist.Backend.GLOO
WORLD_SIZE = 2
def exit_if_lt_x_gpu(x):
if torch.cuda.device_count() < x:
sys.exit(TEST_SKIPS[f"multi-gpu-{x}"].exit_code)
def with_comms(func=None):
if func is None:
return partial(
with_comms,
)
@wraps(func)
def wrapper(self, *args, **kwargs):
if BACKEND == dist.Backend.NCCL and torch.cuda.device_count() < self.world_size:
sys.exit(TEST_SKIPS[f"multi-gpu-{self.world_size}"].exit_code)
self.dist_init()
func(self)
self.destroy_comms()
return wrapper
class TestCollectivesWithNCCL(MultiProcessTestCase):
def setUp(self):
super().setUp()
os.environ["WORLD_SIZE"] = str(self.world_size)
os.environ["BACKEND"] = dist.Backend.NCCL
self._spawn_processes()
@property
def device(self):
return torch.device(self.rank)
@property
def world_size(self):
return WORLD_SIZE
@property
def process_group(self):
return dist.group.WORLD
def dist_init(self):
dist.init_process_group(
backend=BACKEND,
world_size=self.world_size,
rank=self.rank,
init_method=f"file://{self.file_name}",
)
# set device for nccl pg for collectives
if BACKEND == "nccl":
torch.cuda.set_device(self.rank)
def destroy_comms(self):
# Wait for all ranks to reach here before starting shutdown.
dist.barrier()
dist.destroy_process_group()
@requires_nccl()
@with_comms()
def test_all_gather_into_tensor_coalesced(self):
exit_if_lt_x_gpu(self.world_size)
tensors = [
torch.ones([4], device=f"cuda:{self.rank}"),
torch.ones([4], device=f"cuda:{self.rank}") + 1,
]
mesh = dt.DeviceMesh(f"cuda:{self.rank}", torch.arange(self.world_size))
res = ft_c.all_gather_into_tensor_coalesced(tensors, mesh)
self.assertEqual(2, len(res))
self.assertEqual(torch.ones([4 * dist.get_world_size()]), res[0])
self.assertEqual(torch.ones([4 * dist.get_world_size()]) + 1, res[1])
@with_comms()
def test_all_to_all_single(self):
device = "cuda" if BACKEND == dist.Backend.NCCL else "cpu"
mesh = dt.DeviceMesh(device, torch.arange(self.world_size))
rank = dist.get_rank()
row = self.world_size * (rank + 1) * (self.world_size + 1) / 2
x = torch.ones(int(row), 5, device=device) * (rank + 1)
split_sizes = [(i + 1) * (rank + 1) for i in range(self.world_size)]
y = ft_c.all_to_all_single(
x, output_split_sizes=split_sizes, input_split_sizes=split_sizes, group=mesh
)
expected = []
for idx, tensor in enumerate(torch.split(x, split_sizes)):
expected.append(torch.full_like(tensor, (idx + 1)))
expected = torch.cat(expected)
self.assertEqual(y, expected)
@with_comms()
def test_all_to_all_single_1d_input(self):
device = "cuda" if BACKEND == dist.Backend.NCCL else "cpu"
mesh = dt.DeviceMesh(device, torch.arange(self.world_size))
rank = dist.get_rank()
row = self.world_size * (rank + 1) * (self.world_size + 1) / 2
x = torch.ones(int(row), device=device) * (rank + 1)
split_sizes = [(i + 1) * (rank + 1) for i in range(self.world_size)]
y = ft_c.all_to_all_single(
x, output_split_sizes=split_sizes, input_split_sizes=split_sizes, group=mesh
)
expected = []
for idx, tensor in enumerate(torch.split(x, split_sizes)):
expected.append(torch.full_like(tensor, (idx + 1)))
expected = torch.cat(expected)
self.assertEqual(y, expected)
@with_comms()
def test_all_to_all_single_split_sizes_none(self):
device = "cuda" if BACKEND == dist.Backend.NCCL else "cpu"
mesh = dt.DeviceMesh(device, torch.arange(self.world_size))
rank = dist.get_rank()
x = torch.ones(self.world_size, self.world_size, device=device) * (rank + 1)
y = ft_c.all_to_all_single(
x, output_split_sizes=None, input_split_sizes=None, group=mesh
)
expected = []
for idx, tensor in enumerate(torch.chunk(x, self.world_size)):
expected.append(torch.full_like(tensor, (idx + 1)))
expected = torch.cat(expected)
self.assertEqual(y, expected)
@unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch")
@requires_nccl()
@with_comms()
def test_tracing(self):
def allreduce(t, pg):
return ft_c.all_reduce(t, "sum", pg)
compiled_allreduce = torch.compile(allreduce, fullgraph=True)
compiled_allreduce(torch.randn(8, device=self.device), self.process_group)
@unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch")
def test_tracing_with_fakepg(self):
exit_if_lt_x_gpu(self.world_size)
def allreduce(t, pg):
return ft_c.all_reduce(t, "sum", pg)
compiled_allreduce = torch.compile(allreduce, fullgraph=True)
dist.init_process_group(
backend="fake",
rank=0,
world_size=8,
store=FakeStore(),
)
allreduce(torch.randn(8, device=self.device), pg=dist.group.WORLD)
class TestNCCLCollectivesWithWorldSize4(TestCollectivesWithNCCL):
@property
def world_size(self):
return 4
@requires_nccl()
@with_comms()
def test_permute_tensor_with_sub_group(self):
exit_if_lt_x_gpu(self.world_size)
device = "cuda"
mesh_dim_names = ["dp", "tp"]
mesh_2d = dt.init_device_mesh(
device, (2, self.world_size // 2), mesh_dim_names=mesh_dim_names
)
for mesh_name in mesh_dim_names:
mesh = mesh_2d[mesh_name]
rank = mesh.get_local_rank()
# rank0: [0., 1.], rank1: [2., 3.]
send_tensor = torch.arange(2, dtype=torch.float32, device=device) + 2 * rank
recvd_tensor = ft_c.permute_tensor(send_tensor, [1, 0], group=mesh)
# rank0: [2., 3.], rank1: [0., 1.]
expected = torch.arange(2, dtype=torch.float32, device=device) + 2 * (
(rank - 1 + 2) % 2
)
self.assertEqual(
recvd_tensor,
expected,
msg=f"Expected {expected} on {self.rank=} (local_rank={rank}), "
f"but received {recvd_tensor} instead.",
)
@instantiate_parametrized_tests
class TestFunctionalAutograd(MultiThreadedTestCase):
def setUp(self):
super().setUp()
self._spawn_threads()
@property
def world_size(self):
return 2
@parametrize("compile", [True, False])
def test_all_to_all_single(self, compile: bool = True) -> None:
group = dist.group.WORLD.group_name
t = torch.ones((self.world_size, 2), requires_grad=True)
def my_func(t: torch.Tensor, world_size: int) -> torch.Tensor:
sizes = [1] * world_size
t = t * 2
assert t.requires_grad
out = ft_c.all_to_all_single_autograd(t, sizes, sizes, group)
out = out + 0
return out
if compile:
compiled = torch.compile(my_func, fullgraph=True, backend="aot_eager")
else:
compiled = my_func
out = compiled(t, self.world_size)
self.assertEqual(out.shape, t.shape)
self.assertEqual(out, torch.full_like(t, 2.0))
self.assertIsNotNone(out.grad_fn)
self.assertTrue(out.requires_grad)
loss = out.sum()
loss.backward()
self.assertEqual(t.grad, torch.full_like(t, 2.0))
def test_all_to_all_single_inductor(self) -> None:
group = dist.group.WORLD.group_name
t = torch.rand((self.world_size, 2), requires_grad=True)
def my_func(t: torch.Tensor, world_size: int) -> torch.Tensor:
sizes = [1] * world_size
t = t * 10
assert t.requires_grad
out = ft_c.all_to_all_single_autograd(t, sizes, sizes, group)
out = out + 2
return out.sum()
compiled = torch.compile(my_func, fullgraph=True)
def run_with_backward():
out = compiled(t, self.world_size)
out.backward()
res, codes = run_and_get_code(run_with_backward)
for code in codes:
FileCheck().check_count(
"_c10d_functional.all_to_all_single.default", 1, exactly=True
).check_count("_c10d_functional.wait_tensor.default", 1, exactly=True).run(
code
)
self.assertIsNotNone(t.grad)
@parametrize("compile", [True, False])
def test_all_gather_tensor(self, compile: bool) -> None:
group = dist.group.WORLD.group_name
def my_func(t: torch.Tensor, dim: int) -> torch.Tensor:
assert t.requires_grad
out = ft_c.all_gather_tensor_autograd(
t * 1.0,
gather_dim=dim,
group=group,
)
out = out * 1.0
return out
if compile:
compiled = torch.compile(my_func, fullgraph=True, backend="aot_eager")
else:
compiled = my_func
dims_to_gather = [0, 1, 2]
for dim in dims_to_gather:
output_size = [3, 3, 3]
output_size[dim] *= self.world_size
# each rank have its own tensor, all_gather gives a bigger tensor
local_tensor = torch.ones([3, 3, 3], requires_grad=True)
gathered_tensor = compiled(local_tensor, dim)
self.assertEqual(gathered_tensor, torch.ones(output_size))
gathered_tensor.sum().backward()
self.assertEqual(
local_tensor.grad,
torch.full((3, 3, 3), fill_value=float(self.world_size)),
)
@parametrize("compile", [True, False])
def test_reduce_scatter_tensor(self, compile: bool) -> None:
group = dist.group.WORLD.group_name
def my_func(t: torch.Tensor, dim: int) -> torch.Tensor:
assert t.requires_grad
rs_tensor = (
ft_c.reduce_scatter_tensor_autograd(
input_tensor * 1.0, "sum", scatter_dim=dim, group=group
)
* 1.0
)
return rs_tensor
if compile:
compiled = torch.compile(my_func, fullgraph=True, backend="aot_eager")
else:
compiled = my_func
dims_to_scatter = [0, 1]
for dim in dims_to_scatter:
group_size = self.world_size
input_size = [3, 3]
output_size = [3, 3]
output_size[dim] *= group_size
input_tensor = torch.ones(output_size, requires_grad=True)
rs_tensor = compiled(input_tensor, dim)
res_num = 1 * group_size
self.assertEqual(rs_tensor, torch.ones(input_size) * res_num)
rs_tensor.sum().backward()
self.assertEqual(input_tensor.grad, torch.full(output_size, fill_value=1.0))
class TestFunctionalAutogradWithNCCL(MultiProcessTestCase):
def setUp(self):
super().setUp()
os.environ["WORLD_SIZE"] = str(self.world_size)
os.environ["BACKEND"] = dist.Backend.NCCL
self._spawn_processes()
@property
def device(self):
return torch.device(self.rank)
@property
def world_size(self):
return 2
@property
def process_group(self):
return dist.group.WORLD
def dist_init(self):
dist.init_process_group(
backend=BACKEND,
world_size=self.world_size,
rank=self.rank,
init_method=f"file://{self.file_name}",
)
# set device for nccl pg for collectives
if BACKEND == "nccl":
torch.cuda.set_device(self.rank)
def destroy_comms(self):
# Wait for all ranks to reach here before starting shutdown.
dist.barrier()
dist.destroy_process_group()
@requires_nccl()
@with_comms()
def test_all_to_all_single(self) -> None:
group = self.process_group.group_name
t = torch.ones((self.world_size, 2), requires_grad=True, device=self.device)
sizes = [1] * self.world_size
assert t.requires_grad
out = ft_c.all_to_all_single_autograd(t * 2, sizes, sizes, group) + 0
self.assertEqual(out.shape, t.shape)
self.assertEqual(out, torch.full_like(t, 2.0))
self.assertIsNotNone(out.grad_fn)
self.assertTrue(out.requires_grad)
loss = out.sum()
loss.backward()
self.assertEqual(t.grad, torch.full_like(t, 2.0))
if __name__ == "__main__":
run_tests()