| from collections import defaultdict, abc as container_abcs |
| import torch |
| from copy import deepcopy |
| from itertools import chain |
| import warnings |
| import functools |
| |
| __all__ = ['Optimizer'] |
| |
| class _RequiredParameter(object): |
| """Singleton class representing a required parameter for an Optimizer.""" |
| def __repr__(self): |
| return "<required parameter>" |
| |
| required = _RequiredParameter() |
| |
| |
| def _use_grad_for_differentiable(func): |
| def _use_grad(self, *args, **kwargs): |
| prev_grad = torch.is_grad_enabled() |
| try: |
| torch.set_grad_enabled(self.defaults['differentiable']) |
| ret = func(self, *args, **kwargs) |
| finally: |
| torch.set_grad_enabled(prev_grad) |
| return ret |
| return _use_grad |
| |
| |
| class Optimizer(object): |
| r"""Base class for all optimizers. |
| |
| .. warning:: |
| Parameters need to be specified as collections that have a deterministic |
| ordering that is consistent between runs. Examples of objects that don't |
| satisfy those properties are sets and iterators over values of dictionaries. |
| |
| Args: |
| params (iterable): an iterable of :class:`torch.Tensor` s or |
| :class:`dict` s. Specifies what Tensors should be optimized. |
| defaults: (dict): a dict containing default values of optimization |
| options (used when a parameter group doesn't specify them). |
| """ |
| |
| def __init__(self, params, defaults): |
| torch._C._log_api_usage_once("python.optimizer") |
| self.defaults = defaults |
| |
| self._hook_for_profile() |
| |
| if isinstance(params, torch.Tensor): |
| raise TypeError("params argument given to the optimizer should be " |
| "an iterable of Tensors or dicts, but got " + |
| torch.typename(params)) |
| |
| self.state = defaultdict(dict) |
| self.param_groups = [] |
| |
| param_groups = list(params) |
| if len(param_groups) == 0: |
| raise ValueError("optimizer got an empty parameter list") |
| if not isinstance(param_groups[0], dict): |
| param_groups = [{'params': param_groups}] |
| |
| for param_group in param_groups: |
| self.add_param_group(param_group) |
| |
| # Allows _cuda_graph_capture_health_check to rig a poor man's TORCH_WARN_ONCE in python, |
| # which I don't think exists |
| # https://github.com/pytorch/pytorch/issues/72948 |
| self._warned_capturable_if_run_uncaptured = True |
| |
| |
| def __getstate__(self): |
| return { |
| 'defaults': self.defaults, |
| 'state': self.state, |
| 'param_groups': self.param_groups, |
| } |
| |
| def __setstate__(self, state): |
| self.__dict__.update(state) |
| self._hook_for_profile() # To support multiprocessing pickle/unpickle. |
| self.defaults.setdefault('differentiable', False) |
| |
| def __repr__(self): |
| format_string = self.__class__.__name__ + ' (' |
| for i, group in enumerate(self.param_groups): |
| format_string += '\n' |
| format_string += 'Parameter Group {0}\n'.format(i) |
| for key in sorted(group.keys()): |
| if key != 'params': |
| format_string += ' {0}: {1}\n'.format(key, group[key]) |
| format_string += ')' |
| return format_string |
| |
| # Currently needed by Adam and AdamW |
| def _cuda_graph_capture_health_check(self): |
| if torch.has_cuda and torch.cuda.is_available(): |
| capturing = torch.cuda.is_current_stream_capturing() |
| |
| if capturing and not self.defaults['capturable']: |
| raise RuntimeError("Attempting CUDA graph capture of step() for an instance of " + |
| self.__class__.__name__ + |
| " but this instance was constructed with capturable=False.") |
| |
| if ( |
| (not getattr(self, "_warned_capturable_if_run_uncaptured", False)) |
| and self.defaults["capturable"] |
| and (not capturing) |
| ): |
| print("Warning: This instance was constructed with capturable=True, but step() " + |
| "is running without CUDA graph capture. If you never intend to graph-capture this " + |
| "instance, capturable=True can impair performance, and you should set capturable=False.") |
| self._warned_capturable_if_run_uncaptured = True |
| |
| def _optimizer_step_code(self): |
| """Entry point for `torch.profile.profiler`. |
| |
| When python tracing is enabled the profiler will hook into this |
| function at the CPython level to inspect the optimizer's parameters and |
| param groups. It is called it after `step()` since many optimizers |
| lazily initialize state. |
| |
| This is a workaround due to lack of a proper step hook on the optimizer, |
| and will be removed if it exists. |
| """ |
| pass |
| |
| def _hook_for_profile(self): |
| self._zero_grad_profile_name = "Optimizer.zero_grad#{}.zero_grad".format(self.__class__.__name__) |
| |
| def profile_hook_step(func): |
| |
| @functools.wraps(func) |
| def wrapper(*args, **kwargs): |
| obj, *_ = args |
| profile_name = "Optimizer.step#{}.step".format(obj.__class__.__name__) |
| with torch.autograd.profiler.record_function(profile_name): |
| out = func(*args, **kwargs) |
| obj._optimizer_step_code() |
| return out |
| |
| return wrapper |
| |
| hooked = getattr(self.__class__.step, "hooked", None) |
| if not hooked: |
| self.__class__.step = profile_hook_step(self.__class__.step) |
| self.__class__.step.hooked = True |
| |
| def state_dict(self): |
| r"""Returns the state of the optimizer as a :class:`dict`. |
| |
| It contains two entries: |
| |
| * state - a dict holding current optimization state. Its content |
| differs between optimizer classes. |
| * param_groups - a list containing all parameter groups where each |
| parameter group is a dict |
| """ |
| # Save order indices instead of Tensors |
| param_mappings = {} |
| start_index = 0 |
| |
| def pack_group(group): |
| nonlocal start_index |
| packed = {k: v for k, v in group.items() if k != 'params'} |
| param_mappings.update({id(p): i for i, p in enumerate(group['params'], start_index) |
| if id(p) not in param_mappings}) |
| packed['params'] = [param_mappings[id(p)] for p in group['params']] |
| start_index += len(packed['params']) |
| return packed |
| param_groups = [pack_group(g) for g in self.param_groups] |
| # Remap state to use order indices as keys |
| packed_state = {(param_mappings[id(k)] if isinstance(k, torch.Tensor) else k): v |
| for k, v in self.state.items()} |
| return { |
| 'state': packed_state, |
| 'param_groups': param_groups, |
| } |
| |
| def load_state_dict(self, state_dict): |
| r"""Loads the optimizer state. |
| |
| Args: |
| state_dict (dict): optimizer state. Should be an object returned |
| from a call to :meth:`state_dict`. |
| """ |
| # deepcopy, to be consistent with module API |
| state_dict = deepcopy(state_dict) |
| # Validate the state_dict |
| groups = self.param_groups |
| saved_groups = state_dict['param_groups'] |
| |
| if len(groups) != len(saved_groups): |
| raise ValueError("loaded state dict has a different number of " |
| "parameter groups") |
| param_lens = (len(g['params']) for g in groups) |
| saved_lens = (len(g['params']) for g in saved_groups) |
| if any(p_len != s_len for p_len, s_len in zip(param_lens, saved_lens)): |
| raise ValueError("loaded state dict contains a parameter group " |
| "that doesn't match the size of optimizer's group") |
| |
| # Update the state |
| id_map = {old_id: p for old_id, p in |
| zip(chain.from_iterable((g['params'] for g in saved_groups)), |
| chain.from_iterable((g['params'] for g in groups)))} |
| |
| def cast(param, value, key=None): |
| r"""Make a deep copy of value, casting all tensors to device of param.""" |
| if isinstance(value, torch.Tensor): |
| # Floating-point types are a bit special here. They are the only ones |
| # that are assumed to always match the type of params. |
| # Make sure state['step'] is not casted https://github.com/pytorch/pytorch/issues/74424 |
| if (key != "step"): |
| if param.is_floating_point(): |
| value = value.to(param.dtype) |
| value = value.to(param.device) |
| return value |
| elif isinstance(value, dict): |
| return {k: cast(param, v, key=k) for k, v in value.items()} |
| elif isinstance(value, container_abcs.Iterable): |
| return type(value)(cast(param, v) for v in value) |
| else: |
| return value |
| |
| # Copy state assigned to params (and cast tensors to appropriate types). |
| # State that is not assigned to params is copied as is (needed for |
| # backward compatibility). |
| state = defaultdict(dict) |
| for k, v in state_dict['state'].items(): |
| if k in id_map: |
| param = id_map[k] |
| state[param] = cast(param, v) |
| else: |
| state[k] = v |
| |
| # Update parameter groups, setting their 'params' value |
| def update_group(group, new_group): |
| new_group['params'] = group['params'] |
| return new_group |
| param_groups = [ |
| update_group(g, ng) for g, ng in zip(groups, saved_groups)] |
| self.__setstate__({'state': state, 'param_groups': param_groups}) |
| |
| def zero_grad(self, set_to_none: bool = False): |
| r"""Sets the gradients of all optimized :class:`torch.Tensor` s to zero. |
| |
| Args: |
| set_to_none (bool): instead of setting to zero, set the grads to None. |
| This will in general have lower memory footprint, and can modestly improve performance. |
| However, it changes certain behaviors. For example: |
| 1. When the user tries to access a gradient and perform manual ops on it, |
| a None attribute or a Tensor full of 0s will behave differently. |
| 2. If the user requests ``zero_grad(set_to_none=True)`` followed by a backward pass, ``.grad``\ s |
| are guaranteed to be None for params that did not receive a gradient. |
| 3. ``torch.optim`` optimizers have a different behavior if the gradient is 0 or None |
| (in one case it does the step with a gradient of 0 and in the other it skips |
| the step altogether). |
| """ |
| foreach = self.defaults.get('foreach', False) |
| |
| if not hasattr(self, "_zero_grad_profile_name"): |
| self._hook_for_profile() |
| if foreach: |
| per_device_and_dtype_grads = defaultdict(lambda: defaultdict(list)) |
| with torch.autograd.profiler.record_function(self._zero_grad_profile_name): |
| for group in self.param_groups: |
| for p in group['params']: |
| if p.grad is not None: |
| if set_to_none: |
| p.grad = None |
| else: |
| if p.grad.grad_fn is not None: |
| p.grad.detach_() |
| else: |
| p.grad.requires_grad_(False) |
| if (not foreach or p.grad.is_sparse): |
| p.grad.zero_() |
| else: |
| per_device_and_dtype_grads[p.grad.device][p.grad.dtype].append(p.grad) |
| if foreach: |
| for _, per_dtype_grads in per_device_and_dtype_grads.items(): |
| for grads in per_dtype_grads.values(): |
| torch._foreach_zero_(grads) |
| |
| def step(self, closure): |
| r"""Performs a single optimization step (parameter update). |
| |
| Args: |
| closure (Callable): A closure that reevaluates the model and |
| returns the loss. Optional for most optimizers. |
| |
| .. note:: |
| Unless otherwise specified, this function should not modify the |
| ``.grad`` field of the parameters. |
| """ |
| raise NotImplementedError |
| |
| def add_param_group(self, param_group): |
| r"""Add a param group to the :class:`Optimizer` s `param_groups`. |
| |
| This can be useful when fine tuning a pre-trained network as frozen layers can be made |
| trainable and added to the :class:`Optimizer` as training progresses. |
| |
| Args: |
| param_group (dict): Specifies what Tensors should be optimized along with group |
| specific optimization options. |
| """ |
| assert isinstance(param_group, dict), "param group must be a dict" |
| |
| params = param_group['params'] |
| if isinstance(params, torch.Tensor): |
| param_group['params'] = [params] |
| elif isinstance(params, set): |
| raise TypeError('optimizer parameters need to be organized in ordered collections, but ' |
| 'the ordering of tensors in sets will change between runs. Please use a list instead.') |
| else: |
| param_group['params'] = list(params) |
| |
| for param in param_group['params']: |
| if not isinstance(param, torch.Tensor): |
| raise TypeError("optimizer can only optimize Tensors, " |
| "but one of the params is " + torch.typename(param)) |
| if not self.defaults.get('differentiable', None) and not (param.is_leaf or param.retains_grad): |
| raise ValueError("can't optimize a non-leaf Tensor") |
| |
| for name, default in self.defaults.items(): |
| if default is required and name not in param_group: |
| raise ValueError("parameter group didn't specify a value of required optimization parameter " + |
| name) |
| else: |
| param_group.setdefault(name, default) |
| |
| params = param_group['params'] |
| if len(params) != len(set(params)): |
| warnings.warn("optimizer contains a parameter group with duplicate parameters; " |
| "in future, this will cause an error; " |
| "see github.com/pytorch/pytorch/issues/40967 for more information", stacklevel=3) |
| |
| param_set = set() |
| for group in self.param_groups: |
| param_set.update(set(group['params'])) |
| |
| if not param_set.isdisjoint(set(param_group['params'])): |
| raise ValueError("some parameters appear in more than one parameter group") |
| |
| self.param_groups.append(param_group) |