| import os |
| import subprocess |
| |
| from typing import Dict, List, Tuple |
| |
| from tools.stats.import_test_stats import get_disabled_tests, get_slow_tests |
| |
| |
| def calculate_shards( |
| num_shards: int, tests: List[str], job_times: Dict[str, float] |
| ) -> List[Tuple[float, List[str]]]: |
| filtered_job_times: Dict[str, float] = dict() |
| unknown_jobs: List[str] = [] |
| for test in tests: |
| if test in job_times: |
| filtered_job_times[test] = job_times[test] |
| else: |
| unknown_jobs.append(test) |
| |
| # The following attempts to implement a partition approximation greedy algorithm |
| # See more at https://en.wikipedia.org/wiki/Greedy_number_partitioning |
| sorted_jobs = sorted( |
| filtered_job_times, key=lambda j: filtered_job_times[j], reverse=True |
| ) |
| sharded_jobs: List[Tuple[float, List[str]]] = [(0.0, []) for _ in range(num_shards)] |
| for job in sorted_jobs: |
| min_shard_index = sorted(range(num_shards), key=lambda i: sharded_jobs[i][0])[0] |
| curr_shard_time, curr_shard_jobs = sharded_jobs[min_shard_index] |
| curr_shard_jobs.append(job) |
| sharded_jobs[min_shard_index] = ( |
| curr_shard_time + filtered_job_times[job], |
| curr_shard_jobs, |
| ) |
| |
| # Round robin the unknown jobs starting with the smallest shard |
| index = sorted(range(num_shards), key=lambda i: sharded_jobs[i][0])[0] |
| for job in unknown_jobs: |
| sharded_jobs[index][1].append(job) |
| index = (index + 1) % num_shards |
| return sharded_jobs |
| |
| |
| def _query_changed_test_files() -> List[str]: |
| default_branch = f"origin/{os.environ.get('GIT_DEFAULT_BRANCH', 'master')}" |
| cmd = ["git", "diff", "--name-only", default_branch, "HEAD"] |
| proc = subprocess.run(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE) |
| |
| if proc.returncode != 0: |
| raise RuntimeError("Unable to get changed files") |
| |
| lines = proc.stdout.decode().strip().split("\n") |
| lines = [line.strip() for line in lines] |
| return lines |
| |
| |
| def get_reordered_tests(tests: List[str]) -> List[str]: |
| """Get the reordered test filename list based on github PR history or git changed file.""" |
| prioritized_tests: List[str] = [] |
| if len(prioritized_tests) == 0: |
| try: |
| changed_files = _query_changed_test_files() |
| except Exception: |
| # If unable to get changed files from git, quit without doing any sorting |
| return tests |
| |
| prefix = f"test{os.path.sep}" |
| prioritized_tests = [ |
| f for f in changed_files if f.startswith(prefix) and f.endswith(".py") |
| ] |
| prioritized_tests = [f[len(prefix) :] for f in prioritized_tests] |
| prioritized_tests = [f[: -len(".py")] for f in prioritized_tests] |
| print("Prioritized test from test file changes.") |
| |
| bring_to_front = [] |
| the_rest = [] |
| |
| for test in tests: |
| if test in prioritized_tests: |
| bring_to_front.append(test) |
| else: |
| the_rest.append(test) |
| if len(tests) == len(bring_to_front) + len(the_rest): |
| print( |
| f"reordering tests for PR:\n" |
| f"prioritized: {bring_to_front}\nthe rest: {the_rest}\n" |
| ) |
| return bring_to_front + the_rest |
| else: |
| print( |
| f"Something went wrong in CI reordering, expecting total of {len(tests)}:\n" |
| f"but found prioritized: {len(bring_to_front)}\nthe rest: {len(the_rest)}\n" |
| ) |
| return tests |
| |
| |
| def get_test_case_configs(dirpath: str) -> None: |
| get_slow_tests(dirpath=dirpath) |
| get_disabled_tests(dirpath=dirpath) |