| from __future__ import annotations |
| |
| import math |
| import os |
| import subprocess |
| from pathlib import Path |
| from typing import Callable, Sequence |
| |
| from tools.stats.import_test_stats import get_disabled_tests |
| from tools.testing.test_run import ShardedTest, TestRun |
| |
| |
| REPO_ROOT = Path(__file__).resolve().parent.parent.parent |
| |
| IS_MEM_LEAK_CHECK = os.getenv("PYTORCH_TEST_CUDA_MEM_LEAK_CHECK", "0") == "1" |
| BUILD_ENVIRONMENT = os.getenv("BUILD_ENVIRONMENT", "") |
| USE_3_PROCS = "sm86" in BUILD_ENVIRONMENT or "cuda" not in BUILD_ENVIRONMENT |
| |
| # NUM_PROCS_FOR_SHARDING_CALC must remain consistent across all shards of a job |
| # to ensure that sharding is consistent, NUM_PROCS is the actual number of procs |
| # used to run tests. If they are not equal, the only consequence should be |
| # unequal shards. |
| IS_ROCM = os.path.exists("/opt/rocm") |
| NUM_PROCS = 1 if IS_MEM_LEAK_CHECK else 3 if USE_3_PROCS else 2 |
| NUM_PROCS_FOR_SHARDING_CALC = NUM_PROCS if not IS_ROCM or IS_MEM_LEAK_CHECK else 2 |
| THRESHOLD = 60 * 10 # 10 minutes |
| |
| # See Note [ROCm parallel CI testing] |
| # Special logic for ROCm GHA runners to query number of GPUs available. |
| # torch.version.hip was not available to check if this was a ROCm self-hosted runner. |
| # Must check for ROCm runner in another way. We look for /opt/rocm directory. |
| if IS_ROCM and not IS_MEM_LEAK_CHECK: |
| try: |
| # This is the same logic used in GHA health check, see .github/templates/common.yml.j2 |
| lines = ( |
| subprocess.check_output(["rocminfo"], encoding="ascii").strip().split("\n") |
| ) |
| count = 0 |
| for line in lines: |
| if " gfx" in line: |
| count += 1 |
| assert count > 0 # there must be at least 1 GPU |
| # Limiting to 8 GPUs(PROCS) |
| NUM_PROCS = min(count, 8) |
| except subprocess.CalledProcessError as e: |
| # The safe default for ROCm GHA runners is to run tests serially. |
| NUM_PROCS = 1 |
| |
| |
| class ShardJob: |
| def __init__(self) -> None: |
| self.serial: list[ShardedTest] = [] |
| self.parallel: list[ShardedTest] = [] |
| |
| def get_total_time(self) -> float: |
| """Default is the value for which to substitute if a test has no time""" |
| procs = [0.0 for _ in range(NUM_PROCS_FOR_SHARDING_CALC)] |
| for test in self.parallel: |
| min_index = procs.index(min(procs)) |
| procs[min_index] += test.get_time() |
| time = max(procs) + sum(test.get_time() for test in self.serial) |
| return time |
| |
| def convert_to_tuple(self) -> tuple[float, list[ShardedTest]]: |
| return (self.get_total_time(), self.serial + self.parallel) |
| |
| |
| def get_with_pytest_shard( |
| tests: Sequence[TestRun], |
| test_file_times: dict[str, float], |
| test_class_times: dict[str, dict[str, float]] | None, |
| ) -> list[ShardedTest]: |
| sharded_tests: list[ShardedTest] = [] |
| |
| for test in tests: |
| duration = get_duration(test, test_file_times, test_class_times or {}) |
| |
| if duration and duration > THRESHOLD: |
| num_shards = math.ceil(duration / THRESHOLD) |
| for i in range(num_shards): |
| sharded_tests.append( |
| ShardedTest(test, i + 1, num_shards, duration / num_shards) |
| ) |
| else: |
| sharded_tests.append(ShardedTest(test, 1, 1, duration)) |
| return sharded_tests |
| |
| |
| def get_duration( |
| test: TestRun, |
| test_file_times: dict[str, float], |
| test_class_times: dict[str, dict[str, float]], |
| ) -> float | None: |
| """Calculate the time for a TestRun based on the given test_file_times and |
| test_class_times. Returns None if the time is unknown.""" |
| file_duration = test_file_times.get(test.test_file, None) |
| if test.is_full_file(): |
| return file_duration |
| |
| def get_duration_for_classes( |
| test_file: str, test_classes: frozenset[str] |
| ) -> float | None: |
| duration: float = 0 |
| |
| for test_class in test_classes: |
| class_duration = test_class_times.get(test_file, {}).get(test_class, None) |
| if class_duration is None: |
| return None |
| duration += class_duration |
| return duration |
| |
| included = test.included() |
| excluded = test.excluded() |
| included_classes_duration = get_duration_for_classes(test.test_file, included) |
| excluded_classes_duration = get_duration_for_classes(test.test_file, excluded) |
| |
| if included_classes_duration is None or excluded_classes_duration is None: |
| # Didn't get the time for all classes, so time is unknown |
| return None |
| |
| if included: |
| return included_classes_duration |
| assert ( |
| excluded |
| ), f"TestRun {test} is not full file but doesn't have included or excluded classes" |
| if file_duration is None: |
| return None |
| return file_duration - excluded_classes_duration |
| |
| |
| def shard( |
| sharded_jobs: list[ShardJob], |
| pytest_sharded_tests: Sequence[ShardedTest], |
| estimated_time_limit: float | None = None, |
| serial: bool = False, |
| ) -> None: |
| # Modifies sharded_jobs in place |
| if len(sharded_jobs) == 0: |
| assert ( |
| len(pytest_sharded_tests) == 0 |
| ), "No shards provided but there are tests to shard" |
| return |
| |
| round_robin_index = 0 |
| |
| def _get_min_sharded_job( |
| sharded_jobs: list[ShardJob], test: ShardedTest |
| ) -> ShardJob: |
| if test.time is None: |
| nonlocal round_robin_index |
| job = sharded_jobs[round_robin_index % len(sharded_jobs)] |
| round_robin_index += 1 |
| return job |
| return min(sharded_jobs, key=lambda j: j.get_total_time()) |
| |
| def _shard_serial( |
| tests: Sequence[ShardedTest], sharded_jobs: list[ShardJob] |
| ) -> None: |
| assert estimated_time_limit is not None, "Estimated time limit must be provided" |
| new_sharded_jobs = sharded_jobs |
| for test in tests: |
| if ( |
| len(sharded_jobs) > 1 |
| and sharded_jobs[-1].get_total_time() > estimated_time_limit |
| ): |
| new_sharded_jobs = sharded_jobs[:-1] |
| min_sharded_job = _get_min_sharded_job(new_sharded_jobs, test) |
| min_sharded_job.serial.append(test) |
| |
| def _shard_parallel( |
| tests: Sequence[ShardedTest], sharded_jobs: list[ShardJob] |
| ) -> None: |
| for test in tests: |
| min_sharded_job = _get_min_sharded_job(sharded_jobs, test) |
| min_sharded_job.parallel.append(test) |
| |
| if serial: |
| _shard_serial(pytest_sharded_tests, sharded_jobs) |
| else: |
| _shard_parallel(pytest_sharded_tests, sharded_jobs) |
| |
| return |
| |
| |
| def calculate_shards( |
| num_shards: int, |
| tests: Sequence[TestRun], |
| test_file_times: dict[str, float], |
| test_class_times: dict[str, dict[str, float]] | None, |
| must_serial: Callable[[str], bool] | None = None, |
| sort_by_time: bool = True, |
| ) -> list[tuple[float, list[ShardedTest]]]: |
| must_serial = must_serial or (lambda x: True) |
| test_class_times = test_class_times or {} |
| |
| # Divide tests into pytest shards |
| if sort_by_time: |
| known_tests = [ |
| x |
| for x in tests |
| if get_duration(x, test_file_times, test_class_times) is not None |
| ] |
| unknown_tests = [x for x in tests if x not in known_tests] |
| |
| pytest_sharded_tests = sorted( |
| get_with_pytest_shard(known_tests, test_file_times, test_class_times), |
| key=lambda j: j.get_time(), |
| reverse=True, |
| ) + get_with_pytest_shard(unknown_tests, test_file_times, test_class_times) |
| else: |
| pytest_sharded_tests = get_with_pytest_shard( |
| tests, test_file_times, test_class_times |
| ) |
| del tests |
| |
| serial_tests = [test for test in pytest_sharded_tests if must_serial(test.name)] |
| parallel_tests = [test for test in pytest_sharded_tests if test not in serial_tests] |
| |
| serial_time = sum(test.get_time() for test in serial_tests) |
| parallel_time = sum(test.get_time() for test in parallel_tests) |
| total_time = serial_time + parallel_time / NUM_PROCS_FOR_SHARDING_CALC |
| estimated_time_per_shard = total_time / num_shards |
| # Separate serial tests from parallel tests as much as possible to maximize |
| # parallelism by putting all the serial tests on the first num_serial_shards |
| # shards. The estimated_time_limit is the estimated time it should take for |
| # the least filled serial shard. Ex if we have 8 min of serial tests, 20 min |
| # of parallel tests, 6 shards, and 2 procs per machine, we would expect each |
| # machine to take 3 min and should aim for 3 serial shards, with shards 1 |
| # and 2 taking 3 min and shard 3 taking 2 min. The estimated time limit |
| # would be 2 min. This ensures that the first few shard contains as many |
| # serial tests as possible and as few parallel tests as possible. The least |
| # filled/last (in the example, the 3rd) shard may contain a lot of both |
| # serial and parallel tests. |
| estimated_time_limit = 0.0 |
| if estimated_time_per_shard != 0: |
| estimated_time_limit = serial_time % estimated_time_per_shard |
| if estimated_time_limit <= 0.01: |
| estimated_time_limit = estimated_time_per_shard |
| if total_time == 0: |
| num_serial_shards = num_shards |
| else: |
| num_serial_shards = max(math.ceil(serial_time / total_time * num_shards), 1) |
| |
| sharded_jobs = [ShardJob() for _ in range(num_shards)] |
| shard( |
| sharded_jobs=sharded_jobs[:num_serial_shards], |
| pytest_sharded_tests=serial_tests, |
| estimated_time_limit=estimated_time_limit, |
| serial=True, |
| ) |
| shard( |
| sharded_jobs=sharded_jobs, |
| pytest_sharded_tests=parallel_tests, |
| serial=False, |
| ) |
| |
| return [job.convert_to_tuple() for job in sharded_jobs] |
| |
| |
| def get_test_case_configs(dirpath: str) -> None: |
| get_disabled_tests(dirpath=dirpath) |