| # Owner(s): ["module: dynamo"] |
| """ |
| PYTEST_DONT_REWRITE (prevents pytest from rewriting assertions, which interferes |
| with test_adam in OptimizerTests) |
| """ |
| import functools |
| |
| import torch |
| import torch._dynamo |
| import torch._dynamo.test_case |
| import torch._dynamo.testing |
| from torch.nn import Parameter |
| |
| |
| class MyOptimizer(torch.optim.Optimizer): |
| def __init__(self, params): |
| super().__init__(params, {}) |
| |
| def _init_group(self, params, group): |
| any_complex = False |
| for p in group["params"]: |
| params.append(p) |
| any_complex |= p.is_complex() |
| return any_complex |
| |
| def step(self): |
| for group in self.param_groups: |
| params = [] |
| any_complex = self._init_group(params, group) |
| if any_complex: |
| params[0] -= 1 |
| else: |
| params[0] += 1 |
| |
| |
| class End2EndTests(torch._dynamo.test_case.TestCase): |
| # https://github.com/pytorch/torchdynamo/issues/1604 |
| def test_optimizing_over_tensor_with_requires_grad(self): |
| class Net(torch.nn.Module): |
| def forward(self, x, y): |
| z = torch.bmm(x, y) |
| z = torch.flatten(z, 1) |
| return z |
| |
| def training_iter_fn(batch, model, optimizer): |
| optimizer.zero_grad() |
| out = model(**batch) |
| target = torch.tensor([0, 7]) |
| loss = torch.nn.CrossEntropyLoss()(out, target) |
| loss.backward() |
| optimizer.step() |
| return loss |
| |
| net = Net() |
| input1 = torch.randn(2, 1, 4) |
| input2 = torch.randn(2, 4, 8, requires_grad=True) |
| optimizer = torch.optim.Adam([input2], lr=0.1) |
| |
| cnts = torch._dynamo.testing.CompileCounter() |
| opt_training_iter_fn = torch._dynamo.optimize(cnts)(training_iter_fn) |
| batch = {"x": input1, "y": input2} |
| for _ in range(2): |
| opt_training_iter_fn(batch, net, optimizer) |
| self.assertEqual(cnts.frame_count, 2) |
| |
| def test_state_dict(self): |
| @torch.compile(backend="eager") |
| def _test_state_dict(weight, bias, input): |
| def fn_base(optimizer, weight, bias): |
| optimizer.zero_grad() |
| i = input |
| loss = (weight.mv(i) + bias).pow(2).sum() |
| loss.backward() |
| return loss |
| |
| optimizer = torch.optim.Adagrad([weight, bias]) |
| fn = functools.partial(fn_base, optimizer, weight, bias) |
| return optimizer, fn |
| |
| optimizer, fn = _test_state_dict( |
| Parameter(torch.randn(10, 5)), |
| Parameter(torch.randn(10)), |
| torch.randn(5, requires_grad=True), |
| ) |
| optimizer.step(fn) |
| |
| def test_init_group(self): |
| for dtype in [torch.float32, torch.cfloat]: |
| tensor = torch.randn(5, 5, dtype=dtype) |
| params = Parameter(tensor.detach().clone(), requires_grad=False) |
| opt_params = Parameter(tensor.detach().clone(), requires_grad=False) |
| |
| optim = MyOptimizer([params]) |
| optim.step() |
| |
| opt_optim = MyOptimizer([opt_params]) |
| opt_step = torch.compile(backend="eager", fullgraph=True)(opt_optim.step) |
| opt_step() |
| |
| self.assertEqual(params, opt_params) |
| |
| |
| if __name__ == "__main__": |
| from torch._dynamo.test_case import run_tests |
| |
| run_tests() |