| from concurrent.futures import Future |
| from typing import Any, Dict, List, Optional |
| |
| import torch.distributed as dist |
| import torch.distributed.checkpoint.state_dict_loader as loader |
| import torch.distributed.checkpoint.state_dict_saver as saver |
| from torch.distributed.checkpoint.metadata import Metadata, STATE_DICT_TYPE |
| from torch.distributed.checkpoint.storage import ( |
| LoadPlanner, |
| SavePlanner, |
| StorageReader, |
| StorageWriter, |
| ) |
| |
| |
| __all__: List[str] = [] |
| |
| |
| class _Checkpointer: |
| """This base class specefies a high level API for saving and loading |
| distributed `state_dict` 's. It provides an abstraction over the low-level APIs |
| provided by :py:mod:`torch.distributed.checkpoint.storage`, essentially calling |
| :py:meth: `torch.distributed.state_dict_saver.save` and |
| :py:meth: `torch.distributed.state_dict_loader.load` with the provided storage |
| readers and writers. |
| |
| .. warning:: |
| This feature is experimental and subject to removal/change. |
| |
| """ |
| |
| def __init__( |
| self, |
| storage_writer: StorageWriter, |
| storage_reader: StorageReader, |
| *, |
| process_group: Optional[dist.ProcessGroup] = None, |
| coordinator_rank: int = 0, |
| no_dist: bool = False, |
| load_planner: Optional[LoadPlanner] = None, |
| save_planner: Optional[SavePlanner] = None, |
| ): |
| """Initializes the Checkpointer instance. |
| |
| Args: |
| storage_writer: Instance of StorageWrite use to perform writes. |
| storage_reader: StorageReader used to load data from. |
| process_group: ProcessGroup to be used for cross-rank synchronization. |
| coordinator_rank: Rank to use to coordinate the checkpoint. rank0 is used by default. |
| no_dist: If ``True``, distributed checkpoint will not load in SPMD style. (Default: ``False``) |
| loader_planner: Instance of LoadPlanner to use when loading. |
| save_planner: Instance of SavePlanner to use when saving. |
| """ |
| self.storage_writer = storage_writer |
| self.storage_reader = storage_reader |
| self.process_group = process_group |
| self.coordinator_rank = coordinator_rank |
| self.no_dist = no_dist |
| self.load_planner = load_planner |
| self.save_planner = save_planner |
| |
| def save( |
| self, |
| state_dict: STATE_DICT_TYPE, |
| ) -> Metadata: |
| """Calls :py:meth: `torch.distributed.state_dict_saver.save`. Utilizing values passed during initialization.""" |
| return saver.save( |
| state_dict, |
| self.storage_writer, |
| process_group=self.process_group, |
| coordinator_rank=self.coordinator_rank, |
| no_dist=self.no_dist, |
| planner=self.save_planner, |
| ) |
| |
| def async_save( |
| self, |
| state_dict: STATE_DICT_TYPE, |
| ) -> Future: |
| """ |
| Calls :py:meth: `torch.distributed.state_dict_saver._async_save`. Utilizing values passed during initialization. |
| |
| Returns: |
| Future: A future holding the resultant Metadata object from `save`. |
| """ |
| return saver.async_save( |
| state_dict, |
| storage_writer=self.storage_writer, |
| process_group=self.process_group, |
| planner=self.save_planner, |
| ) |
| |
| def load(self, state_dict: Dict[str, Any]) -> None: |
| """Calls :py:meth: `torch.distributed.state_dict_loader.load`. Utilizing values passed during initialization.""" |
| loader.load( |
| state_dict, |
| storage_reader=self.storage_reader, |
| process_group=self.process_group, |
| planner=self.load_planner, |
| ) |