fix: deduplicate code across codeflash-core and codeflash-python

- Extract _parse_candidates helper in _client.py (used by get_candidates and optimize_with_line_profiler)
- Parameterize URL resolution in _http.py (_resolve_url_from_env replaces two near-identical functions)
- Delegate get_repo_owner_and_name to parse_repo_owner_and_name in _git.py
- Simplify _par_apply_fns to delegate to _apply_fns in danom/stream.py
- Remove duplicate performance_gain from _verification.py (use codeflash_core's version)
- Extract _extract_pytest_error helper in _verification.py (replaces duplicated 6-line block)
- Consolidate collect_names_from_annotation into collect_type_names_from_annotation in _ast_helpers.py
- Add ast.Attribute handling and relax BinOp guard in collect_type_names_from_annotation
- Add unit tests for all extracted helpers
This commit is contained in:
Kevin Turcios 2026-04-23 22:39:50 -05:00
parent ffadf16147
commit 6b73b07d15
15 changed files with 186 additions and 129 deletions

View file

@ -26,6 +26,19 @@ from .exceptions import (
AIServiceError,
)
def _parse_candidates(data: dict[str, Any]) -> list[Candidate]:
"""Build a list of candidates from an AI service response."""
return [
Candidate(
code=item.get("source_code", ""),
explanation=item.get("explanation", ""),
candidate_id=item.get("optimization_id", ""),
)
for item in data.get("optimizations", [])
]
if TYPE_CHECKING:
import requests as requests_types
@ -136,14 +149,7 @@ class AIClient:
if request.test_input_examples is not None:
payload["test_input_examples"] = request.test_input_examples
data = self.post("/optimize", payload)
return [
Candidate(
code=item.get("source_code", ""),
explanation=item.get("explanation", ""),
candidate_id=item.get("optimization_id", ""),
)
for item in data.get("optimizations", [])
]
return _parse_candidates(data)
def generate_ranking(
self,
@ -200,14 +206,7 @@ class AIClient:
"codeflash_version": request.codeflash_version,
}
data = self.post("/optimize-line-profiler", payload)
return [
Candidate(
code=item.get("source_code", ""),
explanation=item.get("explanation", ""),
candidate_id=item.get("optimization_id", ""),
)
for item in data.get("optimizations", [])
]
return _parse_candidates(data)
def generate_explanation(
self,

View file

@ -47,18 +47,9 @@ def get_repo_owner_and_name(
git_remote: str = "origin",
) -> tuple[str, str]:
"""Return (owner, repo_name) parsed from the git remote URL."""
remote_url = get_remote_url(repo, git_remote)
if remote_url.endswith(".git"):
remote_url = remote_url.removesuffix(".git")
remote_url = remote_url.rstrip("/")
split_url = remote_url.split("/")
repo_owner_with_github, repo_name = split_url[-2], split_url[-1]
repo_owner = (
repo_owner_with_github.split(":")[1]
if ":" in repo_owner_with_github
else repo_owner_with_github
)
return repo_owner, repo_name
from ._platform import parse_repo_owner_and_name # noqa: PLC0415
return parse_repo_owner_and_name(get_remote_url(repo, git_remote))
def check_running_in_git_repo(module_root: str) -> bool:

View file

@ -7,29 +7,35 @@ from typing import Any
from .exceptions import InvalidAPIKeyError
_PROD_URL = "https://app.codeflash.ai"
_LOCAL_URL = "http://localhost:8000"
_CFAPI_PROD_URL = "https://app.codeflash.ai"
_CFAPI_LOCAL_URL = "http://localhost:3001"
def _resolve_url_from_env(
env_var: str,
prod_url: str,
local_url: str,
) -> str:
"""Return *prod_url* or *local_url* based on an environment variable."""
server = os.environ.get(env_var, "prod")
if server.lower() == "local":
return local_url
return prod_url
def _resolve_base_url() -> str:
"""
Return the base URL based on *CODEFLASH_AIS_SERVER*.
"""
server = os.environ.get("CODEFLASH_AIS_SERVER", "prod")
if server.lower() == "local":
return _LOCAL_URL
return _PROD_URL
"""Return the AI service base URL from the environment."""
return _resolve_url_from_env(
"CODEFLASH_AIS_SERVER",
"https://app.codeflash.ai",
"http://localhost:8000",
)
def _resolve_cfapi_base_url() -> str:
"""Return the platform API base URL from the environment."""
server = os.environ.get("CODEFLASH_CFAPI_SERVER", "prod")
if server.lower() == "local":
return _CFAPI_LOCAL_URL
return _CFAPI_PROD_URL
return _resolve_url_from_env(
"CODEFLASH_CFAPI_SERVER",
"https://app.codeflash.ai",
"http://localhost:3001",
)
def _strip_trailing_slash(url: str) -> str:

View file

@ -338,21 +338,7 @@ def _par_apply_fns(
"""
Apply *ops* to *elements* eagerly.
"""
results: list[Any] = []
for elem in elements:
valid = True
res: Any = elem
for op, op_fn in ops:
if op == _MAP:
res = op_fn(res)
elif op == _FILTER and not op_fn(res):
valid = False
break
elif op == _TAP:
op_fn(deepcopy(res))
if valid:
results.append(res)
return tuple(results)
return tuple(_apply_fns(elements, ops))
async def _async_apply_fns(

View file

@ -14,6 +14,7 @@ from codeflash_core import (
OptimizationRequest,
OptimizationReviewResult,
)
from codeflash_core._client import _parse_candidates
@pytest.fixture(name="client")
@ -49,6 +50,51 @@ def _mock_post(client):
yield mock
class TestParseCandidates:
"""Tests for _parse_candidates."""
def test_parses_optimizations(self):
"""Multiple optimizations are parsed into Candidate objects."""
data = {
"optimizations": [
{
"source_code": "def f(): pass",
"explanation": "simplified",
"optimization_id": "id-1",
},
{
"source_code": "def g(): pass",
"explanation": "inlined",
"optimization_id": "id-2",
},
]
}
result = _parse_candidates(data)
assert 2 == len(result)
assert all(isinstance(c, Candidate) for c in result)
assert "id-1" == result[0].candidate_id
assert "inlined" == result[1].explanation
def test_empty_optimizations(self):
"""An empty optimizations list returns no candidates."""
assert [] == _parse_candidates({"optimizations": []})
def test_missing_key(self):
"""A response without 'optimizations' returns no candidates."""
assert [] == _parse_candidates({})
def test_missing_fields_use_defaults(self):
"""Missing item fields default to empty strings."""
data = {"optimizations": [{}]}
result = _parse_candidates(data)
assert 1 == len(result)
assert "" == result[0].code
assert "" == result[0].explanation
assert "" == result[0].candidate_id
class TestAIClient:
"""Tests for AIClient."""

View file

@ -9,6 +9,7 @@ from codeflash_core import (
InvalidAPIKeyError,
PrComment,
)
from codeflash_core._http import _resolve_url_from_env
from codeflash_core._platform import (
PlatformClient,
parse_repo_owner_and_name,
@ -87,6 +88,38 @@ class TestParseRepoOwnerAndName:
)
class TestResolveUrlFromEnv:
"""Tests for _resolve_url_from_env."""
def test_returns_prod_by_default(self, monkeypatch):
"""Unset env var returns prod URL."""
monkeypatch.delenv("MY_SERVER", raising=False)
assert "https://prod.example.com" == _resolve_url_from_env(
"MY_SERVER", "https://prod.example.com", "http://localhost:9000"
)
def test_returns_local_when_set(self, monkeypatch):
"""'local' env var returns local URL."""
monkeypatch.setenv("MY_SERVER", "local")
assert "http://localhost:9000" == _resolve_url_from_env(
"MY_SERVER", "https://prod.example.com", "http://localhost:9000"
)
def test_case_insensitive(self, monkeypatch):
"""'LOCAL' (uppercase) returns local URL."""
monkeypatch.setenv("MY_SERVER", "LOCAL")
assert "http://localhost:9000" == _resolve_url_from_env(
"MY_SERVER", "https://prod.example.com", "http://localhost:9000"
)
def test_non_local_returns_prod(self, monkeypatch):
"""Any value other than 'local' returns prod URL."""
monkeypatch.setenv("MY_SERVER", "staging")
assert "https://prod.example.com" == _resolve_url_from_env(
"MY_SERVER", "https://prod.example.com", "http://localhost:9000"
)
class TestPlatformClient:
"""Tests for PlatformClient construction."""

View file

@ -17,11 +17,10 @@ from typing import TYPE_CHECKING, Any
import attrs
import libcst as cst
from codeflash_core import BenchmarkDetail, humanize_runtime
from codeflash_core import BenchmarkDetail, humanize_runtime, performance_gain
from ..analysis._discovery import inspect_top_level_functions_or_methods
from ..analysis._formatter import sort_imports
from ..verification._verification import performance_gain
from .models import (
ProcessedBenchmarkInfo,
get_function_alias,

View file

@ -13,11 +13,11 @@ from codeflash_core import (
PrComment,
check_and_push_branch,
get_repo_owner_and_name,
performance_gain,
)
from ..ai._tabulate import tabulate
from ..testing._testgen import format_perf, format_time
from ..verification._verification import performance_gain
if TYPE_CHECKING:
import git

View file

@ -145,15 +145,15 @@ def collect_type_names_from_annotation(
node: ast.expr | None,
) -> set[str]:
"""Recursively collect type names from an annotation node."""
if node is None:
return set()
if isinstance(node, ast.Name):
return {node.id}
if isinstance(node, ast.Attribute) and isinstance(node.value, ast.Name):
return {node.value.id}
if isinstance(node, ast.Subscript):
names = collect_type_names_from_annotation(node.value)
names |= collect_type_names_from_annotation(node.slice)
return names
if isinstance(node, ast.BinOp) and isinstance(node.op, ast.BitOr):
if isinstance(node, ast.BinOp):
return collect_type_names_from_annotation(
node.left
) | collect_type_names_from_annotation(node.right)
@ -165,26 +165,6 @@ def collect_type_names_from_annotation(
return set()
def collect_names_from_annotation(
node: ast.expr,
names: set[str],
) -> None:
"""Mutating variant: add type annotation names into *names*."""
if isinstance(node, ast.Name):
names.add(node.id)
elif isinstance(node, ast.Subscript):
collect_names_from_annotation(node.value, names)
collect_names_from_annotation(node.slice, names)
elif isinstance(node, ast.Tuple):
for elt in node.elts:
collect_names_from_annotation(elt, names)
elif isinstance(node, ast.BinOp):
collect_names_from_annotation(node.left, names)
collect_names_from_annotation(node.right, names)
elif isinstance(node, ast.Attribute) and isinstance(node.value, ast.Name):
names.add(node.value.id)
def expr_matches_name(
node: ast.AST | None,
import_aliases: dict[str, str],

View file

@ -16,7 +16,6 @@ from ._ast_helpers import (
ImportCollector,
collect_existing_class_names,
collect_import_aliases,
collect_names_from_annotation,
collect_type_names_from_annotation,
find_class_node_by_name,
get_expr_name,
@ -216,7 +215,7 @@ def extract_imports_for_class( # noqa: C901, PLR0912
for item in class_node.body:
if isinstance(item, ast.AnnAssign) and item.annotation:
collect_names_from_annotation(item.annotation, needed_names)
needed_names |= collect_type_names_from_annotation(item.annotation)
elif (
isinstance(item, ast.Assign)
and isinstance(item.value, ast.Call)

View file

@ -19,10 +19,13 @@ import libcst as cst
from libcst import MetadataWrapper
from libcst.metadata import PositionProvider
from codeflash_core import AIServiceConnectionError, AIServiceError
from codeflash_core import (
AIServiceConnectionError,
AIServiceError,
performance_gain,
)
from .._constants import LANGUAGE_FIELDS, LANGUAGE_VERSION
from ..verification._verification import performance_gain
if TYPE_CHECKING:
from codeflash_core import AIClient

View file

@ -44,7 +44,16 @@ def shorten_pytest_error(pytest_error_string: str) -> str:
)
def compare_test_results( # noqa: C901, PLR0912
def _extract_pytest_error(
failures: dict[str, str] | None,
qualified_name: str,
) -> str:
"""Look up and shorten a pytest error for *qualified_name*."""
raw = failures.get(qualified_name, "") if failures else ""
return shorten_pytest_error(raw) if raw else ""
def compare_test_results( # noqa: C901
original_results: TestResults,
candidate_results: TestResults,
pass_fail_only: bool = False, # noqa: FBT001, FBT002
@ -110,30 +119,15 @@ def compare_test_results( # noqa: C901, PLR0912
)
# Gather pytest error messages
candidate_test_failures = candidate_results.test_failures
original_test_failures = original_results.test_failures
cdd_pytest_error = (
candidate_test_failures.get(
original_test_result.id.test_fn_qualified_name(),
"",
)
if candidate_test_failures
else ""
qname = original_test_result.id.test_fn_qualified_name()
cdd_pytest_error = _extract_pytest_error(
candidate_results.test_failures,
qname,
)
if cdd_pytest_error:
cdd_pytest_error = shorten_pytest_error(cdd_pytest_error)
original_pytest_error = (
original_test_failures.get(
original_test_result.id.test_fn_qualified_name(),
"",
)
if original_test_failures
else ""
original_pytest_error = _extract_pytest_error(
original_results.test_failures,
qname,
)
if original_pytest_error:
original_pytest_error = shorten_pytest_error(
original_pytest_error,
)
# Check pass/fail mismatch
if original_test_result.test_type in {
@ -236,18 +230,3 @@ def compare_test_results( # noqa: C901, PLR0912
return False, test_diffs
return len(test_diffs) == 0, test_diffs
def performance_gain(
*,
original_runtime_ns: int,
optimized_runtime_ns: int,
) -> float:
"""Calculate the performance gain of optimized code over the original.
Returns a ratio where 1.0 means 100% faster (2x speedup).
Returns 0.0 when the optimized runtime is zero.
"""
if optimized_runtime_ns == 0:
return 0.0
return (original_runtime_ns - optimized_runtime_ns) / optimized_runtime_ns

View file

@ -4585,12 +4585,13 @@ def test_collect_type_names_none_annotation() -> None:
assert collect_type_names_from_annotation(None) == set()
def test_collect_type_names_attribute_skipped() -> None:
def test_collect_type_names_attribute_collects_module() -> None:
"""Attribute like module.Foo collects the module name."""
tree = ast.parse("def f(x: module.Foo): pass")
func = tree.body[0]
assert isinstance(func, ast.FunctionDef)
ann = func.args.args[0].annotation
assert collect_type_names_from_annotation(ann) == set()
assert {"module"} == collect_type_names_from_annotation(ann)
# --- Tests for extract_init_stub_from_class ---

View file

@ -125,6 +125,14 @@ class TestCollectTypeNamesFromAnnotation:
"""None returns empty set."""
assert set() == collect_type_names_from_annotation(None)
def test_attribute_collects_module(self) -> None:
"""An Attribute like typing.Optional collects the module name."""
ann = ast.Attribute(
value=ast.Name(id="typing"),
attr="Optional",
)
assert {"typing"} == collect_type_names_from_annotation(ann)
class TestDeclarativeClassDetection:
"""Tests for NamedTuple, dataclass, and attrs detection."""

View file

@ -5,6 +5,7 @@ from pathlib import Path
import attrs
import pytest
from codeflash_core import performance_gain
from codeflash_python._model import VerificationType
from codeflash_python.test_discovery.models import TestType
from codeflash_python.testing.models import (
@ -13,8 +14,9 @@ from codeflash_python.testing.models import (
TestResults,
)
from codeflash_python.verification._verification import (
_extract_pytest_error,
compare_test_results,
performance_gain,
shorten_pytest_error,
)
from codeflash_python.verification.models import (
BehaviorDiff,
@ -485,3 +487,28 @@ class TestGetAllUniqueInvocationLoopIds:
results = TestResults()
assert set() == results.get_all_unique_invocation_loop_ids()
class TestExtractPytestError:
"""Tests for _extract_pytest_error."""
def test_extracts_and_shortens(self) -> None:
"""A matching error is looked up and shortened."""
failures = {
"test_mod::test_fn": "some context\nE AssertionError: 1 != 2\n> assert 1 == 2",
}
result = _extract_pytest_error(failures, "test_mod::test_fn")
assert "AssertionError: 1 != 2\nassert 1 == 2" == result
def test_missing_key_returns_empty(self) -> None:
"""A missing qualified name returns an empty string."""
assert "" == _extract_pytest_error({"other::fn": "err"}, "test::fn")
def test_none_failures_returns_empty(self) -> None:
"""None failures dict returns an empty string."""
assert "" == _extract_pytest_error(None, "test::fn")
def test_empty_failures_returns_empty(self) -> None:
"""Empty failures dict returns an empty string."""
assert "" == _extract_pytest_error({}, "test::fn")