| import os |
| import subprocess |
| |
| from typing import Callable, Dict, List, Optional, Tuple |
| |
| from tools.stats.import_test_stats import get_disabled_tests, get_slow_tests |
| |
| NUM_PROCS = 1 if os.getenv("PYTORCH_TEST_CUDA_MEM_LEAK_CHECK", "0") == "1" else 2 |
| |
| |
| class ShardJob: |
| def __init__(self, test_times: Dict[str, float]): |
| self.test_times = test_times |
| self.serial: List[str] = [] |
| self.parallel: List[str] = [] |
| |
| def get_total_time(self) -> float: |
| procs = [0.0 for _ in range(NUM_PROCS)] |
| for test in self.parallel: |
| test_time = self.test_times.get(test, 0) |
| min_index = procs.index(min(procs)) |
| procs[min_index] += test_time |
| time = max(procs) + sum(self.test_times.get(test, 0) for test in self.serial) |
| return time |
| |
| def convert_to_tuple(self) -> Tuple[float, List[str]]: |
| return (self.get_total_time(), self.serial + self.parallel) |
| |
| |
| def calculate_shards( |
| num_shards: int, |
| tests: List[str], |
| test_file_times: Dict[str, float], |
| must_serial: Optional[Callable[[str], bool]] = None, |
| ) -> List[Tuple[float, List[str]]]: |
| must_serial = must_serial or (lambda x: True) |
| |
| known_tests = [x for x in tests if x in test_file_times] |
| unknown_tests: List[str] = [x for x in tests if x not in known_tests] |
| |
| sorted_tests = sorted(known_tests, key=lambda j: test_file_times[j], reverse=True) |
| |
| sharded_jobs: List[ShardJob] = [ |
| ShardJob(test_file_times) for _ in range(num_shards) |
| ] |
| for test in sorted_tests: |
| if must_serial(test): |
| min_sharded_job = min(sharded_jobs, key=lambda j: j.get_total_time()) |
| min_sharded_job.serial.append(test) |
| else: |
| min_sharded_job = min(sharded_jobs, key=lambda j: j.get_total_time()) |
| min_sharded_job.parallel.append(test) |
| |
| # Round robin the unknown jobs starting with the smallest shard |
| index = min(range(num_shards), key=lambda i: sharded_jobs[i].get_total_time()) |
| for test in unknown_tests: |
| sharded_jobs[index].serial.append(test) |
| index = (index + 1) % num_shards |
| return [job.convert_to_tuple() for job in sharded_jobs] |
| |
| |
| def _query_changed_test_files() -> List[str]: |
| default_branch = f"origin/{os.environ.get('GIT_DEFAULT_BRANCH', 'master')}" |
| cmd = ["git", "diff", "--name-only", default_branch, "HEAD"] |
| proc = subprocess.run(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE) |
| |
| if proc.returncode != 0: |
| raise RuntimeError("Unable to get changed files") |
| |
| lines = proc.stdout.decode().strip().split("\n") |
| lines = [line.strip() for line in lines] |
| return lines |
| |
| |
| def get_reordered_tests(tests: List[str]) -> List[str]: |
| """Get the reordered test filename list based on github PR history or git changed file.""" |
| prioritized_tests: List[str] = [] |
| if len(prioritized_tests) == 0: |
| try: |
| changed_files = _query_changed_test_files() |
| except Exception: |
| # If unable to get changed files from git, quit without doing any sorting |
| return tests |
| |
| prefix = f"test{os.path.sep}" |
| prioritized_tests = [ |
| f for f in changed_files if f.startswith(prefix) and f.endswith(".py") |
| ] |
| prioritized_tests = [f[len(prefix) :] for f in prioritized_tests] |
| prioritized_tests = [f[: -len(".py")] for f in prioritized_tests] |
| print("Prioritized test from test file changes.") |
| |
| bring_to_front = [] |
| the_rest = [] |
| |
| for test in tests: |
| if test in prioritized_tests: |
| bring_to_front.append(test) |
| else: |
| the_rest.append(test) |
| if len(tests) == len(bring_to_front) + len(the_rest): |
| print( |
| f"reordering tests for PR:\n" |
| f"prioritized: {bring_to_front}\nthe rest: {the_rest}\n" |
| ) |
| return bring_to_front + the_rest |
| else: |
| print( |
| f"Something went wrong in CI reordering, expecting total of {len(tests)}:\n" |
| f"but found prioritized: {len(bring_to_front)}\nthe rest: {len(the_rest)}\n" |
| ) |
| return tests |
| |
| |
| def get_test_case_configs(dirpath: str) -> None: |
| get_slow_tests(dirpath=dirpath) |
| get_disabled_tests(dirpath=dirpath) |