| import argparse |
| import operator |
| import time |
| |
| import matplotlib.pyplot as plt |
| import numpy as np |
| import pandas as pd |
| import seaborn as sns |
| |
| import torch |
| import torch._C._te as te |
| |
| |
| class kernel_arena_scope: |
| def __enter__(self): |
| self.scope = te.KernelScope() |
| |
| def __exit__(self, typ, val, traceback): |
| self.scope = None |
| |
| |
| unary_ops = [ |
| ("sin", torch.sin), |
| ("cos", torch.cos), |
| ("tan", torch.tan), |
| ("asin", torch.asin), |
| ("acos", torch.acos), |
| ("atan", torch.atan), |
| ("sinh", torch.sinh), |
| ("cosh", torch.cosh), |
| ("tanh", torch.tanh), |
| ("sigmoid", torch.sigmoid), |
| ("exp", torch.exp), |
| ("expm1", torch.expm1), |
| ("expm1", torch.expm1), |
| ("abs", torch.abs), |
| ("log", torch.log), |
| ("fast_log", torch.log), |
| ("log2", torch.log2), |
| ("log10", torch.log10), |
| ("log1p", torch.log1p), |
| ("erf", torch.erf), |
| ("erfc", torch.erfc), |
| ("sqrt", torch.sqrt), |
| ("rsqrt", torch.rsqrt), |
| ("ceil", torch.ceil), |
| ("floor", torch.floor), |
| ("round", torch.round), |
| ("trunc", torch.trunc), |
| ("lgamma", torch.lgamma), |
| # ("frac", torch.frac), # seems unimplemented |
| # ("isnan", torch.isnan), # no out variant |
| ] |
| |
| |
| def gen_unary_nnc_fun(nnc_name): |
| def nnc_fun(A, B): |
| def compute(i, j): |
| return getattr(A.load([i, j]), nnc_name)() |
| |
| return compute |
| |
| return nnc_fun |
| |
| |
| def gen_unary_torch_fun(torch_op): |
| def torch_fun(a, b, out): |
| def fun(): |
| return torch_op(a, out=out) |
| |
| return fun |
| |
| return torch_fun |
| |
| |
| def gen_binary_nnc_fun(fn): |
| def nnc_fun(A, B): |
| def compute(i, j): |
| return fn(A.load([i, j]), B.load([i, j])) |
| |
| return compute |
| |
| return nnc_fun |
| |
| |
| def gen_binary_torch_fun(fn): |
| def pt_fun(a, b, out): |
| def fun(): |
| return fn(a, b, out=out) |
| |
| return fun |
| |
| return pt_fun |
| |
| |
| def gen_int_comparison_tensors(N, M): |
| return ( |
| torch.randint(0, 3, (N, M)), |
| torch.randint(0, 3, (N, M)), |
| torch.empty((N, M), dtype=torch.bool), |
| ) |
| |
| |
| def gen_float_comparison_tensors(N, M): |
| return (torch.rand(N, M), torch.rand(N, M), torch.empty((N, M), dtype=torch.bool)) |
| |
| |
| te_bool = te.Dtype.Bool |
| binary_ops = [ |
| ("add", operator.add, torch.add), |
| ("mul", operator.mul, torch.mul), |
| ("sub", operator.sub, torch.sub), |
| ("div", operator.truediv, torch.div), |
| ( |
| "eq", |
| (lambda a, b: te.Cast.make(te_bool, a == b)), |
| torch.eq, |
| gen_int_comparison_tensors, |
| ), |
| ( |
| "gt", |
| (lambda a, b: te.Cast.make(te_bool, a > b)), |
| torch.gt, |
| gen_float_comparison_tensors, |
| ), |
| ( |
| "lt", |
| (lambda a, b: te.Cast.make(te_bool, a < b)), |
| torch.lt, |
| gen_float_comparison_tensors, |
| ), |
| ( |
| "gte", |
| (lambda a, b: te.Cast.make(te_bool, a >= b)), |
| torch.greater_equal, |
| gen_float_comparison_tensors, |
| ), |
| ( |
| "lte", |
| (lambda a, b: te.Cast.make(te_bool, a <= b)), |
| torch.less_equal, |
| gen_float_comparison_tensors, |
| ), |
| # ('neq', (lambda a, b: a != b), None)), # no one-op equivalent |
| # ('&', (lambda a, b: a & b), torch.bitwise_and), # requires more work to test |
| ] |
| |
| |
| def nnc_relu(A, B): |
| def f(i, j): |
| return torch._C._te.ifThenElse( |
| A.load([i, j]) < torch._C._te.ExprHandle.float(0), |
| torch._C._te.ExprHandle.float(0), |
| A.load([i, j]), |
| ) |
| |
| return f |
| |
| |
| def pt_relu(a, b, c): |
| return torch.relu(a) |
| |
| |
| custom_ops = [ |
| ("relu", nnc_relu, pt_relu), |
| # ('nnc_mul_relu', nnc_mul_relu, pt_mul_relu) |
| # ('manual_sigmoid', nnc_manual_sigmoid, lambda a, b, c: torch.sigmoid(a, out=c)) |
| ] |
| |
| |
| def gen_custom_torch_fun(fn): |
| def pt_fun(a, b, out): |
| def fun(): |
| return fn(a, b, out) |
| |
| return fun |
| |
| return pt_fun |
| |
| |
| def normalize_benchmarks(ops): |
| return [i + (None,) if len(i) == 3 else i for i in ops] |
| |
| |
| names = [] |
| nnc_fns = [] |
| pt_fns = [] |
| shape_fns = [] |
| |
| for nnc_name, pt_op in unary_ops: |
| names.append(nnc_name) |
| nnc_fns.append(gen_unary_nnc_fun(nnc_name)) |
| pt_fns.append(gen_unary_torch_fun(pt_op)) |
| shape_fns.append(None) |
| |
| for name, lmbda, pt_fn, shape_fn in normalize_benchmarks(binary_ops): |
| names.append(name) |
| nnc_fns.append(gen_binary_nnc_fun(lmbda)) |
| pt_fns.append(gen_binary_torch_fun(pt_fn)) |
| shape_fns.append(shape_fn) |
| |
| for name, lmbda, pt_fn, shape_fn in normalize_benchmarks(custom_ops): |
| names.append(name) |
| nnc_fns.append(lmbda) |
| pt_fns.append(gen_custom_torch_fun(pt_fn)) |
| shape_fns.append(shape_fn) |
| |
| benchmarks = list(zip(names, nnc_fns, pt_fns, shape_fns)) |
| |
| |
| def run_benchmarks(benchmarks, sizes): |
| df = pd.DataFrame(columns=["name", "N", "M", "nnc_time", "torch_time", "ratio"]) |
| with torch.no_grad(): |
| for name, nnc_fun, torch_fun, shape_fn in benchmarks: |
| for N, M in sizes: |
| iters = int(1e6 / (N + M)) |
| with kernel_arena_scope(): |
| if shape_fn is None: |
| tA = torch.rand(M, N).clamp(0.01, 0.99) |
| tB = torch.rand(M, N).clamp(0.01, 0.99) |
| tX = torch.empty(M, N) |
| tR = torch.empty(M, N) |
| else: |
| tA, tB, tX = shape_fn(M, N) |
| tR = tX.clone() |
| |
| def get_nnc_type(dtype): |
| if dtype == torch.float: |
| return torch._C._te.Dtype.Float |
| elif dtype == torch.long: |
| return torch._C._te.Dtype.Long |
| |
| dtype = get_nnc_type(tA.dtype) |
| |
| dM = torch._C._te.ExprHandle.int(M) |
| dN = torch._C._te.ExprHandle.int(N) |
| |
| A = torch._C._te.Placeholder("A", dtype, [dM, dN]) |
| B = torch._C._te.Placeholder("B", dtype, [dM, dN]) |
| |
| dim_args = [ |
| torch._C._te.DimArg(*args) for args in [(dM, "m"), (dN, "n")] |
| ] |
| |
| compute = nnc_fun(A, B) |
| X = torch._C._te.Compute("X", dim_args, compute) |
| loopnest = torch._C._te.LoopNest([X]) |
| loopnest.prepare_for_codegen() |
| stmt = torch._C._te.simplify(loopnest.root_stmt()) |
| cg = torch._C._te.construct_codegen( |
| "llvm", stmt, [torch._C._te.BufferArg(x) for x in [A, B, X]] |
| ) |
| |
| # warmup |
| for _ in range(10): |
| cg.call([tA, tB, tX]) |
| start = time.time() |
| for it in range(iters): |
| cg.call([tA, tB, tX]) |
| time1 = time.time() - start |
| |
| fn = torch_fun(tA, tB, tR) |
| # warmup |
| for _ in range(10): |
| tR = fn() |
| start = time.time() |
| for it in range(iters): |
| tR = fn() |
| time2 = time.time() - start |
| |
| df = df.append( |
| { |
| "name": name, |
| "N": N, |
| "M": M, |
| "nnc_time": time1, |
| "torch_time": time2, |
| "ratio": time2 / time1, |
| }, |
| ignore_index=True, |
| ) |
| print(name, N, M) |
| |
| print(time2 / time1, time1, time2) |
| print() |
| |
| def check_correctness(a, b): |
| if not np.allclose(a, b): |
| print(name) |
| assert np.allclose(a, b) |
| |
| check_correctness(tX, tR) |
| return df |
| |
| |
| def dump_plot(df, sizes): |
| keys = [] |
| vals = [] |
| indexed = df[df["N"] == df["M"]] |
| for index, row in indexed.iterrows(): |
| keys.append(row["name"]) |
| vals.append(row["ratio"]) |
| |
| keys = keys[:: len(sizes)] |
| sns.set(rc={"figure.figsize": (5.0, len(keys) * 0.5)}) |
| |
| cmap = sns.diverging_palette(10, 120, n=9, as_cmap=True) |
| np_vals = np.array([vals]).reshape(-1, len(sizes)) |
| g = sns.heatmap(np_vals, annot=True, cmap=cmap, center=1.0, yticklabels=True) |
| plt.yticks(rotation=0) |
| plt.title("PyTorch performance divided by NNC performance (single core)") |
| plt.xlabel("Size of NxN matrix") |
| plt.ylabel("Operation") |
| g.set_yticklabels(keys) |
| g.set_xticklabels(sizes) |
| |
| plt.savefig("nnc.png") |
| |
| |
| if __name__ == "__main__": |
| parser = argparse.ArgumentParser(description="Runs NNC microbenchmarks") |
| parser.add_argument( |
| "--multi-threaded", |
| "--multi_threaded", |
| action="store_true", |
| help="Run with more than one thread", |
| ) |
| args = parser.parse_args() |
| if not args.multi_threaded: |
| torch.set_num_threads(1) |
| |
| sizes = [1, 4, 16, 64, 256, 1024] |
| df = run_benchmarks(benchmarks, [(i, i) for i in sizes]) |
| dump_plot(df, sizes) |