| import abc |
| from dataclasses import dataclass |
| from typing import List, Any |
| |
| from torch.futures import Future |
| |
| from .metadata import ( |
| Metadata, |
| MetadataIndex, |
| ) |
| |
| from .planner import ( |
| LoadPlan, |
| SavePlan, |
| SavePlanner, |
| LoadPlanner, |
| ) |
| |
| __all__ = ["WriteResult", "StorageWriter", "StorageReader"] |
| |
| |
| @dataclass(frozen=True) |
| class WriteResult: |
| index: MetadataIndex |
| |
| size_in_bytes: int |
| storage_data: Any |
| |
| |
| class StorageWriter(abc.ABC): |
| """ |
| Interface used by ``save_state_dict`` to write to storage. |
| |
| One StorageWriter instance acts as both the coordinator and the follower |
| in a distributed checkpoint. As part of initialization, each instance |
| is told its role. |
| |
| A subclass should expect the following sequence of calls. |
| |
| 1) (all ranks) init() |
| 2) (all ranks) prepare_local_plan() |
| 3) (coordinator) prepare_global_plan() |
| 4) (all ranks) write_data() |
| 5) (coordinator) finish() |
| """ |
| |
| @abc.abstractmethod |
| def init(self, is_coordinator: bool) -> None: |
| """ |
| Initialize this instance. |
| |
| Args: |
| is_coordinator (bool): Whether this instance is reponsible for coordinating |
| the checkpoint. |
| """ |
| pass |
| |
| @abc.abstractmethod |
| def prepare_local_plan(self, plan: SavePlan) -> SavePlan: |
| """ |
| Perform storage-specific local planning. |
| |
| While this method can produce a completely different plan, the recomended |
| way is to store storage specific data in SavePlan::storage_data. |
| |
| Args: |
| plan (SavePlan): The local plan from the ``SavePlanner`` in use. |
| |
| Returns: |
| A transformed ``SavePlan`` after storage local planning |
| """ |
| pass |
| |
| @abc.abstractmethod |
| def prepare_global_plan(self, plans: List[SavePlan]) -> List[SavePlan]: |
| """ |
| Perform centralized planning of storage. |
| |
| This method is only called on the coordinator instance. |
| |
| While this method can produce a completely different plan, the prefered |
| way is to store storage specific data in SavePlan::storage_data. |
| |
| Args: |
| plans: A list of ``SavePlan`` instances, one for each rank. |
| |
| Returns: |
| A list of transformed ``SavePlan`` after storage global planning |
| """ |
| pass |
| |
| @abc.abstractmethod |
| def write_data( |
| self, plan: SavePlan, planner: SavePlanner |
| ) -> Future[List[WriteResult]]: |
| """ |
| Write all items from ``plan`` using ``planner`` to resolve the data. |
| |
| A subclass should call ``SavePlanner::resolve_data`` on each item |
| from the plan to get access to the underlying object to write. |
| |
| Subclasses should lazily call `resolve_data` as it can allocate memory. |
| In case of tensors, make following assuptions: |
| |
| - They might be on any device, including not matching the one on ``WriteItem::tensor_data`` |
| - They might be views or not contiguous. Only the projection needs to be saved. |
| |
| Args: |
| plan (SavePlan): The save plan to execute. |
| planner (SavePlanner): Planner object to be used to resolve items to data. |
| |
| Returns: |
| A future that completes to a list of WriteResult |
| """ |
| pass |
| |
| @abc.abstractmethod |
| def finish( |
| self, metadata: Metadata, results: List[List[WriteResult]] |
| ) -> None: |
| """ |
| Writes the metadata and marks the current checkpoint as sucessful. |
| |
| The actual format/schema used for serializing `metadata` is an |
| implemetation detail. The only requirement is that it's recoverable |
| in to the same object graph. |
| |
| Args: |
| metadata (Metadata): metadata for the new checkpoint |
| results: A list of WriteResults from all ranks. |
| |
| Returns: |
| None |
| """ |
| pass |
| |
| |
| class StorageReader(abc.ABC): |
| """ |
| Interface used by ``load_state_dict`` to read from storage. |
| |
| One StorageReader instance acts as both the coordinator and the follower |
| in a distributed checkpoint. As part of initialization, each instance |
| is told its role. |
| |
| A subclass should expected the following sequence of calls by ``load_state_dict``: |
| |
| 1) (all ranks) read_metadata() |
| 2) (all ranks) init |
| 3) (all ranks) prepare_local_plan |
| 4) (coordinator) prepare_global_plan |
| 5) (all ranks) read_data |
| """ |
| |
| @abc.abstractmethod |
| def read_metadata(self) -> Metadata: |
| """ |
| Reads the checkpoint metadata. |
| |
| Returns: |
| The metatada object associated with the checkpoint being loaded. |
| |
| """ |
| pass |
| |
| @abc.abstractmethod |
| def init(self, metadata: Metadata, is_coordinator: bool) -> None: |
| """ |
| Initialize this instance. |
| |
| Args: |
| metadata (Metadata): The metadata schema to use. |
| is_coordinator (bool): Whether this instance is reponsible for coordinating |
| the checkpoint. |
| """ |
| pass |
| |
| @abc.abstractmethod |
| def prepare_local_plan(self, plan: LoadPlan) -> LoadPlan: |
| """ |
| Perform storage-specific local planning. |
| |
| While this method can produce a completely different plan, the recomended |
| way is to store storage specific data in LoadPlan::storage_data. |
| |
| Args: |
| plan (LoadPlan): The local plan from the ``LoadPlan`` in use. |
| |
| Returns: |
| A transformed ``LoadPlan`` after storage local planning |
| """ |
| pass |
| |
| @abc.abstractmethod |
| def prepare_global_plan(self, plans: List[LoadPlan]) -> List[LoadPlan]: |
| """ |
| Perform centralized planning of storage loading. |
| |
| This method is only called on the coordinator instance. |
| |
| While this method can produce a completely different plan, the prefered |
| way is to store storage specific data in LoadPlan::storage_data. |
| |
| Args: |
| plans: A list of ``LoadPlan`` instances, one for each rank. |
| |
| Returns: |
| A list of transformed ``LoadPlan`` after storage global planning |
| """ |
| pass |
| |
| @abc.abstractmethod |
| def read_data(self, plan: LoadPlan, planner: LoadPlanner) -> Future[None]: |
| """ |
| Reads all items from ``plan`` using ``planner`` to resolve the data. |
| |
| A subclass should call ``LoadPlanner::load_bytes`` to deserialize a BytesIO |
| object into the right place. |
| |
| A subclass should call ``LoadPlanner::resolve_tensor`` to get access to the |
| tensors that in should load data into. |
| |
| It's the StorageLayer responsibility to properly schedule any cross device copies |
| required. |
| |
| Args: |
| plan (LoadPlan): The local plan to execute on |
| planner (LoadPlanner): The planner object to use to resolve items. |
| |
| Returns: |
| A future that completes once all reads are finished. |
| """ |
| pass |