blob: e7a25512d3821bddaf49b707857729bf94f04061 [file] [log] [blame]
# Owner(s): ["module: optimizer"]
from copy import deepcopy
import torch
from optim.test_optim import TestOptim, TestDifferentiableOptimizer # noqa: F401
from optim.test_lrscheduler import TestLRScheduler # noqa: F401
from optim.test_swa_utils import TestSWAUtils # noqa: F401
from torch.testing._internal.common_optimizers import optim_db, optims, OptimizerErrorEnum
from torch.testing._internal.common_device_type import instantiate_device_type_tests, onlyCPU, skipMPS
from torch.testing._internal.common_utils import markDynamoStrictTest, run_tests, TestCase
@markDynamoStrictTest
class TestOptimRenewed(TestCase):
@onlyCPU
@optims([optim for optim in optim_db if optim.optim_error_inputs_func is not None])
def test_errors(self, device, dtype, optim_info):
optim_cls = optim_info.optim_cls
error_inputs = optim_info.optim_error_inputs_func(device=device, dtype=dtype)
for error_input in error_inputs:
optim_input = error_input.optimizer_error_input
params, kwargs = optim_input.params, optim_input.kwargs
if error_input.error_on == OptimizerErrorEnum.CONSTRUCTION_ERROR:
with self.assertRaisesRegex(error_input.error_type, error_input.error_regex):
optim_cls(params, **kwargs)
elif error_input.error_on == OptimizerErrorEnum.STEP_ERROR:
optim = optim_cls(params, **kwargs)
with self.assertRaisesRegex(error_input.error_type, error_input.error_regex):
optim.step()
else:
raise NotImplementedError(f"Unknown error type {error_input.error_on}")
def _test_derived_optimizers(self, device, dtype, optim_info, flag):
assert flag in ("foreach", "fused")
# why 7? iteration 7 is where we start to see differences for RAdam
# params interacting with the small eps value, because that's right
# after rho_t becomes greater than 5 in step 6.
kIterations = 7
optim_inputs = optim_info.optim_inputs_func()
optim_cls = optim_info.optim_cls
for optim_input in optim_inputs:
updated_params, state = [], []
kwargs = deepcopy(optim_input.kwargs)
if (kwargs.get("capturable", False) and
(str(device) == "cpu" or optim_cls.__name__ == "ASGD")):
# capturable is not supported on CPU nor in single tensor ASGD
continue
for flag_value in (False, True):
kwargs[flag] = flag_value
input = torch.tensor(
[0.1, 0.2, 0.3, 0.4, 0.5, 0.6], dtype=dtype, device=device
).reshape(3, 2)
torch.manual_seed(1)
model = torch.nn.Sequential(
torch.nn.Linear(2, 3),
torch.nn.Sigmoid(),
torch.nn.Linear(3, 1),
torch.nn.Sigmoid(),
)
model.to(dtype=dtype, device=device)
# foreach/fused optimizers should be tested with a
# zero_size tensor as its last param.
# ref: https://github.com/pytorch/pytorch/issues/100701
empty_param = torch.empty((), device=device, dtype=dtype, requires_grad=True)
empty_param.grad = torch.rand_like(empty_param)
params = list(model.parameters()) + [empty_param]
optimizer = optim_cls(params, **kwargs)
for i in range(kIterations):
optimizer.zero_grad()
# Test that step behaves as expected (a no-op) when grads are set to None
if i != 3:
output = model(input)
loss = output.sum()
loss.backward()
optimizer.step()
state.append(optimizer.state)
updated_params.append(model.parameters())
og_state, new_state = state
for og_p, new_p in zip(updated_params[0], updated_params[1]):
self.assertEqual(og_p, new_p)
# check that optimizer states are the same
og_p_state = og_state[og_p]
new_p_state = new_state[new_p]
for k in og_p_state:
self.assertEqual(og_p_state[k], new_p_state[k])
@skipMPS # MPS doesn't support torch.float64, see https://github.com/pytorch/pytorch/issues/115350
@optims([optim for optim in optim_db if "foreach" in optim.supported_impls], dtypes=[torch.float64])
def test_foreach_matches_forloop(self, device, dtype, optim_info):
self._test_derived_optimizers(device, dtype, optim_info, "foreach")
@onlyCPU
@optims(optim_db)
def test_optim_infos_do_not_specify_global_cliquey_kwargs(self, device, dtype, optim_info):
global_cliquey_flags = ["foreach", "fused", "differentiable"]
for optim_input in optim_info.optim_inputs_func():
self.assertFalse(any(f for f in global_cliquey_flags if f in optim_input.kwargs))
instantiate_device_type_tests(TestOptimRenewed, globals(), allow_mps=True)
if __name__ == '__main__':
run_tests()