| import torch |
| |
| import torch._dynamo |
| import torch._dynamo.config |
| import torch._inductor.config as config |
| from benchmark_helper import time_with_torch_timer |
| |
| |
| @torch._dynamo.optimize("inductor", nopython=True) |
| def inductor_aten_bmm(a, b): |
| return torch.bmm(a, b) |
| |
| |
| @torch._dynamo.optimize("inductor", nopython=True) |
| def inductor_triton_bmm(a, b): |
| return torch.bmm(a, b) |
| |
| |
| def torch_bmm(a, b): |
| return torch.bmm(a, b) |
| |
| |
| def test_total_time(shapes): |
| print("shape; torch bmm; inductor aten bmm; inductor triton bmm") |
| for i in range(len(shapes)): |
| a_shape, b_shape = shapes[i] |
| print(a_shape, "x", b_shape, end="; ") |
| a = torch.randn(a_shape, device="cuda", dtype=torch.float16) |
| b = torch.randn(b_shape, device="cuda", dtype=a.dtype) |
| |
| config.triton.use_bmm = False |
| inductor_aten_bmm(a, b) |
| |
| config.triton.use_bmm = True |
| inductor_triton_bmm(a, b) |
| |
| torch_ms = time_with_torch_timer(torch_bmm, (a, b)).mean * 1000 |
| |
| config.triton.use_bmm = False |
| ind_aten_ms = time_with_torch_timer(inductor_aten_bmm, (a, b)).mean * 1000 |
| |
| config.triton.use_bmm = True |
| ind_triton_ms = time_with_torch_timer(inductor_triton_bmm, (a, b)).mean * 1000 |
| |
| print(torch_ms, ind_aten_ms, ind_triton_ms, sep="; ") |
| |
| |
| if __name__ == "__main__": |
| shapes = [ |
| # BERT (all) |
| ([192, 128, 64], [192, 64, 128]), |
| ([192, 128, 128], [192, 128, 64]), |
| # hf_GPT2 (all) |
| ([12, 1024, 1024], [12, 1024, 64]), |
| ([12, 1024, 64], [12, 64, 1024]), |
| # hf_Albert (all) |
| ([12, 512, 64], [12, 64, 512]), |
| ([12, 512, 512], [12, 512, 64]), |
| ] |
| |
| test_total_time(shapes) |