feedback loop for unmatched test results (#2059)

fixes CF-932

# Pull Request Checklist

## Description
- [ ] **Description of PR**: Clear and concise description of what this
PR accomplishes
- [ ] **Breaking Changes**: Document any breaking changes (if
applicable)
- [ ] **Related Issues**: Link to any related issues or tickets

## Testing
- [ ] **Test cases Attached**: All relevant test cases have been
added/updated
- [ ] **Manual Testing**: Manual testing completed for the changes

## Monitoring & Debugging
- [ ] **Logging in place**: Appropriate logging has been added for
debugging user issues
- [ ] **Sentry will be able to catch errors**: Error handling ensures
Sentry can capture and report errors
- [ ] **Avoid Dev based/Prisma logging**: No development-only or
Prisma-specific logging in production code

## Configuration
- [ ] **Env variables newly added**: Any new environment variables are
documented in .env.example file or mentioned in description
---

## Additional Notes
<!-- Add any additional context, screenshots, or notes for reviewers
here -->

---------

Co-authored-by: codeflash-ai[bot] <148906541+codeflash-ai[bot]@users.noreply.github.com>
Co-authored-by: ali <mohammed18200118@gmail.com>
Co-authored-by: Kevin Turcios <106575910+KRRT7@users.noreply.github.com>
This commit is contained in:
Aseem Saxena 2025-12-16 18:44:32 -08:00 committed by GitHub
parent 96ea895c99
commit 1192df12a6
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
25 changed files with 1533 additions and 54 deletions

View file

@ -219,3 +219,4 @@ RANKING_MODEL: LLM = _get_openai_model()
REFINEMENT_MODEL: LLM = _get_anthropic_model() REFINEMENT_MODEL: LLM = _get_anthropic_model()
EXPLANATIONS_MODEL: LLM = _get_anthropic_model() EXPLANATIONS_MODEL: LLM = _get_anthropic_model()
OPTIMIZATION_REVIEW_MODEL: LLM = _get_anthropic_model() OPTIMIZATION_REVIEW_MODEL: LLM = _get_anthropic_model()
CODE_REPAIR_MODEL: LLM = _get_anthropic_model()

View file

@ -20,6 +20,7 @@ Including another URLconf
# from django.contrib import admin # from django.contrib import admin
from django.urls import path from django.urls import path
from code_repair.code_repair import code_repair_api
from explanations.explanations import explanations_api from explanations.explanations import explanations_api
from log_features.log_features import features_api from log_features.log_features import features_api
from optimization_review.optimization_review import optimization_review_api from optimization_review.optimization_review import optimization_review_api
@ -39,5 +40,6 @@ urlpatterns = [
path("ai/explain", explanations_api.urls), path("ai/explain", explanations_api.urls),
path("ai/rank", ranker_api.urls), path("ai/rank", ranker_api.urls),
path("ai/optimization_review", optimization_review_api.urls), path("ai/optimization_review", optimization_review_api.urls),
path("ai/code_repair", code_repair_api.urls),
path("ai/workflow-gen", workflow_gen_api.urls), path("ai/workflow-gen", workflow_gen_api.urls),
] ]

View file

@ -0,0 +1,61 @@
You are a senior software engineer who is great at reviewing and repairing python code for performance and behavior.
The goal of repairing code is to ensure that the optimized code is performant and has the same behavior as the original code based on the provided test results.
You are provided the following information:
- original_source_code
- optimized_source_code - This is the optimized implementation of the original code.
- test_details - This has the details of the behavioral differences between the original and optimized code.
### Output format
Request to replace sections of the optimized code in an existing file using SEARCH/REPLACE blocks that define exact changes to specific parts of the file.
Define your output in XML-style tags like here:
<replace_in_file>
<path>src/main.py</path>
<diff>
<<<<<<< SEARCH
a = 2
=======
a = 3
>>>>>>> REPLACE
</diff>
</replace_in_file>
Always adhere to this format for tool use to ensure proper parsing and execution.
## replace_in_file
Description: Request to replace sections of content in an existing file using SEARCH/REPLACE blocks that define exact changes to specific parts of the file. This tool should be used when you need to make targeted changes to specific parts of a file.
Parameters:
- path: (required) The path of the file to modify
- diff: (required) One or more SEARCH/REPLACE blocks following this exact format:
```
<<<<<<< SEARCH
[exact content to find]
=======
[new content to replace with]
>>>>>>> REPLACE
```
Critical rules:
1. SEARCH content must match the associated file section to find EXACTLY:
* Match character-for-character including whitespace, indentation, line endings
* Include all comments, docstrings, etc.
2. SEARCH/REPLACE blocks will ONLY replace the first match occurrence.
* Including multiple unique SEARCH/REPLACE blocks if you need to make multiple changes.
* Include *just* enough lines in each SEARCH section to uniquely match each set of lines that need to change.
* When using multiple SEARCH/REPLACE blocks, list them in the order they appear in the file.
3. Keep SEARCH/REPLACE blocks concise:
* Break large SEARCH/REPLACE blocks into a series of smaller blocks that each change a small portion of the file.
* Include just the changing lines, and a few surrounding lines if needed for uniqueness.
* Do not include long runs of unchanging lines in SEARCH/REPLACE blocks.
* Each line must be complete. Never truncate lines mid-way through as this can cause matching failures.
4. Special operations:
* To move code: Use two SEARCH/REPLACE blocks (one to delete from original + one to insert at new location)
* To delete code: Use empty REPLACE section
Usage:
<replace_in_file>
<path>File path here</path>
<diff>
Search and replace blocks here
</diff>
</replace_in_file>

View file

@ -0,0 +1,13 @@
Please fix the optimized code to match the behaviour of the original code, while trying to keep the optimization logic intact.
### The original source code:
{original_source_code}
### The optimized source code
{modified_source_code}
### The test result details
{test_details}

View file

View file

@ -0,0 +1,6 @@
from django.apps import AppConfig
class CodeRepairConfig(AppConfig):
default_auto_field = "django.db.models.BigAutoField"
name = "code_repair"

View file

@ -0,0 +1,204 @@
from __future__ import annotations
import logging
import uuid
from pathlib import Path
from typing import TYPE_CHECKING
import libcst as cst
import sentry_sdk
from ninja import NinjaAPI, Schema
from openai.types.chat import ChatCompletionSystemMessageParam, ChatCompletionUserMessageParam
from pydantic import ValidationError
from aiservice.analytics.posthog import ph
from aiservice.common_utils import validate_trace_id
from aiservice.env_specific import debug_log_sensitive_data, llm_clients
from aiservice.models.aimodels import CODE_REPAIR_MODEL, calculate_llm_cost
from log_features.log_event import update_optimization_cost
from log_features.log_features import log_features
from optimizer.models import OptimizedCandidateSource
from .code_repair_context import ( # noqa: TC001 (don't move CodeRepairRequestSchema to type-checking because it's the schema definition)
CodeRepairContext,
CodeRepairContextData,
CodeRepairRequestSchema,
)
if TYPE_CHECKING:
from django.http import HttpRequest as Request
from openai.types.chat import (
ChatCompletionAssistantMessageParam,
ChatCompletionFunctionMessageParam,
ChatCompletionToolMessageParam,
)
from aiservice.models.aimodels import LLM
code_repair_api = NinjaAPI(urls_namespace="code_repair")
# Get the directory of the current file
current_dir = Path(__file__).parent
SYSTEM_PROMPT = (current_dir / "CODE_REPAIR_SYSTEM_PROMPT.md").read_text()
USER_PROMPT = (current_dir / "CODE_REPAIR_USER_PROMPT.md").read_text()
async def code_repair( # noqa: D417
user_id: str, optimization_id: str, ctx: CodeRepairContext, optimize_model: LLM = CODE_REPAIR_MODEL
) -> CodeRepairIntermediateResponseItemschema | CodeRepairErrorResponseSchema:
"""Repair the given candidate to match the behaviour of the original code.
Parameters
----------
:param user_id:
:param optimization_id
:param optimize_model: LLM for getting the code_repairs
:param ctx: the repair context, has the data property which includes
- original code
- optimized code
- behaviour test diffs
Returns
-------
CodeRepairIntermediateResponseItemschema or CodeRepairErrorResponseSchema
"""
system_prompt = ctx.get_system_prompt()
user_prompt = ctx.get_user_prompt()
new_op_id = str(uuid.uuid4())
system_message = ChatCompletionSystemMessageParam(role="system", content=system_prompt)
user_message = ChatCompletionUserMessageParam(role="user", content=user_prompt)
messages: list[
ChatCompletionSystemMessageParam
| ChatCompletionUserMessageParam
| ChatCompletionAssistantMessageParam
| ChatCompletionToolMessageParam
| ChatCompletionFunctionMessageParam
] = [system_message, user_message]
debug_log_sensitive_data(f"This was the user prompt\n {user_prompt}\n")
llm_client = llm_clients[optimize_model.model_type]
try:
output = await llm_client.with_options(max_retries=2).chat.completions.create(
model=optimize_model.name, messages=messages, n=1
)
llm_cost = calculate_llm_cost(output, optimize_model)
except Exception as e:
logging.exception("Claude Code Generation error in code_repair")
sentry_sdk.capture_exception(e)
debug_log_sensitive_data(f"Failed to generate code for source:\n{ctx.data.original_source_code}")
return CodeRepairErrorResponseSchema(error=str(e))
debug_log_sensitive_data(f"ClaudeClient optimization response:\n{output.model_dump_json(indent=2)}")
if output.usage is not None:
ph(user_id, "code_repair-usage", properties={"model": optimize_model.name, "usage": output.usage.json()})
results = [content for op in output.choices if (content := op.message.content)] # will be of size 1
# Regex doesn't work yet in extracting everything else other than the search replace block
explanation = results[0]
repaired_optimization = ""
try:
diff_patches = ctx.extract_diff_patches_from_llm_res(results[0])
repaired_optimization = ctx.apply_patches_to_optimized_code(diff_patches)
except (ValueError, ValidationError) as exc:
sentry_sdk.capture_exception(exc)
debug_log_sensitive_data(f"{type(exc).__name__} for source:\n{ctx.data.modified_source_code}")
debug_log_sensitive_data(f"Traceback: {exc}")
repaired_optimization = ""
if not ctx.is_valid(repaired_optimization):
repaired_optimization = ""
return CodeRepairIntermediateResponseItemschema(
optimization_id=new_op_id,
parent_id=optimization_id,
source_code=repaired_optimization,
llm_cost=llm_cost,
explanation=explanation,
)
class CodeRepairErrorResponseSchema(Schema):
error: str
class CodeRepairIntermediateResponseItemschema(Schema):
# the key will be the optimization id and the value will be the actual refined code
optimization_id: str
parent_id: str
source_code: str
llm_cost: float
explanation: str
class CodeRepairResponseItemschema(Schema):
# the key will be the optimization id and the value will be the actual refined code
optimization_id: str
parent_id: str
source_code: str
explanation: str
@code_repair_api.post(
"/",
response={
200: CodeRepairResponseItemschema,
400: CodeRepairErrorResponseSchema,
500: CodeRepairErrorResponseSchema,
},
)
async def repair(
request: Request, data: CodeRepairRequestSchema
) -> tuple[int, CodeRepairResponseItemschema | CodeRepairErrorResponseSchema]:
ph(request.user, "aiservice-code_repair-called")
ctx_data = CodeRepairContextData(
original_source_code=data.original_source_code,
modified_source_code=data.modified_source_code,
test_diffs=data.test_diffs,
)
ctx = CodeRepairContext(ctx_data=ctx_data, base_system_prompt=SYSTEM_PROMPT, base_user_prompt=USER_PROMPT)
trace_id = data.trace_id
if not validate_trace_id(trace_id):
return 400, CodeRepairErrorResponseSchema(error="Invalid trace ID. Please provide a valid UUIDv4.")
code_repair_data = await code_repair(user_id=request.user, optimization_id=data.optimization_id, ctx=ctx)
total_llm_cost = 0.0
if isinstance(code_repair_data, CodeRepairErrorResponseSchema):
return 500, code_repair_data
total_llm_cost += code_repair_data.llm_cost
try:
ctx.validate_python_module()
except cst.ParserSyntaxError as e:
# log exception with sentry
sentry_sdk.capture_exception(e)
debug_log_sensitive_data(f"ParserSyntaxError for source:\n{code_repair_data.source_code}")
debug_log_sensitive_data(f"Traceback: {e}")
return 500, CodeRepairErrorResponseSchema(error=str(e))
except (ValueError, ValidationError) as exc:
# Another one bites the Pydantic validation dust
sentry_sdk.capture_exception(exc)
debug_log_sensitive_data(f"{type(exc).__name__} for source:\n{code_repair_data.source_code}")
debug_log_sensitive_data(f"Traceback: {exc}")
return 500, CodeRepairErrorResponseSchema(error=str(exc))
if hasattr(request, "should_log_features") and request.should_log_features:
await log_features(
trace_id=data.trace_id,
user_id=request.user,
optimizations_raw={code_repair_data.optimization_id: code_repair_data.source_code},
explanations_raw={code_repair_data.optimization_id: code_repair_data.explanation},
# explanations_post={cei.optimization_id: cei.explanation for cei in optimization_response_items},
# optimizations_post={cei.optimization_id: cei.source_code for cei in optimization_response_items},
optimizations_origin={
code_repair_data.optimization_id: {
"source": OptimizedCandidateSource.REPAIR,
"parent": code_repair_data.parent_id,
}
},
)
await update_optimization_cost(trace_id=trace_id, cost=total_llm_cost)
return 200, CodeRepairResponseItemschema(
source_code=code_repair_data.source_code,
optimization_id=code_repair_data.optimization_id,
parent_id=code_repair_data.parent_id,
explanation=code_repair_data.explanation,
)

View file

@ -0,0 +1,153 @@
import logging
from collections import defaultdict
from dataclasses import dataclass
from enum import Enum
import libcst as cst
import sentry_sdk
from aiservice.env_specific import debug_log_sensitive_data
from ninja import Field, Schema
from optimizer.context_utils.constants import REPLACE_IN_FILE_TAGS_REGEX
from optimizer.context_utils.context_helpers import group_code, is_markdown_structure_changed, split_markdown_code
from optimizer.diff_patches_utils.patches_v2 import apply_patches, group_diff_patches_by_path
from optimizer.models import CodeAndExplanation
from pydantic import ValidationError
from testgen.instrumentation.edit_generated_test import parse_module_to_cst
class TestDiffScope(str, Enum):
RETURN_VALUE = "return_value"
STDOUT = "stdout"
DID_PASS = "did_pass" # noqa: S105
TIMED_OUT = "timed_out"
SCOPE_DESCRIPTIONS = {
TestDiffScope.RETURN_VALUE: (
"The function returned a different value in the optimized code compared to the original."
),
TestDiffScope.STDOUT: ("The output printed to stdout is different in the optimized code compared to the original."),
TestDiffScope.DID_PASS: (
"The test passed in one version but failed in the other (a change in pass/fail behavior)."
),
}
class TestDiff(Schema):
scope: TestDiffScope
original_value: bool | str | int | float | dict | list | None = None
candidate_value: bool | str | int | float | dict | list | None = None
original_pass: bool
candidate_pass: bool
test_src_code: str
candidate_pytest_error: str | None = None
original_pytest_error: str | None = None
class CodeRepairRequestSchema(Schema):
trace_id: str
optimization_id: str
original_source_code: str
modified_source_code: str
test_diffs: list[TestDiff] = Field(..., alias="test_diffs")
@dataclass()
class CodeRepairContextData:
original_source_code: str
modified_source_code: str
test_diffs: list[TestDiff]
class CodeRepairContext:
def __init__(self, ctx_data: CodeRepairContextData, base_system_prompt: str, base_user_prompt: str) -> None:
self.data = ctx_data
self.base_system_prompt = base_system_prompt
self.base_user_prompt = base_user_prompt
def get_system_prompt(self) -> str:
return self.base_system_prompt
def build_test_details(self, test_diffs: list[TestDiff]) -> str:
sections = defaultdict(str)
for diff in test_diffs:
try:
if sections[diff.test_src_code] == "":
# add error strings and test def only once per test function
sections[diff.test_src_code] += f"""Test Source:
```python
{diff.test_src_code}
```
Pytest error (original code): {diff.original_pytest_error if diff.original_pytest_error else ""}
Pytest error (optimized code): {diff.candidate_pytest_error if diff.candidate_pytest_error else ""}
"""
sections[diff.test_src_code] += "\n".join(
[
f"{SCOPE_DESCRIPTIONS.get(diff.scope, diff.scope.value)}",
f"Expected: {diff.original_value!r}.\nGot: {diff.candidate_value!r}."
if diff.scope != TestDiffScope.DID_PASS
else "",
f"Original code test status: {'Passed' if diff.original_pass else 'Failed'}. Optimized code test status: {'Passed' if diff.candidate_pass else 'Failed'}",
"---",
]
)
except Exception as e:
logging.exception("Some issue in parsing test diffs")
sentry_sdk.capture_exception(e)
return "\n".join(sections.values())
def get_user_prompt(self) -> str:
return self.base_user_prompt.format(
original_source_code=self.data.original_source_code,
modified_source_code=self.data.modified_source_code,
test_details=self.build_test_details(self.data.test_diffs),
)
def extract_diff_patches_from_llm_res(self, llm_res: str) -> str:
matches = REPLACE_IN_FILE_TAGS_REGEX.findall(llm_res)
replace_tags = ""
if matches and len(matches) != 0:
replace_tags = f"<replace_in_file>{matches[0]}</replace_in_file>"
return replace_tags
def apply_patches_to_optimized_code(self, replace_tags: str) -> str:
if replace_tags == "":
return ""
file_to_code = split_markdown_code(self.data.modified_source_code)
# sometimes the llm can write multiple replace tags for the same file, so we group them by path to avoid parsing & applying multiple times
file_to_diffs = group_diff_patches_by_path(replace_tags)
for path, diff in file_to_diffs.items():
scoped_code = file_to_code.get(path, None)
if scoped_code is None:
debug_log_sensitive_data(f"no scoped code for {path}, existing: {file_to_code.keys()}")
continue
new_code = apply_patches(diff, scoped_code)
file_to_code[path] = new_code
return group_code(file_to_code)
def is_valid(self, new_refined_code: str) -> bool:
if is_markdown_structure_changed(new_refined_code, self.data.modified_source_code):
return False
valid = True
for code in split_markdown_code(new_refined_code).values():
stripped_code = code.strip()
if not stripped_code:
valid = False
break
try:
parse_module_to_cst(code)
except cst.ParserSyntaxError:
valid = False
break
return valid
def validate_python_module(self) -> None:
for _code in split_markdown_code(self.data.modified_source_code).values():
try:
cst_module = parse_module_to_cst(_code)
CodeAndExplanation(cst_module, "")
except (ValueError, ValidationError, cst.ParserSyntaxError): # noqa: TRY203
raise

View file

@ -44,6 +44,7 @@ def log_features(
experiment_metadata: dict[str, str] | None = None, experiment_metadata: dict[str, str] | None = None,
final_explanation: str | None = None, final_explanation: str | None = None,
ranking: dict[str, Any] | None = None, ranking: dict[str, Any] | None = None,
optimizations_origin: dict[str, dict[str, str]] | None = None,
) -> None: ) -> None:
"""Log features of a code optimization run to the database. """Log features of a code optimization run to the database.
@ -95,6 +96,7 @@ def log_features(
"experiment_metadata": experiment_metadata, "experiment_metadata": experiment_metadata,
"final_explanation": final_explanation, "final_explanation": final_explanation,
"ranking": ranking, "ranking": ranking,
"optimizations_origin": optimizations_origin,
}, },
) )
@ -148,6 +150,14 @@ def log_features(
f.ranking = f.ranking | ranking if ranking is not None else f.ranking f.ranking = f.ranking | ranking if ranking is not None else f.ranking
else: else:
f.ranking = ranking if ranking is not None else f.ranking f.ranking = ranking if ranking is not None else f.ranking
if f.optimizations_origin is not None:
# merge the optimizations_origin with the existing ones
f.optimizations_origin = merge_dicts(f.optimizations_origin, optimizations_origin or {})
else:
f.optimizations_origin = (
optimizations_origin if optimizations_origin is not None else f.optimizations_origin
)
f.speedup_ratio = speedup_ratio if speedup_ratio is not None else f.speedup_ratio f.speedup_ratio = speedup_ratio if speedup_ratio is not None else f.speedup_ratio
f.original_runtime = original_runtime if original_runtime is not None else f.original_runtime f.original_runtime = original_runtime if original_runtime is not None else f.original_runtime
f.optimized_runtime = optimized_runtime if optimized_runtime is not None else f.optimized_runtime f.optimized_runtime = optimized_runtime if optimized_runtime is not None else f.optimized_runtime
@ -165,6 +175,22 @@ def log_features(
f.save() f.save()
def merge_dicts(a: dict[str, dict[str, str]], b: dict[str, dict[str, str]]) -> dict[str, dict[str, str]]:
result: dict[str, dict[str, str]] = {}
for key, inner in a.items():
result[key] = inner.copy()
for key, inner in b.items():
if key not in result:
result[key] = inner.copy()
else:
# b overrides a
result[key].update(inner)
return result
@features_api.post("/", response={200: None, 500: LoggingErrorResponseSchema}) @features_api.post("/", response={200: None, 500: LoggingErrorResponseSchema})
async def log_features_cli(request: HttpRequest, data: LoggingSchema) -> int | tuple[int, LoggingErrorResponseSchema]: async def log_features_cli(request: HttpRequest, data: LoggingSchema) -> int | tuple[int, LoggingErrorResponseSchema]:
try: try:

View file

@ -31,6 +31,7 @@ class OptimizationFeatures(models.Model):
experiment_metadata = models.JSONField(null=True, blank=True) experiment_metadata = models.JSONField(null=True, blank=True)
final_explanation = models.TextField(null=True, blank=True) final_explanation = models.TextField(null=True, blank=True)
ranking = models.JSONField(null=True, blank=True) ranking = models.JSONField(null=True, blank=True)
optimizations_origin = models.JSONField(null=True, blank=True)
# PR suggestions or create Approval fields # PR suggestions or create Approval fields
approval_required = models.BooleanField(default=False) approval_required = models.BooleanField(default=False)

View file

@ -1,5 +1,7 @@
from pydantic import ValidationError from pydantic import ValidationError
from optimizer.context_utils.constants import MULTI_REPLACE_IN_FILE_TAGS_REGEX
class SearchReplaceBlock: class SearchReplaceBlock:
def __init__(self, search, replace): def __init__(self, search, replace):
@ -19,36 +21,44 @@ def parse_diff(diff: str) -> list[SearchReplaceBlock]:
blocks: list[SearchReplaceBlock] = [] blocks: list[SearchReplaceBlock] = []
lines = diff.splitlines(keepends=True) lines = diff.splitlines(keepends=True)
n = len(lines)
idx = 0 idx = 0
while idx < len(lines): # Precompute the "marker" strings for efficiency
line = lines[idx].strip() search_marker = "<<<<<<< SEARCH"
if line == "<<<<<<< SEARCH": delimiter_marker = "======="
search_lines = [] replace_marker = ">>>>>>> REPLACE"
while idx < n:
line_stripped = lines[idx].strip()
if line_stripped.startswith(search_marker):
idx += 1 idx += 1
search_start = idx
while idx < len(lines) and lines[idx].strip() != "=======": # Find delimiter_marker line
search_lines.append(lines[idx]) while idx < n and lines[idx].strip() != delimiter_marker:
idx += 1 idx += 1
search_end = idx
if idx >= len(lines): if idx >= n:
raise ValueError("Invalid diff format: Missing '=======' marker") raise ValueError("Invalid diff format: Missing '=======' marker")
replace_lines = []
idx += 1 idx += 1
replace_start = idx
while idx < len(lines) and lines[idx].strip() != ">>>>>>> REPLACE": while idx < n and not lines[idx].strip().startswith(replace_marker):
replace_lines.append(lines[idx])
idx += 1 idx += 1
replace_end = idx
if idx >= len(lines): if idx >= n:
raise ValueError("Invalid diff format: Missing '>>>>>>> REPLACE' marker") raise ValueError(
"Invalid diff format: Missing '>>>>>>> REPLACE' marker"
)
search_content = "".join(search_lines).rstrip() search_content = "".join(lines[search_start:search_end]).rstrip()
replace_content = "".join(replace_lines).rstrip() replace_content = "".join(lines[replace_start:replace_end]).rstrip()
try: try:
block = SearchReplaceBlock.from_block(search=search_content, replace=replace_content) block = SearchReplaceBlock.from_block(
search=search_content, replace=replace_content
)
blocks.append(block) blocks.append(block)
except ValidationError as ve: except ValidationError as ve:
raise ValueError(f"Invalid block format: {ve}") raise ValueError(f"Invalid block format: {ve}")
@ -61,13 +71,35 @@ def parse_diff(diff: str) -> list[SearchReplaceBlock]:
return blocks return blocks
def group_diff_patches_by_path(replace_tags_str: str) -> dict[str, str]:
matches = MULTI_REPLACE_IN_FILE_TAGS_REGEX.findall(replace_tags_str)
file_to_diffs = {}
current_file = None
current_diff = ""
for path, diff in matches:
if path != current_file:
if current_file:
file_to_diffs[current_file] = current_diff
current_file = path
current_diff = diff
else:
current_diff += diff
if current_file:
file_to_diffs[current_file] = current_diff
return file_to_diffs
def apply_patches(diff_str: str, content: str) -> str: def apply_patches(diff_str: str, content: str) -> str:
try: try:
patch_blocks = parse_diff(diff_str) patch_blocks = parse_diff(diff_str)
except ValueError: except ValueError:
return content return content
for idx, block in enumerate(patch_blocks, 1): for block in patch_blocks:
if not block.search: if not block.search:
if block.replace: if block.replace:
# a replacement block without a search, then just add the replace block # a replacement block without a search, then just add the replace block
@ -79,5 +111,7 @@ def apply_patches(diff_str: str, content: str) -> str:
start_char_idx = content.find(block.search) start_char_idx = content.find(block.search)
if start_char_idx != -1: if start_char_idx != -1:
end_char_idx = start_char_idx + len(block.search) end_char_idx = start_char_idx + len(block.search)
content = f"{content[:start_char_idx]}{block.replace}{content[end_char_idx:]}" content = (
f"{content[:start_char_idx]}{block.replace}{content[end_char_idx:]}"
)
return content return content

View file

@ -1,9 +1,18 @@
import enum
import libcst import libcst
from ninja import Schema from ninja import Schema
from pydantic import field_validator from pydantic import field_validator
from pydantic.dataclasses import dataclass from pydantic.dataclasses import dataclass
class OptimizedCandidateSource(str, enum.Enum):
OPTIMIZE = "OPTIMIZE"
OPTIMIZE_LP = "OPTIMIZE_LP"
REFINE = "REFINE"
REPAIR = "REPAIR"
@dataclass(frozen=True) @dataclass(frozen=True)
class CodeAndExplanation: class CodeAndExplanation:
cst_module: libcst.Module | None cst_module: libcst.Module | None

View file

@ -8,6 +8,11 @@ from typing import TYPE_CHECKING
import libcst as cst import libcst as cst
import sentry_sdk import sentry_sdk
from ninja import NinjaAPI
from ninja.errors import HttpError
from openai.types.chat import ChatCompletionSystemMessageParam, ChatCompletionUserMessageParam
from pydantic import ValidationError
from aiservice.analytics.posthog import ph from aiservice.analytics.posthog import ph
from aiservice.common_utils import parse_python_version, should_hack_for_demo, validate_trace_id from aiservice.common_utils import parse_python_version, should_hack_for_demo, validate_trace_id
from aiservice.env_specific import debug_log_sensitive_data, debug_log_sensitive_data_from_callable, llm_clients from aiservice.env_specific import debug_log_sensitive_data, debug_log_sensitive_data_from_callable, llm_clients
@ -15,11 +20,6 @@ from aiservice.models.aimodels import OPTIMIZE_MODEL, calculate_llm_cost
from authapp.user import get_user_by_id from authapp.user import get_user_by_id
from log_features.log_event import log_optimization_event from log_features.log_event import log_optimization_event
from log_features.log_features import log_features from log_features.log_features import log_features
from ninja import NinjaAPI
from ninja.errors import HttpError
from openai.types.chat import ChatCompletionSystemMessageParam, ChatCompletionUserMessageParam
from pydantic import ValidationError
from optimizer.context_utils.context_helpers import group_code from optimizer.context_utils.context_helpers import group_code
from optimizer.context_utils.optimizer_context import ( from optimizer.context_utils.optimizer_context import (
BaseOptimizerContext, BaseOptimizerContext,
@ -27,10 +27,9 @@ from optimizer.context_utils.optimizer_context import (
OptimizeResponseItemSchema, OptimizeResponseItemSchema,
OptimizeResponseSchema, OptimizeResponseSchema,
) )
from optimizer.models import OptimizeSchema # noqa: TC001 from optimizer.models import OptimizedCandidateSource, OptimizeSchema # noqa: TC001
if TYPE_CHECKING: if TYPE_CHECKING:
from aiservice.models.aimodels import LLM
from django.http import HttpRequest from django.http import HttpRequest
from openai.types.chat import ( from openai.types.chat import (
ChatCompletionAssistantMessageParam, ChatCompletionAssistantMessageParam,
@ -38,6 +37,8 @@ if TYPE_CHECKING:
ChatCompletionToolMessageParam, ChatCompletionToolMessageParam,
) )
from aiservice.models.aimodels import LLM
optimizations_json = [ optimizations_json = [
{ {
@ -345,7 +346,7 @@ async def optimize(
}, },
explanations_post={cei.optimization_id: cei.explanation for cei in optimization_response_items}, explanations_post={cei.optimization_id: cei.explanation for cei in optimization_response_items},
experiment_metadata=data.experiment_metadata if data.experiment_metadata else None, experiment_metadata=data.experiment_metadata if data.experiment_metadata else None,
# request=request, optimizations_origin={cei.optimization_id: {"source": OptimizedCandidateSource.OPTIMIZE, "parent": None} for cei in optimization_response_items},
) )
) )

View file

@ -5,29 +5,30 @@ from pathlib import Path
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
import sentry_sdk import sentry_sdk
from ninja import NinjaAPI, Schema
from openai.types.chat import ChatCompletionSystemMessageParam, ChatCompletionUserMessageParam
from aiservice.analytics.posthog import ph from aiservice.analytics.posthog import ph
from aiservice.common_utils import parse_python_version, validate_trace_id from aiservice.common_utils import parse_python_version, validate_trace_id
from aiservice.env_specific import debug_log_sensitive_data, debug_log_sensitive_data_from_callable, llm_clients from aiservice.env_specific import debug_log_sensitive_data, debug_log_sensitive_data_from_callable, llm_clients
from aiservice.models.aimodels import OPTIMIZE_MODEL, calculate_llm_cost from aiservice.models.aimodels import OPTIMIZE_MODEL, calculate_llm_cost
from log_features.log_event import update_optimization_cost from log_features.log_event import update_optimization_cost
from log_features.log_features import log_features from log_features.log_features import log_features
from ninja import NinjaAPI, Schema
from openai.types.chat import ChatCompletionSystemMessageParam, ChatCompletionUserMessageParam
from optimizer.context_utils.optimizer_context import ( from optimizer.context_utils.optimizer_context import (
BaseOptimizerContext, BaseOptimizerContext,
OptimizeErrorResponseSchema, OptimizeErrorResponseSchema,
OptimizeResponseSchema, OptimizeResponseSchema,
) )
from optimizer.models import OptimizedCandidateSource
if TYPE_CHECKING: if TYPE_CHECKING:
from aiservice.models.aimodels import LLM
from openai.types.chat import ( from openai.types.chat import (
ChatCompletionAssistantMessageParam, ChatCompletionAssistantMessageParam,
ChatCompletionFunctionMessageParam, ChatCompletionFunctionMessageParam,
ChatCompletionToolMessageParam, ChatCompletionToolMessageParam,
) )
from aiservice.models.aimodels import LLM
from optimizer.context_utils.optimizer_context import OptimizeResponseItemSchema from optimizer.context_utils.optimizer_context import OptimizeResponseItemSchema
@ -190,6 +191,8 @@ async def optimize(request, data: OptimizeSchemaLP) -> tuple[int, OptimizeRespon
}, },
explanations_post={cei.optimization_id: cei.explanation for cei in optimization_response_items}, explanations_post={cei.optimization_id: cei.explanation for cei in optimization_response_items},
experiment_metadata=data.experiment_metadata if data.experiment_metadata else None, experiment_metadata=data.experiment_metadata if data.experiment_metadata else None,
optimizations_origin={cei.optimization_id: {"source": OptimizedCandidateSource.OPTIMIZE_LP, "parent": None} for cei in optimization_response_items},
) )
response = OptimizeResponseSchema(optimizations=optimization_response_items) response = OptimizeResponseSchema(optimizations=optimization_response_items)

View file

@ -2,31 +2,34 @@ from __future__ import annotations
import asyncio import asyncio
import logging import logging
import uuid
from pathlib import Path from pathlib import Path
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
import libcst as cst import libcst as cst
import sentry_sdk import sentry_sdk
from ninja import NinjaAPI, Schema
from openai.types.chat import ChatCompletionSystemMessageParam, ChatCompletionUserMessageParam
from pydantic import ValidationError
from aiservice.analytics.posthog import ph from aiservice.analytics.posthog import ph
from aiservice.common_utils import validate_trace_id from aiservice.common_utils import validate_trace_id
from aiservice.env_specific import debug_log_sensitive_data, llm_clients from aiservice.env_specific import debug_log_sensitive_data, llm_clients
from aiservice.models.aimodels import REFINEMENT_MODEL, calculate_llm_cost from aiservice.models.aimodels import REFINEMENT_MODEL, calculate_llm_cost
from log_features.log_event import update_optimization_cost from log_features.log_event import update_optimization_cost
from log_features.log_features import log_features from log_features.log_features import log_features
from ninja import NinjaAPI, Schema
from openai.types.chat import ChatCompletionSystemMessageParam, ChatCompletionUserMessageParam
from pydantic import ValidationError
from optimizer.context_utils.refiner_context import BaseRefinerContext, RefinementContextData from optimizer.context_utils.refiner_context import BaseRefinerContext, RefinementContextData
from optimizer.models import OptimizedCandidateSource
if TYPE_CHECKING: if TYPE_CHECKING:
from aiservice.models.aimodels import LLM
from openai.types.chat import ( from openai.types.chat import (
ChatCompletionAssistantMessageParam, ChatCompletionAssistantMessageParam,
ChatCompletionFunctionMessageParam, ChatCompletionFunctionMessageParam,
ChatCompletionToolMessageParam, ChatCompletionToolMessageParam,
) )
from aiservice.models.aimodels import LLM
refinement_api = NinjaAPI(urls_namespace="refinement") refinement_api = NinjaAPI(urls_namespace="refinement")
@ -260,7 +263,8 @@ async def refinement( # noqa: D417
refined_optimization = "" refined_optimization = ""
return RefinementIntermediateResponseItemschema( return RefinementIntermediateResponseItemschema(
optimization_id=optimization_id, parent_id=optimization_id,
optimization_id=str(uuid.uuid4()),
source_code=refined_optimization, source_code=refined_optimization,
explanation=refined_explanation, explanation=refined_explanation,
original_explanation=ctx.data.optimized_explanation, original_explanation=ctx.data.optimized_explanation,
@ -291,6 +295,7 @@ class OptimizeErrorResponseSchema(Schema):
class RefinementIntermediateResponseItemschema(Schema): class RefinementIntermediateResponseItemschema(Schema):
# the key will be the optimization id and the value will be the actual refined code # the key will be the optimization id and the value will be the actual refined code
explanation: str explanation: str
parent_id: str
optimization_id: str optimization_id: str
source_code: str source_code: str
original_explanation: str original_explanation: str
@ -301,6 +306,7 @@ class RefinementResponseItemschema(Schema):
# the key will be the optimization id and the value will be the actual refined code # the key will be the optimization id and the value will be the actual refined code
explanation: str explanation: str
optimization_id: str optimization_id: str
parent_id: str
source_code: str source_code: str
@ -374,27 +380,30 @@ async def refine(
trace_id=trace_id, trace_id=trace_id,
user_id=request.user, user_id=request.user,
optimizations_raw={ optimizations_raw={
cei.optimization_id[:-4] + "refi": cei.source_code cei.optimization_id: cei.source_code
for cei in refinement_data for cei in refinement_data
if not isinstance(cei, OptimizeErrorResponseSchema) if not isinstance(cei, OptimizeErrorResponseSchema)
}, },
optimizations_post={ optimizations_post={cei.optimization_id: cei.source_code for cei in filtered_refined_optimizations},
cei.optimization_id[:-4] + "refi": cei.source_code for cei in filtered_refined_optimizations
},
explanations_raw={ explanations_raw={
cei.optimization_id[:-4] + "refi": cei.explanation cei.optimization_id: cei.explanation
for cei in refinement_data for cei in refinement_data
if not isinstance(cei, OptimizeErrorResponseSchema) if not isinstance(cei, OptimizeErrorResponseSchema)
}, },
explanations_post={ explanations_post={cei.optimization_id: cei.explanation for cei in filtered_refined_optimizations},
cei.optimization_id[:-4] + "refi": cei.explanation for cei in filtered_refined_optimizations optimizations_origin={
cei.optimization_id: {"source": OptimizedCandidateSource.REFINE, "parent": cei.parent_id}
for cei in filtered_refined_optimizations
}, },
) )
await update_optimization_cost(trace_id=trace_id, cost=total_llm_cost) await update_optimization_cost(trace_id=trace_id, cost=total_llm_cost)
return 200, Refinementschema( return 200, Refinementschema(
refinements=[ refinements=[
RefinementResponseItemschema( RefinementResponseItemschema(
source_code=x.source_code, explanation=x.original_explanation, optimization_id=x.optimization_id source_code=x.source_code,
explanation=x.original_explanation,
optimization_id=x.optimization_id,
parent_id=x.parent_id,
) )
for x in filtered_refined_optimizations for x in filtered_refined_optimizations
] ]

View file

@ -3,6 +3,7 @@ from __future__ import annotations
from typing import TYPE_CHECKING, cast from typing import TYPE_CHECKING, cast
from libcst import CSTTransformer, ImportAlias, ImportFrom, MetadataWrapper, Name, parse_expression, parse_module from libcst import CSTTransformer, ImportAlias, ImportFrom, MetadataWrapper, Name, parse_expression, parse_module
from functools import lru_cache
if TYPE_CHECKING: if TYPE_CHECKING:
from libcst import ( from libcst import (
@ -19,6 +20,7 @@ if TYPE_CHECKING:
from aiservice.models.functions_to_optimize import FunctionToOptimize from aiservice.models.functions_to_optimize import FunctionToOptimize
@lru_cache(maxsize=128)
def parse_module_to_cst(module_str: str) -> Module: def parse_module_to_cst(module_str: str) -> Module:
"""Parse a module string into its libCST representation. """Parse a module string into its libCST representation.

View file

@ -0,0 +1,389 @@
from code_repair.code_repair_context import CodeRepairContext, CodeRepairContextData
from optimizer.diff_patches_utils.patches_v2 import apply_patches
def test_code_repair_single_file():
original_code = """```python:demo.py
import math
from typing import List, Tuple, Optional
def calculate_portfolio_metrics(
investments: List[Tuple[str, float, float]],
risk_free_rate: float = 0.02
) -> dict:
if not investments:
raise ValueError("Investments list cannot be empty")
if abs(sum(weight for _, weight, _ in investments) - 1.0) > 1e-10:
raise ValueError("Portfolio weights must sum to 1.0")
# Calculate weighted return
weighted_return = sum(weight * ret for _, weight, ret in investments)
# Calculate portfolio volatility (simplified)
volatility = math.sqrt(sum((weight * ret) ** 2 for _, weight, ret in investments))
# Calculate Sharpe ratio
if volatility == 0:
sharpe_ratio = 0.0
else:
sharpe_ratio = (weighted_return - risk_free_rate) / volatility
# Find best and worst performing assets
best_asset = max(investments, key=lambda x: x[2])
worst_asset = min(investments, key=lambda x: x[2])
return {
'weighted_return': round(weighted_return, 6),
'volatility': round(volatility, 6),
'sharpe_ratio': round(sharpe_ratio, 6),
'best_performing': (best_asset[0], round(best_asset[2], 6)),
'worst_performing': (worst_asset[0], round(worst_asset[2], 6)),
'total_assets': len(investments)
}
```
"""
optimized_code = """```python:demo.py
import math
from typing import List, Tuple, Optional
def calculate_portfolio_metrics(
investments: List[Tuple[str, float, float]],
risk_free_rate: float = 0.02
) -> dict:
if not investments:
raise ValueError("Investments list cannot be empty")
total_weight = sum(w for _, w, _ in investments)
if total_weight != 1.0: # Should use tolerance check
raise ValueError("Portfolio weights must sum to 1.0")
weighted_return = 1.0
for _, weight, ret in investments:
weighted_return *= (1 + ret) ** weight
weighted_return = weighted_return - 1.0 # Convert back from geometric
returns = [r for _, _, r in investments]
mean_return = sum(returns) / len(returns)
volatility = math.sqrt(sum((r - mean_return) ** 2 for r in returns) / len(returns))
# BUG 4: Sharpe ratio calculation is correct but uses wrong inputs
if volatility == 0:
sharpe_ratio = 0.0
else:
sharpe_ratio = (weighted_return - risk_free_rate) / volatility
def risk_adjusted_return(return_val, weight):
return (return_val - risk_free_rate) / (weight * return_val) if weight * return_val != 0 else return_val
best_asset = max(investments, key=lambda x: risk_adjusted_return(x[2], x[1]))
worst_asset = min(investments, key=lambda x: risk_adjusted_return(x[2], x[1]))
return {
"weighted_return": round(weighted_return, 6),
"volatility": 2,
"sharpe_ratio": round(sharpe_ratio, 6),
"best_performing": (best_asset[0], round(best_asset[2], 6)),
"worst_performing": (worst_asset[0], round(worst_asset[2], 6)),
"total_assets": len(investments),
}
```
"""
llm_response = """Looking at the test failure and comparing the original and modified code, I can see several issues in the modified implementation:
1. **Hardcoded volatility**: The volatility is hardcoded to `2` instead of being calculated
2. **Incorrect weight validation**: Missing tolerance check for weight sum validation
3. **Wrong volatility calculation**: Using standard deviation of returns instead of weighted portfolio volatility
4. **Wrong weighted return calculation**: Using geometric mean instead of arithmetic weighted return
5. **Wrong asset selection criteria**: Using risk-adjusted return instead of simple return
Let me fix these issues:
<replace_in_file>
<path>demo.py</path>
<diff>
<<<<<<< SEARCH
total_weight = sum(w for _, w, _ in investments)
if total_weight != 1.0: # Should use tolerance check
raise ValueError("Portfolio weights must sum to 1.0")
=======
if abs(sum(weight for _, weight, _ in investments) - 1.0) > 1e-10:
raise ValueError("Portfolio weights must sum to 1.0")
>>>>>>> REPLACE
</diff>
</replace_in_file>
<replace_in_file>
<path>demo.py</path>
<diff>
<<<<<<< SEARCH
weighted_return = 1.0
for _, weight, ret in investments:
weighted_return *= (1 + ret) ** weight
weighted_return = weighted_return - 1.0 # Convert back from geometric
=======
# Calculate weighted return
weighted_return = sum(weight * ret for _, weight, ret in investments)
>>>>>>> REPLACE
</diff>
</replace_in_file>
<replace_in_file>
<path>demo.py</path>
<diff>
<<<<<<< SEARCH
returns = [r for _, _, r in investments]
mean_return = sum(returns) / len(returns)
volatility = math.sqrt(sum((r - mean_return) ** 2 for r in returns) / len(returns))
=======
# Calculate portfolio volatility (simplified)
volatility = math.sqrt(sum((weight * ret) ** 2 for _, weight, ret in investments))
>>>>>>> REPLACE
</diff>
</replace_in_file>
<replace_in_file>
<path>demo.py</path>
<diff>
<<<<<<< SEARCH
# BUG 4: Sharpe ratio calculation is correct but uses wrong inputs
if volatility == 0:
sharpe_ratio = 0.0
else:
sharpe_ratio = (weighted_return - risk_free_rate) / volatility
def risk_adjusted_return(return_val, weight):
return (return_val - risk_free_rate) / (weight * return_val) if weight * return_val != 0 else return_val
best_asset = max(investments, key=lambda x: risk_adjusted_return(x[2], x[1]))
worst_asset = min(investments, key=lambda x: risk_adjusted_return(x[2], x[1]))
=======
# Calculate Sharpe ratio
if volatility == 0:
sharpe_ratio = 0.0
else:
sharpe_ratio = (weighted_return - risk_free_rate) / volatility
# Find best and worst performing assets
best_asset = max(investments, key=lambda x: x[2])
worst_asset = min(investments, key=lambda x: x[2])
>>>>>>> REPLACE
</diff>
</replace_in_file>
<replace_in_file>
<path>demo.py</path>
<diff>
<<<<<<< SEARCH
return {
"weighted_return": round(weighted_return, 6),
"volatility": 2,
"sharpe_ratio": round(sharpe_ratio, 6),
"best_performing": (best_asset[0], round(best_asset[2], 6)),
"worst_performing": (worst_asset[0], round(worst_asset[2], 6)),
"total_assets": len(investments),
}
=======
return {
'weighted_return': round(weighted_return, 6),
'volatility': round(volatility, 6),
'sharpe_ratio': round(sharpe_ratio, 6),
'best_performing': (best_asset[0], round(best_asset[2], 6)),
'worst_performing': (worst_asset[0], round(worst_asset[2], 6)),
'total_assets': len(investments)
}
>>>>>>> REPLACE
</diff>
</replace_in_file>
The key changes I made:
1. **Fixed weight validation**: Restored the tolerance-based check using `abs(sum(...) - 1.0) > 1e-10` instead of exact equality
2. **Fixed weighted return calculation**: Changed from geometric mean back to simple weighted arithmetic mean
3. **Fixed volatility calculation**: Restored the original simplified portfolio volatility formula using weighted returns
4. **Fixed volatility return value**: Removed the hardcoded `2` and properly calculated and rounded the volatility
5. **Fixed asset selection**: Restored the original logic to find best/worst assets based on simple returns rather than risk-adjusted returns
6. **Fixed dictionary formatting**: Changed from double quotes to single quotes to match original formatting
These changes align the modified code with the original implementation's behavior, ensuring that the test for zero volatility passes (when cash investment has 0% return, the volatility should indeed be 0.0)."""
ctx = CodeRepairContext(CodeRepairContextData(original_code, optimized_code, ""), "" , "")
diff_patches = ctx.extract_diff_patches_from_llm_res(llm_response)
refined_optimization = ctx.apply_patches_to_optimized_code(diff_patches)
print(refined_optimization)
assert ctx.is_valid(refined_optimization)
"""
```python:demo.py
import math
from typing import List, Tuple, Optional
def calculate_portfolio_metrics(
investments: List[Tuple[str, float, float]],
risk_free_rate: float = 0.02
) -> dict:
if not investments:
raise ValueError("Investments list cannot be empty")
total_weight = sum(w for _, w, _ in investments)
if total_weight != 1.0: # Should use tolerance check
raise ValueError("Portfolio weights must sum to 1.0")
weighted_return = 1.0
for _, weight, ret in investments:
weighted_return *= (1 + ret) ** weight
weighted_return = weighted_return - 1.0 # Convert back from geometric
returns = [r for _, _, r in investments]
mean_return = sum(returns) / len(returns)
volatility = math.sqrt(sum((r - mean_return) ** 2 for r in returns) / len(returns))
# BUG 4: Sharpe ratio calculation is correct but uses wrong inputs
if volatility == 0:
sharpe_ratio = 0.0
else:
sharpe_ratio = (weighted_return - risk_free_rate) / volatility
def risk_adjusted_return(return_val, weight):
return (return_val - risk_free_rate) / (weight * return_val) if weight * return_val != 0 else return_val
best_asset = max(investments, key=lambda x: risk_adjusted_return(x[2], x[1]))
worst_asset = min(investments, key=lambda x: risk_adjusted_return(x[2], x[1]))
return {
"weighted_return": round(weighted_return, 6),
"volatility": 2,
"sharpe_ratio": round(sharpe_ratio, 6),
"best_performing": (best_asset[0], round(best_asset[2], 6)),
"worst_performing": (worst_asset[0], round(worst_asset[2], 6)),
"total_assets": len(investments),
}
```
"""
def test_patch_apply():
patch = """<<<<<<< SEARCH
total_weight = sum(w for _, w, _ in investments)
if total_weight != 1.0: # Should use tolerance check
raise ValueError("Portfolio weights must sum to 1.0")
=======
if abs(sum(weight for _, weight, _ in investments) - 1.0) > 1e-10:
raise ValueError("Portfolio weights must sum to 1.0")
>>>>>>> REPLACE
<<<<<<< SEARCH
weighted_return = 1.0
for _, weight, ret in investments:
weighted_return *= (1 + ret) ** weight
weighted_return = weighted_return - 1.0 # Convert back from geometric
=======
# Calculate weighted return
weighted_return = sum(weight * ret for _, weight, ret in investments)
>>>>>>> REPLACE
<<<<<<< SEARCH
returns = [r for _, _, r in investments]
mean_return = sum(returns) / len(returns)
volatility = math.sqrt(sum((r - mean_return) ** 2 for r in returns) / len(returns))
=======
# Calculate portfolio volatility (simplified)
volatility = math.sqrt(sum((weight * ret) ** 2 for _, weight, ret in investments))
>>>>>>> REPLACE
<<<<<<< SEARCH
# BUG 4: Sharpe ratio calculation is correct but uses wrong inputs
if volatility == 0:
sharpe_ratio = 0.0
else:
sharpe_ratio = (weighted_return - risk_free_rate) / volatility
def risk_adjusted_return(return_val, weight):
return (return_val - risk_free_rate) / (weight * return_val) if weight * return_val != 0 else return_val
best_asset = max(investments, key=lambda x: risk_adjusted_return(x[2], x[1]))
worst_asset = min(investments, key=lambda x: risk_adjusted_return(x[2], x[1]))
=======
# Calculate Sharpe ratio
if volatility == 0:
sharpe_ratio = 0.0
else:
sharpe_ratio = (weighted_return - risk_free_rate) / volatility
# Find best and worst performing assets
best_asset = max(investments, key=lambda x: x[2])
worst_asset = min(investments, key=lambda x: x[2])
>>>>>>> REPLACE
<<<<<<< SEARCH
return {
"weighted_return": round(weighted_return, 6),
"volatility": 2,
"sharpe_ratio": round(sharpe_ratio, 6),
"best_performing": (best_asset[0], round(best_asset[2], 6)),
"worst_performing": (worst_asset[0], round(worst_asset[2], 6)),
"total_assets": len(investments),
}
=======
return {
'weighted_return': round(weighted_return, 6),
'volatility': round(volatility, 6),
'sharpe_ratio': round(sharpe_ratio, 6),
'best_performing': (best_asset[0], round(best_asset[2], 6)),
'worst_performing': (worst_asset[0], round(worst_asset[2], 6)),
'total_assets': len(investments)
}
>>>>>>> REPLACE
"""
code = """import math
from typing import List, Tuple, Optional
def calculate_portfolio_metrics(
investments: List[Tuple[str, float, float]],
risk_free_rate: float = 0.02
) -> dict:
if not investments:
raise ValueError("Investments list cannot be empty")
total_weight = sum(w for _, w, _ in investments)
if total_weight != 1.0: # Should use tolerance check
raise ValueError("Portfolio weights must sum to 1.0")
weighted_return = 1.0
for _, weight, ret in investments:
weighted_return *= (1 + ret) ** weight
weighted_return = weighted_return - 1.0 # Convert back from geometric
returns = [r for _, _, r in investments]
mean_return = sum(returns) / len(returns)
volatility = math.sqrt(sum((r - mean_return) ** 2 for r in returns) / len(returns))
# BUG 4: Sharpe ratio calculation is correct but uses wrong inputs
if volatility == 0:
sharpe_ratio = 0.0
else:
sharpe_ratio = (weighted_return - risk_free_rate) / volatility
def risk_adjusted_return(return_val, weight):
return (return_val - risk_free_rate) / (weight * return_val) if weight * return_val != 0 else return_val
best_asset = max(investments, key=lambda x: risk_adjusted_return(x[2], x[1]))
worst_asset = min(investments, key=lambda x: risk_adjusted_return(x[2], x[1]))
return {
"weighted_return": round(weighted_return, 6),
"volatility": 2,
"sharpe_ratio": round(sharpe_ratio, 6),
"best_performing": (best_asset[0], round(best_asset[2], 6)),
"worst_performing": (worst_asset[0], round(worst_asset[2], 6)),
"total_assets": len(investments),
}
"""
new_code = apply_patches(patch, code)
print(new_code)

View file

@ -0,0 +1,475 @@
<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="UTF-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<title>Code Repair Logs Dashboard</title>
<style>
* {
box-sizing: border-box;
margin: 0;
padding: 0;
}
body {
font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, Oxygen, Ubuntu, sans-serif;
background: #1a1a2e;
color: #eee;
min-height: 100vh;
padding: 20px;
}
h1 {
text-align: center;
margin-bottom: 20px;
color: #00d9ff;
}
.stats {
display: flex;
gap: 20px;
justify-content: center;
margin-bottom: 30px;
flex-wrap: wrap;
}
.stat-card {
background: #16213e;
padding: 20px 30px;
border-radius: 10px;
text-align: center;
min-width: 150px;
}
.stat-card .number {
font-size: 2em;
font-weight: bold;
color: #00d9ff;
}
.stat-card .label {
color: #888;
font-size: 0.9em;
}
.search-bar {
display: flex;
gap: 10px;
margin-bottom: 20px;
justify-content: center;
flex-wrap: wrap;
}
.search-bar input {
padding: 10px 15px;
border: none;
border-radius: 5px;
background: #16213e;
color: #eee;
width: 300px;
font-size: 1em;
}
.search-bar input:focus {
outline: 2px solid #00d9ff;
}
.trace-group {
background: #16213e;
border-radius: 10px;
margin-bottom: 20px;
overflow: hidden;
}
.trace-header {
background: #0f3460;
padding: 15px 20px;
cursor: pointer;
display: flex;
justify-content: space-between;
align-items: center;
transition: background 0.2s;
}
.trace-header:hover {
background: #1a4a80;
}
.trace-header .trace-id {
font-family: monospace;
color: #00d9ff;
font-size: 0.95em;
}
.trace-header .count-badge {
background: #e94560;
color: white;
padding: 3px 10px;
border-radius: 15px;
font-size: 0.85em;
}
.trace-header .arrow {
transition: transform 0.2s;
}
.trace-header.expanded .arrow {
transform: rotate(180deg);
}
.trace-content {
display: none;
padding: 0;
}
.trace-content.show {
display: block;
}
.log-entry {
border-top: 1px solid #0f3460;
padding: 20px;
}
.log-entry:first-child {
border-top: none;
}
.log-entry-header {
display: flex;
justify-content: space-between;
align-items: center;
margin-bottom: 15px;
flex-wrap: wrap;
gap: 10px;
}
.optimization-id {
font-family: monospace;
font-size: 0.8em;
color: #888;
}
.timestamp {
font-size: 0.8em;
color: #888;
}
.section {
margin-bottom: 15px;
}
.section-title {
font-weight: bold;
color: #00d9ff;
margin-bottom: 8px;
font-size: 0.9em;
text-transform: uppercase;
}
.code-block {
background: #0d1117;
border-radius: 5px;
padding: 15px;
overflow-x: auto;
font-family: 'Fira Code', 'Consolas', monospace;
font-size: 0.85em;
line-height: 1.5;
white-space: pre-wrap;
word-wrap: break-word;
max-height: 400px;
overflow-y: auto;
}
.explanation-block {
background: #1e3a5f;
border-radius: 5px;
padding: 15px;
line-height: 1.6;
font-size: 0.9em;
max-height: 300px;
overflow-y: auto;
}
.refined-code {
background: #0d2818;
border: 1px solid #2ea043;
}
.expand-all-btn {
background: #0f3460;
color: #eee;
border: none;
padding: 10px 20px;
border-radius: 5px;
cursor: pointer;
font-size: 0.9em;
transition: background 0.2s;
}
.expand-all-btn:hover {
background: #1a4a80;
}
.no-results {
text-align: center;
padding: 40px;
color: #888;
}
.copy-btn {
background: #333;
color: #888;
border: none;
padding: 5px 10px;
border-radius: 3px;
cursor: pointer;
font-size: 0.75em;
float: right;
margin-bottom: 5px;
}
.copy-btn:hover {
background: #444;
color: #eee;
}
.status-badges {
display: flex;
gap: 8px;
margin-top: 10px;
}
.status-badge {
padding: 4px 12px;
border-radius: 4px;
font-size: 0.8em;
font-weight: 600;
text-transform: uppercase;
}
.status-badge.passed-true {
background: #238636;
color: #fff;
}
.status-badge.passed-false {
background: #da3633;
color: #fff;
}
.status-badge.faster-true {
background: #1f6feb;
color: #fff;
}
.status-badge.faster-false {
background: #6e7681;
color: #fff;
}
.status-badge.pending {
background: #484f58;
color: #8b949e;
}
.trace-status-summary {
display: flex;
gap: 8px;
align-items: center;
}
.trace-status-icon {
font-size: 1.1em;
}
</style>
</head>
<body>
<h1>Code Repair Logs Dashboard</h1>
<div class="stats">
<div class="stat-card">
<div class="number" id="total-traces">0</div>
<div class="label">Trace Groups</div>
</div>
<div class="stat-card">
<div class="number" id="total-logs">0</div>
<div class="label">Total Logs</div>
</div>
<div class="stat-card">
<div class="number" id="avg-per-trace">0</div>
<div class="label">Avg per Trace</div>
</div>
<div class="stat-card">
<div class="number" id="passed-count" style="color: #238636;">0</div>
<div class="label">Passed</div>
</div>
<div class="stat-card">
<div class="number" id="faster-count" style="color: #1f6feb;">0</div>
<div class="label">Faster</div>
</div>
</div>
<div class="search-bar">
<input type="text" id="search" placeholder="Search by trace ID, optimization ID, or content...">
<button class="expand-all-btn" onclick="toggleAll()">Expand All</button>
</div>
<div id="dashboard"></div>
<script>
// Data will be injected here
const data = DATA_PLACEHOLDER;
let allExpanded = false;
function escapeHtml(text) {
if (!text) return '';
const div = document.createElement('div');
div.textContent = text;
return div.innerHTML;
}
function formatDate(dateStr) {
if (!dateStr) return '';
const date = new Date(dateStr);
return date.toLocaleString();
}
function copyToClipboard(text, btn) {
navigator.clipboard.writeText(text).then(() => {
const originalText = btn.textContent;
btn.textContent = 'Copied!';
setTimeout(() => btn.textContent = originalText, 1500);
});
}
function getStatusBadge(value, label, type) {
if (value === null || value === undefined) {
return `<span class="status-badge pending">${label}: Pending</span>`;
}
const isTrue = value === 'True' || value === 'true' || value === true || value === 'yes';
const className = `${type}-${isTrue}`;
const displayValue = isTrue ? 'Yes' : 'No';
return `<span class="status-badge ${className}">${label}: ${displayValue}</span>`;
}
function getTraceStatusSummary(logs) {
const withStatus = logs.filter(l => l.passed !== null && l.passed !== undefined);
if (withStatus.length === 0) return '';
const passedCount = withStatus.filter(l => l.passed === 'True' || l.passed === 'true' || l.passed === true).length;
const fasterCount = withStatus.filter(l => l.faster === 'True' || l.faster === 'true' || l.faster === true).length;
const passedIcon = passedCount === withStatus.length ? '✓' : passedCount > 0 ? '◐' : '✗';
const fasterIcon = fasterCount === withStatus.length ? '⚡' : fasterCount > 0 ? '◐' : '';
return `<span class="trace-status-summary">
<span title="Passed: ${passedCount}/${withStatus.length}">${passedIcon}</span>
<span title="Faster: ${fasterCount}/${withStatus.length}">${fasterIcon}</span>
</span>`;
}
function renderDashboard(filteredData = data) {
const dashboard = document.getElementById('dashboard');
// Group by trace_id
const grouped = {};
filteredData.forEach(log => {
if (!grouped[log.trace_id]) {
grouped[log.trace_id] = [];
}
grouped[log.trace_id].push(log);
});
// Sort groups by most recent
const sortedTraces = Object.entries(grouped).sort((a, b) => {
const aDate = new Date(a[1][0].created_at);
const bDate = new Date(b[1][0].created_at);
return bDate - aDate;
});
// Update stats
document.getElementById('total-traces').textContent = sortedTraces.length;
document.getElementById('total-logs').textContent = filteredData.length;
document.getElementById('avg-per-trace').textContent = sortedTraces.length > 0
? (filteredData.length / sortedTraces.length).toFixed(1)
: '0';
// Calculate passed/faster stats
const withStatus = filteredData.filter(l => l.passed !== null && l.passed !== undefined);
const passedCount = withStatus.filter(l => l.passed === 'True' || l.passed === 'true' || l.passed === true).length;
const fasterCount = withStatus.filter(l => l.faster === 'True' || l.faster === 'true' || l.faster === true).length;
document.getElementById('passed-count').textContent = withStatus.length > 0 ? `${passedCount}/${withStatus.length}` : '';
document.getElementById('faster-count').textContent = withStatus.length > 0 ? `${fasterCount}/${withStatus.length}` : '';
if (sortedTraces.length === 0) {
dashboard.innerHTML = '<div class="no-results">No results found</div>';
return;
}
let html = '';
sortedTraces.forEach(([traceId, logs], idx) => {
// Sort logs within group by created_at
logs.sort((a, b) => new Date(a.created_at) - new Date(b.created_at));
html += `
<div class="trace-group">
<div class="trace-header" onclick="toggleTrace(${idx})">
<span class="trace-id">Trace: ${escapeHtml(traceId)}</span>
<div style="display: flex; align-items: center; gap: 10px;">
${getTraceStatusSummary(logs)}
<span class="count-badge">${logs.length} entries</span>
<span class="arrow"></span>
</div>
</div>
<div class="trace-content" id="trace-${idx}">
${logs.map((log, logIdx) => `
<div class="log-entry">
<div class="log-entry-header">
<span class="optimization-id">Optimization: ${escapeHtml(log.optimization_id)}</span>
<span class="timestamp">${formatDate(log.created_at)}</span>
</div>
<div class="status-badges">
${getStatusBadge(log.passed, 'Passed', 'passed')}
${getStatusBadge(log.faster, 'Faster', 'faster')}
</div>
<div class="section" style="margin-top: 15px;">
<div class="section-title">User Prompt</div>
<button class="copy-btn" onclick="copyToClipboard(data.find(d => d.optimization_id === '${log.optimization_id}').user_prompt, this)">Copy</button>
<div class="code-block">${escapeHtml(log.user_prompt)}</div>
</div>
<div class="section">
<div class="section-title">Explanation</div>
<div class="explanation-block">${escapeHtml(log.explanation)}</div>
</div>
<div class="section">
<div class="section-title">Refined Optimization</div>
<button class="copy-btn" onclick="copyToClipboard(data.find(d => d.optimization_id === '${log.optimization_id}').refined_optimization, this)">Copy</button>
<div class="code-block refined-code">${escapeHtml(log.refined_optimization)}</div>
</div>
</div>
`).join('')}
</div>
</div>`;
});
dashboard.innerHTML = html;
}
function toggleTrace(idx) {
const content = document.getElementById(`trace-${idx}`);
const header = content.previousElementSibling;
content.classList.toggle('show');
header.classList.toggle('expanded');
}
function toggleAll() {
const contents = document.querySelectorAll('.trace-content');
const headers = document.querySelectorAll('.trace-header');
allExpanded = !allExpanded;
contents.forEach(c => {
if (allExpanded) {
c.classList.add('show');
} else {
c.classList.remove('show');
}
});
headers.forEach(h => {
if (allExpanded) {
h.classList.add('expanded');
} else {
h.classList.remove('expanded');
}
});
document.querySelector('.expand-all-btn').textContent = allExpanded ? 'Collapse All' : 'Expand All';
}
document.getElementById('search').addEventListener('input', function(e) {
const query = e.target.value.toLowerCase();
if (!query) {
renderDashboard(data);
return;
}
const filtered = data.filter(log =>
(log.trace_id && log.trace_id.toLowerCase().includes(query)) ||
(log.optimization_id && log.optimization_id.toLowerCase().includes(query)) ||
(log.user_prompt && log.user_prompt.toLowerCase().includes(query)) ||
(log.explanation && log.explanation.toLowerCase().includes(query)) ||
(log.refined_optimization && log.refined_optimization.toLowerCase().includes(query))
);
renderDashboard(filtered);
});
// Initial render
renderDashboard();
</script>
</body>
</html>

View file

@ -0,0 +1,79 @@
#!/usr/bin/env python3
"""Generate an HTML dashboard from code_repair_logs SQLite database."""
import json
import sqlite3
import webbrowser
from pathlib import Path
def main():
db_path = Path(__file__).parent / "code_repair_log.db"
cf_db_path = Path(__file__).parent / "code_repair_logs_cf.db"
template_path = Path(__file__).parent / "code_repair_dashboard.html"
output_path = Path(__file__).parent / "code_repair_dashboard_live.html"
# Connect to main database and fetch all logs
conn = sqlite3.connect(db_path)
conn.row_factory = sqlite3.Row
cursor = conn.cursor()
cursor.execute("""
SELECT optimization_id, trace_id, user_prompt, explanation,
refined_optimization, created_at, updated_at
FROM code_repair_logs
ORDER BY created_at DESC
""")
rows = cursor.fetchall()
conn.close()
# Connect to cf database and fetch passed/faster status
cf_conn = sqlite3.connect(cf_db_path)
cf_conn.row_factory = sqlite3.Row
cf_cursor = cf_conn.cursor()
cf_cursor.execute("""
SELECT optimization_id, passed, faster
FROM code_repair_logs_cf
""")
cf_rows = cf_cursor.fetchall()
cf_conn.close()
# Create lookup dict for cf data
cf_data = {row["optimization_id"]: {"passed": row["passed"], "faster": row["faster"]} for row in cf_rows}
# Convert to list of dicts and merge cf data
data = []
for row in rows:
d = dict(row)
opt_id = d["optimization_id"][:-4] + "cdrp"
if opt_id in cf_data:
d["passed"] = cf_data[opt_id]["passed"]
d["faster"] = cf_data[opt_id]["faster"]
else:
d["passed"] = None
d["faster"] = None
data.append(d)
# Read template
template = template_path.read_text()
# Replace placeholder with actual data
json_data = json.dumps(data, default=str, indent=2)
html_content = template.replace("DATA_PLACEHOLDER", json_data)
# Write output
output_path.write_text(html_content)
print(f"Dashboard generated: {output_path}")
print(f"Total logs: {len(data)}")
print(f"Unique traces: {len(set(d['trace_id'] for d in data))}")
# Open in browser
webbrowser.open(f"file://{output_path.absolute()}")
if __name__ == "__main__":
main()

View file

@ -39,7 +39,7 @@ const CollapsibleCodeBlock = memo(
setIsExpanded(false); setIsExpanded(false);
} }
}, },
[isExpanded] [isExpanded],
); );
if (!hasMoreLines) { if (!hasMoreLines) {

View file

@ -23,7 +23,9 @@ const Tabs = () => {
> >
<span className={styles.tabText}> <span className={styles.tabText}>
Tasks Tasks
{tasksCount > 0 && <span className={styles.badge}>{tasksCount}</span>} {tasksCount > 0 && (
<span className={styles.badge}>{tasksCount}</span>
)}
</span> </span>
</button> </button>
</div> </div>

View file

@ -180,7 +180,12 @@ function renderStepsUI(steps) {
// Show "select different interpreter" option after install button // Show "select different interpreter" option after install button
if (vscodeActionCommand) { if (vscodeActionCommand) {
const { title, btnText: vscodeBtnText, command: vscodeCmd, args = [] } = vscodeActionCommand; const {
title,
btnText: vscodeBtnText,
command: vscodeCmd,
args = [],
} = vscodeActionCommand;
if (title) { if (title) {
const detailsElem = document.createElement("p"); const detailsElem = document.createElement("p");
detailsElem.className = "step-action"; detailsElem.className = "step-action";
@ -192,7 +197,11 @@ function renderStepsUI(steps) {
actionBtn.textContent = vscodeBtnText; actionBtn.textContent = vscodeBtnText;
actionBtn.className = "step-action-btn secondary-btn"; actionBtn.className = "step-action-btn secondary-btn";
actionBtn.addEventListener("click", () => { actionBtn.addEventListener("click", () => {
vscode.postMessage({ command: "vscodeCommand", cmd: vscodeCmd, args }); vscode.postMessage({
command: "vscodeCommand",
cmd: vscodeCmd,
args,
});
}); });
actionsContainer.appendChild(actionBtn); actionsContainer.appendChild(actionBtn);
} }

View file

@ -56,9 +56,7 @@ export async function getRepositoryPublicKey(
repo: string, repo: string,
): Promise<{ public_key: string; key_id: string }> { ): Promise<{ public_key: string; key_id: string }> {
try { try {
console.log( console.log(`[secret-utils.ts:getRepositoryPublicKey] Getting public key for ${owner}/${repo}`)
`[secret-utils.ts:getRepositoryPublicKey] Getting public key for ${owner}/${repo}`,
)
const response = await octokit.rest.actions.getRepoPublicKey({ const response = await octokit.rest.actions.getRepoPublicKey({
owner, owner,
@ -166,4 +164,3 @@ export async function encryptAndStoreSecret(
`[secret-utils.ts:encryptAndStoreSecret] Successfully encrypted and stored secret ${secretName} for ${owner}/${repo}`, `[secret-utils.ts:encryptAndStoreSecret] Successfully encrypted and stored secret ${secretName} for ${owner}/${repo}`,
) )
} }

View file

@ -0,0 +1,2 @@
-- AlterTable
ALTER TABLE "public"."optimization_features" ADD COLUMN "optimizations_origin" JSONB;

View file

@ -73,6 +73,7 @@ model optimization_features {
organization String? organization String?
repository String? repository String?
ranking Json? ranking Json?
optimizations_origin Json?
review_quality String? // Hight, Med, low review_quality String? // Hight, Med, low
review_explanation String? review_explanation String?
calling_fn_details String? calling_fn_details String?