- Extract _parse_candidates helper in _client.py (used by get_candidates and optimize_with_line_profiler) - Parameterize URL resolution in _http.py (_resolve_url_from_env replaces two near-identical functions) - Delegate get_repo_owner_and_name to parse_repo_owner_and_name in _git.py - Simplify _par_apply_fns to delegate to _apply_fns in danom/stream.py - Remove duplicate performance_gain from _verification.py (use codeflash_core's version) - Extract _extract_pytest_error helper in _verification.py (replaces duplicated 6-line block) - Consolidate collect_names_from_annotation into collect_type_names_from_annotation in _ast_helpers.py - Add ast.Attribute handling and relax BinOp guard in collect_type_names_from_annotation - Add unit tests for all extracted helpers
497 lines
16 KiB
Python
497 lines
16 KiB
Python
from __future__ import annotations
|
|
|
|
from unittest.mock import ANY, MagicMock, patch
|
|
|
|
import pytest
|
|
import requests
|
|
|
|
from codeflash_core import (
|
|
AIClient,
|
|
AIServiceConnectionError,
|
|
AIServiceError,
|
|
Candidate,
|
|
InvalidAPIKeyError,
|
|
OptimizationRequest,
|
|
OptimizationReviewResult,
|
|
)
|
|
from codeflash_core._client import _parse_candidates
|
|
|
|
|
|
@pytest.fixture(name="client")
|
|
def _client():
|
|
"""
|
|
An AIClient pointed at localhost.
|
|
"""
|
|
with AIClient(base_url="http://localhost", api_key="cf-test") as c:
|
|
yield c
|
|
|
|
|
|
@pytest.fixture(name="request_")
|
|
def _request() -> OptimizationRequest:
|
|
"""
|
|
A sample OptimizationRequest for testing.
|
|
"""
|
|
return OptimizationRequest(
|
|
source_code="def compute(x): return x + 1",
|
|
context_code="# read only",
|
|
language="python",
|
|
language_version="3.12.0",
|
|
)
|
|
|
|
|
|
@pytest.fixture(name="mock_post")
|
|
def _mock_post(client):
|
|
"""
|
|
Patch client._session.post and return the mock.
|
|
"""
|
|
with patch.object(client._session, "post") as mock:
|
|
mock.return_value = MagicMock()
|
|
mock.return_value.json.return_value = {"optimizations": []}
|
|
yield mock
|
|
|
|
|
|
class TestParseCandidates:
|
|
"""Tests for _parse_candidates."""
|
|
|
|
def test_parses_optimizations(self):
|
|
"""Multiple optimizations are parsed into Candidate objects."""
|
|
data = {
|
|
"optimizations": [
|
|
{
|
|
"source_code": "def f(): pass",
|
|
"explanation": "simplified",
|
|
"optimization_id": "id-1",
|
|
},
|
|
{
|
|
"source_code": "def g(): pass",
|
|
"explanation": "inlined",
|
|
"optimization_id": "id-2",
|
|
},
|
|
]
|
|
}
|
|
result = _parse_candidates(data)
|
|
|
|
assert 2 == len(result)
|
|
assert all(isinstance(c, Candidate) for c in result)
|
|
assert "id-1" == result[0].candidate_id
|
|
assert "inlined" == result[1].explanation
|
|
|
|
def test_empty_optimizations(self):
|
|
"""An empty optimizations list returns no candidates."""
|
|
assert [] == _parse_candidates({"optimizations": []})
|
|
|
|
def test_missing_key(self):
|
|
"""A response without 'optimizations' returns no candidates."""
|
|
assert [] == _parse_candidates({})
|
|
|
|
def test_missing_fields_use_defaults(self):
|
|
"""Missing item fields default to empty strings."""
|
|
data = {"optimizations": [{}]}
|
|
result = _parse_candidates(data)
|
|
|
|
assert 1 == len(result)
|
|
assert "" == result[0].code
|
|
assert "" == result[0].explanation
|
|
assert "" == result[0].candidate_id
|
|
|
|
|
|
class TestAIClient:
|
|
"""Tests for AIClient."""
|
|
|
|
def test_default_timeout(self, monkeypatch):
|
|
"""
|
|
Default timeout is 300 seconds.
|
|
"""
|
|
monkeypatch.setenv("CODEFLASH_API_KEY", "cf-test")
|
|
with AIClient() as c:
|
|
assert 300.0 == c._timeout
|
|
|
|
def test_strips_trailing_slash(self):
|
|
"""
|
|
Trailing slashes are stripped from the base URL.
|
|
"""
|
|
with AIClient(
|
|
base_url="http://localhost:8000/", api_key="cf-test"
|
|
) as c:
|
|
assert "http://localhost:8000" == c._base_url
|
|
|
|
def test_sets_auth_header(self):
|
|
"""
|
|
An API key sets the Authorization header.
|
|
"""
|
|
with AIClient(api_key="cf-test-key") as c:
|
|
assert "Bearer cf-test-key" == c._session.headers["Authorization"]
|
|
|
|
def test_no_auth_header_when_empty(self):
|
|
"""
|
|
An empty API key sets no Authorization header.
|
|
"""
|
|
with AIClient(api_key="") as c:
|
|
assert "Authorization" not in c._session.headers
|
|
|
|
def test_context_manager_closes_session(self, monkeypatch):
|
|
"""
|
|
Exiting the context manager closes the session.
|
|
"""
|
|
monkeypatch.setenv("CODEFLASH_API_KEY", "cf-test")
|
|
c = AIClient()
|
|
with patch.object(c._session, "close") as mock_close, c:
|
|
pass
|
|
mock_close.assert_called_once()
|
|
|
|
|
|
class TestLocalApiResolution:
|
|
"""Tests for local API via CODEFLASH_AIS_SERVER."""
|
|
|
|
@pytest.mark.parametrize(
|
|
("env_value", "expected_url"),
|
|
[
|
|
("local", "http://localhost:8000"),
|
|
("LOCAL", "http://localhost:8000"),
|
|
("prod", "https://app.codeflash.ai"),
|
|
(None, "https://app.codeflash.ai"),
|
|
],
|
|
)
|
|
def test_env_resolution(self, monkeypatch, env_value, expected_url):
|
|
"""
|
|
CODEFLASH_AIS_SERVER resolves to the correct URL.
|
|
"""
|
|
monkeypatch.setenv("CODEFLASH_API_KEY", "cf-test")
|
|
if env_value is None:
|
|
monkeypatch.delenv("CODEFLASH_AIS_SERVER", raising=False)
|
|
else:
|
|
monkeypatch.setenv("CODEFLASH_AIS_SERVER", env_value)
|
|
with AIClient() as c:
|
|
assert expected_url == c._base_url
|
|
|
|
def test_explicit_url_overrides_env(self, monkeypatch):
|
|
"""
|
|
An explicit base_url overrides the env var.
|
|
"""
|
|
monkeypatch.setenv("CODEFLASH_AIS_SERVER", "local")
|
|
with AIClient(base_url="https://custom.api", api_key="cf-test") as c:
|
|
assert "https://custom.api" == c._base_url
|
|
|
|
|
|
class TestAPIKeyResolution:
|
|
"""Tests for API key resolution from environment."""
|
|
|
|
def test_reads_from_env(self, monkeypatch):
|
|
"""
|
|
API key is read from CODEFLASH_API_KEY.
|
|
"""
|
|
monkeypatch.setenv("CODEFLASH_API_KEY", "cf-abc123")
|
|
with AIClient() as c:
|
|
assert "cf-abc123" == c._api_key
|
|
|
|
def test_missing_key_raises(self, monkeypatch):
|
|
"""
|
|
Missing CODEFLASH_API_KEY raises InvalidAPIKeyError.
|
|
"""
|
|
monkeypatch.delenv("CODEFLASH_API_KEY", raising=False)
|
|
with pytest.raises(InvalidAPIKeyError, match="not found"):
|
|
AIClient()
|
|
|
|
def test_invalid_prefix_raises(self, monkeypatch):
|
|
"""
|
|
A key without the cf- prefix raises InvalidAPIKeyError.
|
|
"""
|
|
monkeypatch.setenv("CODEFLASH_API_KEY", "sk-bad-prefix")
|
|
with pytest.raises(InvalidAPIKeyError, match="must start with"):
|
|
AIClient()
|
|
|
|
def test_explicit_key_skips_env(self, monkeypatch):
|
|
"""
|
|
An explicit api_key bypasses environment resolution.
|
|
"""
|
|
monkeypatch.delenv("CODEFLASH_API_KEY", raising=False)
|
|
with AIClient(api_key="cf-explicit") as c:
|
|
assert "cf-explicit" == c._api_key
|
|
|
|
|
|
class TestGetCandidates:
|
|
"""Tests for AIClient.get_candidates."""
|
|
|
|
def test_success(self, client, request_):
|
|
"""
|
|
A successful response returns parsed candidates
|
|
with the correct payload sent.
|
|
"""
|
|
mock_resp = MagicMock()
|
|
mock_resp.json.return_value = {
|
|
"optimizations": [
|
|
{
|
|
"source_code": "def compute(x): return x + 1",
|
|
"explanation": "optimized",
|
|
"optimization_id": "abc123",
|
|
},
|
|
{
|
|
"source_code": "def compute(x): return ~(~x)",
|
|
"explanation": "bit trick",
|
|
"optimization_id": "def456",
|
|
},
|
|
]
|
|
}
|
|
with patch.object(
|
|
client._session, "post", return_value=mock_resp
|
|
) as mock_post:
|
|
result = client.get_candidates(request_)
|
|
|
|
mock_post.assert_called_once_with(
|
|
"http://localhost/ai/optimize",
|
|
json={
|
|
"source_code": "def compute(x): return x + 1",
|
|
"dependency_code": "# read only",
|
|
"trace_id": ANY,
|
|
"language": "python",
|
|
"language_version": "3.12.0",
|
|
"n_candidates": 5,
|
|
"call_sequence": 1,
|
|
"is_async": False,
|
|
"is_numerical_code": None,
|
|
"codeflash_version": "",
|
|
},
|
|
timeout=300.0,
|
|
)
|
|
assert 2 == len(result)
|
|
assert all(isinstance(item, Candidate) for item in result)
|
|
assert "abc123" == result[0].candidate_id
|
|
assert "bit trick" == result[1].explanation
|
|
|
|
def test_empty_response(self, client, request_, mock_post):
|
|
"""
|
|
An empty or missing optimizations key returns no candidates.
|
|
"""
|
|
assert [] == client.get_candidates(request_)
|
|
|
|
mock_post.return_value.json.return_value = {}
|
|
assert [] == client.get_candidates(request_)
|
|
|
|
def test_non_python_language(self, client, mock_post):
|
|
"""
|
|
Language and version are passed through from the request.
|
|
"""
|
|
js_request = OptimizationRequest(
|
|
source_code="function add(a, b) { return a + b; }",
|
|
language="javascript",
|
|
language_version="ES2022",
|
|
)
|
|
client.get_candidates(js_request)
|
|
|
|
payload = mock_post.call_args[1]["json"]
|
|
assert "javascript" == payload["language"]
|
|
assert "ES2022" == payload["language_version"]
|
|
|
|
def test_http_error_raises(self, client, request_):
|
|
"""
|
|
An HTTP error raises AIServiceError.
|
|
"""
|
|
mock_resp = MagicMock()
|
|
mock_resp.status_code = 500
|
|
mock_resp.text = "Internal Server Error"
|
|
mock_resp.raise_for_status.side_effect = requests.HTTPError("500")
|
|
with (
|
|
patch.object(client._session, "post", return_value=mock_resp),
|
|
pytest.raises(AIServiceError) as exc_info,
|
|
):
|
|
client.get_candidates(request_)
|
|
|
|
assert 500 == exc_info.value.status_code
|
|
|
|
def test_connection_error_raises(self, client, request_):
|
|
"""
|
|
A connection failure raises AIServiceConnectionError.
|
|
"""
|
|
with (
|
|
patch.object(
|
|
client._session,
|
|
"post",
|
|
side_effect=requests.ConnectionError("refused"),
|
|
),
|
|
pytest.raises(AIServiceConnectionError),
|
|
):
|
|
client.get_candidates(request_)
|
|
|
|
def test_empty_body_raises_ai_service_error(self, client, request_):
|
|
"""
|
|
A 200 response with an empty body raises AIServiceError.
|
|
"""
|
|
mock_resp = MagicMock()
|
|
mock_resp.status_code = 200
|
|
mock_resp.text = ""
|
|
mock_resp.raise_for_status.return_value = None
|
|
mock_resp.json.side_effect = ValueError("No JSON object")
|
|
with (
|
|
patch.object(client._session, "post", return_value=mock_resp),
|
|
pytest.raises(AIServiceError) as exc_info,
|
|
):
|
|
client.get_candidates(request_)
|
|
|
|
assert 200 == exc_info.value.status_code
|
|
|
|
|
|
class TestOptimizeWithLineProfiler:
|
|
"""Tests for AIClient.optimize_with_line_profiler."""
|
|
|
|
def test_success(self, client, request_):
|
|
"""
|
|
A successful response returns parsed candidates.
|
|
"""
|
|
mock_resp = MagicMock()
|
|
mock_resp.json.return_value = {
|
|
"optimizations": [
|
|
{
|
|
"source_code": "def compute(x): return x + 1",
|
|
"explanation": "line-profiler guided",
|
|
"optimization_id": "lp-001",
|
|
},
|
|
]
|
|
}
|
|
with patch.object(
|
|
client._session, "post", return_value=mock_resp
|
|
) as mock_post:
|
|
result = client.optimize_with_line_profiler(
|
|
request_,
|
|
line_profiler_results="Line # Hits Time",
|
|
)
|
|
|
|
mock_post.assert_called_once_with(
|
|
"http://localhost/ai/optimize-line-profiler",
|
|
json={
|
|
"source_code": "def compute(x): return x + 1",
|
|
"dependency_code": "# read only",
|
|
"trace_id": ANY,
|
|
"language": "python",
|
|
"language_version": "3.12.0",
|
|
"n_candidates": 5,
|
|
"line_profiler_results": "Line # Hits Time",
|
|
"call_sequence": 1,
|
|
"is_numerical_code": None,
|
|
"codeflash_version": "",
|
|
},
|
|
timeout=300.0,
|
|
)
|
|
assert 1 == len(result)
|
|
assert all(isinstance(c, Candidate) for c in result)
|
|
assert "lp-001" == result[0].candidate_id
|
|
|
|
def test_empty_line_profiler_returns_early(self, client, request_):
|
|
"""
|
|
Empty line profiler results return [] without calling the API.
|
|
"""
|
|
with patch.object(client._session, "post") as mock_post:
|
|
result = client.optimize_with_line_profiler(
|
|
request_, line_profiler_results=""
|
|
)
|
|
|
|
assert [] == result
|
|
mock_post.assert_not_called()
|
|
|
|
def test_http_error_raises(self, client, request_):
|
|
"""
|
|
An HTTP error raises AIServiceError.
|
|
"""
|
|
mock_resp = MagicMock()
|
|
mock_resp.status_code = 500
|
|
mock_resp.text = "Internal Server Error"
|
|
mock_resp.raise_for_status.side_effect = requests.HTTPError("500")
|
|
with (
|
|
patch.object(client._session, "post", return_value=mock_resp),
|
|
pytest.raises(AIServiceError),
|
|
):
|
|
client.optimize_with_line_profiler(
|
|
request_,
|
|
line_profiler_results="Line # Hits Time",
|
|
)
|
|
|
|
|
|
class TestGenerateExplanation:
|
|
"""Tests for AIClient.generate_explanation."""
|
|
|
|
def test_success(self, client):
|
|
"""
|
|
A successful response returns the explanation text.
|
|
"""
|
|
mock_resp = MagicMock()
|
|
mock_resp.json.return_value = {
|
|
"explanation": "Replaced loop with vectorized op."
|
|
}
|
|
with patch.object(client._session, "post", return_value=mock_resp):
|
|
result = client.generate_explanation(
|
|
{"trace_id": "t1", "source_code": "x"}
|
|
)
|
|
|
|
assert "Replaced loop with vectorized op." == result
|
|
|
|
def test_failure_returns_empty(self, client):
|
|
"""
|
|
An API error returns an empty string.
|
|
"""
|
|
mock_resp = MagicMock()
|
|
mock_resp.status_code = 500
|
|
mock_resp.text = "fail"
|
|
mock_resp.raise_for_status.side_effect = requests.HTTPError("500")
|
|
with patch.object(client._session, "post", return_value=mock_resp):
|
|
result = client.generate_explanation({"trace_id": "t1"})
|
|
|
|
assert "" == result
|
|
|
|
|
|
class TestLogResults:
|
|
"""Tests for AIClient.log_results."""
|
|
|
|
def test_success(self, client):
|
|
"""
|
|
A successful call completes without error.
|
|
"""
|
|
mock_resp = MagicMock()
|
|
mock_resp.json.return_value = {}
|
|
with patch.object(client._session, "post", return_value=mock_resp):
|
|
client.log_results({"trace_id": "t1"})
|
|
|
|
def test_failure_is_silent(self, client):
|
|
"""
|
|
API errors are silently swallowed.
|
|
"""
|
|
mock_resp = MagicMock()
|
|
mock_resp.status_code = 500
|
|
mock_resp.text = "fail"
|
|
mock_resp.raise_for_status.side_effect = requests.HTTPError("500")
|
|
with patch.object(client._session, "post", return_value=mock_resp):
|
|
client.log_results({"trace_id": "t1"}) # no exception
|
|
|
|
|
|
class TestGetOptimizationReview:
|
|
"""Tests for AIClient.get_optimization_review."""
|
|
|
|
def test_success(self, client):
|
|
"""
|
|
A successful response returns review and explanation.
|
|
"""
|
|
mock_resp = MagicMock()
|
|
mock_resp.json.return_value = {
|
|
"review": "high",
|
|
"review_explanation": "Well-tested optimization.",
|
|
}
|
|
with patch.object(client._session, "post", return_value=mock_resp):
|
|
result = client.get_optimization_review(
|
|
{"trace_id": "t1", "original_code": "x"}
|
|
)
|
|
|
|
assert isinstance(result, OptimizationReviewResult)
|
|
assert "high" == result.review
|
|
assert "Well-tested optimization." == result.explanation
|
|
|
|
def test_failure_returns_empty(self, client):
|
|
"""
|
|
An API error returns empty review and explanation.
|
|
"""
|
|
mock_resp = MagicMock()
|
|
mock_resp.status_code = 500
|
|
mock_resp.text = "fail"
|
|
mock_resp.raise_for_status.side_effect = requests.HTTPError("500")
|
|
with patch.object(client._session, "post", return_value=mock_resp):
|
|
result = client.get_optimization_review({"trace_id": "t1"})
|
|
|
|
assert "" == result.review
|
|
assert "" == result.explanation
|