blob: 7802c74bc4de7b25127ca063cf93bf72d99d86ab [file] [log] [blame]
# Owner(s): ["oncall: distributed"]
import torch
import torch.distributed as dist
import torch.distributed._functional_collectives as funcol
from torch.distributed._tensor import DTensor
from torch.distributed._tensor.placement_types import Shard
from torch.distributed.distributed_c10d import _get_default_group
from torch.distributed.fsdp._shard_utils import (
_create_chunk_dtensor,
_create_chunk_sharded_tensor,
_gather_state_dict,
)
from torch.testing._internal.common_fsdp import FSDPTest
from torch.testing._internal.common_utils import run_tests
from torch.testing._internal.distributed._tensor.common_dtensor import (
DTensorTestBase,
skip_if_lt_x_gpu,
with_comms,
)
class TestShardUtilsDistributed(FSDPTest):
@property
def world_size(self):
return 2
def _create_tensor(self, *size):
# Keep everything deterministic.
torch.manual_seed(0)
return torch.rand(*size).cuda()
@skip_if_lt_x_gpu(2)
def test_create_chunk_sharded_tensor(self):
for size in ((1,), (1, 6), (12,), (12, 6), (25,), (25, 6)):
tensor = self._create_tensor(*size)
sharded_tensor = _create_chunk_sharded_tensor(
tensor,
self.rank,
self.world_size,
torch.cuda.device_count(),
_get_default_group(),
)
output = torch.empty(*size).cuda() if self.rank == 0 else None
sharded_tensor.gather(0, output)
if self.rank == 0:
self.assertEqual(tensor, output)
class TestShardUtilsDistributedDTensor(DTensorTestBase):
@property
def world_size(self):
return 2
def _create_tensor(self, *size):
# Keep everything deterministic.
torch.manual_seed(0)
return torch.rand(*size).cuda()
@with_comms
@skip_if_lt_x_gpu(2)
def test_create_chunk_dtensor(self):
device_mesh = self.build_device_mesh()
for size in ((1,), (1, 6), (12,), (12, 6), (25,), (25, 6)):
tensor = self._create_tensor(*size)
tensor_chunks = torch.chunk(tensor, self.world_size, dim=0)
dtensor = _create_chunk_dtensor(tensor, self.rank, device_mesh)
local_tensor = dtensor.to_local()
if local_tensor.numel() != 0:
self.assertEqual(local_tensor, tensor_chunks[self.rank])
else:
self.assertEqual(self.rank >= len(tensor_chunks), True)
@with_comms
@skip_if_lt_x_gpu(2)
def test_gather_state_dict_dtensor(self):
device_mesh = self.build_device_mesh()
shard_spec = [Shard(0)]
torch.random.manual_seed(dist.get_rank())
local_tensor = torch.randn(3, 3, 3)
dist_tensor = DTensor.from_local(local_tensor, device_mesh, shard_spec)
state_dict = {"dtensor": dist_tensor}
gathered_state_dict = _gather_state_dict(state_dict)
expected_gathered_dtensor = funcol.all_gather_tensor(
dist_tensor.to_local(), gather_dim=0, group=(device_mesh, 0)
)
self.assertEqual(expected_gathered_dtensor, gathered_state_dict["dtensor"])
if __name__ == "__main__":
run_tests()