| # Owner(s): ["module: optimizer"] |
| |
| import torch |
| from torch.optim import ( |
| Adadelta, |
| Adagrad, |
| Adam, |
| Adamax, |
| AdamW, |
| ASGD, |
| NAdam, |
| RAdam, |
| RMSprop, |
| Rprop, |
| SGD, |
| ) |
| from torch.testing._internal.common_utils import ( |
| gradcheck, |
| load_tests, |
| skipIfTorchDynamo, |
| TestCase, |
| ) |
| |
| |
| # load_tests from common_utils is used to automatically filter tests for |
| # sharding on sandcastle. This line silences flake warnings |
| load_tests = load_tests |
| |
| |
| def _diff_fn(p, grad, opt_differentiable_state, opt_class, kwargs, *ignored): |
| # Ignored is the list of values in `opt_differentiable_state`, we do this |
| # for `gradcheck` to correctly track the state tensors as function inputs |
| # because otherwise it can't unpack the values in the `opt_differentiable_state` |
| # dict |
| p = p.clone() |
| p.grad = grad |
| opt_differentiable_state = { |
| k: v.clone() if isinstance(v, torch.Tensor) else v |
| for k, v in opt_differentiable_state.items() |
| } |
| opt = opt_class([p], **kwargs) |
| opt.state[p].update(opt_differentiable_state) |
| opt.step() |
| return (p,) + tuple( |
| v |
| for v in opt.state[p].values() |
| if isinstance(v, torch.Tensor) and v.requires_grad |
| ) |
| |
| |
| @skipIfTorchDynamo("Differentiable optimizers not supported") |
| class TestDifferentiableOptimizer(TestCase): |
| def test_sgd(self): |
| p = torch.rand(10, requires_grad=True, dtype=torch.float64) |
| grad = torch.rand(10, requires_grad=True, dtype=torch.float64) |
| mbuff = torch.rand(10, requires_grad=True, dtype=torch.float64) |
| state = {"momentum_buffer": mbuff} |
| gradcheck( |
| _diff_fn, |
| ( |
| p, |
| grad, |
| state, |
| SGD, |
| {"lr": 0.9, "differentiable": True}, |
| *state.values(), |
| ), |
| ) |
| |
| def test_adam(self): |
| state = {} |
| p = torch.rand(10, requires_grad=True, dtype=torch.float64) |
| grad = torch.rand(10, requires_grad=True, dtype=torch.float64) |
| # `step` is not a continuous variable (even though we define it as a float) |
| # and so it shouldn't require gradients. |
| state["step"] = torch.tensor(10.0, requires_grad=False, dtype=torch.float64) |
| state["exp_avg"] = torch.rand(10, requires_grad=True, dtype=torch.float64) |
| state["exp_avg_sq"] = torch.rand(10, requires_grad=True, dtype=torch.float64) |
| state["max_exp_avg_sq"] = torch.rand( |
| 10, requires_grad=True, dtype=torch.float64 |
| ) |
| |
| gradcheck( |
| _diff_fn, |
| ( |
| p, |
| grad, |
| state, |
| Adam, |
| {"lr": 0.9, "differentiable": True, "amsgrad": True}, |
| *state.values(), |
| ), |
| ) |
| |
| def test_rmsprop(self): |
| state = {} |
| p = torch.rand(10, requires_grad=True, dtype=torch.float64) |
| grad = torch.rand(10, requires_grad=True, dtype=torch.float64) |
| state["step"] = torch.zeros((), dtype=torch.float64) |
| state["square_avg"] = torch.rand(10, requires_grad=True, dtype=torch.float64) |
| state["momentum_buffer"] = torch.rand( |
| 10, requires_grad=True, dtype=torch.float64 |
| ) |
| # This can cause issues with large values and nan due to sqrt ops |
| state["grad_avg"] = 1e-2 * torch.rand( |
| 10, requires_grad=True, dtype=torch.float64 |
| ) |
| gradcheck( |
| _diff_fn, |
| ( |
| p, |
| grad, |
| state, |
| RMSprop, |
| { |
| "lr": 0.9, |
| "maximize": True, |
| "momentum": 0.9, |
| "differentiable": True, |
| "centered": True, |
| "weight_decay": 0.1, |
| }, |
| *state.values(), |
| ), |
| ) |
| |
| def test_adadelta(self): |
| state = {} |
| p = torch.rand(10, requires_grad=True, dtype=torch.float64) |
| grad = torch.rand(10, requires_grad=True, dtype=torch.float64) |
| # `step` is not a continuous variable (even though we define it as a float) |
| # and so it shouldn't require gradients. |
| state["step"] = torch.tensor(10.0, requires_grad=False, dtype=torch.float64) |
| state["square_avg"] = torch.rand(10, requires_grad=True, dtype=torch.float64) |
| state["acc_delta"] = torch.rand(10, requires_grad=True, dtype=torch.float64) |
| gradcheck( |
| _diff_fn, |
| ( |
| p, |
| grad, |
| state, |
| Adadelta, |
| {"lr": 0.9, "weight_decay": 0.1, "differentiable": True}, |
| *state.values(), |
| ), |
| ) |
| |
| def test_adagrad(self): |
| state = {} |
| p = torch.rand(10, requires_grad=True, dtype=torch.float64) |
| grad = torch.rand(10, requires_grad=True, dtype=torch.float64) |
| # `step` is not a continuous variable (even though we define it as a float) |
| # and so it shouldn't require gradients. |
| state["step"] = torch.tensor(10.0, requires_grad=False, dtype=torch.float64) |
| state["sum"] = torch.rand(10, requires_grad=True, dtype=torch.float64) |
| gradcheck( |
| _diff_fn, |
| ( |
| p, |
| grad, |
| state, |
| Adagrad, |
| {"lr": 0.9, "weight_decay": 0.1, "differentiable": True}, |
| *state.values(), |
| ), |
| ) |
| |
| def test_adamax(self): |
| state = {} |
| p = torch.rand(10, requires_grad=True, dtype=torch.float64) |
| grad = torch.rand(10, requires_grad=True, dtype=torch.float64) |
| # `step` is not a continuous variable (even though we define it as a float) |
| # and so it shouldn't require gradients. |
| state["step"] = torch.tensor(10.0, requires_grad=False, dtype=torch.float64) |
| state["exp_avg"] = torch.rand(10, requires_grad=True, dtype=torch.float64) |
| state["exp_inf"] = torch.rand(10, requires_grad=True, dtype=torch.float64) |
| gradcheck( |
| _diff_fn, |
| ( |
| p, |
| grad, |
| state, |
| Adamax, |
| {"lr": 0.9, "weight_decay": 0.1, "differentiable": True}, |
| *state.values(), |
| ), |
| ) |
| |
| @skipIfTorchDynamo( |
| "The inplace mu update fails with dynamo, " |
| "since this is only happening when differentiable is enabled, skipping for now" |
| ) |
| def test_asgd(self): |
| state = {} |
| p = torch.rand(10, requires_grad=True, dtype=torch.float64) |
| grad = torch.rand(10, requires_grad=True, dtype=torch.float64) |
| # `step` `eta` & `mu` are not continuous variables (even though we define them as floats) |
| # and so they shouldn't require gradients. |
| state["step"] = torch.tensor(10.0, requires_grad=False, dtype=torch.float64) |
| state["eta"] = torch.tensor(0.9, requires_grad=False, dtype=torch.float64) |
| state["mu"] = torch.tensor(1.0, requires_grad=False, dtype=torch.float64) |
| state["ax"] = torch.rand(10, requires_grad=True, dtype=torch.float64) |
| |
| gradcheck( |
| _diff_fn, |
| ( |
| p, |
| grad, |
| state, |
| ASGD, |
| {"lr": 0.9, "differentiable": True}, |
| *state.values(), |
| ), |
| ) |
| |
| def test_rprop(self): |
| state = {} |
| p = torch.rand(10, requires_grad=True, dtype=torch.float64) |
| grad = torch.rand(10, requires_grad=True, dtype=torch.float64) |
| # `step` is not a continuous variable (even though we define it as a float) |
| # and so it shouldn't require gradients. |
| state["step"] = torch.tensor(10.0, requires_grad=False, dtype=torch.float64) |
| state["prev"] = torch.rand(10, requires_grad=True, dtype=torch.float64) |
| state["step_size"] = torch.rand(10, requires_grad=True, dtype=torch.float64) |
| |
| gradcheck( |
| _diff_fn, |
| ( |
| p, |
| grad, |
| state, |
| Rprop, |
| {"lr": 0.9, "differentiable": True}, |
| *state.values(), |
| ), |
| ) |
| |
| def test_adamw(self): |
| state = {} |
| p = torch.rand(10, requires_grad=True, dtype=torch.float64) |
| grad = torch.rand(10, requires_grad=True, dtype=torch.float64) |
| # `step` is not a continuous variable (even though we define it as a float) |
| # and so it shouldn't require gradients. |
| state["step"] = torch.tensor(10.0, requires_grad=False, dtype=torch.float64) |
| state["exp_avg"] = torch.rand(10, requires_grad=True, dtype=torch.float64) |
| state["exp_avg_sq"] = torch.rand(10, requires_grad=True, dtype=torch.float64) |
| state["max_exp_avg_sq"] = torch.rand( |
| 10, requires_grad=True, dtype=torch.float64 |
| ) |
| |
| gradcheck( |
| _diff_fn, |
| ( |
| p, |
| grad, |
| state, |
| AdamW, |
| {"lr": 0.9, "differentiable": True, "amsgrad": True}, |
| *state.values(), |
| ), |
| ) |
| |
| def test_nadam(self): |
| state = {} |
| p = torch.rand(10, requires_grad=True, dtype=torch.float64) |
| grad = torch.rand(10, requires_grad=True, dtype=torch.float64) |
| # `step` is not a continuous variable (even though we define it as a float) |
| # and so it shouldn't require gradients. |
| state["step"] = torch.tensor(10.0, requires_grad=False, dtype=torch.float64) |
| state["exp_avg"] = torch.rand(10, requires_grad=True, dtype=torch.float64) |
| state["exp_avg_sq"] = torch.rand(10, requires_grad=True, dtype=torch.float64) |
| state["mu_product"] = torch.tensor(1.0, requires_grad=True, dtype=torch.float64) |
| |
| gradcheck( |
| _diff_fn, |
| ( |
| p, |
| grad, |
| state, |
| NAdam, |
| {"lr": 0.9, "differentiable": True}, |
| *state.values(), |
| ), |
| ) |
| |
| gradcheck( |
| _diff_fn, |
| ( |
| p, |
| grad, |
| state, |
| NAdam, |
| {"lr": 0.9, "decoupled_weight_decay": True, "differentiable": True}, |
| *state.values(), |
| ), |
| ) |
| |
| def test_radam(self): |
| state = {} |
| p = torch.rand(10, requires_grad=True, dtype=torch.float64) |
| grad = torch.rand(10, requires_grad=True, dtype=torch.float64) |
| # `step` is not a continuous variable (even though we define it as a float) |
| # and so it shouldn't require gradients. |
| state["step"] = torch.tensor(10.0, requires_grad=False, dtype=torch.float64) |
| state["exp_avg"] = torch.rand(10, requires_grad=True, dtype=torch.float64) |
| state["exp_avg_sq"] = torch.rand(10, requires_grad=True, dtype=torch.float64) |
| |
| gradcheck( |
| _diff_fn, |
| ( |
| p, |
| grad, |
| state, |
| RAdam, |
| {"lr": 0.9, "differentiable": True}, |
| *state.values(), |
| ), |
| ) |
| gradcheck( |
| _diff_fn, |
| ( |
| p, |
| grad, |
| state, |
| RAdam, |
| { |
| "lr": 0.9, |
| "weight_decay": 0.1, |
| "decoupled_weight_decay": True, |
| "differentiable": True, |
| }, |
| *state.values(), |
| ), |
| ) |
| |
| |
| if __name__ == "__main__": |
| print("These tests should be run through test/test_optim.py instead") |