| import time |
| |
| import torch |
| import torch.utils |
| from functorch.compile import aot_function, tvm_compile |
| |
| a = torch.randn(2000, 1, 4, requires_grad=True) |
| b = torch.randn(1, 2000, 4) |
| |
| |
| def f(a): |
| return (a * b).sum(dim=0) |
| |
| |
| fw_compiler = tvm_compile(target="llvm", tuning_logfile="fw_keops") |
| bw_compiler = tvm_compile(target="llvm", tuning_logfile="bw_keops") |
| compiled_f = aot_function(f, fw_compiler, bw_compiler) |
| |
| # fw_compiler = lambda x, _: x |
| # bw_compiler = lambda x, _: x |
| iters = 10 |
| out = compiled_f(a) |
| out.sum().backward() |
| |
| |
| def bench(func): |
| begin = time.time() |
| for _ in range(iters): |
| out = func(a).sin() |
| out.sum().backward() |
| a.grad = None |
| print(time.time() - begin) |
| |
| |
| def bench_jax(): |
| import jax |
| import jax.numpy as jnp |
| |
| jax_a = jnp.array(a.detach().numpy()) |
| jax_b = jnp.array(b.detach().numpy()) |
| |
| def f(a): |
| return jnp.sin((a * jax_b).sum(axis=[0])).sum() |
| |
| jit_f = jax.jit(jax.grad(f)) |
| jit_f(jax_a) |
| begin = time.time() |
| for _ in range(iters): |
| out = jit_f(jax_a) |
| out.block_until_ready() |
| print(time.time() - begin) |
| # for |
| |
| |
| bench(f) |
| bench(compiled_f) |
| # bench_jax() |