[FSDP][state_dict] Restore the state_dict_config for NO_SHARD (#100855)
Any change to the user configurations should be temporary. This PR fixes the issue when NO_SHARD state_dict/load_state_dict is called, the state_dict_config and state_dict_type are changed permanently.
Differential Revision: [D45593313](https://our.internmc.facebook.com/intern/diff/D45593313/)
**NOTE FOR REVIEWERS**: This PR has internal Meta-specific changes or comments, please review them on [Phabricator](https://our.internmc.facebook.com/intern/diff/D45593313/)!
Pull Request resolved: https://github.com/pytorch/pytorch/pull/100855
Approved by: https://github.com/awgu, https://github.com/zhaojuanmao, https://github.com/rohan-varma
diff --git a/test/distributed/_composable/test_compose.py b/test/distributed/_composable/test_compose.py
index 126abd5..3b7c820 100644
--- a/test/distributed/_composable/test_compose.py
+++ b/test/distributed/_composable/test_compose.py
@@ -302,6 +302,8 @@
self.assertIsInstance(tensor, ShardedTensor)
elif "u2" in fqn:
self.assertIsInstance(tensor, torch.Tensor)
+ # Ensure that get_state_dict_type can still correctly get the settings.
+ _ = FSDP.get_state_dict_type(model)
instantiate_parametrized_tests(TestFSDPCheckpoint)
diff --git a/torch/distributed/fsdp/_state_dict_utils.py b/torch/distributed/fsdp/_state_dict_utils.py
index c41a50f..6b60d37 100644
--- a/torch/distributed/fsdp/_state_dict_utils.py
+++ b/torch/distributed/fsdp/_state_dict_utils.py
@@ -1,7 +1,18 @@
+import contextlib
import functools
import math
import warnings
-from typing import Any, Callable, cast, Dict, Iterator, List, no_type_check, Tuple
+from typing import (
+ Any,
+ Callable,
+ cast,
+ Dict,
+ Generator,
+ Iterator,
+ List,
+ no_type_check,
+ Tuple,
+)
import torch
import torch.distributed as dist
@@ -635,6 +646,20 @@
_deregister_orig_params(module, fsdp_state)
[email protected]
+def _replace_with_full_state_dict_type(fsdp_state: _FSDPState) -> Generator:
+ old_state_dict_config = fsdp_state._state_dict_config
+ old_state_dict_type = fsdp_state._state_dict_type
+ try:
+ fsdp_state._state_dict_config = FullStateDictConfig()
+ fsdp_state._state_dict_type = StateDictType.FULL_STATE_DICT
+ yield
+ except Exception as e:
+ raise e
+ fsdp_state._state_dict_config = old_state_dict_config
+ fsdp_state._state_dict_type = old_state_dict_type
+
+
@no_type_check
@torch.no_grad()
def _post_state_dict_hook(
@@ -650,17 +675,23 @@
what postprocessing will be done.
"""
if fsdp_state.sharding_strategy == ShardingStrategy.NO_SHARD:
- fsdp_state._state_dict_config = FullStateDictConfig()
- fsdp_state._state_dict_type = StateDictType.FULL_STATE_DICT
+ context = _replace_with_full_state_dict_type(fsdp_state)
+ warnings.warn(
+ "When using ``NO_SHARD`` for ``ShardingStrategy``, full_state_dict will"
+ "be returned."
+ )
+ else:
+ context = contextlib.nullcontext()
- _post_state_dict_hook_fn = {
- StateDictType.FULL_STATE_DICT: _full_post_state_dict_hook,
- StateDictType.LOCAL_STATE_DICT: _local_post_state_dict_hook,
- StateDictType.SHARDED_STATE_DICT: _sharded_post_state_dict_hook,
- }
- processed_state_dict = _post_state_dict_hook_fn[fsdp_state._state_dict_type](
- module, fsdp_state, state_dict, prefix
- )
+ with context:
+ _post_state_dict_hook_fn = {
+ StateDictType.FULL_STATE_DICT: _full_post_state_dict_hook,
+ StateDictType.LOCAL_STATE_DICT: _local_post_state_dict_hook,
+ StateDictType.SHARDED_STATE_DICT: _sharded_post_state_dict_hook,
+ }
+ processed_state_dict = _post_state_dict_hook_fn[fsdp_state._state_dict_type](
+ module, fsdp_state, state_dict, prefix
+ )
return processed_state_dict
@@ -678,24 +709,26 @@
be done.
"""
if fsdp_state.sharding_strategy == ShardingStrategy.NO_SHARD:
- fsdp_state._state_dict_config = FullStateDictConfig()
- fsdp_state._state_dict_type = StateDictType.FULL_STATE_DICT
+ context = _replace_with_full_state_dict_type(fsdp_state)
warnings.warn(
"When using ``NO_SHARD`` for ``ShardingStrategy``, full_state_dict will"
"be returned."
)
+ else:
+ context = contextlib.nullcontext()
- _pre_state_dict_hook_fn = {
- StateDictType.FULL_STATE_DICT: _full_pre_state_dict_hook,
- StateDictType.LOCAL_STATE_DICT: _local_pre_state_dict_hook,
- StateDictType.SHARDED_STATE_DICT: _sharded_pre_state_dict_hook,
- }
- _pre_state_dict_hook_fn[fsdp_state._state_dict_type](
- fsdp_state,
- module,
- *args,
- **kwargs,
- )
+ with context:
+ _pre_state_dict_hook_fn = {
+ StateDictType.FULL_STATE_DICT: _full_pre_state_dict_hook,
+ StateDictType.LOCAL_STATE_DICT: _local_pre_state_dict_hook,
+ StateDictType.SHARDED_STATE_DICT: _sharded_pre_state_dict_hook,
+ }
+ _pre_state_dict_hook_fn[fsdp_state._state_dict_type](
+ fsdp_state,
+ module,
+ *args,
+ **kwargs,
+ )
@no_type_check
@@ -713,21 +746,27 @@
be done.
"""
if fsdp_state.sharding_strategy == ShardingStrategy.NO_SHARD:
- fsdp_state._state_dict_config = FullStateDictConfig()
- fsdp_state._state_dict_type = StateDictType.FULL_STATE_DICT
+ context = _replace_with_full_state_dict_type(fsdp_state)
+ warnings.warn(
+ "When using ``NO_SHARD`` for ``ShardingStrategy``, full_state_dict will"
+ "be returned."
+ )
+ else:
+ context = contextlib.nullcontext()
- _pre_load_state_dict_hook_fn = {
- StateDictType.FULL_STATE_DICT: _full_pre_load_state_dict_hook,
- StateDictType.LOCAL_STATE_DICT: _local_pre_load_state_dict_hook,
- StateDictType.SHARDED_STATE_DICT: _sharded_pre_load_state_dict_hook,
- }
- # Code that is common for all state_dict impls
- if fsdp_state._device_handle.is_available():
- fsdp_state._device_handle.synchronize()
- # Dispatch into state_dict specific implementation of pre-hook.
- _pre_load_state_dict_hook_fn[fsdp_state._state_dict_type](
- module, fsdp_state, state_dict, prefix
- )
+ with context:
+ _pre_load_state_dict_hook_fn = {
+ StateDictType.FULL_STATE_DICT: _full_pre_load_state_dict_hook,
+ StateDictType.LOCAL_STATE_DICT: _local_pre_load_state_dict_hook,
+ StateDictType.SHARDED_STATE_DICT: _sharded_pre_load_state_dict_hook,
+ }
+ # Code that is common for all state_dict impls
+ if fsdp_state._device_handle.is_available():
+ fsdp_state._device_handle.synchronize()
+ # Dispatch into state_dict specific implementation of pre-hook.
+ _pre_load_state_dict_hook_fn[fsdp_state._state_dict_type](
+ module, fsdp_state, state_dict, prefix
+ )
@no_type_check
@@ -738,18 +777,24 @@
*args: Any,
) -> None:
if fsdp_state.sharding_strategy == ShardingStrategy.NO_SHARD:
- fsdp_state._state_dict_config = FullStateDictConfig()
- fsdp_state._state_dict_type = StateDictType.FULL_STATE_DICT
+ context = _replace_with_full_state_dict_type(fsdp_state)
+ warnings.warn(
+ "When using ``NO_SHARD`` for ``ShardingStrategy``, full_state_dict will"
+ "be returned."
+ )
+ else:
+ context = contextlib.nullcontext()
- _post_load_state_dict_hook_fn = {
- StateDictType.FULL_STATE_DICT: _full_post_load_state_dict_hook,
- StateDictType.LOCAL_STATE_DICT: _local_post_load_state_dict_hook,
- StateDictType.SHARDED_STATE_DICT: _sharded_post_load_state_dict_hook,
- }
- # Code that is common for all state_dict impls
- # Dispatch into state_dict type specific implementation of post-hook for
- # loading state_dict.
- _post_load_state_dict_hook_fn[fsdp_state._state_dict_type](module, fsdp_state)
+ with context:
+ _post_load_state_dict_hook_fn = {
+ StateDictType.FULL_STATE_DICT: _full_post_load_state_dict_hook,
+ StateDictType.LOCAL_STATE_DICT: _local_post_load_state_dict_hook,
+ StateDictType.SHARDED_STATE_DICT: _sharded_post_load_state_dict_hook,
+ }
+ # Code that is common for all state_dict impls
+ # Dispatch into state_dict type specific implementation of post-hook for
+ # loading state_dict.
+ _post_load_state_dict_hook_fn[fsdp_state._state_dict_type](module, fsdp_state)
def _register_all_state_dict_hooks(state: _FSDPState):