| import pytest # noqa: F401 |
| |
| default_rnns = ['cudnn', 'aten', 'jit', 'jit_premul', 'jit_premul_bias', 'jit_simple', |
| 'jit_multilayer', 'py'] |
| default_cnns = ['resnet18', 'resnet18_jit', 'resnet50', 'resnet50_jit'] |
| all_nets = default_rnns + default_cnns |
| |
| def pytest_generate_tests(metafunc): |
| # This creates lists of tests to generate, can be customized |
| if metafunc.cls.__name__ == "TestBenchNetwork": |
| metafunc.parametrize('net_name', all_nets, scope="class") |
| metafunc.parametrize("executor", [metafunc.config.getoption("executor")], scope="class") |
| metafunc.parametrize("fuser", [metafunc.config.getoption("fuser")], scope="class") |
| |
| def pytest_addoption(parser): |
| parser.addoption("--fuser", default="old", help="fuser to use for benchmarks") |
| parser.addoption("--executor", default="legacy", help="executor to use for benchmarks") |