blob: 9cb33a90dba07d0864336f27b267c54e8739cc25 [file] [log] [blame]
# Owner(s): ["oncall: distributed"]
import torch
from torch.distributed.distributed_c10d import _get_default_group
from torch.distributed.fsdp._shard_utils import _create_chunk_sharded_tensor
from torch.testing._internal.common_distributed import skip_if_lt_x_gpu
from torch.testing._internal.common_fsdp import FSDPTest
from torch.testing._internal.common_utils import run_tests
class TestShardUtilsDistributed(FSDPTest):
@property
def world_size(self):
return 2
def _create_tensor(self, *size):
# Keep everything deterministic.
torch.manual_seed(0)
return torch.rand(*size).cuda()
@skip_if_lt_x_gpu(2)
def test_create_chunk_sharded_tensor(self):
for size in ((1,), (1, 6), (12,), (12, 6), (25,), (25, 6)):
tensor = self._create_tensor(*size)
sharded_tensor = _create_chunk_sharded_tensor(
tensor,
self.rank,
self.world_size,
torch.cuda.device_count(),
_get_default_group(),
)
output = torch.empty(*size).cuda() if self.rank == 0 else None
sharded_tensor.gather(0, output)
if self.rank == 0:
self.assertEqual(tensor, output)
if __name__ == "__main__":
run_tests()