mirror of
https://github.com/codeflash-ai/codeflash-internal.git
synced 2026-05-04 18:25:18 +00:00
feedback loop for unmatched test results (#2059)
fixes CF-932 # Pull Request Checklist ## Description - [ ] **Description of PR**: Clear and concise description of what this PR accomplishes - [ ] **Breaking Changes**: Document any breaking changes (if applicable) - [ ] **Related Issues**: Link to any related issues or tickets ## Testing - [ ] **Test cases Attached**: All relevant test cases have been added/updated - [ ] **Manual Testing**: Manual testing completed for the changes ## Monitoring & Debugging - [ ] **Logging in place**: Appropriate logging has been added for debugging user issues - [ ] **Sentry will be able to catch errors**: Error handling ensures Sentry can capture and report errors - [ ] **Avoid Dev based/Prisma logging**: No development-only or Prisma-specific logging in production code ## Configuration - [ ] **Env variables newly added**: Any new environment variables are documented in .env.example file or mentioned in description --- ## Additional Notes <!-- Add any additional context, screenshots, or notes for reviewers here --> --------- Co-authored-by: codeflash-ai[bot] <148906541+codeflash-ai[bot]@users.noreply.github.com> Co-authored-by: ali <mohammed18200118@gmail.com> Co-authored-by: Kevin Turcios <106575910+KRRT7@users.noreply.github.com>
This commit is contained in:
parent
96ea895c99
commit
1192df12a6
25 changed files with 1533 additions and 54 deletions
|
|
@ -219,3 +219,4 @@ RANKING_MODEL: LLM = _get_openai_model()
|
|||
REFINEMENT_MODEL: LLM = _get_anthropic_model()
|
||||
EXPLANATIONS_MODEL: LLM = _get_anthropic_model()
|
||||
OPTIMIZATION_REVIEW_MODEL: LLM = _get_anthropic_model()
|
||||
CODE_REPAIR_MODEL: LLM = _get_anthropic_model()
|
||||
|
|
|
|||
|
|
@ -20,6 +20,7 @@ Including another URLconf
|
|||
# from django.contrib import admin
|
||||
from django.urls import path
|
||||
|
||||
from code_repair.code_repair import code_repair_api
|
||||
from explanations.explanations import explanations_api
|
||||
from log_features.log_features import features_api
|
||||
from optimization_review.optimization_review import optimization_review_api
|
||||
|
|
@ -39,5 +40,6 @@ urlpatterns = [
|
|||
path("ai/explain", explanations_api.urls),
|
||||
path("ai/rank", ranker_api.urls),
|
||||
path("ai/optimization_review", optimization_review_api.urls),
|
||||
path("ai/code_repair", code_repair_api.urls),
|
||||
path("ai/workflow-gen", workflow_gen_api.urls),
|
||||
]
|
||||
|
|
|
|||
61
django/aiservice/code_repair/CODE_REPAIR_SYSTEM_PROMPT.md
Normal file
61
django/aiservice/code_repair/CODE_REPAIR_SYSTEM_PROMPT.md
Normal file
|
|
@ -0,0 +1,61 @@
|
|||
You are a senior software engineer who is great at reviewing and repairing python code for performance and behavior.
|
||||
The goal of repairing code is to ensure that the optimized code is performant and has the same behavior as the original code based on the provided test results.
|
||||
You are provided the following information:
|
||||
|
||||
- original_source_code
|
||||
- optimized_source_code - This is the optimized implementation of the original code.
|
||||
- test_details - This has the details of the behavioral differences between the original and optimized code.
|
||||
|
||||
### Output format
|
||||
|
||||
Request to replace sections of the optimized code in an existing file using SEARCH/REPLACE blocks that define exact changes to specific parts of the file.
|
||||
Define your output in XML-style tags like here:
|
||||
|
||||
<replace_in_file>
|
||||
<path>src/main.py</path>
|
||||
<diff>
|
||||
<<<<<<< SEARCH
|
||||
a = 2
|
||||
=======
|
||||
a = 3
|
||||
>>>>>>> REPLACE
|
||||
</diff>
|
||||
</replace_in_file>
|
||||
|
||||
Always adhere to this format for tool use to ensure proper parsing and execution.
|
||||
|
||||
## replace_in_file
|
||||
Description: Request to replace sections of content in an existing file using SEARCH/REPLACE blocks that define exact changes to specific parts of the file. This tool should be used when you need to make targeted changes to specific parts of a file.
|
||||
Parameters:
|
||||
- path: (required) The path of the file to modify
|
||||
- diff: (required) One or more SEARCH/REPLACE blocks following this exact format:
|
||||
```
|
||||
<<<<<<< SEARCH
|
||||
[exact content to find]
|
||||
=======
|
||||
[new content to replace with]
|
||||
>>>>>>> REPLACE
|
||||
```
|
||||
Critical rules:
|
||||
1. SEARCH content must match the associated file section to find EXACTLY:
|
||||
* Match character-for-character including whitespace, indentation, line endings
|
||||
* Include all comments, docstrings, etc.
|
||||
2. SEARCH/REPLACE blocks will ONLY replace the first match occurrence.
|
||||
* Including multiple unique SEARCH/REPLACE blocks if you need to make multiple changes.
|
||||
* Include *just* enough lines in each SEARCH section to uniquely match each set of lines that need to change.
|
||||
* When using multiple SEARCH/REPLACE blocks, list them in the order they appear in the file.
|
||||
3. Keep SEARCH/REPLACE blocks concise:
|
||||
* Break large SEARCH/REPLACE blocks into a series of smaller blocks that each change a small portion of the file.
|
||||
* Include just the changing lines, and a few surrounding lines if needed for uniqueness.
|
||||
* Do not include long runs of unchanging lines in SEARCH/REPLACE blocks.
|
||||
* Each line must be complete. Never truncate lines mid-way through as this can cause matching failures.
|
||||
4. Special operations:
|
||||
* To move code: Use two SEARCH/REPLACE blocks (one to delete from original + one to insert at new location)
|
||||
* To delete code: Use empty REPLACE section
|
||||
Usage:
|
||||
<replace_in_file>
|
||||
<path>File path here</path>
|
||||
<diff>
|
||||
Search and replace blocks here
|
||||
</diff>
|
||||
</replace_in_file>
|
||||
13
django/aiservice/code_repair/CODE_REPAIR_USER_PROMPT.md
Normal file
13
django/aiservice/code_repair/CODE_REPAIR_USER_PROMPT.md
Normal file
|
|
@ -0,0 +1,13 @@
|
|||
Please fix the optimized code to match the behaviour of the original code, while trying to keep the optimization logic intact.
|
||||
|
||||
### The original source code:
|
||||
|
||||
{original_source_code}
|
||||
|
||||
### The optimized source code
|
||||
|
||||
{modified_source_code}
|
||||
|
||||
### The test result details
|
||||
|
||||
{test_details}
|
||||
0
django/aiservice/code_repair/__init__.py
Normal file
0
django/aiservice/code_repair/__init__.py
Normal file
6
django/aiservice/code_repair/apps.py
Normal file
6
django/aiservice/code_repair/apps.py
Normal file
|
|
@ -0,0 +1,6 @@
|
|||
from django.apps import AppConfig
|
||||
|
||||
|
||||
class CodeRepairConfig(AppConfig):
|
||||
default_auto_field = "django.db.models.BigAutoField"
|
||||
name = "code_repair"
|
||||
204
django/aiservice/code_repair/code_repair.py
Normal file
204
django/aiservice/code_repair/code_repair.py
Normal file
|
|
@ -0,0 +1,204 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import uuid
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import libcst as cst
|
||||
import sentry_sdk
|
||||
from ninja import NinjaAPI, Schema
|
||||
from openai.types.chat import ChatCompletionSystemMessageParam, ChatCompletionUserMessageParam
|
||||
from pydantic import ValidationError
|
||||
|
||||
from aiservice.analytics.posthog import ph
|
||||
from aiservice.common_utils import validate_trace_id
|
||||
from aiservice.env_specific import debug_log_sensitive_data, llm_clients
|
||||
from aiservice.models.aimodels import CODE_REPAIR_MODEL, calculate_llm_cost
|
||||
from log_features.log_event import update_optimization_cost
|
||||
from log_features.log_features import log_features
|
||||
from optimizer.models import OptimizedCandidateSource
|
||||
|
||||
from .code_repair_context import ( # noqa: TC001 (don't move CodeRepairRequestSchema to type-checking because it's the schema definition)
|
||||
CodeRepairContext,
|
||||
CodeRepairContextData,
|
||||
CodeRepairRequestSchema,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from django.http import HttpRequest as Request
|
||||
from openai.types.chat import (
|
||||
ChatCompletionAssistantMessageParam,
|
||||
ChatCompletionFunctionMessageParam,
|
||||
ChatCompletionToolMessageParam,
|
||||
)
|
||||
|
||||
from aiservice.models.aimodels import LLM
|
||||
|
||||
code_repair_api = NinjaAPI(urls_namespace="code_repair")
|
||||
|
||||
# Get the directory of the current file
|
||||
current_dir = Path(__file__).parent
|
||||
SYSTEM_PROMPT = (current_dir / "CODE_REPAIR_SYSTEM_PROMPT.md").read_text()
|
||||
|
||||
USER_PROMPT = (current_dir / "CODE_REPAIR_USER_PROMPT.md").read_text()
|
||||
|
||||
|
||||
async def code_repair( # noqa: D417
|
||||
user_id: str, optimization_id: str, ctx: CodeRepairContext, optimize_model: LLM = CODE_REPAIR_MODEL
|
||||
) -> CodeRepairIntermediateResponseItemschema | CodeRepairErrorResponseSchema:
|
||||
"""Repair the given candidate to match the behaviour of the original code.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
:param user_id:
|
||||
:param optimization_id
|
||||
:param optimize_model: LLM for getting the code_repairs
|
||||
:param ctx: the repair context, has the data property which includes
|
||||
- original code
|
||||
- optimized code
|
||||
- behaviour test diffs
|
||||
Returns
|
||||
-------
|
||||
CodeRepairIntermediateResponseItemschema or CodeRepairErrorResponseSchema
|
||||
"""
|
||||
system_prompt = ctx.get_system_prompt()
|
||||
user_prompt = ctx.get_user_prompt()
|
||||
|
||||
new_op_id = str(uuid.uuid4())
|
||||
system_message = ChatCompletionSystemMessageParam(role="system", content=system_prompt)
|
||||
user_message = ChatCompletionUserMessageParam(role="user", content=user_prompt)
|
||||
messages: list[
|
||||
ChatCompletionSystemMessageParam
|
||||
| ChatCompletionUserMessageParam
|
||||
| ChatCompletionAssistantMessageParam
|
||||
| ChatCompletionToolMessageParam
|
||||
| ChatCompletionFunctionMessageParam
|
||||
] = [system_message, user_message]
|
||||
debug_log_sensitive_data(f"This was the user prompt\n {user_prompt}\n")
|
||||
llm_client = llm_clients[optimize_model.model_type]
|
||||
try:
|
||||
output = await llm_client.with_options(max_retries=2).chat.completions.create(
|
||||
model=optimize_model.name, messages=messages, n=1
|
||||
)
|
||||
llm_cost = calculate_llm_cost(output, optimize_model)
|
||||
except Exception as e:
|
||||
logging.exception("Claude Code Generation error in code_repair")
|
||||
sentry_sdk.capture_exception(e)
|
||||
debug_log_sensitive_data(f"Failed to generate code for source:\n{ctx.data.original_source_code}")
|
||||
return CodeRepairErrorResponseSchema(error=str(e))
|
||||
debug_log_sensitive_data(f"ClaudeClient optimization response:\n{output.model_dump_json(indent=2)}")
|
||||
if output.usage is not None:
|
||||
ph(user_id, "code_repair-usage", properties={"model": optimize_model.name, "usage": output.usage.json()})
|
||||
results = [content for op in output.choices if (content := op.message.content)] # will be of size 1
|
||||
|
||||
# Regex doesn't work yet in extracting everything else other than the search replace block
|
||||
explanation = results[0]
|
||||
|
||||
repaired_optimization = ""
|
||||
try:
|
||||
diff_patches = ctx.extract_diff_patches_from_llm_res(results[0])
|
||||
repaired_optimization = ctx.apply_patches_to_optimized_code(diff_patches)
|
||||
except (ValueError, ValidationError) as exc:
|
||||
sentry_sdk.capture_exception(exc)
|
||||
debug_log_sensitive_data(f"{type(exc).__name__} for source:\n{ctx.data.modified_source_code}")
|
||||
debug_log_sensitive_data(f"Traceback: {exc}")
|
||||
repaired_optimization = ""
|
||||
|
||||
if not ctx.is_valid(repaired_optimization):
|
||||
repaired_optimization = ""
|
||||
|
||||
return CodeRepairIntermediateResponseItemschema(
|
||||
optimization_id=new_op_id,
|
||||
parent_id=optimization_id,
|
||||
source_code=repaired_optimization,
|
||||
llm_cost=llm_cost,
|
||||
explanation=explanation,
|
||||
)
|
||||
|
||||
|
||||
class CodeRepairErrorResponseSchema(Schema):
|
||||
error: str
|
||||
|
||||
|
||||
class CodeRepairIntermediateResponseItemschema(Schema):
|
||||
# the key will be the optimization id and the value will be the actual refined code
|
||||
optimization_id: str
|
||||
parent_id: str
|
||||
source_code: str
|
||||
llm_cost: float
|
||||
explanation: str
|
||||
|
||||
|
||||
class CodeRepairResponseItemschema(Schema):
|
||||
# the key will be the optimization id and the value will be the actual refined code
|
||||
optimization_id: str
|
||||
parent_id: str
|
||||
source_code: str
|
||||
explanation: str
|
||||
|
||||
|
||||
@code_repair_api.post(
|
||||
"/",
|
||||
response={
|
||||
200: CodeRepairResponseItemschema,
|
||||
400: CodeRepairErrorResponseSchema,
|
||||
500: CodeRepairErrorResponseSchema,
|
||||
},
|
||||
)
|
||||
async def repair(
|
||||
request: Request, data: CodeRepairRequestSchema
|
||||
) -> tuple[int, CodeRepairResponseItemschema | CodeRepairErrorResponseSchema]:
|
||||
ph(request.user, "aiservice-code_repair-called")
|
||||
ctx_data = CodeRepairContextData(
|
||||
original_source_code=data.original_source_code,
|
||||
modified_source_code=data.modified_source_code,
|
||||
test_diffs=data.test_diffs,
|
||||
)
|
||||
ctx = CodeRepairContext(ctx_data=ctx_data, base_system_prompt=SYSTEM_PROMPT, base_user_prompt=USER_PROMPT)
|
||||
trace_id = data.trace_id
|
||||
if not validate_trace_id(trace_id):
|
||||
return 400, CodeRepairErrorResponseSchema(error="Invalid trace ID. Please provide a valid UUIDv4.")
|
||||
|
||||
code_repair_data = await code_repair(user_id=request.user, optimization_id=data.optimization_id, ctx=ctx)
|
||||
total_llm_cost = 0.0
|
||||
if isinstance(code_repair_data, CodeRepairErrorResponseSchema):
|
||||
return 500, code_repair_data
|
||||
total_llm_cost += code_repair_data.llm_cost
|
||||
try:
|
||||
ctx.validate_python_module()
|
||||
except cst.ParserSyntaxError as e:
|
||||
# log exception with sentry
|
||||
sentry_sdk.capture_exception(e)
|
||||
debug_log_sensitive_data(f"ParserSyntaxError for source:\n{code_repair_data.source_code}")
|
||||
debug_log_sensitive_data(f"Traceback: {e}")
|
||||
return 500, CodeRepairErrorResponseSchema(error=str(e))
|
||||
except (ValueError, ValidationError) as exc:
|
||||
# Another one bites the Pydantic validation dust
|
||||
sentry_sdk.capture_exception(exc)
|
||||
debug_log_sensitive_data(f"{type(exc).__name__} for source:\n{code_repair_data.source_code}")
|
||||
debug_log_sensitive_data(f"Traceback: {exc}")
|
||||
return 500, CodeRepairErrorResponseSchema(error=str(exc))
|
||||
|
||||
if hasattr(request, "should_log_features") and request.should_log_features:
|
||||
await log_features(
|
||||
trace_id=data.trace_id,
|
||||
user_id=request.user,
|
||||
optimizations_raw={code_repair_data.optimization_id: code_repair_data.source_code},
|
||||
explanations_raw={code_repair_data.optimization_id: code_repair_data.explanation},
|
||||
# explanations_post={cei.optimization_id: cei.explanation for cei in optimization_response_items},
|
||||
# optimizations_post={cei.optimization_id: cei.source_code for cei in optimization_response_items},
|
||||
optimizations_origin={
|
||||
code_repair_data.optimization_id: {
|
||||
"source": OptimizedCandidateSource.REPAIR,
|
||||
"parent": code_repair_data.parent_id,
|
||||
}
|
||||
},
|
||||
)
|
||||
await update_optimization_cost(trace_id=trace_id, cost=total_llm_cost)
|
||||
return 200, CodeRepairResponseItemschema(
|
||||
source_code=code_repair_data.source_code,
|
||||
optimization_id=code_repair_data.optimization_id,
|
||||
parent_id=code_repair_data.parent_id,
|
||||
explanation=code_repair_data.explanation,
|
||||
)
|
||||
153
django/aiservice/code_repair/code_repair_context.py
Normal file
153
django/aiservice/code_repair/code_repair_context.py
Normal file
|
|
@ -0,0 +1,153 @@
|
|||
import logging
|
||||
from collections import defaultdict
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
|
||||
import libcst as cst
|
||||
import sentry_sdk
|
||||
from aiservice.env_specific import debug_log_sensitive_data
|
||||
from ninja import Field, Schema
|
||||
from optimizer.context_utils.constants import REPLACE_IN_FILE_TAGS_REGEX
|
||||
from optimizer.context_utils.context_helpers import group_code, is_markdown_structure_changed, split_markdown_code
|
||||
from optimizer.diff_patches_utils.patches_v2 import apply_patches, group_diff_patches_by_path
|
||||
from optimizer.models import CodeAndExplanation
|
||||
from pydantic import ValidationError
|
||||
from testgen.instrumentation.edit_generated_test import parse_module_to_cst
|
||||
|
||||
|
||||
class TestDiffScope(str, Enum):
|
||||
RETURN_VALUE = "return_value"
|
||||
STDOUT = "stdout"
|
||||
DID_PASS = "did_pass" # noqa: S105
|
||||
TIMED_OUT = "timed_out"
|
||||
|
||||
|
||||
SCOPE_DESCRIPTIONS = {
|
||||
TestDiffScope.RETURN_VALUE: (
|
||||
"The function returned a different value in the optimized code compared to the original."
|
||||
),
|
||||
TestDiffScope.STDOUT: ("The output printed to stdout is different in the optimized code compared to the original."),
|
||||
TestDiffScope.DID_PASS: (
|
||||
"The test passed in one version but failed in the other (a change in pass/fail behavior)."
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
class TestDiff(Schema):
|
||||
scope: TestDiffScope
|
||||
original_value: bool | str | int | float | dict | list | None = None
|
||||
candidate_value: bool | str | int | float | dict | list | None = None
|
||||
original_pass: bool
|
||||
candidate_pass: bool
|
||||
test_src_code: str
|
||||
candidate_pytest_error: str | None = None
|
||||
original_pytest_error: str | None = None
|
||||
|
||||
|
||||
class CodeRepairRequestSchema(Schema):
|
||||
trace_id: str
|
||||
optimization_id: str
|
||||
original_source_code: str
|
||||
modified_source_code: str
|
||||
test_diffs: list[TestDiff] = Field(..., alias="test_diffs")
|
||||
|
||||
|
||||
@dataclass()
|
||||
class CodeRepairContextData:
|
||||
original_source_code: str
|
||||
modified_source_code: str
|
||||
test_diffs: list[TestDiff]
|
||||
|
||||
|
||||
class CodeRepairContext:
|
||||
def __init__(self, ctx_data: CodeRepairContextData, base_system_prompt: str, base_user_prompt: str) -> None:
|
||||
self.data = ctx_data
|
||||
self.base_system_prompt = base_system_prompt
|
||||
self.base_user_prompt = base_user_prompt
|
||||
|
||||
def get_system_prompt(self) -> str:
|
||||
return self.base_system_prompt
|
||||
|
||||
def build_test_details(self, test_diffs: list[TestDiff]) -> str:
|
||||
sections = defaultdict(str)
|
||||
for diff in test_diffs:
|
||||
try:
|
||||
if sections[diff.test_src_code] == "":
|
||||
# add error strings and test def only once per test function
|
||||
sections[diff.test_src_code] += f"""Test Source:
|
||||
```python
|
||||
{diff.test_src_code}
|
||||
```
|
||||
Pytest error (original code): {diff.original_pytest_error if diff.original_pytest_error else ""}
|
||||
Pytest error (optimized code): {diff.candidate_pytest_error if diff.candidate_pytest_error else ""}
|
||||
"""
|
||||
sections[diff.test_src_code] += "\n".join(
|
||||
[
|
||||
f"{SCOPE_DESCRIPTIONS.get(diff.scope, diff.scope.value)}",
|
||||
f"Expected: {diff.original_value!r}.\nGot: {diff.candidate_value!r}."
|
||||
if diff.scope != TestDiffScope.DID_PASS
|
||||
else "",
|
||||
f"Original code test status: {'Passed' if diff.original_pass else 'Failed'}. Optimized code test status: {'Passed' if diff.candidate_pass else 'Failed'}",
|
||||
"---",
|
||||
]
|
||||
)
|
||||
except Exception as e:
|
||||
logging.exception("Some issue in parsing test diffs")
|
||||
sentry_sdk.capture_exception(e)
|
||||
return "\n".join(sections.values())
|
||||
|
||||
def get_user_prompt(self) -> str:
|
||||
return self.base_user_prompt.format(
|
||||
original_source_code=self.data.original_source_code,
|
||||
modified_source_code=self.data.modified_source_code,
|
||||
test_details=self.build_test_details(self.data.test_diffs),
|
||||
)
|
||||
|
||||
def extract_diff_patches_from_llm_res(self, llm_res: str) -> str:
|
||||
matches = REPLACE_IN_FILE_TAGS_REGEX.findall(llm_res)
|
||||
replace_tags = ""
|
||||
if matches and len(matches) != 0:
|
||||
replace_tags = f"<replace_in_file>{matches[0]}</replace_in_file>"
|
||||
|
||||
return replace_tags
|
||||
|
||||
def apply_patches_to_optimized_code(self, replace_tags: str) -> str:
|
||||
if replace_tags == "":
|
||||
return ""
|
||||
|
||||
file_to_code = split_markdown_code(self.data.modified_source_code)
|
||||
# sometimes the llm can write multiple replace tags for the same file, so we group them by path to avoid parsing & applying multiple times
|
||||
file_to_diffs = group_diff_patches_by_path(replace_tags)
|
||||
|
||||
for path, diff in file_to_diffs.items():
|
||||
scoped_code = file_to_code.get(path, None)
|
||||
if scoped_code is None:
|
||||
debug_log_sensitive_data(f"no scoped code for {path}, existing: {file_to_code.keys()}")
|
||||
continue
|
||||
new_code = apply_patches(diff, scoped_code)
|
||||
file_to_code[path] = new_code
|
||||
return group_code(file_to_code)
|
||||
|
||||
def is_valid(self, new_refined_code: str) -> bool:
|
||||
if is_markdown_structure_changed(new_refined_code, self.data.modified_source_code):
|
||||
return False
|
||||
valid = True
|
||||
for code in split_markdown_code(new_refined_code).values():
|
||||
stripped_code = code.strip()
|
||||
if not stripped_code:
|
||||
valid = False
|
||||
break
|
||||
try:
|
||||
parse_module_to_cst(code)
|
||||
except cst.ParserSyntaxError:
|
||||
valid = False
|
||||
break
|
||||
return valid
|
||||
|
||||
def validate_python_module(self) -> None:
|
||||
for _code in split_markdown_code(self.data.modified_source_code).values():
|
||||
try:
|
||||
cst_module = parse_module_to_cst(_code)
|
||||
CodeAndExplanation(cst_module, "")
|
||||
except (ValueError, ValidationError, cst.ParserSyntaxError): # noqa: TRY203
|
||||
raise
|
||||
|
|
@ -44,6 +44,7 @@ def log_features(
|
|||
experiment_metadata: dict[str, str] | None = None,
|
||||
final_explanation: str | None = None,
|
||||
ranking: dict[str, Any] | None = None,
|
||||
optimizations_origin: dict[str, dict[str, str]] | None = None,
|
||||
) -> None:
|
||||
"""Log features of a code optimization run to the database.
|
||||
|
||||
|
|
@ -95,6 +96,7 @@ def log_features(
|
|||
"experiment_metadata": experiment_metadata,
|
||||
"final_explanation": final_explanation,
|
||||
"ranking": ranking,
|
||||
"optimizations_origin": optimizations_origin,
|
||||
},
|
||||
)
|
||||
|
||||
|
|
@ -148,6 +150,14 @@ def log_features(
|
|||
f.ranking = f.ranking | ranking if ranking is not None else f.ranking
|
||||
else:
|
||||
f.ranking = ranking if ranking is not None else f.ranking
|
||||
|
||||
if f.optimizations_origin is not None:
|
||||
# merge the optimizations_origin with the existing ones
|
||||
f.optimizations_origin = merge_dicts(f.optimizations_origin, optimizations_origin or {})
|
||||
else:
|
||||
f.optimizations_origin = (
|
||||
optimizations_origin if optimizations_origin is not None else f.optimizations_origin
|
||||
)
|
||||
f.speedup_ratio = speedup_ratio if speedup_ratio is not None else f.speedup_ratio
|
||||
f.original_runtime = original_runtime if original_runtime is not None else f.original_runtime
|
||||
f.optimized_runtime = optimized_runtime if optimized_runtime is not None else f.optimized_runtime
|
||||
|
|
@ -165,6 +175,22 @@ def log_features(
|
|||
f.save()
|
||||
|
||||
|
||||
def merge_dicts(a: dict[str, dict[str, str]], b: dict[str, dict[str, str]]) -> dict[str, dict[str, str]]:
|
||||
result: dict[str, dict[str, str]] = {}
|
||||
|
||||
for key, inner in a.items():
|
||||
result[key] = inner.copy()
|
||||
|
||||
for key, inner in b.items():
|
||||
if key not in result:
|
||||
result[key] = inner.copy()
|
||||
else:
|
||||
# b overrides a
|
||||
result[key].update(inner)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
@features_api.post("/", response={200: None, 500: LoggingErrorResponseSchema})
|
||||
async def log_features_cli(request: HttpRequest, data: LoggingSchema) -> int | tuple[int, LoggingErrorResponseSchema]:
|
||||
try:
|
||||
|
|
|
|||
|
|
@ -31,6 +31,7 @@ class OptimizationFeatures(models.Model):
|
|||
experiment_metadata = models.JSONField(null=True, blank=True)
|
||||
final_explanation = models.TextField(null=True, blank=True)
|
||||
ranking = models.JSONField(null=True, blank=True)
|
||||
optimizations_origin = models.JSONField(null=True, blank=True)
|
||||
|
||||
# PR suggestions or create Approval fields
|
||||
approval_required = models.BooleanField(default=False)
|
||||
|
|
|
|||
|
|
@ -1,5 +1,7 @@
|
|||
from pydantic import ValidationError
|
||||
|
||||
from optimizer.context_utils.constants import MULTI_REPLACE_IN_FILE_TAGS_REGEX
|
||||
|
||||
|
||||
class SearchReplaceBlock:
|
||||
def __init__(self, search, replace):
|
||||
|
|
@ -19,36 +21,44 @@ def parse_diff(diff: str) -> list[SearchReplaceBlock]:
|
|||
|
||||
blocks: list[SearchReplaceBlock] = []
|
||||
lines = diff.splitlines(keepends=True)
|
||||
n = len(lines)
|
||||
idx = 0
|
||||
|
||||
while idx < len(lines):
|
||||
line = lines[idx].strip()
|
||||
if line == "<<<<<<< SEARCH":
|
||||
search_lines = []
|
||||
# Precompute the "marker" strings for efficiency
|
||||
search_marker = "<<<<<<< SEARCH"
|
||||
delimiter_marker = "======="
|
||||
replace_marker = ">>>>>>> REPLACE"
|
||||
|
||||
while idx < n:
|
||||
line_stripped = lines[idx].strip()
|
||||
if line_stripped.startswith(search_marker):
|
||||
idx += 1
|
||||
|
||||
while idx < len(lines) and lines[idx].strip() != "=======":
|
||||
search_lines.append(lines[idx])
|
||||
search_start = idx
|
||||
# Find delimiter_marker line
|
||||
while idx < n and lines[idx].strip() != delimiter_marker:
|
||||
idx += 1
|
||||
search_end = idx
|
||||
|
||||
if idx >= len(lines):
|
||||
if idx >= n:
|
||||
raise ValueError("Invalid diff format: Missing '=======' marker")
|
||||
|
||||
replace_lines = []
|
||||
idx += 1
|
||||
|
||||
while idx < len(lines) and lines[idx].strip() != ">>>>>>> REPLACE":
|
||||
replace_lines.append(lines[idx])
|
||||
replace_start = idx
|
||||
while idx < n and not lines[idx].strip().startswith(replace_marker):
|
||||
idx += 1
|
||||
replace_end = idx
|
||||
|
||||
if idx >= len(lines):
|
||||
raise ValueError("Invalid diff format: Missing '>>>>>>> REPLACE' marker")
|
||||
if idx >= n:
|
||||
raise ValueError(
|
||||
"Invalid diff format: Missing '>>>>>>> REPLACE' marker"
|
||||
)
|
||||
|
||||
search_content = "".join(search_lines).rstrip()
|
||||
replace_content = "".join(replace_lines).rstrip()
|
||||
search_content = "".join(lines[search_start:search_end]).rstrip()
|
||||
replace_content = "".join(lines[replace_start:replace_end]).rstrip()
|
||||
|
||||
try:
|
||||
block = SearchReplaceBlock.from_block(search=search_content, replace=replace_content)
|
||||
block = SearchReplaceBlock.from_block(
|
||||
search=search_content, replace=replace_content
|
||||
)
|
||||
blocks.append(block)
|
||||
except ValidationError as ve:
|
||||
raise ValueError(f"Invalid block format: {ve}")
|
||||
|
|
@ -61,13 +71,35 @@ def parse_diff(diff: str) -> list[SearchReplaceBlock]:
|
|||
return blocks
|
||||
|
||||
|
||||
def group_diff_patches_by_path(replace_tags_str: str) -> dict[str, str]:
|
||||
matches = MULTI_REPLACE_IN_FILE_TAGS_REGEX.findall(replace_tags_str)
|
||||
file_to_diffs = {}
|
||||
|
||||
current_file = None
|
||||
current_diff = ""
|
||||
|
||||
for path, diff in matches:
|
||||
if path != current_file:
|
||||
if current_file:
|
||||
file_to_diffs[current_file] = current_diff
|
||||
current_file = path
|
||||
current_diff = diff
|
||||
else:
|
||||
current_diff += diff
|
||||
|
||||
if current_file:
|
||||
file_to_diffs[current_file] = current_diff
|
||||
|
||||
return file_to_diffs
|
||||
|
||||
|
||||
def apply_patches(diff_str: str, content: str) -> str:
|
||||
try:
|
||||
patch_blocks = parse_diff(diff_str)
|
||||
except ValueError:
|
||||
return content
|
||||
|
||||
for idx, block in enumerate(patch_blocks, 1):
|
||||
for block in patch_blocks:
|
||||
if not block.search:
|
||||
if block.replace:
|
||||
# a replacement block without a search, then just add the replace block
|
||||
|
|
@ -79,5 +111,7 @@ def apply_patches(diff_str: str, content: str) -> str:
|
|||
start_char_idx = content.find(block.search)
|
||||
if start_char_idx != -1:
|
||||
end_char_idx = start_char_idx + len(block.search)
|
||||
content = f"{content[:start_char_idx]}{block.replace}{content[end_char_idx:]}"
|
||||
content = (
|
||||
f"{content[:start_char_idx]}{block.replace}{content[end_char_idx:]}"
|
||||
)
|
||||
return content
|
||||
|
|
|
|||
|
|
@ -1,9 +1,18 @@
|
|||
import enum
|
||||
|
||||
import libcst
|
||||
from ninja import Schema
|
||||
from pydantic import field_validator
|
||||
from pydantic.dataclasses import dataclass
|
||||
|
||||
|
||||
class OptimizedCandidateSource(str, enum.Enum):
|
||||
OPTIMIZE = "OPTIMIZE"
|
||||
OPTIMIZE_LP = "OPTIMIZE_LP"
|
||||
REFINE = "REFINE"
|
||||
REPAIR = "REPAIR"
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class CodeAndExplanation:
|
||||
cst_module: libcst.Module | None
|
||||
|
|
|
|||
|
|
@ -8,6 +8,11 @@ from typing import TYPE_CHECKING
|
|||
|
||||
import libcst as cst
|
||||
import sentry_sdk
|
||||
from ninja import NinjaAPI
|
||||
from ninja.errors import HttpError
|
||||
from openai.types.chat import ChatCompletionSystemMessageParam, ChatCompletionUserMessageParam
|
||||
from pydantic import ValidationError
|
||||
|
||||
from aiservice.analytics.posthog import ph
|
||||
from aiservice.common_utils import parse_python_version, should_hack_for_demo, validate_trace_id
|
||||
from aiservice.env_specific import debug_log_sensitive_data, debug_log_sensitive_data_from_callable, llm_clients
|
||||
|
|
@ -15,11 +20,6 @@ from aiservice.models.aimodels import OPTIMIZE_MODEL, calculate_llm_cost
|
|||
from authapp.user import get_user_by_id
|
||||
from log_features.log_event import log_optimization_event
|
||||
from log_features.log_features import log_features
|
||||
from ninja import NinjaAPI
|
||||
from ninja.errors import HttpError
|
||||
from openai.types.chat import ChatCompletionSystemMessageParam, ChatCompletionUserMessageParam
|
||||
from pydantic import ValidationError
|
||||
|
||||
from optimizer.context_utils.context_helpers import group_code
|
||||
from optimizer.context_utils.optimizer_context import (
|
||||
BaseOptimizerContext,
|
||||
|
|
@ -27,10 +27,9 @@ from optimizer.context_utils.optimizer_context import (
|
|||
OptimizeResponseItemSchema,
|
||||
OptimizeResponseSchema,
|
||||
)
|
||||
from optimizer.models import OptimizeSchema # noqa: TC001
|
||||
from optimizer.models import OptimizedCandidateSource, OptimizeSchema # noqa: TC001
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from aiservice.models.aimodels import LLM
|
||||
from django.http import HttpRequest
|
||||
from openai.types.chat import (
|
||||
ChatCompletionAssistantMessageParam,
|
||||
|
|
@ -38,6 +37,8 @@ if TYPE_CHECKING:
|
|||
ChatCompletionToolMessageParam,
|
||||
)
|
||||
|
||||
from aiservice.models.aimodels import LLM
|
||||
|
||||
|
||||
optimizations_json = [
|
||||
{
|
||||
|
|
@ -345,7 +346,7 @@ async def optimize(
|
|||
},
|
||||
explanations_post={cei.optimization_id: cei.explanation for cei in optimization_response_items},
|
||||
experiment_metadata=data.experiment_metadata if data.experiment_metadata else None,
|
||||
# request=request,
|
||||
optimizations_origin={cei.optimization_id: {"source": OptimizedCandidateSource.OPTIMIZE, "parent": None} for cei in optimization_response_items},
|
||||
)
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -5,29 +5,30 @@ from pathlib import Path
|
|||
from typing import TYPE_CHECKING
|
||||
|
||||
import sentry_sdk
|
||||
from ninja import NinjaAPI, Schema
|
||||
from openai.types.chat import ChatCompletionSystemMessageParam, ChatCompletionUserMessageParam
|
||||
|
||||
from aiservice.analytics.posthog import ph
|
||||
from aiservice.common_utils import parse_python_version, validate_trace_id
|
||||
from aiservice.env_specific import debug_log_sensitive_data, debug_log_sensitive_data_from_callable, llm_clients
|
||||
from aiservice.models.aimodels import OPTIMIZE_MODEL, calculate_llm_cost
|
||||
from log_features.log_event import update_optimization_cost
|
||||
from log_features.log_features import log_features
|
||||
from ninja import NinjaAPI, Schema
|
||||
from openai.types.chat import ChatCompletionSystemMessageParam, ChatCompletionUserMessageParam
|
||||
|
||||
from optimizer.context_utils.optimizer_context import (
|
||||
BaseOptimizerContext,
|
||||
OptimizeErrorResponseSchema,
|
||||
OptimizeResponseSchema,
|
||||
)
|
||||
from optimizer.models import OptimizedCandidateSource
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from aiservice.models.aimodels import LLM
|
||||
from openai.types.chat import (
|
||||
ChatCompletionAssistantMessageParam,
|
||||
ChatCompletionFunctionMessageParam,
|
||||
ChatCompletionToolMessageParam,
|
||||
)
|
||||
|
||||
from aiservice.models.aimodels import LLM
|
||||
from optimizer.context_utils.optimizer_context import OptimizeResponseItemSchema
|
||||
|
||||
|
||||
|
|
@ -190,6 +191,8 @@ async def optimize(request, data: OptimizeSchemaLP) -> tuple[int, OptimizeRespon
|
|||
},
|
||||
explanations_post={cei.optimization_id: cei.explanation for cei in optimization_response_items},
|
||||
experiment_metadata=data.experiment_metadata if data.experiment_metadata else None,
|
||||
optimizations_origin={cei.optimization_id: {"source": OptimizedCandidateSource.OPTIMIZE_LP, "parent": None} for cei in optimization_response_items},
|
||||
|
||||
)
|
||||
|
||||
response = OptimizeResponseSchema(optimizations=optimization_response_items)
|
||||
|
|
|
|||
|
|
@ -2,31 +2,34 @@ from __future__ import annotations
|
|||
|
||||
import asyncio
|
||||
import logging
|
||||
import uuid
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import libcst as cst
|
||||
import sentry_sdk
|
||||
from ninja import NinjaAPI, Schema
|
||||
from openai.types.chat import ChatCompletionSystemMessageParam, ChatCompletionUserMessageParam
|
||||
from pydantic import ValidationError
|
||||
|
||||
from aiservice.analytics.posthog import ph
|
||||
from aiservice.common_utils import validate_trace_id
|
||||
from aiservice.env_specific import debug_log_sensitive_data, llm_clients
|
||||
from aiservice.models.aimodels import REFINEMENT_MODEL, calculate_llm_cost
|
||||
from log_features.log_event import update_optimization_cost
|
||||
from log_features.log_features import log_features
|
||||
from ninja import NinjaAPI, Schema
|
||||
from openai.types.chat import ChatCompletionSystemMessageParam, ChatCompletionUserMessageParam
|
||||
from pydantic import ValidationError
|
||||
|
||||
from optimizer.context_utils.refiner_context import BaseRefinerContext, RefinementContextData
|
||||
from optimizer.models import OptimizedCandidateSource
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from aiservice.models.aimodels import LLM
|
||||
from openai.types.chat import (
|
||||
ChatCompletionAssistantMessageParam,
|
||||
ChatCompletionFunctionMessageParam,
|
||||
ChatCompletionToolMessageParam,
|
||||
)
|
||||
|
||||
from aiservice.models.aimodels import LLM
|
||||
|
||||
|
||||
refinement_api = NinjaAPI(urls_namespace="refinement")
|
||||
|
||||
|
|
@ -260,7 +263,8 @@ async def refinement( # noqa: D417
|
|||
refined_optimization = ""
|
||||
|
||||
return RefinementIntermediateResponseItemschema(
|
||||
optimization_id=optimization_id,
|
||||
parent_id=optimization_id,
|
||||
optimization_id=str(uuid.uuid4()),
|
||||
source_code=refined_optimization,
|
||||
explanation=refined_explanation,
|
||||
original_explanation=ctx.data.optimized_explanation,
|
||||
|
|
@ -291,6 +295,7 @@ class OptimizeErrorResponseSchema(Schema):
|
|||
class RefinementIntermediateResponseItemschema(Schema):
|
||||
# the key will be the optimization id and the value will be the actual refined code
|
||||
explanation: str
|
||||
parent_id: str
|
||||
optimization_id: str
|
||||
source_code: str
|
||||
original_explanation: str
|
||||
|
|
@ -301,6 +306,7 @@ class RefinementResponseItemschema(Schema):
|
|||
# the key will be the optimization id and the value will be the actual refined code
|
||||
explanation: str
|
||||
optimization_id: str
|
||||
parent_id: str
|
||||
source_code: str
|
||||
|
||||
|
||||
|
|
@ -374,27 +380,30 @@ async def refine(
|
|||
trace_id=trace_id,
|
||||
user_id=request.user,
|
||||
optimizations_raw={
|
||||
cei.optimization_id[:-4] + "refi": cei.source_code
|
||||
cei.optimization_id: cei.source_code
|
||||
for cei in refinement_data
|
||||
if not isinstance(cei, OptimizeErrorResponseSchema)
|
||||
},
|
||||
optimizations_post={
|
||||
cei.optimization_id[:-4] + "refi": cei.source_code for cei in filtered_refined_optimizations
|
||||
},
|
||||
optimizations_post={cei.optimization_id: cei.source_code for cei in filtered_refined_optimizations},
|
||||
explanations_raw={
|
||||
cei.optimization_id[:-4] + "refi": cei.explanation
|
||||
cei.optimization_id: cei.explanation
|
||||
for cei in refinement_data
|
||||
if not isinstance(cei, OptimizeErrorResponseSchema)
|
||||
},
|
||||
explanations_post={
|
||||
cei.optimization_id[:-4] + "refi": cei.explanation for cei in filtered_refined_optimizations
|
||||
explanations_post={cei.optimization_id: cei.explanation for cei in filtered_refined_optimizations},
|
||||
optimizations_origin={
|
||||
cei.optimization_id: {"source": OptimizedCandidateSource.REFINE, "parent": cei.parent_id}
|
||||
for cei in filtered_refined_optimizations
|
||||
},
|
||||
)
|
||||
await update_optimization_cost(trace_id=trace_id, cost=total_llm_cost)
|
||||
return 200, Refinementschema(
|
||||
refinements=[
|
||||
RefinementResponseItemschema(
|
||||
source_code=x.source_code, explanation=x.original_explanation, optimization_id=x.optimization_id
|
||||
source_code=x.source_code,
|
||||
explanation=x.original_explanation,
|
||||
optimization_id=x.optimization_id,
|
||||
parent_id=x.parent_id,
|
||||
)
|
||||
for x in filtered_refined_optimizations
|
||||
]
|
||||
|
|
|
|||
|
|
@ -3,6 +3,7 @@ from __future__ import annotations
|
|||
from typing import TYPE_CHECKING, cast
|
||||
|
||||
from libcst import CSTTransformer, ImportAlias, ImportFrom, MetadataWrapper, Name, parse_expression, parse_module
|
||||
from functools import lru_cache
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from libcst import (
|
||||
|
|
@ -19,6 +20,7 @@ if TYPE_CHECKING:
|
|||
from aiservice.models.functions_to_optimize import FunctionToOptimize
|
||||
|
||||
|
||||
@lru_cache(maxsize=128)
|
||||
def parse_module_to_cst(module_str: str) -> Module:
|
||||
"""Parse a module string into its libCST representation.
|
||||
|
||||
|
|
|
|||
389
django/aiservice/tests/optimizer/test_code_repair.py
Normal file
389
django/aiservice/tests/optimizer/test_code_repair.py
Normal file
|
|
@ -0,0 +1,389 @@
|
|||
|
||||
|
||||
from code_repair.code_repair_context import CodeRepairContext, CodeRepairContextData
|
||||
from optimizer.diff_patches_utils.patches_v2 import apply_patches
|
||||
|
||||
|
||||
def test_code_repair_single_file():
|
||||
|
||||
original_code = """```python:demo.py
|
||||
import math
|
||||
from typing import List, Tuple, Optional
|
||||
|
||||
def calculate_portfolio_metrics(
|
||||
investments: List[Tuple[str, float, float]],
|
||||
risk_free_rate: float = 0.02
|
||||
) -> dict:
|
||||
if not investments:
|
||||
raise ValueError("Investments list cannot be empty")
|
||||
|
||||
if abs(sum(weight for _, weight, _ in investments) - 1.0) > 1e-10:
|
||||
raise ValueError("Portfolio weights must sum to 1.0")
|
||||
|
||||
# Calculate weighted return
|
||||
weighted_return = sum(weight * ret for _, weight, ret in investments)
|
||||
|
||||
# Calculate portfolio volatility (simplified)
|
||||
volatility = math.sqrt(sum((weight * ret) ** 2 for _, weight, ret in investments))
|
||||
|
||||
# Calculate Sharpe ratio
|
||||
if volatility == 0:
|
||||
sharpe_ratio = 0.0
|
||||
else:
|
||||
sharpe_ratio = (weighted_return - risk_free_rate) / volatility
|
||||
|
||||
# Find best and worst performing assets
|
||||
best_asset = max(investments, key=lambda x: x[2])
|
||||
worst_asset = min(investments, key=lambda x: x[2])
|
||||
|
||||
return {
|
||||
'weighted_return': round(weighted_return, 6),
|
||||
'volatility': round(volatility, 6),
|
||||
'sharpe_ratio': round(sharpe_ratio, 6),
|
||||
'best_performing': (best_asset[0], round(best_asset[2], 6)),
|
||||
'worst_performing': (worst_asset[0], round(worst_asset[2], 6)),
|
||||
'total_assets': len(investments)
|
||||
}
|
||||
```
|
||||
"""
|
||||
optimized_code = """```python:demo.py
|
||||
import math
|
||||
from typing import List, Tuple, Optional
|
||||
|
||||
def calculate_portfolio_metrics(
|
||||
investments: List[Tuple[str, float, float]],
|
||||
risk_free_rate: float = 0.02
|
||||
) -> dict:
|
||||
if not investments:
|
||||
raise ValueError("Investments list cannot be empty")
|
||||
|
||||
total_weight = sum(w for _, w, _ in investments)
|
||||
if total_weight != 1.0: # Should use tolerance check
|
||||
raise ValueError("Portfolio weights must sum to 1.0")
|
||||
|
||||
weighted_return = 1.0
|
||||
for _, weight, ret in investments:
|
||||
weighted_return *= (1 + ret) ** weight
|
||||
weighted_return = weighted_return - 1.0 # Convert back from geometric
|
||||
|
||||
returns = [r for _, _, r in investments]
|
||||
mean_return = sum(returns) / len(returns)
|
||||
volatility = math.sqrt(sum((r - mean_return) ** 2 for r in returns) / len(returns))
|
||||
|
||||
# BUG 4: Sharpe ratio calculation is correct but uses wrong inputs
|
||||
if volatility == 0:
|
||||
sharpe_ratio = 0.0
|
||||
else:
|
||||
sharpe_ratio = (weighted_return - risk_free_rate) / volatility
|
||||
|
||||
def risk_adjusted_return(return_val, weight):
|
||||
return (return_val - risk_free_rate) / (weight * return_val) if weight * return_val != 0 else return_val
|
||||
|
||||
best_asset = max(investments, key=lambda x: risk_adjusted_return(x[2], x[1]))
|
||||
worst_asset = min(investments, key=lambda x: risk_adjusted_return(x[2], x[1]))
|
||||
|
||||
return {
|
||||
"weighted_return": round(weighted_return, 6),
|
||||
"volatility": 2,
|
||||
"sharpe_ratio": round(sharpe_ratio, 6),
|
||||
"best_performing": (best_asset[0], round(best_asset[2], 6)),
|
||||
"worst_performing": (worst_asset[0], round(worst_asset[2], 6)),
|
||||
"total_assets": len(investments),
|
||||
}
|
||||
```
|
||||
"""
|
||||
|
||||
|
||||
llm_response = """Looking at the test failure and comparing the original and modified code, I can see several issues in the modified implementation:
|
||||
|
||||
1. **Hardcoded volatility**: The volatility is hardcoded to `2` instead of being calculated
|
||||
2. **Incorrect weight validation**: Missing tolerance check for weight sum validation
|
||||
3. **Wrong volatility calculation**: Using standard deviation of returns instead of weighted portfolio volatility
|
||||
4. **Wrong weighted return calculation**: Using geometric mean instead of arithmetic weighted return
|
||||
5. **Wrong asset selection criteria**: Using risk-adjusted return instead of simple return
|
||||
|
||||
Let me fix these issues:
|
||||
|
||||
<replace_in_file>
|
||||
<path>demo.py</path>
|
||||
<diff>
|
||||
<<<<<<< SEARCH
|
||||
total_weight = sum(w for _, w, _ in investments)
|
||||
if total_weight != 1.0: # Should use tolerance check
|
||||
raise ValueError("Portfolio weights must sum to 1.0")
|
||||
=======
|
||||
if abs(sum(weight for _, weight, _ in investments) - 1.0) > 1e-10:
|
||||
raise ValueError("Portfolio weights must sum to 1.0")
|
||||
>>>>>>> REPLACE
|
||||
</diff>
|
||||
</replace_in_file>
|
||||
|
||||
<replace_in_file>
|
||||
<path>demo.py</path>
|
||||
<diff>
|
||||
<<<<<<< SEARCH
|
||||
weighted_return = 1.0
|
||||
for _, weight, ret in investments:
|
||||
weighted_return *= (1 + ret) ** weight
|
||||
weighted_return = weighted_return - 1.0 # Convert back from geometric
|
||||
=======
|
||||
# Calculate weighted return
|
||||
weighted_return = sum(weight * ret for _, weight, ret in investments)
|
||||
>>>>>>> REPLACE
|
||||
</diff>
|
||||
</replace_in_file>
|
||||
|
||||
<replace_in_file>
|
||||
<path>demo.py</path>
|
||||
<diff>
|
||||
<<<<<<< SEARCH
|
||||
returns = [r for _, _, r in investments]
|
||||
mean_return = sum(returns) / len(returns)
|
||||
volatility = math.sqrt(sum((r - mean_return) ** 2 for r in returns) / len(returns))
|
||||
=======
|
||||
# Calculate portfolio volatility (simplified)
|
||||
volatility = math.sqrt(sum((weight * ret) ** 2 for _, weight, ret in investments))
|
||||
>>>>>>> REPLACE
|
||||
</diff>
|
||||
</replace_in_file>
|
||||
|
||||
<replace_in_file>
|
||||
<path>demo.py</path>
|
||||
<diff>
|
||||
<<<<<<< SEARCH
|
||||
# BUG 4: Sharpe ratio calculation is correct but uses wrong inputs
|
||||
if volatility == 0:
|
||||
sharpe_ratio = 0.0
|
||||
else:
|
||||
sharpe_ratio = (weighted_return - risk_free_rate) / volatility
|
||||
|
||||
def risk_adjusted_return(return_val, weight):
|
||||
return (return_val - risk_free_rate) / (weight * return_val) if weight * return_val != 0 else return_val
|
||||
|
||||
best_asset = max(investments, key=lambda x: risk_adjusted_return(x[2], x[1]))
|
||||
worst_asset = min(investments, key=lambda x: risk_adjusted_return(x[2], x[1]))
|
||||
=======
|
||||
# Calculate Sharpe ratio
|
||||
if volatility == 0:
|
||||
sharpe_ratio = 0.0
|
||||
else:
|
||||
sharpe_ratio = (weighted_return - risk_free_rate) / volatility
|
||||
|
||||
# Find best and worst performing assets
|
||||
best_asset = max(investments, key=lambda x: x[2])
|
||||
worst_asset = min(investments, key=lambda x: x[2])
|
||||
>>>>>>> REPLACE
|
||||
</diff>
|
||||
</replace_in_file>
|
||||
|
||||
<replace_in_file>
|
||||
<path>demo.py</path>
|
||||
<diff>
|
||||
<<<<<<< SEARCH
|
||||
return {
|
||||
"weighted_return": round(weighted_return, 6),
|
||||
"volatility": 2,
|
||||
"sharpe_ratio": round(sharpe_ratio, 6),
|
||||
"best_performing": (best_asset[0], round(best_asset[2], 6)),
|
||||
"worst_performing": (worst_asset[0], round(worst_asset[2], 6)),
|
||||
"total_assets": len(investments),
|
||||
}
|
||||
=======
|
||||
return {
|
||||
'weighted_return': round(weighted_return, 6),
|
||||
'volatility': round(volatility, 6),
|
||||
'sharpe_ratio': round(sharpe_ratio, 6),
|
||||
'best_performing': (best_asset[0], round(best_asset[2], 6)),
|
||||
'worst_performing': (worst_asset[0], round(worst_asset[2], 6)),
|
||||
'total_assets': len(investments)
|
||||
}
|
||||
>>>>>>> REPLACE
|
||||
</diff>
|
||||
</replace_in_file>
|
||||
|
||||
The key changes I made:
|
||||
|
||||
1. **Fixed weight validation**: Restored the tolerance-based check using `abs(sum(...) - 1.0) > 1e-10` instead of exact equality
|
||||
2. **Fixed weighted return calculation**: Changed from geometric mean back to simple weighted arithmetic mean
|
||||
3. **Fixed volatility calculation**: Restored the original simplified portfolio volatility formula using weighted returns
|
||||
4. **Fixed volatility return value**: Removed the hardcoded `2` and properly calculated and rounded the volatility
|
||||
5. **Fixed asset selection**: Restored the original logic to find best/worst assets based on simple returns rather than risk-adjusted returns
|
||||
6. **Fixed dictionary formatting**: Changed from double quotes to single quotes to match original formatting
|
||||
|
||||
These changes align the modified code with the original implementation's behavior, ensuring that the test for zero volatility passes (when cash investment has 0% return, the volatility should indeed be 0.0)."""
|
||||
ctx = CodeRepairContext(CodeRepairContextData(original_code, optimized_code, ""), "" , "")
|
||||
diff_patches = ctx.extract_diff_patches_from_llm_res(llm_response)
|
||||
refined_optimization = ctx.apply_patches_to_optimized_code(diff_patches)
|
||||
|
||||
print(refined_optimization)
|
||||
assert ctx.is_valid(refined_optimization)
|
||||
|
||||
|
||||
"""
|
||||
```python:demo.py
|
||||
import math
|
||||
from typing import List, Tuple, Optional
|
||||
|
||||
def calculate_portfolio_metrics(
|
||||
investments: List[Tuple[str, float, float]],
|
||||
risk_free_rate: float = 0.02
|
||||
) -> dict:
|
||||
if not investments:
|
||||
raise ValueError("Investments list cannot be empty")
|
||||
|
||||
total_weight = sum(w for _, w, _ in investments)
|
||||
if total_weight != 1.0: # Should use tolerance check
|
||||
raise ValueError("Portfolio weights must sum to 1.0")
|
||||
|
||||
weighted_return = 1.0
|
||||
for _, weight, ret in investments:
|
||||
weighted_return *= (1 + ret) ** weight
|
||||
weighted_return = weighted_return - 1.0 # Convert back from geometric
|
||||
|
||||
returns = [r for _, _, r in investments]
|
||||
mean_return = sum(returns) / len(returns)
|
||||
volatility = math.sqrt(sum((r - mean_return) ** 2 for r in returns) / len(returns))
|
||||
|
||||
# BUG 4: Sharpe ratio calculation is correct but uses wrong inputs
|
||||
if volatility == 0:
|
||||
sharpe_ratio = 0.0
|
||||
else:
|
||||
sharpe_ratio = (weighted_return - risk_free_rate) / volatility
|
||||
|
||||
def risk_adjusted_return(return_val, weight):
|
||||
return (return_val - risk_free_rate) / (weight * return_val) if weight * return_val != 0 else return_val
|
||||
|
||||
best_asset = max(investments, key=lambda x: risk_adjusted_return(x[2], x[1]))
|
||||
worst_asset = min(investments, key=lambda x: risk_adjusted_return(x[2], x[1]))
|
||||
|
||||
return {
|
||||
"weighted_return": round(weighted_return, 6),
|
||||
"volatility": 2,
|
||||
"sharpe_ratio": round(sharpe_ratio, 6),
|
||||
"best_performing": (best_asset[0], round(best_asset[2], 6)),
|
||||
"worst_performing": (worst_asset[0], round(worst_asset[2], 6)),
|
||||
"total_assets": len(investments),
|
||||
}
|
||||
```
|
||||
"""
|
||||
|
||||
|
||||
def test_patch_apply():
|
||||
patch = """<<<<<<< SEARCH
|
||||
total_weight = sum(w for _, w, _ in investments)
|
||||
if total_weight != 1.0: # Should use tolerance check
|
||||
raise ValueError("Portfolio weights must sum to 1.0")
|
||||
=======
|
||||
if abs(sum(weight for _, weight, _ in investments) - 1.0) > 1e-10:
|
||||
raise ValueError("Portfolio weights must sum to 1.0")
|
||||
>>>>>>> REPLACE
|
||||
|
||||
<<<<<<< SEARCH
|
||||
weighted_return = 1.0
|
||||
for _, weight, ret in investments:
|
||||
weighted_return *= (1 + ret) ** weight
|
||||
weighted_return = weighted_return - 1.0 # Convert back from geometric
|
||||
=======
|
||||
# Calculate weighted return
|
||||
weighted_return = sum(weight * ret for _, weight, ret in investments)
|
||||
>>>>>>> REPLACE
|
||||
|
||||
<<<<<<< SEARCH
|
||||
returns = [r for _, _, r in investments]
|
||||
mean_return = sum(returns) / len(returns)
|
||||
volatility = math.sqrt(sum((r - mean_return) ** 2 for r in returns) / len(returns))
|
||||
=======
|
||||
# Calculate portfolio volatility (simplified)
|
||||
volatility = math.sqrt(sum((weight * ret) ** 2 for _, weight, ret in investments))
|
||||
>>>>>>> REPLACE
|
||||
|
||||
<<<<<<< SEARCH
|
||||
# BUG 4: Sharpe ratio calculation is correct but uses wrong inputs
|
||||
if volatility == 0:
|
||||
sharpe_ratio = 0.0
|
||||
else:
|
||||
sharpe_ratio = (weighted_return - risk_free_rate) / volatility
|
||||
|
||||
def risk_adjusted_return(return_val, weight):
|
||||
return (return_val - risk_free_rate) / (weight * return_val) if weight * return_val != 0 else return_val
|
||||
|
||||
best_asset = max(investments, key=lambda x: risk_adjusted_return(x[2], x[1]))
|
||||
worst_asset = min(investments, key=lambda x: risk_adjusted_return(x[2], x[1]))
|
||||
=======
|
||||
# Calculate Sharpe ratio
|
||||
if volatility == 0:
|
||||
sharpe_ratio = 0.0
|
||||
else:
|
||||
sharpe_ratio = (weighted_return - risk_free_rate) / volatility
|
||||
|
||||
# Find best and worst performing assets
|
||||
best_asset = max(investments, key=lambda x: x[2])
|
||||
worst_asset = min(investments, key=lambda x: x[2])
|
||||
>>>>>>> REPLACE
|
||||
|
||||
<<<<<<< SEARCH
|
||||
return {
|
||||
"weighted_return": round(weighted_return, 6),
|
||||
"volatility": 2,
|
||||
"sharpe_ratio": round(sharpe_ratio, 6),
|
||||
"best_performing": (best_asset[0], round(best_asset[2], 6)),
|
||||
"worst_performing": (worst_asset[0], round(worst_asset[2], 6)),
|
||||
"total_assets": len(investments),
|
||||
}
|
||||
=======
|
||||
return {
|
||||
'weighted_return': round(weighted_return, 6),
|
||||
'volatility': round(volatility, 6),
|
||||
'sharpe_ratio': round(sharpe_ratio, 6),
|
||||
'best_performing': (best_asset[0], round(best_asset[2], 6)),
|
||||
'worst_performing': (worst_asset[0], round(worst_asset[2], 6)),
|
||||
'total_assets': len(investments)
|
||||
}
|
||||
>>>>>>> REPLACE
|
||||
"""
|
||||
code = """import math
|
||||
from typing import List, Tuple, Optional
|
||||
|
||||
def calculate_portfolio_metrics(
|
||||
investments: List[Tuple[str, float, float]],
|
||||
risk_free_rate: float = 0.02
|
||||
) -> dict:
|
||||
if not investments:
|
||||
raise ValueError("Investments list cannot be empty")
|
||||
|
||||
total_weight = sum(w for _, w, _ in investments)
|
||||
if total_weight != 1.0: # Should use tolerance check
|
||||
raise ValueError("Portfolio weights must sum to 1.0")
|
||||
|
||||
weighted_return = 1.0
|
||||
for _, weight, ret in investments:
|
||||
weighted_return *= (1 + ret) ** weight
|
||||
weighted_return = weighted_return - 1.0 # Convert back from geometric
|
||||
|
||||
returns = [r for _, _, r in investments]
|
||||
mean_return = sum(returns) / len(returns)
|
||||
volatility = math.sqrt(sum((r - mean_return) ** 2 for r in returns) / len(returns))
|
||||
|
||||
# BUG 4: Sharpe ratio calculation is correct but uses wrong inputs
|
||||
if volatility == 0:
|
||||
sharpe_ratio = 0.0
|
||||
else:
|
||||
sharpe_ratio = (weighted_return - risk_free_rate) / volatility
|
||||
|
||||
def risk_adjusted_return(return_val, weight):
|
||||
return (return_val - risk_free_rate) / (weight * return_val) if weight * return_val != 0 else return_val
|
||||
|
||||
best_asset = max(investments, key=lambda x: risk_adjusted_return(x[2], x[1]))
|
||||
worst_asset = min(investments, key=lambda x: risk_adjusted_return(x[2], x[1]))
|
||||
|
||||
return {
|
||||
"weighted_return": round(weighted_return, 6),
|
||||
"volatility": 2,
|
||||
"sharpe_ratio": round(sharpe_ratio, 6),
|
||||
"best_performing": (best_asset[0], round(best_asset[2], 6)),
|
||||
"worst_performing": (worst_asset[0], round(worst_asset[2], 6)),
|
||||
"total_assets": len(investments),
|
||||
}
|
||||
"""
|
||||
new_code = apply_patches(patch, code)
|
||||
print(new_code)
|
||||
475
experiments/code_repair_dashboard.html
Normal file
475
experiments/code_repair_dashboard.html
Normal file
|
|
@ -0,0 +1,475 @@
|
|||
<!DOCTYPE html>
|
||||
<html lang="en">
|
||||
<head>
|
||||
<meta charset="UTF-8">
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
||||
<title>Code Repair Logs Dashboard</title>
|
||||
<style>
|
||||
* {
|
||||
box-sizing: border-box;
|
||||
margin: 0;
|
||||
padding: 0;
|
||||
}
|
||||
body {
|
||||
font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, Oxygen, Ubuntu, sans-serif;
|
||||
background: #1a1a2e;
|
||||
color: #eee;
|
||||
min-height: 100vh;
|
||||
padding: 20px;
|
||||
}
|
||||
h1 {
|
||||
text-align: center;
|
||||
margin-bottom: 20px;
|
||||
color: #00d9ff;
|
||||
}
|
||||
.stats {
|
||||
display: flex;
|
||||
gap: 20px;
|
||||
justify-content: center;
|
||||
margin-bottom: 30px;
|
||||
flex-wrap: wrap;
|
||||
}
|
||||
.stat-card {
|
||||
background: #16213e;
|
||||
padding: 20px 30px;
|
||||
border-radius: 10px;
|
||||
text-align: center;
|
||||
min-width: 150px;
|
||||
}
|
||||
.stat-card .number {
|
||||
font-size: 2em;
|
||||
font-weight: bold;
|
||||
color: #00d9ff;
|
||||
}
|
||||
.stat-card .label {
|
||||
color: #888;
|
||||
font-size: 0.9em;
|
||||
}
|
||||
.search-bar {
|
||||
display: flex;
|
||||
gap: 10px;
|
||||
margin-bottom: 20px;
|
||||
justify-content: center;
|
||||
flex-wrap: wrap;
|
||||
}
|
||||
.search-bar input {
|
||||
padding: 10px 15px;
|
||||
border: none;
|
||||
border-radius: 5px;
|
||||
background: #16213e;
|
||||
color: #eee;
|
||||
width: 300px;
|
||||
font-size: 1em;
|
||||
}
|
||||
.search-bar input:focus {
|
||||
outline: 2px solid #00d9ff;
|
||||
}
|
||||
.trace-group {
|
||||
background: #16213e;
|
||||
border-radius: 10px;
|
||||
margin-bottom: 20px;
|
||||
overflow: hidden;
|
||||
}
|
||||
.trace-header {
|
||||
background: #0f3460;
|
||||
padding: 15px 20px;
|
||||
cursor: pointer;
|
||||
display: flex;
|
||||
justify-content: space-between;
|
||||
align-items: center;
|
||||
transition: background 0.2s;
|
||||
}
|
||||
.trace-header:hover {
|
||||
background: #1a4a80;
|
||||
}
|
||||
.trace-header .trace-id {
|
||||
font-family: monospace;
|
||||
color: #00d9ff;
|
||||
font-size: 0.95em;
|
||||
}
|
||||
.trace-header .count-badge {
|
||||
background: #e94560;
|
||||
color: white;
|
||||
padding: 3px 10px;
|
||||
border-radius: 15px;
|
||||
font-size: 0.85em;
|
||||
}
|
||||
.trace-header .arrow {
|
||||
transition: transform 0.2s;
|
||||
}
|
||||
.trace-header.expanded .arrow {
|
||||
transform: rotate(180deg);
|
||||
}
|
||||
.trace-content {
|
||||
display: none;
|
||||
padding: 0;
|
||||
}
|
||||
.trace-content.show {
|
||||
display: block;
|
||||
}
|
||||
.log-entry {
|
||||
border-top: 1px solid #0f3460;
|
||||
padding: 20px;
|
||||
}
|
||||
.log-entry:first-child {
|
||||
border-top: none;
|
||||
}
|
||||
.log-entry-header {
|
||||
display: flex;
|
||||
justify-content: space-between;
|
||||
align-items: center;
|
||||
margin-bottom: 15px;
|
||||
flex-wrap: wrap;
|
||||
gap: 10px;
|
||||
}
|
||||
.optimization-id {
|
||||
font-family: monospace;
|
||||
font-size: 0.8em;
|
||||
color: #888;
|
||||
}
|
||||
.timestamp {
|
||||
font-size: 0.8em;
|
||||
color: #888;
|
||||
}
|
||||
.section {
|
||||
margin-bottom: 15px;
|
||||
}
|
||||
.section-title {
|
||||
font-weight: bold;
|
||||
color: #00d9ff;
|
||||
margin-bottom: 8px;
|
||||
font-size: 0.9em;
|
||||
text-transform: uppercase;
|
||||
}
|
||||
.code-block {
|
||||
background: #0d1117;
|
||||
border-radius: 5px;
|
||||
padding: 15px;
|
||||
overflow-x: auto;
|
||||
font-family: 'Fira Code', 'Consolas', monospace;
|
||||
font-size: 0.85em;
|
||||
line-height: 1.5;
|
||||
white-space: pre-wrap;
|
||||
word-wrap: break-word;
|
||||
max-height: 400px;
|
||||
overflow-y: auto;
|
||||
}
|
||||
.explanation-block {
|
||||
background: #1e3a5f;
|
||||
border-radius: 5px;
|
||||
padding: 15px;
|
||||
line-height: 1.6;
|
||||
font-size: 0.9em;
|
||||
max-height: 300px;
|
||||
overflow-y: auto;
|
||||
}
|
||||
.refined-code {
|
||||
background: #0d2818;
|
||||
border: 1px solid #2ea043;
|
||||
}
|
||||
.expand-all-btn {
|
||||
background: #0f3460;
|
||||
color: #eee;
|
||||
border: none;
|
||||
padding: 10px 20px;
|
||||
border-radius: 5px;
|
||||
cursor: pointer;
|
||||
font-size: 0.9em;
|
||||
transition: background 0.2s;
|
||||
}
|
||||
.expand-all-btn:hover {
|
||||
background: #1a4a80;
|
||||
}
|
||||
.no-results {
|
||||
text-align: center;
|
||||
padding: 40px;
|
||||
color: #888;
|
||||
}
|
||||
.copy-btn {
|
||||
background: #333;
|
||||
color: #888;
|
||||
border: none;
|
||||
padding: 5px 10px;
|
||||
border-radius: 3px;
|
||||
cursor: pointer;
|
||||
font-size: 0.75em;
|
||||
float: right;
|
||||
margin-bottom: 5px;
|
||||
}
|
||||
.copy-btn:hover {
|
||||
background: #444;
|
||||
color: #eee;
|
||||
}
|
||||
.status-badges {
|
||||
display: flex;
|
||||
gap: 8px;
|
||||
margin-top: 10px;
|
||||
}
|
||||
.status-badge {
|
||||
padding: 4px 12px;
|
||||
border-radius: 4px;
|
||||
font-size: 0.8em;
|
||||
font-weight: 600;
|
||||
text-transform: uppercase;
|
||||
}
|
||||
.status-badge.passed-true {
|
||||
background: #238636;
|
||||
color: #fff;
|
||||
}
|
||||
.status-badge.passed-false {
|
||||
background: #da3633;
|
||||
color: #fff;
|
||||
}
|
||||
.status-badge.faster-true {
|
||||
background: #1f6feb;
|
||||
color: #fff;
|
||||
}
|
||||
.status-badge.faster-false {
|
||||
background: #6e7681;
|
||||
color: #fff;
|
||||
}
|
||||
.status-badge.pending {
|
||||
background: #484f58;
|
||||
color: #8b949e;
|
||||
}
|
||||
.trace-status-summary {
|
||||
display: flex;
|
||||
gap: 8px;
|
||||
align-items: center;
|
||||
}
|
||||
.trace-status-icon {
|
||||
font-size: 1.1em;
|
||||
}
|
||||
</style>
|
||||
</head>
|
||||
<body>
|
||||
<h1>Code Repair Logs Dashboard</h1>
|
||||
|
||||
<div class="stats">
|
||||
<div class="stat-card">
|
||||
<div class="number" id="total-traces">0</div>
|
||||
<div class="label">Trace Groups</div>
|
||||
</div>
|
||||
<div class="stat-card">
|
||||
<div class="number" id="total-logs">0</div>
|
||||
<div class="label">Total Logs</div>
|
||||
</div>
|
||||
<div class="stat-card">
|
||||
<div class="number" id="avg-per-trace">0</div>
|
||||
<div class="label">Avg per Trace</div>
|
||||
</div>
|
||||
<div class="stat-card">
|
||||
<div class="number" id="passed-count" style="color: #238636;">0</div>
|
||||
<div class="label">Passed</div>
|
||||
</div>
|
||||
<div class="stat-card">
|
||||
<div class="number" id="faster-count" style="color: #1f6feb;">0</div>
|
||||
<div class="label">Faster</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div class="search-bar">
|
||||
<input type="text" id="search" placeholder="Search by trace ID, optimization ID, or content...">
|
||||
<button class="expand-all-btn" onclick="toggleAll()">Expand All</button>
|
||||
</div>
|
||||
|
||||
<div id="dashboard"></div>
|
||||
|
||||
<script>
|
||||
// Data will be injected here
|
||||
const data = DATA_PLACEHOLDER;
|
||||
|
||||
let allExpanded = false;
|
||||
|
||||
function escapeHtml(text) {
|
||||
if (!text) return '';
|
||||
const div = document.createElement('div');
|
||||
div.textContent = text;
|
||||
return div.innerHTML;
|
||||
}
|
||||
|
||||
function formatDate(dateStr) {
|
||||
if (!dateStr) return '';
|
||||
const date = new Date(dateStr);
|
||||
return date.toLocaleString();
|
||||
}
|
||||
|
||||
function copyToClipboard(text, btn) {
|
||||
navigator.clipboard.writeText(text).then(() => {
|
||||
const originalText = btn.textContent;
|
||||
btn.textContent = 'Copied!';
|
||||
setTimeout(() => btn.textContent = originalText, 1500);
|
||||
});
|
||||
}
|
||||
|
||||
function getStatusBadge(value, label, type) {
|
||||
if (value === null || value === undefined) {
|
||||
return `<span class="status-badge pending">${label}: Pending</span>`;
|
||||
}
|
||||
const isTrue = value === 'True' || value === 'true' || value === true || value === 'yes';
|
||||
const className = `${type}-${isTrue}`;
|
||||
const displayValue = isTrue ? 'Yes' : 'No';
|
||||
return `<span class="status-badge ${className}">${label}: ${displayValue}</span>`;
|
||||
}
|
||||
|
||||
function getTraceStatusSummary(logs) {
|
||||
const withStatus = logs.filter(l => l.passed !== null && l.passed !== undefined);
|
||||
if (withStatus.length === 0) return '';
|
||||
|
||||
const passedCount = withStatus.filter(l => l.passed === 'True' || l.passed === 'true' || l.passed === true).length;
|
||||
const fasterCount = withStatus.filter(l => l.faster === 'True' || l.faster === 'true' || l.faster === true).length;
|
||||
|
||||
const passedIcon = passedCount === withStatus.length ? '✓' : passedCount > 0 ? '◐' : '✗';
|
||||
const fasterIcon = fasterCount === withStatus.length ? '⚡' : fasterCount > 0 ? '◐' : '−';
|
||||
|
||||
return `<span class="trace-status-summary">
|
||||
<span title="Passed: ${passedCount}/${withStatus.length}">${passedIcon}</span>
|
||||
<span title="Faster: ${fasterCount}/${withStatus.length}">${fasterIcon}</span>
|
||||
</span>`;
|
||||
}
|
||||
|
||||
function renderDashboard(filteredData = data) {
|
||||
const dashboard = document.getElementById('dashboard');
|
||||
|
||||
// Group by trace_id
|
||||
const grouped = {};
|
||||
filteredData.forEach(log => {
|
||||
if (!grouped[log.trace_id]) {
|
||||
grouped[log.trace_id] = [];
|
||||
}
|
||||
grouped[log.trace_id].push(log);
|
||||
});
|
||||
|
||||
// Sort groups by most recent
|
||||
const sortedTraces = Object.entries(grouped).sort((a, b) => {
|
||||
const aDate = new Date(a[1][0].created_at);
|
||||
const bDate = new Date(b[1][0].created_at);
|
||||
return bDate - aDate;
|
||||
});
|
||||
|
||||
// Update stats
|
||||
document.getElementById('total-traces').textContent = sortedTraces.length;
|
||||
document.getElementById('total-logs').textContent = filteredData.length;
|
||||
document.getElementById('avg-per-trace').textContent = sortedTraces.length > 0
|
||||
? (filteredData.length / sortedTraces.length).toFixed(1)
|
||||
: '0';
|
||||
|
||||
// Calculate passed/faster stats
|
||||
const withStatus = filteredData.filter(l => l.passed !== null && l.passed !== undefined);
|
||||
const passedCount = withStatus.filter(l => l.passed === 'True' || l.passed === 'true' || l.passed === true).length;
|
||||
const fasterCount = withStatus.filter(l => l.faster === 'True' || l.faster === 'true' || l.faster === true).length;
|
||||
document.getElementById('passed-count').textContent = withStatus.length > 0 ? `${passedCount}/${withStatus.length}` : '−';
|
||||
document.getElementById('faster-count').textContent = withStatus.length > 0 ? `${fasterCount}/${withStatus.length}` : '−';
|
||||
|
||||
if (sortedTraces.length === 0) {
|
||||
dashboard.innerHTML = '<div class="no-results">No results found</div>';
|
||||
return;
|
||||
}
|
||||
|
||||
let html = '';
|
||||
sortedTraces.forEach(([traceId, logs], idx) => {
|
||||
// Sort logs within group by created_at
|
||||
logs.sort((a, b) => new Date(a.created_at) - new Date(b.created_at));
|
||||
|
||||
html += `
|
||||
<div class="trace-group">
|
||||
<div class="trace-header" onclick="toggleTrace(${idx})">
|
||||
<span class="trace-id">Trace: ${escapeHtml(traceId)}</span>
|
||||
<div style="display: flex; align-items: center; gap: 10px;">
|
||||
${getTraceStatusSummary(logs)}
|
||||
<span class="count-badge">${logs.length} entries</span>
|
||||
<span class="arrow">▼</span>
|
||||
</div>
|
||||
</div>
|
||||
<div class="trace-content" id="trace-${idx}">
|
||||
${logs.map((log, logIdx) => `
|
||||
<div class="log-entry">
|
||||
<div class="log-entry-header">
|
||||
<span class="optimization-id">Optimization: ${escapeHtml(log.optimization_id)}</span>
|
||||
<span class="timestamp">${formatDate(log.created_at)}</span>
|
||||
</div>
|
||||
|
||||
<div class="status-badges">
|
||||
${getStatusBadge(log.passed, 'Passed', 'passed')}
|
||||
${getStatusBadge(log.faster, 'Faster', 'faster')}
|
||||
</div>
|
||||
|
||||
<div class="section" style="margin-top: 15px;">
|
||||
<div class="section-title">User Prompt</div>
|
||||
<button class="copy-btn" onclick="copyToClipboard(data.find(d => d.optimization_id === '${log.optimization_id}').user_prompt, this)">Copy</button>
|
||||
<div class="code-block">${escapeHtml(log.user_prompt)}</div>
|
||||
</div>
|
||||
|
||||
<div class="section">
|
||||
<div class="section-title">Explanation</div>
|
||||
<div class="explanation-block">${escapeHtml(log.explanation)}</div>
|
||||
</div>
|
||||
|
||||
<div class="section">
|
||||
<div class="section-title">Refined Optimization</div>
|
||||
<button class="copy-btn" onclick="copyToClipboard(data.find(d => d.optimization_id === '${log.optimization_id}').refined_optimization, this)">Copy</button>
|
||||
<div class="code-block refined-code">${escapeHtml(log.refined_optimization)}</div>
|
||||
</div>
|
||||
</div>
|
||||
`).join('')}
|
||||
</div>
|
||||
</div>`;
|
||||
});
|
||||
|
||||
dashboard.innerHTML = html;
|
||||
}
|
||||
|
||||
function toggleTrace(idx) {
|
||||
const content = document.getElementById(`trace-${idx}`);
|
||||
const header = content.previousElementSibling;
|
||||
content.classList.toggle('show');
|
||||
header.classList.toggle('expanded');
|
||||
}
|
||||
|
||||
function toggleAll() {
|
||||
const contents = document.querySelectorAll('.trace-content');
|
||||
const headers = document.querySelectorAll('.trace-header');
|
||||
allExpanded = !allExpanded;
|
||||
|
||||
contents.forEach(c => {
|
||||
if (allExpanded) {
|
||||
c.classList.add('show');
|
||||
} else {
|
||||
c.classList.remove('show');
|
||||
}
|
||||
});
|
||||
|
||||
headers.forEach(h => {
|
||||
if (allExpanded) {
|
||||
h.classList.add('expanded');
|
||||
} else {
|
||||
h.classList.remove('expanded');
|
||||
}
|
||||
});
|
||||
|
||||
document.querySelector('.expand-all-btn').textContent = allExpanded ? 'Collapse All' : 'Expand All';
|
||||
}
|
||||
|
||||
document.getElementById('search').addEventListener('input', function(e) {
|
||||
const query = e.target.value.toLowerCase();
|
||||
if (!query) {
|
||||
renderDashboard(data);
|
||||
return;
|
||||
}
|
||||
|
||||
const filtered = data.filter(log =>
|
||||
(log.trace_id && log.trace_id.toLowerCase().includes(query)) ||
|
||||
(log.optimization_id && log.optimization_id.toLowerCase().includes(query)) ||
|
||||
(log.user_prompt && log.user_prompt.toLowerCase().includes(query)) ||
|
||||
(log.explanation && log.explanation.toLowerCase().includes(query)) ||
|
||||
(log.refined_optimization && log.refined_optimization.toLowerCase().includes(query))
|
||||
);
|
||||
|
||||
renderDashboard(filtered);
|
||||
});
|
||||
|
||||
// Initial render
|
||||
renderDashboard();
|
||||
</script>
|
||||
</body>
|
||||
</html>
|
||||
79
experiments/generate_dashboard.py
Normal file
79
experiments/generate_dashboard.py
Normal file
|
|
@ -0,0 +1,79 @@
|
|||
#!/usr/bin/env python3
|
||||
"""Generate an HTML dashboard from code_repair_logs SQLite database."""
|
||||
|
||||
import json
|
||||
import sqlite3
|
||||
import webbrowser
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
def main():
|
||||
db_path = Path(__file__).parent / "code_repair_log.db"
|
||||
cf_db_path = Path(__file__).parent / "code_repair_logs_cf.db"
|
||||
template_path = Path(__file__).parent / "code_repair_dashboard.html"
|
||||
output_path = Path(__file__).parent / "code_repair_dashboard_live.html"
|
||||
|
||||
# Connect to main database and fetch all logs
|
||||
conn = sqlite3.connect(db_path)
|
||||
conn.row_factory = sqlite3.Row
|
||||
cursor = conn.cursor()
|
||||
|
||||
cursor.execute("""
|
||||
SELECT optimization_id, trace_id, user_prompt, explanation,
|
||||
refined_optimization, created_at, updated_at
|
||||
FROM code_repair_logs
|
||||
ORDER BY created_at DESC
|
||||
""")
|
||||
|
||||
rows = cursor.fetchall()
|
||||
conn.close()
|
||||
|
||||
# Connect to cf database and fetch passed/faster status
|
||||
cf_conn = sqlite3.connect(cf_db_path)
|
||||
cf_conn.row_factory = sqlite3.Row
|
||||
cf_cursor = cf_conn.cursor()
|
||||
|
||||
cf_cursor.execute("""
|
||||
SELECT optimization_id, passed, faster
|
||||
FROM code_repair_logs_cf
|
||||
""")
|
||||
|
||||
cf_rows = cf_cursor.fetchall()
|
||||
cf_conn.close()
|
||||
|
||||
# Create lookup dict for cf data
|
||||
cf_data = {row["optimization_id"]: {"passed": row["passed"], "faster": row["faster"]} for row in cf_rows}
|
||||
|
||||
# Convert to list of dicts and merge cf data
|
||||
data = []
|
||||
for row in rows:
|
||||
d = dict(row)
|
||||
opt_id = d["optimization_id"][:-4] + "cdrp"
|
||||
if opt_id in cf_data:
|
||||
d["passed"] = cf_data[opt_id]["passed"]
|
||||
d["faster"] = cf_data[opt_id]["faster"]
|
||||
else:
|
||||
d["passed"] = None
|
||||
d["faster"] = None
|
||||
data.append(d)
|
||||
|
||||
# Read template
|
||||
template = template_path.read_text()
|
||||
|
||||
# Replace placeholder with actual data
|
||||
json_data = json.dumps(data, default=str, indent=2)
|
||||
html_content = template.replace("DATA_PLACEHOLDER", json_data)
|
||||
|
||||
# Write output
|
||||
output_path.write_text(html_content)
|
||||
|
||||
print(f"Dashboard generated: {output_path}")
|
||||
print(f"Total logs: {len(data)}")
|
||||
print(f"Unique traces: {len(set(d['trace_id'] for d in data))}")
|
||||
|
||||
# Open in browser
|
||||
webbrowser.open(f"file://{output_path.absolute()}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
|
@ -39,7 +39,7 @@ const CollapsibleCodeBlock = memo(
|
|||
setIsExpanded(false);
|
||||
}
|
||||
},
|
||||
[isExpanded]
|
||||
[isExpanded],
|
||||
);
|
||||
|
||||
if (!hasMoreLines) {
|
||||
|
|
|
|||
|
|
@ -23,7 +23,9 @@ const Tabs = () => {
|
|||
>
|
||||
<span className={styles.tabText}>
|
||||
Tasks
|
||||
{tasksCount > 0 && <span className={styles.badge}>{tasksCount}</span>}
|
||||
{tasksCount > 0 && (
|
||||
<span className={styles.badge}>{tasksCount}</span>
|
||||
)}
|
||||
</span>
|
||||
</button>
|
||||
</div>
|
||||
|
|
|
|||
|
|
@ -180,7 +180,12 @@ function renderStepsUI(steps) {
|
|||
|
||||
// Show "select different interpreter" option after install button
|
||||
if (vscodeActionCommand) {
|
||||
const { title, btnText: vscodeBtnText, command: vscodeCmd, args = [] } = vscodeActionCommand;
|
||||
const {
|
||||
title,
|
||||
btnText: vscodeBtnText,
|
||||
command: vscodeCmd,
|
||||
args = [],
|
||||
} = vscodeActionCommand;
|
||||
if (title) {
|
||||
const detailsElem = document.createElement("p");
|
||||
detailsElem.className = "step-action";
|
||||
|
|
@ -192,7 +197,11 @@ function renderStepsUI(steps) {
|
|||
actionBtn.textContent = vscodeBtnText;
|
||||
actionBtn.className = "step-action-btn secondary-btn";
|
||||
actionBtn.addEventListener("click", () => {
|
||||
vscode.postMessage({ command: "vscodeCommand", cmd: vscodeCmd, args });
|
||||
vscode.postMessage({
|
||||
command: "vscodeCommand",
|
||||
cmd: vscodeCmd,
|
||||
args,
|
||||
});
|
||||
});
|
||||
actionsContainer.appendChild(actionBtn);
|
||||
}
|
||||
|
|
|
|||
|
|
@ -56,9 +56,7 @@ export async function getRepositoryPublicKey(
|
|||
repo: string,
|
||||
): Promise<{ public_key: string; key_id: string }> {
|
||||
try {
|
||||
console.log(
|
||||
`[secret-utils.ts:getRepositoryPublicKey] Getting public key for ${owner}/${repo}`,
|
||||
)
|
||||
console.log(`[secret-utils.ts:getRepositoryPublicKey] Getting public key for ${owner}/${repo}`)
|
||||
|
||||
const response = await octokit.rest.actions.getRepoPublicKey({
|
||||
owner,
|
||||
|
|
@ -166,4 +164,3 @@ export async function encryptAndStoreSecret(
|
|||
`[secret-utils.ts:encryptAndStoreSecret] Successfully encrypted and stored secret ${secretName} for ${owner}/${repo}`,
|
||||
)
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -0,0 +1,2 @@
|
|||
-- AlterTable
|
||||
ALTER TABLE "public"."optimization_features" ADD COLUMN "optimizations_origin" JSONB;
|
||||
|
|
@ -73,6 +73,7 @@ model optimization_features {
|
|||
organization String?
|
||||
repository String?
|
||||
ranking Json?
|
||||
optimizations_origin Json?
|
||||
review_quality String? // Hight, Med, low
|
||||
review_explanation String?
|
||||
calling_fn_details String?
|
||||
|
|
|
|||
Loading…
Reference in a new issue