[FSDP][state_dict] Restore the state_dict_config for NO_SHARD (#100855)

Any change to the user configurations should be temporary. This PR fixes the issue when NO_SHARD state_dict/load_state_dict is called, the state_dict_config and state_dict_type are changed permanently.

Differential Revision: [D45593313](https://our.internmc.facebook.com/intern/diff/D45593313/)

**NOTE FOR REVIEWERS**: This PR has internal Meta-specific changes or comments, please review them on [Phabricator](https://our.internmc.facebook.com/intern/diff/D45593313/)!
Pull Request resolved: https://github.com/pytorch/pytorch/pull/100855
Approved by: https://github.com/awgu, https://github.com/zhaojuanmao, https://github.com/rohan-varma
diff --git a/test/distributed/_composable/test_compose.py b/test/distributed/_composable/test_compose.py
index 126abd5..3b7c820 100644
--- a/test/distributed/_composable/test_compose.py
+++ b/test/distributed/_composable/test_compose.py
@@ -302,6 +302,8 @@
                 self.assertIsInstance(tensor, ShardedTensor)
             elif "u2" in fqn:
                 self.assertIsInstance(tensor, torch.Tensor)
+        # Ensure that get_state_dict_type can still correctly get the settings.
+        _ = FSDP.get_state_dict_type(model)
 
 
 instantiate_parametrized_tests(TestFSDPCheckpoint)
diff --git a/torch/distributed/fsdp/_state_dict_utils.py b/torch/distributed/fsdp/_state_dict_utils.py
index c41a50f..6b60d37 100644
--- a/torch/distributed/fsdp/_state_dict_utils.py
+++ b/torch/distributed/fsdp/_state_dict_utils.py
@@ -1,7 +1,18 @@
+import contextlib
 import functools
 import math
 import warnings
-from typing import Any, Callable, cast, Dict, Iterator, List, no_type_check, Tuple
+from typing import (
+    Any,
+    Callable,
+    cast,
+    Dict,
+    Generator,
+    Iterator,
+    List,
+    no_type_check,
+    Tuple,
+)
 
 import torch
 import torch.distributed as dist
@@ -635,6 +646,20 @@
         _deregister_orig_params(module, fsdp_state)
 
 
[email protected]
+def _replace_with_full_state_dict_type(fsdp_state: _FSDPState) -> Generator:
+    old_state_dict_config = fsdp_state._state_dict_config
+    old_state_dict_type = fsdp_state._state_dict_type
+    try:
+        fsdp_state._state_dict_config = FullStateDictConfig()
+        fsdp_state._state_dict_type = StateDictType.FULL_STATE_DICT
+        yield
+    except Exception as e:
+        raise e
+    fsdp_state._state_dict_config = old_state_dict_config
+    fsdp_state._state_dict_type = old_state_dict_type
+
+
 @no_type_check
 @torch.no_grad()
 def _post_state_dict_hook(
@@ -650,17 +675,23 @@
     what postprocessing will be done.
     """
     if fsdp_state.sharding_strategy == ShardingStrategy.NO_SHARD:
-        fsdp_state._state_dict_config = FullStateDictConfig()
-        fsdp_state._state_dict_type = StateDictType.FULL_STATE_DICT
+        context = _replace_with_full_state_dict_type(fsdp_state)
+        warnings.warn(
+            "When using ``NO_SHARD`` for ``ShardingStrategy``, full_state_dict will"
+            "be returned."
+        )
+    else:
+        context = contextlib.nullcontext()
 
-    _post_state_dict_hook_fn = {
-        StateDictType.FULL_STATE_DICT: _full_post_state_dict_hook,
-        StateDictType.LOCAL_STATE_DICT: _local_post_state_dict_hook,
-        StateDictType.SHARDED_STATE_DICT: _sharded_post_state_dict_hook,
-    }
-    processed_state_dict = _post_state_dict_hook_fn[fsdp_state._state_dict_type](
-        module, fsdp_state, state_dict, prefix
-    )
+    with context:
+        _post_state_dict_hook_fn = {
+            StateDictType.FULL_STATE_DICT: _full_post_state_dict_hook,
+            StateDictType.LOCAL_STATE_DICT: _local_post_state_dict_hook,
+            StateDictType.SHARDED_STATE_DICT: _sharded_post_state_dict_hook,
+        }
+        processed_state_dict = _post_state_dict_hook_fn[fsdp_state._state_dict_type](
+            module, fsdp_state, state_dict, prefix
+        )
     return processed_state_dict
 
 
@@ -678,24 +709,26 @@
     be done.
     """
     if fsdp_state.sharding_strategy == ShardingStrategy.NO_SHARD:
-        fsdp_state._state_dict_config = FullStateDictConfig()
-        fsdp_state._state_dict_type = StateDictType.FULL_STATE_DICT
+        context = _replace_with_full_state_dict_type(fsdp_state)
         warnings.warn(
             "When using ``NO_SHARD`` for ``ShardingStrategy``, full_state_dict will"
             "be returned."
         )
+    else:
+        context = contextlib.nullcontext()
 
-    _pre_state_dict_hook_fn = {
-        StateDictType.FULL_STATE_DICT: _full_pre_state_dict_hook,
-        StateDictType.LOCAL_STATE_DICT: _local_pre_state_dict_hook,
-        StateDictType.SHARDED_STATE_DICT: _sharded_pre_state_dict_hook,
-    }
-    _pre_state_dict_hook_fn[fsdp_state._state_dict_type](
-        fsdp_state,
-        module,
-        *args,
-        **kwargs,
-    )
+    with context:
+        _pre_state_dict_hook_fn = {
+            StateDictType.FULL_STATE_DICT: _full_pre_state_dict_hook,
+            StateDictType.LOCAL_STATE_DICT: _local_pre_state_dict_hook,
+            StateDictType.SHARDED_STATE_DICT: _sharded_pre_state_dict_hook,
+        }
+        _pre_state_dict_hook_fn[fsdp_state._state_dict_type](
+            fsdp_state,
+            module,
+            *args,
+            **kwargs,
+        )
 
 
 @no_type_check
@@ -713,21 +746,27 @@
     be done.
     """
     if fsdp_state.sharding_strategy == ShardingStrategy.NO_SHARD:
-        fsdp_state._state_dict_config = FullStateDictConfig()
-        fsdp_state._state_dict_type = StateDictType.FULL_STATE_DICT
+        context = _replace_with_full_state_dict_type(fsdp_state)
+        warnings.warn(
+            "When using ``NO_SHARD`` for ``ShardingStrategy``, full_state_dict will"
+            "be returned."
+        )
+    else:
+        context = contextlib.nullcontext()
 
-    _pre_load_state_dict_hook_fn = {
-        StateDictType.FULL_STATE_DICT: _full_pre_load_state_dict_hook,
-        StateDictType.LOCAL_STATE_DICT: _local_pre_load_state_dict_hook,
-        StateDictType.SHARDED_STATE_DICT: _sharded_pre_load_state_dict_hook,
-    }
-    # Code that is common for all state_dict impls
-    if fsdp_state._device_handle.is_available():
-        fsdp_state._device_handle.synchronize()
-    # Dispatch into state_dict specific implementation of pre-hook.
-    _pre_load_state_dict_hook_fn[fsdp_state._state_dict_type](
-        module, fsdp_state, state_dict, prefix
-    )
+    with context:
+        _pre_load_state_dict_hook_fn = {
+            StateDictType.FULL_STATE_DICT: _full_pre_load_state_dict_hook,
+            StateDictType.LOCAL_STATE_DICT: _local_pre_load_state_dict_hook,
+            StateDictType.SHARDED_STATE_DICT: _sharded_pre_load_state_dict_hook,
+        }
+        # Code that is common for all state_dict impls
+        if fsdp_state._device_handle.is_available():
+            fsdp_state._device_handle.synchronize()
+        # Dispatch into state_dict specific implementation of pre-hook.
+        _pre_load_state_dict_hook_fn[fsdp_state._state_dict_type](
+            module, fsdp_state, state_dict, prefix
+        )
 
 
 @no_type_check
@@ -738,18 +777,24 @@
     *args: Any,
 ) -> None:
     if fsdp_state.sharding_strategy == ShardingStrategy.NO_SHARD:
-        fsdp_state._state_dict_config = FullStateDictConfig()
-        fsdp_state._state_dict_type = StateDictType.FULL_STATE_DICT
+        context = _replace_with_full_state_dict_type(fsdp_state)
+        warnings.warn(
+            "When using ``NO_SHARD`` for ``ShardingStrategy``, full_state_dict will"
+            "be returned."
+        )
+    else:
+        context = contextlib.nullcontext()
 
-    _post_load_state_dict_hook_fn = {
-        StateDictType.FULL_STATE_DICT: _full_post_load_state_dict_hook,
-        StateDictType.LOCAL_STATE_DICT: _local_post_load_state_dict_hook,
-        StateDictType.SHARDED_STATE_DICT: _sharded_post_load_state_dict_hook,
-    }
-    # Code that is common for all state_dict impls
-    # Dispatch into state_dict type specific implementation of post-hook for
-    # loading state_dict.
-    _post_load_state_dict_hook_fn[fsdp_state._state_dict_type](module, fsdp_state)
+    with context:
+        _post_load_state_dict_hook_fn = {
+            StateDictType.FULL_STATE_DICT: _full_post_load_state_dict_hook,
+            StateDictType.LOCAL_STATE_DICT: _local_post_load_state_dict_hook,
+            StateDictType.SHARDED_STATE_DICT: _sharded_post_load_state_dict_hook,
+        }
+        # Code that is common for all state_dict impls
+        # Dispatch into state_dict type specific implementation of post-hook for
+        # loading state_dict.
+        _post_load_state_dict_hook_fn[fsdp_state._state_dict_type](module, fsdp_state)
 
 
 def _register_all_state_dict_hooks(state: _FSDPState):