mirror of
https://github.com/codeflash-ai/codeflash-internal.git
synced 2026-05-04 18:25:18 +00:00
150 lines
5.9 KiB
Python
150 lines
5.9 KiB
Python
import logging
|
|
from collections import defaultdict
|
|
from dataclasses import dataclass
|
|
from enum import Enum
|
|
from typing import TYPE_CHECKING
|
|
|
|
import libcst as cst
|
|
import sentry_sdk
|
|
from ninja import Field, Schema
|
|
from pydantic import ValidationError
|
|
|
|
from aiservice.common.cst_utils import parse_module_to_cst
|
|
from optimizer.context_utils.context_helpers import group_code, is_markdown_structure_changed, split_markdown_code
|
|
from optimizer.diff_patches_utils.seach_and_replace import SearchAndReplaceDiff
|
|
from optimizer.models import CodeAndExplanation
|
|
|
|
if TYPE_CHECKING:
|
|
from optimizer.diff_patches_utils.diff import Diff
|
|
|
|
|
|
class TestDiffScope(str, Enum):
|
|
RETURN_VALUE = "return_value"
|
|
STDOUT = "stdout"
|
|
DID_PASS = "did_pass" # noqa: S105
|
|
TIMED_OUT = "timed_out"
|
|
|
|
|
|
SCOPE_DESCRIPTIONS = {
|
|
TestDiffScope.RETURN_VALUE: (
|
|
"The function returned a different value in the optimized code compared to the original."
|
|
),
|
|
TestDiffScope.STDOUT: ("The output printed to stdout is different in the optimized code compared to the original."),
|
|
TestDiffScope.DID_PASS: (
|
|
"The test passed in one version but failed in the other (a change in pass/fail behavior)."
|
|
),
|
|
}
|
|
|
|
|
|
class TestDiff(Schema):
|
|
scope: TestDiffScope
|
|
original_value: bool | str | int | float | dict | list | None = None
|
|
candidate_value: bool | str | int | float | dict | list | None = None
|
|
original_pass: bool
|
|
candidate_pass: bool
|
|
test_src_code: str
|
|
candidate_pytest_error: str | None = None
|
|
original_pytest_error: str | None = None
|
|
|
|
|
|
class CodeRepairRequestSchema(Schema):
|
|
trace_id: str
|
|
optimization_id: str
|
|
original_source_code: str
|
|
modified_source_code: str
|
|
test_diffs: list[TestDiff] = Field(..., alias="test_diffs")
|
|
language: str = "python" # python, javascript, typescript
|
|
|
|
|
|
@dataclass()
|
|
class CodeRepairContextData:
|
|
original_source_code: str
|
|
modified_source_code: str
|
|
test_diffs: list[TestDiff]
|
|
language: str = "python"
|
|
|
|
|
|
class CodeRepairContext:
|
|
def __init__(self, ctx_data: CodeRepairContextData, base_system_prompt: str, base_user_prompt: str) -> None:
|
|
self.data = ctx_data
|
|
self.base_system_prompt = base_system_prompt
|
|
self.base_user_prompt = base_user_prompt
|
|
|
|
def get_system_prompt(self) -> str:
|
|
return self.base_system_prompt
|
|
|
|
def build_test_details(self, test_diffs: list[TestDiff]) -> str:
|
|
sections = defaultdict(str)
|
|
language = self.data.language
|
|
test_error_label = "Pytest error" if language == "python" else "Test error"
|
|
for diff in test_diffs:
|
|
try:
|
|
if sections[diff.test_src_code] == "":
|
|
# add error strings and test def only once per test function
|
|
sections[diff.test_src_code] += f"""Test Source:
|
|
```{language}
|
|
{diff.test_src_code}
|
|
```
|
|
{test_error_label} (original code): {diff.original_pytest_error if diff.original_pytest_error else ""}
|
|
{test_error_label} (optimized code): {diff.candidate_pytest_error if diff.candidate_pytest_error else ""}
|
|
"""
|
|
sections[diff.test_src_code] += "\n".join(
|
|
[
|
|
f"{SCOPE_DESCRIPTIONS.get(diff.scope, diff.scope.value)}",
|
|
f"Expected: {diff.original_value!r}.\nGot: {diff.candidate_value!r}."
|
|
if diff.scope != TestDiffScope.DID_PASS
|
|
else "",
|
|
f"Original code test status: {'Passed' if diff.original_pass else 'Failed'}. Optimized code test status: {'Passed' if diff.candidate_pass else 'Failed'}",
|
|
"---",
|
|
]
|
|
)
|
|
except Exception as e:
|
|
logging.exception("Some issue in parsing test diffs")
|
|
sentry_sdk.capture_exception(e)
|
|
return "\n".join(sections.values())
|
|
|
|
def get_user_prompt(self) -> str:
|
|
return self.base_user_prompt.format(
|
|
original_source_code=self.data.original_source_code,
|
|
modified_source_code=self.data.modified_source_code,
|
|
test_details=self.build_test_details(self.data.test_diffs),
|
|
)
|
|
|
|
def apply_patches_to_optimized_code(self, llm_res: str) -> str:
|
|
file_to_code = split_markdown_code(self.data.modified_source_code, self.data.language)
|
|
diff: Diff = SearchAndReplaceDiff(content=llm_res, source_code=file_to_code)
|
|
file_to_code = diff.run()
|
|
return group_code(file_to_code, self.data.language)
|
|
|
|
def is_valid(self, new_refined_code: str) -> bool:
|
|
if is_markdown_structure_changed(new_refined_code, self.data.modified_source_code, self.data.language):
|
|
return False
|
|
valid = True
|
|
for code in split_markdown_code(new_refined_code, self.data.language).values():
|
|
stripped_code = code.strip()
|
|
if not stripped_code:
|
|
valid = False
|
|
break
|
|
# Only validate Python syntax with libcst
|
|
if self.data.language == "python":
|
|
try:
|
|
parse_module_to_cst(code)
|
|
except cst.ParserSyntaxError:
|
|
valid = False
|
|
break
|
|
# For JavaScript/TypeScript, basic validation (non-empty code)
|
|
# More sophisticated validation could be added later
|
|
return valid
|
|
|
|
def validate_module(self) -> None:
|
|
"""Validate the module syntax based on language."""
|
|
# Skip validation for non-Python languages for now
|
|
# TODO: have some way to validate the syntax of the code for other languages like js & ts
|
|
if self.data.language != "python":
|
|
return
|
|
for _code in split_markdown_code(self.data.modified_source_code).values():
|
|
try:
|
|
cst_module = parse_module_to_cst(_code)
|
|
CodeAndExplanation(cst_module, "")
|
|
except (ValueError, ValidationError, cst.ParserSyntaxError): # noqa: TRY203
|
|
raise
|