Merge pull request #1919 from codeflash-ai/cf-feat-rerun-trace

feat: add --rerun flag to rerun optimization from stored trace
This commit is contained in:
Sarthak Agarwal 2026-03-29 18:09:33 +05:30 committed by GitHub
commit 14dafe08fd
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 31 additions and 3 deletions

View file

@ -154,6 +154,7 @@ class AiServiceClient:
is_async: bool = False,
n_candidates: int = 5,
is_numerical_code: bool | None = None,
rerun_trace_id: str | None = None,
) -> list[OptimizedCandidate]:
"""Optimize the given code for performance by making a request to the Django endpoint.
@ -194,6 +195,7 @@ class AiServiceClient:
"call_sequence": self.get_next_sequence(),
"n_candidates": n_candidates,
"is_numerical_code": is_numerical_code,
"rerun_trace_id": rerun_trace_id,
}
self.add_language_metadata(payload, language_version, module_system)
@ -234,6 +236,7 @@ class AiServiceClient:
is_numerical_code: bool | None = None,
language: str = "python",
language_version: str | None = None,
rerun_trace_id: str | None = None,
) -> list[OptimizedCandidate]:
"""Optimize code for performance using line profiler results.
@ -272,6 +275,7 @@ class AiServiceClient:
"codeflash_version": codeflash_version,
"call_sequence": self.get_next_sequence(),
"is_numerical_code": is_numerical_code,
"rerun_trace_id": rerun_trace_id,
}
try:
@ -318,7 +322,9 @@ class AiServiceClient:
self.log_error_response(response, "generating optimized candidates", "cli-optimize-error-response")
return None
def optimize_code_refinement(self, request: list[AIServiceRefinerRequest]) -> list[OptimizedCandidate]:
def optimize_code_refinement(
self, request: list[AIServiceRefinerRequest], rerun_trace_id: str | None = None
) -> list[OptimizedCandidate]:
"""Refine optimization candidates for improved performance.
Supports Python, JavaScript, and TypeScript code refinement with optional
@ -349,6 +355,7 @@ class AiServiceClient:
"call_sequence": self.get_next_sequence(),
# Multi-language support
"language": opt.language,
"rerun_trace_id": rerun_trace_id,
}
self.add_language_metadata(item, opt.language_version)
@ -375,7 +382,9 @@ class AiServiceClient:
console.rule()
return []
def code_repair(self, request: AIServiceCodeRepairRequest) -> OptimizedCandidate | None:
def code_repair(
self, request: AIServiceCodeRepairRequest, rerun_trace_id: str | None = None
) -> OptimizedCandidate | None:
console.rule()
try:
payload = {
@ -385,6 +394,7 @@ class AiServiceClient:
"trace_id": request.trace_id,
"test_diffs": request.test_diffs,
"language": request.language,
"rerun_trace_id": rerun_trace_id,
}
response = self.make_ai_service_request("/code_repair", payload=payload, timeout=self.timeout)
except (requests.exceptions.RequestException, TypeError) as e:
@ -607,6 +617,7 @@ class AiServiceClient:
language_version: str | None = None,
module_system: str | None = None,
is_numerical_code: bool | None = None,
rerun_trace_id: str | None = None,
) -> tuple[str, str, str, str | None] | None:
"""Generate regression tests for the given function by making a request to the Django endpoint.
@ -655,6 +666,7 @@ class AiServiceClient:
"is_numerical_code": is_numerical_code,
"class_name": function_to_optimize.class_name,
"qualified_name": function_to_optimize.qualified_name,
"rerun_trace_id": rerun_trace_id,
}
self.add_language_metadata(payload, language_version, module_system)

View file

@ -437,6 +437,12 @@ def _build_parser() -> ArgumentParser:
)
parser.add_argument("--config-file", type=str, help="Path to the pyproject.toml with codeflash configs.")
parser.add_argument("--replay-test", type=str, nargs="+", help="Paths to replay test to optimize functions from")
parser.add_argument(
"--rerun",
type=str,
help="Rerun a previous optimization by trace ID, using stored LLM results",
metavar="TRACE_ID",
)
parser.add_argument(
"--no-pr", action="store_true", help="Do not create a PR for the optimization, only update the code locally."
)

View file

@ -877,6 +877,10 @@ class FunctionOptimizer:
return Failure(f"No best optimizations found for function {self.function_to_optimize.qualified_name}")
return Success(best_optimization)
@property
def rerun_trace_id(self) -> str | None:
return getattr(self.args, "rerun", None) if self.args else None
def get_trace_id(self, exp_type: str) -> str:
"""Get the trace ID for the current experiment type."""
if self.experiment_id:
@ -1291,6 +1295,7 @@ class FunctionOptimizer:
language_version=self.language_support.language_version,
)
],
rerun_trace_id=self.rerun_trace_id,
)
self.future_all_refinements.append(future_refinement)
@ -1353,6 +1358,7 @@ class FunctionOptimizer:
is_numerical_code=self.is_numerical_code and not self.args.no_jit_opts,
language=self.function_to_optimize.language,
language_version=self.language_support.language_version,
rerun_trace_id=self.rerun_trace_id,
)
normalized_original = self.language_support.normalize_code(code_context.read_writable_code.flat.strip())
@ -1485,7 +1491,7 @@ class FunctionOptimizer:
trace_id=trace_id,
language=language,
)
return executor.submit(ai_service_client.code_repair, request=request)
return executor.submit(ai_service_client.code_repair, request=request, rerun_trace_id=self.rerun_trace_id)
def log_successful_optimization(
self, explanation: Explanation, generated_tests: GeneratedTestsList, exp_type: str
@ -1861,6 +1867,7 @@ class FunctionOptimizer:
is_async=self.function_to_optimize.is_async,
n_candidates=n_candidates,
is_numerical_code=is_numerical_code,
rerun_trace_id=self.rerun_trace_id,
)
future_references = self.executor.submit(
@ -3181,6 +3188,7 @@ class FunctionOptimizer:
test_path,
test_perf_path,
self.is_numerical_code,
self.rerun_trace_id,
)
for test_index, (test_path, test_perf_path) in enumerate(
zip(generated_test_paths, generated_perf_test_paths)

View file

@ -29,6 +29,7 @@ def generate_tests(
test_path: Path,
test_perf_path: Path,
is_numerical_code: bool | None = None,
rerun_trace_id: str | None = None,
) -> tuple[str, str, str, str | None, Path, Path] | None:
# TODO: Sometimes this recreates the original Class definition. This overrides and messes up the original
# class import. Remove the recreation of the class definition
@ -73,6 +74,7 @@ def generate_tests(
language_version=current_language_support().language_version,
module_system=project_module_system,
is_numerical_code=is_numerical_code,
rerun_trace_id=rerun_trace_id,
)
if response and isinstance(response, tuple) and len(response) == 4:
generated_test_source, instrumented_behavior_test_source, instrumented_perf_test_source, raw_generated_tests = (