mirror of
https://github.com/codeflash-ai/codeflash-agent.git
synced 2026-05-04 18:25:19 +00:00
fix: deduplicate code across codeflash-core and codeflash-python
- Extract _parse_candidates helper in _client.py (used by get_candidates and optimize_with_line_profiler) - Parameterize URL resolution in _http.py (_resolve_url_from_env replaces two near-identical functions) - Delegate get_repo_owner_and_name to parse_repo_owner_and_name in _git.py - Simplify _par_apply_fns to delegate to _apply_fns in danom/stream.py - Remove duplicate performance_gain from _verification.py (use codeflash_core's version) - Extract _extract_pytest_error helper in _verification.py (replaces duplicated 6-line block) - Consolidate collect_names_from_annotation into collect_type_names_from_annotation in _ast_helpers.py - Add ast.Attribute handling and relax BinOp guard in collect_type_names_from_annotation - Add unit tests for all extracted helpers
This commit is contained in:
parent
ffadf16147
commit
6b73b07d15
15 changed files with 186 additions and 129 deletions
|
|
@ -26,6 +26,19 @@ from .exceptions import (
|
|||
AIServiceError,
|
||||
)
|
||||
|
||||
|
||||
def _parse_candidates(data: dict[str, Any]) -> list[Candidate]:
|
||||
"""Build a list of candidates from an AI service response."""
|
||||
return [
|
||||
Candidate(
|
||||
code=item.get("source_code", ""),
|
||||
explanation=item.get("explanation", ""),
|
||||
candidate_id=item.get("optimization_id", ""),
|
||||
)
|
||||
for item in data.get("optimizations", [])
|
||||
]
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import requests as requests_types
|
||||
|
||||
|
|
@ -136,14 +149,7 @@ class AIClient:
|
|||
if request.test_input_examples is not None:
|
||||
payload["test_input_examples"] = request.test_input_examples
|
||||
data = self.post("/optimize", payload)
|
||||
return [
|
||||
Candidate(
|
||||
code=item.get("source_code", ""),
|
||||
explanation=item.get("explanation", ""),
|
||||
candidate_id=item.get("optimization_id", ""),
|
||||
)
|
||||
for item in data.get("optimizations", [])
|
||||
]
|
||||
return _parse_candidates(data)
|
||||
|
||||
def generate_ranking(
|
||||
self,
|
||||
|
|
@ -200,14 +206,7 @@ class AIClient:
|
|||
"codeflash_version": request.codeflash_version,
|
||||
}
|
||||
data = self.post("/optimize-line-profiler", payload)
|
||||
return [
|
||||
Candidate(
|
||||
code=item.get("source_code", ""),
|
||||
explanation=item.get("explanation", ""),
|
||||
candidate_id=item.get("optimization_id", ""),
|
||||
)
|
||||
for item in data.get("optimizations", [])
|
||||
]
|
||||
return _parse_candidates(data)
|
||||
|
||||
def generate_explanation(
|
||||
self,
|
||||
|
|
|
|||
|
|
@ -47,18 +47,9 @@ def get_repo_owner_and_name(
|
|||
git_remote: str = "origin",
|
||||
) -> tuple[str, str]:
|
||||
"""Return (owner, repo_name) parsed from the git remote URL."""
|
||||
remote_url = get_remote_url(repo, git_remote)
|
||||
if remote_url.endswith(".git"):
|
||||
remote_url = remote_url.removesuffix(".git")
|
||||
remote_url = remote_url.rstrip("/")
|
||||
split_url = remote_url.split("/")
|
||||
repo_owner_with_github, repo_name = split_url[-2], split_url[-1]
|
||||
repo_owner = (
|
||||
repo_owner_with_github.split(":")[1]
|
||||
if ":" in repo_owner_with_github
|
||||
else repo_owner_with_github
|
||||
)
|
||||
return repo_owner, repo_name
|
||||
from ._platform import parse_repo_owner_and_name # noqa: PLC0415
|
||||
|
||||
return parse_repo_owner_and_name(get_remote_url(repo, git_remote))
|
||||
|
||||
|
||||
def check_running_in_git_repo(module_root: str) -> bool:
|
||||
|
|
|
|||
|
|
@ -7,29 +7,35 @@ from typing import Any
|
|||
|
||||
from .exceptions import InvalidAPIKeyError
|
||||
|
||||
_PROD_URL = "https://app.codeflash.ai"
|
||||
_LOCAL_URL = "http://localhost:8000"
|
||||
|
||||
_CFAPI_PROD_URL = "https://app.codeflash.ai"
|
||||
_CFAPI_LOCAL_URL = "http://localhost:3001"
|
||||
def _resolve_url_from_env(
|
||||
env_var: str,
|
||||
prod_url: str,
|
||||
local_url: str,
|
||||
) -> str:
|
||||
"""Return *prod_url* or *local_url* based on an environment variable."""
|
||||
server = os.environ.get(env_var, "prod")
|
||||
if server.lower() == "local":
|
||||
return local_url
|
||||
return prod_url
|
||||
|
||||
|
||||
def _resolve_base_url() -> str:
|
||||
"""
|
||||
Return the base URL based on *CODEFLASH_AIS_SERVER*.
|
||||
"""
|
||||
server = os.environ.get("CODEFLASH_AIS_SERVER", "prod")
|
||||
if server.lower() == "local":
|
||||
return _LOCAL_URL
|
||||
return _PROD_URL
|
||||
"""Return the AI service base URL from the environment."""
|
||||
return _resolve_url_from_env(
|
||||
"CODEFLASH_AIS_SERVER",
|
||||
"https://app.codeflash.ai",
|
||||
"http://localhost:8000",
|
||||
)
|
||||
|
||||
|
||||
def _resolve_cfapi_base_url() -> str:
|
||||
"""Return the platform API base URL from the environment."""
|
||||
server = os.environ.get("CODEFLASH_CFAPI_SERVER", "prod")
|
||||
if server.lower() == "local":
|
||||
return _CFAPI_LOCAL_URL
|
||||
return _CFAPI_PROD_URL
|
||||
return _resolve_url_from_env(
|
||||
"CODEFLASH_CFAPI_SERVER",
|
||||
"https://app.codeflash.ai",
|
||||
"http://localhost:3001",
|
||||
)
|
||||
|
||||
|
||||
def _strip_trailing_slash(url: str) -> str:
|
||||
|
|
|
|||
|
|
@ -338,21 +338,7 @@ def _par_apply_fns(
|
|||
"""
|
||||
Apply *ops* to *elements* eagerly.
|
||||
"""
|
||||
results: list[Any] = []
|
||||
for elem in elements:
|
||||
valid = True
|
||||
res: Any = elem
|
||||
for op, op_fn in ops:
|
||||
if op == _MAP:
|
||||
res = op_fn(res)
|
||||
elif op == _FILTER and not op_fn(res):
|
||||
valid = False
|
||||
break
|
||||
elif op == _TAP:
|
||||
op_fn(deepcopy(res))
|
||||
if valid:
|
||||
results.append(res)
|
||||
return tuple(results)
|
||||
return tuple(_apply_fns(elements, ops))
|
||||
|
||||
|
||||
async def _async_apply_fns(
|
||||
|
|
|
|||
|
|
@ -14,6 +14,7 @@ from codeflash_core import (
|
|||
OptimizationRequest,
|
||||
OptimizationReviewResult,
|
||||
)
|
||||
from codeflash_core._client import _parse_candidates
|
||||
|
||||
|
||||
@pytest.fixture(name="client")
|
||||
|
|
@ -49,6 +50,51 @@ def _mock_post(client):
|
|||
yield mock
|
||||
|
||||
|
||||
class TestParseCandidates:
|
||||
"""Tests for _parse_candidates."""
|
||||
|
||||
def test_parses_optimizations(self):
|
||||
"""Multiple optimizations are parsed into Candidate objects."""
|
||||
data = {
|
||||
"optimizations": [
|
||||
{
|
||||
"source_code": "def f(): pass",
|
||||
"explanation": "simplified",
|
||||
"optimization_id": "id-1",
|
||||
},
|
||||
{
|
||||
"source_code": "def g(): pass",
|
||||
"explanation": "inlined",
|
||||
"optimization_id": "id-2",
|
||||
},
|
||||
]
|
||||
}
|
||||
result = _parse_candidates(data)
|
||||
|
||||
assert 2 == len(result)
|
||||
assert all(isinstance(c, Candidate) for c in result)
|
||||
assert "id-1" == result[0].candidate_id
|
||||
assert "inlined" == result[1].explanation
|
||||
|
||||
def test_empty_optimizations(self):
|
||||
"""An empty optimizations list returns no candidates."""
|
||||
assert [] == _parse_candidates({"optimizations": []})
|
||||
|
||||
def test_missing_key(self):
|
||||
"""A response without 'optimizations' returns no candidates."""
|
||||
assert [] == _parse_candidates({})
|
||||
|
||||
def test_missing_fields_use_defaults(self):
|
||||
"""Missing item fields default to empty strings."""
|
||||
data = {"optimizations": [{}]}
|
||||
result = _parse_candidates(data)
|
||||
|
||||
assert 1 == len(result)
|
||||
assert "" == result[0].code
|
||||
assert "" == result[0].explanation
|
||||
assert "" == result[0].candidate_id
|
||||
|
||||
|
||||
class TestAIClient:
|
||||
"""Tests for AIClient."""
|
||||
|
||||
|
|
|
|||
|
|
@ -9,6 +9,7 @@ from codeflash_core import (
|
|||
InvalidAPIKeyError,
|
||||
PrComment,
|
||||
)
|
||||
from codeflash_core._http import _resolve_url_from_env
|
||||
from codeflash_core._platform import (
|
||||
PlatformClient,
|
||||
parse_repo_owner_and_name,
|
||||
|
|
@ -87,6 +88,38 @@ class TestParseRepoOwnerAndName:
|
|||
)
|
||||
|
||||
|
||||
class TestResolveUrlFromEnv:
|
||||
"""Tests for _resolve_url_from_env."""
|
||||
|
||||
def test_returns_prod_by_default(self, monkeypatch):
|
||||
"""Unset env var returns prod URL."""
|
||||
monkeypatch.delenv("MY_SERVER", raising=False)
|
||||
assert "https://prod.example.com" == _resolve_url_from_env(
|
||||
"MY_SERVER", "https://prod.example.com", "http://localhost:9000"
|
||||
)
|
||||
|
||||
def test_returns_local_when_set(self, monkeypatch):
|
||||
"""'local' env var returns local URL."""
|
||||
monkeypatch.setenv("MY_SERVER", "local")
|
||||
assert "http://localhost:9000" == _resolve_url_from_env(
|
||||
"MY_SERVER", "https://prod.example.com", "http://localhost:9000"
|
||||
)
|
||||
|
||||
def test_case_insensitive(self, monkeypatch):
|
||||
"""'LOCAL' (uppercase) returns local URL."""
|
||||
monkeypatch.setenv("MY_SERVER", "LOCAL")
|
||||
assert "http://localhost:9000" == _resolve_url_from_env(
|
||||
"MY_SERVER", "https://prod.example.com", "http://localhost:9000"
|
||||
)
|
||||
|
||||
def test_non_local_returns_prod(self, monkeypatch):
|
||||
"""Any value other than 'local' returns prod URL."""
|
||||
monkeypatch.setenv("MY_SERVER", "staging")
|
||||
assert "https://prod.example.com" == _resolve_url_from_env(
|
||||
"MY_SERVER", "https://prod.example.com", "http://localhost:9000"
|
||||
)
|
||||
|
||||
|
||||
class TestPlatformClient:
|
||||
"""Tests for PlatformClient construction."""
|
||||
|
||||
|
|
|
|||
|
|
@ -17,11 +17,10 @@ from typing import TYPE_CHECKING, Any
|
|||
import attrs
|
||||
import libcst as cst
|
||||
|
||||
from codeflash_core import BenchmarkDetail, humanize_runtime
|
||||
from codeflash_core import BenchmarkDetail, humanize_runtime, performance_gain
|
||||
|
||||
from ..analysis._discovery import inspect_top_level_functions_or_methods
|
||||
from ..analysis._formatter import sort_imports
|
||||
from ..verification._verification import performance_gain
|
||||
from .models import (
|
||||
ProcessedBenchmarkInfo,
|
||||
get_function_alias,
|
||||
|
|
|
|||
|
|
@ -13,11 +13,11 @@ from codeflash_core import (
|
|||
PrComment,
|
||||
check_and_push_branch,
|
||||
get_repo_owner_and_name,
|
||||
performance_gain,
|
||||
)
|
||||
|
||||
from ..ai._tabulate import tabulate
|
||||
from ..testing._testgen import format_perf, format_time
|
||||
from ..verification._verification import performance_gain
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import git
|
||||
|
|
|
|||
|
|
@ -145,15 +145,15 @@ def collect_type_names_from_annotation(
|
|||
node: ast.expr | None,
|
||||
) -> set[str]:
|
||||
"""Recursively collect type names from an annotation node."""
|
||||
if node is None:
|
||||
return set()
|
||||
if isinstance(node, ast.Name):
|
||||
return {node.id}
|
||||
if isinstance(node, ast.Attribute) and isinstance(node.value, ast.Name):
|
||||
return {node.value.id}
|
||||
if isinstance(node, ast.Subscript):
|
||||
names = collect_type_names_from_annotation(node.value)
|
||||
names |= collect_type_names_from_annotation(node.slice)
|
||||
return names
|
||||
if isinstance(node, ast.BinOp) and isinstance(node.op, ast.BitOr):
|
||||
if isinstance(node, ast.BinOp):
|
||||
return collect_type_names_from_annotation(
|
||||
node.left
|
||||
) | collect_type_names_from_annotation(node.right)
|
||||
|
|
@ -165,26 +165,6 @@ def collect_type_names_from_annotation(
|
|||
return set()
|
||||
|
||||
|
||||
def collect_names_from_annotation(
|
||||
node: ast.expr,
|
||||
names: set[str],
|
||||
) -> None:
|
||||
"""Mutating variant: add type annotation names into *names*."""
|
||||
if isinstance(node, ast.Name):
|
||||
names.add(node.id)
|
||||
elif isinstance(node, ast.Subscript):
|
||||
collect_names_from_annotation(node.value, names)
|
||||
collect_names_from_annotation(node.slice, names)
|
||||
elif isinstance(node, ast.Tuple):
|
||||
for elt in node.elts:
|
||||
collect_names_from_annotation(elt, names)
|
||||
elif isinstance(node, ast.BinOp):
|
||||
collect_names_from_annotation(node.left, names)
|
||||
collect_names_from_annotation(node.right, names)
|
||||
elif isinstance(node, ast.Attribute) and isinstance(node.value, ast.Name):
|
||||
names.add(node.value.id)
|
||||
|
||||
|
||||
def expr_matches_name(
|
||||
node: ast.AST | None,
|
||||
import_aliases: dict[str, str],
|
||||
|
|
|
|||
|
|
@ -16,7 +16,6 @@ from ._ast_helpers import (
|
|||
ImportCollector,
|
||||
collect_existing_class_names,
|
||||
collect_import_aliases,
|
||||
collect_names_from_annotation,
|
||||
collect_type_names_from_annotation,
|
||||
find_class_node_by_name,
|
||||
get_expr_name,
|
||||
|
|
@ -216,7 +215,7 @@ def extract_imports_for_class( # noqa: C901, PLR0912
|
|||
|
||||
for item in class_node.body:
|
||||
if isinstance(item, ast.AnnAssign) and item.annotation:
|
||||
collect_names_from_annotation(item.annotation, needed_names)
|
||||
needed_names |= collect_type_names_from_annotation(item.annotation)
|
||||
elif (
|
||||
isinstance(item, ast.Assign)
|
||||
and isinstance(item.value, ast.Call)
|
||||
|
|
|
|||
|
|
@ -19,10 +19,13 @@ import libcst as cst
|
|||
from libcst import MetadataWrapper
|
||||
from libcst.metadata import PositionProvider
|
||||
|
||||
from codeflash_core import AIServiceConnectionError, AIServiceError
|
||||
from codeflash_core import (
|
||||
AIServiceConnectionError,
|
||||
AIServiceError,
|
||||
performance_gain,
|
||||
)
|
||||
|
||||
from .._constants import LANGUAGE_FIELDS, LANGUAGE_VERSION
|
||||
from ..verification._verification import performance_gain
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from codeflash_core import AIClient
|
||||
|
|
|
|||
|
|
@ -44,7 +44,16 @@ def shorten_pytest_error(pytest_error_string: str) -> str:
|
|||
)
|
||||
|
||||
|
||||
def compare_test_results( # noqa: C901, PLR0912
|
||||
def _extract_pytest_error(
|
||||
failures: dict[str, str] | None,
|
||||
qualified_name: str,
|
||||
) -> str:
|
||||
"""Look up and shorten a pytest error for *qualified_name*."""
|
||||
raw = failures.get(qualified_name, "") if failures else ""
|
||||
return shorten_pytest_error(raw) if raw else ""
|
||||
|
||||
|
||||
def compare_test_results( # noqa: C901
|
||||
original_results: TestResults,
|
||||
candidate_results: TestResults,
|
||||
pass_fail_only: bool = False, # noqa: FBT001, FBT002
|
||||
|
|
@ -110,30 +119,15 @@ def compare_test_results( # noqa: C901, PLR0912
|
|||
)
|
||||
|
||||
# Gather pytest error messages
|
||||
candidate_test_failures = candidate_results.test_failures
|
||||
original_test_failures = original_results.test_failures
|
||||
cdd_pytest_error = (
|
||||
candidate_test_failures.get(
|
||||
original_test_result.id.test_fn_qualified_name(),
|
||||
"",
|
||||
)
|
||||
if candidate_test_failures
|
||||
else ""
|
||||
qname = original_test_result.id.test_fn_qualified_name()
|
||||
cdd_pytest_error = _extract_pytest_error(
|
||||
candidate_results.test_failures,
|
||||
qname,
|
||||
)
|
||||
if cdd_pytest_error:
|
||||
cdd_pytest_error = shorten_pytest_error(cdd_pytest_error)
|
||||
original_pytest_error = (
|
||||
original_test_failures.get(
|
||||
original_test_result.id.test_fn_qualified_name(),
|
||||
"",
|
||||
)
|
||||
if original_test_failures
|
||||
else ""
|
||||
original_pytest_error = _extract_pytest_error(
|
||||
original_results.test_failures,
|
||||
qname,
|
||||
)
|
||||
if original_pytest_error:
|
||||
original_pytest_error = shorten_pytest_error(
|
||||
original_pytest_error,
|
||||
)
|
||||
|
||||
# Check pass/fail mismatch
|
||||
if original_test_result.test_type in {
|
||||
|
|
@ -236,18 +230,3 @@ def compare_test_results( # noqa: C901, PLR0912
|
|||
return False, test_diffs
|
||||
|
||||
return len(test_diffs) == 0, test_diffs
|
||||
|
||||
|
||||
def performance_gain(
|
||||
*,
|
||||
original_runtime_ns: int,
|
||||
optimized_runtime_ns: int,
|
||||
) -> float:
|
||||
"""Calculate the performance gain of optimized code over the original.
|
||||
|
||||
Returns a ratio where 1.0 means 100% faster (2x speedup).
|
||||
Returns 0.0 when the optimized runtime is zero.
|
||||
"""
|
||||
if optimized_runtime_ns == 0:
|
||||
return 0.0
|
||||
return (original_runtime_ns - optimized_runtime_ns) / optimized_runtime_ns
|
||||
|
|
|
|||
|
|
@ -4585,12 +4585,13 @@ def test_collect_type_names_none_annotation() -> None:
|
|||
assert collect_type_names_from_annotation(None) == set()
|
||||
|
||||
|
||||
def test_collect_type_names_attribute_skipped() -> None:
|
||||
def test_collect_type_names_attribute_collects_module() -> None:
|
||||
"""Attribute like module.Foo collects the module name."""
|
||||
tree = ast.parse("def f(x: module.Foo): pass")
|
||||
func = tree.body[0]
|
||||
assert isinstance(func, ast.FunctionDef)
|
||||
ann = func.args.args[0].annotation
|
||||
assert collect_type_names_from_annotation(ann) == set()
|
||||
assert {"module"} == collect_type_names_from_annotation(ann)
|
||||
|
||||
|
||||
# --- Tests for extract_init_stub_from_class ---
|
||||
|
|
|
|||
|
|
@ -125,6 +125,14 @@ class TestCollectTypeNamesFromAnnotation:
|
|||
"""None returns empty set."""
|
||||
assert set() == collect_type_names_from_annotation(None)
|
||||
|
||||
def test_attribute_collects_module(self) -> None:
|
||||
"""An Attribute like typing.Optional collects the module name."""
|
||||
ann = ast.Attribute(
|
||||
value=ast.Name(id="typing"),
|
||||
attr="Optional",
|
||||
)
|
||||
assert {"typing"} == collect_type_names_from_annotation(ann)
|
||||
|
||||
|
||||
class TestDeclarativeClassDetection:
|
||||
"""Tests for NamedTuple, dataclass, and attrs detection."""
|
||||
|
|
|
|||
|
|
@ -5,6 +5,7 @@ from pathlib import Path
|
|||
import attrs
|
||||
import pytest
|
||||
|
||||
from codeflash_core import performance_gain
|
||||
from codeflash_python._model import VerificationType
|
||||
from codeflash_python.test_discovery.models import TestType
|
||||
from codeflash_python.testing.models import (
|
||||
|
|
@ -13,8 +14,9 @@ from codeflash_python.testing.models import (
|
|||
TestResults,
|
||||
)
|
||||
from codeflash_python.verification._verification import (
|
||||
_extract_pytest_error,
|
||||
compare_test_results,
|
||||
performance_gain,
|
||||
shorten_pytest_error,
|
||||
)
|
||||
from codeflash_python.verification.models import (
|
||||
BehaviorDiff,
|
||||
|
|
@ -485,3 +487,28 @@ class TestGetAllUniqueInvocationLoopIds:
|
|||
results = TestResults()
|
||||
|
||||
assert set() == results.get_all_unique_invocation_loop_ids()
|
||||
|
||||
|
||||
class TestExtractPytestError:
|
||||
"""Tests for _extract_pytest_error."""
|
||||
|
||||
def test_extracts_and_shortens(self) -> None:
|
||||
"""A matching error is looked up and shortened."""
|
||||
failures = {
|
||||
"test_mod::test_fn": "some context\nE AssertionError: 1 != 2\n> assert 1 == 2",
|
||||
}
|
||||
result = _extract_pytest_error(failures, "test_mod::test_fn")
|
||||
|
||||
assert "AssertionError: 1 != 2\nassert 1 == 2" == result
|
||||
|
||||
def test_missing_key_returns_empty(self) -> None:
|
||||
"""A missing qualified name returns an empty string."""
|
||||
assert "" == _extract_pytest_error({"other::fn": "err"}, "test::fn")
|
||||
|
||||
def test_none_failures_returns_empty(self) -> None:
|
||||
"""None failures dict returns an empty string."""
|
||||
assert "" == _extract_pytest_error(None, "test::fn")
|
||||
|
||||
def test_empty_failures_returns_empty(self) -> None:
|
||||
"""Empty failures dict returns an empty string."""
|
||||
assert "" == _extract_pytest_error({}, "test::fn")
|
||||
|
|
|
|||
Loading…
Reference in a new issue