codeflash-internal/django/aiservice/code_repair/code_repair_context.py
2026-01-28 22:19:40 +02:00

150 lines
5.9 KiB
Python

import logging
from collections import defaultdict
from dataclasses import dataclass
from enum import Enum
from typing import TYPE_CHECKING
import libcst as cst
import sentry_sdk
from ninja import Field, Schema
from pydantic import ValidationError
from aiservice.common.cst_utils import parse_module_to_cst
from optimizer.context_utils.context_helpers import group_code, is_markdown_structure_changed, split_markdown_code
from optimizer.diff_patches_utils.seach_and_replace import SearchAndReplaceDiff
from optimizer.models import CodeAndExplanation
if TYPE_CHECKING:
from optimizer.diff_patches_utils.diff import Diff
class TestDiffScope(str, Enum):
RETURN_VALUE = "return_value"
STDOUT = "stdout"
DID_PASS = "did_pass" # noqa: S105
TIMED_OUT = "timed_out"
SCOPE_DESCRIPTIONS = {
TestDiffScope.RETURN_VALUE: (
"The function returned a different value in the optimized code compared to the original."
),
TestDiffScope.STDOUT: ("The output printed to stdout is different in the optimized code compared to the original."),
TestDiffScope.DID_PASS: (
"The test passed in one version but failed in the other (a change in pass/fail behavior)."
),
}
class TestDiff(Schema):
scope: TestDiffScope
original_value: bool | str | int | float | dict | list | None = None
candidate_value: bool | str | int | float | dict | list | None = None
original_pass: bool
candidate_pass: bool
test_src_code: str
candidate_pytest_error: str | None = None
original_pytest_error: str | None = None
class CodeRepairRequestSchema(Schema):
trace_id: str
optimization_id: str
original_source_code: str
modified_source_code: str
test_diffs: list[TestDiff] = Field(..., alias="test_diffs")
language: str = "python" # python, javascript, typescript
@dataclass()
class CodeRepairContextData:
original_source_code: str
modified_source_code: str
test_diffs: list[TestDiff]
language: str = "python"
class CodeRepairContext:
def __init__(self, ctx_data: CodeRepairContextData, base_system_prompt: str, base_user_prompt: str) -> None:
self.data = ctx_data
self.base_system_prompt = base_system_prompt
self.base_user_prompt = base_user_prompt
def get_system_prompt(self) -> str:
return self.base_system_prompt
def build_test_details(self, test_diffs: list[TestDiff]) -> str:
sections = defaultdict(str)
language = self.data.language
test_error_label = "Pytest error" if language == "python" else "Test error"
for diff in test_diffs:
try:
if sections[diff.test_src_code] == "":
# add error strings and test def only once per test function
sections[diff.test_src_code] += f"""Test Source:
```{language}
{diff.test_src_code}
```
{test_error_label} (original code): {diff.original_pytest_error if diff.original_pytest_error else ""}
{test_error_label} (optimized code): {diff.candidate_pytest_error if diff.candidate_pytest_error else ""}
"""
sections[diff.test_src_code] += "\n".join(
[
f"{SCOPE_DESCRIPTIONS.get(diff.scope, diff.scope.value)}",
f"Expected: {diff.original_value!r}.\nGot: {diff.candidate_value!r}."
if diff.scope != TestDiffScope.DID_PASS
else "",
f"Original code test status: {'Passed' if diff.original_pass else 'Failed'}. Optimized code test status: {'Passed' if diff.candidate_pass else 'Failed'}",
"---",
]
)
except Exception as e:
logging.exception("Some issue in parsing test diffs")
sentry_sdk.capture_exception(e)
return "\n".join(sections.values())
def get_user_prompt(self) -> str:
return self.base_user_prompt.format(
original_source_code=self.data.original_source_code,
modified_source_code=self.data.modified_source_code,
test_details=self.build_test_details(self.data.test_diffs),
)
def apply_patches_to_optimized_code(self, llm_res: str) -> str:
file_to_code = split_markdown_code(self.data.modified_source_code, self.data.language)
diff: Diff = SearchAndReplaceDiff(content=llm_res, source_code=file_to_code)
file_to_code = diff.run()
return group_code(file_to_code, self.data.language)
def is_valid(self, new_refined_code: str) -> bool:
if is_markdown_structure_changed(new_refined_code, self.data.modified_source_code, self.data.language):
return False
valid = True
for code in split_markdown_code(new_refined_code, self.data.language).values():
stripped_code = code.strip()
if not stripped_code:
valid = False
break
# Only validate Python syntax with libcst
if self.data.language == "python":
try:
parse_module_to_cst(code)
except cst.ParserSyntaxError:
valid = False
break
# For JavaScript/TypeScript, basic validation (non-empty code)
# More sophisticated validation could be added later
return valid
def validate_module(self) -> None:
"""Validate the module syntax based on language."""
# Skip validation for non-Python languages for now
# TODO: have some way to validate the syntax of the code for other languages like js & ts
if self.data.language != "python":
return
for _code in split_markdown_code(self.data.modified_source_code).values():
try:
cst_module = parse_module_to_cst(_code)
CodeAndExplanation(cst_module, "")
except (ValueError, ValidationError, cst.ParserSyntaxError): # noqa: TRY203
raise