blob: 8e4e3005ef371307e32b2948fb06c4da62e1c0d9 [file] [log] [blame]
#!/usr/bin/env python3
import argparse
import json
import re
import subprocess
from bisect import bisect_right
from collections import defaultdict
from typing import (
Callable,
DefaultDict,
Generic,
List,
Optional,
Pattern,
Sequence,
TypeVar,
cast,
)
from typing_extensions import TypedDict
class Hunk(TypedDict):
old_start: int
old_count: int
new_start: int
new_count: int
class Diff(TypedDict):
old_filename: Optional[str]
hunks: List[Hunk]
# @@ -start,count +start,count @@
hunk_pattern = r"^@@\s+-(\d+)(?:,(\d+))?\s+\+(\d+)(?:,(\d+))?\s+@@"
def parse_diff(diff: str) -> Diff:
name = None
name_found = False
hunks: List[Hunk] = []
for line in diff.splitlines():
hunk_match = re.match(hunk_pattern, line)
if name_found:
if hunk_match:
old_start, old_count, new_start, new_count = hunk_match.groups()
hunks.append(
{
"old_start": int(old_start),
"old_count": int(old_count or "1"),
"new_start": int(new_start),
"new_count": int(new_count or "1"),
}
)
else:
assert not hunk_match
name_match = re.match(r"^--- (?:(?:/dev/null)|(?:a/(.*)))$", line)
if name_match:
name_found = True
(name,) = name_match.groups()
return {
"old_filename": name,
"hunks": hunks,
}
T = TypeVar("T")
U = TypeVar("U")
# we want to use bisect.bisect_right to find the closest hunk to a given
# line number, but the bisect module won't have a key function until
# Python 3.10 https://github.com/python/cpython/pull/20556 so we make an
# O(1) wrapper around the list of hunks that makes it pretend to just be
# a list of line numbers
# https://gist.github.com/ericremoreynolds/2d80300dabc70eebc790
class KeyifyList(Generic[T, U]):
def __init__(self, inner: List[T], key: Callable[[T], U]) -> None:
self.inner = inner
self.key = key
def __len__(self) -> int:
return len(self.inner)
def __getitem__(self, k: int) -> U:
return self.key(self.inner[k])
def translate(diff: Diff, line_number: int) -> Optional[int]:
if line_number < 1:
return None
hunks = diff["hunks"]
if not hunks:
return line_number
keyified = KeyifyList(
hunks, lambda hunk: hunk["new_start"] + (0 if hunk["new_count"] > 0 else 1)
)
i = bisect_right(cast(Sequence[int], keyified), line_number)
if i < 1:
return line_number
hunk = hunks[i - 1]
d = line_number - (hunk["new_start"] + (hunk["new_count"] or 1))
return None if d < 0 else hunk["old_start"] + (hunk["old_count"] or 1) + d
# we use camelCase here because this will be output as JSON and so the
# field names need to match the group names from here:
# https://github.com/pytorch/add-annotations-github-action/blob/3ab7d7345209f5299d53303f7aaca7d3bc09e250/action.yml#L23
class Annotation(TypedDict):
filename: str
lineNumber: int
columnNumber: int
errorCode: str
errorDesc: str
def parse_annotation(regex: Pattern[str], line: str) -> Optional[Annotation]:
m = re.match(regex, line)
if m:
try:
line_number = int(m.group("lineNumber"))
column_number = int(m.group("columnNumber"))
except ValueError:
return None
return {
"filename": m.group("filename"),
"lineNumber": line_number,
"columnNumber": column_number,
"errorCode": m.group("errorCode"),
"errorDesc": m.group("errorDesc"),
}
else:
return None
def translate_all(
*, lines: List[str], regex: Pattern[str], commit: str
) -> List[Annotation]:
ann_dict: DefaultDict[str, List[Annotation]] = defaultdict(list)
for line in lines:
annotation = parse_annotation(regex, line)
if annotation is not None:
ann_dict[annotation["filename"]].append(annotation)
ann_list = []
for filename, annotations in ann_dict.items():
raw_diff = subprocess.check_output(
["git", "diff-index", "--unified=0", commit, filename],
encoding="utf-8",
)
diff = parse_diff(raw_diff) if raw_diff.strip() else None
# if there is a diff but it doesn't list an old filename, that
# means the file is absent in the commit we're targeting, so we
# skip it
if not (diff and not diff["old_filename"]):
for annotation in annotations:
line_number: Optional[int] = annotation["lineNumber"]
if diff:
annotation["filename"] = cast(str, diff["old_filename"])
line_number = translate(diff, cast(int, line_number))
if line_number:
annotation["lineNumber"] = line_number
ann_list.append(annotation)
return ann_list
def main() -> None:
parser = argparse.ArgumentParser()
parser.add_argument("--file")
parser.add_argument("--regex")
parser.add_argument("--commit")
args = parser.parse_args()
with open(args.file, "r") as f:
lines = f.readlines()
print(json.dumps(translate_all(lines=lines, regex=args.regex, commit=args.commit)))
if __name__ == "__main__":
main()