mirror of
https://github.com/codeflash-ai/codeflash-agent.git
synced 2026-05-04 18:25:19 +00:00
Wire baseline runtime, test examples, and LP diversity to AI service
Send baseline_runtime_ns, loop_count, test_input_examples, and line_profiler_results from the client to the optimize endpoint so the AI service can generate better-informed candidates. Restructure the per-function optimizer to establish baseline before candidate generation, and alternate line profiler data across calls for diversity.
This commit is contained in:
parent
f23822b919
commit
92e7a23722
4 changed files with 90 additions and 34 deletions
|
|
@ -115,7 +115,7 @@ class AIClient:
|
|||
if not trace_id:
|
||||
trace_id = str(uuid.uuid4())
|
||||
|
||||
payload = {
|
||||
payload: dict[str, Any] = {
|
||||
"source_code": request.source_code,
|
||||
"dependency_code": request.context_code,
|
||||
"trace_id": trace_id,
|
||||
|
|
@ -127,6 +127,14 @@ class AIClient:
|
|||
"is_numerical_code": request.is_numerical_code,
|
||||
"codeflash_version": request.codeflash_version,
|
||||
}
|
||||
if request.baseline_runtime_ns is not None:
|
||||
payload["baseline_runtime_ns"] = request.baseline_runtime_ns
|
||||
if request.loop_count is not None:
|
||||
payload["loop_count"] = request.loop_count
|
||||
if request.line_profiler_results is not None:
|
||||
payload["line_profiler_results"] = request.line_profiler_results
|
||||
if request.test_input_examples is not None:
|
||||
payload["test_input_examples"] = request.test_input_examples
|
||||
data = self.post("/optimize", payload)
|
||||
return [
|
||||
Candidate(
|
||||
|
|
|
|||
|
|
@ -18,6 +18,10 @@ class OptimizationRequest:
|
|||
is_async: bool = False
|
||||
is_numerical_code: bool | None = None
|
||||
codeflash_version: str = ""
|
||||
baseline_runtime_ns: int | None = None
|
||||
loop_count: int | None = None
|
||||
line_profiler_results: str | None = None
|
||||
test_input_examples: str | None = None
|
||||
|
||||
|
||||
@attrs.frozen
|
||||
|
|
|
|||
|
|
@ -37,6 +37,8 @@ def generate_candidates(
|
|||
function_trace_id: str,
|
||||
fn_input: FunctionInput,
|
||||
code_context: CodeOptimizationContext,
|
||||
baseline: OriginalCodeBaseline | None = None,
|
||||
test_input_examples: str | None = None,
|
||||
*,
|
||||
is_numerical: bool = False,
|
||||
) -> list[Candidate]:
|
||||
|
|
@ -45,6 +47,12 @@ def generate_candidates(
|
|||
CodeStringsMarkdown,
|
||||
)
|
||||
|
||||
baseline_runtime_ns: int | None = None
|
||||
loop_count: int | None = None
|
||||
if baseline is not None:
|
||||
baseline_runtime_ns = int(baseline.runtime)
|
||||
loop_count = baseline.benchmarking_test_results.number_of_loops()
|
||||
|
||||
request = OptimizationRequest(
|
||||
source_code=code_context.read_writable_code.markdown,
|
||||
language=ctx.plugin.language_id,
|
||||
|
|
@ -53,6 +61,9 @@ def generate_candidates(
|
|||
is_async=fn_input.function.is_async,
|
||||
is_numerical_code=is_numerical,
|
||||
codeflash_version=_core_version,
|
||||
baseline_runtime_ns=baseline_runtime_ns,
|
||||
loop_count=loop_count,
|
||||
test_input_examples=test_input_examples,
|
||||
)
|
||||
try:
|
||||
raw = ctx.ai_client.get_candidates(
|
||||
|
|
|
|||
|
|
@ -58,6 +58,43 @@ if TYPE_CHECKING:
|
|||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
_MAX_TEST_EXAMPLES_CHARS = 4000
|
||||
|
||||
|
||||
def _extract_test_input_examples(
|
||||
test_files: TestFiles | None,
|
||||
) -> str | None:
|
||||
"""Extract generated test source code as test input examples.
|
||||
|
||||
Reads the raw source of GENERATED_REGRESSION test files and
|
||||
concatenates them, truncated to *_MAX_TEST_EXAMPLES_CHARS*.
|
||||
Returns *None* when no generated tests are available.
|
||||
"""
|
||||
if test_files is None:
|
||||
return None
|
||||
|
||||
from ..test_discovery.models import TestType # noqa: PLC0415
|
||||
|
||||
parts: list[str] = []
|
||||
total = 0
|
||||
for tf in test_files.test_files:
|
||||
if tf.test_type != TestType.GENERATED_REGRESSION:
|
||||
continue
|
||||
try:
|
||||
source = tf.original_file_path.read_text(encoding="utf-8")
|
||||
except Exception: # noqa: BLE001
|
||||
continue
|
||||
if not source.strip():
|
||||
continue
|
||||
remaining = _MAX_TEST_EXAMPLES_CHARS - total
|
||||
if remaining <= 0:
|
||||
break
|
||||
chunk = source[:remaining]
|
||||
parts.append(chunk)
|
||||
total += len(chunk)
|
||||
return "\n\n".join(parts) if parts else None
|
||||
|
||||
|
||||
_HAS_NUMBA: bool = importlib.util.find_spec("numba") is not None
|
||||
|
||||
NUMERICAL_MODULES: frozenset[str] = frozenset(
|
||||
|
|
@ -546,44 +583,40 @@ class PythonFunctionOptimizer:
|
|||
fn_input, self.ctx.project_root, self.ctx.test_cfg
|
||||
)
|
||||
|
||||
# 4. Fire off AI candidate generation concurrently
|
||||
# with baseline — the HTTP call doesn't need
|
||||
# baseline results.
|
||||
from concurrent.futures import ( # noqa: PLC0415
|
||||
ThreadPoolExecutor,
|
||||
# 4. Establish baseline first so we can send runtime
|
||||
# data to the AI service for better-informed candidates.
|
||||
baseline = establish_original_code_baseline(
|
||||
test_files=self.test_files,
|
||||
test_config=self.ctx.test_cfg,
|
||||
test_env=test_env,
|
||||
cwd=self.ctx.project_root,
|
||||
is_async=func.is_async,
|
||||
async_function=(
|
||||
func if func.is_async else None
|
||||
),
|
||||
)
|
||||
|
||||
with ThreadPoolExecutor(max_workers=1) as pool:
|
||||
candidates_future = pool.submit(
|
||||
generate_candidates,
|
||||
ctx=self.ctx,
|
||||
function_trace_id=self.function_trace_id,
|
||||
fn_input=fn_input,
|
||||
code_context=code_context,
|
||||
is_numerical=numerical,
|
||||
if baseline is None:
|
||||
return FunctionResult(
|
||||
function=func,
|
||||
module_path=fn_input.module_path,
|
||||
success=False,
|
||||
message="Baseline establishment failed",
|
||||
)
|
||||
|
||||
baseline = establish_original_code_baseline(
|
||||
test_files=self.test_files,
|
||||
test_config=self.ctx.test_cfg,
|
||||
test_env=test_env,
|
||||
cwd=self.ctx.project_root,
|
||||
is_async=func.is_async,
|
||||
async_function=(
|
||||
func if func.is_async else None
|
||||
),
|
||||
)
|
||||
test_input_examples = _extract_test_input_examples(
|
||||
self.test_files,
|
||||
)
|
||||
|
||||
if baseline is None:
|
||||
candidates_future.cancel()
|
||||
return FunctionResult(
|
||||
function=func,
|
||||
module_path=fn_input.module_path,
|
||||
success=False,
|
||||
message="Baseline establishment failed",
|
||||
)
|
||||
|
||||
candidates = candidates_future.result()
|
||||
candidates = generate_candidates(
|
||||
ctx=self.ctx,
|
||||
function_trace_id=self.function_trace_id,
|
||||
fn_input=fn_input,
|
||||
code_context=code_context,
|
||||
baseline=baseline,
|
||||
test_input_examples=test_input_examples,
|
||||
is_numerical=numerical,
|
||||
)
|
||||
|
||||
# 3a. Collect async metrics if function is async.
|
||||
if func.is_async:
|
||||
|
|
|
|||
Loading…
Reference in a new issue