244 lines
6.8 KiB
Python
244 lines
6.8 KiB
Python
"""GitHub API helpers: fetch PR data, post reviews/comments/labels/check runs."""
|
|
|
|
from __future__ import annotations
|
|
|
|
import logging
|
|
from typing import TYPE_CHECKING
|
|
|
|
import stamina
|
|
|
|
from .retry import is_retryable
|
|
|
|
if TYPE_CHECKING:
|
|
import httpx
|
|
|
|
log = logging.getLogger(__name__)
|
|
|
|
GITHUB_API = "https://api.github.com"
|
|
MAX_DIFF_CHARS = 60_000
|
|
MAX_PAGES = 50
|
|
|
|
|
|
@stamina.retry(on=is_retryable, attempts=3)
|
|
async def fetch_pr_diff(
|
|
client: httpx.AsyncClient,
|
|
owner: str, repo: str, pr_number: int, token: str,
|
|
) -> str:
|
|
"""Fetch the unified diff for a pull request via GitHub API."""
|
|
resp = await client.get(
|
|
f"{GITHUB_API}/repos/{owner}/{repo}"
|
|
f"/pulls/{pr_number}",
|
|
headers={
|
|
"Authorization": f"token {token}",
|
|
"Accept": "application/vnd.github.diff",
|
|
},
|
|
)
|
|
resp.raise_for_status()
|
|
return resp.text
|
|
|
|
|
|
@stamina.retry(on=is_retryable, attempts=3)
|
|
async def fetch_pr_files(
|
|
client: httpx.AsyncClient,
|
|
owner: str, repo: str, pr_number: int, token: str,
|
|
) -> list[dict]:
|
|
"""Fetch the list of changed files for a pull request (paginated)."""
|
|
files: list[dict] = []
|
|
page = 1
|
|
while True:
|
|
resp = await client.get(
|
|
f"{GITHUB_API}/repos/{owner}/{repo}"
|
|
f"/pulls/{pr_number}/files",
|
|
headers={
|
|
"Authorization": f"token {token}",
|
|
"Accept": "application/vnd.github+json",
|
|
},
|
|
params={"per_page": 100, "page": page},
|
|
)
|
|
resp.raise_for_status()
|
|
batch = resp.json()
|
|
if not batch:
|
|
break
|
|
files.extend(batch)
|
|
page += 1
|
|
if page > MAX_PAGES:
|
|
log.warning(
|
|
"Pagination cap reached fetching files for "
|
|
"%s/%s#%d (%d pages)",
|
|
owner, repo, pr_number, MAX_PAGES,
|
|
)
|
|
break
|
|
return files
|
|
|
|
|
|
@stamina.retry(on=is_retryable, attempts=3)
|
|
async def fetch_pr_details(
|
|
client: httpx.AsyncClient,
|
|
owner: str, repo: str, pr_number: int, token: str,
|
|
) -> dict:
|
|
"""Fetch PR metadata (head/base refs, title, etc.)."""
|
|
resp = await client.get(
|
|
f"{GITHUB_API}/repos/{owner}/{repo}"
|
|
f"/pulls/{pr_number}",
|
|
headers={
|
|
"Authorization": f"token {token}",
|
|
"Accept": "application/vnd.github+json",
|
|
},
|
|
)
|
|
resp.raise_for_status()
|
|
return resp.json()
|
|
|
|
|
|
@stamina.retry(on=is_retryable, attempts=3)
|
|
async def fetch_commit_diff(
|
|
client: httpx.AsyncClient,
|
|
owner: str, repo: str, sha: str, token: str,
|
|
) -> str:
|
|
"""Fetch the unified diff for a single commit via GitHub API."""
|
|
resp = await client.get(
|
|
f"{GITHUB_API}/repos/{owner}/{repo}"
|
|
f"/commits/{sha}",
|
|
headers={
|
|
"Authorization": f"token {token}",
|
|
"Accept": "application/vnd.github.diff",
|
|
},
|
|
)
|
|
resp.raise_for_status()
|
|
return resp.text
|
|
|
|
|
|
@stamina.retry(on=is_retryable, attempts=3)
|
|
async def fetch_repo_labels(
|
|
client: httpx.AsyncClient,
|
|
owner: str, repo: str, token: str,
|
|
) -> list[str]:
|
|
"""Fetch all label names from a repository."""
|
|
labels: list[str] = []
|
|
page = 1
|
|
while True:
|
|
resp = await client.get(
|
|
f"{GITHUB_API}/repos/{owner}/{repo}/labels",
|
|
headers={
|
|
"Authorization": f"token {token}",
|
|
"Accept": "application/vnd.github+json",
|
|
},
|
|
params={"per_page": 100, "page": page},
|
|
)
|
|
resp.raise_for_status()
|
|
batch = resp.json()
|
|
if not batch:
|
|
break
|
|
labels.extend(item["name"] for item in batch)
|
|
page += 1
|
|
if page > MAX_PAGES:
|
|
log.warning(
|
|
"Pagination cap reached fetching labels for "
|
|
"%s/%s (%d pages)",
|
|
owner, repo, MAX_PAGES,
|
|
)
|
|
break
|
|
return labels
|
|
|
|
|
|
@stamina.retry(on=is_retryable, attempts=3)
|
|
async def post_review(
|
|
client: httpx.AsyncClient,
|
|
owner: str, repo: str, pr_number: int,
|
|
body: str, event: str, token: str,
|
|
) -> None:
|
|
"""Submit a PR review (COMMENT, APPROVE, or REQUEST_CHANGES)."""
|
|
resp = await client.post(
|
|
f"{GITHUB_API}/repos/{owner}/{repo}"
|
|
f"/pulls/{pr_number}/reviews",
|
|
headers={
|
|
"Authorization": f"token {token}",
|
|
"Accept": "application/vnd.github+json",
|
|
},
|
|
json={"body": body, "event": event},
|
|
)
|
|
resp.raise_for_status()
|
|
|
|
|
|
@stamina.retry(on=is_retryable, attempts=3)
|
|
async def post_comment(
|
|
client: httpx.AsyncClient,
|
|
owner: str, repo: str, issue_number: int,
|
|
body: str, token: str,
|
|
) -> None:
|
|
"""Post a comment on a PR or issue."""
|
|
resp = await client.post(
|
|
f"{GITHUB_API}/repos/{owner}/{repo}"
|
|
f"/issues/{issue_number}/comments",
|
|
headers={
|
|
"Authorization": f"token {token}",
|
|
"Accept": "application/vnd.github+json",
|
|
},
|
|
json={"body": body},
|
|
)
|
|
resp.raise_for_status()
|
|
|
|
|
|
@stamina.retry(on=is_retryable, attempts=3)
|
|
async def add_labels(
|
|
client: httpx.AsyncClient,
|
|
owner: str, repo: str, issue_number: int,
|
|
labels: list[str], token: str,
|
|
) -> None:
|
|
"""Add labels to an issue or PR."""
|
|
resp = await client.post(
|
|
f"{GITHUB_API}/repos/{owner}/{repo}"
|
|
f"/issues/{issue_number}/labels",
|
|
headers={
|
|
"Authorization": f"token {token}",
|
|
"Accept": "application/vnd.github+json",
|
|
},
|
|
json={"labels": labels},
|
|
)
|
|
resp.raise_for_status()
|
|
|
|
|
|
@stamina.retry(on=is_retryable, attempts=3)
|
|
async def create_check_run(
|
|
client: httpx.AsyncClient,
|
|
owner: str, repo: str, head_sha: str,
|
|
name: str, conclusion: str, output: dict,
|
|
token: str,
|
|
) -> None:
|
|
"""Create a check run on a commit."""
|
|
resp = await client.post(
|
|
f"{GITHUB_API}/repos/{owner}/{repo}/check-runs",
|
|
headers={
|
|
"Authorization": f"token {token}",
|
|
"Accept": "application/vnd.github+json",
|
|
},
|
|
json={
|
|
"name": name,
|
|
"head_sha": head_sha,
|
|
"status": "completed",
|
|
"conclusion": conclusion,
|
|
"output": output,
|
|
},
|
|
)
|
|
resp.raise_for_status()
|
|
|
|
|
|
def build_file_summary(files: list[dict]) -> str:
|
|
"""Build a one-line-per-file summary of changed files."""
|
|
lines: list[str] = []
|
|
for f in files:
|
|
name = f["filename"]
|
|
status = f["status"]
|
|
adds = f.get("additions", 0)
|
|
dels = f.get("deletions", 0)
|
|
lines.append(f" {status:10s} {name} (+{adds}/-{dels})")
|
|
return "\n".join(lines)
|
|
|
|
|
|
def truncate_diff(diff: str, max_chars: int = MAX_DIFF_CHARS) -> str:
|
|
"""Truncate diff to max_chars, appending a note if cut."""
|
|
if len(diff) <= max_chars:
|
|
return diff
|
|
return (
|
|
diff[:max_chars]
|
|
+ "\n\n... (diff truncated, full repo available)"
|
|
)
|