blob: d1f856b7df4109e17a220aaada789449a9024559 [file] [log] [blame]
import torch
def create_blocked_tensor(B, M, N, blocksize, sparsity, dtype, device):
assert (
sparsity <= 1.0 and sparsity >= 0.0
), "sparsity should be a value between 0 and 1"
assert M % blocksize[0] == 0
assert N % blocksize[1] == 0
shape = (B, M // blocksize[0], N // blocksize[1])[int(B == 0) :]
A = torch.bernoulli(torch.full(shape, 1 - sparsity, dtype=dtype, device=device))
expected_nnz = int((1 - sparsity) * M * N / (blocksize[0] * blocksize[1]))
nonzero_indices = A.flatten().nonzero()
actual_nnz = nonzero_indices.shape[0]
if actual_nnz > expected_nnz:
selected_nonzeros = torch.randperm(actual_nnz)[: actual_nnz - expected_nnz]
A.flatten()[nonzero_indices[selected_nonzeros]] = 0
elif actual_nnz < expected_nnz:
zero_indices = (A == 0).flatten().nonzero()
selected_zeros = torch.randperm(zero_indices.shape[0])[
: expected_nnz - actual_nnz
]
A.flatten()[zero_indices[selected_zeros]] = 1
A = torch.repeat_interleave(A, blocksize[0], dim=-2)
A = torch.repeat_interleave(A, blocksize[1], dim=-1)
return A
def _test_worker(test_func):
import triton
ms, ms_min, ms_max = triton.testing.do_bench(
test_func, warmup=500, rep=100, fast_flush=False
)
tflops = 2 * m * k * n * 1e-12 / (ms * 1e-3)
return ms, tflops
def test_dense_dense_mm(x, y, **meta):
def test_func(x=x.to_dense(), y=y):
return torch.matmul(x, y)
return _test_worker(test_func)
def test_torch_matmul(x, y, **meta):
def test_func(x=x, y=y):
return torch.matmul(x, y)
return _test_worker(test_func)
def test_bsr_dense_mm(x, y, **meta):
from torch.sparse._triton_ops import bsr_dense_mm
def test_func(x=x, y=y):
return bsr_dense_mm(
x, y, meta=dict(GROUP_SIZE_ROW=4, num_stages=1, num_warps=4)
)
return _test_worker(test_func)
def test_bsr_dense_mm_with_meta(x, y, **meta):
from torch.sparse._triton_ops import bsr_dense_mm
def test_func(x=x, y=y, meta=meta):
return bsr_dense_mm(x, y, meta=meta)
return _test_worker(test_func)
def test_bsr_scatter_mm2(x, y, **meta):
from torch.sparse._triton_ops import bsr_scatter_mm, bsr_scatter_mm_indices_data
indices_data = bsr_scatter_mm_indices_data(
x, y, indices_format="scatter_mm", **meta
)
def test_func(x=x, y=y):
return bsr_scatter_mm(x, y, indices_data=indices_data)
return _test_worker(test_func)
def test_bsr_scatter_mm6(x, y, **meta):
from torch.sparse._triton_ops import bsr_scatter_mm, bsr_scatter_mm_indices_data
indices_data = bsr_scatter_mm_indices_data(
x, y, indices_format="bsr_strided_mm_compressed", **meta
)
def test_func(x=x, y=y):
return bsr_scatter_mm(x, y, indices_data=indices_data)
return _test_worker(test_func)
def test_bsr_scatter_mm(x, y, **meta):
from torch.sparse._triton_ops import bsr_scatter_mm, bsr_scatter_mm_indices_data
def test_func(x=x, y=y):
indices_data = bsr_scatter_mm_indices_data(
x, y, indices_format="bsr_strided_mm_compressed", **meta
)
return bsr_scatter_mm(x, y, indices_data=indices_data)
return _test_worker(test_func)
def test_linear(x, y, **meta):
import torch.nn.functional as F
def test_func(x=x, y=y.transpose(-2, -1)):
return F.linear(y, x)
return _test_worker(test_func)
if __name__ == "__main__":
import argparse
import atexit
import itertools
import sys
import triton
from torch.testing import make_tensor
torch.manual_seed(0)
def integer_list(a):
return list(map(int, a.split(",")))
def float_list(a):
return list(map(float, a.split(",")))
def integer_or_float_list(a):
lst = []
for n in a.split(","):
if n.count(":") == 1:
start, end = map(int, n.split(":"))
lst.extend(range(start, end))
elif n.count(":") == 2:
start, end, step = map(int, n.split(":"))
lst.extend(range(start, end, step))
elif "." in n:
lst.append(float(n))
else:
lst.append(int(n))
return lst
parser = argparse.ArgumentParser(description="SpTritonOps")
parser.add_argument(
"--ops",
default="dense_dense_mm,bsr_dense_mm,bsr_scatter_mm6",
type=str,
)
parser.add_argument("--b", default="0", type=int)
parser.add_argument("--m", default="1024", type=integer_list)
parser.add_argument("--k", default=None, type=integer_list)
parser.add_argument("--n", default=None, type=integer_list)
parser.add_argument("--bm", default="16", type=integer_list)
parser.add_argument("--bk", default=None, type=integer_list)
parser.add_argument("--tile_m", default=None, type=integer_list)
parser.add_argument("--tile_n", default=None, type=integer_list)
parser.add_argument("--split_n", default=None, type=integer_list)
parser.add_argument("--group_size", default=None, type=integer_list)
parser.add_argument("--num_warps", default=None, type=integer_list)
parser.add_argument("--num_stages", default=None, type=integer_list)
parser.add_argument("--sparsity", default="0.5", type=integer_or_float_list)
parser.add_argument("--dtype", default="float16", type=str)
parser.add_argument("--device", default="cuda", type=str)
parser.add_argument("--repeat", default="1", type=int)
parser.add_argument("--outfile", default="stdout", type=str)
parser.add_argument("--star", default=False, action="store_true")
args = parser.parse_args()
if args.outfile == "stdout":
outfile = sys.stdout
elif args.outfile == "stderr":
outfile = sys.stderr
else:
outfile = open(args.outfile, "a")
ops = args.ops.split(",")
b = args.b
m_list = args.m or [1024]
n_list = args.n or [None]
k_list = args.k or [None]
bm_list = args.bm or [16]
bk_list = args.bk or [None]
split_n_list = args.split_n or [None]
tile_m_list = args.tile_m or [None]
tile_n_list = args.tile_n or [None]
group_size_list = args.group_size or [None]
num_warps_list = args.num_warps or [None]
num_stages_list = args.num_stages or [None]
sparsity_list = args.sparsity or [0.5]
dtype = getattr(torch, args.dtype)
if args.star > 0:
import torch.sparse._triton_ops
assert {len(m_list), len(n_list), len(k_list), len(bm_list), len(bk_list)} == {
1
}
m = m_list[0]
n = n_list[0] or m
k = k_list[0] or m
bm = bm_list[0]
bk = bk_list[0] or bm
if "bsr_scatter_mm6" in ops:
meta = torch.sparse._triton_ops.scatter_mm_meta(m, k, n, bm, bk)
elif "bsr_dense_mm_with_meta" in ops:
meta = torch.sparse._triton_ops.bsr_dense_mm_meta(m, k, n, bm, bk)
else:
raise NotImplementedError(f"--star not implemented for operations in {ops}")
if "bsr_scatter_mm6" in ops:
if split_n_list[0] is None:
split_n_list = [
meta["SPLIT_N"] // 2,
meta["SPLIT_N"],
meta["SPLIT_N"] * 2,
][int(meta["SPLIT_N"] == 1) :]
elif split_n_list[0] == 0:
split_n_list = [meta["SPLIT_N"]]
if tile_m_list[0] is None:
tile_m_list = [meta["TILE_M"] // 2, meta["TILE_M"], meta["TILE_M"] * 2][
int(meta["TILE_M"] == 16) :
]
elif tile_m_list[0] == 0:
tile_m_list = [meta["TILE_M"]]
if tile_n_list[0] is None:
tile_n_list = [meta["TILE_N"] // 2, meta["TILE_N"], meta["TILE_N"] * 2][
int(meta["TILE_N"] == 16) :
]
elif tile_n_list[0] == 0:
tile_n_list = [meta["TILE_N"]]
if group_size_list[0] is None:
group_size_list = [
meta["GROUP_SIZE"] - 1,
meta["GROUP_SIZE"],
meta["GROUP_SIZE"] + 1,
][int(meta["GROUP_SIZE"] == 1) :]
elif group_size_list[0] == 0:
group_size_list = [meta["GROUP_SIZE"]]
if "bsr_dense_mm_with_meta" in ops:
if group_size_list[0] is None:
group_size_list = [
meta["GROUP_SIZE_ROW"] - 1,
meta["GROUP_SIZE_ROW"],
meta["GROUP_SIZE_ROW"] + 1,
][int(meta["GROUP_SIZE_ROW"] == 1) :]
elif group_size_list[0] == 0:
group_size_list = [meta["GROUP_SIZE_ROW"]]
if num_warps_list[0] is None:
num_warps_list = [
meta["num_warps"] // 2,
meta["num_warps"],
meta["num_warps"] * 2,
][int(meta["num_warps"] == 1) :]
elif num_warps_list[0] == 0:
num_warps_list = [meta["num_warps"]]
if num_stages_list[0] is None:
num_stages_list = [
meta["num_stages"] - 1,
meta["num_stages"],
meta["num_stages"] + 1,
][int(meta["num_stages"] == 1) :]
elif num_stages_list[0] == 0:
num_stages_list = [meta["num_stages"]]
device = args.device
dense_dense_mm_sizes = set()
target_performance = None
performance_rtol = 1e-2
best_messages = []
@atexit.register
def show_best_messages(best_messages=best_messages):
print("TOP 10:")
for m in best_messages[-10:]:
print(m)
sys.stdout.flush()
for m, k, n, bm, bk, sparsity in itertools.product(
m_list, k_list, n_list, bm_list, bk_list, sparsity_list
):
k = k or m
n = n or m
bk = bk or bm
if bm > m or bk > k:
# Skip invalid parameter combinations
continue
blocksize = (bm, bk)
if isinstance(sparsity, int):
# integer sparsity value corresponds to desired nnz value
sparsity = 1 - bk * bm * sparsity / (m * k)
if sparsity > 1 or sparsity < 0:
continue
x = create_blocked_tensor(
b, m, k, blocksize, sparsity, dtype, device
).to_sparse_bsr(blocksize)
# recompute sparsity
sparsity = 1 - bk * bm * x._nnz() / (m * k)
y = make_tensor(k, n, dtype=dtype, device=device)
bsr_size = f"{b}x{m}x{k}" if b > 0 else f"{k}x{n}"
for op in ops:
if op == "dense_dense_mm":
if (m, k, n) in dense_dense_mm_sizes:
# Skip already benchmarked cases
continue
dense_dense_mm_sizes.add((m, k, n))
best_tflops = 0
for (
split_n,
num_warps,
num_stages,
tile_m,
tile_n,
group_size,
) in itertools.product(
split_n_list,
num_warps_list,
num_stages_list,
tile_m_list,
tile_n_list,
group_size_list,
):
if (
(tile_m or 0) > bm
or (tile_n or 0) > n // (split_n or 1)
or n % (split_n or 1) != 0
or (split_n or 0) > n
):
# Skip invalid parameter combinations
continue
test_func = globals()["test_" + op]
meta = dict(
bsr_scatter_mm6=dict(
SPLIT_N=split_n,
TILE_M=tile_m,
TILE_N=tile_n,
GROUP_SIZE=group_size,
num_stages=num_stages,
num_warps=num_warps,
),
bsr_dense_mm_with_meta=dict(
GROUP_SIZE_ROW=group_size,
num_stages=num_stages,
num_warps=num_warps,
),
).get(op, dict())
meta_str = ";".join(
f"{k}={v}" for k, v in meta.items() if v is not None
)
time_ms_lst = []
performance_tflops_lst = []
for r in range(args.repeat):
try:
time_ms, performance_tflops = test_func(x, y, **meta)
except triton.compiler.OutOfResources as msg:
print(
f"op={op}[{meta_str}]({bsr_size},{k}x{n}) dtype={args.dtype} {sparsity=}(nnz={x._nnz()})"
f" blocksize={bm}x{bk} OutOfResources",
file=outfile,
)
continue
except AssertionError:
raise
except Exception as msg:
msg = str(msg).split("\n", 1)[0]
print(
f"op={op}[{meta_str}]({bsr_size},{k}x{n}) dtype={args.dtype} {sparsity=}(nnz={x._nnz()})"
f" blocksize={bm}x{bk} {msg}",
file=outfile,
)
continue
time_ms_lst.append(time_ms)
performance_tflops_lst.append(performance_tflops)
mark = ""
if op == "dense_dense_mm":
if target_performance is None:
target_performance = performance_tflops
elif target_performance is not None:
if (
abs(1 - performance_tflops / target_performance)
< performance_rtol
):
mark += " @@@"
if best_tflops < performance_tflops:
best_tflops = performance_tflops
best_message = (
f"op={op}[{meta_str}]({bsr_size},x{n}) dtype={args.dtype} {sparsity=:.4f}(nnz={x._nnz()})"
f" blocksize={bm}x{bk} time={time_ms:.3f} ms performance={performance_tflops:.3f} TFLOPS"
)
if best_message not in best_messages:
best_messages.append(best_message)
mark += " !!!"
print(
f"op={op}[{meta_str}]({bsr_size},x{n}) dtype={args.dtype} {sparsity=:.4f}(nnz={x._nnz()})"
f" blocksize={bm}x{bk}"
f" time={time_ms:.3f} ms performance={performance_tflops:.3f} TFLOPS{mark}",
file=outfile,
)
outfile.flush()
if args.repeat > 1:
avg_time_ms = sum(time_ms_lst) / len(time_ms_lst)
avg_performance_tflops = sum(performance_tflops_lst) / len(
performance_tflops_lst
)
print(
f"op={op}[{meta_str}]({bsr_size},{k}x{n}) dtype={args.dtype} {sparsity=}(nnz={x._nnz()})"
f" blocksize={bm}x{bk}"
f" time={time_ms:.3f} ms performance={performance_tflops:.3f} TFLOPS [AVERAGE]",
file=outfile,
)
outfile.flush()
if op not in {"bsr_scatter_mm6", "bsr_dense_mm_with_meta"}:
# Break on operations that do not consume parameters
break