mirror of
https://github.com/codeflash-ai/codeflash-agent.git
synced 2026-05-04 18:25:19 +00:00
Fix CI: mypy errors, ruff formatting, switch to prek (#22)
* Fix mypy errors and apply ruff formatting across packages Fix ast.FunctionDef calls missing type_params for Python 3.12+, correct type: ignore error codes in _comparator and _plugin, and run ruff format on all package source and test files. * Switch CI to prek for lint/typecheck checks Use j178/prek-action for consistent lint+typecheck (ruff check, ruff format, interrogate, mypy) matching local pre-commit config. Keep test as a separate parallel job for test-env support.
This commit is contained in:
parent
a1710f7f92
commit
2caaf6af7c
25 changed files with 339 additions and 177 deletions
37
.github/workflows/ci.yml
vendored
37
.github/workflows/ci.yml
vendored
|
|
@ -10,15 +10,34 @@ concurrency:
|
|||
cancel-in-progress: true
|
||||
|
||||
jobs:
|
||||
ci:
|
||||
uses: codeflash-ai/github-workflows/.github/workflows/ci-python-uv.yml@main
|
||||
with:
|
||||
lint-command: "uv run ruff check && uv run ruff format --check"
|
||||
typecheck-command: >-
|
||||
uv run interrogate packages/codeflash-core/src/ packages/codeflash-python/src/ &&
|
||||
uv run mypy packages/codeflash-core/src/ packages/codeflash-python/src/
|
||||
test-command: "uv run pytest packages/ -v"
|
||||
test-env: '{"CI": "true"}'
|
||||
prek:
|
||||
runs-on: ubuntu-latest
|
||||
permissions:
|
||||
contents: read
|
||||
steps:
|
||||
- uses: actions/checkout@v6
|
||||
- uses: astral-sh/setup-uv@v8.0.0
|
||||
with:
|
||||
python-version: "3.12"
|
||||
enable-cache: true
|
||||
- run: uv sync --all-packages
|
||||
- uses: j178/prek-action@v2
|
||||
|
||||
test:
|
||||
runs-on: ubuntu-latest
|
||||
permissions:
|
||||
contents: read
|
||||
steps:
|
||||
- uses: actions/checkout@v6
|
||||
- uses: astral-sh/setup-uv@v8.0.0
|
||||
with:
|
||||
python-version: "3.12"
|
||||
enable-cache: true
|
||||
- run: uv sync --all-packages
|
||||
- name: Test
|
||||
run: uv run pytest packages/ -v
|
||||
env:
|
||||
CI: "true"
|
||||
|
||||
version:
|
||||
if: github.event_name == 'pull_request'
|
||||
|
|
|
|||
|
|
@ -46,8 +46,8 @@ class PythonState(LanguageState[PythonConfiguration]):
|
|||
_reference_graph: ReferenceGraph | None = None
|
||||
_module_asts: dict[Path, ast.Module] = attrs.Factory(dict)
|
||||
_validated_code: dict[Path, ValidCode] = attrs.Factory(dict)
|
||||
_function_to_tests: dict[str, set[FunctionCalledInTest]] = (
|
||||
attrs.Factory(dict)
|
||||
_function_to_tests: dict[str, set[FunctionCalledInTest]] = attrs.Factory(
|
||||
dict
|
||||
)
|
||||
|
||||
def reference_graph(self) -> ReferenceGraph:
|
||||
|
|
|
|||
|
|
@ -115,9 +115,7 @@ def find_class_node_by_name(
|
|||
if item.name == class_name:
|
||||
return item
|
||||
stack.append(item)
|
||||
elif isinstance(
|
||||
item, (ast.FunctionDef, ast.AsyncFunctionDef)
|
||||
):
|
||||
elif isinstance(item, (ast.FunctionDef, ast.AsyncFunctionDef)):
|
||||
stack.append(item)
|
||||
return None
|
||||
|
||||
|
|
@ -183,9 +181,7 @@ def collect_names_from_annotation(
|
|||
elif isinstance(node, ast.BinOp):
|
||||
collect_names_from_annotation(node.left, names)
|
||||
collect_names_from_annotation(node.right, names)
|
||||
elif isinstance(node, ast.Attribute) and isinstance(
|
||||
node.value, ast.Name
|
||||
):
|
||||
elif isinstance(node, ast.Attribute) and isinstance(node.value, ast.Name):
|
||||
names.add(node.value.id)
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -177,9 +177,7 @@ def has_non_property_method_decorator(
|
|||
if expr_matches_name(decorator, import_aliases, "property"):
|
||||
continue
|
||||
decorator_name = get_expr_name(decorator)
|
||||
if decorator_name and decorator_name.endswith(
|
||||
(".setter", ".deleter")
|
||||
):
|
||||
if decorator_name and decorator_name.endswith((".setter", ".deleter")):
|
||||
continue
|
||||
return True
|
||||
return False
|
||||
|
|
|
|||
|
|
@ -579,9 +579,7 @@ def enrich_testgen_context( # noqa: C901, PLR0912, PLR0915
|
|||
|
||||
lines = module_source.split("\n")
|
||||
class_source = "\n".join(
|
||||
lines[
|
||||
get_class_start_line(class_node) - 1 : class_node.end_lineno
|
||||
]
|
||||
lines[get_class_start_line(class_node) - 1 : class_node.end_lineno]
|
||||
)
|
||||
|
||||
code_strings.append(
|
||||
|
|
@ -617,9 +615,7 @@ def enrich_testgen_context( # noqa: C901, PLR0912, PLR0915
|
|||
if not is_proj and not is_third_party:
|
||||
continue
|
||||
|
||||
mod_result = get_module_source_and_tree(
|
||||
module_path, module_cache
|
||||
)
|
||||
mod_result = get_module_source_and_tree(module_path, module_cache)
|
||||
if mod_result is None:
|
||||
continue
|
||||
module_source, module_tree = mod_result
|
||||
|
|
|
|||
|
|
@ -620,7 +620,9 @@ class PythonFunctionOptimizer:
|
|||
diff_lengths: list[int] = []
|
||||
|
||||
async_eval = self._make_async_evaluator()
|
||||
test_env = build_test_env(fn_input, self.ctx.project_root, self.ctx.test_cfg)
|
||||
test_env = build_test_env(
|
||||
fn_input, self.ctx.project_root, self.ctx.test_cfg
|
||||
)
|
||||
|
||||
def _try_candidate(c: Candidate) -> None:
|
||||
"""Evaluate *c* and append to *valid* if it improves."""
|
||||
|
|
|
|||
|
|
@ -61,7 +61,7 @@ def _build_capabilities() -> dict[str, object]:
|
|||
|
||||
def _lazy_detect_numerical() -> object:
|
||||
"""Placeholder — actual binding happens at call site."""
|
||||
from .._function_optimizer import ( # noqa: PLC0415
|
||||
from .._function_optimizer import ( # type: ignore[import-not-found] # noqa: PLC0415
|
||||
is_numerical_code,
|
||||
)
|
||||
|
||||
|
|
@ -174,9 +174,10 @@ class PythonPlugin:
|
|||
|
||||
# -- Static type assertion -----------------------------------------
|
||||
|
||||
|
||||
def _assert_protocol_compliance() -> None:
|
||||
"""Compile-time check that PythonPlugin satisfies LanguagePlugin."""
|
||||
_: LanguagePlugin = PythonPlugin( # type: ignore[call-arg]
|
||||
_: LanguagePlugin = PythonPlugin( # type: ignore[assignment]
|
||||
configuration=None, # type: ignore[arg-type]
|
||||
state=None, # type: ignore[arg-type]
|
||||
)
|
||||
|
|
|
|||
|
|
@ -373,6 +373,7 @@ class InitDecorator(ast.NodeTransformer):
|
|||
body=[super_call],
|
||||
decorator_list=[decorator],
|
||||
returns=None,
|
||||
type_params=[],
|
||||
)
|
||||
|
||||
node.body.insert(0, init_func)
|
||||
|
|
@ -433,6 +434,7 @@ class InitDecorator(ast.NodeTransformer):
|
|||
),
|
||||
decorator_list=cast("list[ast.expr]", []),
|
||||
returns=None,
|
||||
type_params=[],
|
||||
)
|
||||
|
||||
# ClassName.__init__ = codeflash_capture(...)(_codeflash_patched_ClassName_init)
|
||||
|
|
|
|||
|
|
@ -1247,4 +1247,5 @@ def create_wrapper_function(
|
|||
lineno=lineno,
|
||||
decorator_list=[],
|
||||
returns=None,
|
||||
type_params=[],
|
||||
)
|
||||
|
|
|
|||
|
|
@ -671,7 +671,7 @@ def _comparator_inner( # noqa: C901, PLR0911, PLR0912, PLR0915
|
|||
new_reduce = new.__reduce__()
|
||||
orig_remaining = list(orig_reduce[1][0])
|
||||
new_remaining = list(new_reduce[1][0])
|
||||
orig_saved, orig_started = orig_reduce[2] # type: ignore[misc]
|
||||
orig_saved, orig_started = orig_reduce[2] # type: ignore[str-unpack]
|
||||
new_saved, new_started = new_reduce[2]
|
||||
if orig_started != new_started:
|
||||
return False
|
||||
|
|
|
|||
|
|
@ -86,9 +86,8 @@ def cleanup_artifacts(cwd: Path) -> None:
|
|||
|
||||
dir_prefixes = ("tmp", "codeflash_replay_tests_")
|
||||
for child in cwd.rglob("*"):
|
||||
if (
|
||||
child.is_dir()
|
||||
and any(child.name.startswith(p) for p in dir_prefixes)
|
||||
if child.is_dir() and any(
|
||||
child.name.startswith(p) for p in dir_prefixes
|
||||
):
|
||||
shutil.rmtree(child, ignore_errors=True)
|
||||
log.debug("Removed artifact dir: %s", child)
|
||||
|
|
|
|||
|
|
@ -4,6 +4,7 @@ import os
|
|||
from pathlib import Path
|
||||
from unittest.mock import Mock
|
||||
|
||||
from codeflash_core import performance_gain
|
||||
from codeflash_python.analysis._coverage import (
|
||||
CoverageData,
|
||||
CoverageStatus,
|
||||
|
|
@ -18,7 +19,6 @@ from codeflash_python.testing.models import (
|
|||
InvocationId,
|
||||
TestResults,
|
||||
)
|
||||
from codeflash_core import performance_gain
|
||||
from codeflash_python.verification._critic import (
|
||||
concurrency_gain,
|
||||
coverage_critic,
|
||||
|
|
|
|||
|
|
@ -127,7 +127,9 @@ class TestGenerateExplanation:
|
|||
def test_returns_ai_explanation(self) -> None:
|
||||
"""AI service explanation is returned when available."""
|
||||
opt = _make_optimizer()
|
||||
opt.ctx.ai_client.generate_explanation.return_value = "Better explanation"
|
||||
opt.ctx.ai_client.generate_explanation.return_value = (
|
||||
"Better explanation"
|
||||
)
|
||||
|
||||
winner = _make_candidate()
|
||||
baseline = _make_baseline()
|
||||
|
|
@ -301,8 +303,11 @@ class TestLogEvaluationResults:
|
|||
baseline = _make_baseline()
|
||||
|
||||
log_evaluation_results(
|
||||
opt.ctx.ai_client, opt.function_trace_id,
|
||||
winner, eval_ctx, baseline,
|
||||
opt.ctx.ai_client,
|
||||
opt.function_trace_id,
|
||||
winner,
|
||||
eval_ctx,
|
||||
baseline,
|
||||
)
|
||||
|
||||
opt.ctx.ai_client.log_results.assert_called_once()
|
||||
|
|
@ -321,8 +326,11 @@ class TestLogEvaluationResults:
|
|||
baseline = _make_baseline()
|
||||
|
||||
log_evaluation_results(
|
||||
opt.ctx.ai_client, opt.function_trace_id,
|
||||
winner, eval_ctx, baseline,
|
||||
opt.ctx.ai_client,
|
||||
opt.function_trace_id,
|
||||
winner,
|
||||
eval_ctx,
|
||||
baseline,
|
||||
)
|
||||
|
||||
payload = opt.ctx.ai_client.log_results.call_args[0][0]
|
||||
|
|
@ -339,7 +347,13 @@ class TestBuildBenchmarkDetails:
|
|||
baseline = _make_baseline()
|
||||
|
||||
result = build_benchmark_details(
|
||||
winner, baseline, {}, {}, {}, None, Path("/tmp"),
|
||||
winner,
|
||||
baseline,
|
||||
{},
|
||||
{},
|
||||
{},
|
||||
None,
|
||||
Path("/tmp"),
|
||||
)
|
||||
|
||||
assert result is None
|
||||
|
|
@ -354,9 +368,13 @@ class TestBuildBenchmarkDetails:
|
|||
baseline = _make_baseline()
|
||||
|
||||
result = build_benchmark_details(
|
||||
winner, baseline,
|
||||
{bk: 50_000}, {bk: 200_000},
|
||||
{}, None, Path("/tmp"),
|
||||
winner,
|
||||
baseline,
|
||||
{bk: 50_000},
|
||||
{bk: 200_000},
|
||||
{},
|
||||
None,
|
||||
Path("/tmp"),
|
||||
)
|
||||
|
||||
assert result is None
|
||||
|
|
@ -374,9 +392,13 @@ class TestBuildBenchmarkDetails:
|
|||
baseline = _make_baseline(runtime=200_000)
|
||||
|
||||
result = build_benchmark_details(
|
||||
winner, baseline,
|
||||
{bk: 50_000}, {bk: 200_000},
|
||||
candidate_bench_results, None, Path("/tmp"),
|
||||
winner,
|
||||
baseline,
|
||||
{bk: 50_000},
|
||||
{bk: 200_000},
|
||||
candidate_bench_results,
|
||||
None,
|
||||
Path("/tmp"),
|
||||
)
|
||||
|
||||
assert result is not None
|
||||
|
|
|
|||
|
|
@ -53,10 +53,12 @@ async def lifespan(app: FastAPI) -> AsyncIterator[None]:
|
|||
yield
|
||||
if running_tasks:
|
||||
log.info(
|
||||
"Draining %d background tasks...", len(running_tasks),
|
||||
"Draining %d background tasks...",
|
||||
len(running_tasks),
|
||||
)
|
||||
await asyncio.gather(
|
||||
*running_tasks, return_exceptions=True,
|
||||
*running_tasks,
|
||||
return_exceptions=True,
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -79,10 +81,12 @@ async def webhook(
|
|||
request: Request,
|
||||
x_github_event: str = Header(..., alias="X-GitHub-Event"),
|
||||
x_hub_signature_256: str = Header(
|
||||
..., alias="X-Hub-Signature-256",
|
||||
...,
|
||||
alias="X-Hub-Signature-256",
|
||||
),
|
||||
x_github_delivery: str = Header(
|
||||
..., alias="X-GitHub-Delivery",
|
||||
...,
|
||||
alias="X-GitHub-Delivery",
|
||||
),
|
||||
) -> dict[str, str]:
|
||||
"""Receive and dispatch GitHub webhook events."""
|
||||
|
|
@ -100,7 +104,9 @@ async def webhook(
|
|||
payload = await request.json()
|
||||
log.info(
|
||||
"Event %s delivery=%s action=%s",
|
||||
x_github_event, x_github_delivery, payload.get("action"),
|
||||
x_github_event,
|
||||
x_github_delivery,
|
||||
payload.get("action"),
|
||||
)
|
||||
|
||||
handler = EVENT_HANDLERS.get(x_github_event)
|
||||
|
|
@ -112,7 +118,8 @@ async def webhook(
|
|||
|
||||
task = asyncio.create_task(
|
||||
safe_handle(
|
||||
handler, payload,
|
||||
handler,
|
||||
payload,
|
||||
config=cfg,
|
||||
http_client=http_client,
|
||||
),
|
||||
|
|
@ -153,20 +160,28 @@ async def dispatch_issues(
|
|||
installation_id = payload["installation"]["id"]
|
||||
|
||||
token = await get_installation_token(
|
||||
config, installation_id, client=http_client,
|
||||
config,
|
||||
installation_id,
|
||||
client=http_client,
|
||||
)
|
||||
repo_dir = await clone_repo(
|
||||
owner, repo, repo_info["default_branch"],
|
||||
token, config.workspace_dir,
|
||||
owner,
|
||||
repo,
|
||||
repo_info["default_branch"],
|
||||
token,
|
||||
config.workspace_dir,
|
||||
)
|
||||
|
||||
_write_ci_context(str(repo_dir), {
|
||||
"event_type": "issues",
|
||||
"action": action,
|
||||
"owner": owner,
|
||||
"repo": repo,
|
||||
"number": payload["issue"]["number"],
|
||||
})
|
||||
_write_ci_context(
|
||||
str(repo_dir),
|
||||
{
|
||||
"event_type": "issues",
|
||||
"action": action,
|
||||
"owner": owner,
|
||||
"repo": repo,
|
||||
"number": payload["issue"]["number"],
|
||||
},
|
||||
)
|
||||
|
||||
await run_agent(config, repo_dir, token)
|
||||
log.info("Agent handled issue %s/%s#%d", owner, repo, payload["issue"]["number"])
|
||||
|
|
@ -191,22 +206,30 @@ async def dispatch_pr(
|
|||
installation_id = payload["installation"]["id"]
|
||||
|
||||
token = await get_installation_token(
|
||||
config, installation_id, client=http_client,
|
||||
config,
|
||||
installation_id,
|
||||
client=http_client,
|
||||
)
|
||||
repo_dir = await clone_repo(
|
||||
owner, repo, pr["head"]["ref"],
|
||||
token, config.workspace_dir,
|
||||
owner,
|
||||
repo,
|
||||
pr["head"]["ref"],
|
||||
token,
|
||||
config.workspace_dir,
|
||||
)
|
||||
|
||||
_write_ci_context(str(repo_dir), {
|
||||
"event_type": "pull_request",
|
||||
"action": action,
|
||||
"owner": owner,
|
||||
"repo": repo,
|
||||
"number": pr["number"],
|
||||
"base_ref": pr["base"]["ref"],
|
||||
"head_ref": pr["head"]["ref"],
|
||||
})
|
||||
_write_ci_context(
|
||||
str(repo_dir),
|
||||
{
|
||||
"event_type": "pull_request",
|
||||
"action": action,
|
||||
"owner": owner,
|
||||
"repo": repo,
|
||||
"number": pr["number"],
|
||||
"base_ref": pr["base"]["ref"],
|
||||
"head_ref": pr["head"]["ref"],
|
||||
},
|
||||
)
|
||||
|
||||
ci_prompt = (
|
||||
"AUTONOMOUS MODE: Do NOT ask the user any questions — work fully "
|
||||
|
|
@ -248,21 +271,29 @@ async def dispatch_push(
|
|||
return
|
||||
|
||||
token = await get_installation_token(
|
||||
config, installation_id, client=http_client,
|
||||
config,
|
||||
installation_id,
|
||||
client=http_client,
|
||||
)
|
||||
repo_dir = await clone_repo(
|
||||
owner, repo, default_branch,
|
||||
token, config.workspace_dir,
|
||||
owner,
|
||||
repo,
|
||||
default_branch,
|
||||
token,
|
||||
config.workspace_dir,
|
||||
)
|
||||
|
||||
_write_ci_context(str(repo_dir), {
|
||||
"event_type": "push",
|
||||
"action": None,
|
||||
"owner": owner,
|
||||
"repo": repo,
|
||||
"head_sha": payload.get("after", ""),
|
||||
"ref": ref,
|
||||
})
|
||||
_write_ci_context(
|
||||
str(repo_dir),
|
||||
{
|
||||
"event_type": "push",
|
||||
"action": None,
|
||||
"owner": owner,
|
||||
"repo": repo,
|
||||
"head_sha": payload.get("after", ""),
|
||||
"ref": ref,
|
||||
},
|
||||
)
|
||||
|
||||
await run_agent(config, repo_dir, token)
|
||||
log.info("Agent handled push to %s/%s ref=%s", owner, repo, ref)
|
||||
|
|
|
|||
|
|
@ -23,7 +23,8 @@ GITHUB_API = "https://api.github.com"
|
|||
# Cache installation tokens for 50 min (tokens last 1 hour).
|
||||
# Keyed by (app_id, installation_id) to prevent cross-app leakage.
|
||||
token_cache: TTLCache[tuple[str | int, int], str] = TTLCache(
|
||||
maxsize=64, ttl=3000,
|
||||
maxsize=64,
|
||||
ttl=3000,
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -40,7 +41,9 @@ def generate_jwt(cfg: Config) -> str:
|
|||
|
||||
@stamina.retry(on=is_retryable, attempts=3)
|
||||
async def get_installation_token(
|
||||
cfg: Config, installation_id: int, *,
|
||||
cfg: Config,
|
||||
installation_id: int,
|
||||
*,
|
||||
client: httpx.AsyncClient,
|
||||
) -> str:
|
||||
"""Exchange the JWT for an installation access token.
|
||||
|
|
@ -54,8 +57,7 @@ async def get_installation_token(
|
|||
|
||||
token = generate_jwt(cfg)
|
||||
resp = await client.post(
|
||||
f"{GITHUB_API}/app/installations/"
|
||||
f"{installation_id}/access_tokens",
|
||||
f"{GITHUB_API}/app/installations/{installation_id}/access_tokens",
|
||||
headers={
|
||||
"Authorization": f"Bearer {token}",
|
||||
"Accept": "application/vnd.github+json",
|
||||
|
|
@ -69,12 +71,16 @@ async def get_installation_token(
|
|||
|
||||
|
||||
def verify_signature(
|
||||
payload: bytes, signature: str, secret: str,
|
||||
payload: bytes,
|
||||
signature: str,
|
||||
secret: str,
|
||||
) -> bool:
|
||||
"""Verify the X-Hub-Signature-256 header."""
|
||||
if not signature.startswith("sha256="):
|
||||
return False
|
||||
expected = hmac.new(
|
||||
secret.encode(), payload, hashlib.sha256,
|
||||
secret.encode(),
|
||||
payload,
|
||||
hashlib.sha256,
|
||||
).hexdigest()
|
||||
return hmac.compare_digest(f"sha256={expected}", signature)
|
||||
|
|
|
|||
|
|
@ -50,8 +50,11 @@ class BackendSpec(ABC):
|
|||
that need extra flags to enable file editing.
|
||||
"""
|
||||
return self.build_cmd(
|
||||
cli=cli, model=model, prompt=prompt,
|
||||
repo_dir=repo_dir, plugin_dir=plugin_dir,
|
||||
cli=cli,
|
||||
model=model,
|
||||
prompt=prompt,
|
||||
repo_dir=repo_dir,
|
||||
plugin_dir=plugin_dir,
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -84,10 +87,15 @@ class ClaudeBackend(BackendSpec):
|
|||
agent: str = "codeflash-ci",
|
||||
) -> tuple[list[str], str | None]:
|
||||
cmd = [
|
||||
cli, "-p", prompt,
|
||||
"--model", model,
|
||||
"--agent", agent,
|
||||
"--max-turns", "200",
|
||||
cli,
|
||||
"-p",
|
||||
prompt,
|
||||
"--model",
|
||||
model,
|
||||
"--agent",
|
||||
agent,
|
||||
"--max-turns",
|
||||
"200",
|
||||
"--dangerously-skip-permissions",
|
||||
]
|
||||
if plugin_dir:
|
||||
|
|
@ -109,11 +117,15 @@ class CodexBackend(BackendSpec):
|
|||
plugin_dir: Path | None = None,
|
||||
) -> tuple[list[str], str | None]:
|
||||
cmd = [
|
||||
cli, "exec",
|
||||
"--model", model,
|
||||
cli,
|
||||
"exec",
|
||||
"--model",
|
||||
model,
|
||||
"--full-auto",
|
||||
"-C", str(repo_dir),
|
||||
"-o", "/dev/stdout",
|
||||
"-C",
|
||||
str(repo_dir),
|
||||
"-o",
|
||||
"/dev/stdout",
|
||||
prompt,
|
||||
]
|
||||
return cmd, None
|
||||
|
|
|
|||
|
|
@ -45,7 +45,8 @@ class Config:
|
|||
)
|
||||
claude_model: str = field(
|
||||
default_factory=lambda: os.environ.get(
|
||||
"CLAUDE_MODEL", "us.anthropic.claude-sonnet-4-5-20250929-v1:0",
|
||||
"CLAUDE_MODEL",
|
||||
"us.anthropic.claude-sonnet-4-5-20250929-v1:0",
|
||||
),
|
||||
)
|
||||
plugin_dir: Path = field(default_factory=default_plugin_dir)
|
||||
|
|
@ -56,7 +57,8 @@ class Config:
|
|||
)
|
||||
codex_model: str = field(
|
||||
default_factory=lambda: os.environ.get(
|
||||
"CODEX_MODEL", "gpt-5.4",
|
||||
"CODEX_MODEL",
|
||||
"gpt-5.4",
|
||||
),
|
||||
)
|
||||
|
||||
|
|
@ -77,7 +79,8 @@ class Config:
|
|||
workspace_dir: Path = field(
|
||||
default_factory=lambda: Path(
|
||||
os.environ.get(
|
||||
"WORKSPACE_DIR", "/tmp/codeflash-workspaces",
|
||||
"WORKSPACE_DIR",
|
||||
"/tmp/codeflash-workspaces",
|
||||
),
|
||||
),
|
||||
)
|
||||
|
|
|
|||
|
|
@ -11,7 +11,9 @@ log = logging.getLogger(__name__)
|
|||
|
||||
|
||||
def _validate_clone_args(
|
||||
owner: str, repo: str, workspace: Path,
|
||||
owner: str,
|
||||
repo: str,
|
||||
workspace: Path,
|
||||
) -> None:
|
||||
"""Reject owner/repo values that could escape the workspace."""
|
||||
for name, value in [("owner", owner), ("repo", repo)]:
|
||||
|
|
@ -36,7 +38,8 @@ async def clone_repo(
|
|||
# Atomic unique directory -- avoids race conditions and rmtree.
|
||||
repo_dir = Path(
|
||||
tempfile.mkdtemp(
|
||||
prefix=f"{owner}_{repo}_", dir=workspace,
|
||||
prefix=f"{owner}_{repo}_",
|
||||
dir=workspace,
|
||||
),
|
||||
)
|
||||
|
||||
|
|
@ -46,13 +49,15 @@ async def clone_repo(
|
|||
msg = f"Path escapes workspace: {repo_dir}"
|
||||
raise ValueError(msg)
|
||||
|
||||
clone_url = (
|
||||
f"https://x-access-token:{token}@github.com"
|
||||
f"/{owner}/{repo}.git"
|
||||
)
|
||||
clone_url = f"https://x-access-token:{token}@github.com/{owner}/{repo}.git"
|
||||
proc = await asyncio.create_subprocess_exec(
|
||||
"git", "clone", "--depth=1", "--branch", ref,
|
||||
clone_url, str(repo_dir),
|
||||
"git",
|
||||
"clone",
|
||||
"--depth=1",
|
||||
"--branch",
|
||||
ref,
|
||||
clone_url,
|
||||
str(repo_dir),
|
||||
stdout=asyncio.subprocess.PIPE,
|
||||
stderr=asyncio.subprocess.PIPE,
|
||||
)
|
||||
|
|
@ -60,7 +65,8 @@ async def clone_repo(
|
|||
if proc.returncode != 0:
|
||||
log.error(
|
||||
"git clone failed (rc=%d): %s",
|
||||
proc.returncode, stderr.decode(),
|
||||
proc.returncode,
|
||||
stderr.decode(),
|
||||
)
|
||||
msg = f"git clone failed for {owner}/{repo} ref={ref}"
|
||||
raise RuntimeError(msg)
|
||||
|
|
@ -80,7 +86,9 @@ async def commit_and_push(
|
|||
"""
|
||||
# Stage everything
|
||||
proc = await asyncio.create_subprocess_exec(
|
||||
"git", "add", "-A",
|
||||
"git",
|
||||
"add",
|
||||
"-A",
|
||||
cwd=str(repo_dir),
|
||||
stdout=asyncio.subprocess.PIPE,
|
||||
stderr=asyncio.subprocess.PIPE,
|
||||
|
|
@ -89,7 +97,10 @@ async def commit_and_push(
|
|||
|
||||
# Check if there are staged changes
|
||||
proc = await asyncio.create_subprocess_exec(
|
||||
"git", "diff", "--cached", "--quiet",
|
||||
"git",
|
||||
"diff",
|
||||
"--cached",
|
||||
"--quiet",
|
||||
cwd=str(repo_dir),
|
||||
stdout=asyncio.subprocess.PIPE,
|
||||
stderr=asyncio.subprocess.PIPE,
|
||||
|
|
@ -101,7 +112,10 @@ async def commit_and_push(
|
|||
|
||||
# Commit
|
||||
proc = await asyncio.create_subprocess_exec(
|
||||
"git", "commit", "-m", message,
|
||||
"git",
|
||||
"commit",
|
||||
"-m",
|
||||
message,
|
||||
cwd=str(repo_dir),
|
||||
stdout=asyncio.subprocess.PIPE,
|
||||
stderr=asyncio.subprocess.PIPE,
|
||||
|
|
@ -115,7 +129,11 @@ async def commit_and_push(
|
|||
# installation token (which may lack push permission).
|
||||
plain_url = f"https://github.com/{owner}/{repo}.git"
|
||||
proc = await asyncio.create_subprocess_exec(
|
||||
"git", "remote", "set-url", "origin", plain_url,
|
||||
"git",
|
||||
"remote",
|
||||
"set-url",
|
||||
"origin",
|
||||
plain_url,
|
||||
cwd=str(repo_dir),
|
||||
stdout=asyncio.subprocess.PIPE,
|
||||
stderr=asyncio.subprocess.PIPE,
|
||||
|
|
@ -124,7 +142,10 @@ async def commit_and_push(
|
|||
|
||||
# Push to the PR branch
|
||||
proc = await asyncio.create_subprocess_exec(
|
||||
"git", "push", "origin", f"HEAD:{branch}",
|
||||
"git",
|
||||
"push",
|
||||
"origin",
|
||||
f"HEAD:{branch}",
|
||||
cwd=str(repo_dir),
|
||||
stdout=asyncio.subprocess.PIPE,
|
||||
stderr=asyncio.subprocess.PIPE,
|
||||
|
|
|
|||
|
|
@ -22,12 +22,14 @@ MAX_PAGES = 50
|
|||
@stamina.retry(on=is_retryable, attempts=3)
|
||||
async def fetch_pr_diff(
|
||||
client: httpx.AsyncClient,
|
||||
owner: str, repo: str, pr_number: int, token: str,
|
||||
owner: str,
|
||||
repo: str,
|
||||
pr_number: int,
|
||||
token: str,
|
||||
) -> str:
|
||||
"""Fetch the unified diff for a pull request via GitHub API."""
|
||||
resp = await client.get(
|
||||
f"{GITHUB_API}/repos/{owner}/{repo}"
|
||||
f"/pulls/{pr_number}",
|
||||
f"{GITHUB_API}/repos/{owner}/{repo}/pulls/{pr_number}",
|
||||
headers={
|
||||
"Authorization": f"token {token}",
|
||||
"Accept": "application/vnd.github.diff",
|
||||
|
|
@ -40,15 +42,17 @@ async def fetch_pr_diff(
|
|||
@stamina.retry(on=is_retryable, attempts=3)
|
||||
async def fetch_pr_files(
|
||||
client: httpx.AsyncClient,
|
||||
owner: str, repo: str, pr_number: int, token: str,
|
||||
owner: str,
|
||||
repo: str,
|
||||
pr_number: int,
|
||||
token: str,
|
||||
) -> list[dict]:
|
||||
"""Fetch the list of changed files for a pull request (paginated)."""
|
||||
files: list[dict] = []
|
||||
page = 1
|
||||
while True:
|
||||
resp = await client.get(
|
||||
f"{GITHUB_API}/repos/{owner}/{repo}"
|
||||
f"/pulls/{pr_number}/files",
|
||||
f"{GITHUB_API}/repos/{owner}/{repo}/pulls/{pr_number}/files",
|
||||
headers={
|
||||
"Authorization": f"token {token}",
|
||||
"Accept": "application/vnd.github+json",
|
||||
|
|
@ -63,9 +67,11 @@ async def fetch_pr_files(
|
|||
page += 1
|
||||
if page > MAX_PAGES:
|
||||
log.warning(
|
||||
"Pagination cap reached fetching files for "
|
||||
"%s/%s#%d (%d pages)",
|
||||
owner, repo, pr_number, MAX_PAGES,
|
||||
"Pagination cap reached fetching files for %s/%s#%d (%d pages)",
|
||||
owner,
|
||||
repo,
|
||||
pr_number,
|
||||
MAX_PAGES,
|
||||
)
|
||||
break
|
||||
return files
|
||||
|
|
@ -74,12 +80,14 @@ async def fetch_pr_files(
|
|||
@stamina.retry(on=is_retryable, attempts=3)
|
||||
async def fetch_pr_details(
|
||||
client: httpx.AsyncClient,
|
||||
owner: str, repo: str, pr_number: int, token: str,
|
||||
owner: str,
|
||||
repo: str,
|
||||
pr_number: int,
|
||||
token: str,
|
||||
) -> dict:
|
||||
"""Fetch PR metadata (head/base refs, title, etc.)."""
|
||||
resp = await client.get(
|
||||
f"{GITHUB_API}/repos/{owner}/{repo}"
|
||||
f"/pulls/{pr_number}",
|
||||
f"{GITHUB_API}/repos/{owner}/{repo}/pulls/{pr_number}",
|
||||
headers={
|
||||
"Authorization": f"token {token}",
|
||||
"Accept": "application/vnd.github+json",
|
||||
|
|
@ -92,12 +100,14 @@ async def fetch_pr_details(
|
|||
@stamina.retry(on=is_retryable, attempts=3)
|
||||
async def fetch_commit_diff(
|
||||
client: httpx.AsyncClient,
|
||||
owner: str, repo: str, sha: str, token: str,
|
||||
owner: str,
|
||||
repo: str,
|
||||
sha: str,
|
||||
token: str,
|
||||
) -> str:
|
||||
"""Fetch the unified diff for a single commit via GitHub API."""
|
||||
resp = await client.get(
|
||||
f"{GITHUB_API}/repos/{owner}/{repo}"
|
||||
f"/commits/{sha}",
|
||||
f"{GITHUB_API}/repos/{owner}/{repo}/commits/{sha}",
|
||||
headers={
|
||||
"Authorization": f"token {token}",
|
||||
"Accept": "application/vnd.github.diff",
|
||||
|
|
@ -110,7 +120,9 @@ async def fetch_commit_diff(
|
|||
@stamina.retry(on=is_retryable, attempts=3)
|
||||
async def fetch_repo_labels(
|
||||
client: httpx.AsyncClient,
|
||||
owner: str, repo: str, token: str,
|
||||
owner: str,
|
||||
repo: str,
|
||||
token: str,
|
||||
) -> list[str]:
|
||||
"""Fetch all label names from a repository."""
|
||||
labels: list[str] = []
|
||||
|
|
@ -132,9 +144,10 @@ async def fetch_repo_labels(
|
|||
page += 1
|
||||
if page > MAX_PAGES:
|
||||
log.warning(
|
||||
"Pagination cap reached fetching labels for "
|
||||
"%s/%s (%d pages)",
|
||||
owner, repo, MAX_PAGES,
|
||||
"Pagination cap reached fetching labels for %s/%s (%d pages)",
|
||||
owner,
|
||||
repo,
|
||||
MAX_PAGES,
|
||||
)
|
||||
break
|
||||
return labels
|
||||
|
|
@ -143,13 +156,16 @@ async def fetch_repo_labels(
|
|||
@stamina.retry(on=is_retryable, attempts=3)
|
||||
async def post_review(
|
||||
client: httpx.AsyncClient,
|
||||
owner: str, repo: str, pr_number: int,
|
||||
body: str, event: str, token: str,
|
||||
owner: str,
|
||||
repo: str,
|
||||
pr_number: int,
|
||||
body: str,
|
||||
event: str,
|
||||
token: str,
|
||||
) -> None:
|
||||
"""Submit a PR review (COMMENT, APPROVE, or REQUEST_CHANGES)."""
|
||||
resp = await client.post(
|
||||
f"{GITHUB_API}/repos/{owner}/{repo}"
|
||||
f"/pulls/{pr_number}/reviews",
|
||||
f"{GITHUB_API}/repos/{owner}/{repo}/pulls/{pr_number}/reviews",
|
||||
headers={
|
||||
"Authorization": f"token {token}",
|
||||
"Accept": "application/vnd.github+json",
|
||||
|
|
@ -162,13 +178,15 @@ async def post_review(
|
|||
@stamina.retry(on=is_retryable, attempts=3)
|
||||
async def post_comment(
|
||||
client: httpx.AsyncClient,
|
||||
owner: str, repo: str, issue_number: int,
|
||||
body: str, token: str,
|
||||
owner: str,
|
||||
repo: str,
|
||||
issue_number: int,
|
||||
body: str,
|
||||
token: str,
|
||||
) -> None:
|
||||
"""Post a comment on a PR or issue."""
|
||||
resp = await client.post(
|
||||
f"{GITHUB_API}/repos/{owner}/{repo}"
|
||||
f"/issues/{issue_number}/comments",
|
||||
f"{GITHUB_API}/repos/{owner}/{repo}/issues/{issue_number}/comments",
|
||||
headers={
|
||||
"Authorization": f"token {token}",
|
||||
"Accept": "application/vnd.github+json",
|
||||
|
|
@ -181,13 +199,15 @@ async def post_comment(
|
|||
@stamina.retry(on=is_retryable, attempts=3)
|
||||
async def add_labels(
|
||||
client: httpx.AsyncClient,
|
||||
owner: str, repo: str, issue_number: int,
|
||||
labels: list[str], token: str,
|
||||
owner: str,
|
||||
repo: str,
|
||||
issue_number: int,
|
||||
labels: list[str],
|
||||
token: str,
|
||||
) -> None:
|
||||
"""Add labels to an issue or PR."""
|
||||
resp = await client.post(
|
||||
f"{GITHUB_API}/repos/{owner}/{repo}"
|
||||
f"/issues/{issue_number}/labels",
|
||||
f"{GITHUB_API}/repos/{owner}/{repo}/issues/{issue_number}/labels",
|
||||
headers={
|
||||
"Authorization": f"token {token}",
|
||||
"Accept": "application/vnd.github+json",
|
||||
|
|
@ -200,8 +220,12 @@ async def add_labels(
|
|||
@stamina.retry(on=is_retryable, attempts=3)
|
||||
async def create_check_run(
|
||||
client: httpx.AsyncClient,
|
||||
owner: str, repo: str, head_sha: str,
|
||||
name: str, conclusion: str, output: dict,
|
||||
owner: str,
|
||||
repo: str,
|
||||
head_sha: str,
|
||||
name: str,
|
||||
conclusion: str,
|
||||
output: dict,
|
||||
token: str,
|
||||
) -> None:
|
||||
"""Create a check run on a commit."""
|
||||
|
|
@ -238,7 +262,4 @@ def truncate_diff(diff: str, max_chars: int = MAX_DIFF_CHARS) -> str:
|
|||
"""Truncate diff to max_chars, appending a note if cut."""
|
||||
if len(diff) <= max_chars:
|
||||
return diff
|
||||
return (
|
||||
diff[:max_chars]
|
||||
+ "\n\n... (diff truncated, full repo available)"
|
||||
)
|
||||
return diff[:max_chars] + "\n\n... (diff truncated, full repo available)"
|
||||
|
|
|
|||
|
|
@ -16,8 +16,9 @@ def is_retryable(exc: Exception) -> bool:
|
|||
code = exc.response.status_code
|
||||
return code == 429 or code >= 500
|
||||
return isinstance(
|
||||
exc, (httpx.ConnectError, httpx.TimeoutException),
|
||||
exc,
|
||||
(httpx.ConnectError, httpx.TimeoutException),
|
||||
)
|
||||
|
||||
|
||||
# https://smee.io/ACAUooTvHulETive
|
||||
# https://smee.io/ACAUooTvHulETive
|
||||
|
|
|
|||
|
|
@ -24,7 +24,9 @@ def test_generate_jwt_structure(mock_config):
|
|||
def test_generate_jwt_claims(mock_config):
|
||||
token = generate_jwt(mock_config)
|
||||
claims = pyjwt.decode(
|
||||
token, options={"verify_signature": False}, algorithms=["RS256"],
|
||||
token,
|
||||
options={"verify_signature": False},
|
||||
algorithms=["RS256"],
|
||||
)
|
||||
# PyJWT requires iss as string; Config.app_id is int, converted in generate_jwt.
|
||||
assert claims["iss"] == "12345"
|
||||
|
|
@ -40,7 +42,9 @@ def test_verify_signature_valid():
|
|||
import hmac
|
||||
|
||||
sig = hmac.new(
|
||||
WEBHOOK_SECRET.encode(), payload, hashlib.sha256,
|
||||
WEBHOOK_SECRET.encode(),
|
||||
payload,
|
||||
hashlib.sha256,
|
||||
).hexdigest()
|
||||
assert verify_signature(payload, f"sha256={sig}", WEBHOOK_SECRET)
|
||||
|
||||
|
|
@ -61,7 +65,9 @@ async def test_get_installation_token_fetches(mock_config):
|
|||
|
||||
async with httpx.AsyncClient() as client:
|
||||
token = await get_installation_token(
|
||||
mock_config, 99, client=client,
|
||||
mock_config,
|
||||
99,
|
||||
client=client,
|
||||
)
|
||||
|
||||
assert token == "ghs_test123"
|
||||
|
|
|
|||
|
|
@ -19,9 +19,13 @@ def test_claude_backend_build_cmd():
|
|||
plugin_dir=Path("/tmp/plugins"),
|
||||
)
|
||||
assert cmd == [
|
||||
"claude", "-p", "review this",
|
||||
"--model", "claude-sonnet-4-6",
|
||||
"--plugin-dir", "/tmp/plugins",
|
||||
"claude",
|
||||
"-p",
|
||||
"review this",
|
||||
"--model",
|
||||
"claude-sonnet-4-6",
|
||||
"--plugin-dir",
|
||||
"/tmp/plugins",
|
||||
]
|
||||
assert cwd == "/tmp/repo"
|
||||
|
||||
|
|
@ -48,12 +52,18 @@ def test_claude_backend_build_edit_cmd_default_agent():
|
|||
plugin_dir=Path("/tmp/plugins"),
|
||||
)
|
||||
assert cmd == [
|
||||
"claude", "-p", "CI: process .codeflash/ci-context.json",
|
||||
"--model", "claude-sonnet-4-6",
|
||||
"--agent", "codeflash-ci",
|
||||
"--max-turns", "200",
|
||||
"claude",
|
||||
"-p",
|
||||
"CI: process .codeflash/ci-context.json",
|
||||
"--model",
|
||||
"claude-sonnet-4-6",
|
||||
"--agent",
|
||||
"codeflash-ci",
|
||||
"--max-turns",
|
||||
"200",
|
||||
"--dangerously-skip-permissions",
|
||||
"--plugin-dir", "/tmp/plugins",
|
||||
"--plugin-dir",
|
||||
"/tmp/plugins",
|
||||
]
|
||||
assert cwd == "/tmp/repo"
|
||||
|
||||
|
|
@ -83,11 +93,15 @@ def test_codex_backend_build_cmd():
|
|||
plugin_dir=Path("/tmp/plugins"),
|
||||
)
|
||||
assert cmd == [
|
||||
"codex", "exec",
|
||||
"--model", "gpt-5.4",
|
||||
"codex",
|
||||
"exec",
|
||||
"--model",
|
||||
"gpt-5.4",
|
||||
"--full-auto",
|
||||
"-C", "/tmp/repo",
|
||||
"-o", "/dev/stdout",
|
||||
"-C",
|
||||
"/tmp/repo",
|
||||
"-o",
|
||||
"/dev/stdout",
|
||||
"review this",
|
||||
]
|
||||
assert cwd is None
|
||||
|
|
|
|||
|
|
@ -29,9 +29,14 @@ def test_load_private_key_from_file(tmp_path):
|
|||
|
||||
|
||||
def test_load_private_key_missing():
|
||||
with patch.dict(
|
||||
os.environ, {}, clear=True,
|
||||
), pytest.raises(ValueError, match="GITHUB_PRIVATE_KEY"):
|
||||
with (
|
||||
patch.dict(
|
||||
os.environ,
|
||||
{},
|
||||
clear=True,
|
||||
),
|
||||
pytest.raises(ValueError, match="GITHUB_PRIVATE_KEY"),
|
||||
):
|
||||
load_private_key()
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -52,7 +52,11 @@ async def test_clone_repo_success(tmp_path):
|
|||
|
||||
with patch(PATCH_TARGET, return_value=mock_proc):
|
||||
result = await clone_repo(
|
||||
"owner", "repo", "main", "tok", tmp_path,
|
||||
"owner",
|
||||
"repo",
|
||||
"main",
|
||||
"tok",
|
||||
tmp_path,
|
||||
)
|
||||
|
||||
assert result.parent == tmp_path
|
||||
|
|
|
|||
|
|
@ -14,7 +14,9 @@ def _make_status_error(status_code: int) -> httpx.HTTPStatusError:
|
|||
response = Mock(spec=httpx.Response)
|
||||
response.status_code = status_code
|
||||
return httpx.HTTPStatusError(
|
||||
"error", request=Mock(), response=response,
|
||||
"error",
|
||||
request=Mock(),
|
||||
response=response,
|
||||
)
|
||||
|
||||
|
||||
|
|
|
|||
Loading…
Reference in a new issue