mirror of
https://github.com/codeflash-ai/codeflash-internal.git
synced 2026-05-04 18:25:18 +00:00
feedback loop for unmatched test results (#2059)
fixes CF-932 # Pull Request Checklist ## Description - [ ] **Description of PR**: Clear and concise description of what this PR accomplishes - [ ] **Breaking Changes**: Document any breaking changes (if applicable) - [ ] **Related Issues**: Link to any related issues or tickets ## Testing - [ ] **Test cases Attached**: All relevant test cases have been added/updated - [ ] **Manual Testing**: Manual testing completed for the changes ## Monitoring & Debugging - [ ] **Logging in place**: Appropriate logging has been added for debugging user issues - [ ] **Sentry will be able to catch errors**: Error handling ensures Sentry can capture and report errors - [ ] **Avoid Dev based/Prisma logging**: No development-only or Prisma-specific logging in production code ## Configuration - [ ] **Env variables newly added**: Any new environment variables are documented in .env.example file or mentioned in description --- ## Additional Notes <!-- Add any additional context, screenshots, or notes for reviewers here --> --------- Co-authored-by: codeflash-ai[bot] <148906541+codeflash-ai[bot]@users.noreply.github.com> Co-authored-by: ali <mohammed18200118@gmail.com> Co-authored-by: Kevin Turcios <106575910+KRRT7@users.noreply.github.com>
This commit is contained in:
parent
96ea895c99
commit
1192df12a6
25 changed files with 1533 additions and 54 deletions
|
|
@ -219,3 +219,4 @@ RANKING_MODEL: LLM = _get_openai_model()
|
||||||
REFINEMENT_MODEL: LLM = _get_anthropic_model()
|
REFINEMENT_MODEL: LLM = _get_anthropic_model()
|
||||||
EXPLANATIONS_MODEL: LLM = _get_anthropic_model()
|
EXPLANATIONS_MODEL: LLM = _get_anthropic_model()
|
||||||
OPTIMIZATION_REVIEW_MODEL: LLM = _get_anthropic_model()
|
OPTIMIZATION_REVIEW_MODEL: LLM = _get_anthropic_model()
|
||||||
|
CODE_REPAIR_MODEL: LLM = _get_anthropic_model()
|
||||||
|
|
|
||||||
|
|
@ -20,6 +20,7 @@ Including another URLconf
|
||||||
# from django.contrib import admin
|
# from django.contrib import admin
|
||||||
from django.urls import path
|
from django.urls import path
|
||||||
|
|
||||||
|
from code_repair.code_repair import code_repair_api
|
||||||
from explanations.explanations import explanations_api
|
from explanations.explanations import explanations_api
|
||||||
from log_features.log_features import features_api
|
from log_features.log_features import features_api
|
||||||
from optimization_review.optimization_review import optimization_review_api
|
from optimization_review.optimization_review import optimization_review_api
|
||||||
|
|
@ -39,5 +40,6 @@ urlpatterns = [
|
||||||
path("ai/explain", explanations_api.urls),
|
path("ai/explain", explanations_api.urls),
|
||||||
path("ai/rank", ranker_api.urls),
|
path("ai/rank", ranker_api.urls),
|
||||||
path("ai/optimization_review", optimization_review_api.urls),
|
path("ai/optimization_review", optimization_review_api.urls),
|
||||||
|
path("ai/code_repair", code_repair_api.urls),
|
||||||
path("ai/workflow-gen", workflow_gen_api.urls),
|
path("ai/workflow-gen", workflow_gen_api.urls),
|
||||||
]
|
]
|
||||||
|
|
|
||||||
61
django/aiservice/code_repair/CODE_REPAIR_SYSTEM_PROMPT.md
Normal file
61
django/aiservice/code_repair/CODE_REPAIR_SYSTEM_PROMPT.md
Normal file
|
|
@ -0,0 +1,61 @@
|
||||||
|
You are a senior software engineer who is great at reviewing and repairing python code for performance and behavior.
|
||||||
|
The goal of repairing code is to ensure that the optimized code is performant and has the same behavior as the original code based on the provided test results.
|
||||||
|
You are provided the following information:
|
||||||
|
|
||||||
|
- original_source_code
|
||||||
|
- optimized_source_code - This is the optimized implementation of the original code.
|
||||||
|
- test_details - This has the details of the behavioral differences between the original and optimized code.
|
||||||
|
|
||||||
|
### Output format
|
||||||
|
|
||||||
|
Request to replace sections of the optimized code in an existing file using SEARCH/REPLACE blocks that define exact changes to specific parts of the file.
|
||||||
|
Define your output in XML-style tags like here:
|
||||||
|
|
||||||
|
<replace_in_file>
|
||||||
|
<path>src/main.py</path>
|
||||||
|
<diff>
|
||||||
|
<<<<<<< SEARCH
|
||||||
|
a = 2
|
||||||
|
=======
|
||||||
|
a = 3
|
||||||
|
>>>>>>> REPLACE
|
||||||
|
</diff>
|
||||||
|
</replace_in_file>
|
||||||
|
|
||||||
|
Always adhere to this format for tool use to ensure proper parsing and execution.
|
||||||
|
|
||||||
|
## replace_in_file
|
||||||
|
Description: Request to replace sections of content in an existing file using SEARCH/REPLACE blocks that define exact changes to specific parts of the file. This tool should be used when you need to make targeted changes to specific parts of a file.
|
||||||
|
Parameters:
|
||||||
|
- path: (required) The path of the file to modify
|
||||||
|
- diff: (required) One or more SEARCH/REPLACE blocks following this exact format:
|
||||||
|
```
|
||||||
|
<<<<<<< SEARCH
|
||||||
|
[exact content to find]
|
||||||
|
=======
|
||||||
|
[new content to replace with]
|
||||||
|
>>>>>>> REPLACE
|
||||||
|
```
|
||||||
|
Critical rules:
|
||||||
|
1. SEARCH content must match the associated file section to find EXACTLY:
|
||||||
|
* Match character-for-character including whitespace, indentation, line endings
|
||||||
|
* Include all comments, docstrings, etc.
|
||||||
|
2. SEARCH/REPLACE blocks will ONLY replace the first match occurrence.
|
||||||
|
* Including multiple unique SEARCH/REPLACE blocks if you need to make multiple changes.
|
||||||
|
* Include *just* enough lines in each SEARCH section to uniquely match each set of lines that need to change.
|
||||||
|
* When using multiple SEARCH/REPLACE blocks, list them in the order they appear in the file.
|
||||||
|
3. Keep SEARCH/REPLACE blocks concise:
|
||||||
|
* Break large SEARCH/REPLACE blocks into a series of smaller blocks that each change a small portion of the file.
|
||||||
|
* Include just the changing lines, and a few surrounding lines if needed for uniqueness.
|
||||||
|
* Do not include long runs of unchanging lines in SEARCH/REPLACE blocks.
|
||||||
|
* Each line must be complete. Never truncate lines mid-way through as this can cause matching failures.
|
||||||
|
4. Special operations:
|
||||||
|
* To move code: Use two SEARCH/REPLACE blocks (one to delete from original + one to insert at new location)
|
||||||
|
* To delete code: Use empty REPLACE section
|
||||||
|
Usage:
|
||||||
|
<replace_in_file>
|
||||||
|
<path>File path here</path>
|
||||||
|
<diff>
|
||||||
|
Search and replace blocks here
|
||||||
|
</diff>
|
||||||
|
</replace_in_file>
|
||||||
13
django/aiservice/code_repair/CODE_REPAIR_USER_PROMPT.md
Normal file
13
django/aiservice/code_repair/CODE_REPAIR_USER_PROMPT.md
Normal file
|
|
@ -0,0 +1,13 @@
|
||||||
|
Please fix the optimized code to match the behaviour of the original code, while trying to keep the optimization logic intact.
|
||||||
|
|
||||||
|
### The original source code:
|
||||||
|
|
||||||
|
{original_source_code}
|
||||||
|
|
||||||
|
### The optimized source code
|
||||||
|
|
||||||
|
{modified_source_code}
|
||||||
|
|
||||||
|
### The test result details
|
||||||
|
|
||||||
|
{test_details}
|
||||||
0
django/aiservice/code_repair/__init__.py
Normal file
0
django/aiservice/code_repair/__init__.py
Normal file
6
django/aiservice/code_repair/apps.py
Normal file
6
django/aiservice/code_repair/apps.py
Normal file
|
|
@ -0,0 +1,6 @@
|
||||||
|
from django.apps import AppConfig
|
||||||
|
|
||||||
|
|
||||||
|
class CodeRepairConfig(AppConfig):
|
||||||
|
default_auto_field = "django.db.models.BigAutoField"
|
||||||
|
name = "code_repair"
|
||||||
204
django/aiservice/code_repair/code_repair.py
Normal file
204
django/aiservice/code_repair/code_repair.py
Normal file
|
|
@ -0,0 +1,204 @@
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import uuid
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
|
import libcst as cst
|
||||||
|
import sentry_sdk
|
||||||
|
from ninja import NinjaAPI, Schema
|
||||||
|
from openai.types.chat import ChatCompletionSystemMessageParam, ChatCompletionUserMessageParam
|
||||||
|
from pydantic import ValidationError
|
||||||
|
|
||||||
|
from aiservice.analytics.posthog import ph
|
||||||
|
from aiservice.common_utils import validate_trace_id
|
||||||
|
from aiservice.env_specific import debug_log_sensitive_data, llm_clients
|
||||||
|
from aiservice.models.aimodels import CODE_REPAIR_MODEL, calculate_llm_cost
|
||||||
|
from log_features.log_event import update_optimization_cost
|
||||||
|
from log_features.log_features import log_features
|
||||||
|
from optimizer.models import OptimizedCandidateSource
|
||||||
|
|
||||||
|
from .code_repair_context import ( # noqa: TC001 (don't move CodeRepairRequestSchema to type-checking because it's the schema definition)
|
||||||
|
CodeRepairContext,
|
||||||
|
CodeRepairContextData,
|
||||||
|
CodeRepairRequestSchema,
|
||||||
|
)
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from django.http import HttpRequest as Request
|
||||||
|
from openai.types.chat import (
|
||||||
|
ChatCompletionAssistantMessageParam,
|
||||||
|
ChatCompletionFunctionMessageParam,
|
||||||
|
ChatCompletionToolMessageParam,
|
||||||
|
)
|
||||||
|
|
||||||
|
from aiservice.models.aimodels import LLM
|
||||||
|
|
||||||
|
code_repair_api = NinjaAPI(urls_namespace="code_repair")
|
||||||
|
|
||||||
|
# Get the directory of the current file
|
||||||
|
current_dir = Path(__file__).parent
|
||||||
|
SYSTEM_PROMPT = (current_dir / "CODE_REPAIR_SYSTEM_PROMPT.md").read_text()
|
||||||
|
|
||||||
|
USER_PROMPT = (current_dir / "CODE_REPAIR_USER_PROMPT.md").read_text()
|
||||||
|
|
||||||
|
|
||||||
|
async def code_repair( # noqa: D417
|
||||||
|
user_id: str, optimization_id: str, ctx: CodeRepairContext, optimize_model: LLM = CODE_REPAIR_MODEL
|
||||||
|
) -> CodeRepairIntermediateResponseItemschema | CodeRepairErrorResponseSchema:
|
||||||
|
"""Repair the given candidate to match the behaviour of the original code.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
:param user_id:
|
||||||
|
:param optimization_id
|
||||||
|
:param optimize_model: LLM for getting the code_repairs
|
||||||
|
:param ctx: the repair context, has the data property which includes
|
||||||
|
- original code
|
||||||
|
- optimized code
|
||||||
|
- behaviour test diffs
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
CodeRepairIntermediateResponseItemschema or CodeRepairErrorResponseSchema
|
||||||
|
"""
|
||||||
|
system_prompt = ctx.get_system_prompt()
|
||||||
|
user_prompt = ctx.get_user_prompt()
|
||||||
|
|
||||||
|
new_op_id = str(uuid.uuid4())
|
||||||
|
system_message = ChatCompletionSystemMessageParam(role="system", content=system_prompt)
|
||||||
|
user_message = ChatCompletionUserMessageParam(role="user", content=user_prompt)
|
||||||
|
messages: list[
|
||||||
|
ChatCompletionSystemMessageParam
|
||||||
|
| ChatCompletionUserMessageParam
|
||||||
|
| ChatCompletionAssistantMessageParam
|
||||||
|
| ChatCompletionToolMessageParam
|
||||||
|
| ChatCompletionFunctionMessageParam
|
||||||
|
] = [system_message, user_message]
|
||||||
|
debug_log_sensitive_data(f"This was the user prompt\n {user_prompt}\n")
|
||||||
|
llm_client = llm_clients[optimize_model.model_type]
|
||||||
|
try:
|
||||||
|
output = await llm_client.with_options(max_retries=2).chat.completions.create(
|
||||||
|
model=optimize_model.name, messages=messages, n=1
|
||||||
|
)
|
||||||
|
llm_cost = calculate_llm_cost(output, optimize_model)
|
||||||
|
except Exception as e:
|
||||||
|
logging.exception("Claude Code Generation error in code_repair")
|
||||||
|
sentry_sdk.capture_exception(e)
|
||||||
|
debug_log_sensitive_data(f"Failed to generate code for source:\n{ctx.data.original_source_code}")
|
||||||
|
return CodeRepairErrorResponseSchema(error=str(e))
|
||||||
|
debug_log_sensitive_data(f"ClaudeClient optimization response:\n{output.model_dump_json(indent=2)}")
|
||||||
|
if output.usage is not None:
|
||||||
|
ph(user_id, "code_repair-usage", properties={"model": optimize_model.name, "usage": output.usage.json()})
|
||||||
|
results = [content for op in output.choices if (content := op.message.content)] # will be of size 1
|
||||||
|
|
||||||
|
# Regex doesn't work yet in extracting everything else other than the search replace block
|
||||||
|
explanation = results[0]
|
||||||
|
|
||||||
|
repaired_optimization = ""
|
||||||
|
try:
|
||||||
|
diff_patches = ctx.extract_diff_patches_from_llm_res(results[0])
|
||||||
|
repaired_optimization = ctx.apply_patches_to_optimized_code(diff_patches)
|
||||||
|
except (ValueError, ValidationError) as exc:
|
||||||
|
sentry_sdk.capture_exception(exc)
|
||||||
|
debug_log_sensitive_data(f"{type(exc).__name__} for source:\n{ctx.data.modified_source_code}")
|
||||||
|
debug_log_sensitive_data(f"Traceback: {exc}")
|
||||||
|
repaired_optimization = ""
|
||||||
|
|
||||||
|
if not ctx.is_valid(repaired_optimization):
|
||||||
|
repaired_optimization = ""
|
||||||
|
|
||||||
|
return CodeRepairIntermediateResponseItemschema(
|
||||||
|
optimization_id=new_op_id,
|
||||||
|
parent_id=optimization_id,
|
||||||
|
source_code=repaired_optimization,
|
||||||
|
llm_cost=llm_cost,
|
||||||
|
explanation=explanation,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class CodeRepairErrorResponseSchema(Schema):
|
||||||
|
error: str
|
||||||
|
|
||||||
|
|
||||||
|
class CodeRepairIntermediateResponseItemschema(Schema):
|
||||||
|
# the key will be the optimization id and the value will be the actual refined code
|
||||||
|
optimization_id: str
|
||||||
|
parent_id: str
|
||||||
|
source_code: str
|
||||||
|
llm_cost: float
|
||||||
|
explanation: str
|
||||||
|
|
||||||
|
|
||||||
|
class CodeRepairResponseItemschema(Schema):
|
||||||
|
# the key will be the optimization id and the value will be the actual refined code
|
||||||
|
optimization_id: str
|
||||||
|
parent_id: str
|
||||||
|
source_code: str
|
||||||
|
explanation: str
|
||||||
|
|
||||||
|
|
||||||
|
@code_repair_api.post(
|
||||||
|
"/",
|
||||||
|
response={
|
||||||
|
200: CodeRepairResponseItemschema,
|
||||||
|
400: CodeRepairErrorResponseSchema,
|
||||||
|
500: CodeRepairErrorResponseSchema,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
async def repair(
|
||||||
|
request: Request, data: CodeRepairRequestSchema
|
||||||
|
) -> tuple[int, CodeRepairResponseItemschema | CodeRepairErrorResponseSchema]:
|
||||||
|
ph(request.user, "aiservice-code_repair-called")
|
||||||
|
ctx_data = CodeRepairContextData(
|
||||||
|
original_source_code=data.original_source_code,
|
||||||
|
modified_source_code=data.modified_source_code,
|
||||||
|
test_diffs=data.test_diffs,
|
||||||
|
)
|
||||||
|
ctx = CodeRepairContext(ctx_data=ctx_data, base_system_prompt=SYSTEM_PROMPT, base_user_prompt=USER_PROMPT)
|
||||||
|
trace_id = data.trace_id
|
||||||
|
if not validate_trace_id(trace_id):
|
||||||
|
return 400, CodeRepairErrorResponseSchema(error="Invalid trace ID. Please provide a valid UUIDv4.")
|
||||||
|
|
||||||
|
code_repair_data = await code_repair(user_id=request.user, optimization_id=data.optimization_id, ctx=ctx)
|
||||||
|
total_llm_cost = 0.0
|
||||||
|
if isinstance(code_repair_data, CodeRepairErrorResponseSchema):
|
||||||
|
return 500, code_repair_data
|
||||||
|
total_llm_cost += code_repair_data.llm_cost
|
||||||
|
try:
|
||||||
|
ctx.validate_python_module()
|
||||||
|
except cst.ParserSyntaxError as e:
|
||||||
|
# log exception with sentry
|
||||||
|
sentry_sdk.capture_exception(e)
|
||||||
|
debug_log_sensitive_data(f"ParserSyntaxError for source:\n{code_repair_data.source_code}")
|
||||||
|
debug_log_sensitive_data(f"Traceback: {e}")
|
||||||
|
return 500, CodeRepairErrorResponseSchema(error=str(e))
|
||||||
|
except (ValueError, ValidationError) as exc:
|
||||||
|
# Another one bites the Pydantic validation dust
|
||||||
|
sentry_sdk.capture_exception(exc)
|
||||||
|
debug_log_sensitive_data(f"{type(exc).__name__} for source:\n{code_repair_data.source_code}")
|
||||||
|
debug_log_sensitive_data(f"Traceback: {exc}")
|
||||||
|
return 500, CodeRepairErrorResponseSchema(error=str(exc))
|
||||||
|
|
||||||
|
if hasattr(request, "should_log_features") and request.should_log_features:
|
||||||
|
await log_features(
|
||||||
|
trace_id=data.trace_id,
|
||||||
|
user_id=request.user,
|
||||||
|
optimizations_raw={code_repair_data.optimization_id: code_repair_data.source_code},
|
||||||
|
explanations_raw={code_repair_data.optimization_id: code_repair_data.explanation},
|
||||||
|
# explanations_post={cei.optimization_id: cei.explanation for cei in optimization_response_items},
|
||||||
|
# optimizations_post={cei.optimization_id: cei.source_code for cei in optimization_response_items},
|
||||||
|
optimizations_origin={
|
||||||
|
code_repair_data.optimization_id: {
|
||||||
|
"source": OptimizedCandidateSource.REPAIR,
|
||||||
|
"parent": code_repair_data.parent_id,
|
||||||
|
}
|
||||||
|
},
|
||||||
|
)
|
||||||
|
await update_optimization_cost(trace_id=trace_id, cost=total_llm_cost)
|
||||||
|
return 200, CodeRepairResponseItemschema(
|
||||||
|
source_code=code_repair_data.source_code,
|
||||||
|
optimization_id=code_repair_data.optimization_id,
|
||||||
|
parent_id=code_repair_data.parent_id,
|
||||||
|
explanation=code_repair_data.explanation,
|
||||||
|
)
|
||||||
153
django/aiservice/code_repair/code_repair_context.py
Normal file
153
django/aiservice/code_repair/code_repair_context.py
Normal file
|
|
@ -0,0 +1,153 @@
|
||||||
|
import logging
|
||||||
|
from collections import defaultdict
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from enum import Enum
|
||||||
|
|
||||||
|
import libcst as cst
|
||||||
|
import sentry_sdk
|
||||||
|
from aiservice.env_specific import debug_log_sensitive_data
|
||||||
|
from ninja import Field, Schema
|
||||||
|
from optimizer.context_utils.constants import REPLACE_IN_FILE_TAGS_REGEX
|
||||||
|
from optimizer.context_utils.context_helpers import group_code, is_markdown_structure_changed, split_markdown_code
|
||||||
|
from optimizer.diff_patches_utils.patches_v2 import apply_patches, group_diff_patches_by_path
|
||||||
|
from optimizer.models import CodeAndExplanation
|
||||||
|
from pydantic import ValidationError
|
||||||
|
from testgen.instrumentation.edit_generated_test import parse_module_to_cst
|
||||||
|
|
||||||
|
|
||||||
|
class TestDiffScope(str, Enum):
|
||||||
|
RETURN_VALUE = "return_value"
|
||||||
|
STDOUT = "stdout"
|
||||||
|
DID_PASS = "did_pass" # noqa: S105
|
||||||
|
TIMED_OUT = "timed_out"
|
||||||
|
|
||||||
|
|
||||||
|
SCOPE_DESCRIPTIONS = {
|
||||||
|
TestDiffScope.RETURN_VALUE: (
|
||||||
|
"The function returned a different value in the optimized code compared to the original."
|
||||||
|
),
|
||||||
|
TestDiffScope.STDOUT: ("The output printed to stdout is different in the optimized code compared to the original."),
|
||||||
|
TestDiffScope.DID_PASS: (
|
||||||
|
"The test passed in one version but failed in the other (a change in pass/fail behavior)."
|
||||||
|
),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class TestDiff(Schema):
|
||||||
|
scope: TestDiffScope
|
||||||
|
original_value: bool | str | int | float | dict | list | None = None
|
||||||
|
candidate_value: bool | str | int | float | dict | list | None = None
|
||||||
|
original_pass: bool
|
||||||
|
candidate_pass: bool
|
||||||
|
test_src_code: str
|
||||||
|
candidate_pytest_error: str | None = None
|
||||||
|
original_pytest_error: str | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class CodeRepairRequestSchema(Schema):
|
||||||
|
trace_id: str
|
||||||
|
optimization_id: str
|
||||||
|
original_source_code: str
|
||||||
|
modified_source_code: str
|
||||||
|
test_diffs: list[TestDiff] = Field(..., alias="test_diffs")
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass()
|
||||||
|
class CodeRepairContextData:
|
||||||
|
original_source_code: str
|
||||||
|
modified_source_code: str
|
||||||
|
test_diffs: list[TestDiff]
|
||||||
|
|
||||||
|
|
||||||
|
class CodeRepairContext:
|
||||||
|
def __init__(self, ctx_data: CodeRepairContextData, base_system_prompt: str, base_user_prompt: str) -> None:
|
||||||
|
self.data = ctx_data
|
||||||
|
self.base_system_prompt = base_system_prompt
|
||||||
|
self.base_user_prompt = base_user_prompt
|
||||||
|
|
||||||
|
def get_system_prompt(self) -> str:
|
||||||
|
return self.base_system_prompt
|
||||||
|
|
||||||
|
def build_test_details(self, test_diffs: list[TestDiff]) -> str:
|
||||||
|
sections = defaultdict(str)
|
||||||
|
for diff in test_diffs:
|
||||||
|
try:
|
||||||
|
if sections[diff.test_src_code] == "":
|
||||||
|
# add error strings and test def only once per test function
|
||||||
|
sections[diff.test_src_code] += f"""Test Source:
|
||||||
|
```python
|
||||||
|
{diff.test_src_code}
|
||||||
|
```
|
||||||
|
Pytest error (original code): {diff.original_pytest_error if diff.original_pytest_error else ""}
|
||||||
|
Pytest error (optimized code): {diff.candidate_pytest_error if diff.candidate_pytest_error else ""}
|
||||||
|
"""
|
||||||
|
sections[diff.test_src_code] += "\n".join(
|
||||||
|
[
|
||||||
|
f"{SCOPE_DESCRIPTIONS.get(diff.scope, diff.scope.value)}",
|
||||||
|
f"Expected: {diff.original_value!r}.\nGot: {diff.candidate_value!r}."
|
||||||
|
if diff.scope != TestDiffScope.DID_PASS
|
||||||
|
else "",
|
||||||
|
f"Original code test status: {'Passed' if diff.original_pass else 'Failed'}. Optimized code test status: {'Passed' if diff.candidate_pass else 'Failed'}",
|
||||||
|
"---",
|
||||||
|
]
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logging.exception("Some issue in parsing test diffs")
|
||||||
|
sentry_sdk.capture_exception(e)
|
||||||
|
return "\n".join(sections.values())
|
||||||
|
|
||||||
|
def get_user_prompt(self) -> str:
|
||||||
|
return self.base_user_prompt.format(
|
||||||
|
original_source_code=self.data.original_source_code,
|
||||||
|
modified_source_code=self.data.modified_source_code,
|
||||||
|
test_details=self.build_test_details(self.data.test_diffs),
|
||||||
|
)
|
||||||
|
|
||||||
|
def extract_diff_patches_from_llm_res(self, llm_res: str) -> str:
|
||||||
|
matches = REPLACE_IN_FILE_TAGS_REGEX.findall(llm_res)
|
||||||
|
replace_tags = ""
|
||||||
|
if matches and len(matches) != 0:
|
||||||
|
replace_tags = f"<replace_in_file>{matches[0]}</replace_in_file>"
|
||||||
|
|
||||||
|
return replace_tags
|
||||||
|
|
||||||
|
def apply_patches_to_optimized_code(self, replace_tags: str) -> str:
|
||||||
|
if replace_tags == "":
|
||||||
|
return ""
|
||||||
|
|
||||||
|
file_to_code = split_markdown_code(self.data.modified_source_code)
|
||||||
|
# sometimes the llm can write multiple replace tags for the same file, so we group them by path to avoid parsing & applying multiple times
|
||||||
|
file_to_diffs = group_diff_patches_by_path(replace_tags)
|
||||||
|
|
||||||
|
for path, diff in file_to_diffs.items():
|
||||||
|
scoped_code = file_to_code.get(path, None)
|
||||||
|
if scoped_code is None:
|
||||||
|
debug_log_sensitive_data(f"no scoped code for {path}, existing: {file_to_code.keys()}")
|
||||||
|
continue
|
||||||
|
new_code = apply_patches(diff, scoped_code)
|
||||||
|
file_to_code[path] = new_code
|
||||||
|
return group_code(file_to_code)
|
||||||
|
|
||||||
|
def is_valid(self, new_refined_code: str) -> bool:
|
||||||
|
if is_markdown_structure_changed(new_refined_code, self.data.modified_source_code):
|
||||||
|
return False
|
||||||
|
valid = True
|
||||||
|
for code in split_markdown_code(new_refined_code).values():
|
||||||
|
stripped_code = code.strip()
|
||||||
|
if not stripped_code:
|
||||||
|
valid = False
|
||||||
|
break
|
||||||
|
try:
|
||||||
|
parse_module_to_cst(code)
|
||||||
|
except cst.ParserSyntaxError:
|
||||||
|
valid = False
|
||||||
|
break
|
||||||
|
return valid
|
||||||
|
|
||||||
|
def validate_python_module(self) -> None:
|
||||||
|
for _code in split_markdown_code(self.data.modified_source_code).values():
|
||||||
|
try:
|
||||||
|
cst_module = parse_module_to_cst(_code)
|
||||||
|
CodeAndExplanation(cst_module, "")
|
||||||
|
except (ValueError, ValidationError, cst.ParserSyntaxError): # noqa: TRY203
|
||||||
|
raise
|
||||||
|
|
@ -44,6 +44,7 @@ def log_features(
|
||||||
experiment_metadata: dict[str, str] | None = None,
|
experiment_metadata: dict[str, str] | None = None,
|
||||||
final_explanation: str | None = None,
|
final_explanation: str | None = None,
|
||||||
ranking: dict[str, Any] | None = None,
|
ranking: dict[str, Any] | None = None,
|
||||||
|
optimizations_origin: dict[str, dict[str, str]] | None = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Log features of a code optimization run to the database.
|
"""Log features of a code optimization run to the database.
|
||||||
|
|
||||||
|
|
@ -95,6 +96,7 @@ def log_features(
|
||||||
"experiment_metadata": experiment_metadata,
|
"experiment_metadata": experiment_metadata,
|
||||||
"final_explanation": final_explanation,
|
"final_explanation": final_explanation,
|
||||||
"ranking": ranking,
|
"ranking": ranking,
|
||||||
|
"optimizations_origin": optimizations_origin,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -148,6 +150,14 @@ def log_features(
|
||||||
f.ranking = f.ranking | ranking if ranking is not None else f.ranking
|
f.ranking = f.ranking | ranking if ranking is not None else f.ranking
|
||||||
else:
|
else:
|
||||||
f.ranking = ranking if ranking is not None else f.ranking
|
f.ranking = ranking if ranking is not None else f.ranking
|
||||||
|
|
||||||
|
if f.optimizations_origin is not None:
|
||||||
|
# merge the optimizations_origin with the existing ones
|
||||||
|
f.optimizations_origin = merge_dicts(f.optimizations_origin, optimizations_origin or {})
|
||||||
|
else:
|
||||||
|
f.optimizations_origin = (
|
||||||
|
optimizations_origin if optimizations_origin is not None else f.optimizations_origin
|
||||||
|
)
|
||||||
f.speedup_ratio = speedup_ratio if speedup_ratio is not None else f.speedup_ratio
|
f.speedup_ratio = speedup_ratio if speedup_ratio is not None else f.speedup_ratio
|
||||||
f.original_runtime = original_runtime if original_runtime is not None else f.original_runtime
|
f.original_runtime = original_runtime if original_runtime is not None else f.original_runtime
|
||||||
f.optimized_runtime = optimized_runtime if optimized_runtime is not None else f.optimized_runtime
|
f.optimized_runtime = optimized_runtime if optimized_runtime is not None else f.optimized_runtime
|
||||||
|
|
@ -165,6 +175,22 @@ def log_features(
|
||||||
f.save()
|
f.save()
|
||||||
|
|
||||||
|
|
||||||
|
def merge_dicts(a: dict[str, dict[str, str]], b: dict[str, dict[str, str]]) -> dict[str, dict[str, str]]:
|
||||||
|
result: dict[str, dict[str, str]] = {}
|
||||||
|
|
||||||
|
for key, inner in a.items():
|
||||||
|
result[key] = inner.copy()
|
||||||
|
|
||||||
|
for key, inner in b.items():
|
||||||
|
if key not in result:
|
||||||
|
result[key] = inner.copy()
|
||||||
|
else:
|
||||||
|
# b overrides a
|
||||||
|
result[key].update(inner)
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
@features_api.post("/", response={200: None, 500: LoggingErrorResponseSchema})
|
@features_api.post("/", response={200: None, 500: LoggingErrorResponseSchema})
|
||||||
async def log_features_cli(request: HttpRequest, data: LoggingSchema) -> int | tuple[int, LoggingErrorResponseSchema]:
|
async def log_features_cli(request: HttpRequest, data: LoggingSchema) -> int | tuple[int, LoggingErrorResponseSchema]:
|
||||||
try:
|
try:
|
||||||
|
|
|
||||||
|
|
@ -31,6 +31,7 @@ class OptimizationFeatures(models.Model):
|
||||||
experiment_metadata = models.JSONField(null=True, blank=True)
|
experiment_metadata = models.JSONField(null=True, blank=True)
|
||||||
final_explanation = models.TextField(null=True, blank=True)
|
final_explanation = models.TextField(null=True, blank=True)
|
||||||
ranking = models.JSONField(null=True, blank=True)
|
ranking = models.JSONField(null=True, blank=True)
|
||||||
|
optimizations_origin = models.JSONField(null=True, blank=True)
|
||||||
|
|
||||||
# PR suggestions or create Approval fields
|
# PR suggestions or create Approval fields
|
||||||
approval_required = models.BooleanField(default=False)
|
approval_required = models.BooleanField(default=False)
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,7 @@
|
||||||
from pydantic import ValidationError
|
from pydantic import ValidationError
|
||||||
|
|
||||||
|
from optimizer.context_utils.constants import MULTI_REPLACE_IN_FILE_TAGS_REGEX
|
||||||
|
|
||||||
|
|
||||||
class SearchReplaceBlock:
|
class SearchReplaceBlock:
|
||||||
def __init__(self, search, replace):
|
def __init__(self, search, replace):
|
||||||
|
|
@ -19,36 +21,44 @@ def parse_diff(diff: str) -> list[SearchReplaceBlock]:
|
||||||
|
|
||||||
blocks: list[SearchReplaceBlock] = []
|
blocks: list[SearchReplaceBlock] = []
|
||||||
lines = diff.splitlines(keepends=True)
|
lines = diff.splitlines(keepends=True)
|
||||||
|
n = len(lines)
|
||||||
idx = 0
|
idx = 0
|
||||||
|
|
||||||
while idx < len(lines):
|
# Precompute the "marker" strings for efficiency
|
||||||
line = lines[idx].strip()
|
search_marker = "<<<<<<< SEARCH"
|
||||||
if line == "<<<<<<< SEARCH":
|
delimiter_marker = "======="
|
||||||
search_lines = []
|
replace_marker = ">>>>>>> REPLACE"
|
||||||
|
|
||||||
|
while idx < n:
|
||||||
|
line_stripped = lines[idx].strip()
|
||||||
|
if line_stripped.startswith(search_marker):
|
||||||
idx += 1
|
idx += 1
|
||||||
|
search_start = idx
|
||||||
while idx < len(lines) and lines[idx].strip() != "=======":
|
# Find delimiter_marker line
|
||||||
search_lines.append(lines[idx])
|
while idx < n and lines[idx].strip() != delimiter_marker:
|
||||||
idx += 1
|
idx += 1
|
||||||
|
search_end = idx
|
||||||
|
|
||||||
if idx >= len(lines):
|
if idx >= n:
|
||||||
raise ValueError("Invalid diff format: Missing '=======' marker")
|
raise ValueError("Invalid diff format: Missing '=======' marker")
|
||||||
|
|
||||||
replace_lines = []
|
|
||||||
idx += 1
|
idx += 1
|
||||||
|
replace_start = idx
|
||||||
while idx < len(lines) and lines[idx].strip() != ">>>>>>> REPLACE":
|
while idx < n and not lines[idx].strip().startswith(replace_marker):
|
||||||
replace_lines.append(lines[idx])
|
|
||||||
idx += 1
|
idx += 1
|
||||||
|
replace_end = idx
|
||||||
|
|
||||||
if idx >= len(lines):
|
if idx >= n:
|
||||||
raise ValueError("Invalid diff format: Missing '>>>>>>> REPLACE' marker")
|
raise ValueError(
|
||||||
|
"Invalid diff format: Missing '>>>>>>> REPLACE' marker"
|
||||||
|
)
|
||||||
|
|
||||||
search_content = "".join(search_lines).rstrip()
|
search_content = "".join(lines[search_start:search_end]).rstrip()
|
||||||
replace_content = "".join(replace_lines).rstrip()
|
replace_content = "".join(lines[replace_start:replace_end]).rstrip()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
block = SearchReplaceBlock.from_block(search=search_content, replace=replace_content)
|
block = SearchReplaceBlock.from_block(
|
||||||
|
search=search_content, replace=replace_content
|
||||||
|
)
|
||||||
blocks.append(block)
|
blocks.append(block)
|
||||||
except ValidationError as ve:
|
except ValidationError as ve:
|
||||||
raise ValueError(f"Invalid block format: {ve}")
|
raise ValueError(f"Invalid block format: {ve}")
|
||||||
|
|
@ -61,13 +71,35 @@ def parse_diff(diff: str) -> list[SearchReplaceBlock]:
|
||||||
return blocks
|
return blocks
|
||||||
|
|
||||||
|
|
||||||
|
def group_diff_patches_by_path(replace_tags_str: str) -> dict[str, str]:
|
||||||
|
matches = MULTI_REPLACE_IN_FILE_TAGS_REGEX.findall(replace_tags_str)
|
||||||
|
file_to_diffs = {}
|
||||||
|
|
||||||
|
current_file = None
|
||||||
|
current_diff = ""
|
||||||
|
|
||||||
|
for path, diff in matches:
|
||||||
|
if path != current_file:
|
||||||
|
if current_file:
|
||||||
|
file_to_diffs[current_file] = current_diff
|
||||||
|
current_file = path
|
||||||
|
current_diff = diff
|
||||||
|
else:
|
||||||
|
current_diff += diff
|
||||||
|
|
||||||
|
if current_file:
|
||||||
|
file_to_diffs[current_file] = current_diff
|
||||||
|
|
||||||
|
return file_to_diffs
|
||||||
|
|
||||||
|
|
||||||
def apply_patches(diff_str: str, content: str) -> str:
|
def apply_patches(diff_str: str, content: str) -> str:
|
||||||
try:
|
try:
|
||||||
patch_blocks = parse_diff(diff_str)
|
patch_blocks = parse_diff(diff_str)
|
||||||
except ValueError:
|
except ValueError:
|
||||||
return content
|
return content
|
||||||
|
|
||||||
for idx, block in enumerate(patch_blocks, 1):
|
for block in patch_blocks:
|
||||||
if not block.search:
|
if not block.search:
|
||||||
if block.replace:
|
if block.replace:
|
||||||
# a replacement block without a search, then just add the replace block
|
# a replacement block without a search, then just add the replace block
|
||||||
|
|
@ -79,5 +111,7 @@ def apply_patches(diff_str: str, content: str) -> str:
|
||||||
start_char_idx = content.find(block.search)
|
start_char_idx = content.find(block.search)
|
||||||
if start_char_idx != -1:
|
if start_char_idx != -1:
|
||||||
end_char_idx = start_char_idx + len(block.search)
|
end_char_idx = start_char_idx + len(block.search)
|
||||||
content = f"{content[:start_char_idx]}{block.replace}{content[end_char_idx:]}"
|
content = (
|
||||||
|
f"{content[:start_char_idx]}{block.replace}{content[end_char_idx:]}"
|
||||||
|
)
|
||||||
return content
|
return content
|
||||||
|
|
|
||||||
|
|
@ -1,9 +1,18 @@
|
||||||
|
import enum
|
||||||
|
|
||||||
import libcst
|
import libcst
|
||||||
from ninja import Schema
|
from ninja import Schema
|
||||||
from pydantic import field_validator
|
from pydantic import field_validator
|
||||||
from pydantic.dataclasses import dataclass
|
from pydantic.dataclasses import dataclass
|
||||||
|
|
||||||
|
|
||||||
|
class OptimizedCandidateSource(str, enum.Enum):
|
||||||
|
OPTIMIZE = "OPTIMIZE"
|
||||||
|
OPTIMIZE_LP = "OPTIMIZE_LP"
|
||||||
|
REFINE = "REFINE"
|
||||||
|
REPAIR = "REPAIR"
|
||||||
|
|
||||||
|
|
||||||
@dataclass(frozen=True)
|
@dataclass(frozen=True)
|
||||||
class CodeAndExplanation:
|
class CodeAndExplanation:
|
||||||
cst_module: libcst.Module | None
|
cst_module: libcst.Module | None
|
||||||
|
|
|
||||||
|
|
@ -8,6 +8,11 @@ from typing import TYPE_CHECKING
|
||||||
|
|
||||||
import libcst as cst
|
import libcst as cst
|
||||||
import sentry_sdk
|
import sentry_sdk
|
||||||
|
from ninja import NinjaAPI
|
||||||
|
from ninja.errors import HttpError
|
||||||
|
from openai.types.chat import ChatCompletionSystemMessageParam, ChatCompletionUserMessageParam
|
||||||
|
from pydantic import ValidationError
|
||||||
|
|
||||||
from aiservice.analytics.posthog import ph
|
from aiservice.analytics.posthog import ph
|
||||||
from aiservice.common_utils import parse_python_version, should_hack_for_demo, validate_trace_id
|
from aiservice.common_utils import parse_python_version, should_hack_for_demo, validate_trace_id
|
||||||
from aiservice.env_specific import debug_log_sensitive_data, debug_log_sensitive_data_from_callable, llm_clients
|
from aiservice.env_specific import debug_log_sensitive_data, debug_log_sensitive_data_from_callable, llm_clients
|
||||||
|
|
@ -15,11 +20,6 @@ from aiservice.models.aimodels import OPTIMIZE_MODEL, calculate_llm_cost
|
||||||
from authapp.user import get_user_by_id
|
from authapp.user import get_user_by_id
|
||||||
from log_features.log_event import log_optimization_event
|
from log_features.log_event import log_optimization_event
|
||||||
from log_features.log_features import log_features
|
from log_features.log_features import log_features
|
||||||
from ninja import NinjaAPI
|
|
||||||
from ninja.errors import HttpError
|
|
||||||
from openai.types.chat import ChatCompletionSystemMessageParam, ChatCompletionUserMessageParam
|
|
||||||
from pydantic import ValidationError
|
|
||||||
|
|
||||||
from optimizer.context_utils.context_helpers import group_code
|
from optimizer.context_utils.context_helpers import group_code
|
||||||
from optimizer.context_utils.optimizer_context import (
|
from optimizer.context_utils.optimizer_context import (
|
||||||
BaseOptimizerContext,
|
BaseOptimizerContext,
|
||||||
|
|
@ -27,10 +27,9 @@ from optimizer.context_utils.optimizer_context import (
|
||||||
OptimizeResponseItemSchema,
|
OptimizeResponseItemSchema,
|
||||||
OptimizeResponseSchema,
|
OptimizeResponseSchema,
|
||||||
)
|
)
|
||||||
from optimizer.models import OptimizeSchema # noqa: TC001
|
from optimizer.models import OptimizedCandidateSource, OptimizeSchema # noqa: TC001
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from aiservice.models.aimodels import LLM
|
|
||||||
from django.http import HttpRequest
|
from django.http import HttpRequest
|
||||||
from openai.types.chat import (
|
from openai.types.chat import (
|
||||||
ChatCompletionAssistantMessageParam,
|
ChatCompletionAssistantMessageParam,
|
||||||
|
|
@ -38,6 +37,8 @@ if TYPE_CHECKING:
|
||||||
ChatCompletionToolMessageParam,
|
ChatCompletionToolMessageParam,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
from aiservice.models.aimodels import LLM
|
||||||
|
|
||||||
|
|
||||||
optimizations_json = [
|
optimizations_json = [
|
||||||
{
|
{
|
||||||
|
|
@ -345,7 +346,7 @@ async def optimize(
|
||||||
},
|
},
|
||||||
explanations_post={cei.optimization_id: cei.explanation for cei in optimization_response_items},
|
explanations_post={cei.optimization_id: cei.explanation for cei in optimization_response_items},
|
||||||
experiment_metadata=data.experiment_metadata if data.experiment_metadata else None,
|
experiment_metadata=data.experiment_metadata if data.experiment_metadata else None,
|
||||||
# request=request,
|
optimizations_origin={cei.optimization_id: {"source": OptimizedCandidateSource.OPTIMIZE, "parent": None} for cei in optimization_response_items},
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -5,29 +5,30 @@ from pathlib import Path
|
||||||
from typing import TYPE_CHECKING
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
import sentry_sdk
|
import sentry_sdk
|
||||||
|
from ninja import NinjaAPI, Schema
|
||||||
|
from openai.types.chat import ChatCompletionSystemMessageParam, ChatCompletionUserMessageParam
|
||||||
|
|
||||||
from aiservice.analytics.posthog import ph
|
from aiservice.analytics.posthog import ph
|
||||||
from aiservice.common_utils import parse_python_version, validate_trace_id
|
from aiservice.common_utils import parse_python_version, validate_trace_id
|
||||||
from aiservice.env_specific import debug_log_sensitive_data, debug_log_sensitive_data_from_callable, llm_clients
|
from aiservice.env_specific import debug_log_sensitive_data, debug_log_sensitive_data_from_callable, llm_clients
|
||||||
from aiservice.models.aimodels import OPTIMIZE_MODEL, calculate_llm_cost
|
from aiservice.models.aimodels import OPTIMIZE_MODEL, calculate_llm_cost
|
||||||
from log_features.log_event import update_optimization_cost
|
from log_features.log_event import update_optimization_cost
|
||||||
from log_features.log_features import log_features
|
from log_features.log_features import log_features
|
||||||
from ninja import NinjaAPI, Schema
|
|
||||||
from openai.types.chat import ChatCompletionSystemMessageParam, ChatCompletionUserMessageParam
|
|
||||||
|
|
||||||
from optimizer.context_utils.optimizer_context import (
|
from optimizer.context_utils.optimizer_context import (
|
||||||
BaseOptimizerContext,
|
BaseOptimizerContext,
|
||||||
OptimizeErrorResponseSchema,
|
OptimizeErrorResponseSchema,
|
||||||
OptimizeResponseSchema,
|
OptimizeResponseSchema,
|
||||||
)
|
)
|
||||||
|
from optimizer.models import OptimizedCandidateSource
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from aiservice.models.aimodels import LLM
|
|
||||||
from openai.types.chat import (
|
from openai.types.chat import (
|
||||||
ChatCompletionAssistantMessageParam,
|
ChatCompletionAssistantMessageParam,
|
||||||
ChatCompletionFunctionMessageParam,
|
ChatCompletionFunctionMessageParam,
|
||||||
ChatCompletionToolMessageParam,
|
ChatCompletionToolMessageParam,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
from aiservice.models.aimodels import LLM
|
||||||
from optimizer.context_utils.optimizer_context import OptimizeResponseItemSchema
|
from optimizer.context_utils.optimizer_context import OptimizeResponseItemSchema
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -190,6 +191,8 @@ async def optimize(request, data: OptimizeSchemaLP) -> tuple[int, OptimizeRespon
|
||||||
},
|
},
|
||||||
explanations_post={cei.optimization_id: cei.explanation for cei in optimization_response_items},
|
explanations_post={cei.optimization_id: cei.explanation for cei in optimization_response_items},
|
||||||
experiment_metadata=data.experiment_metadata if data.experiment_metadata else None,
|
experiment_metadata=data.experiment_metadata if data.experiment_metadata else None,
|
||||||
|
optimizations_origin={cei.optimization_id: {"source": OptimizedCandidateSource.OPTIMIZE_LP, "parent": None} for cei in optimization_response_items},
|
||||||
|
|
||||||
)
|
)
|
||||||
|
|
||||||
response = OptimizeResponseSchema(optimizations=optimization_response_items)
|
response = OptimizeResponseSchema(optimizations=optimization_response_items)
|
||||||
|
|
|
||||||
|
|
@ -2,31 +2,34 @@ from __future__ import annotations
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
import logging
|
import logging
|
||||||
|
import uuid
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import TYPE_CHECKING
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
import libcst as cst
|
import libcst as cst
|
||||||
import sentry_sdk
|
import sentry_sdk
|
||||||
|
from ninja import NinjaAPI, Schema
|
||||||
|
from openai.types.chat import ChatCompletionSystemMessageParam, ChatCompletionUserMessageParam
|
||||||
|
from pydantic import ValidationError
|
||||||
|
|
||||||
from aiservice.analytics.posthog import ph
|
from aiservice.analytics.posthog import ph
|
||||||
from aiservice.common_utils import validate_trace_id
|
from aiservice.common_utils import validate_trace_id
|
||||||
from aiservice.env_specific import debug_log_sensitive_data, llm_clients
|
from aiservice.env_specific import debug_log_sensitive_data, llm_clients
|
||||||
from aiservice.models.aimodels import REFINEMENT_MODEL, calculate_llm_cost
|
from aiservice.models.aimodels import REFINEMENT_MODEL, calculate_llm_cost
|
||||||
from log_features.log_event import update_optimization_cost
|
from log_features.log_event import update_optimization_cost
|
||||||
from log_features.log_features import log_features
|
from log_features.log_features import log_features
|
||||||
from ninja import NinjaAPI, Schema
|
|
||||||
from openai.types.chat import ChatCompletionSystemMessageParam, ChatCompletionUserMessageParam
|
|
||||||
from pydantic import ValidationError
|
|
||||||
|
|
||||||
from optimizer.context_utils.refiner_context import BaseRefinerContext, RefinementContextData
|
from optimizer.context_utils.refiner_context import BaseRefinerContext, RefinementContextData
|
||||||
|
from optimizer.models import OptimizedCandidateSource
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from aiservice.models.aimodels import LLM
|
|
||||||
from openai.types.chat import (
|
from openai.types.chat import (
|
||||||
ChatCompletionAssistantMessageParam,
|
ChatCompletionAssistantMessageParam,
|
||||||
ChatCompletionFunctionMessageParam,
|
ChatCompletionFunctionMessageParam,
|
||||||
ChatCompletionToolMessageParam,
|
ChatCompletionToolMessageParam,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
from aiservice.models.aimodels import LLM
|
||||||
|
|
||||||
|
|
||||||
refinement_api = NinjaAPI(urls_namespace="refinement")
|
refinement_api = NinjaAPI(urls_namespace="refinement")
|
||||||
|
|
||||||
|
|
@ -260,7 +263,8 @@ async def refinement( # noqa: D417
|
||||||
refined_optimization = ""
|
refined_optimization = ""
|
||||||
|
|
||||||
return RefinementIntermediateResponseItemschema(
|
return RefinementIntermediateResponseItemschema(
|
||||||
optimization_id=optimization_id,
|
parent_id=optimization_id,
|
||||||
|
optimization_id=str(uuid.uuid4()),
|
||||||
source_code=refined_optimization,
|
source_code=refined_optimization,
|
||||||
explanation=refined_explanation,
|
explanation=refined_explanation,
|
||||||
original_explanation=ctx.data.optimized_explanation,
|
original_explanation=ctx.data.optimized_explanation,
|
||||||
|
|
@ -291,6 +295,7 @@ class OptimizeErrorResponseSchema(Schema):
|
||||||
class RefinementIntermediateResponseItemschema(Schema):
|
class RefinementIntermediateResponseItemschema(Schema):
|
||||||
# the key will be the optimization id and the value will be the actual refined code
|
# the key will be the optimization id and the value will be the actual refined code
|
||||||
explanation: str
|
explanation: str
|
||||||
|
parent_id: str
|
||||||
optimization_id: str
|
optimization_id: str
|
||||||
source_code: str
|
source_code: str
|
||||||
original_explanation: str
|
original_explanation: str
|
||||||
|
|
@ -301,6 +306,7 @@ class RefinementResponseItemschema(Schema):
|
||||||
# the key will be the optimization id and the value will be the actual refined code
|
# the key will be the optimization id and the value will be the actual refined code
|
||||||
explanation: str
|
explanation: str
|
||||||
optimization_id: str
|
optimization_id: str
|
||||||
|
parent_id: str
|
||||||
source_code: str
|
source_code: str
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -374,27 +380,30 @@ async def refine(
|
||||||
trace_id=trace_id,
|
trace_id=trace_id,
|
||||||
user_id=request.user,
|
user_id=request.user,
|
||||||
optimizations_raw={
|
optimizations_raw={
|
||||||
cei.optimization_id[:-4] + "refi": cei.source_code
|
cei.optimization_id: cei.source_code
|
||||||
for cei in refinement_data
|
for cei in refinement_data
|
||||||
if not isinstance(cei, OptimizeErrorResponseSchema)
|
if not isinstance(cei, OptimizeErrorResponseSchema)
|
||||||
},
|
},
|
||||||
optimizations_post={
|
optimizations_post={cei.optimization_id: cei.source_code for cei in filtered_refined_optimizations},
|
||||||
cei.optimization_id[:-4] + "refi": cei.source_code for cei in filtered_refined_optimizations
|
|
||||||
},
|
|
||||||
explanations_raw={
|
explanations_raw={
|
||||||
cei.optimization_id[:-4] + "refi": cei.explanation
|
cei.optimization_id: cei.explanation
|
||||||
for cei in refinement_data
|
for cei in refinement_data
|
||||||
if not isinstance(cei, OptimizeErrorResponseSchema)
|
if not isinstance(cei, OptimizeErrorResponseSchema)
|
||||||
},
|
},
|
||||||
explanations_post={
|
explanations_post={cei.optimization_id: cei.explanation for cei in filtered_refined_optimizations},
|
||||||
cei.optimization_id[:-4] + "refi": cei.explanation for cei in filtered_refined_optimizations
|
optimizations_origin={
|
||||||
|
cei.optimization_id: {"source": OptimizedCandidateSource.REFINE, "parent": cei.parent_id}
|
||||||
|
for cei in filtered_refined_optimizations
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
await update_optimization_cost(trace_id=trace_id, cost=total_llm_cost)
|
await update_optimization_cost(trace_id=trace_id, cost=total_llm_cost)
|
||||||
return 200, Refinementschema(
|
return 200, Refinementschema(
|
||||||
refinements=[
|
refinements=[
|
||||||
RefinementResponseItemschema(
|
RefinementResponseItemschema(
|
||||||
source_code=x.source_code, explanation=x.original_explanation, optimization_id=x.optimization_id
|
source_code=x.source_code,
|
||||||
|
explanation=x.original_explanation,
|
||||||
|
optimization_id=x.optimization_id,
|
||||||
|
parent_id=x.parent_id,
|
||||||
)
|
)
|
||||||
for x in filtered_refined_optimizations
|
for x in filtered_refined_optimizations
|
||||||
]
|
]
|
||||||
|
|
|
||||||
|
|
@ -3,6 +3,7 @@ from __future__ import annotations
|
||||||
from typing import TYPE_CHECKING, cast
|
from typing import TYPE_CHECKING, cast
|
||||||
|
|
||||||
from libcst import CSTTransformer, ImportAlias, ImportFrom, MetadataWrapper, Name, parse_expression, parse_module
|
from libcst import CSTTransformer, ImportAlias, ImportFrom, MetadataWrapper, Name, parse_expression, parse_module
|
||||||
|
from functools import lru_cache
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from libcst import (
|
from libcst import (
|
||||||
|
|
@ -19,6 +20,7 @@ if TYPE_CHECKING:
|
||||||
from aiservice.models.functions_to_optimize import FunctionToOptimize
|
from aiservice.models.functions_to_optimize import FunctionToOptimize
|
||||||
|
|
||||||
|
|
||||||
|
@lru_cache(maxsize=128)
|
||||||
def parse_module_to_cst(module_str: str) -> Module:
|
def parse_module_to_cst(module_str: str) -> Module:
|
||||||
"""Parse a module string into its libCST representation.
|
"""Parse a module string into its libCST representation.
|
||||||
|
|
||||||
|
|
|
||||||
389
django/aiservice/tests/optimizer/test_code_repair.py
Normal file
389
django/aiservice/tests/optimizer/test_code_repair.py
Normal file
|
|
@ -0,0 +1,389 @@
|
||||||
|
|
||||||
|
|
||||||
|
from code_repair.code_repair_context import CodeRepairContext, CodeRepairContextData
|
||||||
|
from optimizer.diff_patches_utils.patches_v2 import apply_patches
|
||||||
|
|
||||||
|
|
||||||
|
def test_code_repair_single_file():
|
||||||
|
|
||||||
|
original_code = """```python:demo.py
|
||||||
|
import math
|
||||||
|
from typing import List, Tuple, Optional
|
||||||
|
|
||||||
|
def calculate_portfolio_metrics(
|
||||||
|
investments: List[Tuple[str, float, float]],
|
||||||
|
risk_free_rate: float = 0.02
|
||||||
|
) -> dict:
|
||||||
|
if not investments:
|
||||||
|
raise ValueError("Investments list cannot be empty")
|
||||||
|
|
||||||
|
if abs(sum(weight for _, weight, _ in investments) - 1.0) > 1e-10:
|
||||||
|
raise ValueError("Portfolio weights must sum to 1.0")
|
||||||
|
|
||||||
|
# Calculate weighted return
|
||||||
|
weighted_return = sum(weight * ret for _, weight, ret in investments)
|
||||||
|
|
||||||
|
# Calculate portfolio volatility (simplified)
|
||||||
|
volatility = math.sqrt(sum((weight * ret) ** 2 for _, weight, ret in investments))
|
||||||
|
|
||||||
|
# Calculate Sharpe ratio
|
||||||
|
if volatility == 0:
|
||||||
|
sharpe_ratio = 0.0
|
||||||
|
else:
|
||||||
|
sharpe_ratio = (weighted_return - risk_free_rate) / volatility
|
||||||
|
|
||||||
|
# Find best and worst performing assets
|
||||||
|
best_asset = max(investments, key=lambda x: x[2])
|
||||||
|
worst_asset = min(investments, key=lambda x: x[2])
|
||||||
|
|
||||||
|
return {
|
||||||
|
'weighted_return': round(weighted_return, 6),
|
||||||
|
'volatility': round(volatility, 6),
|
||||||
|
'sharpe_ratio': round(sharpe_ratio, 6),
|
||||||
|
'best_performing': (best_asset[0], round(best_asset[2], 6)),
|
||||||
|
'worst_performing': (worst_asset[0], round(worst_asset[2], 6)),
|
||||||
|
'total_assets': len(investments)
|
||||||
|
}
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
optimized_code = """```python:demo.py
|
||||||
|
import math
|
||||||
|
from typing import List, Tuple, Optional
|
||||||
|
|
||||||
|
def calculate_portfolio_metrics(
|
||||||
|
investments: List[Tuple[str, float, float]],
|
||||||
|
risk_free_rate: float = 0.02
|
||||||
|
) -> dict:
|
||||||
|
if not investments:
|
||||||
|
raise ValueError("Investments list cannot be empty")
|
||||||
|
|
||||||
|
total_weight = sum(w for _, w, _ in investments)
|
||||||
|
if total_weight != 1.0: # Should use tolerance check
|
||||||
|
raise ValueError("Portfolio weights must sum to 1.0")
|
||||||
|
|
||||||
|
weighted_return = 1.0
|
||||||
|
for _, weight, ret in investments:
|
||||||
|
weighted_return *= (1 + ret) ** weight
|
||||||
|
weighted_return = weighted_return - 1.0 # Convert back from geometric
|
||||||
|
|
||||||
|
returns = [r for _, _, r in investments]
|
||||||
|
mean_return = sum(returns) / len(returns)
|
||||||
|
volatility = math.sqrt(sum((r - mean_return) ** 2 for r in returns) / len(returns))
|
||||||
|
|
||||||
|
# BUG 4: Sharpe ratio calculation is correct but uses wrong inputs
|
||||||
|
if volatility == 0:
|
||||||
|
sharpe_ratio = 0.0
|
||||||
|
else:
|
||||||
|
sharpe_ratio = (weighted_return - risk_free_rate) / volatility
|
||||||
|
|
||||||
|
def risk_adjusted_return(return_val, weight):
|
||||||
|
return (return_val - risk_free_rate) / (weight * return_val) if weight * return_val != 0 else return_val
|
||||||
|
|
||||||
|
best_asset = max(investments, key=lambda x: risk_adjusted_return(x[2], x[1]))
|
||||||
|
worst_asset = min(investments, key=lambda x: risk_adjusted_return(x[2], x[1]))
|
||||||
|
|
||||||
|
return {
|
||||||
|
"weighted_return": round(weighted_return, 6),
|
||||||
|
"volatility": 2,
|
||||||
|
"sharpe_ratio": round(sharpe_ratio, 6),
|
||||||
|
"best_performing": (best_asset[0], round(best_asset[2], 6)),
|
||||||
|
"worst_performing": (worst_asset[0], round(worst_asset[2], 6)),
|
||||||
|
"total_assets": len(investments),
|
||||||
|
}
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
llm_response = """Looking at the test failure and comparing the original and modified code, I can see several issues in the modified implementation:
|
||||||
|
|
||||||
|
1. **Hardcoded volatility**: The volatility is hardcoded to `2` instead of being calculated
|
||||||
|
2. **Incorrect weight validation**: Missing tolerance check for weight sum validation
|
||||||
|
3. **Wrong volatility calculation**: Using standard deviation of returns instead of weighted portfolio volatility
|
||||||
|
4. **Wrong weighted return calculation**: Using geometric mean instead of arithmetic weighted return
|
||||||
|
5. **Wrong asset selection criteria**: Using risk-adjusted return instead of simple return
|
||||||
|
|
||||||
|
Let me fix these issues:
|
||||||
|
|
||||||
|
<replace_in_file>
|
||||||
|
<path>demo.py</path>
|
||||||
|
<diff>
|
||||||
|
<<<<<<< SEARCH
|
||||||
|
total_weight = sum(w for _, w, _ in investments)
|
||||||
|
if total_weight != 1.0: # Should use tolerance check
|
||||||
|
raise ValueError("Portfolio weights must sum to 1.0")
|
||||||
|
=======
|
||||||
|
if abs(sum(weight for _, weight, _ in investments) - 1.0) > 1e-10:
|
||||||
|
raise ValueError("Portfolio weights must sum to 1.0")
|
||||||
|
>>>>>>> REPLACE
|
||||||
|
</diff>
|
||||||
|
</replace_in_file>
|
||||||
|
|
||||||
|
<replace_in_file>
|
||||||
|
<path>demo.py</path>
|
||||||
|
<diff>
|
||||||
|
<<<<<<< SEARCH
|
||||||
|
weighted_return = 1.0
|
||||||
|
for _, weight, ret in investments:
|
||||||
|
weighted_return *= (1 + ret) ** weight
|
||||||
|
weighted_return = weighted_return - 1.0 # Convert back from geometric
|
||||||
|
=======
|
||||||
|
# Calculate weighted return
|
||||||
|
weighted_return = sum(weight * ret for _, weight, ret in investments)
|
||||||
|
>>>>>>> REPLACE
|
||||||
|
</diff>
|
||||||
|
</replace_in_file>
|
||||||
|
|
||||||
|
<replace_in_file>
|
||||||
|
<path>demo.py</path>
|
||||||
|
<diff>
|
||||||
|
<<<<<<< SEARCH
|
||||||
|
returns = [r for _, _, r in investments]
|
||||||
|
mean_return = sum(returns) / len(returns)
|
||||||
|
volatility = math.sqrt(sum((r - mean_return) ** 2 for r in returns) / len(returns))
|
||||||
|
=======
|
||||||
|
# Calculate portfolio volatility (simplified)
|
||||||
|
volatility = math.sqrt(sum((weight * ret) ** 2 for _, weight, ret in investments))
|
||||||
|
>>>>>>> REPLACE
|
||||||
|
</diff>
|
||||||
|
</replace_in_file>
|
||||||
|
|
||||||
|
<replace_in_file>
|
||||||
|
<path>demo.py</path>
|
||||||
|
<diff>
|
||||||
|
<<<<<<< SEARCH
|
||||||
|
# BUG 4: Sharpe ratio calculation is correct but uses wrong inputs
|
||||||
|
if volatility == 0:
|
||||||
|
sharpe_ratio = 0.0
|
||||||
|
else:
|
||||||
|
sharpe_ratio = (weighted_return - risk_free_rate) / volatility
|
||||||
|
|
||||||
|
def risk_adjusted_return(return_val, weight):
|
||||||
|
return (return_val - risk_free_rate) / (weight * return_val) if weight * return_val != 0 else return_val
|
||||||
|
|
||||||
|
best_asset = max(investments, key=lambda x: risk_adjusted_return(x[2], x[1]))
|
||||||
|
worst_asset = min(investments, key=lambda x: risk_adjusted_return(x[2], x[1]))
|
||||||
|
=======
|
||||||
|
# Calculate Sharpe ratio
|
||||||
|
if volatility == 0:
|
||||||
|
sharpe_ratio = 0.0
|
||||||
|
else:
|
||||||
|
sharpe_ratio = (weighted_return - risk_free_rate) / volatility
|
||||||
|
|
||||||
|
# Find best and worst performing assets
|
||||||
|
best_asset = max(investments, key=lambda x: x[2])
|
||||||
|
worst_asset = min(investments, key=lambda x: x[2])
|
||||||
|
>>>>>>> REPLACE
|
||||||
|
</diff>
|
||||||
|
</replace_in_file>
|
||||||
|
|
||||||
|
<replace_in_file>
|
||||||
|
<path>demo.py</path>
|
||||||
|
<diff>
|
||||||
|
<<<<<<< SEARCH
|
||||||
|
return {
|
||||||
|
"weighted_return": round(weighted_return, 6),
|
||||||
|
"volatility": 2,
|
||||||
|
"sharpe_ratio": round(sharpe_ratio, 6),
|
||||||
|
"best_performing": (best_asset[0], round(best_asset[2], 6)),
|
||||||
|
"worst_performing": (worst_asset[0], round(worst_asset[2], 6)),
|
||||||
|
"total_assets": len(investments),
|
||||||
|
}
|
||||||
|
=======
|
||||||
|
return {
|
||||||
|
'weighted_return': round(weighted_return, 6),
|
||||||
|
'volatility': round(volatility, 6),
|
||||||
|
'sharpe_ratio': round(sharpe_ratio, 6),
|
||||||
|
'best_performing': (best_asset[0], round(best_asset[2], 6)),
|
||||||
|
'worst_performing': (worst_asset[0], round(worst_asset[2], 6)),
|
||||||
|
'total_assets': len(investments)
|
||||||
|
}
|
||||||
|
>>>>>>> REPLACE
|
||||||
|
</diff>
|
||||||
|
</replace_in_file>
|
||||||
|
|
||||||
|
The key changes I made:
|
||||||
|
|
||||||
|
1. **Fixed weight validation**: Restored the tolerance-based check using `abs(sum(...) - 1.0) > 1e-10` instead of exact equality
|
||||||
|
2. **Fixed weighted return calculation**: Changed from geometric mean back to simple weighted arithmetic mean
|
||||||
|
3. **Fixed volatility calculation**: Restored the original simplified portfolio volatility formula using weighted returns
|
||||||
|
4. **Fixed volatility return value**: Removed the hardcoded `2` and properly calculated and rounded the volatility
|
||||||
|
5. **Fixed asset selection**: Restored the original logic to find best/worst assets based on simple returns rather than risk-adjusted returns
|
||||||
|
6. **Fixed dictionary formatting**: Changed from double quotes to single quotes to match original formatting
|
||||||
|
|
||||||
|
These changes align the modified code with the original implementation's behavior, ensuring that the test for zero volatility passes (when cash investment has 0% return, the volatility should indeed be 0.0)."""
|
||||||
|
ctx = CodeRepairContext(CodeRepairContextData(original_code, optimized_code, ""), "" , "")
|
||||||
|
diff_patches = ctx.extract_diff_patches_from_llm_res(llm_response)
|
||||||
|
refined_optimization = ctx.apply_patches_to_optimized_code(diff_patches)
|
||||||
|
|
||||||
|
print(refined_optimization)
|
||||||
|
assert ctx.is_valid(refined_optimization)
|
||||||
|
|
||||||
|
|
||||||
|
"""
|
||||||
|
```python:demo.py
|
||||||
|
import math
|
||||||
|
from typing import List, Tuple, Optional
|
||||||
|
|
||||||
|
def calculate_portfolio_metrics(
|
||||||
|
investments: List[Tuple[str, float, float]],
|
||||||
|
risk_free_rate: float = 0.02
|
||||||
|
) -> dict:
|
||||||
|
if not investments:
|
||||||
|
raise ValueError("Investments list cannot be empty")
|
||||||
|
|
||||||
|
total_weight = sum(w for _, w, _ in investments)
|
||||||
|
if total_weight != 1.0: # Should use tolerance check
|
||||||
|
raise ValueError("Portfolio weights must sum to 1.0")
|
||||||
|
|
||||||
|
weighted_return = 1.0
|
||||||
|
for _, weight, ret in investments:
|
||||||
|
weighted_return *= (1 + ret) ** weight
|
||||||
|
weighted_return = weighted_return - 1.0 # Convert back from geometric
|
||||||
|
|
||||||
|
returns = [r for _, _, r in investments]
|
||||||
|
mean_return = sum(returns) / len(returns)
|
||||||
|
volatility = math.sqrt(sum((r - mean_return) ** 2 for r in returns) / len(returns))
|
||||||
|
|
||||||
|
# BUG 4: Sharpe ratio calculation is correct but uses wrong inputs
|
||||||
|
if volatility == 0:
|
||||||
|
sharpe_ratio = 0.0
|
||||||
|
else:
|
||||||
|
sharpe_ratio = (weighted_return - risk_free_rate) / volatility
|
||||||
|
|
||||||
|
def risk_adjusted_return(return_val, weight):
|
||||||
|
return (return_val - risk_free_rate) / (weight * return_val) if weight * return_val != 0 else return_val
|
||||||
|
|
||||||
|
best_asset = max(investments, key=lambda x: risk_adjusted_return(x[2], x[1]))
|
||||||
|
worst_asset = min(investments, key=lambda x: risk_adjusted_return(x[2], x[1]))
|
||||||
|
|
||||||
|
return {
|
||||||
|
"weighted_return": round(weighted_return, 6),
|
||||||
|
"volatility": 2,
|
||||||
|
"sharpe_ratio": round(sharpe_ratio, 6),
|
||||||
|
"best_performing": (best_asset[0], round(best_asset[2], 6)),
|
||||||
|
"worst_performing": (worst_asset[0], round(worst_asset[2], 6)),
|
||||||
|
"total_assets": len(investments),
|
||||||
|
}
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
def test_patch_apply():
|
||||||
|
patch = """<<<<<<< SEARCH
|
||||||
|
total_weight = sum(w for _, w, _ in investments)
|
||||||
|
if total_weight != 1.0: # Should use tolerance check
|
||||||
|
raise ValueError("Portfolio weights must sum to 1.0")
|
||||||
|
=======
|
||||||
|
if abs(sum(weight for _, weight, _ in investments) - 1.0) > 1e-10:
|
||||||
|
raise ValueError("Portfolio weights must sum to 1.0")
|
||||||
|
>>>>>>> REPLACE
|
||||||
|
|
||||||
|
<<<<<<< SEARCH
|
||||||
|
weighted_return = 1.0
|
||||||
|
for _, weight, ret in investments:
|
||||||
|
weighted_return *= (1 + ret) ** weight
|
||||||
|
weighted_return = weighted_return - 1.0 # Convert back from geometric
|
||||||
|
=======
|
||||||
|
# Calculate weighted return
|
||||||
|
weighted_return = sum(weight * ret for _, weight, ret in investments)
|
||||||
|
>>>>>>> REPLACE
|
||||||
|
|
||||||
|
<<<<<<< SEARCH
|
||||||
|
returns = [r for _, _, r in investments]
|
||||||
|
mean_return = sum(returns) / len(returns)
|
||||||
|
volatility = math.sqrt(sum((r - mean_return) ** 2 for r in returns) / len(returns))
|
||||||
|
=======
|
||||||
|
# Calculate portfolio volatility (simplified)
|
||||||
|
volatility = math.sqrt(sum((weight * ret) ** 2 for _, weight, ret in investments))
|
||||||
|
>>>>>>> REPLACE
|
||||||
|
|
||||||
|
<<<<<<< SEARCH
|
||||||
|
# BUG 4: Sharpe ratio calculation is correct but uses wrong inputs
|
||||||
|
if volatility == 0:
|
||||||
|
sharpe_ratio = 0.0
|
||||||
|
else:
|
||||||
|
sharpe_ratio = (weighted_return - risk_free_rate) / volatility
|
||||||
|
|
||||||
|
def risk_adjusted_return(return_val, weight):
|
||||||
|
return (return_val - risk_free_rate) / (weight * return_val) if weight * return_val != 0 else return_val
|
||||||
|
|
||||||
|
best_asset = max(investments, key=lambda x: risk_adjusted_return(x[2], x[1]))
|
||||||
|
worst_asset = min(investments, key=lambda x: risk_adjusted_return(x[2], x[1]))
|
||||||
|
=======
|
||||||
|
# Calculate Sharpe ratio
|
||||||
|
if volatility == 0:
|
||||||
|
sharpe_ratio = 0.0
|
||||||
|
else:
|
||||||
|
sharpe_ratio = (weighted_return - risk_free_rate) / volatility
|
||||||
|
|
||||||
|
# Find best and worst performing assets
|
||||||
|
best_asset = max(investments, key=lambda x: x[2])
|
||||||
|
worst_asset = min(investments, key=lambda x: x[2])
|
||||||
|
>>>>>>> REPLACE
|
||||||
|
|
||||||
|
<<<<<<< SEARCH
|
||||||
|
return {
|
||||||
|
"weighted_return": round(weighted_return, 6),
|
||||||
|
"volatility": 2,
|
||||||
|
"sharpe_ratio": round(sharpe_ratio, 6),
|
||||||
|
"best_performing": (best_asset[0], round(best_asset[2], 6)),
|
||||||
|
"worst_performing": (worst_asset[0], round(worst_asset[2], 6)),
|
||||||
|
"total_assets": len(investments),
|
||||||
|
}
|
||||||
|
=======
|
||||||
|
return {
|
||||||
|
'weighted_return': round(weighted_return, 6),
|
||||||
|
'volatility': round(volatility, 6),
|
||||||
|
'sharpe_ratio': round(sharpe_ratio, 6),
|
||||||
|
'best_performing': (best_asset[0], round(best_asset[2], 6)),
|
||||||
|
'worst_performing': (worst_asset[0], round(worst_asset[2], 6)),
|
||||||
|
'total_assets': len(investments)
|
||||||
|
}
|
||||||
|
>>>>>>> REPLACE
|
||||||
|
"""
|
||||||
|
code = """import math
|
||||||
|
from typing import List, Tuple, Optional
|
||||||
|
|
||||||
|
def calculate_portfolio_metrics(
|
||||||
|
investments: List[Tuple[str, float, float]],
|
||||||
|
risk_free_rate: float = 0.02
|
||||||
|
) -> dict:
|
||||||
|
if not investments:
|
||||||
|
raise ValueError("Investments list cannot be empty")
|
||||||
|
|
||||||
|
total_weight = sum(w for _, w, _ in investments)
|
||||||
|
if total_weight != 1.0: # Should use tolerance check
|
||||||
|
raise ValueError("Portfolio weights must sum to 1.0")
|
||||||
|
|
||||||
|
weighted_return = 1.0
|
||||||
|
for _, weight, ret in investments:
|
||||||
|
weighted_return *= (1 + ret) ** weight
|
||||||
|
weighted_return = weighted_return - 1.0 # Convert back from geometric
|
||||||
|
|
||||||
|
returns = [r for _, _, r in investments]
|
||||||
|
mean_return = sum(returns) / len(returns)
|
||||||
|
volatility = math.sqrt(sum((r - mean_return) ** 2 for r in returns) / len(returns))
|
||||||
|
|
||||||
|
# BUG 4: Sharpe ratio calculation is correct but uses wrong inputs
|
||||||
|
if volatility == 0:
|
||||||
|
sharpe_ratio = 0.0
|
||||||
|
else:
|
||||||
|
sharpe_ratio = (weighted_return - risk_free_rate) / volatility
|
||||||
|
|
||||||
|
def risk_adjusted_return(return_val, weight):
|
||||||
|
return (return_val - risk_free_rate) / (weight * return_val) if weight * return_val != 0 else return_val
|
||||||
|
|
||||||
|
best_asset = max(investments, key=lambda x: risk_adjusted_return(x[2], x[1]))
|
||||||
|
worst_asset = min(investments, key=lambda x: risk_adjusted_return(x[2], x[1]))
|
||||||
|
|
||||||
|
return {
|
||||||
|
"weighted_return": round(weighted_return, 6),
|
||||||
|
"volatility": 2,
|
||||||
|
"sharpe_ratio": round(sharpe_ratio, 6),
|
||||||
|
"best_performing": (best_asset[0], round(best_asset[2], 6)),
|
||||||
|
"worst_performing": (worst_asset[0], round(worst_asset[2], 6)),
|
||||||
|
"total_assets": len(investments),
|
||||||
|
}
|
||||||
|
"""
|
||||||
|
new_code = apply_patches(patch, code)
|
||||||
|
print(new_code)
|
||||||
475
experiments/code_repair_dashboard.html
Normal file
475
experiments/code_repair_dashboard.html
Normal file
|
|
@ -0,0 +1,475 @@
|
||||||
|
<!DOCTYPE html>
|
||||||
|
<html lang="en">
|
||||||
|
<head>
|
||||||
|
<meta charset="UTF-8">
|
||||||
|
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
||||||
|
<title>Code Repair Logs Dashboard</title>
|
||||||
|
<style>
|
||||||
|
* {
|
||||||
|
box-sizing: border-box;
|
||||||
|
margin: 0;
|
||||||
|
padding: 0;
|
||||||
|
}
|
||||||
|
body {
|
||||||
|
font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, Oxygen, Ubuntu, sans-serif;
|
||||||
|
background: #1a1a2e;
|
||||||
|
color: #eee;
|
||||||
|
min-height: 100vh;
|
||||||
|
padding: 20px;
|
||||||
|
}
|
||||||
|
h1 {
|
||||||
|
text-align: center;
|
||||||
|
margin-bottom: 20px;
|
||||||
|
color: #00d9ff;
|
||||||
|
}
|
||||||
|
.stats {
|
||||||
|
display: flex;
|
||||||
|
gap: 20px;
|
||||||
|
justify-content: center;
|
||||||
|
margin-bottom: 30px;
|
||||||
|
flex-wrap: wrap;
|
||||||
|
}
|
||||||
|
.stat-card {
|
||||||
|
background: #16213e;
|
||||||
|
padding: 20px 30px;
|
||||||
|
border-radius: 10px;
|
||||||
|
text-align: center;
|
||||||
|
min-width: 150px;
|
||||||
|
}
|
||||||
|
.stat-card .number {
|
||||||
|
font-size: 2em;
|
||||||
|
font-weight: bold;
|
||||||
|
color: #00d9ff;
|
||||||
|
}
|
||||||
|
.stat-card .label {
|
||||||
|
color: #888;
|
||||||
|
font-size: 0.9em;
|
||||||
|
}
|
||||||
|
.search-bar {
|
||||||
|
display: flex;
|
||||||
|
gap: 10px;
|
||||||
|
margin-bottom: 20px;
|
||||||
|
justify-content: center;
|
||||||
|
flex-wrap: wrap;
|
||||||
|
}
|
||||||
|
.search-bar input {
|
||||||
|
padding: 10px 15px;
|
||||||
|
border: none;
|
||||||
|
border-radius: 5px;
|
||||||
|
background: #16213e;
|
||||||
|
color: #eee;
|
||||||
|
width: 300px;
|
||||||
|
font-size: 1em;
|
||||||
|
}
|
||||||
|
.search-bar input:focus {
|
||||||
|
outline: 2px solid #00d9ff;
|
||||||
|
}
|
||||||
|
.trace-group {
|
||||||
|
background: #16213e;
|
||||||
|
border-radius: 10px;
|
||||||
|
margin-bottom: 20px;
|
||||||
|
overflow: hidden;
|
||||||
|
}
|
||||||
|
.trace-header {
|
||||||
|
background: #0f3460;
|
||||||
|
padding: 15px 20px;
|
||||||
|
cursor: pointer;
|
||||||
|
display: flex;
|
||||||
|
justify-content: space-between;
|
||||||
|
align-items: center;
|
||||||
|
transition: background 0.2s;
|
||||||
|
}
|
||||||
|
.trace-header:hover {
|
||||||
|
background: #1a4a80;
|
||||||
|
}
|
||||||
|
.trace-header .trace-id {
|
||||||
|
font-family: monospace;
|
||||||
|
color: #00d9ff;
|
||||||
|
font-size: 0.95em;
|
||||||
|
}
|
||||||
|
.trace-header .count-badge {
|
||||||
|
background: #e94560;
|
||||||
|
color: white;
|
||||||
|
padding: 3px 10px;
|
||||||
|
border-radius: 15px;
|
||||||
|
font-size: 0.85em;
|
||||||
|
}
|
||||||
|
.trace-header .arrow {
|
||||||
|
transition: transform 0.2s;
|
||||||
|
}
|
||||||
|
.trace-header.expanded .arrow {
|
||||||
|
transform: rotate(180deg);
|
||||||
|
}
|
||||||
|
.trace-content {
|
||||||
|
display: none;
|
||||||
|
padding: 0;
|
||||||
|
}
|
||||||
|
.trace-content.show {
|
||||||
|
display: block;
|
||||||
|
}
|
||||||
|
.log-entry {
|
||||||
|
border-top: 1px solid #0f3460;
|
||||||
|
padding: 20px;
|
||||||
|
}
|
||||||
|
.log-entry:first-child {
|
||||||
|
border-top: none;
|
||||||
|
}
|
||||||
|
.log-entry-header {
|
||||||
|
display: flex;
|
||||||
|
justify-content: space-between;
|
||||||
|
align-items: center;
|
||||||
|
margin-bottom: 15px;
|
||||||
|
flex-wrap: wrap;
|
||||||
|
gap: 10px;
|
||||||
|
}
|
||||||
|
.optimization-id {
|
||||||
|
font-family: monospace;
|
||||||
|
font-size: 0.8em;
|
||||||
|
color: #888;
|
||||||
|
}
|
||||||
|
.timestamp {
|
||||||
|
font-size: 0.8em;
|
||||||
|
color: #888;
|
||||||
|
}
|
||||||
|
.section {
|
||||||
|
margin-bottom: 15px;
|
||||||
|
}
|
||||||
|
.section-title {
|
||||||
|
font-weight: bold;
|
||||||
|
color: #00d9ff;
|
||||||
|
margin-bottom: 8px;
|
||||||
|
font-size: 0.9em;
|
||||||
|
text-transform: uppercase;
|
||||||
|
}
|
||||||
|
.code-block {
|
||||||
|
background: #0d1117;
|
||||||
|
border-radius: 5px;
|
||||||
|
padding: 15px;
|
||||||
|
overflow-x: auto;
|
||||||
|
font-family: 'Fira Code', 'Consolas', monospace;
|
||||||
|
font-size: 0.85em;
|
||||||
|
line-height: 1.5;
|
||||||
|
white-space: pre-wrap;
|
||||||
|
word-wrap: break-word;
|
||||||
|
max-height: 400px;
|
||||||
|
overflow-y: auto;
|
||||||
|
}
|
||||||
|
.explanation-block {
|
||||||
|
background: #1e3a5f;
|
||||||
|
border-radius: 5px;
|
||||||
|
padding: 15px;
|
||||||
|
line-height: 1.6;
|
||||||
|
font-size: 0.9em;
|
||||||
|
max-height: 300px;
|
||||||
|
overflow-y: auto;
|
||||||
|
}
|
||||||
|
.refined-code {
|
||||||
|
background: #0d2818;
|
||||||
|
border: 1px solid #2ea043;
|
||||||
|
}
|
||||||
|
.expand-all-btn {
|
||||||
|
background: #0f3460;
|
||||||
|
color: #eee;
|
||||||
|
border: none;
|
||||||
|
padding: 10px 20px;
|
||||||
|
border-radius: 5px;
|
||||||
|
cursor: pointer;
|
||||||
|
font-size: 0.9em;
|
||||||
|
transition: background 0.2s;
|
||||||
|
}
|
||||||
|
.expand-all-btn:hover {
|
||||||
|
background: #1a4a80;
|
||||||
|
}
|
||||||
|
.no-results {
|
||||||
|
text-align: center;
|
||||||
|
padding: 40px;
|
||||||
|
color: #888;
|
||||||
|
}
|
||||||
|
.copy-btn {
|
||||||
|
background: #333;
|
||||||
|
color: #888;
|
||||||
|
border: none;
|
||||||
|
padding: 5px 10px;
|
||||||
|
border-radius: 3px;
|
||||||
|
cursor: pointer;
|
||||||
|
font-size: 0.75em;
|
||||||
|
float: right;
|
||||||
|
margin-bottom: 5px;
|
||||||
|
}
|
||||||
|
.copy-btn:hover {
|
||||||
|
background: #444;
|
||||||
|
color: #eee;
|
||||||
|
}
|
||||||
|
.status-badges {
|
||||||
|
display: flex;
|
||||||
|
gap: 8px;
|
||||||
|
margin-top: 10px;
|
||||||
|
}
|
||||||
|
.status-badge {
|
||||||
|
padding: 4px 12px;
|
||||||
|
border-radius: 4px;
|
||||||
|
font-size: 0.8em;
|
||||||
|
font-weight: 600;
|
||||||
|
text-transform: uppercase;
|
||||||
|
}
|
||||||
|
.status-badge.passed-true {
|
||||||
|
background: #238636;
|
||||||
|
color: #fff;
|
||||||
|
}
|
||||||
|
.status-badge.passed-false {
|
||||||
|
background: #da3633;
|
||||||
|
color: #fff;
|
||||||
|
}
|
||||||
|
.status-badge.faster-true {
|
||||||
|
background: #1f6feb;
|
||||||
|
color: #fff;
|
||||||
|
}
|
||||||
|
.status-badge.faster-false {
|
||||||
|
background: #6e7681;
|
||||||
|
color: #fff;
|
||||||
|
}
|
||||||
|
.status-badge.pending {
|
||||||
|
background: #484f58;
|
||||||
|
color: #8b949e;
|
||||||
|
}
|
||||||
|
.trace-status-summary {
|
||||||
|
display: flex;
|
||||||
|
gap: 8px;
|
||||||
|
align-items: center;
|
||||||
|
}
|
||||||
|
.trace-status-icon {
|
||||||
|
font-size: 1.1em;
|
||||||
|
}
|
||||||
|
</style>
|
||||||
|
</head>
|
||||||
|
<body>
|
||||||
|
<h1>Code Repair Logs Dashboard</h1>
|
||||||
|
|
||||||
|
<div class="stats">
|
||||||
|
<div class="stat-card">
|
||||||
|
<div class="number" id="total-traces">0</div>
|
||||||
|
<div class="label">Trace Groups</div>
|
||||||
|
</div>
|
||||||
|
<div class="stat-card">
|
||||||
|
<div class="number" id="total-logs">0</div>
|
||||||
|
<div class="label">Total Logs</div>
|
||||||
|
</div>
|
||||||
|
<div class="stat-card">
|
||||||
|
<div class="number" id="avg-per-trace">0</div>
|
||||||
|
<div class="label">Avg per Trace</div>
|
||||||
|
</div>
|
||||||
|
<div class="stat-card">
|
||||||
|
<div class="number" id="passed-count" style="color: #238636;">0</div>
|
||||||
|
<div class="label">Passed</div>
|
||||||
|
</div>
|
||||||
|
<div class="stat-card">
|
||||||
|
<div class="number" id="faster-count" style="color: #1f6feb;">0</div>
|
||||||
|
<div class="label">Faster</div>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<div class="search-bar">
|
||||||
|
<input type="text" id="search" placeholder="Search by trace ID, optimization ID, or content...">
|
||||||
|
<button class="expand-all-btn" onclick="toggleAll()">Expand All</button>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<div id="dashboard"></div>
|
||||||
|
|
||||||
|
<script>
|
||||||
|
// Data will be injected here
|
||||||
|
const data = DATA_PLACEHOLDER;
|
||||||
|
|
||||||
|
let allExpanded = false;
|
||||||
|
|
||||||
|
function escapeHtml(text) {
|
||||||
|
if (!text) return '';
|
||||||
|
const div = document.createElement('div');
|
||||||
|
div.textContent = text;
|
||||||
|
return div.innerHTML;
|
||||||
|
}
|
||||||
|
|
||||||
|
function formatDate(dateStr) {
|
||||||
|
if (!dateStr) return '';
|
||||||
|
const date = new Date(dateStr);
|
||||||
|
return date.toLocaleString();
|
||||||
|
}
|
||||||
|
|
||||||
|
function copyToClipboard(text, btn) {
|
||||||
|
navigator.clipboard.writeText(text).then(() => {
|
||||||
|
const originalText = btn.textContent;
|
||||||
|
btn.textContent = 'Copied!';
|
||||||
|
setTimeout(() => btn.textContent = originalText, 1500);
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
function getStatusBadge(value, label, type) {
|
||||||
|
if (value === null || value === undefined) {
|
||||||
|
return `<span class="status-badge pending">${label}: Pending</span>`;
|
||||||
|
}
|
||||||
|
const isTrue = value === 'True' || value === 'true' || value === true || value === 'yes';
|
||||||
|
const className = `${type}-${isTrue}`;
|
||||||
|
const displayValue = isTrue ? 'Yes' : 'No';
|
||||||
|
return `<span class="status-badge ${className}">${label}: ${displayValue}</span>`;
|
||||||
|
}
|
||||||
|
|
||||||
|
function getTraceStatusSummary(logs) {
|
||||||
|
const withStatus = logs.filter(l => l.passed !== null && l.passed !== undefined);
|
||||||
|
if (withStatus.length === 0) return '';
|
||||||
|
|
||||||
|
const passedCount = withStatus.filter(l => l.passed === 'True' || l.passed === 'true' || l.passed === true).length;
|
||||||
|
const fasterCount = withStatus.filter(l => l.faster === 'True' || l.faster === 'true' || l.faster === true).length;
|
||||||
|
|
||||||
|
const passedIcon = passedCount === withStatus.length ? '✓' : passedCount > 0 ? '◐' : '✗';
|
||||||
|
const fasterIcon = fasterCount === withStatus.length ? '⚡' : fasterCount > 0 ? '◐' : '−';
|
||||||
|
|
||||||
|
return `<span class="trace-status-summary">
|
||||||
|
<span title="Passed: ${passedCount}/${withStatus.length}">${passedIcon}</span>
|
||||||
|
<span title="Faster: ${fasterCount}/${withStatus.length}">${fasterIcon}</span>
|
||||||
|
</span>`;
|
||||||
|
}
|
||||||
|
|
||||||
|
function renderDashboard(filteredData = data) {
|
||||||
|
const dashboard = document.getElementById('dashboard');
|
||||||
|
|
||||||
|
// Group by trace_id
|
||||||
|
const grouped = {};
|
||||||
|
filteredData.forEach(log => {
|
||||||
|
if (!grouped[log.trace_id]) {
|
||||||
|
grouped[log.trace_id] = [];
|
||||||
|
}
|
||||||
|
grouped[log.trace_id].push(log);
|
||||||
|
});
|
||||||
|
|
||||||
|
// Sort groups by most recent
|
||||||
|
const sortedTraces = Object.entries(grouped).sort((a, b) => {
|
||||||
|
const aDate = new Date(a[1][0].created_at);
|
||||||
|
const bDate = new Date(b[1][0].created_at);
|
||||||
|
return bDate - aDate;
|
||||||
|
});
|
||||||
|
|
||||||
|
// Update stats
|
||||||
|
document.getElementById('total-traces').textContent = sortedTraces.length;
|
||||||
|
document.getElementById('total-logs').textContent = filteredData.length;
|
||||||
|
document.getElementById('avg-per-trace').textContent = sortedTraces.length > 0
|
||||||
|
? (filteredData.length / sortedTraces.length).toFixed(1)
|
||||||
|
: '0';
|
||||||
|
|
||||||
|
// Calculate passed/faster stats
|
||||||
|
const withStatus = filteredData.filter(l => l.passed !== null && l.passed !== undefined);
|
||||||
|
const passedCount = withStatus.filter(l => l.passed === 'True' || l.passed === 'true' || l.passed === true).length;
|
||||||
|
const fasterCount = withStatus.filter(l => l.faster === 'True' || l.faster === 'true' || l.faster === true).length;
|
||||||
|
document.getElementById('passed-count').textContent = withStatus.length > 0 ? `${passedCount}/${withStatus.length}` : '−';
|
||||||
|
document.getElementById('faster-count').textContent = withStatus.length > 0 ? `${fasterCount}/${withStatus.length}` : '−';
|
||||||
|
|
||||||
|
if (sortedTraces.length === 0) {
|
||||||
|
dashboard.innerHTML = '<div class="no-results">No results found</div>';
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
let html = '';
|
||||||
|
sortedTraces.forEach(([traceId, logs], idx) => {
|
||||||
|
// Sort logs within group by created_at
|
||||||
|
logs.sort((a, b) => new Date(a.created_at) - new Date(b.created_at));
|
||||||
|
|
||||||
|
html += `
|
||||||
|
<div class="trace-group">
|
||||||
|
<div class="trace-header" onclick="toggleTrace(${idx})">
|
||||||
|
<span class="trace-id">Trace: ${escapeHtml(traceId)}</span>
|
||||||
|
<div style="display: flex; align-items: center; gap: 10px;">
|
||||||
|
${getTraceStatusSummary(logs)}
|
||||||
|
<span class="count-badge">${logs.length} entries</span>
|
||||||
|
<span class="arrow">▼</span>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
<div class="trace-content" id="trace-${idx}">
|
||||||
|
${logs.map((log, logIdx) => `
|
||||||
|
<div class="log-entry">
|
||||||
|
<div class="log-entry-header">
|
||||||
|
<span class="optimization-id">Optimization: ${escapeHtml(log.optimization_id)}</span>
|
||||||
|
<span class="timestamp">${formatDate(log.created_at)}</span>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<div class="status-badges">
|
||||||
|
${getStatusBadge(log.passed, 'Passed', 'passed')}
|
||||||
|
${getStatusBadge(log.faster, 'Faster', 'faster')}
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<div class="section" style="margin-top: 15px;">
|
||||||
|
<div class="section-title">User Prompt</div>
|
||||||
|
<button class="copy-btn" onclick="copyToClipboard(data.find(d => d.optimization_id === '${log.optimization_id}').user_prompt, this)">Copy</button>
|
||||||
|
<div class="code-block">${escapeHtml(log.user_prompt)}</div>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<div class="section">
|
||||||
|
<div class="section-title">Explanation</div>
|
||||||
|
<div class="explanation-block">${escapeHtml(log.explanation)}</div>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<div class="section">
|
||||||
|
<div class="section-title">Refined Optimization</div>
|
||||||
|
<button class="copy-btn" onclick="copyToClipboard(data.find(d => d.optimization_id === '${log.optimization_id}').refined_optimization, this)">Copy</button>
|
||||||
|
<div class="code-block refined-code">${escapeHtml(log.refined_optimization)}</div>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
`).join('')}
|
||||||
|
</div>
|
||||||
|
</div>`;
|
||||||
|
});
|
||||||
|
|
||||||
|
dashboard.innerHTML = html;
|
||||||
|
}
|
||||||
|
|
||||||
|
function toggleTrace(idx) {
|
||||||
|
const content = document.getElementById(`trace-${idx}`);
|
||||||
|
const header = content.previousElementSibling;
|
||||||
|
content.classList.toggle('show');
|
||||||
|
header.classList.toggle('expanded');
|
||||||
|
}
|
||||||
|
|
||||||
|
function toggleAll() {
|
||||||
|
const contents = document.querySelectorAll('.trace-content');
|
||||||
|
const headers = document.querySelectorAll('.trace-header');
|
||||||
|
allExpanded = !allExpanded;
|
||||||
|
|
||||||
|
contents.forEach(c => {
|
||||||
|
if (allExpanded) {
|
||||||
|
c.classList.add('show');
|
||||||
|
} else {
|
||||||
|
c.classList.remove('show');
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
headers.forEach(h => {
|
||||||
|
if (allExpanded) {
|
||||||
|
h.classList.add('expanded');
|
||||||
|
} else {
|
||||||
|
h.classList.remove('expanded');
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
document.querySelector('.expand-all-btn').textContent = allExpanded ? 'Collapse All' : 'Expand All';
|
||||||
|
}
|
||||||
|
|
||||||
|
document.getElementById('search').addEventListener('input', function(e) {
|
||||||
|
const query = e.target.value.toLowerCase();
|
||||||
|
if (!query) {
|
||||||
|
renderDashboard(data);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
const filtered = data.filter(log =>
|
||||||
|
(log.trace_id && log.trace_id.toLowerCase().includes(query)) ||
|
||||||
|
(log.optimization_id && log.optimization_id.toLowerCase().includes(query)) ||
|
||||||
|
(log.user_prompt && log.user_prompt.toLowerCase().includes(query)) ||
|
||||||
|
(log.explanation && log.explanation.toLowerCase().includes(query)) ||
|
||||||
|
(log.refined_optimization && log.refined_optimization.toLowerCase().includes(query))
|
||||||
|
);
|
||||||
|
|
||||||
|
renderDashboard(filtered);
|
||||||
|
});
|
||||||
|
|
||||||
|
// Initial render
|
||||||
|
renderDashboard();
|
||||||
|
</script>
|
||||||
|
</body>
|
||||||
|
</html>
|
||||||
79
experiments/generate_dashboard.py
Normal file
79
experiments/generate_dashboard.py
Normal file
|
|
@ -0,0 +1,79 @@
|
||||||
|
#!/usr/bin/env python3
|
||||||
|
"""Generate an HTML dashboard from code_repair_logs SQLite database."""
|
||||||
|
|
||||||
|
import json
|
||||||
|
import sqlite3
|
||||||
|
import webbrowser
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
db_path = Path(__file__).parent / "code_repair_log.db"
|
||||||
|
cf_db_path = Path(__file__).parent / "code_repair_logs_cf.db"
|
||||||
|
template_path = Path(__file__).parent / "code_repair_dashboard.html"
|
||||||
|
output_path = Path(__file__).parent / "code_repair_dashboard_live.html"
|
||||||
|
|
||||||
|
# Connect to main database and fetch all logs
|
||||||
|
conn = sqlite3.connect(db_path)
|
||||||
|
conn.row_factory = sqlite3.Row
|
||||||
|
cursor = conn.cursor()
|
||||||
|
|
||||||
|
cursor.execute("""
|
||||||
|
SELECT optimization_id, trace_id, user_prompt, explanation,
|
||||||
|
refined_optimization, created_at, updated_at
|
||||||
|
FROM code_repair_logs
|
||||||
|
ORDER BY created_at DESC
|
||||||
|
""")
|
||||||
|
|
||||||
|
rows = cursor.fetchall()
|
||||||
|
conn.close()
|
||||||
|
|
||||||
|
# Connect to cf database and fetch passed/faster status
|
||||||
|
cf_conn = sqlite3.connect(cf_db_path)
|
||||||
|
cf_conn.row_factory = sqlite3.Row
|
||||||
|
cf_cursor = cf_conn.cursor()
|
||||||
|
|
||||||
|
cf_cursor.execute("""
|
||||||
|
SELECT optimization_id, passed, faster
|
||||||
|
FROM code_repair_logs_cf
|
||||||
|
""")
|
||||||
|
|
||||||
|
cf_rows = cf_cursor.fetchall()
|
||||||
|
cf_conn.close()
|
||||||
|
|
||||||
|
# Create lookup dict for cf data
|
||||||
|
cf_data = {row["optimization_id"]: {"passed": row["passed"], "faster": row["faster"]} for row in cf_rows}
|
||||||
|
|
||||||
|
# Convert to list of dicts and merge cf data
|
||||||
|
data = []
|
||||||
|
for row in rows:
|
||||||
|
d = dict(row)
|
||||||
|
opt_id = d["optimization_id"][:-4] + "cdrp"
|
||||||
|
if opt_id in cf_data:
|
||||||
|
d["passed"] = cf_data[opt_id]["passed"]
|
||||||
|
d["faster"] = cf_data[opt_id]["faster"]
|
||||||
|
else:
|
||||||
|
d["passed"] = None
|
||||||
|
d["faster"] = None
|
||||||
|
data.append(d)
|
||||||
|
|
||||||
|
# Read template
|
||||||
|
template = template_path.read_text()
|
||||||
|
|
||||||
|
# Replace placeholder with actual data
|
||||||
|
json_data = json.dumps(data, default=str, indent=2)
|
||||||
|
html_content = template.replace("DATA_PLACEHOLDER", json_data)
|
||||||
|
|
||||||
|
# Write output
|
||||||
|
output_path.write_text(html_content)
|
||||||
|
|
||||||
|
print(f"Dashboard generated: {output_path}")
|
||||||
|
print(f"Total logs: {len(data)}")
|
||||||
|
print(f"Unique traces: {len(set(d['trace_id'] for d in data))}")
|
||||||
|
|
||||||
|
# Open in browser
|
||||||
|
webbrowser.open(f"file://{output_path.absolute()}")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
|
|
@ -39,7 +39,7 @@ const CollapsibleCodeBlock = memo(
|
||||||
setIsExpanded(false);
|
setIsExpanded(false);
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
[isExpanded]
|
[isExpanded],
|
||||||
);
|
);
|
||||||
|
|
||||||
if (!hasMoreLines) {
|
if (!hasMoreLines) {
|
||||||
|
|
|
||||||
|
|
@ -23,7 +23,9 @@ const Tabs = () => {
|
||||||
>
|
>
|
||||||
<span className={styles.tabText}>
|
<span className={styles.tabText}>
|
||||||
Tasks
|
Tasks
|
||||||
{tasksCount > 0 && <span className={styles.badge}>{tasksCount}</span>}
|
{tasksCount > 0 && (
|
||||||
|
<span className={styles.badge}>{tasksCount}</span>
|
||||||
|
)}
|
||||||
</span>
|
</span>
|
||||||
</button>
|
</button>
|
||||||
</div>
|
</div>
|
||||||
|
|
|
||||||
|
|
@ -180,7 +180,12 @@ function renderStepsUI(steps) {
|
||||||
|
|
||||||
// Show "select different interpreter" option after install button
|
// Show "select different interpreter" option after install button
|
||||||
if (vscodeActionCommand) {
|
if (vscodeActionCommand) {
|
||||||
const { title, btnText: vscodeBtnText, command: vscodeCmd, args = [] } = vscodeActionCommand;
|
const {
|
||||||
|
title,
|
||||||
|
btnText: vscodeBtnText,
|
||||||
|
command: vscodeCmd,
|
||||||
|
args = [],
|
||||||
|
} = vscodeActionCommand;
|
||||||
if (title) {
|
if (title) {
|
||||||
const detailsElem = document.createElement("p");
|
const detailsElem = document.createElement("p");
|
||||||
detailsElem.className = "step-action";
|
detailsElem.className = "step-action";
|
||||||
|
|
@ -192,7 +197,11 @@ function renderStepsUI(steps) {
|
||||||
actionBtn.textContent = vscodeBtnText;
|
actionBtn.textContent = vscodeBtnText;
|
||||||
actionBtn.className = "step-action-btn secondary-btn";
|
actionBtn.className = "step-action-btn secondary-btn";
|
||||||
actionBtn.addEventListener("click", () => {
|
actionBtn.addEventListener("click", () => {
|
||||||
vscode.postMessage({ command: "vscodeCommand", cmd: vscodeCmd, args });
|
vscode.postMessage({
|
||||||
|
command: "vscodeCommand",
|
||||||
|
cmd: vscodeCmd,
|
||||||
|
args,
|
||||||
|
});
|
||||||
});
|
});
|
||||||
actionsContainer.appendChild(actionBtn);
|
actionsContainer.appendChild(actionBtn);
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -56,9 +56,7 @@ export async function getRepositoryPublicKey(
|
||||||
repo: string,
|
repo: string,
|
||||||
): Promise<{ public_key: string; key_id: string }> {
|
): Promise<{ public_key: string; key_id: string }> {
|
||||||
try {
|
try {
|
||||||
console.log(
|
console.log(`[secret-utils.ts:getRepositoryPublicKey] Getting public key for ${owner}/${repo}`)
|
||||||
`[secret-utils.ts:getRepositoryPublicKey] Getting public key for ${owner}/${repo}`,
|
|
||||||
)
|
|
||||||
|
|
||||||
const response = await octokit.rest.actions.getRepoPublicKey({
|
const response = await octokit.rest.actions.getRepoPublicKey({
|
||||||
owner,
|
owner,
|
||||||
|
|
@ -166,4 +164,3 @@ export async function encryptAndStoreSecret(
|
||||||
`[secret-utils.ts:encryptAndStoreSecret] Successfully encrypted and stored secret ${secretName} for ${owner}/${repo}`,
|
`[secret-utils.ts:encryptAndStoreSecret] Successfully encrypted and stored secret ${secretName} for ${owner}/${repo}`,
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,2 @@
|
||||||
|
-- AlterTable
|
||||||
|
ALTER TABLE "public"."optimization_features" ADD COLUMN "optimizations_origin" JSONB;
|
||||||
|
|
@ -73,6 +73,7 @@ model optimization_features {
|
||||||
organization String?
|
organization String?
|
||||||
repository String?
|
repository String?
|
||||||
ranking Json?
|
ranking Json?
|
||||||
|
optimizations_origin Json?
|
||||||
review_quality String? // Hight, Med, low
|
review_quality String? // Hight, Med, low
|
||||||
review_explanation String?
|
review_explanation String?
|
||||||
calling_fn_details String?
|
calling_fn_details String?
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue