| import torch |
| |
| |
| @torch.jit.script |
| def fn(x, scale, shift): |
| return scale * x / shift |
| |
| |
| @torch.jit.script |
| def recurrent(x, scale, shift): |
| y = x |
| for i in range(100): |
| y = fn(y, scale, shift) |
| return y |
| |
| |
| x = torch.randn(2, 2, device='cuda') |
| scale = torch.randn(2, 2, device='cuda', requires_grad=True) |
| shift = torch.randn(2, 2, device='cuda', requires_grad=True) |
| inputs = [x, scale, shift] |
| |
| |
| out = recurrent(x, scale, shift) |
| recurrent.graph_for(x, scale, shift) |
| |
| |
| import torch |
| |
| |
| @torch.jit.script |
| def recurrent_scaleshift(x, scale, shift): |
| y = x |
| for i in range(64): |
| y = scale * y + shift |
| return y |
| |
| |
| x = torch.randn(2, 2, device='cuda') |
| scale = torch.randn(2, 2, device='cuda', requires_grad=True) |
| shift = torch.randn(2, 2, device='cuda', requires_grad=True) |
| inputs = [x, scale, shift] |
| out = recurrent_scaleshift(x, scale, shift) |
| recurrent_scaleshift.graph_for(x, scale, shift) |
| |
| |
| import torch |
| x = torch.tensor([]) |
| x.requires_grad = True |
| x.mean().backward() # no error triggered |
| x = x.cuda() |
| x.mean().backward() |