codeflash-agent/packages/codeflash-core/tests/test_client.py
Kevin Turcios 6b73b07d15 fix: deduplicate code across codeflash-core and codeflash-python
- Extract _parse_candidates helper in _client.py (used by get_candidates and optimize_with_line_profiler)
- Parameterize URL resolution in _http.py (_resolve_url_from_env replaces two near-identical functions)
- Delegate get_repo_owner_and_name to parse_repo_owner_and_name in _git.py
- Simplify _par_apply_fns to delegate to _apply_fns in danom/stream.py
- Remove duplicate performance_gain from _verification.py (use codeflash_core's version)
- Extract _extract_pytest_error helper in _verification.py (replaces duplicated 6-line block)
- Consolidate collect_names_from_annotation into collect_type_names_from_annotation in _ast_helpers.py
- Add ast.Attribute handling and relax BinOp guard in collect_type_names_from_annotation
- Add unit tests for all extracted helpers
2026-04-23 22:39:50 -05:00

497 lines
16 KiB
Python

from __future__ import annotations
from unittest.mock import ANY, MagicMock, patch
import pytest
import requests
from codeflash_core import (
AIClient,
AIServiceConnectionError,
AIServiceError,
Candidate,
InvalidAPIKeyError,
OptimizationRequest,
OptimizationReviewResult,
)
from codeflash_core._client import _parse_candidates
@pytest.fixture(name="client")
def _client():
"""
An AIClient pointed at localhost.
"""
with AIClient(base_url="http://localhost", api_key="cf-test") as c:
yield c
@pytest.fixture(name="request_")
def _request() -> OptimizationRequest:
"""
A sample OptimizationRequest for testing.
"""
return OptimizationRequest(
source_code="def compute(x): return x + 1",
context_code="# read only",
language="python",
language_version="3.12.0",
)
@pytest.fixture(name="mock_post")
def _mock_post(client):
"""
Patch client._session.post and return the mock.
"""
with patch.object(client._session, "post") as mock:
mock.return_value = MagicMock()
mock.return_value.json.return_value = {"optimizations": []}
yield mock
class TestParseCandidates:
"""Tests for _parse_candidates."""
def test_parses_optimizations(self):
"""Multiple optimizations are parsed into Candidate objects."""
data = {
"optimizations": [
{
"source_code": "def f(): pass",
"explanation": "simplified",
"optimization_id": "id-1",
},
{
"source_code": "def g(): pass",
"explanation": "inlined",
"optimization_id": "id-2",
},
]
}
result = _parse_candidates(data)
assert 2 == len(result)
assert all(isinstance(c, Candidate) for c in result)
assert "id-1" == result[0].candidate_id
assert "inlined" == result[1].explanation
def test_empty_optimizations(self):
"""An empty optimizations list returns no candidates."""
assert [] == _parse_candidates({"optimizations": []})
def test_missing_key(self):
"""A response without 'optimizations' returns no candidates."""
assert [] == _parse_candidates({})
def test_missing_fields_use_defaults(self):
"""Missing item fields default to empty strings."""
data = {"optimizations": [{}]}
result = _parse_candidates(data)
assert 1 == len(result)
assert "" == result[0].code
assert "" == result[0].explanation
assert "" == result[0].candidate_id
class TestAIClient:
"""Tests for AIClient."""
def test_default_timeout(self, monkeypatch):
"""
Default timeout is 300 seconds.
"""
monkeypatch.setenv("CODEFLASH_API_KEY", "cf-test")
with AIClient() as c:
assert 300.0 == c._timeout
def test_strips_trailing_slash(self):
"""
Trailing slashes are stripped from the base URL.
"""
with AIClient(
base_url="http://localhost:8000/", api_key="cf-test"
) as c:
assert "http://localhost:8000" == c._base_url
def test_sets_auth_header(self):
"""
An API key sets the Authorization header.
"""
with AIClient(api_key="cf-test-key") as c:
assert "Bearer cf-test-key" == c._session.headers["Authorization"]
def test_no_auth_header_when_empty(self):
"""
An empty API key sets no Authorization header.
"""
with AIClient(api_key="") as c:
assert "Authorization" not in c._session.headers
def test_context_manager_closes_session(self, monkeypatch):
"""
Exiting the context manager closes the session.
"""
monkeypatch.setenv("CODEFLASH_API_KEY", "cf-test")
c = AIClient()
with patch.object(c._session, "close") as mock_close, c:
pass
mock_close.assert_called_once()
class TestLocalApiResolution:
"""Tests for local API via CODEFLASH_AIS_SERVER."""
@pytest.mark.parametrize(
("env_value", "expected_url"),
[
("local", "http://localhost:8000"),
("LOCAL", "http://localhost:8000"),
("prod", "https://app.codeflash.ai"),
(None, "https://app.codeflash.ai"),
],
)
def test_env_resolution(self, monkeypatch, env_value, expected_url):
"""
CODEFLASH_AIS_SERVER resolves to the correct URL.
"""
monkeypatch.setenv("CODEFLASH_API_KEY", "cf-test")
if env_value is None:
monkeypatch.delenv("CODEFLASH_AIS_SERVER", raising=False)
else:
monkeypatch.setenv("CODEFLASH_AIS_SERVER", env_value)
with AIClient() as c:
assert expected_url == c._base_url
def test_explicit_url_overrides_env(self, monkeypatch):
"""
An explicit base_url overrides the env var.
"""
monkeypatch.setenv("CODEFLASH_AIS_SERVER", "local")
with AIClient(base_url="https://custom.api", api_key="cf-test") as c:
assert "https://custom.api" == c._base_url
class TestAPIKeyResolution:
"""Tests for API key resolution from environment."""
def test_reads_from_env(self, monkeypatch):
"""
API key is read from CODEFLASH_API_KEY.
"""
monkeypatch.setenv("CODEFLASH_API_KEY", "cf-abc123")
with AIClient() as c:
assert "cf-abc123" == c._api_key
def test_missing_key_raises(self, monkeypatch):
"""
Missing CODEFLASH_API_KEY raises InvalidAPIKeyError.
"""
monkeypatch.delenv("CODEFLASH_API_KEY", raising=False)
with pytest.raises(InvalidAPIKeyError, match="not found"):
AIClient()
def test_invalid_prefix_raises(self, monkeypatch):
"""
A key without the cf- prefix raises InvalidAPIKeyError.
"""
monkeypatch.setenv("CODEFLASH_API_KEY", "sk-bad-prefix")
with pytest.raises(InvalidAPIKeyError, match="must start with"):
AIClient()
def test_explicit_key_skips_env(self, monkeypatch):
"""
An explicit api_key bypasses environment resolution.
"""
monkeypatch.delenv("CODEFLASH_API_KEY", raising=False)
with AIClient(api_key="cf-explicit") as c:
assert "cf-explicit" == c._api_key
class TestGetCandidates:
"""Tests for AIClient.get_candidates."""
def test_success(self, client, request_):
"""
A successful response returns parsed candidates
with the correct payload sent.
"""
mock_resp = MagicMock()
mock_resp.json.return_value = {
"optimizations": [
{
"source_code": "def compute(x): return x + 1",
"explanation": "optimized",
"optimization_id": "abc123",
},
{
"source_code": "def compute(x): return ~(~x)",
"explanation": "bit trick",
"optimization_id": "def456",
},
]
}
with patch.object(
client._session, "post", return_value=mock_resp
) as mock_post:
result = client.get_candidates(request_)
mock_post.assert_called_once_with(
"http://localhost/ai/optimize",
json={
"source_code": "def compute(x): return x + 1",
"dependency_code": "# read only",
"trace_id": ANY,
"language": "python",
"language_version": "3.12.0",
"n_candidates": 5,
"call_sequence": 1,
"is_async": False,
"is_numerical_code": None,
"codeflash_version": "",
},
timeout=300.0,
)
assert 2 == len(result)
assert all(isinstance(item, Candidate) for item in result)
assert "abc123" == result[0].candidate_id
assert "bit trick" == result[1].explanation
def test_empty_response(self, client, request_, mock_post):
"""
An empty or missing optimizations key returns no candidates.
"""
assert [] == client.get_candidates(request_)
mock_post.return_value.json.return_value = {}
assert [] == client.get_candidates(request_)
def test_non_python_language(self, client, mock_post):
"""
Language and version are passed through from the request.
"""
js_request = OptimizationRequest(
source_code="function add(a, b) { return a + b; }",
language="javascript",
language_version="ES2022",
)
client.get_candidates(js_request)
payload = mock_post.call_args[1]["json"]
assert "javascript" == payload["language"]
assert "ES2022" == payload["language_version"]
def test_http_error_raises(self, client, request_):
"""
An HTTP error raises AIServiceError.
"""
mock_resp = MagicMock()
mock_resp.status_code = 500
mock_resp.text = "Internal Server Error"
mock_resp.raise_for_status.side_effect = requests.HTTPError("500")
with (
patch.object(client._session, "post", return_value=mock_resp),
pytest.raises(AIServiceError) as exc_info,
):
client.get_candidates(request_)
assert 500 == exc_info.value.status_code
def test_connection_error_raises(self, client, request_):
"""
A connection failure raises AIServiceConnectionError.
"""
with (
patch.object(
client._session,
"post",
side_effect=requests.ConnectionError("refused"),
),
pytest.raises(AIServiceConnectionError),
):
client.get_candidates(request_)
def test_empty_body_raises_ai_service_error(self, client, request_):
"""
A 200 response with an empty body raises AIServiceError.
"""
mock_resp = MagicMock()
mock_resp.status_code = 200
mock_resp.text = ""
mock_resp.raise_for_status.return_value = None
mock_resp.json.side_effect = ValueError("No JSON object")
with (
patch.object(client._session, "post", return_value=mock_resp),
pytest.raises(AIServiceError) as exc_info,
):
client.get_candidates(request_)
assert 200 == exc_info.value.status_code
class TestOptimizeWithLineProfiler:
"""Tests for AIClient.optimize_with_line_profiler."""
def test_success(self, client, request_):
"""
A successful response returns parsed candidates.
"""
mock_resp = MagicMock()
mock_resp.json.return_value = {
"optimizations": [
{
"source_code": "def compute(x): return x + 1",
"explanation": "line-profiler guided",
"optimization_id": "lp-001",
},
]
}
with patch.object(
client._session, "post", return_value=mock_resp
) as mock_post:
result = client.optimize_with_line_profiler(
request_,
line_profiler_results="Line # Hits Time",
)
mock_post.assert_called_once_with(
"http://localhost/ai/optimize-line-profiler",
json={
"source_code": "def compute(x): return x + 1",
"dependency_code": "# read only",
"trace_id": ANY,
"language": "python",
"language_version": "3.12.0",
"n_candidates": 5,
"line_profiler_results": "Line # Hits Time",
"call_sequence": 1,
"is_numerical_code": None,
"codeflash_version": "",
},
timeout=300.0,
)
assert 1 == len(result)
assert all(isinstance(c, Candidate) for c in result)
assert "lp-001" == result[0].candidate_id
def test_empty_line_profiler_returns_early(self, client, request_):
"""
Empty line profiler results return [] without calling the API.
"""
with patch.object(client._session, "post") as mock_post:
result = client.optimize_with_line_profiler(
request_, line_profiler_results=""
)
assert [] == result
mock_post.assert_not_called()
def test_http_error_raises(self, client, request_):
"""
An HTTP error raises AIServiceError.
"""
mock_resp = MagicMock()
mock_resp.status_code = 500
mock_resp.text = "Internal Server Error"
mock_resp.raise_for_status.side_effect = requests.HTTPError("500")
with (
patch.object(client._session, "post", return_value=mock_resp),
pytest.raises(AIServiceError),
):
client.optimize_with_line_profiler(
request_,
line_profiler_results="Line # Hits Time",
)
class TestGenerateExplanation:
"""Tests for AIClient.generate_explanation."""
def test_success(self, client):
"""
A successful response returns the explanation text.
"""
mock_resp = MagicMock()
mock_resp.json.return_value = {
"explanation": "Replaced loop with vectorized op."
}
with patch.object(client._session, "post", return_value=mock_resp):
result = client.generate_explanation(
{"trace_id": "t1", "source_code": "x"}
)
assert "Replaced loop with vectorized op." == result
def test_failure_returns_empty(self, client):
"""
An API error returns an empty string.
"""
mock_resp = MagicMock()
mock_resp.status_code = 500
mock_resp.text = "fail"
mock_resp.raise_for_status.side_effect = requests.HTTPError("500")
with patch.object(client._session, "post", return_value=mock_resp):
result = client.generate_explanation({"trace_id": "t1"})
assert "" == result
class TestLogResults:
"""Tests for AIClient.log_results."""
def test_success(self, client):
"""
A successful call completes without error.
"""
mock_resp = MagicMock()
mock_resp.json.return_value = {}
with patch.object(client._session, "post", return_value=mock_resp):
client.log_results({"trace_id": "t1"})
def test_failure_is_silent(self, client):
"""
API errors are silently swallowed.
"""
mock_resp = MagicMock()
mock_resp.status_code = 500
mock_resp.text = "fail"
mock_resp.raise_for_status.side_effect = requests.HTTPError("500")
with patch.object(client._session, "post", return_value=mock_resp):
client.log_results({"trace_id": "t1"}) # no exception
class TestGetOptimizationReview:
"""Tests for AIClient.get_optimization_review."""
def test_success(self, client):
"""
A successful response returns review and explanation.
"""
mock_resp = MagicMock()
mock_resp.json.return_value = {
"review": "high",
"review_explanation": "Well-tested optimization.",
}
with patch.object(client._session, "post", return_value=mock_resp):
result = client.get_optimization_review(
{"trace_id": "t1", "original_code": "x"}
)
assert isinstance(result, OptimizationReviewResult)
assert "high" == result.review
assert "Well-tested optimization." == result.explanation
def test_failure_returns_empty(self, client):
"""
An API error returns empty review and explanation.
"""
mock_resp = MagicMock()
mock_resp.status_code = 500
mock_resp.text = "fail"
mock_resp.raise_for_status.side_effect = requests.HTTPError("500")
with patch.object(client._session, "post", return_value=mock_resp):
result = client.get_optimization_review({"trace_id": "t1"})
assert "" == result.review
assert "" == result.explanation