[DSD] Add unittest to verify HSDP1 + broadcast_from_rank0 (#128755)

HSDP1 + broadcast_from_rank0 actually behaves differently from FSDP1 + broadcast_from_rank0. So we need an unittest to cover this use case.

This test relies on the fix from https://github.com/pytorch/pytorch/pull/128446.

Differential Revision: [D58621436](https://our.internmc.facebook.com/intern/diff/D58621436/)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/128755
Approved by: https://github.com/Skylion007, https://github.com/wz337
ghstack dependencies: #128685
diff --git a/test/distributed/checkpoint/test_state_dict.py b/test/distributed/checkpoint/test_state_dict.py
index ac62635..7736350 100644
--- a/test/distributed/checkpoint/test_state_dict.py
+++ b/test/distributed/checkpoint/test_state_dict.py
@@ -33,7 +33,11 @@
     set_optimizer_state_dict,
     StateDictOptions,
 )
-from torch.distributed.fsdp import FullyShardedDataParallel as FSDP, StateDictType
+from torch.distributed.fsdp import (
+    FullyShardedDataParallel as FSDP,
+    ShardingStrategy,
+    StateDictType,
+)
 from torch.distributed.fsdp.wrap import ModuleWrapPolicy
 from torch.distributed.optim import _apply_optimizer_in_backward
 from torch.nn.parallel import DistributedDataParallel as DDP
@@ -70,7 +74,7 @@
 
     @property
     def world_size(self) -> int:
-        return 2
+        return min(4, torch.cuda.device_count())
 
     def _test_save_load(
         self,
@@ -567,55 +571,71 @@
         set_model_state_dict(ddp_model, get_model_state_dict(ddp_model))
         self.assertEqual(model.state_dict(), get_model_state_dict(ddp_model))
 
-    @with_comms
-    @skip_if_lt_x_gpu(2)
-    def test_broadcast_from_rank0(self) -> None:
-        def inner_test(wrapper):
-            model = CompositeParamModel(device=torch.device("cuda"))
-            optim = torch.optim.Adam(model.parameters())
-            fsdp_model = wrapper(copy.deepcopy(model))
-            fsdp_optim = torch.optim.Adam(fsdp_model.parameters())
+    def _test_broadcast_from_rank0(self, wrapper) -> None:
+        model = CompositeParamModel(device=torch.device("cuda"))
+        optim = torch.optim.Adam(model.parameters())
+        fsdp_model = wrapper(copy.deepcopy(model))
+        fsdp_optim = torch.optim.Adam(fsdp_model.parameters())
 
-            batch = torch.rand(8, 100, device="cuda")
-            model(batch).sum().backward()
-            optim.step()
-            states, optim_states = get_state_dict(model, optim)
+        batch = torch.rand(8, 100, device="cuda")
+        model(batch).sum().backward()
+        optim.step()
+        states, optim_states = get_state_dict(model, optim)
 
-            fsdp_model(batch).sum().backward()
-            fsdp_optim.step()
+        fsdp_model(batch).sum().backward()
+        fsdp_optim.step()
 
-            def check(equal):
-                fsdp_states = get_model_state_dict(
-                    fsdp_model,
-                    options=StateDictOptions(full_state_dict=True),
-                )
-                fsdp_optim_states = get_optimizer_state_dict(
-                    fsdp_model,
-                    fsdp_optim,
-                    options=StateDictOptions(full_state_dict=True),
-                )
-                if equal:
-                    self.assertEqual(states, fsdp_states)
-                    self.assertEqual(optim_states, fsdp_optim_states)
-                else:
-                    self.assertNotEqual(states, fsdp_states)
-                    self.assertNotEqual(optim_states, fsdp_optim_states)
-
-            check(equal=True)
-            fsdp_model(batch).sum().backward()
-            fsdp_optim.step()
-            check(equal=False)
-
-            # Drop the states to simulate loading from rank0
-            if dist.get_rank() > 0:
-                load_states = {}
-                load_states2 = {}
-                load_optim_states = {}
+        def check(equal):
+            fsdp_states = get_model_state_dict(
+                fsdp_model,
+                options=StateDictOptions(full_state_dict=True),
+            )
+            fsdp_optim_states = get_optimizer_state_dict(
+                fsdp_model,
+                fsdp_optim,
+                options=StateDictOptions(full_state_dict=True),
+            )
+            if equal:
+                self.assertEqual(states, fsdp_states)
+                self.assertEqual(optim_states, fsdp_optim_states)
             else:
-                load_states = copy.deepcopy(states)
-                load_states2 = copy.deepcopy(states)
-                load_optim_states = copy.deepcopy(optim_states)
+                self.assertNotEqual(states, fsdp_states)
+                self.assertNotEqual(optim_states, fsdp_optim_states)
 
+        check(equal=True)
+        fsdp_model(batch).sum().backward()
+        fsdp_optim.step()
+        check(equal=False)
+
+        # Drop the states to simulate loading from rank0
+        if dist.get_rank() > 0:
+            load_states = {}
+            load_states2 = {}
+            load_optim_states = {}
+        else:
+            load_states = copy.deepcopy(states)
+            load_states2 = copy.deepcopy(states)
+            load_optim_states = copy.deepcopy(optim_states)
+
+        set_model_state_dict(
+            fsdp_model,
+            model_state_dict=load_states,
+            options=StateDictOptions(broadcast_from_rank0=True, full_state_dict=True),
+        )
+        set_optimizer_state_dict(
+            fsdp_model,
+            fsdp_optim,
+            optim_state_dict=load_optim_states,
+            options=StateDictOptions(broadcast_from_rank0=True, full_state_dict=True),
+        )
+
+        check(equal=True)
+        # Verify the `strict` flag.
+        load_states = load_states2
+        if load_states:
+            key = next(iter(load_states.keys()))
+            load_states.pop(key)
+        with self.assertRaisesRegex(RuntimeError, "Missing key"):
             set_model_state_dict(
                 fsdp_model,
                 model_state_dict=load_states,
@@ -623,30 +643,10 @@
                     broadcast_from_rank0=True, full_state_dict=True
                 ),
             )
-            set_optimizer_state_dict(
-                fsdp_model,
-                fsdp_optim,
-                optim_state_dict=load_optim_states,
-                options=StateDictOptions(
-                    broadcast_from_rank0=True, full_state_dict=True
-                ),
-            )
 
-            check(equal=True)
-            # Verify the `strict` flag.
-            load_states = load_states2
-            if load_states:
-                key = next(iter(load_states.keys()))
-                load_states.pop(key)
-            with self.assertRaisesRegex(RuntimeError, "Missing key"):
-                set_model_state_dict(
-                    fsdp_model,
-                    model_state_dict=load_states,
-                    options=StateDictOptions(
-                        broadcast_from_rank0=True, full_state_dict=True
-                    ),
-                )
-
+    @with_comms
+    @skip_if_lt_x_gpu(2)
+    def test_broadcast_from_rank0(self) -> None:
         device_mesh = init_device_mesh("cuda", (self.world_size,))
         self.run_subtests(
             {
@@ -655,7 +655,24 @@
                     functools.partial(FSDP, device_mesh=device_mesh),
                 ]
             },
-            inner_test,
+            self._test_broadcast_from_rank0,
+        )
+
+    @with_comms
+    @skip_if_lt_x_gpu(4)
+    def test_broadcast_from_rank0_hsdp(self) -> None:
+        device_mesh = init_device_mesh("cuda", (2, self.world_size // 2))
+        self.run_subtests(
+            {
+                "wrapper": [
+                    functools.partial(
+                        FSDP,
+                        device_mesh=device_mesh,
+                        sharding_strategy=ShardingStrategy.HYBRID_SHARD,
+                    ),
+                ]
+            },
+            self._test_broadcast_from_rank0,
         )
 
     @with_comms