| import argparse |
| import random |
| |
| import torch |
| |
| |
| def bench(nt_a, nt_b, niter): |
| # Warmup |
| nt_c = nt_a.bmm(nt_b) |
| |
| torch.cuda.synchronize() |
| start_event = torch.cuda.Event(enable_timing=True) |
| end_event = torch.cuda.Event(enable_timing=True) |
| start_event.record() |
| for iter in range(niter): |
| nt_c = nt_a.bmm(nt_b) |
| end_event.record() |
| torch.cuda.synchronize() |
| runtime = (start_event.elapsed_time(end_event)) / niter |
| return runtime |
| |
| |
| def sweep_n(niter, dtype): |
| for ntensor in [4, 8, 16, 32, 64, 128, 256]: |
| tensors = [torch.randn(256, random.randint(100, 200)) for t in range(ntensor)] |
| nt_a = torch.nested.nested_tensor( |
| tensors, |
| dtype=dtype, |
| device="cuda", |
| ) |
| nt_b = torch.nested.nested_tensor( |
| [t.t() for t in tensors], |
| dtype=dtype, |
| device="cuda", |
| ) |
| runtime = bench(nt_a, nt_b, niter) |
| nt_a_size = torch.ops.aten._nested_tensor_size(nt_a) |
| lengths = nt_a_size[:, 1] |
| print( |
| ",".join( |
| map( |
| str, |
| [ |
| ntensor, |
| dtype, |
| lengths.min().item(), |
| lengths.float().mean().item(), |
| lengths.max().item(), |
| runtime, |
| ], |
| ) |
| ) |
| ) |
| |
| |
| if __name__ == "__main__": |
| random.seed(123) |
| parser = argparse.ArgumentParser(description="Nested Tensor BMM Benchmark") |
| parser.add_argument("--niter", default="10", type=int) |
| |
| args = parser.parse_args() |
| niter = args.niter |
| |
| print("ntensor,dtype,min_length,mean_length,max_length,runtime") |
| sweep_n(niter, torch.float32) |
| sweep_n(niter, torch.float16) |