mirror of
https://github.com/codeflash-ai/codeflash-internal.git
synced 2026-05-04 18:25:18 +00:00
Wire baseline runtime, test examples, and LP diversity into optimizer
Accept baseline_runtime_ns, loop_count, line_profiler_results, and test_input_examples on the optimize endpoint. Pass runtime context and test examples to the user prompt so the LLM can generate better-informed candidates. Alternate line profiler data across parallel calls for diversity (odd calls get LP, even calls don't).
This commit is contained in:
parent
9b3cd48048
commit
ccfe0998e7
6 changed files with 97 additions and 4 deletions
|
|
@ -18,6 +18,8 @@ MARKDOWN_CONTEXT_PROMPT = (
|
|||
DEPS_CONTEXT_PROMPT = (parent_dir / "dependency_context_prompt.md").read_text()
|
||||
INIT_OPTIMIZATION_PROMPT = (parent_dir / "init_optimization_prompt.md").read_text()
|
||||
LINE_PROF_CONTEXT_PROMPT = (parent_dir / "lineprof_context_prompt.md").read_text()
|
||||
RUNTIME_CONTEXT_PROMPT = (parent_dir / "runtime_context_prompt.md").read_text()
|
||||
TEST_EXAMPLES_PROMPT = (parent_dir / "test_examples_prompt.md").read_text()
|
||||
|
||||
FULL_CODE_PROMPT_INSTRUCTIONS = """- Always provide the FULL, updated content of the artifact. This means:
|
||||
- Include ALL updated code, even if some parts are unchanged
|
||||
|
|
|
|||
|
|
@ -17,6 +17,8 @@ from core.languages.python.optimizer.context_utils.constants import (
|
|||
INIT_OPTIMIZATION_PROMPT,
|
||||
LINE_PROF_CONTEXT_PROMPT,
|
||||
MARKDOWN_CONTEXT_PROMPT,
|
||||
RUNTIME_CONTEXT_PROMPT,
|
||||
TEST_EXAMPLES_PROMPT,
|
||||
)
|
||||
from core.languages.python.optimizer.context_utils.context_helpers import is_multi_context
|
||||
from core.languages.python.optimizer.diff_patches_utils.diff import (
|
||||
|
|
@ -35,6 +37,17 @@ if TYPE_CHECKING:
|
|||
from core.languages.python.optimizer.diff_patches_utils.diff import Diff
|
||||
|
||||
|
||||
def _humanize_ns(ns: int) -> str:
|
||||
"""Convert nanoseconds to a human-readable string."""
|
||||
if ns < 1_000:
|
||||
return f"{ns}ns"
|
||||
if ns < 1_000_000:
|
||||
return f"{ns / 1_000:.1f}us"
|
||||
if ns < 1_000_000_000:
|
||||
return f"{ns / 1_000_000:.1f}ms"
|
||||
return f"{ns / 1_000_000_000:.2f}s"
|
||||
|
||||
|
||||
@dataclass
|
||||
class CodeStrAndExplanation:
|
||||
code: str
|
||||
|
|
@ -74,7 +87,14 @@ class BaseOptimizerContext:
|
|||
def get_system_prompt(self, python_version_str: str) -> str: # noqa: ARG002
|
||||
return self.base_system_prompt
|
||||
|
||||
def get_user_prompt(self, dependency_code: str, line_profiler_results: str | None) -> str: # noqa: ARG002
|
||||
def get_user_prompt( # noqa: ARG002
|
||||
self,
|
||||
dependency_code: str,
|
||||
line_profiler_results: str | None,
|
||||
baseline_runtime_ns: int | None = None,
|
||||
loop_count: int | None = None,
|
||||
test_input_examples: str | None = None,
|
||||
) -> str:
|
||||
return self.base_user_prompt
|
||||
|
||||
def extract_code_and_explanation_from_llm_res(self, content: str) -> CodeStrAndExplanation:
|
||||
|
|
@ -116,12 +136,31 @@ class SingleOptimizerContext(BaseOptimizerContext):
|
|||
+ EXPLANATION_THEN_CODE
|
||||
)
|
||||
|
||||
def get_user_prompt(self, dependency_code: str, line_profiler_results: str | None) -> str:
|
||||
def get_user_prompt(
|
||||
self,
|
||||
dependency_code: str,
|
||||
line_profiler_results: str | None,
|
||||
baseline_runtime_ns: int | None = None,
|
||||
loop_count: int | None = None,
|
||||
test_input_examples: str | None = None,
|
||||
) -> str:
|
||||
markdown_source_code = wrap_code_in_markdown(self.source_code)
|
||||
runtime_part = ""
|
||||
if baseline_runtime_ns is not None and loop_count is not None:
|
||||
runtime_part = RUNTIME_CONTEXT_PROMPT.format(
|
||||
baseline_runtime_ns=baseline_runtime_ns,
|
||||
baseline_runtime_human=_humanize_ns(baseline_runtime_ns),
|
||||
loop_count=loop_count,
|
||||
)
|
||||
test_examples_part = ""
|
||||
if test_input_examples:
|
||||
test_examples_part = TEST_EXAMPLES_PROMPT.format(test_input_examples=test_input_examples)
|
||||
return (
|
||||
f"{DEPS_CONTEXT_PROMPT.format(dependency_code=dependency_code) if dependency_code else ''}"
|
||||
f"{self.base_user_prompt.format(source_code=markdown_source_code, init_optimization_prompt=INIT_OPTIMIZATION_PROMPT if find_init(ast.parse(self.source_code)) else '')}"
|
||||
f"{runtime_part}"
|
||||
f"{LINE_PROF_CONTEXT_PROMPT.format(line_profiler_results=line_profiler_results) if line_profiler_results else ''}"
|
||||
f"{test_examples_part}"
|
||||
)
|
||||
|
||||
def extract_code_and_explanation_from_llm_res(self, content: str) -> CodeStrAndExplanation:
|
||||
|
|
@ -215,12 +254,31 @@ class MultiOptimizerContext(BaseOptimizerContext):
|
|||
+ code_format_instructions
|
||||
)
|
||||
|
||||
def get_user_prompt(self, dependency_code: str, line_profiler_results: str | None) -> str:
|
||||
def get_user_prompt(
|
||||
self,
|
||||
dependency_code: str,
|
||||
line_profiler_results: str | None,
|
||||
baseline_runtime_ns: int | None = None,
|
||||
loop_count: int | None = None,
|
||||
test_input_examples: str | None = None,
|
||||
) -> str:
|
||||
has_init = any(find_init(ast.parse(code)) for code in self.original_file_to_code.values())
|
||||
runtime_part = ""
|
||||
if baseline_runtime_ns is not None and loop_count is not None:
|
||||
runtime_part = RUNTIME_CONTEXT_PROMPT.format(
|
||||
baseline_runtime_ns=baseline_runtime_ns,
|
||||
baseline_runtime_human=_humanize_ns(baseline_runtime_ns),
|
||||
loop_count=loop_count,
|
||||
)
|
||||
test_examples_part = ""
|
||||
if test_input_examples:
|
||||
test_examples_part = TEST_EXAMPLES_PROMPT.format(test_input_examples=test_input_examples)
|
||||
return (
|
||||
f"{DEPS_CONTEXT_PROMPT.format(dependency_code=dependency_code) if dependency_code else ''}"
|
||||
f"{self.base_user_prompt.format(source_code=self.source_code, init_optimization_prompt=INIT_OPTIMIZATION_PROMPT if has_init else '')}"
|
||||
f"{runtime_part}"
|
||||
f"{LINE_PROF_CONTEXT_PROMPT.format(line_profiler_results=line_profiler_results) if line_profiler_results else ''}"
|
||||
f"{test_examples_part}"
|
||||
)
|
||||
|
||||
def extract_code_and_explanation_from_llm_res(self, content: str) -> CodeStrAndExplanation:
|
||||
|
|
|
|||
|
|
@ -53,6 +53,10 @@ async def generate_optimization_candidate(
|
|||
optimize_model: LLM = OPTIMIZE_MODEL,
|
||||
python_version: tuple[int, int, int] = (3, 12, 9),
|
||||
call_sequence: int | None = None,
|
||||
baseline_runtime_ns: int | None = None,
|
||||
loop_count: int | None = None,
|
||||
line_profiler_results: str | None = None,
|
||||
test_input_examples: str | None = None,
|
||||
) -> tuple[OptimizeResponseItemSchema | None, float | None, str]:
|
||||
"""Optimize the given python code for performance using LLMs."""
|
||||
logging.info("/optimize: Optimizing python code.")
|
||||
|
|
@ -64,7 +68,13 @@ async def generate_optimization_candidate(
|
|||
python_version_str = ".".join(str(x) for x in python_version)
|
||||
|
||||
system_prompt = ctx.get_system_prompt(python_version_str)
|
||||
user_prompt = ctx.get_user_prompt(dependency_code or "", None)
|
||||
user_prompt = ctx.get_user_prompt(
|
||||
dependency_code or "",
|
||||
line_profiler_results,
|
||||
baseline_runtime_ns=baseline_runtime_ns,
|
||||
loop_count=loop_count,
|
||||
test_input_examples=test_input_examples,
|
||||
)
|
||||
|
||||
obs_context: dict[str, Any] | None = {"call_sequence": call_sequence} if call_sequence is not None else None
|
||||
|
||||
|
|
@ -123,6 +133,10 @@ async def optimize_python_code(
|
|||
dependency_code: str | None = None,
|
||||
python_version: tuple[int, int, int] = (3, 12, 9),
|
||||
n_candidates: int = 0,
|
||||
baseline_runtime_ns: int | None = None,
|
||||
loop_count: int | None = None,
|
||||
line_profiler_results: str | None = None,
|
||||
test_input_examples: str | None = None,
|
||||
) -> tuple[list[OptimizeResponseItemSchema], float, dict[str, dict[str, str]], dict[str, str]]:
|
||||
"""Run parallel optimizations with multiple models based on the distribution config.
|
||||
|
||||
|
|
@ -151,6 +165,8 @@ async def optimize_python_code(
|
|||
task_ctx = BaseOptimizerContext.get_dynamic_context(
|
||||
ctx.base_system_prompt, ctx.base_user_prompt, original_source_code
|
||||
)
|
||||
# Diversity: odd-numbered calls include line profiler, even-numbered calls don't
|
||||
lp_for_this_call = line_profiler_results if call_sequence % 2 == 1 else None
|
||||
task = tg.create_task(
|
||||
generate_optimization_candidate(
|
||||
user_id=user_id,
|
||||
|
|
@ -160,6 +176,10 @@ async def optimize_python_code(
|
|||
optimize_model=model,
|
||||
python_version=python_version,
|
||||
call_sequence=call_sequence,
|
||||
baseline_runtime_ns=baseline_runtime_ns,
|
||||
loop_count=loop_count,
|
||||
line_profiler_results=lp_for_this_call,
|
||||
test_input_examples=test_input_examples,
|
||||
)
|
||||
)
|
||||
tasks.append((task, task_ctx))
|
||||
|
|
@ -269,6 +289,10 @@ async def optimize_python(
|
|||
dependency_code=data.dependency_code,
|
||||
python_version=python_version,
|
||||
n_candidates=data.n_candidates,
|
||||
baseline_runtime_ns=data.baseline_runtime_ns,
|
||||
loop_count=data.loop_count,
|
||||
line_profiler_results=data.line_profiler_results,
|
||||
test_input_examples=data.test_input_examples,
|
||||
)
|
||||
)
|
||||
user_task = None
|
||||
|
|
|
|||
|
|
@ -0,0 +1,2 @@
|
|||
|
||||
The current measured runtime of this function is {baseline_runtime_ns} nanoseconds ({baseline_runtime_human}) over {loop_count} benchmark loops. Focus your optimization on changes that will meaningfully reduce this runtime.
|
||||
|
|
@ -0,0 +1,3 @@
|
|||
|
||||
Here are example test inputs and expected outputs for the function you are optimizing. Your optimization MUST preserve the same behavior for these inputs.
|
||||
{test_input_examples}
|
||||
|
|
@ -36,6 +36,10 @@ class OptimizeSchema(Schema):
|
|||
n_candidates: int = 5 # default value for backward compatibility
|
||||
is_numerical_code: bool | None = None
|
||||
rerun_trace_id: str | None = None
|
||||
baseline_runtime_ns: int | None = None
|
||||
loop_count: int | None = None
|
||||
line_profiler_results: str | None = None
|
||||
test_input_examples: str | None = None
|
||||
|
||||
@model_validator(mode="after")
|
||||
def resolve_python_version(self) -> Self:
|
||||
|
|
|
|||
Loading…
Reference in a new issue