[PT-D][Checkpoint] Rename DCP storage layer init() (#92869)

Rename DCP storage layer init() and update tests accordingly.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/92869
Approved by: https://github.com/kumpera
diff --git a/test/distributed/checkpoint/test_checkpoint.py b/test/distributed/checkpoint/test_checkpoint.py
index 96c9811..6d0111a 100644
--- a/test/distributed/checkpoint/test_checkpoint.py
+++ b/test/distributed/checkpoint/test_checkpoint.py
@@ -187,8 +187,8 @@
     def __init__(self, fail_conf):
         super(FaultyStorageWriter, self).__init__(fail_conf)
 
-    def init(self, is_coordinator: bool) -> None:
-        self._fail_rank("fail_init")
+    def set_up_storage_writer(self, is_coordinator: bool) -> None:
+        self._fail_rank("fail_set_up_storage_writer")
 
     def prepare_local_plan(self, plan: SavePlan) -> SavePlan:
         self._fail_rank("fail_prepare_local_plan")
@@ -215,8 +215,8 @@
         super(FaultyStorageReader, self).__init__(fail_conf)
         self.metadata = metadata
 
-    def init(self, metadata: Metadata, is_coordinator: bool) -> None:
-        self._fail_rank("fail_init")
+    def set_up_storage_reader(self, metadata: Metadata, is_coordinator: bool) -> None:
+        self._fail_rank("fail_set_up_storage_reader")
 
     def prepare_local_plan(self, plan: LoadPlan) -> LoadPlan:
         self._fail_rank("fail_prepare_local_plan")
@@ -329,7 +329,7 @@
             "bytes": [1, 2, 3, 4],
         }
 
-        self._test_save(state_dict, fail_init=[0])
+        self._test_save(state_dict, fail_set_up_storage_writer=[0])
         self._test_save(state_dict, fail_finish=[0])
         self._test_save(state_dict, fail_prepare_global_plan=[0])
 
@@ -337,7 +337,7 @@
         self._test_save(state_dict, fail_write_data=[2])
         self._test_save(state_dict, fail_write_data_async=[3])
 
-        self._test_save(state_dict, coordinator=1, fail_init=[1])
+        self._test_save(state_dict, coordinator=1, fail_set_up_storage_writer=[1])
         self._test_save(state_dict, coordinator=1, fail_finish=[1])
 
     def test_save_error_handling_no_dist(self) -> None:
@@ -345,7 +345,7 @@
 
         self.assertFalse(dist.is_initialized())
 
-        self._test_save(state_dict, fail_init=[0])
+        self._test_save(state_dict, fail_set_up_storage_writer=[0])
         self._test_save(state_dict, fail_finish=[0])
         self._test_save(state_dict, fail_prepare_global_plan=[0])
 
@@ -364,14 +364,14 @@
         }
 
         self._test_load(state_dict)
-        self._test_load(state_dict, fail_init=[0])
+        self._test_load(state_dict, fail_set_up_storage_reader=[0])
         self._test_load(state_dict, fail_prepare_global_plan=[0])
         self._test_load(state_dict, fail_read_metadata=[0])
         self._test_load(state_dict, fail_prepare_local_plan=[1])
         self._test_load(state_dict, fail_read_data=[3])
         self._test_load(state_dict, fail_read_data_async=[1])
 
-        self._test_load(state_dict, coordinator=3, fail_init=[0])
+        self._test_load(state_dict, coordinator=3, fail_set_up_storage_reader=[0])
         self._test_load(state_dict, coordinator=1, fail_read_metadata=[3])
         self._test_load(state_dict, coordinator=2, fail_read_data=[0])
         self._test_load(state_dict, coordinator=3, fail_read_data_async=[2])
@@ -380,7 +380,7 @@
     def test_load_error_handling_no_dist(self) -> None:
         state_dict = {"replicated": torch.rand(10, 10), "bytes": [1, 2, 3, 4]}
         self._test_load(state_dict)
-        self._test_load(state_dict, fail_init=[0])
+        self._test_load(state_dict, fail_set_up_storage_reader=[0])
         self._test_load(state_dict, fail_read_metadata=[0])
         self._test_load(state_dict, fail_prepare_local_plan=[0])
         self._test_load(state_dict, fail_prepare_global_plan=[0])
diff --git a/torch/distributed/checkpoint/filesystem.py b/torch/distributed/checkpoint/filesystem.py
index 3d5ca4c..a6016b8 100644
--- a/torch/distributed/checkpoint/filesystem.py
+++ b/torch/distributed/checkpoint/filesystem.py
@@ -345,7 +345,7 @@
         self.thread_count = thread_count
         self.per_thread_copy_ahead = per_thread_copy_ahead
 
-    def init(self, is_coordinator: bool) -> None:
+    def set_up_storage_writer(self, is_coordinator: bool) -> None:
         pass
 
     def prepare_local_plan(self, plan: SavePlan) -> SavePlan:
@@ -513,7 +513,7 @@
         with (self.path / ".metadata").open("rb") as metadata_file:
             return pickle.load(metadata_file)
 
-    def init(self, metadata: Metadata, is_coordinator: bool) -> None:
+    def set_up_storage_reader(self, metadata: Metadata, is_coordinator: bool) -> None:
         self.storage_data = metadata.storage_data
         assert self.storage_data is not None
 
diff --git a/torch/distributed/checkpoint/state_dict_loader.py b/torch/distributed/checkpoint/state_dict_loader.py
index 2093bad..11b8e36 100644
--- a/torch/distributed/checkpoint/state_dict_loader.py
+++ b/torch/distributed/checkpoint/state_dict_loader.py
@@ -92,7 +92,7 @@
         assert planner is not None
         metadata = storage_reader.read_metadata()
         planner.set_up_planner(state_dict, metadata, distW.is_coordinator)
-        storage_reader.init(metadata, distW.is_coordinator)
+        storage_reader.set_up_storage_reader(metadata, distW.is_coordinator)
 
         local_plan = planner.create_local_plan()
         local_plan = storage_reader.prepare_local_plan(local_plan)
diff --git a/torch/distributed/checkpoint/state_dict_saver.py b/torch/distributed/checkpoint/state_dict_saver.py
index c89eed4..0ace087 100644
--- a/torch/distributed/checkpoint/state_dict_saver.py
+++ b/torch/distributed/checkpoint/state_dict_saver.py
@@ -86,7 +86,7 @@
     def local_step():
         assert planner is not None
         planner.set_up_planner(state_dict, distW.is_coordinator)
-        storage_writer.init(distW.is_coordinator)
+        storage_writer.set_up_storage_writer(distW.is_coordinator)
         local_plan = planner.create_local_plan()
         local_plan = storage_writer.prepare_local_plan(local_plan)
         return local_plan
diff --git a/torch/distributed/checkpoint/storage.py b/torch/distributed/checkpoint/storage.py
index dbc8fda..73cd5ff 100644
--- a/torch/distributed/checkpoint/storage.py
+++ b/torch/distributed/checkpoint/storage.py
@@ -37,7 +37,7 @@
 
     A subclass should expect the following sequence of calls.
 
-    1) (all ranks) init()
+    1) (all ranks) set_up_storage_writer()
     2) (all ranks) prepare_local_plan()
     3) (coordinator) prepare_global_plan()
     4) (all ranks) write_data()
@@ -45,7 +45,7 @@
     """
 
     @abc.abstractmethod
-    def init(self, is_coordinator: bool) -> None:
+    def set_up_storage_writer(self, is_coordinator: bool) -> None:
         """
         Initialize this instance.
 
@@ -146,10 +146,10 @@
     A subclass should expected the following sequence of calls by ``load_state_dict``:
 
     1) (all ranks) read_metadata()
-    2) (all ranks) init
-    3) (all ranks) prepare_local_plan
-    4) (coordinator) prepare_global_plan
-    5) (all ranks) read_data
+    2) (all ranks) set_up_storage_reader()
+    3) (all ranks) prepare_local_plan()
+    4) (coordinator) prepare_global_plan()
+    5) (all ranks) read_data()
     """
 
     @abc.abstractmethod
@@ -164,7 +164,7 @@
         pass
 
     @abc.abstractmethod
-    def init(self, metadata: Metadata, is_coordinator: bool) -> None:
+    def set_up_storage_reader(self, metadata: Metadata, is_coordinator: bool) -> None:
         """
         Initialize this instance.