| import logging |
| import warnings |
| |
| from copy import deepcopy |
| from typing import Any, Collection, Dict, List, Mapping, Union |
| |
| import torch |
| from torch import optim |
| from torch.distributed._shard.sharded_tensor import ShardedTensor |
| |
| |
| __all__: List[str] = [] |
| |
| logger = logging.getLogger(__name__) |
| |
| |
| class _NamedOptimizer(optim.Optimizer): |
| """ |
| ``_NamedOptimizer`` takes a dict of parameters and exposes ``state_dict`` by |
| parameter key. We replace the original key (number) in an optim to the |
| fully qualifed name (FQN) string. User can initialize the optim as they |
| initialize a PyTorch optim, the only difference is that they also need to |
| pass in the FQN of each parameters. |
| |
| Args: |
| named_parameters (Mapping[str, Union[torch.Tensor, ShardedTensor]]): |
| Mapping from FQN to parameter. |
| optimizer_class (optim.Optimizer): |
| The class of optimizer to instantiate. |
| param_groups (Collection[Mapping[str, Any]]): |
| `param_groups` to pass to optimizer if specified. |
| The key of the inner map needs to be FQNs. |
| Default: None |
| args: arguments to pass to the optimizer constructor. |
| kwargs: arguments to pass to the optimizer constructor. |
| |
| Example:: |
| >>> # xdoctest: +SKIP("distributed") |
| >>> from torch import optim |
| >>> from torch.distributed.optim import _NamedOptimizer |
| >>> |
| >>> # Define the named optimizer. |
| >>> m = Model(...) |
| >>> named_optim = _NamedOptimizer(m.named_parameters(), optim.SGD) |
| >>> # Forward pass + backward pass. |
| >>> named_optim.step() |
| >>> ... |
| >>> # Call state_dict for the named optimizer returns a FQN state_dict. |
| >>> named_optim.state_dict() |
| |
| Warning: This API is still in development and subject to change. |
| |
| TODO: Add tutorial for _NamedOptimizer. |
| TODO: Add documentation in the docstring for the public attributes |
| like self.param_groups and self.named_parameters. |
| """ |
| |
| def __init__( |
| self, |
| named_parameters: Mapping[str, Union[torch.Tensor, ShardedTensor]], |
| optimizer_class: optim.Optimizer, |
| param_groups: Collection[Mapping[str, Any]] = None, |
| *args, |
| **kwargs, |
| ) -> None: |
| torch._C._log_api_usage_once("torch.distributed.optim._NamedOptimizer") |
| self.param_groups: Collection[Mapping[str, Any]] = param_groups # type: ignore[assignment] |
| self._param_groups_check() |
| self.named_parameters = dict(named_parameters) |
| params_for_optimizer = ( |
| self.named_parameters.values() if param_groups is None else param_groups |
| ) |
| self._optimizer = optimizer_class( # type: ignore[operator] |
| params_for_optimizer, |
| *args, |
| **kwargs, |
| ) |
| if param_groups is None: |
| self.ordered_param_keys = list(self.named_parameters.keys()) |
| else: |
| warnings.warn( |
| "Since we pass in param_groups, we will use param_groups to " |
| "initialize the optimizer, not all parameters of the module." |
| ) |
| param_to_key = {param: key for key, param in self.named_parameters.items()} # type: ignore[misc, has-type] |
| ordered_param_keys = [] |
| for group in param_groups: |
| for param in group["params"]: |
| if param not in param_to_key: |
| raise ValueError( |
| f"Expect param name {param} found in param group but is missing." |
| ) |
| ordered_param_keys.append(param_to_key[param]) |
| self.ordered_param_keys = ordered_param_keys |
| # Update param_groups from optimizer. |
| self.param_groups = self._optimizer.param_groups |
| |
| def _param_groups_check(self): |
| if self.param_groups is not None: |
| for param_group in self.param_groups: |
| assert isinstance(param_group, dict), "param group must be a dict" |
| assert "params" in param_group, "param group must contain key params" |
| params = param_group["params"] |
| if isinstance(params, torch.Tensor): |
| params = [params] |
| params = list(params) |
| for param in params: |
| if not isinstance(param, torch.Tensor): |
| raise TypeError( |
| "optimizer can only optimize Tensors, " |
| "but one of the params is " + torch.typename(param) |
| ) |
| param_group["params"] = params |
| |
| def state_dict(self) -> Dict[str, Any]: |
| """ |
| Return the ``state_dict`` of the optimzer. Instead of using number to index |
| parameters, we will use module fully qualifed name (FQN) as the key. |
| """ |
| state_dict = self._optimizer.state_dict() |
| param_groups = state_dict["param_groups"] |
| |
| ret_state = { |
| self.ordered_param_keys[st_key]: state_val |
| for st_key, state_val in state_dict["state"].items() |
| } |
| |
| ret_groups = [] |
| for group in param_groups: |
| param_keys = [] |
| for param in group["params"]: |
| param_keys.append(self.ordered_param_keys[param]) |
| ret_group = {"params": sorted(param_keys)} |
| for k, v in group.items(): |
| if k != "params": |
| ret_group[k] = deepcopy(v) |
| ret_groups.append(ret_group) |
| |
| return {"state": ret_state, "param_groups": ret_groups} |
| |
| def step(self): |
| """ |
| Performs a single optimization step. |
| |
| This will call :meth:`torch.optim.Optimizer.step` on the wrapped |
| optimizer. |
| """ |
| self._optimizer.step() |
| |
| def load_state_dict(self, state_dict: Mapping[str, Any]) -> None: |
| """ |
| This public function defines the default behavior to load a state_dict |
| for ``_NamedOptimizer``. |
| |
| Sample Code |
| ``` |
| my_model = MyModule() |
| optimizer = _NamedOptimizer(my_model.named_parameters(), Adagrad) |
| ... |
| |
| optim_state_dict = optimizer.state_dict() |
| ... |
| ... |
| |
| optimizer.load_state_dict(optim_state_dict) |
| ... |
| ``` |
| Args: |
| state_dict (Dict[str, Any]) : A ``state_dict`` to load into the optimizer. |
| Note that this state dict update is performed in place. |
| |
| .. note:: PyTorch is using lazy init to initialize the optim states. |
| So it is possible that there is no optim state when user call |
| ``load_state_dict`` and for ``_NamedOptimizer`` we make it stricter |
| that users can only call ``load_state_dict`` after the state is initialized. |
| By doing this, we can validate the optim ``state_dict`` to be loaded. |
| """ |
| new_state_dict = self._optimizer.state_dict() |
| state = state_dict["state"] |
| new_state = new_state_dict["state"] |
| if len(new_state) == 0: |
| raise ValueError( |
| "Expects the optim to be initialized before load but found not initialized." |
| ) |
| |
| # Load state of state_dict |
| if len(new_state) != len(state): |
| raise ValueError( |
| f"Expects equal length as {len(new_state)} in `state_dict` state length but found {len(state)}." |
| ) |
| for idx, param_key in enumerate(self.ordered_param_keys): |
| # When the conditional training is performed, not all parameters are updated in the optim. |
| if param_key not in state.keys(): |
| continue |
| if len(state[param_key]) != len(new_state[idx]): |
| raise ValueError( |
| f"Expects equal length as {len(new_state[idx])} for parameter {param_key} but found: {len(state[param_key])}" |
| ) |
| # Iterate through all optimizer states. |
| for state_key, state_val in new_state[idx].items(): |
| if state_key not in state[param_key]: |
| raise ValueError( |
| f"Expects state {state_key} for parameter {param_key} but not found." |
| ) |
| |
| src_state_val = state[param_key][state_key] |
| if isinstance(state_val, ShardedTensor): |
| assert isinstance(src_state_val, ShardedTensor) |
| num_shards = len(state_val.local_shards()) |
| num_new_shards = len(src_state_val.local_shards()) |
| if num_shards != num_new_shards: |
| raise ValueError( |
| f"Expects equal number of shards as {num_new_shards} but found {num_shards} for {param_key}/{state_key}" |
| ) |
| for shard, src_shard in zip( |
| state_val.local_shards(), src_state_val.local_shards() |
| ): |
| shard.tensor.detach().copy_(src_shard.tensor) |
| elif isinstance(state_val, torch.Tensor): |
| assert isinstance(src_state_val, torch.Tensor) |
| state_val.detach().copy_(src_state_val) |
| else: |
| new_state[idx][state_key] = deepcopy(src_state_val) |
| |
| # Load param_groups of state_dict |
| src_param_groups = state_dict["param_groups"] |
| new_param_groups = new_state_dict["param_groups"] |
| |
| if len(new_param_groups) != len(src_param_groups): |
| raise ValueError( |
| f"Expects equal param_groups count as {len(new_param_groups)} in `state_dict` but found {len(src_param_groups)}." |
| ) |
| src_group_map = {} |
| for group in src_param_groups: |
| param_keys = [] |
| for param_key in group["params"]: |
| param_keys.append(param_key) |
| src_group_map[_gen_param_group_key(param_keys)] = group |
| new_group_map = {} |
| for new_group in new_param_groups: |
| param_keys = [] |
| for param_key in new_group["params"]: |
| param_keys.append(self.ordered_param_keys[param_key]) # type: ignore[call-overload] |
| new_group_map[_gen_param_group_key(param_keys)] = new_group |
| for group_key, new_group in new_group_map.items(): |
| if group_key not in src_group_map: |
| raise ValueError( |
| f"Expects group {group_key} to be in `state_dict` but is missing" |
| ) |
| src_group = src_group_map[group_key] |
| if len(src_group) != len(new_group): |
| raise ValueError( |
| f"Expects equal param_group size as {len(new_group)} for group {group_key} but found {len(src_group)}." |
| ) |
| for k in src_group: |
| if k not in new_group: |
| raise ValueError( |
| f"Expects group key {k} to be in group {group_key} in `state_dict` but is missing." |
| ) |
| if k != "params": |
| new_group[k] = deepcopy(src_group[k]) |
| |
| self._optimizer.load_state_dict(new_state_dict) |
| |
| # pyre-ignore [2] |
| def add_param_group(self, param_group: Any) -> None: |
| raise NotImplementedError( |
| "add_param_group not supported yet and might be implemented soon." |
| ) |
| |
| |
| def _gen_param_group_key(param_keys: List[str]) -> str: |
| """ |
| Concatenate all param keys as a unique indentifier for one param group. |
| """ |
| return "/".join(sorted(param_keys)) |