| import argparse |
| import random |
| |
| import pandas as pd |
| from tqdm import tqdm |
| |
| import torch |
| import torch.utils.benchmark as benchmark |
| from torch import nn |
| from torch.sparse import SparseSemiStructuredTensor, to_sparse_semi_structured |
| |
| |
| torch.set_printoptions( |
| precision=2, |
| threshold=None, |
| edgeitems=16, |
| linewidth=480, |
| profile=None, |
| sci_mode=False, |
| ) |
| |
| |
| # helper model definition for pruner |
| class Model(nn.Module): |
| def __init__(self, m, k, dtype=None): |
| super().__init__() |
| # transposed so reversed |
| self.linear = nn.Linear(k, m) |
| |
| def forward(self, x): |
| return self.linear(x) |
| |
| |
| def rand_sparse_semi_structured_mask( |
| r, c, dtype=torch.float16, device="cuda", choice=None |
| ): |
| """ |
| This function returns a 1:2 sparse matrix of size (r, c). |
| Note that this means this matrix will also be 2:4 and 4:8 sparse as well. |
| """ |
| |
| choices = [[0, 1], [1, 0]] |
| mask_entries = [choice or random.choice(choices) for i in range(r * c // 2)] |
| |
| return ( |
| torch.tensor(mask_entries, dtype=dtype, device=device) |
| .reshape(r, c) |
| .contiguous() |
| ) |
| |
| |
| def test_linear(m, k, n, dtype, contiguous, backend): |
| SparseSemiStructuredTensor._FORCE_CUTLASS = backend == "cutlass" |
| mask = rand_sparse_semi_structured_mask(m, k, dtype=dtype) |
| sparse_weight = torch.rand(m, k).to(dtype).cuda() * mask |
| input_tensor = torch.zeros(n, k).to(dtype).cuda() |
| model = Model(m, k).to(dtype).cuda().eval() |
| |
| dense_measurement = benchmark.Timer( |
| stmt="model(input_tensor)", |
| globals=locals(), |
| ).blocked_autorange() |
| |
| dense_output = model(input_tensor) |
| print(dense_output.shape) |
| |
| # sparsify weights |
| model.linear.weight = nn.Parameter( |
| to_sparse_semi_structured( |
| sparse_weight, |
| ) |
| ) |
| |
| sparse_output = model(input_tensor) |
| print(sparse_output.shape) |
| |
| sparse_measurement = benchmark.Timer( |
| stmt="model(input_tensor)", |
| globals=locals(), |
| ).blocked_autorange() |
| |
| correct = torch.allclose(dense_output, sparse_output, rtol=1e-3, atol=1e-3) |
| |
| return { |
| "test_function": "linear", |
| "m": m, |
| "k": k, |
| "n": n, |
| "dtype": str(dtype), |
| "backend": backend, |
| "sparse_latency (ms)": sparse_measurement.median * 1000, |
| "dense_latency (ms)": dense_measurement.median * 1000, |
| "speedup (d/s)": dense_measurement.median / sparse_measurement.median, |
| "correct": correct, |
| "contiguous": sparse_output.is_contiguous(), |
| } |
| |
| |
| def test_tensor(m, k, n, dtype, contiguous, backend): |
| A = rand_sparse_semi_structured_mask(m, k, dtype=dtype) |
| B = torch.zeros(k, n).to(dtype).cuda() |
| bias = torch.rand(n).to(dtype).cuda() |
| |
| sA = to_sparse_semi_structured(A) |
| |
| # torch.mm calculation |
| if dtype is not torch.int8: |
| dense_output = torch.mm(A, B) |
| |
| dense_measurement = benchmark.Timer( |
| stmt="torch.mm(A, B)", |
| globals=locals(), |
| ).blocked_autorange() |
| |
| else: |
| print("int8 baseline not supported") |
| dense_output = torch.mm(sA, B) |
| |
| dense_measurement = benchmark.Timer( |
| stmt="torch.mm(sA, B)", |
| globals=locals(), |
| ).blocked_autorange() |
| |
| sparse_output = torch.mm(sA, B) |
| sparse_measurement = benchmark.Timer( |
| stmt="torch.mm(sA, B)", |
| globals=locals(), |
| ).blocked_autorange() |
| |
| correct = torch.allclose(dense_output, sparse_output, rtol=1e-3, atol=1e-3) |
| |
| return { |
| "test_function": "tensor", |
| "m": m, |
| "k": k, |
| "n": n, |
| "dtype": str(dtype), |
| "backend": backend, |
| "sparse_latency (ms)": sparse_measurement.median * 1000, |
| "dense_latency (ms)": dense_measurement.median * 1000, |
| "speedup (d/s)": dense_measurement.median / sparse_measurement.median, |
| "correct": correct, |
| "contiguous": sparse_output.is_contiguous(), |
| } |
| |
| |
| if __name__ == "__main__": |
| dtype_lookup = { |
| "int8": torch.int8, |
| "fp16": torch.float16, |
| "bf16": torch.bfloat16, |
| "fp32": torch.float32, |
| } |
| |
| parser = argparse.ArgumentParser(description="Semi-Structured Sparsity Benchmarks") |
| parser.add_argument( |
| "--mode", |
| type=str, |
| choices=[ |
| "nvidia-bert", |
| "nvidia-fixed-k", |
| "nvidia-fixed-mn", |
| ], |
| ) |
| parser.add_argument( |
| "--dtype", |
| type=str, |
| choices=dtype_lookup.keys(), |
| default="fp16", |
| ) |
| parser.add_argument( |
| "--backend", type=str, choices=["cutlass", "cusparselt"], default="cusparselt" |
| ) |
| parser.add_argument("-contiguous", action="store_true") |
| parser.add_argument("-e2e", action="store_true") |
| parser.add_argument("-save", action="store_true") |
| args = parser.parse_args() |
| |
| if args.e2e: |
| eval_fn = test_linear |
| else: |
| eval_fn = test_tensor |
| |
| print(f"Started benchmark: {args.mode} | dtype: {args.dtype}") |
| dtype = dtype_lookup[args.dtype] |
| |
| if args.mode == "nvidia-bert": |
| bert_shapes = [ |
| (3072, 1024, 16384), |
| (4096, 1024, 16384), |
| (1024, 1024, 16384), |
| (1024, 4096, 16384), |
| ] |
| results = ( |
| eval_fn(m, k, n, dtype, args.contiguous, args.backend) |
| for (m, k, n) in tqdm(bert_shapes) |
| ) |
| |
| elif args.mode == "nvidia-fixed-k": |
| mn_vals = [ |
| 3072, |
| 4096, |
| 5120, |
| 6144, |
| 7168, |
| 8192, |
| 9216, |
| 10240, |
| 11264, |
| 12288, |
| 13312, |
| 14336, |
| 15360, |
| 16384, |
| 17408, |
| 18432, |
| 19456, |
| 20480, |
| ] |
| results = ( |
| eval_fn(mn, 10240, mn, dtype, args.contiguous, args.backend) |
| for mn in tqdm(mn_vals) |
| ) |
| |
| elif args.mode == "nvidia-fixed-mn": |
| k_vals = [ |
| 2560, |
| 3840, |
| 5120, |
| 6400, |
| 7680, |
| 8960, |
| 10240, |
| 11520, |
| 12800, |
| 14080, |
| 15360, |
| 16640, |
| 17920, |
| 19200, |
| 20480, |
| ] |
| results = ( |
| eval_fn(10240, k, 10240, dtype, args.contiguous, args.backend) |
| for k in tqdm(k_vals) |
| ) |
| |
| df = pd.DataFrame.from_records(results) |
| if args.save: |
| save_file = f"{args.mode}_{args.dtype}_{args.backend}.csv" |
| df.to_csv(save_file) |
| print(f"Finished benchmark: {args.mode} saved results to {save_file}") |
| print(df) |