blob: 3c80a6e39fee373812cfb6a06a06bd00f28bdddf [file] [log] [blame]
import argparse
import itertools
import os
# from . import conv # noqa: F401
# from . import normalization # noqa: F401
# from . import pooling # noqa: F401
from . import ( # noqa: F401 # noqa: F401 # noqa: F401 # noqa: F401 # noqa: F401 # noqa: F401 # noqa: F401 # noqa: F401 # noqa: F401
attention,
benchmark,
broadcast,
concat,
elementwise,
matmul,
reduction,
rnn_eltwise,
softmax,
swish,
tensor_engine,
)
def main():
parser = argparse.ArgumentParser(
formatter_class=argparse.RawDescriptionHelpFormatter,
description="""Benchmark operators in specific shapes.
Works only with Python3.\n A few examples:
* benchmark.py: runs all the default configs with all the benchmarks.
* benchmark.py reduce: runs all the default configs with all benchmark with a prefix 'reduce'
* benchmark.py layernorm_fwd_cpu_128_32_128_128: run a particular benchmark in that config""",
)
parser.add_argument(
"benchmark_names",
type=str,
default=None,
nargs="*",
help="name of the benchmark to run",
)
parser.add_argument(
"--device",
type=str,
default="cpu,cuda",
help="a comma separated list of device names",
)
parser.add_argument(
"--mode",
type=str,
default="fwd,both",
help="a comma separated list of running modes",
)
parser.add_argument(
"--dtype",
type=str,
default="float32",
help="a comma separated list of Data Types: {float32[default], float16}",
)
parser.add_argument(
"--input-iter",
type=str,
default=None,
help="a comma separated list of Tensor dimensions that includes a start, \
stop, and increment that can be constant or a power of 2 \
{start:stop:inc,start:stop:pow2}",
)
parser.add_argument(
"--engine",
type=str,
default="pt",
help="the underlying tensor engine. only pt for now",
)
parser.add_argument(
"--jit-mode",
"--jit_mode",
type=str,
default="trace",
help="the jit mode to use: one of {trace, none}",
)
parser.add_argument(
"--cuda-pointwise-loop-levels",
"--cuda_pointwise_loop_levels",
type=int,
default=None,
help="num of loop levesl for Cuda pointwise operations: 2 or 3",
)
parser.add_argument(
"--cuda-pointwise-block-count",
"--cuda_pointwise_block_count",
type=int,
default=None,
help="num of block for Cuda pointwise operations",
)
parser.add_argument(
"--cuda-pointwise-block-size",
"--cuda_pointwise_block_size",
type=int,
default=None,
help="num of blocks for Cuda pointwise operations",
)
parser.add_argument(
"--cuda-fuser",
"--cuda_fuser",
type=str,
default="te",
help="The Cuda fuser backend to use: one of {te, nvf, old, none}",
)
parser.add_argument(
"--output",
type=str,
default="stdout",
help="The output format of the benchmark run {stdout[default], json}",
)
parser.add_argument(
"--print-ir",
action="store_true",
help="Print the IR graph of the Fusion.",
)
parser.add_argument(
"--print-kernel",
action="store_true",
help="Print generated kernel(s).",
)
parser.add_argument(
"--no-dynamic-shape",
action="store_true",
help="Disable shape randomization in dynamic benchmarks.",
)
parser.add_argument(
"--cpu-fusion",
"--cpu_fusion",
default=False,
action="store_true",
help="Enable CPU fusion.",
)
parser.add_argument(
"--cat-wo-conditionals",
"--cat_wo_conditionals",
default=False,
action="store_true",
help="Enable CAT wo conditionals.",
)
args = parser.parse_args()
if args.cuda_fuser == "te":
import torch
torch._C._jit_set_profiling_executor(True)
torch._C._jit_set_texpr_fuser_enabled(True)
torch._C._jit_override_can_fuse_on_gpu(True)
torch._C._get_graph_executor_optimize(True)
elif args.cuda_fuser == "old":
import torch
torch._C._jit_set_profiling_executor(False)
torch._C._jit_set_texpr_fuser_enabled(False)
torch._C._jit_override_can_fuse_on_gpu(True)
elif args.cuda_fuser == "nvf":
import torch
torch._C._jit_set_profiling_executor(True)
torch._C._jit_set_texpr_fuser_enabled(False)
torch._C._jit_set_nvfuser_enabled(True)
torch._C._get_graph_executor_optimize(True)
else:
raise ValueError(f"Undefined fuser: {args.cuda_fuser}")
if args.cpu_fusion:
import torch
torch._C._jit_override_can_fuse_on_cpu(True)
else:
import torch
torch._C._jit_override_can_fuse_on_cpu(False)
if args.cat_wo_conditionals:
import torch
torch._C._jit_cat_wo_conditionals(True)
else:
import torch
torch._C._jit_cat_wo_conditionals(False)
def set_global_threads(num_threads):
os.environ["OMP_NUM_THREADS"] = str(num_threads)
os.environ["MKL_NUM_THREADS"] = str(num_threads)
os.environ["TVM_NUM_THREADS"] = str(num_threads)
os.environ["NNC_NUM_THREADS"] = str(num_threads)
devices = args.device.split(",")
# accept 'gpu' as an alternative as the 'cuda' device
devices = ["cuda" if device == "gpu" else device for device in devices]
cpu_count = 0
for index, device in enumerate(devices):
if device.startswith("cpu"):
cpu_count += 1
if cpu_count > 1:
raise ValueError(
"more than one CPU device is not allowed: %d" % (cpu_count)
)
if device == "cpu":
continue
num_threads_str = device[3:]
try:
# see if the device is in 'cpu1' or 'cpu4' format
num_threads = int(num_threads_str)
set_global_threads(num_threads)
devices[index] = "cpu"
except ValueError:
continue
modes = args.mode.split(",")
datatypes = args.dtype.split(",")
for index, dtype in enumerate(datatypes):
datatypes[index] = getattr(torch, dtype)
if not datatypes[index]:
raise AttributeError(f"DataType: {dtype} is not valid!")
tensor_engine.set_engine_mode(args.engine)
def run_default_configs(bench_cls, allow_skip=True):
for mode, device, dtype, config in itertools.product(
modes, devices, datatypes, bench_cls.default_configs()
):
bench = bench_cls(mode, device, dtype, *config)
bench.output_type = args.output
bench.jit_mode = args.jit_mode
if not bench.is_supported():
if allow_skip:
continue
else:
raise ValueError(
f"attempted to run an unsupported benchmark: {bench.desc()}"
)
bench.run(args)
def run_with_input_iter(bench_cls, input_iter, allow_skip=True):
tensor_dim_specs = input_iter.split(",")
tensor_dim_specs = [dim.split(":") for dim in tensor_dim_specs]
configs = []
for start, stop, inc in tensor_dim_specs:
dim_list = []
if inc == "pow2":
curr = int(start)
while curr <= int(stop):
dim_list.append(curr)
curr <<= 1
elif inc == "pow2+1":
curr = int(start)
while curr <= int(stop):
dim_list.append(curr)
curr -= 1
curr <<= 1
curr += 1
else:
dim_list = list(range(int(start), int(stop) + int(inc), int(inc)))
configs.append(dim_list)
configs = itertools.product(*configs)
for mode, device, dtype, config in itertools.product(
modes, devices, datatypes, list(configs)
):
bench = bench_cls(mode, device, dtype, *config)
bench.output_type = args.output
bench.jit_mode = args.jit_mode
if not bench.is_supported():
if allow_skip:
continue
else:
raise ValueError(
f"attempted to run an unsupported benchmark: {bench.desc()}"
)
bench.run(args)
benchmark_classes = benchmark.benchmark_classes
if not args.benchmark_names:
# by default, run all the benchmarks
for benchmark_cls in benchmark_classes:
run_default_configs(benchmark_cls, allow_skip=True)
else:
for name in args.benchmark_names:
# if the name is the prefix of a benchmark class, run all the benchmarks for that class
match_class_name = False
for bench_cls in benchmark_classes:
if name in bench_cls.module():
match_class_name = True
if (args.input_iter is not None) and bench_cls.input_iterable():
run_with_input_iter(bench_cls, args.input_iter, allow_skip=True)
else:
if args.input_iter is not None:
print(
f"WARNING: Incompatible benchmark class called with input_iter arg: {name}"
)
run_default_configs(bench_cls, allow_skip=True)
if match_class_name:
continue
# if not a class module, parse the config and call it that way
match_class_name = False
for bench_cls in benchmark_classes:
cls_module = bench_cls.module()
if name.startswith(cls_module):
match_class_name = True
if name[len(cls_module)] != "_":
raise ValueError(f"invalid name: {name}")
config_str = name[(len(cls_module) + 1) :]
config = config_str.split("_")
if len(config) < 2:
raise ValueError(f"invalid config: {config}")
mode, device = config[0:2]
# TODO: make sure virtual devices such as 'cpu1' and 'cpu4' are supported.
if mode not in ["fwd", "both"]:
raise ValueError(f"invalid mode: {mode}")
for i, entry in enumerate(config):
try:
value = int(entry)
config[i] = value
except ValueError:
pass
# TODO: output dtype in the config and parse it back from the str
bench = bench_cls(config[0], config[1], torch.float32, *config[2:])
bench.jit_mode = args.jit_mode
bench.output_type = args.output
bench.run(args)
if not match_class_name:
available_classes = ", ".join(
[bench_cls.module() for bench_cls in benchmark_classes]
)
raise ValueError(
f"invalid name: {name}\nAvailable benchmark classes:\n{available_classes}"
)
if __name__ == "__main__":
main()