blob: 19e31a0f829071d41b6ba2bad24958d9c8736ecc [file] [log] [blame]
import argparse
import csv
import dataclasses
import os
from generate import run_llama2_7b_bf16, run_llama2_7b_int8, run_mixtral_8x7b_int8
from triton.testing import do_bench
import torch
import torch.nn as nn
from torch.utils.flop_counter import FlopCounterMode
WARMUP_ITER = 5
A100_80G_BF16_TFLOPS = 312
@dataclasses.dataclass
class Experiment:
name: str
metric: str
target: float
actual: float
dtype: str
device: str
is_model: bool = False
class SimpleMLP(nn.Module):
def __init__(self, input_dim, hidden_dim, output_dim, dtype):
super().__init__()
self.layers = nn.ModuleList(
[
nn.Linear(input_dim, hidden_dim, dtype=dtype),
nn.LayerNorm(hidden_dim, dtype=dtype),
nn.Linear(hidden_dim, output_dim, dtype=dtype),
nn.LayerNorm(output_dim, dtype=dtype),
]
)
def forward(self, x):
for layer in self.layers:
x = layer(x)
return x
def run_mlp_layer_norm_gelu(device: str = "cuda"):
dtype_flops_utilization_map = {
torch.bfloat16: "0.71",
}
input_shapes = [1024, 4096, 8192, 16384]
intermediate_size = 14336
results = []
for dtype, expected_flops_utilization in dtype_flops_utilization_map.items():
flops_utilization = 0
for D in input_shapes:
mod = SimpleMLP(
input_dim=D, hidden_dim=intermediate_size, output_dim=D, dtype=dtype
).to(device)
x = torch.randn(D, device=device, dtype=torch.bfloat16)
with FlopCounterMode(display=False) as mode:
mod(x)
flops = mode.get_total_flops()
compiled_mod = torch.compile(mod, dynamic=False)
for _ in range(WARMUP_ITER):
compiled_mod(x)
us_per_iter = do_bench(lambda: compiled_mod(x)) * 1000
flops_utilization += us_per_iter * flops / 1e9 / A100_80G_BF16_TFLOPS
flops_utilization = flops_utilization / len(input_shapes)
dtype_str = str(dtype).replace("torch.", "")
results.append(
Experiment(
"mlp_layer_norm_gelu",
"flops_utilization",
expected_flops_utilization,
f"{flops_utilization:.02f}",
dtype_str,
device,
)
)
return results
def run_layer_norm(device: str = "cuda"):
dtype_memory_bandwidth_map = {
torch.bfloat16: "1017",
}
input_shapes = [1024, 4096, 8192, 16384]
BS = 4096
results = []
for dtype, expected_memory_bandwidth in dtype_memory_bandwidth_map.items():
memory_bandwidth = 0
for D in input_shapes:
mod = nn.LayerNorm(D).to(device)
x = torch.randn(BS, D, device=device, dtype=dtype)
compiled_mod = torch.compile(mod, dynamic=False)
for _ in range(WARMUP_ITER):
compiled_mod(x)
us_per_iter = do_bench(lambda: compiled_mod(x)) * 1000
memory_bandwidth += (1e6 / us_per_iter) * 2 * BS * D * dtype.itemsize / 1e9
memory_bandwidth = memory_bandwidth / len(input_shapes)
dtype_str = str(dtype).replace("torch.", "")
results.append(
Experiment(
"layer_norm",
"memory_bandwidth(GB/s)",
expected_memory_bandwidth,
f"{memory_bandwidth:.02f}",
dtype_str,
device,
)
)
return results
@torch._inductor.config.patch(coordinate_descent_tuning=True)
def run_gather_gemv(device: str = "cuda"):
E = 8
dtype_memory_bandwidth_map = {
torch.int8: "1113",
torch.bfloat16: "1249",
}
input_shapes = [1024, 4096, 8192, 16384]
results = []
for dtype, expected_memory_bandwidth in dtype_memory_bandwidth_map.items():
memory_bandwidth = 0
for D in input_shapes:
def gather_gemv(W, score_idxs, x):
return W[score_idxs].to(x.dtype) @ x
W = torch.randn(E, D, D, device=device).to(dtype=dtype)
x = torch.randn(D, device=device, dtype=torch.bfloat16)
score_idxs = torch.tensor([3, 5], device=device)
compiled_fn = torch.compile(gather_gemv, dynamic=False)
for _ in range(WARMUP_ITER):
compiled_fn(W, score_idxs, x)
us_per_iter = do_bench(lambda: compiled_fn(W, score_idxs, x)) * 1000
memory_bandwidth += (1e6 / us_per_iter) * 2 * D * D * dtype.itemsize / 1e9
memory_bandwidth = memory_bandwidth / len(input_shapes)
dtype_str = str(dtype).replace("torch.", "")
results.append(
Experiment(
"gather_gemv",
"memory_bandwidth(GB/s)",
expected_memory_bandwidth,
f"{memory_bandwidth:.02f}",
dtype_str,
device,
)
)
return results
@torch._inductor.config.patch(coordinate_descent_tuning=True)
def run_gemv(device: str = "cuda"):
dtype_memory_bandwidth_map = {
torch.int8: "990",
torch.bfloat16: "1137",
}
input_shapes = [1024, 4096, 8192, 16384]
results = []
for dtype, expected_memory_bandwidth in dtype_memory_bandwidth_map.items():
memory_bandwidth = 0
for D in input_shapes:
def gemv(W, x):
return W.to(x.dtype) @ x
W = torch.randn(D, D, device="cuda").to(dtype=dtype)
x = torch.randn(D, device="cuda", dtype=torch.bfloat16)
compiled_fn = torch.compile(gemv, dynamic=False)
for _ in range(WARMUP_ITER):
compiled_fn(W, x)
us_per_iter = do_bench(lambda: compiled_fn(W, x)) * 1000
memory_bandwidth += (1e6 / us_per_iter) * D * D * dtype.itemsize / 1e9
memory_bandwidth = memory_bandwidth / len(input_shapes)
dtype_str = str(dtype).replace("torch.", "")
results.append(
Experiment(
"gemv",
"memory_bandwidth(GB/s)",
expected_memory_bandwidth,
f"{memory_bandwidth:.02f}",
dtype_str,
device,
)
)
return results
def output_csv(output_file, headers, row):
if os.path.exists(output_file):
with open(output_file) as fd:
lines = list(csv.reader(fd)) or [[]]
if headers and len(headers) > len(lines[0]):
# if prior results failed the header might not be filled in yet
lines[0] = headers
else:
headers = lines[0]
else:
lines = [headers]
if output_file != DEFAULT_OUTPUT_FILE:
os.makedirs(os.path.dirname(output_file), exist_ok=True)
lines.append([(f"{x:.6f}" if isinstance(x, float) else x) for x in row])
with open(output_file, "w") as fd:
writer = csv.writer(fd, lineterminator="\n")
for line in lines:
writer.writerow(list(line) + ["0"] * (len(headers) - len(line)))
DEFAULT_OUTPUT_FILE = "gpt_fast_benchmark.csv"
all_experiments = {
# A list of GPT models: LlaMa, Mixtral, etc.
run_llama2_7b_bf16,
run_llama2_7b_int8,
run_mixtral_8x7b_int8,
# A list of micro-benchmarks.
run_mlp_layer_norm_gelu,
run_layer_norm,
run_gather_gemv,
run_gemv,
}
def main(output_file=DEFAULT_OUTPUT_FILE):
results = []
for func in all_experiments:
lst = func()
for x in lst:
results.append(dataclasses.astuple(x))
headers = [field.name for field in dataclasses.fields(Experiment)]
for row in results:
output_csv(output_file, headers, row)
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Run experiments.")
parser.add_argument(
"--output",
default=DEFAULT_OUTPUT_FILE,
help="Set the output CSV file to save the benchmark results",
)
args = parser.parse_args()
main(output_file=args.output)