| # Copyright (c) Facebook, Inc. and its affiliates. |
| # All rights reserved. |
| # |
| # This source code is licensed under the BSD-style license found in the |
| # LICENSE file in the root directory of this source tree. |
| |
| import time |
| |
| import torch |
| from functorch import grad, make_fx |
| from functorch.compile import nnc_jit |
| |
| |
| def f(x): |
| return torch.sin(x).sum() |
| |
| |
| inp = torch.randn(100) |
| grad_pt = grad(f) |
| grad_fx = make_fx(grad_pt)(inp) |
| grad_nnc = nnc_jit(grad_pt) |
| |
| |
| def bench(name, f, iters=10000, warmup=3): |
| for _ in range(warmup): |
| f() |
| begin = time.time() |
| for _ in range(iters): |
| f() |
| print(f"{name}: ", time.time() - begin) |
| |
| |
| bench("Pytorch: ", lambda: grad_pt(inp)) |
| bench("FX: ", lambda: grad_fx(inp)) |
| bench("NNC: ", lambda: grad_nnc(inp)) |