| #!/usr/bin/env python3 |
| |
| import argparse |
| import json |
| import re |
| import subprocess |
| from bisect import bisect_right |
| from collections import defaultdict |
| from typing import ( |
| Callable, |
| DefaultDict, |
| Generic, |
| List, |
| Optional, |
| Pattern, |
| Sequence, |
| TypeVar, |
| cast, |
| ) |
| |
| from typing_extensions import TypedDict |
| |
| |
| class Hunk(TypedDict): |
| old_start: int |
| old_count: int |
| new_start: int |
| new_count: int |
| |
| |
| class Diff(TypedDict): |
| old_filename: Optional[str] |
| hunks: List[Hunk] |
| |
| |
| # @@ -start,count +start,count @@ |
| hunk_pattern = r"^@@\s+-(\d+)(?:,(\d+))?\s+\+(\d+)(?:,(\d+))?\s+@@" |
| |
| |
| def parse_diff(diff: str) -> Diff: |
| name = None |
| name_found = False |
| hunks: List[Hunk] = [] |
| for line in diff.splitlines(): |
| hunk_match = re.match(hunk_pattern, line) |
| if name_found: |
| if hunk_match: |
| old_start, old_count, new_start, new_count = hunk_match.groups() |
| hunks.append( |
| { |
| "old_start": int(old_start), |
| "old_count": int(old_count or "1"), |
| "new_start": int(new_start), |
| "new_count": int(new_count or "1"), |
| } |
| ) |
| else: |
| assert not hunk_match |
| name_match = re.match(r"^--- (?:(?:/dev/null)|(?:a/(.*)))$", line) |
| if name_match: |
| name_found = True |
| (name,) = name_match.groups() |
| return { |
| "old_filename": name, |
| "hunks": hunks, |
| } |
| |
| |
| T = TypeVar("T") |
| U = TypeVar("U") |
| |
| |
| # we want to use bisect.bisect_right to find the closest hunk to a given |
| # line number, but the bisect module won't have a key function until |
| # Python 3.10 https://github.com/python/cpython/pull/20556 so we make an |
| # O(1) wrapper around the list of hunks that makes it pretend to just be |
| # a list of line numbers |
| # https://gist.github.com/ericremoreynolds/2d80300dabc70eebc790 |
| class KeyifyList(Generic[T, U]): |
| def __init__(self, inner: List[T], key: Callable[[T], U]) -> None: |
| self.inner = inner |
| self.key = key |
| |
| def __len__(self) -> int: |
| return len(self.inner) |
| |
| def __getitem__(self, k: int) -> U: |
| return self.key(self.inner[k]) |
| |
| |
| def translate(diff: Diff, line_number: int) -> Optional[int]: |
| if line_number < 1: |
| return None |
| |
| hunks = diff["hunks"] |
| if not hunks: |
| return line_number |
| |
| keyified = KeyifyList( |
| hunks, lambda hunk: hunk["new_start"] + (0 if hunk["new_count"] > 0 else 1) |
| ) |
| i = bisect_right(cast(Sequence[int], keyified), line_number) |
| if i < 1: |
| return line_number |
| |
| hunk = hunks[i - 1] |
| d = line_number - (hunk["new_start"] + (hunk["new_count"] or 1)) |
| return None if d < 0 else hunk["old_start"] + (hunk["old_count"] or 1) + d |
| |
| |
| # we use camelCase here because this will be output as JSON and so the |
| # field names need to match the group names from here: |
| # https://github.com/pytorch/add-annotations-github-action/blob/3ab7d7345209f5299d53303f7aaca7d3bc09e250/action.yml#L23 |
| class Annotation(TypedDict): |
| filename: str |
| lineNumber: int |
| columnNumber: int |
| errorCode: str |
| errorDesc: str |
| |
| |
| def parse_annotation(regex: Pattern[str], line: str) -> Optional[Annotation]: |
| m = re.match(regex, line) |
| if m: |
| try: |
| line_number = int(m.group("lineNumber")) |
| column_number = int(m.group("columnNumber")) |
| except ValueError: |
| return None |
| return { |
| "filename": m.group("filename"), |
| "lineNumber": line_number, |
| "columnNumber": column_number, |
| "errorCode": m.group("errorCode"), |
| "errorDesc": m.group("errorDesc"), |
| } |
| else: |
| return None |
| |
| |
| def translate_all( |
| *, lines: List[str], regex: Pattern[str], commit: str |
| ) -> List[Annotation]: |
| ann_dict: DefaultDict[str, List[Annotation]] = defaultdict(list) |
| for line in lines: |
| annotation = parse_annotation(regex, line) |
| if annotation is not None: |
| ann_dict[annotation["filename"]].append(annotation) |
| ann_list = [] |
| for filename, annotations in ann_dict.items(): |
| raw_diff = subprocess.check_output( |
| ["git", "diff-index", "--unified=0", commit, filename], |
| encoding="utf-8", |
| ) |
| diff = parse_diff(raw_diff) if raw_diff.strip() else None |
| # if there is a diff but it doesn't list an old filename, that |
| # means the file is absent in the commit we're targeting, so we |
| # skip it |
| if not (diff and not diff["old_filename"]): |
| for annotation in annotations: |
| line_number: Optional[int] = annotation["lineNumber"] |
| if diff: |
| annotation["filename"] = cast(str, diff["old_filename"]) |
| line_number = translate(diff, cast(int, line_number)) |
| if line_number: |
| annotation["lineNumber"] = line_number |
| ann_list.append(annotation) |
| return ann_list |
| |
| |
| def main() -> None: |
| parser = argparse.ArgumentParser() |
| parser.add_argument("--file") |
| parser.add_argument("--regex") |
| parser.add_argument("--commit") |
| args = parser.parse_args() |
| with open(args.file, "r") as f: |
| lines = f.readlines() |
| print(json.dumps(translate_all(lines=lines, regex=args.regex, commit=args.commit))) |
| |
| |
| if __name__ == "__main__": |
| main() |