mirror of
https://github.com/codeflash-ai/codeflash.git
synced 2026-05-04 18:25:17 +00:00
Merge pull request #1919 from codeflash-ai/cf-feat-rerun-trace
feat: add --rerun flag to rerun optimization from stored trace
This commit is contained in:
commit
14dafe08fd
4 changed files with 31 additions and 3 deletions
|
|
@ -154,6 +154,7 @@ class AiServiceClient:
|
|||
is_async: bool = False,
|
||||
n_candidates: int = 5,
|
||||
is_numerical_code: bool | None = None,
|
||||
rerun_trace_id: str | None = None,
|
||||
) -> list[OptimizedCandidate]:
|
||||
"""Optimize the given code for performance by making a request to the Django endpoint.
|
||||
|
||||
|
|
@ -194,6 +195,7 @@ class AiServiceClient:
|
|||
"call_sequence": self.get_next_sequence(),
|
||||
"n_candidates": n_candidates,
|
||||
"is_numerical_code": is_numerical_code,
|
||||
"rerun_trace_id": rerun_trace_id,
|
||||
}
|
||||
|
||||
self.add_language_metadata(payload, language_version, module_system)
|
||||
|
|
@ -234,6 +236,7 @@ class AiServiceClient:
|
|||
is_numerical_code: bool | None = None,
|
||||
language: str = "python",
|
||||
language_version: str | None = None,
|
||||
rerun_trace_id: str | None = None,
|
||||
) -> list[OptimizedCandidate]:
|
||||
"""Optimize code for performance using line profiler results.
|
||||
|
||||
|
|
@ -272,6 +275,7 @@ class AiServiceClient:
|
|||
"codeflash_version": codeflash_version,
|
||||
"call_sequence": self.get_next_sequence(),
|
||||
"is_numerical_code": is_numerical_code,
|
||||
"rerun_trace_id": rerun_trace_id,
|
||||
}
|
||||
|
||||
try:
|
||||
|
|
@ -318,7 +322,9 @@ class AiServiceClient:
|
|||
self.log_error_response(response, "generating optimized candidates", "cli-optimize-error-response")
|
||||
return None
|
||||
|
||||
def optimize_code_refinement(self, request: list[AIServiceRefinerRequest]) -> list[OptimizedCandidate]:
|
||||
def optimize_code_refinement(
|
||||
self, request: list[AIServiceRefinerRequest], rerun_trace_id: str | None = None
|
||||
) -> list[OptimizedCandidate]:
|
||||
"""Refine optimization candidates for improved performance.
|
||||
|
||||
Supports Python, JavaScript, and TypeScript code refinement with optional
|
||||
|
|
@ -349,6 +355,7 @@ class AiServiceClient:
|
|||
"call_sequence": self.get_next_sequence(),
|
||||
# Multi-language support
|
||||
"language": opt.language,
|
||||
"rerun_trace_id": rerun_trace_id,
|
||||
}
|
||||
|
||||
self.add_language_metadata(item, opt.language_version)
|
||||
|
|
@ -375,7 +382,9 @@ class AiServiceClient:
|
|||
console.rule()
|
||||
return []
|
||||
|
||||
def code_repair(self, request: AIServiceCodeRepairRequest) -> OptimizedCandidate | None:
|
||||
def code_repair(
|
||||
self, request: AIServiceCodeRepairRequest, rerun_trace_id: str | None = None
|
||||
) -> OptimizedCandidate | None:
|
||||
console.rule()
|
||||
try:
|
||||
payload = {
|
||||
|
|
@ -385,6 +394,7 @@ class AiServiceClient:
|
|||
"trace_id": request.trace_id,
|
||||
"test_diffs": request.test_diffs,
|
||||
"language": request.language,
|
||||
"rerun_trace_id": rerun_trace_id,
|
||||
}
|
||||
response = self.make_ai_service_request("/code_repair", payload=payload, timeout=self.timeout)
|
||||
except (requests.exceptions.RequestException, TypeError) as e:
|
||||
|
|
@ -607,6 +617,7 @@ class AiServiceClient:
|
|||
language_version: str | None = None,
|
||||
module_system: str | None = None,
|
||||
is_numerical_code: bool | None = None,
|
||||
rerun_trace_id: str | None = None,
|
||||
) -> tuple[str, str, str, str | None] | None:
|
||||
"""Generate regression tests for the given function by making a request to the Django endpoint.
|
||||
|
||||
|
|
@ -655,6 +666,7 @@ class AiServiceClient:
|
|||
"is_numerical_code": is_numerical_code,
|
||||
"class_name": function_to_optimize.class_name,
|
||||
"qualified_name": function_to_optimize.qualified_name,
|
||||
"rerun_trace_id": rerun_trace_id,
|
||||
}
|
||||
|
||||
self.add_language_metadata(payload, language_version, module_system)
|
||||
|
|
|
|||
|
|
@ -437,6 +437,12 @@ def _build_parser() -> ArgumentParser:
|
|||
)
|
||||
parser.add_argument("--config-file", type=str, help="Path to the pyproject.toml with codeflash configs.")
|
||||
parser.add_argument("--replay-test", type=str, nargs="+", help="Paths to replay test to optimize functions from")
|
||||
parser.add_argument(
|
||||
"--rerun",
|
||||
type=str,
|
||||
help="Rerun a previous optimization by trace ID, using stored LLM results",
|
||||
metavar="TRACE_ID",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--no-pr", action="store_true", help="Do not create a PR for the optimization, only update the code locally."
|
||||
)
|
||||
|
|
|
|||
|
|
@ -877,6 +877,10 @@ class FunctionOptimizer:
|
|||
return Failure(f"No best optimizations found for function {self.function_to_optimize.qualified_name}")
|
||||
return Success(best_optimization)
|
||||
|
||||
@property
|
||||
def rerun_trace_id(self) -> str | None:
|
||||
return getattr(self.args, "rerun", None) if self.args else None
|
||||
|
||||
def get_trace_id(self, exp_type: str) -> str:
|
||||
"""Get the trace ID for the current experiment type."""
|
||||
if self.experiment_id:
|
||||
|
|
@ -1291,6 +1295,7 @@ class FunctionOptimizer:
|
|||
language_version=self.language_support.language_version,
|
||||
)
|
||||
],
|
||||
rerun_trace_id=self.rerun_trace_id,
|
||||
)
|
||||
self.future_all_refinements.append(future_refinement)
|
||||
|
||||
|
|
@ -1353,6 +1358,7 @@ class FunctionOptimizer:
|
|||
is_numerical_code=self.is_numerical_code and not self.args.no_jit_opts,
|
||||
language=self.function_to_optimize.language,
|
||||
language_version=self.language_support.language_version,
|
||||
rerun_trace_id=self.rerun_trace_id,
|
||||
)
|
||||
|
||||
normalized_original = self.language_support.normalize_code(code_context.read_writable_code.flat.strip())
|
||||
|
|
@ -1485,7 +1491,7 @@ class FunctionOptimizer:
|
|||
trace_id=trace_id,
|
||||
language=language,
|
||||
)
|
||||
return executor.submit(ai_service_client.code_repair, request=request)
|
||||
return executor.submit(ai_service_client.code_repair, request=request, rerun_trace_id=self.rerun_trace_id)
|
||||
|
||||
def log_successful_optimization(
|
||||
self, explanation: Explanation, generated_tests: GeneratedTestsList, exp_type: str
|
||||
|
|
@ -1861,6 +1867,7 @@ class FunctionOptimizer:
|
|||
is_async=self.function_to_optimize.is_async,
|
||||
n_candidates=n_candidates,
|
||||
is_numerical_code=is_numerical_code,
|
||||
rerun_trace_id=self.rerun_trace_id,
|
||||
)
|
||||
|
||||
future_references = self.executor.submit(
|
||||
|
|
@ -3181,6 +3188,7 @@ class FunctionOptimizer:
|
|||
test_path,
|
||||
test_perf_path,
|
||||
self.is_numerical_code,
|
||||
self.rerun_trace_id,
|
||||
)
|
||||
for test_index, (test_path, test_perf_path) in enumerate(
|
||||
zip(generated_test_paths, generated_perf_test_paths)
|
||||
|
|
|
|||
|
|
@ -29,6 +29,7 @@ def generate_tests(
|
|||
test_path: Path,
|
||||
test_perf_path: Path,
|
||||
is_numerical_code: bool | None = None,
|
||||
rerun_trace_id: str | None = None,
|
||||
) -> tuple[str, str, str, str | None, Path, Path] | None:
|
||||
# TODO: Sometimes this recreates the original Class definition. This overrides and messes up the original
|
||||
# class import. Remove the recreation of the class definition
|
||||
|
|
@ -73,6 +74,7 @@ def generate_tests(
|
|||
language_version=current_language_support().language_version,
|
||||
module_system=project_module_system,
|
||||
is_numerical_code=is_numerical_code,
|
||||
rerun_trace_id=rerun_trace_id,
|
||||
)
|
||||
if response and isinstance(response, tuple) and len(response) == 4:
|
||||
generated_test_source, instrumented_behavior_test_source, instrumented_perf_test_source, raw_generated_tests = (
|
||||
|
|
|
|||
Loading…
Reference in a new issue