Wire baseline runtime, test examples, and LP diversity to AI service

Send baseline_runtime_ns, loop_count, test_input_examples, and
line_profiler_results from the client to the optimize endpoint so
the AI service can generate better-informed candidates. Restructure
the per-function optimizer to establish baseline before candidate
generation, and alternate line profiler data across calls for
diversity.
This commit is contained in:
Kevin Turcios 2026-04-21 15:45:17 -05:00
parent f23822b919
commit 92e7a23722
4 changed files with 90 additions and 34 deletions

View file

@ -115,7 +115,7 @@ class AIClient:
if not trace_id:
trace_id = str(uuid.uuid4())
payload = {
payload: dict[str, Any] = {
"source_code": request.source_code,
"dependency_code": request.context_code,
"trace_id": trace_id,
@ -127,6 +127,14 @@ class AIClient:
"is_numerical_code": request.is_numerical_code,
"codeflash_version": request.codeflash_version,
}
if request.baseline_runtime_ns is not None:
payload["baseline_runtime_ns"] = request.baseline_runtime_ns
if request.loop_count is not None:
payload["loop_count"] = request.loop_count
if request.line_profiler_results is not None:
payload["line_profiler_results"] = request.line_profiler_results
if request.test_input_examples is not None:
payload["test_input_examples"] = request.test_input_examples
data = self.post("/optimize", payload)
return [
Candidate(

View file

@ -18,6 +18,10 @@ class OptimizationRequest:
is_async: bool = False
is_numerical_code: bool | None = None
codeflash_version: str = ""
baseline_runtime_ns: int | None = None
loop_count: int | None = None
line_profiler_results: str | None = None
test_input_examples: str | None = None
@attrs.frozen

View file

@ -37,6 +37,8 @@ def generate_candidates(
function_trace_id: str,
fn_input: FunctionInput,
code_context: CodeOptimizationContext,
baseline: OriginalCodeBaseline | None = None,
test_input_examples: str | None = None,
*,
is_numerical: bool = False,
) -> list[Candidate]:
@ -45,6 +47,12 @@ def generate_candidates(
CodeStringsMarkdown,
)
baseline_runtime_ns: int | None = None
loop_count: int | None = None
if baseline is not None:
baseline_runtime_ns = int(baseline.runtime)
loop_count = baseline.benchmarking_test_results.number_of_loops()
request = OptimizationRequest(
source_code=code_context.read_writable_code.markdown,
language=ctx.plugin.language_id,
@ -53,6 +61,9 @@ def generate_candidates(
is_async=fn_input.function.is_async,
is_numerical_code=is_numerical,
codeflash_version=_core_version,
baseline_runtime_ns=baseline_runtime_ns,
loop_count=loop_count,
test_input_examples=test_input_examples,
)
try:
raw = ctx.ai_client.get_candidates(

View file

@ -58,6 +58,43 @@ if TYPE_CHECKING:
log = logging.getLogger(__name__)
_MAX_TEST_EXAMPLES_CHARS = 4000
def _extract_test_input_examples(
test_files: TestFiles | None,
) -> str | None:
"""Extract generated test source code as test input examples.
Reads the raw source of GENERATED_REGRESSION test files and
concatenates them, truncated to *_MAX_TEST_EXAMPLES_CHARS*.
Returns *None* when no generated tests are available.
"""
if test_files is None:
return None
from ..test_discovery.models import TestType # noqa: PLC0415
parts: list[str] = []
total = 0
for tf in test_files.test_files:
if tf.test_type != TestType.GENERATED_REGRESSION:
continue
try:
source = tf.original_file_path.read_text(encoding="utf-8")
except Exception: # noqa: BLE001
continue
if not source.strip():
continue
remaining = _MAX_TEST_EXAMPLES_CHARS - total
if remaining <= 0:
break
chunk = source[:remaining]
parts.append(chunk)
total += len(chunk)
return "\n\n".join(parts) if parts else None
_HAS_NUMBA: bool = importlib.util.find_spec("numba") is not None
NUMERICAL_MODULES: frozenset[str] = frozenset(
@ -546,44 +583,40 @@ class PythonFunctionOptimizer:
fn_input, self.ctx.project_root, self.ctx.test_cfg
)
# 4. Fire off AI candidate generation concurrently
# with baseline — the HTTP call doesn't need
# baseline results.
from concurrent.futures import ( # noqa: PLC0415
ThreadPoolExecutor,
# 4. Establish baseline first so we can send runtime
# data to the AI service for better-informed candidates.
baseline = establish_original_code_baseline(
test_files=self.test_files,
test_config=self.ctx.test_cfg,
test_env=test_env,
cwd=self.ctx.project_root,
is_async=func.is_async,
async_function=(
func if func.is_async else None
),
)
with ThreadPoolExecutor(max_workers=1) as pool:
candidates_future = pool.submit(
generate_candidates,
ctx=self.ctx,
function_trace_id=self.function_trace_id,
fn_input=fn_input,
code_context=code_context,
is_numerical=numerical,
if baseline is None:
return FunctionResult(
function=func,
module_path=fn_input.module_path,
success=False,
message="Baseline establishment failed",
)
baseline = establish_original_code_baseline(
test_files=self.test_files,
test_config=self.ctx.test_cfg,
test_env=test_env,
cwd=self.ctx.project_root,
is_async=func.is_async,
async_function=(
func if func.is_async else None
),
)
test_input_examples = _extract_test_input_examples(
self.test_files,
)
if baseline is None:
candidates_future.cancel()
return FunctionResult(
function=func,
module_path=fn_input.module_path,
success=False,
message="Baseline establishment failed",
)
candidates = candidates_future.result()
candidates = generate_candidates(
ctx=self.ctx,
function_trace_id=self.function_trace_id,
fn_input=fn_input,
code_context=code_context,
baseline=baseline,
test_input_examples=test_input_examples,
is_numerical=numerical,
)
# 3a. Collect async metrics if function is async.
if func.is_async: