| r"""Functional interface""" |
| import math |
| from torch import Tensor |
| from typing import List |
| |
| from .adadelta import adadelta # type: ignore[attr-defined] # noqa: F401 |
| from .adagrad import adagrad, _make_sparse # type: ignore[attr-defined] # noqa: F401 |
| from .adam import adam # type: ignore[attr-defined] # noqa: F401 |
| from .adamw import adamw # type: ignore[attr-defined] # noqa: F401 |
| from .adamax import adamax # type: ignore[attr-defined] # noqa: F401 |
| from .asgd import asgd # type: ignore[attr-defined] # noqa: F401 |
| from .nadam import nadam # type: ignore[attr-defined] # noqa: F401 |
| from .radam import radam # type: ignore[attr-defined] # noqa: F401 |
| from .rmsprop import rmsprop # type: ignore[attr-defined] # noqa: F401 |
| from .rprop import rprop # type: ignore[attr-defined] # noqa: F401 |
| from .sgd import sgd # type: ignore[attr-defined] # noqa: F401 |
| |
| |
| # TODO: use foreach API in optim._functional to do all the computation |
| |
| |
| def sparse_adam(params: List[Tensor], |
| grads: List[Tensor], |
| exp_avgs: List[Tensor], |
| exp_avg_sqs: List[Tensor], |
| state_steps: List[int], |
| *, |
| eps: float, |
| beta1: float, |
| beta2: float, |
| lr: float, |
| maximize: bool): |
| r"""Functional API that performs Sparse Adam algorithm computation. |
| |
| See :class:`~torch.optim.SparseAdam` for details. |
| """ |
| for i, param in enumerate(params): |
| grad = grads[i] |
| grad = grad if not maximize else -grad |
| grad = grad.coalesce() # the update is non-linear so indices must be unique |
| grad_indices = grad._indices() |
| grad_values = grad._values() |
| size = grad.size() |
| |
| exp_avg = exp_avgs[i] |
| exp_avg_sq = exp_avg_sqs[i] |
| step = state_steps[i] |
| |
| |
| def make_sparse(values): |
| constructor = grad.new |
| if grad_indices.dim() == 0 or values.dim() == 0: |
| return constructor().resize_as_(grad) |
| return constructor(grad_indices, values, size) |
| |
| # Decay the first and second moment running average coefficient |
| # old <- b * old + (1 - b) * new |
| # <==> old += (1 - b) * (new - old) |
| old_exp_avg_values = exp_avg.sparse_mask(grad)._values() |
| exp_avg_update_values = grad_values.sub(old_exp_avg_values).mul_(1 - beta1) |
| exp_avg.add_(make_sparse(exp_avg_update_values)) |
| old_exp_avg_sq_values = exp_avg_sq.sparse_mask(grad)._values() |
| exp_avg_sq_update_values = grad_values.pow(2).sub_(old_exp_avg_sq_values).mul_(1 - beta2) |
| exp_avg_sq.add_(make_sparse(exp_avg_sq_update_values)) |
| |
| # Dense addition again is intended, avoiding another sparse_mask |
| numer = exp_avg_update_values.add_(old_exp_avg_values) |
| exp_avg_sq_update_values.add_(old_exp_avg_sq_values) |
| denom = exp_avg_sq_update_values.sqrt_().add_(eps) |
| del exp_avg_update_values, exp_avg_sq_update_values |
| |
| bias_correction1 = 1 - beta1 ** step |
| bias_correction2 = 1 - beta2 ** step |
| step_size = lr * math.sqrt(bias_correction2) / bias_correction1 |
| |
| param.add_(make_sparse(-step_size * numer.div_(denom))) |