| # Owner(s): ["oncall: distributed"] |
| |
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from torch.optim import SGD, Adam, AdamW |
| from torch.testing._internal.common_utils import TestCase, run_tests |
| from torch.distributed.optim.utils import functional_optim_map |
| |
| class MyModule(torch.nn.Module): |
| def __init__(self): |
| super().__init__() |
| torch.manual_seed(0) |
| self.lin1 = nn.Linear(3, 3, bias=False) |
| self.lin2 = nn.Linear(3, 3, bias=False) |
| |
| def forward(self, t1): |
| return self.lin2(F.relu(self.lin1(t1))) |
| |
| |
| class TestFunctionalOptimParity(TestCase): |
| def _validate_parameters(self, params_1, params_2): |
| for p1, p2 in zip(params_1, params_2): |
| self.assertEqual(p1, p2) |
| |
| def _test_functional_optim_parity(self, optim_cls, *args, **kwargs): |
| module_optim = MyModule() |
| module_functional = MyModule() |
| optim_params = module_optim.parameters() |
| functional_params = module_functional.parameters() |
| optim = optim_cls(optim_params, *args, **kwargs) |
| functional_optim_cls = functional_optim_map.get(optim_cls, None) |
| if not functional_optim_cls: |
| raise ValueError(f"Functional optimizer not implemented for {optim_cls}") |
| optim_functional = functional_optim_cls( |
| [], *args, **kwargs, _allow_empty_param_list=True |
| ) |
| if not hasattr(optim_functional, "step_param"): |
| raise ValueError( |
| f"Functional optimizer class {optim_functional} must implement step_param method." |
| ) |
| |
| # Initial weights should match |
| self._validate_parameters( |
| module_optim.parameters(), module_functional.parameters() |
| ) |
| # Save old parameters to verify optimizer modifies them. |
| old_module_optim_params = [ |
| param.clone().detach() for param in module_optim.parameters() |
| ] |
| old_module_functional_params = [ |
| param.clone().detach() for param in module_functional.parameters() |
| ] |
| |
| t1 = torch.randn(3, 3) |
| for _ in range(10): |
| module_optim.zero_grad() |
| module_functional.zero_grad() |
| # Forward + Backward |
| optim_out = module_optim(t1).sum() |
| functional_out = module_functional(t1).sum() |
| optim_out.backward() |
| functional_out.backward() |
| # Optimizer step |
| optim.step() |
| # Functional optimizer step_param |
| for param in module_functional.parameters(): |
| grad = param.grad |
| optim_functional.step_param(param, grad) |
| |
| # Validate parameters are equal |
| for optim_param, functional_param in zip( |
| module_optim.parameters(), module_functional.parameters() |
| ): |
| self.assertEqual(optim_param, functional_param) |
| # Validate parameters are modified. |
| for i, (optim_param, functional_param) in enumerate( |
| zip(module_optim.parameters(), module_functional.parameters()) |
| ): |
| self.assertNotEqual(old_module_optim_params[i], optim_param) |
| self.assertNotEqual(old_module_functional_params[i], functional_param) |
| |
| def test_functional_optim_parity_sgd(self): |
| self._test_functional_optim_parity(SGD, 1e-2, momentum=0.9, weight_decay=0.01) |
| |
| def test_functional_optim_parity_adam(self): |
| self._test_functional_optim_parity(Adam, 1e-2, betas=(0.9, 0.999), eps=1e-6) |
| |
| def test_functional_optim_parity_adam_w(self): |
| self._test_functional_optim_parity(AdamW, 1e-2, betas=(0.9, 0.999), eps=1e-6) |
| |
| |
| if __name__ == "__main__": |
| run_tests() |