[FSDP2] Ran post-acc-grad hooks manually (#129450)
FSDP2 accumulates gradients for sharded parameters outside of the autograd engine's normal accumulation logic. We can respect registered post-accumulate-grad hooks by running them manually.
**Discussion**
Discussing with @soulitzer, changing FSDP2 to make the sharded parameters autograd leaves requires nontrivial changes to FSDP and some changes to the autograd engine (around forward vs. backward streams) where the changes may not preserve eager-mode performance and/or add some complexity.
Under the FSDP2 design, the sharded parameters never participate in autograd, so calling `register_post_accumulate_grad_hook` on them would otherwise be a no-op. In other words, there is virtually no chance for FSDP2 incorrectly re-running the hook when it should not.
Given these, a reasonable near-term solution is for FSDP2 to run the post-accumulate-grad hooks manually.
**Caveats**
- Running `foreach=False` optimizer _per parameter tensor_ incurs significantly higher CPU overhead compared to `foreach=True` (partially due to `DTensor` being a `__torch_dispatch__` tensor subclass).
- On preliminary benchmarking on Llama3-8B on 8 GPUs, this CPU overhead is mostly tolerable, but on smaller # of GPUs or a less compute-intensive model, this may not be.
- One solution for native Adam/AdamW is to use `fused=True`, which makes both the CPU overhead lower and GPU compute faster. However, this is generally not an option for user-defined optimizers.
- If this CPU overhead blocks adoption of this feature, then we should seriously consider an FSDP-specific API like `register_post_backward_hook(params: List[nn.Parameter]) -> None` that allows the user to see all parameters in the `FSDPParamGroup` together for the hook so that the user can still run a `foreach=True` optimizer step on that `List[nn.Parameter]`.
- The post-accumulate-grad hook runs in the reduce-scatter stream. Our current stream handling logic does not have the default stream wait for the reduce-scatter stream until the end of backward. Unless we add that, we cannot simply run the post-accumulate-grad hook in the default stream.
- This means that optimizer compute will overlap with backward compute, which may slowdown end-to-end execution slightly (e.g. due to SM contention or wave quantization effects). For example, on Llama3-8B, we see about ~3% decrease in MFU when running optimizer in backward even though the optimizer steps are fully overlapped and there are no CPU boundedness issues.
- This PR's goal is only to run the hook manually. State dict etc. for optimizer-in-backward is out of scope.
**Experiments (torchtitan)**
- Llama3-8B on 2 GPUs, local batch size 1, with full activation checkpointing, and bf16/fp32 mixed precision:
- Without optimizer-in-backward: 82.03 GiB reserved memory; 28.1% MFU
- With optimizer-in-backward (`foreach=False`): 72.84 GiB reserved memory; 28.9% MFU (speedup from more of optimizer step overlapped)
- With optimizer-in-backward (`fused=True`): 70.84 GiB reserved memory; 30.4% MFU
Pull Request resolved: https://github.com/pytorch/pytorch/pull/129450
Approved by: https://github.com/weifengpy, https://github.com/yf225
diff --git a/test/distributed/_composable/fsdp/test_fully_shard_autograd.py b/test/distributed/_composable/fsdp/test_fully_shard_autograd.py
index 8192907..c5cc5ee 100644
--- a/test/distributed/_composable/fsdp/test_fully_shard_autograd.py
+++ b/test/distributed/_composable/fsdp/test_fully_shard_autograd.py
@@ -2,6 +2,9 @@
import collections
import copy
+import functools
+import itertools
+import unittest
from typing import Any, List, Optional, Type, Union
@@ -11,13 +14,20 @@
from torch.distributed._composable.fsdp import fully_shard
from torch.nn.parallel.scatter_gather import _is_namedtuple
+from torch.testing._internal.common_cuda import TEST_CUDA
from torch.testing._internal.common_distributed import skip_if_lt_x_gpu
from torch.testing._internal.common_fsdp import (
check_sharded_parity,
DoubleLinear,
FSDPTest,
+ FSDPTestMultiThread,
+ MLP,
)
from torch.testing._internal.common_utils import run_tests
+from torch.testing._internal.distributed._tensor.common_dtensor import (
+ ModelArgs,
+ Transformer,
+)
class TestFullyShardAutograd(FSDPTest):
@@ -232,5 +242,91 @@
_optim.zero_grad(set_to_none=(iter_idx % 2))
+class TestFullyShardPostAccGradHookMultiThread(FSDPTestMultiThread):
+ @property
+ def world_size(self) -> int:
+ return 2
+
+ @unittest.skipIf(not TEST_CUDA, "no cuda")
+ def test_post_acc_grad_hook_runs(self):
+ param_name_to_hook_count = collections.defaultdict(int)
+
+ def hook(param_name: str, param: torch.Tensor) -> None:
+ nonlocal param_name_to_hook_count
+ param_name_to_hook_count[param_name] += 1
+
+ model = MLP(8)
+ for module in (model.in_proj, model.out_proj, model):
+ fully_shard(module)
+ for param_name, param in model.named_parameters():
+ param_hook = functools.partial(hook, param_name)
+ param.register_post_accumulate_grad_hook(param_hook)
+
+ inp = torch.randn((2, 8), device="cuda")
+ model(inp).sum().backward()
+ param_names = {param_name for param_name, _ in model.named_parameters()}
+ self.assertEqual(param_names, set(param_name_to_hook_count.keys()))
+ for param_name, count in param_name_to_hook_count.items():
+ self.assertEqual(count, 1)
+
+
+class TestFullyShardPostAccGradHookMultiProcess(FSDPTest):
+ @property
+ def world_size(self) -> int:
+ return min(torch.cuda.device_count(), 2)
+
+ @skip_if_lt_x_gpu(2)
+ def test_post_acc_grad_hook_optim_parity(self):
+ """
+ Tests parity of running the optimizer via the post-accumulate-grad
+ hook vs. normally.
+ """
+ torch.manual_seed(42)
+ model_args = ModelArgs(dropout_p=0.0)
+ model = Transformer(model_args)
+
+ ref_model = copy.deepcopy(model).cuda()
+ for module in itertools.chain(ref_model.layers, [ref_model]):
+ fully_shard(module)
+ optim_kwargs = {"lr": 1e-2, "foreach": False}
+ ref_optim = torch.optim.AdamW(ref_model.parameters(), **optim_kwargs)
+ lr_scheduler_kwargs = {"step_size": 5}
+ ref_lr_scheduler = torch.optim.lr_scheduler.StepLR(
+ ref_optim, **lr_scheduler_kwargs
+ )
+
+ for module in itertools.chain(model.layers, [model]):
+ fully_shard(module)
+ param_to_optim = {}
+ param_to_lr_scheduler = {}
+ for param in model.parameters():
+ param_to_optim[param] = torch.optim.AdamW([param], **optim_kwargs)
+ param_to_lr_scheduler[param] = torch.optim.lr_scheduler.StepLR(
+ param_to_optim[param], **lr_scheduler_kwargs
+ )
+
+ def optim_hook(param: nn.Parameter) -> None:
+ param_to_optim[param].step()
+ param_to_optim[param].zero_grad()
+ param_to_lr_scheduler[param].step()
+
+ for param in model.parameters():
+ param.register_post_accumulate_grad_hook(optim_hook)
+
+ torch.manual_seed(42 + self.rank)
+ inp = torch.randint(0, model_args.vocab_size, (2, 16), device="cuda")
+ for _ in range(10):
+ ref_loss = ref_model(inp).sum()
+ ref_loss.backward()
+ ref_optim.step()
+ ref_optim.zero_grad()
+ ref_lr_scheduler.step()
+ loss = model(inp).sum()
+ loss.backward()
+ self.assertTrue(torch.equal(ref_loss, loss))
+ for ref_param, param in zip(ref_model.parameters(), model.parameters()):
+ self.assertTrue(torch.equal(ref_param, param))
+
+
if __name__ == "__main__":
run_tests()
diff --git a/test/distributed/_composable/fsdp/test_fully_shard_memory.py b/test/distributed/_composable/fsdp/test_fully_shard_memory.py
index 7c65de7..bfad074 100644
--- a/test/distributed/_composable/fsdp/test_fully_shard_memory.py
+++ b/test/distributed/_composable/fsdp/test_fully_shard_memory.py
@@ -27,14 +27,31 @@
@skip_if_lt_x_gpu(2)
def test_fully_shard_training_memory(self):
self.run_subtests(
- {"reshard_after_forward": [True, False], "use_cpu_offload": [True, False]},
+ {
+ "reshard_after_forward": [True, False],
+ "use_cpu_offload": [True, False],
+ "run_optim_in_backward": [True, False],
+ },
self._test_fully_shard_training_memory,
)
def _test_fully_shard_training_memory(
- self, reshard_after_forward: bool, use_cpu_offload: bool
+ self,
+ reshard_after_forward: bool,
+ use_cpu_offload: bool,
+ run_optim_in_backward: bool,
):
- if not reshard_after_forward and use_cpu_offload:
+ if (
+ # CPU offloading is typically for memory savings, so we expect
+ # users to want to reshard after forward
+ (not reshard_after_forward and use_cpu_offload)
+ # Optimizer in backward frees sharded gradient GPU memory early for
+ # memory savings, so we expect users to want to reshard after
+ # forward; plus, it has no real effect with CPU offloading
+ or (
+ run_optim_in_backward and (not reshard_after_forward or use_cpu_offload)
+ )
+ ):
return # skip since not a common use case
assert (
self.world_size == 2
@@ -74,7 +91,11 @@
fully_shard_fn(module)
fully_shard_fn(model)
# Do not use foreach since intermediates increase peak memory
- optim = torch.optim.Adam(model.parameters(), lr=1e-2, foreach=False)
+ optim_kwargs = {"lr": 1e-2, "foreach": False}
+ if run_optim_in_backward:
+ self._register_optim_in_backward(model, **optim_kwargs)
+ else:
+ optim = torch.optim.Adam(model.parameters(), lr=1e-2, foreach=False)
# Init: Each module is moved to GPU before sharding parameters
peak_mem_mb = self._get_peak_active_memory_mb()
@@ -131,8 +152,14 @@
3 * max_unsharded_numel + non_block_numel
) * 4 / 1e6 + buffer_mb
if not use_cpu_offload:
- # 2x sharded parameters/gradients
- expected_mem_mb += 2 * model_sharded_numel * 4 / 1e6
+ if run_optim_in_backward:
+ # 1x sharded parameters
+ expected_mem_mb += model_sharded_numel * 4 / 1e-6
+ # 1x sharded block gradients
+ expected_mem_mb += max_unsharded_numel // self.world_size * 4 / 1e-6
+ else:
+ # 2x sharded parameters/gradients
+ expected_mem_mb += 2 * model_sharded_numel * 4 / 1e6
else:
assert not use_cpu_offload
# Sharded parameters, unsharded parameters, 1.5x max unsharded
@@ -146,17 +173,21 @@
torch.cuda.reset_peak_memory_stats()
# Optimizer step: unsharded parameters/gradients freed
- optim.step()
+ if not run_optim_in_backward:
+ optim.step()
mem_mb = self._get_peak_active_memory_mb()
expected_mem_mb = buffer_mb
if not use_cpu_offload:
- # 1x sharded parameters, 1x sharded gradients, 2x sharded optimizer
- # states
- expected_mem_mb += (4 * model_sharded_numel) * 4 / 1e6
+ # 1x sharded parameters, 2x sharded optimizer states
+ expected_mem_mb += (3 * model_sharded_numel) * 4 / 1e6
+ if not run_optim_in_backward:
+ # 1x sharded gradients
+ expected_mem_mb += model_sharded_numel * 4 / 1e6
self.assertLessEqual(mem_mb - base_mem_mb, expected_mem_mb)
# Zero grad: sharded gradients freed
- optim.zero_grad()
+ if not run_optim_in_backward:
+ optim.zero_grad()
torch.cuda.reset_peak_memory_stats() # reset after freeing
mem_mb = self._get_peak_active_memory_mb()
expected_mem_mb = 0
@@ -175,6 +206,20 @@
mem_stats = torch.cuda.memory_stats()
return round(mem_stats["active_bytes.all.current"] / 1e6)
+ def _register_optim_in_backward(
+ self, model: torch.nn.Module, **optim_kwargs
+ ) -> None:
+ param_to_optim = {}
+ for param in model.parameters():
+ param_to_optim[param] = torch.optim.AdamW([param], **optim_kwargs)
+
+ def optim_hook(param: torch.nn.Parameter) -> None:
+ param_to_optim[param].step()
+ param_to_optim[param].zero_grad()
+
+ for param in model.parameters():
+ param.register_post_accumulate_grad_hook(optim_hook)
+
if __name__ == "__main__":
run_tests()
diff --git a/torch/distributed/_composable/fsdp/_fsdp_collectives.py b/torch/distributed/_composable/fsdp/_fsdp_collectives.py
index 33de0d7..8f3edca 100644
--- a/torch/distributed/_composable/fsdp/_fsdp_collectives.py
+++ b/torch/distributed/_composable/fsdp/_fsdp_collectives.py
@@ -340,6 +340,12 @@
new_sharded_grad
)
fsdp_param.sharded_param.grad = new_sharded_dtensor_grad
+ if not ca.compiled_autograd_enabled:
+ for hook in (
+ getattr(fsdp_param.sharded_param, "_post_accumulate_grad_hooks", {})
+ or {}
+ ).values():
+ hook(fsdp_param.sharded_param)
padded_sharded_numel = padded_unsharded_size.numel() // world_size
flat_grad_offset += padded_sharded_numel
post_reduce_event = post_reduce_stream.record_event()