Corrected comments in fsdp (#80456)
Currently, pre- and post-division steps in `FullyShardedDataParallel._post_backward_hook` state the following:
> Average grad by world_size for consistency with PyTorch DDP.
This is not matching what is actually going on, i.e. pre-divide factor may be equal to `world_size` and may not.
For example, for `world_size = 3 `, `predivide_factor=2`
This PR clarifies pre- and post-division in the code
Pull Request resolved: https://github.com/pytorch/pytorch/pull/80456
Approved by: https://github.com/rohan-varma
diff --git a/torch/distributed/fsdp/fully_sharded_data_parallel.py b/torch/distributed/fsdp/fully_sharded_data_parallel.py
index f373490..8d7b5e0 100644
--- a/torch/distributed/fsdp/fully_sharded_data_parallel.py
+++ b/torch/distributed/fsdp/fully_sharded_data_parallel.py
@@ -2915,7 +2915,9 @@
param.grad.data = param.grad.data.to(self.mixed_precision.reduce_dtype)
if self.gradient_predivide_factor > 1 and self.communication_hook is None:
- # Average grad by world_size for consistency with PyTorch DDP.
+ # Average grad by pre-division factor. Together pre- and post-division factors
+ # lead to an overall averaging by world_size, required for consistency with PyTorch DDP.
+ # This is a two-step process to avoid potential underflow and overflow.
param.grad.div_(self.gradient_predivide_factor)
grad = param.grad.data
@@ -2942,7 +2944,9 @@
output, input_flattened, group=self.process_group
)
if self.gradient_postdivide_factor > 1:
- # Average grad by world_size for consistency with PyTorch DDP.
+ # Average grad by pre-division factor. Together pre- and post-division factors
+ # lead to an overall averaging by world_size, required for consistency with PyTorch DDP.
+ # This is a two-step process to avoid potential underflow and overflow.
output.div_(self.gradient_postdivide_factor)
# Note that we need to cast grads back to the full precision if