| """ |
| Custom pytest shard plugin |
| https://github.com/AdamGleave/pytest-shard/blob/64610a08dac6b0511b6d51cf895d0e1040d162ad/pytest_shard/pytest_shard.py#L1 |
| Modifications: |
| * shards are now 1 indexed instead of 0 indexed |
| * option for printing items in shard |
| """ |
| |
| import hashlib |
| |
| from _pytest.config.argparsing import Parser |
| |
| |
| def pytest_addoptions(parser: Parser): |
| """Add options to control sharding.""" |
| group = parser.getgroup("shard") |
| group.addoption( |
| "--shard-id", dest="shard_id", type=int, default=1, help="Number of this shard." |
| ) |
| group.addoption( |
| "--num-shards", |
| dest="num_shards", |
| type=int, |
| default=1, |
| help="Total number of shards.", |
| ) |
| group.addoption( |
| "--print-items", |
| dest="print_items", |
| action="store_true", |
| default=False, |
| help="Print out the items being tested in this shard.", |
| ) |
| |
| |
| class PytestShardPlugin: |
| def __init__(self, config): |
| self.config = config |
| |
| def pytest_report_collectionfinish(self, config, items) -> str: |
| """Log how many and which items are tested in this shard.""" |
| msg = f"Running {len(items)} items in this shard" |
| if config.getoption("print_items"): |
| msg += ": " + ", ".join([item.nodeid for item in items]) |
| return msg |
| |
| def sha256hash(self, x: str) -> int: |
| return int.from_bytes(hashlib.sha256(x.encode()).digest(), "little") |
| |
| def filter_items_by_shard(self, items, shard_id: int, num_shards: int): |
| """Computes `items` that should be tested in `shard_id` out of `num_shards` total shards.""" |
| new_items = [ |
| item |
| for item in items |
| if self.sha256hash(item.nodeid) % num_shards == shard_id - 1 |
| ] |
| return new_items |
| |
| def pytest_collection_modifyitems(self, config, items): |
| """Mutate the collection to consist of just items to be tested in this shard.""" |
| shard_id = config.getoption("shard_id") |
| shard_total = config.getoption("num_shards") |
| if shard_id < 1 or shard_id > shard_total: |
| raise ValueError( |
| f"{shard_id} is not a valid shard ID out of {shard_total} total shards" |
| ) |
| |
| items[:] = self.filter_items_by_shard(items, shard_id, shard_total) |