[PT-D][Checkpoint] Rename DCP storage layer init() (#92869)
Rename DCP storage layer init() and update tests accordingly.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/92869
Approved by: https://github.com/kumpera
diff --git a/test/distributed/checkpoint/test_checkpoint.py b/test/distributed/checkpoint/test_checkpoint.py
index 96c9811..6d0111a 100644
--- a/test/distributed/checkpoint/test_checkpoint.py
+++ b/test/distributed/checkpoint/test_checkpoint.py
@@ -187,8 +187,8 @@
def __init__(self, fail_conf):
super(FaultyStorageWriter, self).__init__(fail_conf)
- def init(self, is_coordinator: bool) -> None:
- self._fail_rank("fail_init")
+ def set_up_storage_writer(self, is_coordinator: bool) -> None:
+ self._fail_rank("fail_set_up_storage_writer")
def prepare_local_plan(self, plan: SavePlan) -> SavePlan:
self._fail_rank("fail_prepare_local_plan")
@@ -215,8 +215,8 @@
super(FaultyStorageReader, self).__init__(fail_conf)
self.metadata = metadata
- def init(self, metadata: Metadata, is_coordinator: bool) -> None:
- self._fail_rank("fail_init")
+ def set_up_storage_reader(self, metadata: Metadata, is_coordinator: bool) -> None:
+ self._fail_rank("fail_set_up_storage_reader")
def prepare_local_plan(self, plan: LoadPlan) -> LoadPlan:
self._fail_rank("fail_prepare_local_plan")
@@ -329,7 +329,7 @@
"bytes": [1, 2, 3, 4],
}
- self._test_save(state_dict, fail_init=[0])
+ self._test_save(state_dict, fail_set_up_storage_writer=[0])
self._test_save(state_dict, fail_finish=[0])
self._test_save(state_dict, fail_prepare_global_plan=[0])
@@ -337,7 +337,7 @@
self._test_save(state_dict, fail_write_data=[2])
self._test_save(state_dict, fail_write_data_async=[3])
- self._test_save(state_dict, coordinator=1, fail_init=[1])
+ self._test_save(state_dict, coordinator=1, fail_set_up_storage_writer=[1])
self._test_save(state_dict, coordinator=1, fail_finish=[1])
def test_save_error_handling_no_dist(self) -> None:
@@ -345,7 +345,7 @@
self.assertFalse(dist.is_initialized())
- self._test_save(state_dict, fail_init=[0])
+ self._test_save(state_dict, fail_set_up_storage_writer=[0])
self._test_save(state_dict, fail_finish=[0])
self._test_save(state_dict, fail_prepare_global_plan=[0])
@@ -364,14 +364,14 @@
}
self._test_load(state_dict)
- self._test_load(state_dict, fail_init=[0])
+ self._test_load(state_dict, fail_set_up_storage_reader=[0])
self._test_load(state_dict, fail_prepare_global_plan=[0])
self._test_load(state_dict, fail_read_metadata=[0])
self._test_load(state_dict, fail_prepare_local_plan=[1])
self._test_load(state_dict, fail_read_data=[3])
self._test_load(state_dict, fail_read_data_async=[1])
- self._test_load(state_dict, coordinator=3, fail_init=[0])
+ self._test_load(state_dict, coordinator=3, fail_set_up_storage_reader=[0])
self._test_load(state_dict, coordinator=1, fail_read_metadata=[3])
self._test_load(state_dict, coordinator=2, fail_read_data=[0])
self._test_load(state_dict, coordinator=3, fail_read_data_async=[2])
@@ -380,7 +380,7 @@
def test_load_error_handling_no_dist(self) -> None:
state_dict = {"replicated": torch.rand(10, 10), "bytes": [1, 2, 3, 4]}
self._test_load(state_dict)
- self._test_load(state_dict, fail_init=[0])
+ self._test_load(state_dict, fail_set_up_storage_reader=[0])
self._test_load(state_dict, fail_read_metadata=[0])
self._test_load(state_dict, fail_prepare_local_plan=[0])
self._test_load(state_dict, fail_prepare_global_plan=[0])
diff --git a/torch/distributed/checkpoint/filesystem.py b/torch/distributed/checkpoint/filesystem.py
index 3d5ca4c..a6016b8 100644
--- a/torch/distributed/checkpoint/filesystem.py
+++ b/torch/distributed/checkpoint/filesystem.py
@@ -345,7 +345,7 @@
self.thread_count = thread_count
self.per_thread_copy_ahead = per_thread_copy_ahead
- def init(self, is_coordinator: bool) -> None:
+ def set_up_storage_writer(self, is_coordinator: bool) -> None:
pass
def prepare_local_plan(self, plan: SavePlan) -> SavePlan:
@@ -513,7 +513,7 @@
with (self.path / ".metadata").open("rb") as metadata_file:
return pickle.load(metadata_file)
- def init(self, metadata: Metadata, is_coordinator: bool) -> None:
+ def set_up_storage_reader(self, metadata: Metadata, is_coordinator: bool) -> None:
self.storage_data = metadata.storage_data
assert self.storage_data is not None
diff --git a/torch/distributed/checkpoint/state_dict_loader.py b/torch/distributed/checkpoint/state_dict_loader.py
index 2093bad..11b8e36 100644
--- a/torch/distributed/checkpoint/state_dict_loader.py
+++ b/torch/distributed/checkpoint/state_dict_loader.py
@@ -92,7 +92,7 @@
assert planner is not None
metadata = storage_reader.read_metadata()
planner.set_up_planner(state_dict, metadata, distW.is_coordinator)
- storage_reader.init(metadata, distW.is_coordinator)
+ storage_reader.set_up_storage_reader(metadata, distW.is_coordinator)
local_plan = planner.create_local_plan()
local_plan = storage_reader.prepare_local_plan(local_plan)
diff --git a/torch/distributed/checkpoint/state_dict_saver.py b/torch/distributed/checkpoint/state_dict_saver.py
index c89eed4..0ace087 100644
--- a/torch/distributed/checkpoint/state_dict_saver.py
+++ b/torch/distributed/checkpoint/state_dict_saver.py
@@ -86,7 +86,7 @@
def local_step():
assert planner is not None
planner.set_up_planner(state_dict, distW.is_coordinator)
- storage_writer.init(distW.is_coordinator)
+ storage_writer.set_up_storage_writer(distW.is_coordinator)
local_plan = planner.create_local_plan()
local_plan = storage_writer.prepare_local_plan(local_plan)
return local_plan
diff --git a/torch/distributed/checkpoint/storage.py b/torch/distributed/checkpoint/storage.py
index dbc8fda..73cd5ff 100644
--- a/torch/distributed/checkpoint/storage.py
+++ b/torch/distributed/checkpoint/storage.py
@@ -37,7 +37,7 @@
A subclass should expect the following sequence of calls.
- 1) (all ranks) init()
+ 1) (all ranks) set_up_storage_writer()
2) (all ranks) prepare_local_plan()
3) (coordinator) prepare_global_plan()
4) (all ranks) write_data()
@@ -45,7 +45,7 @@
"""
@abc.abstractmethod
- def init(self, is_coordinator: bool) -> None:
+ def set_up_storage_writer(self, is_coordinator: bool) -> None:
"""
Initialize this instance.
@@ -146,10 +146,10 @@
A subclass should expected the following sequence of calls by ``load_state_dict``:
1) (all ranks) read_metadata()
- 2) (all ranks) init
- 3) (all ranks) prepare_local_plan
- 4) (coordinator) prepare_global_plan
- 5) (all ranks) read_data
+ 2) (all ranks) set_up_storage_reader()
+ 3) (all ranks) prepare_local_plan()
+ 4) (coordinator) prepare_global_plan()
+ 5) (all ranks) read_data()
"""
@abc.abstractmethod
@@ -164,7 +164,7 @@
pass
@abc.abstractmethod
- def init(self, metadata: Metadata, is_coordinator: bool) -> None:
+ def set_up_storage_reader(self, metadata: Metadata, is_coordinator: bool) -> None:
"""
Initialize this instance.