[DSD] Add unittest to verify HSDP1 + broadcast_from_rank0 (#128755)
HSDP1 + broadcast_from_rank0 actually behaves differently from FSDP1 + broadcast_from_rank0. So we need an unittest to cover this use case.
This test relies on the fix from https://github.com/pytorch/pytorch/pull/128446.
Differential Revision: [D58621436](https://our.internmc.facebook.com/intern/diff/D58621436/)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/128755
Approved by: https://github.com/Skylion007, https://github.com/wz337
ghstack dependencies: #128685
diff --git a/test/distributed/checkpoint/test_state_dict.py b/test/distributed/checkpoint/test_state_dict.py
index ac62635..7736350 100644
--- a/test/distributed/checkpoint/test_state_dict.py
+++ b/test/distributed/checkpoint/test_state_dict.py
@@ -33,7 +33,11 @@
set_optimizer_state_dict,
StateDictOptions,
)
-from torch.distributed.fsdp import FullyShardedDataParallel as FSDP, StateDictType
+from torch.distributed.fsdp import (
+ FullyShardedDataParallel as FSDP,
+ ShardingStrategy,
+ StateDictType,
+)
from torch.distributed.fsdp.wrap import ModuleWrapPolicy
from torch.distributed.optim import _apply_optimizer_in_backward
from torch.nn.parallel import DistributedDataParallel as DDP
@@ -70,7 +74,7 @@
@property
def world_size(self) -> int:
- return 2
+ return min(4, torch.cuda.device_count())
def _test_save_load(
self,
@@ -567,55 +571,71 @@
set_model_state_dict(ddp_model, get_model_state_dict(ddp_model))
self.assertEqual(model.state_dict(), get_model_state_dict(ddp_model))
- @with_comms
- @skip_if_lt_x_gpu(2)
- def test_broadcast_from_rank0(self) -> None:
- def inner_test(wrapper):
- model = CompositeParamModel(device=torch.device("cuda"))
- optim = torch.optim.Adam(model.parameters())
- fsdp_model = wrapper(copy.deepcopy(model))
- fsdp_optim = torch.optim.Adam(fsdp_model.parameters())
+ def _test_broadcast_from_rank0(self, wrapper) -> None:
+ model = CompositeParamModel(device=torch.device("cuda"))
+ optim = torch.optim.Adam(model.parameters())
+ fsdp_model = wrapper(copy.deepcopy(model))
+ fsdp_optim = torch.optim.Adam(fsdp_model.parameters())
- batch = torch.rand(8, 100, device="cuda")
- model(batch).sum().backward()
- optim.step()
- states, optim_states = get_state_dict(model, optim)
+ batch = torch.rand(8, 100, device="cuda")
+ model(batch).sum().backward()
+ optim.step()
+ states, optim_states = get_state_dict(model, optim)
- fsdp_model(batch).sum().backward()
- fsdp_optim.step()
+ fsdp_model(batch).sum().backward()
+ fsdp_optim.step()
- def check(equal):
- fsdp_states = get_model_state_dict(
- fsdp_model,
- options=StateDictOptions(full_state_dict=True),
- )
- fsdp_optim_states = get_optimizer_state_dict(
- fsdp_model,
- fsdp_optim,
- options=StateDictOptions(full_state_dict=True),
- )
- if equal:
- self.assertEqual(states, fsdp_states)
- self.assertEqual(optim_states, fsdp_optim_states)
- else:
- self.assertNotEqual(states, fsdp_states)
- self.assertNotEqual(optim_states, fsdp_optim_states)
-
- check(equal=True)
- fsdp_model(batch).sum().backward()
- fsdp_optim.step()
- check(equal=False)
-
- # Drop the states to simulate loading from rank0
- if dist.get_rank() > 0:
- load_states = {}
- load_states2 = {}
- load_optim_states = {}
+ def check(equal):
+ fsdp_states = get_model_state_dict(
+ fsdp_model,
+ options=StateDictOptions(full_state_dict=True),
+ )
+ fsdp_optim_states = get_optimizer_state_dict(
+ fsdp_model,
+ fsdp_optim,
+ options=StateDictOptions(full_state_dict=True),
+ )
+ if equal:
+ self.assertEqual(states, fsdp_states)
+ self.assertEqual(optim_states, fsdp_optim_states)
else:
- load_states = copy.deepcopy(states)
- load_states2 = copy.deepcopy(states)
- load_optim_states = copy.deepcopy(optim_states)
+ self.assertNotEqual(states, fsdp_states)
+ self.assertNotEqual(optim_states, fsdp_optim_states)
+ check(equal=True)
+ fsdp_model(batch).sum().backward()
+ fsdp_optim.step()
+ check(equal=False)
+
+ # Drop the states to simulate loading from rank0
+ if dist.get_rank() > 0:
+ load_states = {}
+ load_states2 = {}
+ load_optim_states = {}
+ else:
+ load_states = copy.deepcopy(states)
+ load_states2 = copy.deepcopy(states)
+ load_optim_states = copy.deepcopy(optim_states)
+
+ set_model_state_dict(
+ fsdp_model,
+ model_state_dict=load_states,
+ options=StateDictOptions(broadcast_from_rank0=True, full_state_dict=True),
+ )
+ set_optimizer_state_dict(
+ fsdp_model,
+ fsdp_optim,
+ optim_state_dict=load_optim_states,
+ options=StateDictOptions(broadcast_from_rank0=True, full_state_dict=True),
+ )
+
+ check(equal=True)
+ # Verify the `strict` flag.
+ load_states = load_states2
+ if load_states:
+ key = next(iter(load_states.keys()))
+ load_states.pop(key)
+ with self.assertRaisesRegex(RuntimeError, "Missing key"):
set_model_state_dict(
fsdp_model,
model_state_dict=load_states,
@@ -623,30 +643,10 @@
broadcast_from_rank0=True, full_state_dict=True
),
)
- set_optimizer_state_dict(
- fsdp_model,
- fsdp_optim,
- optim_state_dict=load_optim_states,
- options=StateDictOptions(
- broadcast_from_rank0=True, full_state_dict=True
- ),
- )
- check(equal=True)
- # Verify the `strict` flag.
- load_states = load_states2
- if load_states:
- key = next(iter(load_states.keys()))
- load_states.pop(key)
- with self.assertRaisesRegex(RuntimeError, "Missing key"):
- set_model_state_dict(
- fsdp_model,
- model_state_dict=load_states,
- options=StateDictOptions(
- broadcast_from_rank0=True, full_state_dict=True
- ),
- )
-
+ @with_comms
+ @skip_if_lt_x_gpu(2)
+ def test_broadcast_from_rank0(self) -> None:
device_mesh = init_device_mesh("cuda", (self.world_size,))
self.run_subtests(
{
@@ -655,7 +655,24 @@
functools.partial(FSDP, device_mesh=device_mesh),
]
},
- inner_test,
+ self._test_broadcast_from_rank0,
+ )
+
+ @with_comms
+ @skip_if_lt_x_gpu(4)
+ def test_broadcast_from_rank0_hsdp(self) -> None:
+ device_mesh = init_device_mesh("cuda", (2, self.world_size // 2))
+ self.run_subtests(
+ {
+ "wrapper": [
+ functools.partial(
+ FSDP,
+ device_mesh=device_mesh,
+ sharding_strategy=ShardingStrategy.HYBRID_SHARD,
+ ),
+ ]
+ },
+ self._test_broadcast_from_rank0,
)
@with_comms