[DCP] Adds support for meta tensor loading for DCP.load_state_dict() (#113319)

Currently, DCP requires the `model.state_dict()` to be materialized before passing it to DCP to load, since DCP uses the pre-allocated storage from the initialized model state_dict. Therefore, even for fine-tuning and distributed inference, users would need to explicitly materialize the model on GPU before `DCP.load_state_dict()`.

Today's flow:
```
with torch.device("meta"):
    model2 = parallelize_module(
        MLPModule("meta"), tp_mesh, parallelize_plan=parallelize_plan
    )

model.to_empty(device='cuda')
state_dict_to_load = model2.state_dict()
DCP.load_state_dict(
    state_dict=state_dict_to_load,
    storage_reader=DCP.FileSystemReader(CHECKPOINT_DIR),
)
model2.load_state_dict(state_dict_to_load)
```

This PR adds support for meta tensor loading. In DCP's planner, when encountering tensors/DTensor on meta device, we initialize tensor/DTensor on the current device on the fly and replace the tensor/DTensor on meta device in the state_dict.  After the change, users no longer needs to manually call `model.to_empty()` when loading existing checkpoints for fine-tuning and distributed inference.

Updated user flow:
```
with torch.device("meta"):
    model2 = parallelize_module(
        MLPModule("meta"), tp_mesh, parallelize_plan=parallelize_plan
    )
# no longer need to call model.to_empty(device='cuda')
state_dict_to_load = model2.state_dict()
DCP.load_state_dict(
    state_dict=state_dict_to_load,
    storage_reader=DCP.FileSystemReader(CHECKPOINT_DIR),
)
model2.load_state_dict(state_dict_to_load, assign=True)
```

Note that for distributed training, it's still the users' responsibility to reset the parameters (`model.reset_parameters()`) as checkpoint might not exist.

Note that we need to loop thru the state_dict to replace meta tensor/DTensor instead of calling `model.to_empty()` since `DCP.load()` only takes in state_dict but not model.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/113319
Approved by: https://github.com/fegin, https://github.com/LucasLLC
diff --git a/test/distributed/checkpoint/test_tp_checkpoint.py b/test/distributed/checkpoint/test_tp_checkpoint.py
index 692d987..5545782 100644
--- a/test/distributed/checkpoint/test_tp_checkpoint.py
+++ b/test/distributed/checkpoint/test_tp_checkpoint.py
@@ -3,7 +3,7 @@
 from copy import deepcopy
 
 import torch
-import torch.distributed.checkpoint as dist_cp
+import torch.distributed.checkpoint as DCP
 
 from torch.distributed._tensor import init_device_mesh
 
@@ -28,6 +28,19 @@
 from torch.testing._internal.distributed.checkpoint_utils import with_temp_dir
 
 
+class UnevenShardedModel(torch.nn.Module):
+    def __init__(self, device):
+        super().__init__()
+        torch.manual_seed(5)
+        self.net1 = torch.nn.Linear(5, 10, device=device)
+        self.relu = torch.nn.ReLU()
+        self.net2 = torch.nn.Linear(10, 15, device=device)
+        self.net3 = torch.nn.Linear(15, 1, device=device)
+
+    def forward(self, x):
+        return self.net3(self.net2(self.relu(self.net1(x))))
+
+
 class TestTpCheckpoint(DTensorTestBase):
     @with_comms
     @skip_if_lt_x_gpu(2)
@@ -48,9 +61,9 @@
         optimizer = torch.optim.SGD(model.parameters(), lr=0.25)
         original_state_dict = deepcopy(model.state_dict())
 
-        dist_cp.save_state_dict(
+        DCP.save_state_dict(
             state_dict=original_state_dict,
-            storage_writer=dist_cp.FileSystemWriter(CHECKPOINT_DIR),
+            storage_writer=DCP.FileSystemWriter(CHECKPOINT_DIR),
             planner=DefaultSavePlanner(),
         )
 
@@ -66,9 +79,9 @@
         for param1, param2 in zip(original_state_dict.values(), state_dict.values()):
             self.assertNotEqual(param1.to_local(), param2.to_local())
 
-        dist_cp.load_state_dict(
+        DCP.load_state_dict(
             state_dict=state_dict,
-            storage_reader=dist_cp.FileSystemReader(CHECKPOINT_DIR),
+            storage_reader=DCP.FileSystemReader(CHECKPOINT_DIR),
             planner=DefaultLoadPlanner(),
         )
 
@@ -76,6 +89,47 @@
         for param1, param2 in zip(original_state_dict.values(), state_dict.values()):
             self.assertEqual(param1.to_local(), param2.to_local())
 
+    @with_comms
+    @skip_if_lt_x_gpu(2)
+    @with_temp_dir
+    def test_tp_checkpoint_load_on_meta_device(self):
+        CHECKPOINT_DIR = self.temp_dir
+        mesh_shpe = (self.world_size,)
+        tp_mesh = init_device_mesh(self.device_type, mesh_shpe)
+
+        # create model and move it to GPU with id rank
+        model = UnevenShardedModel(self.device_type).cuda(self.rank)
+        # Parallelize the module based on the given Parallel Style.
+        parallelize_plan = {
+            "net1": RowwiseParallel(),
+            "net2": ColwiseParallel(),
+        }
+        model = parallelize_module(model, tp_mesh, parallelize_plan=parallelize_plan)
+        original_state_dict = deepcopy(model.state_dict())
+
+        DCP.save_state_dict(
+            state_dict=original_state_dict,
+            storage_writer=DCP.FileSystemWriter(CHECKPOINT_DIR),
+        )
+
+        model2 = parallelize_module(
+            UnevenShardedModel("meta"), tp_mesh, parallelize_plan=parallelize_plan
+        )
+        state_dict_to_load = model2.state_dict()
+
+        DCP.load_state_dict(
+            state_dict=state_dict_to_load,
+            storage_reader=DCP.FileSystemReader(CHECKPOINT_DIR),
+        )
+        model2.load_state_dict(state_dict_to_load, assign=True)
+        state_dict_after_load = model2.state_dict()
+
+        # After loading, check whether params in state_dict_after_load are equal to original_state_dict.
+        for param1, param2 in zip(
+            original_state_dict.values(), state_dict_after_load.values()
+        ):
+            self.assertEqual(param1, param2)
+
 
 if __name__ == "__main__":
     run_tests()
diff --git a/torch/distributed/checkpoint/default_planner.py b/torch/distributed/checkpoint/default_planner.py
index 775d5b8..d39e2c4 100644
--- a/torch/distributed/checkpoint/default_planner.py
+++ b/torch/distributed/checkpoint/default_planner.py
@@ -40,6 +40,7 @@
     _create_default_metadata_only_plan,
     _create_read_items,
     _create_write_items,
+    _init_state_dict,
 )
 from torch.distributed.checkpoint.utils import find_state_dict_object
 
@@ -162,6 +163,7 @@
         metadata: Metadata,
         is_coordinator: bool,
     ) -> None:
+        _init_state_dict(state_dict)
         self.original_state_dict = state_dict
 
         if self.flatten_sharded_tensors:
diff --git a/torch/distributed/checkpoint/planner_helpers.py b/torch/distributed/checkpoint/planner_helpers.py
index c01ea0d..a3f6f95 100644
--- a/torch/distributed/checkpoint/planner_helpers.py
+++ b/torch/distributed/checkpoint/planner_helpers.py
@@ -1,12 +1,17 @@
-from typing import Any, List
+from typing import Any, cast, List
 
 import torch
+import torch.distributed as dist
+from torch._utils import _get_device_module
+
 from torch.distributed._shard.metadata import ShardMetadata
 from torch.distributed._shard.sharded_tensor import ShardedTensor
 from torch.distributed._shard.sharded_tensor.metadata import TensorProperties
 from torch.distributed._tensor import DTensor
 from torch.distributed._tensor._utils import compute_local_shape_and_global_offset
 
+from torch.utils._pytree import tree_map_only
+
 from .metadata import (
     BytesStorageMetadata,
     ChunkStorageMetadata,
@@ -253,3 +258,46 @@
                 length=0,
             )
         ]
+
+
+def _init_state_dict(state_dict: STATE_DICT_TYPE) -> None:
+    state_dict_assigned_storage = tree_map_only(
+        torch.Tensor, lambda v: _init_meta_tensor(v), state_dict
+    )
+    # The inplace version of tree_map_only, tree_map_only_ doesn't seem to work.
+    # So we need to temporariy update the each element in the state dict with meta tensor.
+    for k in state_dict.keys():
+        state_dict[k] = state_dict_assigned_storage[k]
+
+
+def _init_meta_tensor(value: Any) -> Any:
+    """
+    Initializes tensor, moves it to device for torch.Tensor/DTensor on meta device.
+    """
+
+    device = getattr(value, "device", None)
+    # DCP does the initialization if it's meta tensor/DTensor.
+    if device == torch.device("meta"):
+        device_type = dist.distributed_c10d._get_pg_default_device().type
+        device = cast(torch.device, _get_device_module(device_type).current_device())
+        if isinstance(value, DTensor):
+            new_local_tensor = torch.empty_like(value.to_local(), device=device)
+            # We need to pass shape and stride explicitly, since DTensor might be
+            # sharded unevenly.
+            dtensor = DTensor.from_local(
+                new_local_tensor,
+                device_mesh=value.device_mesh,
+                placements=value.placements,
+                shape=value.size(),
+                stride=value.stride(),
+            )
+            return dtensor
+        elif isinstance(value, torch.Tensor):
+            tensor = torch.empty_like(value, device=device)
+            return tensor
+        else:
+            raise RuntimeError(
+                f"Found unsupported type {type(value)} for meta device loading."
+            )
+    else:
+        return value