blob: 80fcb1534f2e961027e3518744b070b6ec4dbeb6 [file] [log] [blame]
# Owner(s): ["module: dynamo"]
import sys
import pytest
import torch._numpy as tnp
def pytest_configure(config):
config.addinivalue_line("markers", "slow: very slow tests")
def pytest_addoption(parser):
parser.addoption("--runslow", action="store_true", help="run slow tests")
parser.addoption("--nonp", action="store_true", help="error when NumPy is accessed")
class Inaccessible:
def __getattribute__(self, attr):
raise RuntimeError(f"Using --nonp but accessed np.{attr}")
def pytest_sessionstart(session):
if session.config.getoption("--nonp"):
sys.modules["numpy"] = Inaccessible()
def pytest_generate_tests(metafunc):
"""
Hook to parametrize test cases
See https://docs.pytest.org/en/6.2.x/parametrize.html#pytest-generate-tests
The logic here allows us to test with both NumPy-proper and torch._numpy.
Normally we'd just test torch._numpy, e.g.
import torch._numpy as np
...
def test_foo():
np.array([42])
...
but this hook allows us to test NumPy-proper as well, e.g.
def test_foo(np):
np.array([42])
...
np is a pytest parameter, which is either NumPy-proper or torch._numpy. This
allows us to sanity check our own tests, so that tested behaviour is
consistent with NumPy-proper.
pytest will have test names respective to the library being tested, e.g.
$ pytest --collect-only
test_foo[torch._numpy]
test_foo[numpy]
"""
np_params = [tnp]
try:
import numpy as np
except ImportError:
pass
else:
if not isinstance(np, Inaccessible): # i.e. --nonp was used
np_params.append(np)
if "np" in metafunc.fixturenames:
metafunc.parametrize("np", np_params)
def pytest_collection_modifyitems(config, items):
if not config.getoption("--runslow"):
skip_slow = pytest.mark.skip(reason="slow test, use --runslow to run")
for item in items:
if "slow" in item.keywords:
item.add_marker(skip_slow)