blob: e88e85f5c5a090c18fa8afe94189a97cb4c8502b [file] [log] [blame]
import json
import os
import subprocess
import time
from pathlib import Path
from typing import Any, cast, Dict, Optional, Tuple
import requests
import rockset # type: ignore[import]
REPO_ROOT = Path(__file__).resolve().parent.parent.parent
QUERY = """
WITH most_recent_strict_commits AS (
SELECT
push.head_commit.id as sha,
FROM
commons.push
WHERE
push.ref = 'refs/heads/viable/strict'
AND push.repository.full_name = 'pytorch/pytorch'
ORDER BY
push._event_time DESC
LIMIT
3
), workflows AS (
SELECT
id
FROM
commons.workflow_run w
INNER JOIN most_recent_strict_commits c on w.head_sha = c.sha
WHERE
w.name != 'periodic'
),
job AS (
SELECT
j.id
FROM
commons.workflow_job j
INNER JOIN workflows w on w.id = j.run_id
WHERE
j.name NOT LIKE '%asan%'
),
duration_per_job AS (
SELECT
test_run.classname,
test_run.name,
job.id,
SUM(time) as time
FROM
commons.test_run_s3 test_run
/* `test_run` is ginormous and `job` is small, so lookup join is essential */
INNER JOIN job ON test_run.job_id = job.id HINT(join_strategy = lookup)
WHERE
/* cpp tests do not populate `file` for some reason. */
/* Exclude them as we don't include them in our slow test infra */
test_run.file IS NOT NULL
/* do some more filtering to cut down on the test_run size */
AND test_run.skipped IS NULL
AND test_run.failure IS NULL
AND test_run.error IS NULL
GROUP BY
test_run.classname,
test_run.name,
job.id
)
SELECT
CONCAT(
name,
' (__main__.',
classname,
')'
) as test_name,
AVG(time) as avg_duration_sec
FROM
duration_per_job
GROUP BY
CONCAT(
name,
' (__main__.',
classname,
')'
)
HAVING
AVG(time) > 60.0
ORDER BY
test_name
"""
UPDATEBOT_TOKEN = os.environ["UPDATEBOT_TOKEN"]
PYTORCHBOT_TOKEN = os.environ["PYTORCHBOT_TOKEN"]
def git_api(
url: str, params: Dict[str, str], type: str = "get", token: str = UPDATEBOT_TOKEN
) -> Any:
headers = {
"Accept": "application/vnd.github.v3+json",
"Authorization": f"token {token}",
}
if type == "post":
return requests.post(
f"https://api.github.com{url}",
data=json.dumps(params),
headers=headers,
).json()
elif type == "patch":
return requests.patch(
f"https://api.github.com{url}",
data=json.dumps(params),
headers=headers,
).json()
else:
return requests.get(
f"https://api.github.com{url}",
params=params,
headers=headers,
).json()
def make_pr(source_repo: str, params: Dict[str, Any]) -> int:
response = git_api(f"/repos/{source_repo}/pulls", params, type="post")
print(f"made pr {response['html_url']}")
return cast(int, response["number"])
def approve_pr(source_repo: str, pr_number: int) -> None:
params = {"event": "APPROVE"}
# use pytorchbot to approve the pr
git_api(
f"/repos/{source_repo}/pulls/{pr_number}/reviews",
params,
type="post",
token=PYTORCHBOT_TOKEN,
)
def make_comment(source_repo: str, pr_number: int, msg: str) -> None:
params = {"body": msg}
# comment with pytorchbot because pytorchmergebot gets ignored
git_api(
f"/repos/{source_repo}/issues/{pr_number}/comments",
params,
type="post",
token=PYTORCHBOT_TOKEN,
)
def search_for_open_pr(
source_repo: str, search_string: str
) -> Optional[Tuple[int, str]]:
params = {
"q": f"is:pr is:open in:title author:pytorchupdatebot repo:{source_repo} {search_string}",
"sort": "created",
}
response = git_api("/search/issues", params)
if response["total_count"] != 0:
# pr does exist
pr_num = response["items"][0]["number"]
link = response["items"][0]["html_url"]
response = git_api(f"/repos/{source_repo}/pulls/{pr_num}", {})
branch_name = response["head"]["ref"]
print(
f"pr does exist, number is {pr_num}, branch name is {branch_name}, link is {link}"
)
return pr_num, branch_name
return None
if __name__ == "__main__":
rs_client = rockset.RocksetClient(
host="api.usw2a1.rockset.com", api_key=os.environ["ROCKSET_API_KEY"]
)
results = rs_client.sql(QUERY).results
slow_tests = {row["test_name"]: row["avg_duration_sec"] for row in results}
with open(REPO_ROOT / "test" / "slow_tests.json", "w") as f:
json.dump(slow_tests, f, indent=2)
branch_name = f"update_slow_tests_{int(time.time())}"
pr_num = None
open_pr = search_for_open_pr("pytorch/pytorch", "Update slow tests")
if open_pr is not None:
pr_num, branch_name = open_pr
subprocess.run(["git", "checkout", "-b", branch_name], cwd=REPO_ROOT)
subprocess.run(["git", "add", "test/slow_tests.json"], cwd=REPO_ROOT)
subprocess.run(["git", "commit", "-m", "Update slow tests"], cwd=REPO_ROOT)
subprocess.run(
f"git push --set-upstream origin {branch_name} -f".split(), cwd=REPO_ROOT
)
params = {
"title": "Update slow tests",
"head": branch_name,
"base": "main",
"body": "This PR is auto-generated weekly by [this action](https://github.com/pytorch/pytorch/blob/main/"
+ ".github/workflows/weekly.yml).\nUpdate the list of slow tests.",
}
if pr_num is None:
# no existing pr, so make a new one and approve it
pr_num = make_pr("pytorch/pytorch", params)
time.sleep(5)
approve_pr("pytorch/pytorch", pr_num)
make_comment("pytorch/pytorch", pr_num, "@pytorchbot merge")