Fix CI: mypy errors, ruff formatting, switch to prek (#22)

* Fix mypy errors and apply ruff formatting across packages

Fix ast.FunctionDef calls missing type_params for Python 3.12+,
correct type: ignore error codes in _comparator and _plugin, and
run ruff format on all package source and test files.

* Switch CI to prek for lint/typecheck checks

Use j178/prek-action for consistent lint+typecheck (ruff check,
ruff format, interrogate, mypy) matching local pre-commit config.
Keep test as a separate parallel job for test-env support.
This commit is contained in:
Kevin Turcios 2026-04-15 02:52:47 -05:00 committed by GitHub
parent a1710f7f92
commit 2caaf6af7c
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
25 changed files with 339 additions and 177 deletions

View file

@ -10,15 +10,34 @@ concurrency:
cancel-in-progress: true
jobs:
ci:
uses: codeflash-ai/github-workflows/.github/workflows/ci-python-uv.yml@main
with:
lint-command: "uv run ruff check && uv run ruff format --check"
typecheck-command: >-
uv run interrogate packages/codeflash-core/src/ packages/codeflash-python/src/ &&
uv run mypy packages/codeflash-core/src/ packages/codeflash-python/src/
test-command: "uv run pytest packages/ -v"
test-env: '{"CI": "true"}'
prek:
runs-on: ubuntu-latest
permissions:
contents: read
steps:
- uses: actions/checkout@v6
- uses: astral-sh/setup-uv@v8.0.0
with:
python-version: "3.12"
enable-cache: true
- run: uv sync --all-packages
- uses: j178/prek-action@v2
test:
runs-on: ubuntu-latest
permissions:
contents: read
steps:
- uses: actions/checkout@v6
- uses: astral-sh/setup-uv@v8.0.0
with:
python-version: "3.12"
enable-cache: true
- run: uv sync --all-packages
- name: Test
run: uv run pytest packages/ -v
env:
CI: "true"
version:
if: github.event_name == 'pull_request'

View file

@ -46,8 +46,8 @@ class PythonState(LanguageState[PythonConfiguration]):
_reference_graph: ReferenceGraph | None = None
_module_asts: dict[Path, ast.Module] = attrs.Factory(dict)
_validated_code: dict[Path, ValidCode] = attrs.Factory(dict)
_function_to_tests: dict[str, set[FunctionCalledInTest]] = (
attrs.Factory(dict)
_function_to_tests: dict[str, set[FunctionCalledInTest]] = attrs.Factory(
dict
)
def reference_graph(self) -> ReferenceGraph:

View file

@ -115,9 +115,7 @@ def find_class_node_by_name(
if item.name == class_name:
return item
stack.append(item)
elif isinstance(
item, (ast.FunctionDef, ast.AsyncFunctionDef)
):
elif isinstance(item, (ast.FunctionDef, ast.AsyncFunctionDef)):
stack.append(item)
return None
@ -183,9 +181,7 @@ def collect_names_from_annotation(
elif isinstance(node, ast.BinOp):
collect_names_from_annotation(node.left, names)
collect_names_from_annotation(node.right, names)
elif isinstance(node, ast.Attribute) and isinstance(
node.value, ast.Name
):
elif isinstance(node, ast.Attribute) and isinstance(node.value, ast.Name):
names.add(node.value.id)

View file

@ -177,9 +177,7 @@ def has_non_property_method_decorator(
if expr_matches_name(decorator, import_aliases, "property"):
continue
decorator_name = get_expr_name(decorator)
if decorator_name and decorator_name.endswith(
(".setter", ".deleter")
):
if decorator_name and decorator_name.endswith((".setter", ".deleter")):
continue
return True
return False

View file

@ -579,9 +579,7 @@ def enrich_testgen_context( # noqa: C901, PLR0912, PLR0915
lines = module_source.split("\n")
class_source = "\n".join(
lines[
get_class_start_line(class_node) - 1 : class_node.end_lineno
]
lines[get_class_start_line(class_node) - 1 : class_node.end_lineno]
)
code_strings.append(
@ -617,9 +615,7 @@ def enrich_testgen_context( # noqa: C901, PLR0912, PLR0915
if not is_proj and not is_third_party:
continue
mod_result = get_module_source_and_tree(
module_path, module_cache
)
mod_result = get_module_source_and_tree(module_path, module_cache)
if mod_result is None:
continue
module_source, module_tree = mod_result

View file

@ -620,7 +620,9 @@ class PythonFunctionOptimizer:
diff_lengths: list[int] = []
async_eval = self._make_async_evaluator()
test_env = build_test_env(fn_input, self.ctx.project_root, self.ctx.test_cfg)
test_env = build_test_env(
fn_input, self.ctx.project_root, self.ctx.test_cfg
)
def _try_candidate(c: Candidate) -> None:
"""Evaluate *c* and append to *valid* if it improves."""

View file

@ -61,7 +61,7 @@ def _build_capabilities() -> dict[str, object]:
def _lazy_detect_numerical() -> object:
"""Placeholder — actual binding happens at call site."""
from .._function_optimizer import ( # noqa: PLC0415
from .._function_optimizer import ( # type: ignore[import-not-found] # noqa: PLC0415
is_numerical_code,
)
@ -174,9 +174,10 @@ class PythonPlugin:
# -- Static type assertion -----------------------------------------
def _assert_protocol_compliance() -> None:
"""Compile-time check that PythonPlugin satisfies LanguagePlugin."""
_: LanguagePlugin = PythonPlugin( # type: ignore[call-arg]
_: LanguagePlugin = PythonPlugin( # type: ignore[assignment]
configuration=None, # type: ignore[arg-type]
state=None, # type: ignore[arg-type]
)

View file

@ -373,6 +373,7 @@ class InitDecorator(ast.NodeTransformer):
body=[super_call],
decorator_list=[decorator],
returns=None,
type_params=[],
)
node.body.insert(0, init_func)
@ -433,6 +434,7 @@ class InitDecorator(ast.NodeTransformer):
),
decorator_list=cast("list[ast.expr]", []),
returns=None,
type_params=[],
)
# ClassName.__init__ = codeflash_capture(...)(_codeflash_patched_ClassName_init)

View file

@ -1247,4 +1247,5 @@ def create_wrapper_function(
lineno=lineno,
decorator_list=[],
returns=None,
type_params=[],
)

View file

@ -671,7 +671,7 @@ def _comparator_inner( # noqa: C901, PLR0911, PLR0912, PLR0915
new_reduce = new.__reduce__()
orig_remaining = list(orig_reduce[1][0])
new_remaining = list(new_reduce[1][0])
orig_saved, orig_started = orig_reduce[2] # type: ignore[misc]
orig_saved, orig_started = orig_reduce[2] # type: ignore[str-unpack]
new_saved, new_started = new_reduce[2]
if orig_started != new_started:
return False

View file

@ -86,9 +86,8 @@ def cleanup_artifacts(cwd: Path) -> None:
dir_prefixes = ("tmp", "codeflash_replay_tests_")
for child in cwd.rglob("*"):
if (
child.is_dir()
and any(child.name.startswith(p) for p in dir_prefixes)
if child.is_dir() and any(
child.name.startswith(p) for p in dir_prefixes
):
shutil.rmtree(child, ignore_errors=True)
log.debug("Removed artifact dir: %s", child)

View file

@ -4,6 +4,7 @@ import os
from pathlib import Path
from unittest.mock import Mock
from codeflash_core import performance_gain
from codeflash_python.analysis._coverage import (
CoverageData,
CoverageStatus,
@ -18,7 +19,6 @@ from codeflash_python.testing.models import (
InvocationId,
TestResults,
)
from codeflash_core import performance_gain
from codeflash_python.verification._critic import (
concurrency_gain,
coverage_critic,

View file

@ -127,7 +127,9 @@ class TestGenerateExplanation:
def test_returns_ai_explanation(self) -> None:
"""AI service explanation is returned when available."""
opt = _make_optimizer()
opt.ctx.ai_client.generate_explanation.return_value = "Better explanation"
opt.ctx.ai_client.generate_explanation.return_value = (
"Better explanation"
)
winner = _make_candidate()
baseline = _make_baseline()
@ -301,8 +303,11 @@ class TestLogEvaluationResults:
baseline = _make_baseline()
log_evaluation_results(
opt.ctx.ai_client, opt.function_trace_id,
winner, eval_ctx, baseline,
opt.ctx.ai_client,
opt.function_trace_id,
winner,
eval_ctx,
baseline,
)
opt.ctx.ai_client.log_results.assert_called_once()
@ -321,8 +326,11 @@ class TestLogEvaluationResults:
baseline = _make_baseline()
log_evaluation_results(
opt.ctx.ai_client, opt.function_trace_id,
winner, eval_ctx, baseline,
opt.ctx.ai_client,
opt.function_trace_id,
winner,
eval_ctx,
baseline,
)
payload = opt.ctx.ai_client.log_results.call_args[0][0]
@ -339,7 +347,13 @@ class TestBuildBenchmarkDetails:
baseline = _make_baseline()
result = build_benchmark_details(
winner, baseline, {}, {}, {}, None, Path("/tmp"),
winner,
baseline,
{},
{},
{},
None,
Path("/tmp"),
)
assert result is None
@ -354,9 +368,13 @@ class TestBuildBenchmarkDetails:
baseline = _make_baseline()
result = build_benchmark_details(
winner, baseline,
{bk: 50_000}, {bk: 200_000},
{}, None, Path("/tmp"),
winner,
baseline,
{bk: 50_000},
{bk: 200_000},
{},
None,
Path("/tmp"),
)
assert result is None
@ -374,9 +392,13 @@ class TestBuildBenchmarkDetails:
baseline = _make_baseline(runtime=200_000)
result = build_benchmark_details(
winner, baseline,
{bk: 50_000}, {bk: 200_000},
candidate_bench_results, None, Path("/tmp"),
winner,
baseline,
{bk: 50_000},
{bk: 200_000},
candidate_bench_results,
None,
Path("/tmp"),
)
assert result is not None

View file

@ -53,10 +53,12 @@ async def lifespan(app: FastAPI) -> AsyncIterator[None]:
yield
if running_tasks:
log.info(
"Draining %d background tasks...", len(running_tasks),
"Draining %d background tasks...",
len(running_tasks),
)
await asyncio.gather(
*running_tasks, return_exceptions=True,
*running_tasks,
return_exceptions=True,
)
@ -79,10 +81,12 @@ async def webhook(
request: Request,
x_github_event: str = Header(..., alias="X-GitHub-Event"),
x_hub_signature_256: str = Header(
..., alias="X-Hub-Signature-256",
...,
alias="X-Hub-Signature-256",
),
x_github_delivery: str = Header(
..., alias="X-GitHub-Delivery",
...,
alias="X-GitHub-Delivery",
),
) -> dict[str, str]:
"""Receive and dispatch GitHub webhook events."""
@ -100,7 +104,9 @@ async def webhook(
payload = await request.json()
log.info(
"Event %s delivery=%s action=%s",
x_github_event, x_github_delivery, payload.get("action"),
x_github_event,
x_github_delivery,
payload.get("action"),
)
handler = EVENT_HANDLERS.get(x_github_event)
@ -112,7 +118,8 @@ async def webhook(
task = asyncio.create_task(
safe_handle(
handler, payload,
handler,
payload,
config=cfg,
http_client=http_client,
),
@ -153,20 +160,28 @@ async def dispatch_issues(
installation_id = payload["installation"]["id"]
token = await get_installation_token(
config, installation_id, client=http_client,
config,
installation_id,
client=http_client,
)
repo_dir = await clone_repo(
owner, repo, repo_info["default_branch"],
token, config.workspace_dir,
owner,
repo,
repo_info["default_branch"],
token,
config.workspace_dir,
)
_write_ci_context(str(repo_dir), {
"event_type": "issues",
"action": action,
"owner": owner,
"repo": repo,
"number": payload["issue"]["number"],
})
_write_ci_context(
str(repo_dir),
{
"event_type": "issues",
"action": action,
"owner": owner,
"repo": repo,
"number": payload["issue"]["number"],
},
)
await run_agent(config, repo_dir, token)
log.info("Agent handled issue %s/%s#%d", owner, repo, payload["issue"]["number"])
@ -191,22 +206,30 @@ async def dispatch_pr(
installation_id = payload["installation"]["id"]
token = await get_installation_token(
config, installation_id, client=http_client,
config,
installation_id,
client=http_client,
)
repo_dir = await clone_repo(
owner, repo, pr["head"]["ref"],
token, config.workspace_dir,
owner,
repo,
pr["head"]["ref"],
token,
config.workspace_dir,
)
_write_ci_context(str(repo_dir), {
"event_type": "pull_request",
"action": action,
"owner": owner,
"repo": repo,
"number": pr["number"],
"base_ref": pr["base"]["ref"],
"head_ref": pr["head"]["ref"],
})
_write_ci_context(
str(repo_dir),
{
"event_type": "pull_request",
"action": action,
"owner": owner,
"repo": repo,
"number": pr["number"],
"base_ref": pr["base"]["ref"],
"head_ref": pr["head"]["ref"],
},
)
ci_prompt = (
"AUTONOMOUS MODE: Do NOT ask the user any questions — work fully "
@ -248,21 +271,29 @@ async def dispatch_push(
return
token = await get_installation_token(
config, installation_id, client=http_client,
config,
installation_id,
client=http_client,
)
repo_dir = await clone_repo(
owner, repo, default_branch,
token, config.workspace_dir,
owner,
repo,
default_branch,
token,
config.workspace_dir,
)
_write_ci_context(str(repo_dir), {
"event_type": "push",
"action": None,
"owner": owner,
"repo": repo,
"head_sha": payload.get("after", ""),
"ref": ref,
})
_write_ci_context(
str(repo_dir),
{
"event_type": "push",
"action": None,
"owner": owner,
"repo": repo,
"head_sha": payload.get("after", ""),
"ref": ref,
},
)
await run_agent(config, repo_dir, token)
log.info("Agent handled push to %s/%s ref=%s", owner, repo, ref)

View file

@ -23,7 +23,8 @@ GITHUB_API = "https://api.github.com"
# Cache installation tokens for 50 min (tokens last 1 hour).
# Keyed by (app_id, installation_id) to prevent cross-app leakage.
token_cache: TTLCache[tuple[str | int, int], str] = TTLCache(
maxsize=64, ttl=3000,
maxsize=64,
ttl=3000,
)
@ -40,7 +41,9 @@ def generate_jwt(cfg: Config) -> str:
@stamina.retry(on=is_retryable, attempts=3)
async def get_installation_token(
cfg: Config, installation_id: int, *,
cfg: Config,
installation_id: int,
*,
client: httpx.AsyncClient,
) -> str:
"""Exchange the JWT for an installation access token.
@ -54,8 +57,7 @@ async def get_installation_token(
token = generate_jwt(cfg)
resp = await client.post(
f"{GITHUB_API}/app/installations/"
f"{installation_id}/access_tokens",
f"{GITHUB_API}/app/installations/{installation_id}/access_tokens",
headers={
"Authorization": f"Bearer {token}",
"Accept": "application/vnd.github+json",
@ -69,12 +71,16 @@ async def get_installation_token(
def verify_signature(
payload: bytes, signature: str, secret: str,
payload: bytes,
signature: str,
secret: str,
) -> bool:
"""Verify the X-Hub-Signature-256 header."""
if not signature.startswith("sha256="):
return False
expected = hmac.new(
secret.encode(), payload, hashlib.sha256,
secret.encode(),
payload,
hashlib.sha256,
).hexdigest()
return hmac.compare_digest(f"sha256={expected}", signature)

View file

@ -50,8 +50,11 @@ class BackendSpec(ABC):
that need extra flags to enable file editing.
"""
return self.build_cmd(
cli=cli, model=model, prompt=prompt,
repo_dir=repo_dir, plugin_dir=plugin_dir,
cli=cli,
model=model,
prompt=prompt,
repo_dir=repo_dir,
plugin_dir=plugin_dir,
)
@ -84,10 +87,15 @@ class ClaudeBackend(BackendSpec):
agent: str = "codeflash-ci",
) -> tuple[list[str], str | None]:
cmd = [
cli, "-p", prompt,
"--model", model,
"--agent", agent,
"--max-turns", "200",
cli,
"-p",
prompt,
"--model",
model,
"--agent",
agent,
"--max-turns",
"200",
"--dangerously-skip-permissions",
]
if plugin_dir:
@ -109,11 +117,15 @@ class CodexBackend(BackendSpec):
plugin_dir: Path | None = None,
) -> tuple[list[str], str | None]:
cmd = [
cli, "exec",
"--model", model,
cli,
"exec",
"--model",
model,
"--full-auto",
"-C", str(repo_dir),
"-o", "/dev/stdout",
"-C",
str(repo_dir),
"-o",
"/dev/stdout",
prompt,
]
return cmd, None

View file

@ -45,7 +45,8 @@ class Config:
)
claude_model: str = field(
default_factory=lambda: os.environ.get(
"CLAUDE_MODEL", "us.anthropic.claude-sonnet-4-5-20250929-v1:0",
"CLAUDE_MODEL",
"us.anthropic.claude-sonnet-4-5-20250929-v1:0",
),
)
plugin_dir: Path = field(default_factory=default_plugin_dir)
@ -56,7 +57,8 @@ class Config:
)
codex_model: str = field(
default_factory=lambda: os.environ.get(
"CODEX_MODEL", "gpt-5.4",
"CODEX_MODEL",
"gpt-5.4",
),
)
@ -77,7 +79,8 @@ class Config:
workspace_dir: Path = field(
default_factory=lambda: Path(
os.environ.get(
"WORKSPACE_DIR", "/tmp/codeflash-workspaces",
"WORKSPACE_DIR",
"/tmp/codeflash-workspaces",
),
),
)

View file

@ -11,7 +11,9 @@ log = logging.getLogger(__name__)
def _validate_clone_args(
owner: str, repo: str, workspace: Path,
owner: str,
repo: str,
workspace: Path,
) -> None:
"""Reject owner/repo values that could escape the workspace."""
for name, value in [("owner", owner), ("repo", repo)]:
@ -36,7 +38,8 @@ async def clone_repo(
# Atomic unique directory -- avoids race conditions and rmtree.
repo_dir = Path(
tempfile.mkdtemp(
prefix=f"{owner}_{repo}_", dir=workspace,
prefix=f"{owner}_{repo}_",
dir=workspace,
),
)
@ -46,13 +49,15 @@ async def clone_repo(
msg = f"Path escapes workspace: {repo_dir}"
raise ValueError(msg)
clone_url = (
f"https://x-access-token:{token}@github.com"
f"/{owner}/{repo}.git"
)
clone_url = f"https://x-access-token:{token}@github.com/{owner}/{repo}.git"
proc = await asyncio.create_subprocess_exec(
"git", "clone", "--depth=1", "--branch", ref,
clone_url, str(repo_dir),
"git",
"clone",
"--depth=1",
"--branch",
ref,
clone_url,
str(repo_dir),
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.PIPE,
)
@ -60,7 +65,8 @@ async def clone_repo(
if proc.returncode != 0:
log.error(
"git clone failed (rc=%d): %s",
proc.returncode, stderr.decode(),
proc.returncode,
stderr.decode(),
)
msg = f"git clone failed for {owner}/{repo} ref={ref}"
raise RuntimeError(msg)
@ -80,7 +86,9 @@ async def commit_and_push(
"""
# Stage everything
proc = await asyncio.create_subprocess_exec(
"git", "add", "-A",
"git",
"add",
"-A",
cwd=str(repo_dir),
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.PIPE,
@ -89,7 +97,10 @@ async def commit_and_push(
# Check if there are staged changes
proc = await asyncio.create_subprocess_exec(
"git", "diff", "--cached", "--quiet",
"git",
"diff",
"--cached",
"--quiet",
cwd=str(repo_dir),
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.PIPE,
@ -101,7 +112,10 @@ async def commit_and_push(
# Commit
proc = await asyncio.create_subprocess_exec(
"git", "commit", "-m", message,
"git",
"commit",
"-m",
message,
cwd=str(repo_dir),
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.PIPE,
@ -115,7 +129,11 @@ async def commit_and_push(
# installation token (which may lack push permission).
plain_url = f"https://github.com/{owner}/{repo}.git"
proc = await asyncio.create_subprocess_exec(
"git", "remote", "set-url", "origin", plain_url,
"git",
"remote",
"set-url",
"origin",
plain_url,
cwd=str(repo_dir),
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.PIPE,
@ -124,7 +142,10 @@ async def commit_and_push(
# Push to the PR branch
proc = await asyncio.create_subprocess_exec(
"git", "push", "origin", f"HEAD:{branch}",
"git",
"push",
"origin",
f"HEAD:{branch}",
cwd=str(repo_dir),
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.PIPE,

View file

@ -22,12 +22,14 @@ MAX_PAGES = 50
@stamina.retry(on=is_retryable, attempts=3)
async def fetch_pr_diff(
client: httpx.AsyncClient,
owner: str, repo: str, pr_number: int, token: str,
owner: str,
repo: str,
pr_number: int,
token: str,
) -> str:
"""Fetch the unified diff for a pull request via GitHub API."""
resp = await client.get(
f"{GITHUB_API}/repos/{owner}/{repo}"
f"/pulls/{pr_number}",
f"{GITHUB_API}/repos/{owner}/{repo}/pulls/{pr_number}",
headers={
"Authorization": f"token {token}",
"Accept": "application/vnd.github.diff",
@ -40,15 +42,17 @@ async def fetch_pr_diff(
@stamina.retry(on=is_retryable, attempts=3)
async def fetch_pr_files(
client: httpx.AsyncClient,
owner: str, repo: str, pr_number: int, token: str,
owner: str,
repo: str,
pr_number: int,
token: str,
) -> list[dict]:
"""Fetch the list of changed files for a pull request (paginated)."""
files: list[dict] = []
page = 1
while True:
resp = await client.get(
f"{GITHUB_API}/repos/{owner}/{repo}"
f"/pulls/{pr_number}/files",
f"{GITHUB_API}/repos/{owner}/{repo}/pulls/{pr_number}/files",
headers={
"Authorization": f"token {token}",
"Accept": "application/vnd.github+json",
@ -63,9 +67,11 @@ async def fetch_pr_files(
page += 1
if page > MAX_PAGES:
log.warning(
"Pagination cap reached fetching files for "
"%s/%s#%d (%d pages)",
owner, repo, pr_number, MAX_PAGES,
"Pagination cap reached fetching files for %s/%s#%d (%d pages)",
owner,
repo,
pr_number,
MAX_PAGES,
)
break
return files
@ -74,12 +80,14 @@ async def fetch_pr_files(
@stamina.retry(on=is_retryable, attempts=3)
async def fetch_pr_details(
client: httpx.AsyncClient,
owner: str, repo: str, pr_number: int, token: str,
owner: str,
repo: str,
pr_number: int,
token: str,
) -> dict:
"""Fetch PR metadata (head/base refs, title, etc.)."""
resp = await client.get(
f"{GITHUB_API}/repos/{owner}/{repo}"
f"/pulls/{pr_number}",
f"{GITHUB_API}/repos/{owner}/{repo}/pulls/{pr_number}",
headers={
"Authorization": f"token {token}",
"Accept": "application/vnd.github+json",
@ -92,12 +100,14 @@ async def fetch_pr_details(
@stamina.retry(on=is_retryable, attempts=3)
async def fetch_commit_diff(
client: httpx.AsyncClient,
owner: str, repo: str, sha: str, token: str,
owner: str,
repo: str,
sha: str,
token: str,
) -> str:
"""Fetch the unified diff for a single commit via GitHub API."""
resp = await client.get(
f"{GITHUB_API}/repos/{owner}/{repo}"
f"/commits/{sha}",
f"{GITHUB_API}/repos/{owner}/{repo}/commits/{sha}",
headers={
"Authorization": f"token {token}",
"Accept": "application/vnd.github.diff",
@ -110,7 +120,9 @@ async def fetch_commit_diff(
@stamina.retry(on=is_retryable, attempts=3)
async def fetch_repo_labels(
client: httpx.AsyncClient,
owner: str, repo: str, token: str,
owner: str,
repo: str,
token: str,
) -> list[str]:
"""Fetch all label names from a repository."""
labels: list[str] = []
@ -132,9 +144,10 @@ async def fetch_repo_labels(
page += 1
if page > MAX_PAGES:
log.warning(
"Pagination cap reached fetching labels for "
"%s/%s (%d pages)",
owner, repo, MAX_PAGES,
"Pagination cap reached fetching labels for %s/%s (%d pages)",
owner,
repo,
MAX_PAGES,
)
break
return labels
@ -143,13 +156,16 @@ async def fetch_repo_labels(
@stamina.retry(on=is_retryable, attempts=3)
async def post_review(
client: httpx.AsyncClient,
owner: str, repo: str, pr_number: int,
body: str, event: str, token: str,
owner: str,
repo: str,
pr_number: int,
body: str,
event: str,
token: str,
) -> None:
"""Submit a PR review (COMMENT, APPROVE, or REQUEST_CHANGES)."""
resp = await client.post(
f"{GITHUB_API}/repos/{owner}/{repo}"
f"/pulls/{pr_number}/reviews",
f"{GITHUB_API}/repos/{owner}/{repo}/pulls/{pr_number}/reviews",
headers={
"Authorization": f"token {token}",
"Accept": "application/vnd.github+json",
@ -162,13 +178,15 @@ async def post_review(
@stamina.retry(on=is_retryable, attempts=3)
async def post_comment(
client: httpx.AsyncClient,
owner: str, repo: str, issue_number: int,
body: str, token: str,
owner: str,
repo: str,
issue_number: int,
body: str,
token: str,
) -> None:
"""Post a comment on a PR or issue."""
resp = await client.post(
f"{GITHUB_API}/repos/{owner}/{repo}"
f"/issues/{issue_number}/comments",
f"{GITHUB_API}/repos/{owner}/{repo}/issues/{issue_number}/comments",
headers={
"Authorization": f"token {token}",
"Accept": "application/vnd.github+json",
@ -181,13 +199,15 @@ async def post_comment(
@stamina.retry(on=is_retryable, attempts=3)
async def add_labels(
client: httpx.AsyncClient,
owner: str, repo: str, issue_number: int,
labels: list[str], token: str,
owner: str,
repo: str,
issue_number: int,
labels: list[str],
token: str,
) -> None:
"""Add labels to an issue or PR."""
resp = await client.post(
f"{GITHUB_API}/repos/{owner}/{repo}"
f"/issues/{issue_number}/labels",
f"{GITHUB_API}/repos/{owner}/{repo}/issues/{issue_number}/labels",
headers={
"Authorization": f"token {token}",
"Accept": "application/vnd.github+json",
@ -200,8 +220,12 @@ async def add_labels(
@stamina.retry(on=is_retryable, attempts=3)
async def create_check_run(
client: httpx.AsyncClient,
owner: str, repo: str, head_sha: str,
name: str, conclusion: str, output: dict,
owner: str,
repo: str,
head_sha: str,
name: str,
conclusion: str,
output: dict,
token: str,
) -> None:
"""Create a check run on a commit."""
@ -238,7 +262,4 @@ def truncate_diff(diff: str, max_chars: int = MAX_DIFF_CHARS) -> str:
"""Truncate diff to max_chars, appending a note if cut."""
if len(diff) <= max_chars:
return diff
return (
diff[:max_chars]
+ "\n\n... (diff truncated, full repo available)"
)
return diff[:max_chars] + "\n\n... (diff truncated, full repo available)"

View file

@ -16,8 +16,9 @@ def is_retryable(exc: Exception) -> bool:
code = exc.response.status_code
return code == 429 or code >= 500
return isinstance(
exc, (httpx.ConnectError, httpx.TimeoutException),
exc,
(httpx.ConnectError, httpx.TimeoutException),
)
# https://smee.io/ACAUooTvHulETive
# https://smee.io/ACAUooTvHulETive

View file

@ -24,7 +24,9 @@ def test_generate_jwt_structure(mock_config):
def test_generate_jwt_claims(mock_config):
token = generate_jwt(mock_config)
claims = pyjwt.decode(
token, options={"verify_signature": False}, algorithms=["RS256"],
token,
options={"verify_signature": False},
algorithms=["RS256"],
)
# PyJWT requires iss as string; Config.app_id is int, converted in generate_jwt.
assert claims["iss"] == "12345"
@ -40,7 +42,9 @@ def test_verify_signature_valid():
import hmac
sig = hmac.new(
WEBHOOK_SECRET.encode(), payload, hashlib.sha256,
WEBHOOK_SECRET.encode(),
payload,
hashlib.sha256,
).hexdigest()
assert verify_signature(payload, f"sha256={sig}", WEBHOOK_SECRET)
@ -61,7 +65,9 @@ async def test_get_installation_token_fetches(mock_config):
async with httpx.AsyncClient() as client:
token = await get_installation_token(
mock_config, 99, client=client,
mock_config,
99,
client=client,
)
assert token == "ghs_test123"

View file

@ -19,9 +19,13 @@ def test_claude_backend_build_cmd():
plugin_dir=Path("/tmp/plugins"),
)
assert cmd == [
"claude", "-p", "review this",
"--model", "claude-sonnet-4-6",
"--plugin-dir", "/tmp/plugins",
"claude",
"-p",
"review this",
"--model",
"claude-sonnet-4-6",
"--plugin-dir",
"/tmp/plugins",
]
assert cwd == "/tmp/repo"
@ -48,12 +52,18 @@ def test_claude_backend_build_edit_cmd_default_agent():
plugin_dir=Path("/tmp/plugins"),
)
assert cmd == [
"claude", "-p", "CI: process .codeflash/ci-context.json",
"--model", "claude-sonnet-4-6",
"--agent", "codeflash-ci",
"--max-turns", "200",
"claude",
"-p",
"CI: process .codeflash/ci-context.json",
"--model",
"claude-sonnet-4-6",
"--agent",
"codeflash-ci",
"--max-turns",
"200",
"--dangerously-skip-permissions",
"--plugin-dir", "/tmp/plugins",
"--plugin-dir",
"/tmp/plugins",
]
assert cwd == "/tmp/repo"
@ -83,11 +93,15 @@ def test_codex_backend_build_cmd():
plugin_dir=Path("/tmp/plugins"),
)
assert cmd == [
"codex", "exec",
"--model", "gpt-5.4",
"codex",
"exec",
"--model",
"gpt-5.4",
"--full-auto",
"-C", "/tmp/repo",
"-o", "/dev/stdout",
"-C",
"/tmp/repo",
"-o",
"/dev/stdout",
"review this",
]
assert cwd is None

View file

@ -29,9 +29,14 @@ def test_load_private_key_from_file(tmp_path):
def test_load_private_key_missing():
with patch.dict(
os.environ, {}, clear=True,
), pytest.raises(ValueError, match="GITHUB_PRIVATE_KEY"):
with (
patch.dict(
os.environ,
{},
clear=True,
),
pytest.raises(ValueError, match="GITHUB_PRIVATE_KEY"),
):
load_private_key()

View file

@ -52,7 +52,11 @@ async def test_clone_repo_success(tmp_path):
with patch(PATCH_TARGET, return_value=mock_proc):
result = await clone_repo(
"owner", "repo", "main", "tok", tmp_path,
"owner",
"repo",
"main",
"tok",
tmp_path,
)
assert result.parent == tmp_path

View file

@ -14,7 +14,9 @@ def _make_status_error(status_code: int) -> httpx.HTTPStatusError:
response = Mock(spec=httpx.Response)
response.status_code = status_code
return httpx.HTTPStatusError(
"error", request=Mock(), response=response,
"error",
request=Mock(),
response=response,
)