| import pytest # noqa: F401 |
| |
| default_rnns = [ |
| "cudnn", |
| "aten", |
| "jit", |
| "jit_premul", |
| "jit_premul_bias", |
| "jit_simple", |
| "jit_multilayer", |
| "py", |
| ] |
| default_cnns = ["resnet18", "resnet18_jit", "resnet50", "resnet50_jit"] |
| all_nets = default_rnns + default_cnns |
| |
| |
| def pytest_generate_tests(metafunc): |
| # This creates lists of tests to generate, can be customized |
| if metafunc.cls.__name__ == "TestBenchNetwork": |
| metafunc.parametrize("net_name", all_nets, scope="class") |
| metafunc.parametrize( |
| "executor", [metafunc.config.getoption("executor")], scope="class" |
| ) |
| metafunc.parametrize( |
| "fuser", [metafunc.config.getoption("fuser")], scope="class" |
| ) |
| |
| |
| def pytest_addoption(parser): |
| parser.addoption("--fuser", default="old", help="fuser to use for benchmarks") |
| parser.addoption( |
| "--executor", default="legacy", help="executor to use for benchmarks" |
| ) |