blob: 303e7a6ec1db9c63ad293935345d78a1a8a76dc9 [file] [log] [blame]
import click
import sys
import time
import torch
import inspect
import itertools
torch.set_num_threads(1)
torch._C._debug_set_fusion_group_inlining(False)
def rand(*shape):
return torch.rand(*shape).mul(16).add(1)
# ------------------------------------------------------------------------------
# Shape test cases
# ------------------------------------------------------------------------------
def scalar():
return (rand(1), rand(1))
def small():
return (rand(32), rand(32))
def small_2d():
return (rand(1, 32), rand(1, 32))
def small_broadcast():
return (rand(4, 32), rand(32))
def medium():
return (rand(32, 12, 64, 64), rand(32, 12, 64, 64))
def medium_sliced():
return (rand(32, 12, 64, 64)[..., ::2], rand(32, 12, 64, 64)[..., ::2])
def medium_transpose():
return (
rand(32, 12, 64, 64).transpose(-1, -2),
rand(32, 12, 64, 64).transpose(-1, -2),
)
def medium2():
return (rand(32, 3, 224, 224), rand(32, 3, 224, 224))
def medium3d():
return (rand(16, 32, 64), rand(16, 32, 64))
def medium_channels_last():
return (
rand(32, 3, 224, 224).to(memory_format=torch.channels_last),
rand(32, 3, 224, 224).to(memory_format=torch.channels_last),
)
def medium_broadcast():
return (rand(32, 12, 64, 64), rand(64))
def medium_broadcast_channels_last():
return (rand(32, 3, 223, 223).to(memory_format=torch.channels_last), rand(3, 1, 1))
def large():
return (rand(8192, 8192), rand(8192, 8192))
def large_transpose():
return (rand(8192, 8192).transpose(0, 1), rand(8192, 8192).transpose(0, 1))
def large_channels_last():
return (
rand(32, 32, 256, 256).to(memory_format=torch.channels_last),
rand(32, 32, 256, 256).to(memory_format=torch.channels_last),
)
def broadcast_narrow_57611():
return (rand(1, 32, 32, 2), rand(1024, 1, 1, 2))
def large_broadcast_66816():
return (rand(64, 8, 256, 162), rand(256, 162))
# ------------------------------------------------------------------------------
# Operator test cases
# ------------------------------------------------------------------------------
def add(a, b):
return 3 * a + b
def sub(a, b):
return 3 * a - b
def mul(a, b):
return 3 * a * b
def div(a, b):
return 3 * a / b
def relu(a):
return (3 * a).relu()
def sigmoid(a):
return (3 * a).sigmoid()
def tanh(a):
return (3 * a).tanh()
def log(a):
return (3 * a).log()
def exp(a):
return (3 * a).exp()
def square(a):
return (3 * a) ** 2
def fma(a, b):
return a * b + b
def mul_mul_add_66816(a, b, c):
return (a * b) + (a * c)
def hardswish_int(a):
return a * (a + 3).clamp(0, 6) / 6
def hardswish(a):
return a * (a + 3).clamp(0.0, 6.0) / 6
def native_hardswish(a):
return torch._C._nn.hardswish(a * 3)
def softplus(a):
return (a * 1.0).exp().log1p() / 1.0
def mish(a):
return a * ((a * 1.0).exp().log1p() / 1.0).tanh()
SHAPES = [
scalar,
small,
small_2d,
small_broadcast,
medium,
medium2,
medium3d,
medium_sliced,
medium_transpose,
medium_channels_last,
medium_broadcast,
medium_broadcast_channels_last,
large,
large_transpose,
large_channels_last,
broadcast_narrow_57611,
large_broadcast_66816,
]
OPERATORS = [
add,
sub,
mul,
div,
relu,
sigmoid,
tanh,
log,
exp,
square,
fma,
mul_mul_add_66816,
hardswish_int,
hardswish,
native_hardswish,
softplus,
mish,
]
def time_cpu(fn, args, iters):
s = time.perf_counter()
for _ in range(iters):
fn(*args)
e = time.perf_counter()
return e - s
def time_cuda(fn, args, iters):
start = torch.cuda.Event(enable_timing=True)
end = torch.cuda.Event(enable_timing=True)
start.record()
for _ in range(iters):
fn(*args)
end.record()
torch.cuda.synchronize()
return start.elapsed_time(end) / 1e3
def benchmark_with_timer(fn, args, timer):
timer(fn, args, 3)
calibration = timer(fn, args, 1)
iters = int(1.0 / calibration)
return timer(fn, args, iters) / iters
def benchmark(fn, args):
timer = time_cpu if args[0].device.type == "cpu" else time_cuda
return benchmark_with_timer(fn, args, timer)
def micros(s):
return f"{s * 1e6:.1f}"
def with_nvfuser():
torch._C._jit_override_can_fuse_on_cpu(False)
torch._C._jit_override_can_fuse_on_gpu(False)
torch._C._jit_set_texpr_fuser_enabled(False)
torch._C._jit_set_nvfuser_enabled(True)
torch._C._jit_set_profiling_executor(True)
torch._C._jit_set_profiling_mode(True)
def with_nnc():
torch._C._jit_override_can_fuse_on_cpu(True)
torch._C._jit_override_can_fuse_on_gpu(True)
torch._C._jit_set_texpr_fuser_enabled(True)
torch._C._jit_set_nvfuser_enabled(False)
torch._C._jit_set_profiling_executor(True)
torch._C._jit_set_profiling_mode(True)
def with_legacy():
torch._C._jit_override_can_fuse_on_cpu(True)
torch._C._jit_override_can_fuse_on_gpu(True)
torch._C._jit_set_texpr_fuser_enabled(False)
torch._C._jit_set_nvfuser_enabled(False)
torch._C._jit_set_profiling_executor(False)
torch._C._jit_set_profiling_mode(False)
@click.command()
@click.option("--operators", default=None)
@click.option("--shapes", default=None)
def run_benchmarks(operators, shapes):
if operators is None:
operators = OPERATORS
else:
operators = [globals()[k] for k in operators.split(",")]
if shapes is None:
shapes = SHAPES
else:
shapes = [globals()[k] for k in shapes.split(",")]
print("fuser,device,operator,shape,time")
results = []
for shape, operator in itertools.product(shapes, operators):
nargs = len(inspect.signature(operator).parameters)
args = shape()
if nargs > len(args):
args = list(args)
args += [args[-1]] * (nargs - len(args))
args = args[:nargs]
args = [arg.to("cuda") for arg in args]
result = benchmark(operator, args)
print(
",".join(
[
"eager",
args[0].device.type,
operator.__name__,
shape.__name__,
micros(result),
]
)
)
def bench(name):
nnc_op = torch.jit.trace(operator, args)
result = benchmark(nnc_op, args)
print(
",".join(
[
name,
args[0].device.type,
operator.__name__,
shape.__name__,
micros(result),
]
)
)
sys.stdout.flush()
with_nnc()
bench("nnc")
with_nvfuser()
bench("nvfuser")
with_legacy()
bench("legacy")
if __name__ == "__main__":
run_benchmarks()