| import timeit |
| |
| import torch |
| import torch.nn as nn |
| from functorch.compile import compiled_module, tvm_compile |
| |
| |
| def nop(f, _): |
| return f |
| |
| |
| fw_compiler = tvm_compile(target="llvm", tuning_logfile="fw_keops") |
| bw_compiler = tvm_compile(target="llvm", tuning_logfile="bw_keops") |
| fw_compiler = nop |
| bw_compiler = nop |
| |
| |
| def run(mod, input): |
| out = mod(input) |
| out.sum().backward() |
| grads = [p.grad for p in mod.parameters()] |
| return (out, *grads) |
| |
| |
| class Foo(nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.param = nn.Parameter(torch.randn(1)) |
| self.register_buffer("buf", torch.randn(1)) |
| |
| def forward(self, x): |
| return (self.param * x + self.buf).sum(dim=0) |
| |
| |
| input = torch.randn(1) |
| mod = Foo() |
| compiled_mod = compiled_module(mod, fw_compiler, bw_compiler) |
| |
| for a, b in zip(run(mod, input), run(compiled_mod, input)): |
| torch.testing.assert_close(a, b) |
| |
| out = mod(input) |
| out.sum().backward() |
| mod.param.data -= mod.param.grad |
| compiled_mod.orig_module.param.data -= compiled_mod.orig_module.param.grad |
| compiled_mod.orig_module.param.grad = None |
| |
| for a, b in zip(run(mod, input), run(compiled_mod, input)): |
| torch.testing.assert_close(a, b) |
| |
| for _ in range(5): |
| i = 10000 |
| t = timeit.Timer("mod(input)", globals=globals()).timeit(10000) |
| print(f"eager {t/i*1e6}") |
| t = timeit.Timer("compiled_mod(input)", globals=globals()).timeit(10000) |
| print(f"compiled {t/i*1e6}") |