blob: 12d766ae74862c1b8f1740a725c1d3c52b20bbe9 [file] [log] [blame]
import timeit
import torch
import torch.nn.functional as F
torch._C._jit_override_can_fuse_on_cpu(True)
torch._C._debug_set_fusion_group_inlining(False)
torch.set_num_threads(1)
def hardswish(x):
return x * torch.clamp(x + 3.0, 0.0, 6.0) / 6.0
unary_ops = [
hardswish,
torch._C._nn.hardswish,
torch.sigmoid,
torch.reciprocal,
torch.neg,
torch.relu,
torch.isnan,
torch.log,
torch.log10,
torch.log1p,
torch.log2,
torch.exp,
torch.expm1,
torch.erf,
torch.erfc,
torch.cos,
torch.sin,
torch.tan,
torch.acos,
torch.asin,
torch.cosh,
torch.sinh,
torch.atan,
torch.tanh,
torch.sqrt,
torch.rsqrt,
torch.abs,
torch.ceil,
torch.floor,
torch.round,
torch.trunc,
torch.lgamma,
]
print("{:20s} {:>10s} {:>10s} {:>10s}".format("op", "eager", "nnc", "speedup"))
for op in unary_ops:
x = torch.rand((1024, 1024))
traced = torch.jit.trace(lambda x: op(x), (x))
# Warmup.
warmup_iters = 8
for _ in range(warmup_iters):
op(x)
traced(x)
# Validate result.
torch.testing.assert_close(op(x), traced(x))
# Benchmark.
bench_iters = 100
teager = timeit.timeit(stmt="op(x)", globals=globals(), number=bench_iters)
tjit = timeit.timeit(stmt="traced(x)", globals=globals(), number=bench_iters)
print(f"{op.__name__:20s} {teager:10.3f} {tjit:10.3f} {teager/tjit:10.2f}")
def test_batch_norm():
op = F.batch_norm
print("{:20s} {:20s} {:>10s} {:>10s} {:>10s}".format("op", "shape", "eager", "nnc", "speedup"))
batch_norm_shapes = [
[1, 64, 112, 112],
[1, 256, 14, 14],
[1, 128, 28, 28],
[1, 64, 56, 56],
[1, 512, 7, 7],
[5, 64, 112, 112],
[5, 256, 14, 14],
[5, 128, 28, 28],
[5, 64, 56, 56],
[5, 512, 7, 7]]
for n, c, h, w in batch_norm_shapes:
x = torch.rand((n, c, h, w))
y = torch.rand((c))
z = torch.rand((c))
traced = torch.jit.trace(lambda x, y, z: op(x, y, z), (x, y, z))
# Warmup.
warmup_iters = 8
for _ in range(warmup_iters):
op(x, y, z)
traced(x, y, z)
# Validate result.
torch.testing.assert_close(op(x, y, z), traced(x, y, z))
# Benchmark.
bench_iters = 100
teager = timeit.timeit(stmt="op(x, y, z)", globals=locals(), number=bench_iters)
tjit = timeit.timeit(stmt="traced(x, y, z)", globals=locals(), number=bench_iters)
print(f"{op.__name__:20s} ({n:>3d}, {c:>3d}, {h:>3d}, {w:>3d}) {teager:10.3f} {tjit:10.3f} {teager/tjit:10.2f}")
test_batch_norm()