blob: deb3d8f8b60424571ebafaa5116eaa17f9890bce [file] [log] [blame]
import torch
import torch._dynamo
import torch._dynamo.config
import torch._inductor.config as config
import triton
from benchmark_helper import time_with_torch_timer
# The flag below controls whether to allow TF32 on matmul. This flag defaults to True.
torch.backends.cuda.matmul.allow_tf32 = True
# The flag below controls whether to allow TF32 on cuDNN. This flag defaults to True.
torch.backends.cudnn.allow_tf32 = True
@torch._dynamo.optimize("inductor", nopython=True)
def inductor_aten_mm(a, b):
return torch.mm(a, b)
@torch._dynamo.optimize("inductor", nopython=True)
def inductor_triton_mm(a, b):
return torch.mm(a, b)
def torch_mm(a, b):
return torch.mm(a, b)
def triton_mm(a, b):
return triton.ops.matmul(a, b)
def test_total_time(shapes):
print("shape; torch mm; triton mm; inductor aten mm; inductor triton mm")
for i in range(len(shapes)):
a_shape, b_shape = shapes[i]
print(a_shape, "x", b_shape, end="; ")
a = torch.randn(a_shape, device="cuda", dtype=torch.float16)
b = torch.randn(b_shape, device="cuda", dtype=a.dtype)
config.triton.mm = "aten"
inductor_aten_mm(a, b)
config.triton.mm = "triton"
inductor_triton_mm(a, b)
torch_ms = time_with_torch_timer(torch_mm, (a, b)).mean * 1000
triton_ms = time_with_torch_timer(triton_mm, (a, b)).mean * 1000
config.triton.mm = "aten"
ind_aten_ms = time_with_torch_timer(inductor_aten_mm, (a, b)).mean * 1000
config.triton.mm = "triton"
ind_triton_ms = time_with_torch_timer(inductor_triton_mm, (a, b)).mean * 1000
print(torch_ms, triton_ms, ind_aten_ms, ind_triton_ms, sep="; ")
torch._dynamo.reset()
def test_GPU_time(shapes):
print("shape; torch mm; triton mm; inductor aten mm; inductor triton mm")
for i in range(len(shapes)):
a_shape, b_shape = shapes[i]
print(a_shape, "x", b_shape, end="; ")
a = torch.randn(a_shape, device="cuda", dtype=torch.float16)
b = torch.randn(b_shape, device="cuda", dtype=a.dtype)
config.triton.mm = "aten"
inductor_aten_mm(a, b)
config.triton.mm = "triton"
inductor_triton_mm(a, b)
torch_ms, _, _ = triton.testing.do_bench(lambda: torch_mm(a, b))
triton_ms, _, _ = triton.testing.do_bench(lambda: triton_mm(a, b))
ind_aten_ms, _, _ = triton.testing.do_bench(lambda: inductor_aten_mm(a, b))
ind_triton_ms, _, _ = triton.testing.do_bench(lambda: inductor_triton_mm(a, b))
print(torch_ms, triton_ms, ind_aten_ms, ind_triton_ms, sep="; ")
torch._dynamo.reset()
if __name__ == "__main__":
shapes = [
# alexnet
([128, 9216], [9216, 4096]),
([128, 4096], [4096, 4096]),
([128, 4096], [4096, 1000]),
# BERT
([2048, 768], [768, 768]),
([2048, 768], [768, 3072]),
([2048, 3072], [3072, 768]),
# hf_GPT2
([1024, 768], [768, 768]),
([1024, 768], [768, 3072]),
([1024, 3072], [3072, 768]),
([1024, 768], [768, 2304]),
]
print("test total time")
test_total_time(shapes)
print("test GPU time")
test_GPU_time(shapes)
# Results Preview on AWS AI cluster
"""
test total time
shape; torch mm; triton mm; inductor aten mm; inductor triton mm
[128, 9216] x [9216, 4096]; 0.07240759208798409; 0.10885953903198242; 0.20063146017491817; 0.20054904278367758
[128, 4096] x [4096, 4096]; 0.03640300128608942; 0.10960095096379519; 0.09948539081960917; 0.0996188772842288
[128, 4096] x [4096, 1000]; 0.02215010579675436; 0.12592008337378502; 0.031120930798351765; 0.0370654184371233
[2048, 768] x [768, 768]; 0.023501068353652954; 0.10804693214595318; 0.03004650119692087; 0.0276932492852211
[2048, 768] x [768, 3072]; 0.045639658346772194; 0.10883208829909563; 0.062736920081079; 0.06480381824076176
[2048, 3072] x [3072, 768]; 0.054093082435429096; 0.10804777964949608; 0.08744294755160809; 0.07766005117446184
[1024, 768] x [768, 768]; 0.021525858901441097; 0.10909941978752613; 0.02656651195138693; 0.02683836966753006
[1024, 768] x [768, 3072]; 0.027319076471030712; 0.10825308971107006; 0.040118801407516; 0.039282338693737984
[1024, 3072] x [3072, 768]; 0.034132059663534164; 0.10594133753329515; 0.05069758277386427; 0.04572632722556591
[1024, 768] x [768, 2304]; 0.02529360819607973; 0.10486091021448374; 0.03724239766597748; 0.036449190229177475
test GPU time
shape; torch mm; triton mm; inductor aten mm; inductor triton mm
[128, 9216] x [9216, 4096]; 0.09113600105047226; 0.09011200070381165; 0.21606400609016418; 0.21606400609016418
[128, 4096] x [4096, 4096]; 0.053247999399900436; 0.05222399905323982; 0.1157120019197464; 0.1157120019197464
[128, 4096] x [4096, 1000]; 0.026623999699950218; 0.02969600073993206; 0.04710400104522705; 0.05222399905323982
[2048, 768] x [768, 768]; 0.02457600086927414; 0.020479999482631683; 0.04095999896526337; 0.03993599861860275
[2048, 768] x [768, 3072]; 0.05119999870657921; 0.05222399905323982; 0.07475200295448303; 0.07577600330114365
[2048, 3072] x [3072, 768]; 0.05939200147986412; 0.05222399905323982; 0.09830400347709656; 0.0870399996638298
[1024, 768] x [768, 768]; 0.01945599913597107; 0.016383999958634377; 0.03276799991726875; 0.03276799991726875
[1024, 768] x [768, 3072]; 0.03174399957060814; 0.03276799991726875; 0.053247999399900436; 0.053247999399900436
[1024, 3072] x [3072, 768]; 0.04403200000524521; 0.03379200026392937; 0.06860800087451935; 0.062463998794555664
[1024, 768] x [768, 2304]; 0.02969600073993206; 0.02969600073993206; 0.04915200173854828; 0.048128001391887665
"""