blob: a0002c32b1b4b6bd41ef6270315e8c91af85eb84 [file] [log] [blame]
# Owner(s): ["oncall: distributed"]
import sys
from typing import Optional, List, cast
from torch.distributed.checkpoint.storage import WriteResult
from torch.distributed.checkpoint import (
StorageReader,
StorageWriter,
CheckpointException,
load_state_dict,
save_state_dict,
)
import torch
import torch.distributed as dist
import torch.nn
import torch.futures
from torch.futures import Future
from torch.distributed._shard import sharded_tensor
from torch.distributed.checkpoint.default_planner import (
_create_default_local_metadata,
)
from torch.distributed.checkpoint.metadata import (
BytesStorageMetadata,
Metadata,
TensorStorageMetadata,
)
from torch.distributed.checkpoint.planner import (
SavePlan,
SavePlanner,
LoadPlan,
LoadPlanner,
)
from torch.distributed._shard.sharded_tensor import (
state_dict_hook,
ShardedTensor,
)
from torch.distributed._shard.sharding_spec import ChunkShardingSpec
from torch.testing._internal.common_distributed import (
requires_nccl,
skip_if_lt_x_gpu,
)
from torch.testing._internal.distributed._shard.sharded_tensor import (
ShardedTensorTestBase,
with_comms,
)
from torch.testing._internal.common_utils import (
TEST_WITH_DEV_DBG_ASAN,
run_tests,
)
if TEST_WITH_DEV_DBG_ASAN:
print(
"Skip dev-asan as torch + multiprocessing spawn have known issues",
file=sys.stderr,
)
sys.exit(0)
class TestModule(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.sharded: ShardedTensor = sharded_tensor.zeros(self.spec(), 4, 4)
self.regular = torch.nn.Parameter(torch.ones(4, 4))
self.extra_sharded: Optional[ShardedTensor] = None
self.extra_param: Optional[torch.nn.Parameter] = None
self._register_state_dict_hook(state_dict_hook)
def spec(self) -> ChunkShardingSpec:
# pyre-fixme [28]: Unexpected keyword argument `dim` to call `dist._sharding_spec.api.ChunkShardingSpec.__init__`.
return ChunkShardingSpec(
dim=0,
placements=[
"rank:0/cuda:0",
"rank:1/cuda:1",
],
)
class TestDistributedCheckpointing(ShardedTensorTestBase):
@property
def world_size(self) -> int:
return 2
@with_comms(init_rpc=False)
@skip_if_lt_x_gpu(2)
@requires_nccl()
def test_tensor_metadata_with_missing_rank_spec(self) -> None:
spec = ChunkShardingSpec(
dim=0,
placements=[
"rank:1/cuda:1",
],
)
st = sharded_tensor.zeros(spec, 4, 4, dtype=torch.float64)
mapping = {}
md = _create_default_local_metadata({"st": st})
st_md = md.state_dict_metadata["st"]
self.assertEqual(1, len(st_md.chunks))
@with_comms(init_rpc=False)
@skip_if_lt_x_gpu(2)
@requires_nccl()
def test_default_metadata(self) -> None:
device = f"cuda:{dist.get_rank()}"
spec = ChunkShardingSpec(
dim=0,
placements=[
"rank:0/cuda:0",
"rank:1/cuda:1",
],
)
state_dict = {
"sharded": sharded_tensor.rand(
spec,
(
10,
10,
),
),
"replicated": torch.rand(4, device=device),
"bytes": [1, 2, 3, 4],
}
metadata = _create_default_local_metadata(state_dict)
self.assertTrue("bytes" in metadata.state_dict_metadata)
self.assertIsInstance(
metadata.state_dict_metadata["bytes"], BytesStorageMetadata
)
self.assertTrue("replicated" in metadata.state_dict_metadata)
self.assertIsInstance(
metadata.state_dict_metadata["replicated"], TensorStorageMetadata
)
md = metadata.state_dict_metadata["replicated"]
self.assertEqual(md.size, state_dict["replicated"].size())
self.assertEqual(md.properties.dtype, torch.float32)
self.assertEqual(1, len(md.chunks))
self.assertTrue("sharded" in metadata.state_dict_metadata)
self.assertIsInstance(
metadata.state_dict_metadata["sharded"], TensorStorageMetadata
)
md = metadata.state_dict_metadata["sharded"]
self.assertEqual(md.properties.dtype, torch.float32)
self.assertEqual(md.size, state_dict["sharded"].size())
self.assertEqual(2, len(md.chunks))
class TestStorageBase:
def __init__(self, fail_conf):
self.fail_conf = fail_conf
self.rank = 0 if not dist.is_initialized() else dist.get_rank()
def _get_ranks(self, name):
return self.fail_conf[name] if name in self.fail_conf else None
def _fail_rank(self, name):
ranks = self._get_ranks(name)
if ranks is not None and self.rank in ranks:
raise ValueError(f"rank fail {self.rank} for {name}")
def _fail_rank_async(self, name, result=None):
ranks = self._get_ranks(name)
fut = Future()
if ranks is not None and self.rank in ranks:
fut.set_exception(
ValueError(f"async rank fail {self.rank} for {name}")
)
else:
fut.set_result(result)
return fut
class FaultyStorageWriter(TestStorageBase, StorageWriter):
def __init__(self, fail_conf):
super().__init__(fail_conf)
def set_up_storage_writer(self, is_coordinator: bool) -> None:
self._fail_rank("fail_set_up_storage_writer")
def prepare_local_plan(self, plan: SavePlan) -> SavePlan:
self._fail_rank("fail_prepare_local_plan")
return plan
def prepare_global_plan(self, plans: List[SavePlan]) -> List[SavePlan]:
self._fail_rank("fail_prepare_global_plan")
return plans
def write_data(
self, plan: SavePlan, planner: SavePlanner
) -> Future[List[WriteResult]]:
self._fail_rank("fail_write_data")
return self._fail_rank_async("fail_write_data_async", [])
def finish(
self, metadata: Metadata, results: List[List[WriteResult]]
) -> None:
self._fail_rank("fail_finish")
class FaultyStorageReader(TestStorageBase, StorageReader):
def __init__(self, metadata, fail_conf):
super().__init__(fail_conf)
self.metadata = metadata
def set_up_storage_reader(self, metadata: Metadata, is_coordinator: bool) -> None:
self._fail_rank("fail_set_up_storage_reader")
def prepare_local_plan(self, plan: LoadPlan) -> LoadPlan:
self._fail_rank("fail_prepare_local_plan")
return plan
def prepare_global_plan(self, plans: List[LoadPlan]) -> List[LoadPlan]:
self._fail_rank("fail_prepare_global_plan")
return plans
def read_data(self, plan: LoadPlan, planner: LoadPlanner) -> Future[None]:
self._fail_rank("fail_read_data")
return self._fail_rank_async("fail_read_data_async")
def read_metadata(self) -> Metadata:
self._fail_rank("fail_read_metadata")
return self.metadata
class TestDistributedFailure(ShardedTensorTestBase):
def get_spec(self):
return ChunkShardingSpec(
dim=0,
placements=[
f"rank:{r}/cuda:{r}" for r in range(dist.get_world_size())
],
)
@with_comms(init_rpc=False)
@skip_if_lt_x_gpu(2)
@requires_nccl()
def test_dummy_writer_works(self) -> None:
state_dict = {
"sharded": sharded_tensor.rand(self.get_spec(), 20, 20),
"replicated": torch.rand(10, 10),
"bytes": [1, 2, 3, 4],
}
save_state_dict(state_dict, FaultyStorageWriter({}))
@with_comms(init_rpc=False)
@skip_if_lt_x_gpu(2)
@requires_nccl()
def test_dummy_reader_works(self) -> None:
state_dict = {
"sharded": sharded_tensor.rand(self.get_spec(), 20, 20),
"replicated": torch.rand(10, 10),
"bytes": [1, 2, 3, 4],
}
metadata = _create_default_local_metadata(state_dict)
load_state_dict(state_dict, FaultyStorageReader(metadata, {}))
def _test_dist_failure(self, callback, kwargs):
bad_ranks = list(kwargs.values())[0] if len(kwargs) > 0 else []
# Empty bad_ranks means it must work
if len(bad_ranks) == 0:
callback()
else:
with self.assertRaises(CheckpointException) as cm:
callback()
e = cast(CheckpointException, cm.exception)
for rank, wrapped_ex in e.failures.items():
ex = wrapped_ex[0]
self.assertTrue(rank in bad_ranks, msg=f"{rank} did not fail")
if not kwargs.get("ignore_exception_type", False):
self.assertEqual(ValueError, type(ex), str(ex))
failed_ranks = e.failures.keys()
for rank in bad_ranks:
self.assertTrue(
rank in failed_ranks,
msg=f"{rank} was supposed to fail was fine",
)
def _test_save(self, state_dict, coordinator=0, **kwargs):
no_dist = not dist.is_initialized()
def _save():
save_state_dict(
state_dict,
storage_writer=FaultyStorageWriter(kwargs),
coordinator_rank=coordinator,
no_dist=no_dist,
)
self._test_dist_failure(_save, kwargs)
def _test_load(self, state_dict, coordinator=0, **kwargs):
no_dist = not dist.is_initialized()
def _load():
metadata = _create_default_local_metadata(state_dict)
load_state_dict(
state_dict,
storage_reader=FaultyStorageReader(metadata, kwargs),
coordinator_rank=coordinator,
no_dist=no_dist,
)
self._test_dist_failure(_load, kwargs)
@with_comms(init_rpc=False)
@skip_if_lt_x_gpu(4)
@requires_nccl()
def test_save_error_handling(self) -> None:
state_dict = {
"sharded": sharded_tensor.rand(self.get_spec(), 20, 20),
"replicated": torch.rand(10, 10),
"bytes": [1, 2, 3, 4],
}
self._test_save(state_dict, fail_set_up_storage_writer=[0])
self._test_save(state_dict, fail_finish=[0])
self._test_save(state_dict, fail_prepare_global_plan=[0])
self._test_save(state_dict, fail_prepare_local_plan=[0])
self._test_save(state_dict, fail_write_data=[2])
self._test_save(state_dict, fail_write_data_async=[3])
self._test_save(state_dict, coordinator=1, fail_set_up_storage_writer=[1])
self._test_save(state_dict, coordinator=1, fail_finish=[1])
def test_save_error_handling_no_dist(self) -> None:
state_dict = {"replicated": torch.rand(10, 10), "bytes": [1, 2, 3, 4]}
self.assertFalse(dist.is_initialized())
self._test_save(state_dict, fail_set_up_storage_writer=[0])
self._test_save(state_dict, fail_finish=[0])
self._test_save(state_dict, fail_prepare_global_plan=[0])
self._test_save(state_dict, fail_prepare_local_plan=[0])
self._test_save(state_dict, fail_write_data=[0])
self._test_save(state_dict, fail_write_data_async=[0])
@with_comms(init_rpc=False)
@skip_if_lt_x_gpu(4)
@requires_nccl()
def test_load_error_handling(self) -> None:
state_dict = {
"sharded": sharded_tensor.rand(self.get_spec(), 20, 20),
"replicated": torch.rand(10, 10),
"bytes": [1, 2, 3, 4],
}
self._test_load(state_dict)
self._test_load(state_dict, fail_set_up_storage_reader=[0])
self._test_load(state_dict, fail_prepare_global_plan=[0])
self._test_load(state_dict, fail_read_metadata=[0])
self._test_load(state_dict, fail_prepare_local_plan=[1])
self._test_load(state_dict, fail_read_data=[3])
self._test_load(state_dict, fail_read_data_async=[1])
self._test_load(state_dict, coordinator=3, fail_set_up_storage_reader=[0])
self._test_load(state_dict, coordinator=1, fail_read_metadata=[3])
self._test_load(state_dict, coordinator=2, fail_read_data=[0])
self._test_load(state_dict, coordinator=3, fail_read_data_async=[2])
self._test_load(state_dict, coordinator=1, fail_prepare_global_plan=[1])
def test_load_error_handling_no_dist(self) -> None:
state_dict = {"replicated": torch.rand(10, 10), "bytes": [1, 2, 3, 4]}
self._test_load(state_dict)
self._test_load(state_dict, fail_set_up_storage_reader=[0])
self._test_load(state_dict, fail_read_metadata=[0])
self._test_load(state_dict, fail_prepare_local_plan=[0])
self._test_load(state_dict, fail_prepare_global_plan=[0])
self._test_load(state_dict, fail_read_data=[0])
self._test_load(state_dict, fail_read_data_async=[0])
if __name__ == "__main__":
run_tests()