from torch.utils.benchmark import Timer | |
def time_with_torch_timer(fn, args, kwargs=None, iters=100): | |
kwargs = kwargs or {} | |
env = {"args": args, "kwargs": kwargs, "fn": fn} | |
fn_call = "fn(*args, **kwargs)" | |
# Measure end-to-end time | |
timer = Timer(stmt=f"{fn_call}", globals=env) | |
tt = timer.timeit(iters) | |
return tt |