[DCP] Adds support for meta tensor loading for DCP.load_state_dict() (#113319)
Currently, DCP requires the `model.state_dict()` to be materialized before passing it to DCP to load, since DCP uses the pre-allocated storage from the initialized model state_dict. Therefore, even for fine-tuning and distributed inference, users would need to explicitly materialize the model on GPU before `DCP.load_state_dict()`.
Today's flow:
```
with torch.device("meta"):
model2 = parallelize_module(
MLPModule("meta"), tp_mesh, parallelize_plan=parallelize_plan
)
model.to_empty(device='cuda')
state_dict_to_load = model2.state_dict()
DCP.load_state_dict(
state_dict=state_dict_to_load,
storage_reader=DCP.FileSystemReader(CHECKPOINT_DIR),
)
model2.load_state_dict(state_dict_to_load)
```
This PR adds support for meta tensor loading. In DCP's planner, when encountering tensors/DTensor on meta device, we initialize tensor/DTensor on the current device on the fly and replace the tensor/DTensor on meta device in the state_dict. After the change, users no longer needs to manually call `model.to_empty()` when loading existing checkpoints for fine-tuning and distributed inference.
Updated user flow:
```
with torch.device("meta"):
model2 = parallelize_module(
MLPModule("meta"), tp_mesh, parallelize_plan=parallelize_plan
)
# no longer need to call model.to_empty(device='cuda')
state_dict_to_load = model2.state_dict()
DCP.load_state_dict(
state_dict=state_dict_to_load,
storage_reader=DCP.FileSystemReader(CHECKPOINT_DIR),
)
model2.load_state_dict(state_dict_to_load, assign=True)
```
Note that for distributed training, it's still the users' responsibility to reset the parameters (`model.reset_parameters()`) as checkpoint might not exist.
Note that we need to loop thru the state_dict to replace meta tensor/DTensor instead of calling `model.to_empty()` since `DCP.load()` only takes in state_dict but not model.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/113319
Approved by: https://github.com/fegin, https://github.com/LucasLLC
diff --git a/test/distributed/checkpoint/test_tp_checkpoint.py b/test/distributed/checkpoint/test_tp_checkpoint.py
index 692d987..5545782 100644
--- a/test/distributed/checkpoint/test_tp_checkpoint.py
+++ b/test/distributed/checkpoint/test_tp_checkpoint.py
@@ -3,7 +3,7 @@
from copy import deepcopy
import torch
-import torch.distributed.checkpoint as dist_cp
+import torch.distributed.checkpoint as DCP
from torch.distributed._tensor import init_device_mesh
@@ -28,6 +28,19 @@
from torch.testing._internal.distributed.checkpoint_utils import with_temp_dir
+class UnevenShardedModel(torch.nn.Module):
+ def __init__(self, device):
+ super().__init__()
+ torch.manual_seed(5)
+ self.net1 = torch.nn.Linear(5, 10, device=device)
+ self.relu = torch.nn.ReLU()
+ self.net2 = torch.nn.Linear(10, 15, device=device)
+ self.net3 = torch.nn.Linear(15, 1, device=device)
+
+ def forward(self, x):
+ return self.net3(self.net2(self.relu(self.net1(x))))
+
+
class TestTpCheckpoint(DTensorTestBase):
@with_comms
@skip_if_lt_x_gpu(2)
@@ -48,9 +61,9 @@
optimizer = torch.optim.SGD(model.parameters(), lr=0.25)
original_state_dict = deepcopy(model.state_dict())
- dist_cp.save_state_dict(
+ DCP.save_state_dict(
state_dict=original_state_dict,
- storage_writer=dist_cp.FileSystemWriter(CHECKPOINT_DIR),
+ storage_writer=DCP.FileSystemWriter(CHECKPOINT_DIR),
planner=DefaultSavePlanner(),
)
@@ -66,9 +79,9 @@
for param1, param2 in zip(original_state_dict.values(), state_dict.values()):
self.assertNotEqual(param1.to_local(), param2.to_local())
- dist_cp.load_state_dict(
+ DCP.load_state_dict(
state_dict=state_dict,
- storage_reader=dist_cp.FileSystemReader(CHECKPOINT_DIR),
+ storage_reader=DCP.FileSystemReader(CHECKPOINT_DIR),
planner=DefaultLoadPlanner(),
)
@@ -76,6 +89,47 @@
for param1, param2 in zip(original_state_dict.values(), state_dict.values()):
self.assertEqual(param1.to_local(), param2.to_local())
+ @with_comms
+ @skip_if_lt_x_gpu(2)
+ @with_temp_dir
+ def test_tp_checkpoint_load_on_meta_device(self):
+ CHECKPOINT_DIR = self.temp_dir
+ mesh_shpe = (self.world_size,)
+ tp_mesh = init_device_mesh(self.device_type, mesh_shpe)
+
+ # create model and move it to GPU with id rank
+ model = UnevenShardedModel(self.device_type).cuda(self.rank)
+ # Parallelize the module based on the given Parallel Style.
+ parallelize_plan = {
+ "net1": RowwiseParallel(),
+ "net2": ColwiseParallel(),
+ }
+ model = parallelize_module(model, tp_mesh, parallelize_plan=parallelize_plan)
+ original_state_dict = deepcopy(model.state_dict())
+
+ DCP.save_state_dict(
+ state_dict=original_state_dict,
+ storage_writer=DCP.FileSystemWriter(CHECKPOINT_DIR),
+ )
+
+ model2 = parallelize_module(
+ UnevenShardedModel("meta"), tp_mesh, parallelize_plan=parallelize_plan
+ )
+ state_dict_to_load = model2.state_dict()
+
+ DCP.load_state_dict(
+ state_dict=state_dict_to_load,
+ storage_reader=DCP.FileSystemReader(CHECKPOINT_DIR),
+ )
+ model2.load_state_dict(state_dict_to_load, assign=True)
+ state_dict_after_load = model2.state_dict()
+
+ # After loading, check whether params in state_dict_after_load are equal to original_state_dict.
+ for param1, param2 in zip(
+ original_state_dict.values(), state_dict_after_load.values()
+ ):
+ self.assertEqual(param1, param2)
+
if __name__ == "__main__":
run_tests()
diff --git a/torch/distributed/checkpoint/default_planner.py b/torch/distributed/checkpoint/default_planner.py
index 775d5b8..d39e2c4 100644
--- a/torch/distributed/checkpoint/default_planner.py
+++ b/torch/distributed/checkpoint/default_planner.py
@@ -40,6 +40,7 @@
_create_default_metadata_only_plan,
_create_read_items,
_create_write_items,
+ _init_state_dict,
)
from torch.distributed.checkpoint.utils import find_state_dict_object
@@ -162,6 +163,7 @@
metadata: Metadata,
is_coordinator: bool,
) -> None:
+ _init_state_dict(state_dict)
self.original_state_dict = state_dict
if self.flatten_sharded_tensors:
diff --git a/torch/distributed/checkpoint/planner_helpers.py b/torch/distributed/checkpoint/planner_helpers.py
index c01ea0d..a3f6f95 100644
--- a/torch/distributed/checkpoint/planner_helpers.py
+++ b/torch/distributed/checkpoint/planner_helpers.py
@@ -1,12 +1,17 @@
-from typing import Any, List
+from typing import Any, cast, List
import torch
+import torch.distributed as dist
+from torch._utils import _get_device_module
+
from torch.distributed._shard.metadata import ShardMetadata
from torch.distributed._shard.sharded_tensor import ShardedTensor
from torch.distributed._shard.sharded_tensor.metadata import TensorProperties
from torch.distributed._tensor import DTensor
from torch.distributed._tensor._utils import compute_local_shape_and_global_offset
+from torch.utils._pytree import tree_map_only
+
from .metadata import (
BytesStorageMetadata,
ChunkStorageMetadata,
@@ -253,3 +258,46 @@
length=0,
)
]
+
+
+def _init_state_dict(state_dict: STATE_DICT_TYPE) -> None:
+ state_dict_assigned_storage = tree_map_only(
+ torch.Tensor, lambda v: _init_meta_tensor(v), state_dict
+ )
+ # The inplace version of tree_map_only, tree_map_only_ doesn't seem to work.
+ # So we need to temporariy update the each element in the state dict with meta tensor.
+ for k in state_dict.keys():
+ state_dict[k] = state_dict_assigned_storage[k]
+
+
+def _init_meta_tensor(value: Any) -> Any:
+ """
+ Initializes tensor, moves it to device for torch.Tensor/DTensor on meta device.
+ """
+
+ device = getattr(value, "device", None)
+ # DCP does the initialization if it's meta tensor/DTensor.
+ if device == torch.device("meta"):
+ device_type = dist.distributed_c10d._get_pg_default_device().type
+ device = cast(torch.device, _get_device_module(device_type).current_device())
+ if isinstance(value, DTensor):
+ new_local_tensor = torch.empty_like(value.to_local(), device=device)
+ # We need to pass shape and stride explicitly, since DTensor might be
+ # sharded unevenly.
+ dtensor = DTensor.from_local(
+ new_local_tensor,
+ device_mesh=value.device_mesh,
+ placements=value.placements,
+ shape=value.size(),
+ stride=value.stride(),
+ )
+ return dtensor
+ elif isinstance(value, torch.Tensor):
+ tensor = torch.empty_like(value, device=device)
+ return tensor
+ else:
+ raise RuntimeError(
+ f"Found unsupported type {type(value)} for meta device loading."
+ )
+ else:
+ return value