| import functools |
| import math |
| import warnings |
| from collections import defaultdict, OrderedDict |
| from copy import deepcopy |
| from itertools import chain |
| from typing import ( |
| Any, |
| Callable, |
| cast, |
| DefaultDict, |
| Dict, |
| Hashable, |
| Iterable, |
| List, |
| Optional, |
| overload, |
| Set, |
| Tuple, |
| TypeVar, |
| Union, |
| ) |
| from typing_extensions import ParamSpec, Self, TypeAlias |
| |
| import torch |
| import torch.utils.hooks as hooks |
| from torch.utils._foreach_utils import ( |
| _get_foreach_kernels_supported_devices, |
| _get_fused_kernels_supported_devices, |
| _group_tensors_by_device_and_dtype, |
| Indices, |
| ) |
| from torch.utils.hooks import RemovableHandle |
| |
| Args: TypeAlias = Tuple[Any, ...] |
| Kwargs: TypeAlias = Dict[str, Any] |
| StateDict: TypeAlias = Dict[str, Any] |
| TensorListList: TypeAlias = List[List[torch.Tensor]] |
| DeviceDict = Dict[Optional[torch.device], torch.Tensor] |
| |
| |
| GlobalOptimizerPreHook: TypeAlias = Callable[ |
| ["Optimizer", Args, Kwargs], Optional[Tuple[Args, Kwargs]] |
| ] |
| GlobalOptimizerPostHook: TypeAlias = Callable[["Optimizer", Args, Kwargs], None] |
| |
| __all__ = [ |
| "Optimizer", |
| "register_optimizer_step_pre_hook", |
| "register_optimizer_step_post_hook", |
| ] |
| _global_optimizer_pre_hooks: Dict[int, GlobalOptimizerPreHook] = OrderedDict() |
| _global_optimizer_post_hooks: Dict[int, GlobalOptimizerPostHook] = OrderedDict() |
| _foreach_supported_types = [torch.Tensor, torch.nn.parameter.Parameter] |
| |
| |
| class _RequiredParameter: |
| """Singleton class representing a required parameter for an Optimizer.""" |
| |
| def __repr__(self) -> str: |
| return "<required parameter>" |
| |
| |
| required = _RequiredParameter() |
| |
| |
| def _use_grad_for_differentiable(func): |
| def _use_grad(self, *args, **kwargs): |
| import torch._dynamo |
| |
| prev_grad = torch.is_grad_enabled() |
| try: |
| # Note on graph break below: |
| # we need to graph break to ensure that aot respects the no_grad annotation. |
| # This is important for perf because without this, functionalization will generate an epilogue |
| # which updates the mutated parameters of the optimizer which is *not* visible to inductor, as a result, |
| # inductor will allocate for every parameter in the model, which is horrible. |
| # With this, aot correctly sees that this is an inference graph, and functionalization will generate |
| # an epilogue which is appended to the graph, which *is* visible to inductor, as a result, inductor sees that |
| # step is in place and is able to avoid the extra allocation. |
| # In the future, we will either 1) continue to graph break on backward, so this graph break does not matter |
| # or 2) have a fully fused forward and backward graph, which will have no_grad by default, and we can remove this |
| # graph break to allow the fully fused fwd-bwd-optimizer graph to be compiled. |
| # see https://github.com/pytorch/pytorch/issues/104053 |
| torch.set_grad_enabled(self.defaults["differentiable"]) |
| torch._dynamo.graph_break() |
| ret = func(self, *args, **kwargs) |
| finally: |
| torch._dynamo.graph_break() |
| torch.set_grad_enabled(prev_grad) |
| return ret |
| |
| functools.update_wrapper(_use_grad, func) |
| return _use_grad |
| |
| |
| def _get_value(x): |
| # item is significantly faster than a cpu tensor in eager mode |
| if not torch.jit.is_scripting() and torch.compiler.is_compiling(): |
| return x |
| else: |
| return x.item() if isinstance(x, torch.Tensor) else x |
| |
| |
| def _stack_if_compiling(x): |
| if not torch.jit.is_scripting() and torch.compiler.is_compiling(): |
| return torch.stack(x) |
| else: |
| return x |
| |
| |
| def _dispatch_sqrt( |
| x: float, |
| ): # float annotation is needed because of torchscript type inference |
| if not torch.jit.is_scripting() and isinstance(x, torch.Tensor): |
| return x.sqrt() |
| else: |
| return math.sqrt(x) |
| |
| |
| def _disable_dynamo_if_unsupported(single_tensor_fn=None): |
| # workaround for torchscript BC |
| # it requires all called functions to be in the |
| # global environment at the site at which the |
| # maybe_fallback closure is created |
| if single_tensor_fn: |
| globals()[single_tensor_fn.__name__] = single_tensor_fn |
| |
| def wrapper(func): |
| import inspect |
| |
| disabled_func = torch._disable_dynamo(func) |
| ps = inspect.signature(func).parameters |
| has_state_steps = True |
| try: |
| state_steps_ind = list(ps.keys()).index("state_steps") |
| except ValueError: |
| has_state_steps = False |
| |
| # Today, there are cases where we stack state steps |
| # and pass them as the value arg of foreach ops. |
| # Having state steps on cuda as the value arg is not supported in eager, |
| # but this only occurs in the rare case that the user explicitly deletes |
| # the capturable flag. If capturable=True, this is not a problem. |
| @functools.wraps(func) |
| def maybe_fallback(*args, **kwargs): |
| if torch.compiler.is_compiling() and ( |
| not kwargs.get("capturable", False) |
| and has_state_steps |
| and (args[state_steps_ind] and args[state_steps_ind][0].is_cuda) |
| or ( |
| "state_steps" in kwargs |
| and kwargs["state_steps"] |
| and kwargs["state_steps"][0].is_cuda |
| ) |
| ): |
| return disabled_func(*args, **kwargs) |
| else: |
| return func(*args, **kwargs) |
| |
| return maybe_fallback |
| |
| return wrapper |
| |
| |
| # For any optimizer with a faster implementation, we attempt to default to the |
| # fastest + stablest whenever possible. For foreach, the requirements are to have |
| # native params all on CUDA. For fused, there's currently the additional requirement |
| # that the tensors' dtypes must be floating point. Neither alternative supports |
| # torch.jit.script nor differentiable, so we fall back to the single tensor |
| # implementation in those cases. |
| def _default_to_fused_or_foreach( |
| params: List[torch.Tensor], differentiable: bool, use_fused: bool = False |
| ) -> Tuple[bool, bool]: |
| if torch.jit.is_scripting() or differentiable: |
| return False, False |
| |
| fused_supported_devices = _get_fused_kernels_supported_devices() |
| foreach_supported_devices = _get_foreach_kernels_supported_devices() |
| fused = use_fused and all( |
| p is None |
| or ( |
| type(p) in _foreach_supported_types |
| and p.device.type in fused_supported_devices |
| and torch.is_floating_point(p) |
| ) |
| for p in params |
| ) |
| foreach = not fused and all( |
| p is None |
| or ( |
| type(p) in _foreach_supported_types |
| and p.device.type in foreach_supported_devices |
| ) |
| for p in params |
| ) |
| return fused, foreach |
| |
| |
| def _view_as_real(params, *state_and_grads): |
| for i, p in enumerate(params): |
| if torch.is_complex(p): |
| params[i] = torch.view_as_real(params[i]) |
| for s in state_and_grads: |
| s[i] = torch.view_as_real(s[i]) |
| |
| |
| def _get_scalar_dtype(is_fused=None): |
| if is_fused: |
| return torch.float32 |
| return ( |
| torch.float64 if torch.get_default_dtype() == torch.float64 else torch.float32 |
| ) |
| |
| |
| def _get_capturable_supported_devices(supports_xla: bool = True) -> List[str]: |
| r"""Return the device type list that supports capturable optimizer.""" |
| capturable_supported_devices = ["cuda"] |
| if not torch.jit.is_scripting(): |
| capturable_supported_devices.append(torch._C._get_privateuse1_backend_name()) |
| if supports_xla: |
| capturable_supported_devices.append("xla") |
| return capturable_supported_devices |
| |
| |
| # Common doc strings among optimizers |
| _foreach_doc = r"""foreach (bool, optional): whether foreach implementation of optimizer |
| is used. If unspecified by the user (so foreach is None), we will try to use |
| foreach over the for-loop implementation on CUDA, since it is usually |
| significantly more performant. Note that the foreach implementation uses |
| ~ sizeof(params) more peak memory than the for-loop version due to the intermediates |
| being a tensorlist vs just one tensor. If memory is prohibitive, batch fewer |
| parameters through the optimizer at a time or switch this flag to False (default: None)""" |
| |
| _fused_doc = r"""fused (bool, optional): whether the fused implementation is used. |
| Currently, `torch.float64`, `torch.float32`, `torch.float16`, and `torch.bfloat16` |
| are supported. (default: None) |
| |
| .. note:: The foreach and fused implementations are typically faster than the for-loop, |
| single-tensor implementation. Thus, if the user has not specified BOTH flags |
| (i.e., when foreach = fused = None), we will attempt defaulting to the foreach |
| implementation when the tensors are all on CUDA. For example, if the user specifies |
| True for fused but nothing for foreach, we will run the fused implementation. If |
| the user specifies False for foreach but nothing for fused (or False for fused but |
| nothing for foreach), we will run the for-loop implementation. If the user specifies |
| True for both foreach and fused, we will prioritize fused over foreach, as it is |
| typically faster. We attempt to use the fastest, so the hierarchy goes fused -> |
| foreach -> for-loop. HOWEVER, since the fused implementation is relatively new, |
| we want to give it sufficient bake-in time, so we default to foreach and NOT |
| fused when the user has not specified either flag.""" |
| |
| _capturable_doc = r"""capturable (bool, optional): whether this instance is safe to |
| capture in a CUDA graph. Passing True can impair ungraphed performance, |
| so if you don't intend to graph capture this instance, leave it False |
| (default: False)""" |
| |
| _differentiable_doc = r"""differentiable (bool, optional): whether autograd should |
| occur through the optimizer step in training. Otherwise, the step() |
| function runs in a torch.no_grad() context. Setting to True can impair |
| performance, so leave it False if you don't intend to run autograd |
| through this instance (default: False)""" |
| |
| _maximize_doc = r"""maximize (bool, optional): maximize the objective with respect to the |
| params, instead of minimizing (default: False)""" |
| |
| |
| def register_optimizer_step_pre_hook(hook: GlobalOptimizerPreHook) -> RemovableHandle: |
| r"""Register a pre hook common to all optimizers. The hook should have the following |
| signature:: |
| |
| hook(optimizer, args, kwargs) -> None or modified args and kwargs |
| |
| Args: |
| hook (Callable): A user defined hook which is registered on all optimizers. |
| |
| Returns: |
| :class:`torch.utils.hooks.RemovableHandle`: |
| a handle that can be used to remove the added hook by calling |
| ``handle.remove()`` |
| """ |
| handle = hooks.RemovableHandle(_global_optimizer_pre_hooks) |
| _global_optimizer_pre_hooks[handle.id] = hook |
| return handle |
| |
| |
| def register_optimizer_step_post_hook(hook: GlobalOptimizerPostHook) -> RemovableHandle: |
| r"""Register a post hook common to all optimizers. The hook should have the following |
| signature:: |
| |
| hook(optimizer, args, kwargs) -> None |
| |
| Args: |
| hook (Callable): A user defined hook which is registered on all optimizers. |
| |
| Returns: |
| :class:`torch.utils.hooks.RemovableHandle`: |
| a handle that can be used to remove the added hook by calling |
| ``handle.remove()`` |
| """ |
| handle = hooks.RemovableHandle(_global_optimizer_post_hooks) |
| _global_optimizer_post_hooks[handle.id] = hook |
| return handle |
| |
| |
| ParamsT: TypeAlias = Union[Iterable[torch.Tensor], Iterable[Dict[str, Any]]] |
| |
| _P = ParamSpec("_P") |
| R = TypeVar("R") |
| T = TypeVar("T") |
| |
| |
| class Optimizer: |
| r"""Base class for all optimizers. |
| |
| .. warning:: |
| Parameters need to be specified as collections that have a deterministic |
| ordering that is consistent between runs. Examples of objects that don't |
| satisfy those properties are sets and iterators over values of dictionaries. |
| |
| Args: |
| params (iterable): an iterable of :class:`torch.Tensor` s or |
| :class:`dict` s. Specifies what Tensors should be optimized. |
| defaults: (dict): a dict containing default values of optimization |
| options (used when a parameter group doesn't specify them). |
| """ |
| |
| OptimizerPreHook: TypeAlias = Callable[[Self, Args, Kwargs], Optional[Tuple[Args, Kwargs]]] # type: ignore[misc] |
| OptimizerPostHook: TypeAlias = Callable[[Self, Args, Kwargs], None] # type: ignore[misc] |
| |
| _optimizer_step_pre_hooks: Dict[int, OptimizerPreHook] |
| _optimizer_step_post_hooks: Dict[int, OptimizerPostHook] |
| _optimizer_state_dict_pre_hooks: 'OrderedDict[int, Callable[["Optimizer"], None]]' |
| _optimizer_state_dict_post_hooks: 'OrderedDict[int, Callable[["Optimizer", StateDict], Optional[StateDict]]]' |
| _optimizer_load_state_dict_pre_hooks: 'OrderedDict[int, Callable[["Optimizer", StateDict], Optional[StateDict]]]' |
| _optimizer_load_state_dict_post_hooks: 'OrderedDict[int, Callable[["Optimizer"], None]]' |
| |
| def __init__(self, params: ParamsT, defaults: Dict[str, Any]) -> None: |
| torch._C._log_api_usage_once("python.optimizer") |
| self.defaults = defaults |
| self._optimizer_step_pre_hooks = OrderedDict() |
| self._optimizer_step_post_hooks = OrderedDict() |
| self._optimizer_state_dict_pre_hooks = OrderedDict() |
| self._optimizer_state_dict_post_hooks = OrderedDict() |
| self._optimizer_load_state_dict_pre_hooks = OrderedDict() |
| self._optimizer_load_state_dict_post_hooks = OrderedDict() |
| |
| self._patch_step_function() |
| |
| if isinstance(params, torch.Tensor): |
| raise TypeError( |
| "params argument given to the optimizer should be " |
| "an iterable of Tensors or dicts, but got " + torch.typename(params) |
| ) |
| |
| self.state: DefaultDict[torch.Tensor, Any] = defaultdict(dict) |
| self.param_groups: List[Dict[str, Any]] = [] |
| |
| param_groups = list(params) |
| if len(param_groups) == 0: |
| raise ValueError("optimizer got an empty parameter list") |
| if not isinstance(param_groups[0], dict): |
| param_groups = [{"params": param_groups}] |
| |
| for param_group in param_groups: |
| self.add_param_group(cast(dict, param_group)) |
| |
| # Allows _cuda_graph_capture_health_check to rig a poor man's TORCH_WARN_ONCE in python, |
| # which I don't think exists |
| # https://github.com/pytorch/pytorch/issues/72948 |
| self._warned_capturable_if_run_uncaptured = True |
| |
| def __getstate__(self) -> Dict[str, Any]: |
| return { |
| "defaults": self.defaults, |
| "state": self.state, |
| "param_groups": self.param_groups, |
| } |
| |
| def __setstate__(self, state: Dict[str, Any]) -> None: |
| self.__dict__.update(state) |
| if "_optimizer_step_pre_hooks" not in self.__dict__: |
| self._optimizer_step_pre_hooks = OrderedDict() |
| if "_optimizer_step_post_hooks" not in self.__dict__: |
| self._optimizer_step_post_hooks = OrderedDict() |
| if "_optimizer_state_dict_pre_hooks" not in self.__dict__: |
| self._optimizer_state_dict_pre_hooks = OrderedDict() |
| if "_optimizer_state_dict_post_hooks" not in self.__dict__: |
| self._optimizer_state_dict_post_hooks = OrderedDict() |
| if "_optimizer_load_state_dict_pre_hooks" not in self.__dict__: |
| self._optimizer_load_state_dict_pre_hooks = OrderedDict() |
| if "_optimizer_load_state_dict_post_hooks" not in self.__dict__: |
| self._optimizer_load_state_dict_post_hooks = OrderedDict() |
| self._patch_step_function() # To support multiprocessing pickle/unpickle |
| self.defaults.setdefault("differentiable", False) |
| |
| def __repr__(self) -> str: |
| format_string = self.__class__.__name__ + " (" |
| for i, group in enumerate(self.param_groups): |
| format_string += "\n" |
| format_string += f"Parameter Group {i}\n" |
| for key in sorted(group.keys()): |
| if key != "params": |
| format_string += f" {key}: {group[key]}\n" |
| format_string += ")" |
| return format_string |
| |
| # Currently needed by Adam and AdamW |
| def _cuda_graph_capture_health_check(self) -> None: |
| # Note [torch.compile x capturable] |
| # If we are compiling, we try to take the capturable path automatically by |
| # setting the flag to True during tracing. Due to this, we skip all the checks |
| # normally required for determining whether we can use CUDA graphs and |
| # shunt the responsibility to torch.inductor. This saves time during tracing |
| # since the checks are slow without sacrificing UX since inductor will warn |
| # later if CUDA graphs cannot be enabled, e.g., |
| # https://github.com/pytorch/pytorch/blob/d3ba8901d8640eb16f88b2bfef9df7fa383d4b47/torch/_inductor/compile_fx.py#L390. |
| # Thus, when compiling, inductor will determine if cudagraphs |
| # can be enabled based on whether there is input mutation or CPU tensors. |
| if ( |
| not torch.compiler.is_compiling() |
| and torch.backends.cuda.is_built() |
| and torch.cuda.is_available() |
| ): |
| capturing = torch.cuda.is_current_stream_capturing() |
| |
| if capturing and not all( |
| group["capturable"] for group in self.param_groups |
| ): |
| raise RuntimeError( |
| "Attempting CUDA graph capture of step() for an instance of " |
| + self.__class__.__name__ |
| + " but param_groups' capturable is False." |
| ) |
| |
| if ( |
| (not getattr(self, "_warned_capturable_if_run_uncaptured", False)) |
| and all(group["capturable"] for group in self.param_groups) |
| and (not capturing) |
| ): |
| warnings.warn( |
| "This instance was constructed with capturable=True or some of all the param_groups came with capturable=True, " |
| "but step() is running without CUDA graph capture. If you never intend to graph-capture this " |
| "instance, capturable=True can impair performance, and you should set capturable=False." |
| ) |
| self._warned_capturable_if_run_uncaptured = True |
| |
| def _optimizer_step_code(self) -> None: |
| """Entry point for `torch.profile.profiler`. |
| |
| When python tracing is enabled the profiler will hook into this |
| function at the CPython level to inspect the optimizer's parameters and |
| param groups. It is called it after `step()` since many optimizers |
| lazily initialize state. |
| |
| This is a workaround due to lack of a proper step hook on the optimizer, |
| and will be removed if it exists. |
| """ |
| pass |
| |
| @staticmethod |
| def profile_hook_step(func: Callable[_P, R]) -> Callable[_P, R]: |
| @functools.wraps(func) |
| def wrapper(*args: _P.args, **kwargs: _P.kwargs) -> R: |
| self, *_ = args |
| self = cast(Optimizer, self) |
| profile_name = f"Optimizer.step#{self.__class__.__name__}.step" |
| with torch.autograd.profiler.record_function(profile_name): |
| # call optimizer step pre hooks |
| for pre_hook in chain( |
| _global_optimizer_pre_hooks.values(), |
| self._optimizer_step_pre_hooks.values(), |
| ): |
| result = pre_hook(self, args, kwargs) |
| if result is not None: |
| if isinstance(result, tuple) and len(result) == 2: |
| args, kwargs = result # type: ignore[assignment] |
| else: |
| raise RuntimeError( |
| f"{func} must return None or a tuple of (new_args, new_kwargs), but got {result}." |
| ) |
| |
| out = func(*args, **kwargs) |
| self._optimizer_step_code() |
| |
| # call optimizer step post hooks |
| for post_hook in chain( |
| self._optimizer_step_post_hooks.values(), |
| _global_optimizer_post_hooks.values(), |
| ): |
| post_hook(self, args, kwargs) |
| |
| return out |
| |
| return wrapper |
| |
| @staticmethod |
| def _group_tensors_by_device_and_dtype( |
| tensorlistlist: TensorListList, |
| with_indices: bool = False, |
| ) -> Union[ |
| Dict[Tuple[None, None], Tuple[TensorListList, Indices]], |
| Dict[Tuple[torch.device, torch.dtype], Tuple[TensorListList, Indices]], |
| ]: |
| """Groups a list of lists of tensors by device and dtype. |
| Skips this step if we are compiling since this will occur during inductor lowering. |
| """ |
| if torch.compiler.is_compiling(): |
| return {(None, None): (tensorlistlist, list(range(len(tensorlistlist[0]))))} |
| else: |
| return _group_tensors_by_device_and_dtype(tensorlistlist, with_indices) # type: ignore[return-value, arg-type] |
| |
| def _patch_step_function(self) -> None: |
| self._zero_grad_profile_name = ( |
| f"Optimizer.zero_grad#{self.__class__.__name__}.zero_grad" |
| ) |
| hooked = getattr(self.__class__.step, "hooked", None) |
| if not hooked: |
| self.__class__.step = self.profile_hook_step(self.__class__.step) # type: ignore[assignment] |
| self.__class__.step.hooked = True # type: ignore[attr-defined] |
| |
| def register_step_pre_hook(self, hook: OptimizerPreHook) -> RemovableHandle: |
| r"""Register an optimizer step pre hook which will be called before |
| optimizer step. It should have the following signature:: |
| |
| hook(optimizer, args, kwargs) -> None or modified args and kwargs |
| |
| The ``optimizer`` argument is the optimizer instance being used. If |
| args and kwargs are modified by the pre-hook, then the transformed |
| values are returned as a tuple containing the new_args and new_kwargs. |
| |
| Args: |
| hook (Callable): The user defined hook to be registered. |
| |
| Returns: |
| :class:`torch.utils.hooks.RemovableHandle`: |
| a handle that can be used to remove the added hook by calling |
| ``handle.remove()`` |
| """ |
| handle = hooks.RemovableHandle(self._optimizer_step_pre_hooks) |
| self._optimizer_step_pre_hooks[handle.id] = hook |
| return handle |
| |
| def register_step_post_hook(self, hook: OptimizerPostHook) -> RemovableHandle: |
| r"""Register an optimizer step post hook which will be called after optimizer step. |
| It should have the following signature:: |
| |
| hook(optimizer, args, kwargs) -> None |
| |
| The ``optimizer`` argument is the optimizer instance being used. |
| |
| Args: |
| hook (Callable): The user defined hook to be registered. |
| |
| Returns: |
| :class:`torch.utils.hooks.RemovableHandle`: |
| a handle that can be used to remove the added hook by calling |
| ``handle.remove()`` |
| """ |
| handle = hooks.RemovableHandle(self._optimizer_step_post_hooks) |
| self._optimizer_step_post_hooks[handle.id] = hook |
| return handle |
| |
| def register_state_dict_pre_hook( |
| self, hook: Callable[["Optimizer"], None], prepend: bool = False |
| ) -> RemovableHandle: |
| r"""Register a state dict pre-hook which will be called before |
| :meth:`~torch.optim.Optimizer.state_dict` is called. It should have the |
| following signature:: |
| |
| hook(optimizer) -> None |
| |
| The ``optimizer`` argument is the optimizer instance being used. |
| The hook will be called with argument ``self`` before calling ``state_dict`` on ``self``. |
| The registered hook can be used to perform pre-processing before the ``state_dict`` |
| call is made. |
| |
| Args: |
| hook (Callable): The user defined hook to be registered. |
| prepend (bool): If True, the provided pre ``hook`` will be fired before |
| all the already registered pre-hooks on ``state_dict``. Otherwise, |
| the provided ``hook`` will be fired after all the already registered |
| pre-hooks. (default: False) |
| |
| Returns: |
| :class:`torch.utils.hooks.RemoveableHandle`: |
| a handle that can be used to remove the added hook by calling |
| ``handle.remove()`` |
| """ |
| handle = hooks.RemovableHandle(self._optimizer_state_dict_pre_hooks) |
| self._optimizer_state_dict_pre_hooks[handle.id] = hook |
| if prepend: |
| self._optimizer_state_dict_pre_hooks.move_to_end(handle.id, last=False) |
| return handle |
| |
| def register_state_dict_post_hook( |
| self, |
| hook: Callable[["Optimizer", StateDict], Optional[StateDict]], |
| prepend: bool = False, |
| ) -> RemovableHandle: |
| r"""Register a state dict post-hook which will be called after |
| :meth:`~torch.optim.Optimizer.state_dict` is called. It should have the |
| following signature:: |
| |
| hook(optimizer, state_dict) -> state_dict or None |
| |
| The hook will be called with arguments ``self`` and ``state_dict`` after generating |
| a ``state_dict`` on ``self``. The hook may modify the state_dict inplace or optionally |
| return a new one. The registered hook can be used to perform post-processing |
| on the ``state_dict`` before it is returned. |
| |
| Args: |
| hook (Callable): The user defined hook to be registered. |
| prepend (bool): If True, the provided post ``hook`` will be fired before |
| all the already registered post-hooks on ``state_dict``. Otherwise, |
| the provided ``hook`` will be fired after all the already registered |
| post-hooks. (default: False) |
| |
| Returns: |
| :class:`torch.utils.hooks.RemoveableHandle`: |
| a handle that can be used to remove the added hook by calling |
| ``handle.remove()`` |
| """ |
| handle = hooks.RemovableHandle(self._optimizer_state_dict_post_hooks) |
| self._optimizer_state_dict_post_hooks[handle.id] = hook |
| if prepend: |
| self._optimizer_state_dict_post_hooks.move_to_end(handle.id, last=False) |
| return handle |
| |
| @torch._disable_dynamo |
| def state_dict(self) -> StateDict: |
| r"""Returns the state of the optimizer as a :class:`dict`. |
| |
| It contains two entries: |
| |
| * ``state``: a Dict holding current optimization state. Its content |
| differs between optimizer classes, but some common characteristics |
| hold. For example, state is saved per parameter, and the parameter |
| itself is NOT saved. ``state`` is a Dictionary mapping parameter ids |
| to a Dict with state corresponding to each parameter. |
| * ``param_groups``: a List containing all parameter groups where each |
| parameter group is a Dict. Each parameter group contains metadata |
| specific to the optimizer, such as learning rate and weight decay, |
| as well as a List of parameter IDs of the parameters in the group. |
| |
| NOTE: The parameter IDs may look like indices but they are just IDs |
| associating state with param_group. When loading from a state_dict, |
| the optimizer will zip the param_group ``params`` (int IDs) and the |
| optimizer ``param_groups`` (actual ``nn.Parameter`` s) in order to |
| match state WITHOUT additional verification. |
| |
| A returned state dict might look something like: |
| |
| .. code-block:: text |
| |
| { |
| 'state': { |
| 0: {'momentum_buffer': tensor(...), ...}, |
| 1: {'momentum_buffer': tensor(...), ...}, |
| 2: {'momentum_buffer': tensor(...), ...}, |
| 3: {'momentum_buffer': tensor(...), ...} |
| }, |
| 'param_groups': [ |
| { |
| 'lr': 0.01, |
| 'weight_decay': 0, |
| ... |
| 'params': [0] |
| }, |
| { |
| 'lr': 0.001, |
| 'weight_decay': 0.5, |
| ... |
| 'params': [1, 2, 3] |
| } |
| ] |
| } |
| |
| """ |
| |
| for pre_hook in self._optimizer_state_dict_pre_hooks.values(): |
| pre_hook(self) |
| |
| # Save order indices instead of Tensors |
| param_mappings: Dict[int, int] = {} |
| start_index = 0 |
| |
| def pack_group(group: Dict[str, Any]) -> Dict[str, Any]: |
| nonlocal start_index |
| packed = {k: v for k, v in group.items() if k != "params"} |
| param_mappings.update( |
| { |
| id(p): i |
| for i, p in enumerate(group["params"], start_index) |
| if id(p) not in param_mappings |
| } |
| ) |
| packed["params"] = [param_mappings[id(p)] for p in group["params"]] |
| start_index += len(packed["params"]) |
| return packed |
| |
| param_groups = [pack_group(g) for g in self.param_groups] |
| # Remap state to use order indices as keys |
| packed_state = { |
| (param_mappings[id(k)] if isinstance(k, torch.Tensor) else k): v |
| for k, v in self.state.items() |
| } |
| |
| state_dict = { |
| "state": packed_state, |
| "param_groups": param_groups, |
| } |
| |
| for post_hook in self._optimizer_state_dict_post_hooks.values(): |
| hook_result = post_hook(self, state_dict) |
| if hook_result is not None: |
| state_dict = hook_result |
| return state_dict |
| |
| @staticmethod |
| def _process_value_according_to_param_policy( |
| param: torch.Tensor, |
| value: torch.Tensor, |
| param_id: int, |
| param_groups: List[Dict[Any, Any]], |
| key: Hashable = None, |
| ) -> torch.Tensor: |
| # Floating-point types are a bit special here. They are the only ones |
| # that are assumed to always match the type of params. |
| # Make sure state['step'] is not casted https://github.com/pytorch/pytorch/issues/74424 |
| # UNLESS fused or capturable, see note [special device hosting for step] |
| fused = False |
| capturable = False |
| assert param_groups is not None |
| for pg in param_groups: |
| if param_id in pg["params"]: |
| fused = pg["fused"] if "fused" in pg else False |
| capturable = pg["capturable"] if "capturable" in pg else False |
| break |
| if key == "step": |
| if capturable or fused: |
| return value.to(dtype=torch.float32, device=param.device) |
| else: |
| return value |
| else: |
| if param.is_floating_point(): |
| return value.to(dtype=param.dtype, device=param.device) |
| else: |
| return value.to(device=param.device) |
| |
| def register_load_state_dict_pre_hook( |
| self, |
| hook: Callable[["Optimizer", StateDict], Optional[StateDict]], |
| prepend: bool = False, |
| ) -> RemovableHandle: |
| r"""Register a load_state_dict pre-hook which will be called before |
| :meth:`~torch.optim.Optimizer.load_state_dict` is called. It should have the |
| following signature:: |
| |
| hook(optimizer, state_dict) -> state_dict or None |
| |
| The ``optimizer`` argument is the optimizer instance being used and the |
| ``state_dict`` argument is a shallow copy of the ``state_dict`` the user |
| passed in to ``load_state_dict``. The hook may modify the state_dict inplace |
| or optionally return a new one. If a state_dict is returned, it will be used |
| to be loaded into the optimizer. |
| |
| The hook will be called with argument ``self`` and ``state_dict`` before |
| calling ``load_state_dict`` on ``self``. The registered hook can be used to |
| perform pre-processing before the ``load_state_dict`` call is made. |
| |
| Args: |
| hook (Callable): The user defined hook to be registered. |
| prepend (bool): If True, the provided pre ``hook`` will be fired before |
| all the already registered pre-hooks on ``load_state_dict``. Otherwise, |
| the provided ``hook`` will be fired after all the already registered |
| pre-hooks. (default: False) |
| |
| Returns: |
| :class:`torch.utils.hooks.RemoveableHandle`: |
| a handle that can be used to remove the added hook by calling |
| ``handle.remove()`` |
| """ |
| handle = hooks.RemovableHandle(self._optimizer_load_state_dict_pre_hooks) |
| self._optimizer_load_state_dict_pre_hooks[handle.id] = hook |
| if prepend: |
| self._optimizer_load_state_dict_pre_hooks.move_to_end(handle.id, last=False) |
| return handle |
| |
| def register_load_state_dict_post_hook( |
| self, hook: Callable[["Optimizer"], None], prepend: bool = False |
| ) -> RemovableHandle: |
| r"""Register a load_state_dict post-hook which will be called after |
| :meth:`~torch.optim.Optimizer.load_state_dict` is called. It should have the |
| following signature:: |
| |
| hook(optimizer) -> None |
| |
| The ``optimizer`` argument is the optimizer instance being used. |
| |
| The hook will be called with argument ``self`` after calling |
| ``load_state_dict`` on ``self``. The registered hook can be used to |
| perform post-processing after ``load_state_dict`` has loaded the |
| ``state_dict``. |
| |
| Args: |
| hook (Callable): The user defined hook to be registered. |
| prepend (bool): If True, the provided post ``hook`` will be fired before |
| all the already registered post-hooks on ``load_state_dict``. Otherwise, |
| the provided ``hook`` will be fired after all the already registered |
| post-hooks. (default: False) |
| |
| Returns: |
| :class:`torch.utils.hooks.RemoveableHandle`: |
| a handle that can be used to remove the added hook by calling |
| ``handle.remove()`` |
| """ |
| handle = hooks.RemovableHandle(self._optimizer_load_state_dict_post_hooks) |
| self._optimizer_load_state_dict_post_hooks[handle.id] = hook |
| if prepend: |
| self._optimizer_load_state_dict_post_hooks.move_to_end(handle.id, last=False) # type: ignore[attr-defined] |
| return handle |
| |
| @torch._disable_dynamo |
| def load_state_dict(self, state_dict: StateDict) -> None: |
| r"""Loads the optimizer state. |
| |
| Args: |
| state_dict (dict): optimizer state. Should be an object returned |
| from a call to :meth:`state_dict`. |
| """ |
| # shallow copy, to be consistent with module API |
| state_dict = state_dict.copy() |
| |
| for pre_hook in self._optimizer_load_state_dict_pre_hooks.values(): |
| hook_result = pre_hook(self, state_dict) |
| if hook_result is not None: |
| state_dict = hook_result |
| |
| # Validate the state_dict |
| groups = self.param_groups |
| |
| # Deepcopy as we write into saved_groups later to update state |
| saved_groups = deepcopy(state_dict["param_groups"]) |
| |
| if len(groups) != len(saved_groups): |
| raise ValueError( |
| "loaded state dict has a different number of " "parameter groups" |
| ) |
| param_lens = (len(g["params"]) for g in groups) |
| saved_lens = (len(g["params"]) for g in saved_groups) |
| if any(p_len != s_len for p_len, s_len in zip(param_lens, saved_lens)): |
| raise ValueError( |
| "loaded state dict contains a parameter group " |
| "that doesn't match the size of optimizer's group" |
| ) |
| |
| # Update the state |
| id_map = dict( |
| zip( |
| chain.from_iterable(g["params"] for g in saved_groups), |
| chain.from_iterable(g["params"] for g in groups), |
| ) |
| ) |
| |
| def _cast(param, value, param_id=None, param_groups=None, key=None): |
| r"""Make a deep copy of value, casting all tensors to device of param.""" |
| if isinstance(value, torch.Tensor): |
| return Optimizer._process_value_according_to_param_policy( |
| param, value, param_id, param_groups, key |
| ) |
| elif isinstance(value, dict): |
| return { |
| k: _cast( |
| param, v, param_id=param_id, param_groups=param_groups, key=k |
| ) |
| for k, v in value.items() |
| } |
| elif isinstance(value, Iterable): |
| return type(value)(_cast(param, v, param_id=param_id, param_groups=param_groups) for v in value) # type: ignore[call-arg] |
| else: |
| return value |
| |
| # Copy state assigned to params (and cast tensors to appropriate types). |
| # State that is not assigned to params is copied as is (needed for |
| # backward compatibility). |
| state: DefaultDict[torch.Tensor, Dict[Any, Any]] = defaultdict(dict) |
| for k, v in state_dict["state"].items(): |
| if k in id_map: |
| param = id_map[k] |
| state[param] = _cast( |
| param, v, param_id=k, param_groups=state_dict["param_groups"] |
| ) |
| else: |
| state[k] = v |
| |
| # Update parameter groups, setting their 'params' value |
| def update_group( |
| group: Dict[str, Any], new_group: Dict[str, Any] |
| ) -> Dict[str, Any]: |
| new_group["params"] = group["params"] |
| return new_group |
| |
| param_groups = [update_group(g, ng) for g, ng in zip(groups, saved_groups)] |
| self.__setstate__({"state": state, "param_groups": param_groups}) |
| |
| for post_hook in self._optimizer_load_state_dict_post_hooks.values(): |
| post_hook(self) |
| |
| @torch._disable_dynamo |
| def zero_grad(self, set_to_none: bool = True) -> None: |
| r"""Resets the gradients of all optimized :class:`torch.Tensor` s. |
| |
| Args: |
| set_to_none (bool): instead of setting to zero, set the grads to None. |
| This will in general have lower memory footprint, and can modestly improve performance. |
| However, it changes certain behaviors. For example: |
| 1. When the user tries to access a gradient and perform manual ops on it, |
| a None attribute or a Tensor full of 0s will behave differently. |
| 2. If the user requests ``zero_grad(set_to_none=True)`` followed by a backward pass, ``.grad``\ s |
| are guaranteed to be None for params that did not receive a gradient. |
| 3. ``torch.optim`` optimizers have a different behavior if the gradient is 0 or None |
| (in one case it does the step with a gradient of 0 and in the other it skips |
| the step altogether). |
| """ |
| foreach = self.defaults.get("foreach", False) or self.defaults.get( |
| "fused", False |
| ) |
| |
| if not hasattr(self, "_zero_grad_profile_name"): |
| self._patch_step_function() |
| |
| per_device_and_dtype_grads: Optional[ |
| DefaultDict[torch.device, DefaultDict[torch.dtype, List[torch.Tensor]]] |
| ] |
| if foreach: |
| per_device_and_dtype_grads = defaultdict(lambda: defaultdict(list)) |
| else: |
| per_device_and_dtype_grads = None |
| |
| with torch.autograd.profiler.record_function(self._zero_grad_profile_name): |
| for group in self.param_groups: |
| for p in group["params"]: |
| if p.grad is not None: |
| if set_to_none: |
| p.grad = None |
| else: |
| if p.grad.grad_fn is not None: |
| p.grad.detach_() |
| else: |
| p.grad.requires_grad_(False) |
| if not foreach or p.grad.is_sparse: |
| p.grad.zero_() |
| else: |
| assert per_device_and_dtype_grads is not None |
| per_device_and_dtype_grads[p.grad.device][ |
| p.grad.dtype |
| ].append(p.grad) |
| if foreach: |
| assert per_device_and_dtype_grads is not None |
| for per_dtype_grads in per_device_and_dtype_grads.values(): |
| for grads in per_dtype_grads.values(): |
| torch._foreach_zero_(grads) |
| |
| @overload |
| def step(self, closure: None = ...) -> None: |
| ... |
| |
| @overload |
| def step(self, closure: Callable[[], float]) -> float: |
| ... |
| |
| def step(self, closure: Optional[Callable[[], float]] = None) -> Optional[float]: |
| r"""Performs a single optimization step (parameter update). |
| |
| Args: |
| closure (Callable): A closure that reevaluates the model and |
| returns the loss. Optional for most optimizers. |
| |
| .. note:: |
| Unless otherwise specified, this function should not modify the |
| ``.grad`` field of the parameters. |
| """ |
| raise NotImplementedError |
| |
| @torch._disable_dynamo |
| def add_param_group(self, param_group: Dict[str, Any]) -> None: |
| r"""Add a param group to the :class:`Optimizer` s `param_groups`. |
| |
| This can be useful when fine tuning a pre-trained network as frozen layers can be made |
| trainable and added to the :class:`Optimizer` as training progresses. |
| |
| Args: |
| param_group (dict): Specifies what Tensors should be optimized along with group |
| specific optimization options. |
| """ |
| if not isinstance(param_group, dict): |
| raise TypeError(f"param_group must be a dict, but got {type(param_group)}") |
| |
| params = param_group["params"] |
| if isinstance(params, torch.Tensor): |
| param_group["params"] = [params] |
| elif isinstance(params, set): |
| raise TypeError( |
| "optimizer parameters need to be organized in ordered collections, but " |
| "the ordering of tensors in sets will change between runs. Please use a list instead." |
| ) |
| else: |
| param_group["params"] = list(params) |
| |
| for param in param_group["params"]: |
| if not isinstance(param, torch.Tensor): |
| raise TypeError( |
| "optimizer can only optimize Tensors, " |
| "but one of the params is " + torch.typename(param) |
| ) |
| if not self.defaults.get("differentiable", None) and not ( |
| param.is_leaf or param.retains_grad |
| ): |
| raise ValueError("can't optimize a non-leaf Tensor") |
| |
| for name, default in self.defaults.items(): |
| if default is required and name not in param_group: |
| raise ValueError( |
| f"parameter group didn't specify a value of required optimization parameter {name}" |
| ) |
| else: |
| param_group.setdefault(name, default) |
| |
| params = param_group["params"] |
| if len(params) != len(set(params)): |
| warnings.warn( |
| "optimizer contains a parameter group with duplicate parameters; " |
| "in future, this will cause an error; " |
| "see github.com/pytorch/pytorch/issues/40967 for more information", |
| stacklevel=3, |
| ) |
| |
| param_set: Set[torch.Tensor] = set() |
| for group in self.param_groups: |
| param_set.update(set(group["params"])) |
| |
| if not param_set.isdisjoint(set(param_group["params"])): |
| raise ValueError("some parameters appear in more than one parameter group") |
| |
| self.param_groups.append(param_group) |