blob: 7716aeb6bf12817aedf5332252afc03aacc273fc [file] [log] [blame]
import argparse
import operator
import time
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
import torch
import torch._C._te as te
class kernel_arena_scope:
def __enter__(self):
self.scope = te.KernelScope()
def __exit__(self, typ, val, traceback):
self.scope = None
unary_ops = [
("sin", torch.sin),
("cos", torch.cos),
("tan", torch.tan),
("asin", torch.asin),
("acos", torch.acos),
("atan", torch.atan),
("sinh", torch.sinh),
("cosh", torch.cosh),
("tanh", torch.tanh),
("sigmoid", torch.sigmoid),
("exp", torch.exp),
("expm1", torch.expm1),
("expm1", torch.expm1),
("abs", torch.abs),
("log", torch.log),
("fast_log", torch.log),
("log2", torch.log2),
("log10", torch.log10),
("log1p", torch.log1p),
("erf", torch.erf),
("erfc", torch.erfc),
("sqrt", torch.sqrt),
("rsqrt", torch.rsqrt),
("ceil", torch.ceil),
("floor", torch.floor),
("round", torch.round),
("trunc", torch.trunc),
("lgamma", torch.lgamma),
# ("frac", torch.frac), # seems unimplemented
# ("isnan", torch.isnan), # no out variant
]
def gen_unary_nnc_fun(nnc_name):
def nnc_fun(A, B):
def compute(i, j):
return getattr(A.load([i, j]), nnc_name)()
return compute
return nnc_fun
def gen_unary_torch_fun(torch_op):
def torch_fun(a, b, out):
def fun():
return torch_op(a, out=out)
return fun
return torch_fun
def gen_binary_nnc_fun(fn):
def nnc_fun(A, B):
def compute(i, j):
return fn(A.load([i, j]), B.load([i, j]))
return compute
return nnc_fun
def gen_binary_torch_fun(fn):
def pt_fun(a, b, out):
def fun():
return fn(a, b, out=out)
return fun
return pt_fun
def gen_int_comparison_tensors(N, M):
return (
torch.randint(0, 3, (N, M)),
torch.randint(0, 3, (N, M)),
torch.empty((N, M), dtype=torch.bool),
)
def gen_float_comparison_tensors(N, M):
return (torch.rand(N, M), torch.rand(N, M), torch.empty((N, M), dtype=torch.bool))
te_bool = te.Dtype.Bool
binary_ops = [
("add", operator.add, torch.add),
("mul", operator.mul, torch.mul),
("sub", operator.sub, torch.sub),
("div", operator.truediv, torch.div),
(
"eq",
(lambda a, b: te.Cast.make(te_bool, a == b)),
torch.eq,
gen_int_comparison_tensors,
),
(
"gt",
(lambda a, b: te.Cast.make(te_bool, a > b)),
torch.gt,
gen_float_comparison_tensors,
),
(
"lt",
(lambda a, b: te.Cast.make(te_bool, a < b)),
torch.lt,
gen_float_comparison_tensors,
),
(
"gte",
(lambda a, b: te.Cast.make(te_bool, a >= b)),
torch.greater_equal,
gen_float_comparison_tensors,
),
(
"lte",
(lambda a, b: te.Cast.make(te_bool, a <= b)),
torch.less_equal,
gen_float_comparison_tensors,
),
# ('neq', (lambda a, b: a != b), None)), # no one-op equivalent
# ('&', (lambda a, b: a & b), torch.bitwise_and), # requires more work to test
]
def nnc_relu(A, B):
def f(i, j):
return torch._C._te.ifThenElse(
A.load([i, j]) < torch._C._te.ExprHandle.float(0),
torch._C._te.ExprHandle.float(0),
A.load([i, j]),
)
return f
def pt_relu(a, b, c):
return torch.relu(a)
custom_ops = [
("relu", nnc_relu, pt_relu),
# ('nnc_mul_relu', nnc_mul_relu, pt_mul_relu)
# ('manual_sigmoid', nnc_manual_sigmoid, lambda a, b, c: torch.sigmoid(a, out=c))
]
def gen_custom_torch_fun(fn):
def pt_fun(a, b, out):
def fun():
return fn(a, b, out)
return fun
return pt_fun
def normalize_benchmarks(ops):
return [i + (None,) if len(i) == 3 else i for i in ops]
names = []
nnc_fns = []
pt_fns = []
shape_fns = []
for nnc_name, pt_op in unary_ops:
names.append(nnc_name)
nnc_fns.append(gen_unary_nnc_fun(nnc_name))
pt_fns.append(gen_unary_torch_fun(pt_op))
shape_fns.append(None)
for name, lmbda, pt_fn, shape_fn in normalize_benchmarks(binary_ops):
names.append(name)
nnc_fns.append(gen_binary_nnc_fun(lmbda))
pt_fns.append(gen_binary_torch_fun(pt_fn))
shape_fns.append(shape_fn)
for name, lmbda, pt_fn, shape_fn in normalize_benchmarks(custom_ops):
names.append(name)
nnc_fns.append(lmbda)
pt_fns.append(gen_custom_torch_fun(pt_fn))
shape_fns.append(shape_fn)
benchmarks = list(zip(names, nnc_fns, pt_fns, shape_fns))
def run_benchmarks(benchmarks, sizes):
df = pd.DataFrame(columns=["name", "N", "M", "nnc_time", "torch_time", "ratio"])
with torch.no_grad():
for name, nnc_fun, torch_fun, shape_fn in benchmarks:
for N, M in sizes:
iters = int(1e6 / (N + M))
with kernel_arena_scope():
if shape_fn is None:
tA = torch.rand(M, N).clamp(0.01, 0.99)
tB = torch.rand(M, N).clamp(0.01, 0.99)
tX = torch.empty(M, N)
tR = torch.empty(M, N)
else:
tA, tB, tX = shape_fn(M, N)
tR = tX.clone()
def get_nnc_type(dtype):
if dtype == torch.float:
return torch._C._te.Dtype.Float
elif dtype == torch.long:
return torch._C._te.Dtype.Long
dtype = get_nnc_type(tA.dtype)
dM = torch._C._te.ExprHandle.int(M)
dN = torch._C._te.ExprHandle.int(N)
A = torch._C._te.Placeholder("A", dtype, [dM, dN])
B = torch._C._te.Placeholder("B", dtype, [dM, dN])
dim_args = [
torch._C._te.DimArg(*args) for args in [(dM, "m"), (dN, "n")]
]
compute = nnc_fun(A, B)
X = torch._C._te.Compute("X", dim_args, compute)
loopnest = torch._C._te.LoopNest([X])
loopnest.prepare_for_codegen()
stmt = torch._C._te.simplify(loopnest.root_stmt())
cg = torch._C._te.construct_codegen(
"llvm", stmt, [torch._C._te.BufferArg(x) for x in [A, B, X]]
)
# warmup
for _ in range(10):
cg.call([tA, tB, tX])
start = time.time()
for it in range(iters):
cg.call([tA, tB, tX])
time1 = time.time() - start
fn = torch_fun(tA, tB, tR)
# warmup
for _ in range(10):
tR = fn()
start = time.time()
for it in range(iters):
tR = fn()
time2 = time.time() - start
df = df.append(
{
"name": name,
"N": N,
"M": M,
"nnc_time": time1,
"torch_time": time2,
"ratio": time2 / time1,
},
ignore_index=True,
)
print(name, N, M)
print(time2 / time1, time1, time2)
print()
def check_correctness(a, b):
if not np.allclose(a, b):
print(name)
assert np.allclose(a, b)
check_correctness(tX, tR)
return df
def dump_plot(df, sizes):
keys = []
vals = []
indexed = df[df["N"] == df["M"]]
for index, row in indexed.iterrows():
keys.append(row["name"])
vals.append(row["ratio"])
keys = keys[:: len(sizes)]
sns.set(rc={"figure.figsize": (5.0, len(keys) * 0.5)})
cmap = sns.diverging_palette(10, 120, n=9, as_cmap=True)
np_vals = np.array([vals]).reshape(-1, len(sizes))
g = sns.heatmap(np_vals, annot=True, cmap=cmap, center=1.0, yticklabels=True)
plt.yticks(rotation=0)
plt.title("PyTorch performance divided by NNC performance (single core)")
plt.xlabel("Size of NxN matrix")
plt.ylabel("Operation")
g.set_yticklabels(keys)
g.set_xticklabels(sizes)
plt.savefig("nnc.png")
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Runs NNC microbenchmarks")
parser.add_argument(
"--multi-threaded",
"--multi_threaded",
action="store_true",
help="Run with more than one thread",
)
args = parser.parse_args()
if not args.multi_threaded:
torch.set_num_threads(1)
sizes = [1, 4, 16, 64, 256, 1024]
df = run_benchmarks(benchmarks, [(i, i) for i in sizes])
dump_plot(df, sizes)