mirror of
https://github.com/codeflash-ai/codeflash-agent.git
synced 2026-05-04 18:25:19 +00:00
Parallelize candidate evaluation across all optimization passes
Evaluates candidates concurrently using ThreadPoolExecutor with project overlays for isolation. Each candidate gets its own symlinked copy of the project so test subprocesses don't interfere with each other or the original source. Shared result lists protected with threading.Lock.
This commit is contained in:
parent
b455c1e69f
commit
66461ad4e7
1 changed files with 58 additions and 14 deletions
|
|
@ -69,6 +69,41 @@ NUMBA_REQUIRED_MODULES: frozenset[str] = frozenset(
|
|||
)
|
||||
|
||||
|
||||
def _evaluate_batch_parallel(
|
||||
candidates: list[Candidate],
|
||||
try_fn: _TryCandidateFn,
|
||||
max_workers: int = 4,
|
||||
) -> None:
|
||||
"""Evaluate *candidates* using *try_fn*, in parallel when possible.
|
||||
|
||||
Falls back to sequential evaluation when only one candidate is
|
||||
present or when the thread pool raises an unexpected error.
|
||||
"""
|
||||
if len(candidates) <= 1:
|
||||
for c in candidates:
|
||||
try_fn(c)
|
||||
return
|
||||
|
||||
from concurrent.futures import ( # noqa: PLC0415
|
||||
ThreadPoolExecutor,
|
||||
as_completed,
|
||||
)
|
||||
|
||||
with ThreadPoolExecutor(max_workers=max_workers) as pool:
|
||||
futures = {
|
||||
pool.submit(try_fn, c): c for c in candidates
|
||||
}
|
||||
for future in as_completed(futures):
|
||||
exc = future.exception()
|
||||
if exc is not None:
|
||||
cid = futures[future].candidate_id
|
||||
log.warning(
|
||||
"Candidate %s raised during evaluation: %s",
|
||||
cid,
|
||||
exc,
|
||||
)
|
||||
|
||||
|
||||
def _uses_numerical_names(node: ast.AST, numerical_names: set[str]) -> bool:
|
||||
"""Return *True* if *node* references any of *numerical_names*."""
|
||||
return any(
|
||||
|
|
@ -602,7 +637,7 @@ class PythonFunctionOptimizer:
|
|||
|
||||
# -- Evaluation and selection ------------------------------------
|
||||
|
||||
def _evaluate_and_select( # noqa: C901, PLR0912, PLR0915
|
||||
def _evaluate_and_select( # noqa: C901, PLR0915
|
||||
self,
|
||||
candidates: list[Candidate],
|
||||
fn_input: FunctionInput,
|
||||
|
|
@ -614,7 +649,7 @@ class PythonFunctionOptimizer:
|
|||
normalize_python_code,
|
||||
)
|
||||
from ._candidate_eval import ( # noqa: PLC0415
|
||||
evaluate_candidate,
|
||||
evaluate_candidate_isolated,
|
||||
log_evaluation_results,
|
||||
rank_candidates,
|
||||
)
|
||||
|
|
@ -643,9 +678,12 @@ class PythonFunctionOptimizer:
|
|||
message=("All candidates duplicated the original"),
|
||||
)
|
||||
|
||||
import threading # noqa: PLC0415
|
||||
|
||||
eval_ctx = EvaluationContext()
|
||||
valid: list[Candidate] = []
|
||||
diff_lengths: list[int] = []
|
||||
_lock = threading.Lock()
|
||||
|
||||
async_eval = self._make_async_evaluator()
|
||||
test_env = build_test_env(
|
||||
|
|
@ -654,7 +692,7 @@ class PythonFunctionOptimizer:
|
|||
|
||||
def _try_candidate(c: Candidate) -> None:
|
||||
"""Evaluate *c* and append to *valid* if it improves."""
|
||||
sp = evaluate_candidate(
|
||||
sp = evaluate_candidate_isolated(
|
||||
candidate=c,
|
||||
fn_input=fn_input,
|
||||
baseline=baseline,
|
||||
|
|
@ -668,14 +706,14 @@ class PythonFunctionOptimizer:
|
|||
evaluate_async_fn=async_eval,
|
||||
)
|
||||
if sp is not None and sp > 0:
|
||||
valid.append(c)
|
||||
diff_lengths.append(
|
||||
diff_length(c.code, fn_input.source_code),
|
||||
)
|
||||
with _lock:
|
||||
valid.append(c)
|
||||
diff_lengths.append(
|
||||
diff_length(c.code, fn_input.source_code),
|
||||
)
|
||||
|
||||
# Pass 1: evaluate initial candidates.
|
||||
for candidate in unique:
|
||||
_try_candidate(candidate)
|
||||
# Pass 1: evaluate initial candidates in parallel.
|
||||
_evaluate_batch_parallel(unique, _try_candidate)
|
||||
|
||||
# Pass 2: refinement + repair.
|
||||
pass2 = generate_refinement_candidates(
|
||||
|
|
@ -703,8 +741,9 @@ class PythonFunctionOptimizer:
|
|||
normalize_fn=normalize_python_code,
|
||||
original_normalized=normalized_original,
|
||||
)
|
||||
for candidate in pass2_unique:
|
||||
_try_candidate(candidate)
|
||||
_evaluate_batch_parallel(
|
||||
pass2_unique, _try_candidate,
|
||||
)
|
||||
|
||||
# Pass 3: adaptive optimization (needs >=2 valid).
|
||||
if len(valid) >= 2: # noqa: PLR2004
|
||||
|
|
@ -715,8 +754,9 @@ class PythonFunctionOptimizer:
|
|||
eval_ctx=eval_ctx,
|
||||
fn_input=fn_input,
|
||||
)
|
||||
for candidate in adaptive:
|
||||
_try_candidate(candidate)
|
||||
_evaluate_batch_parallel(
|
||||
adaptive, _try_candidate,
|
||||
)
|
||||
|
||||
if not valid:
|
||||
return FunctionResult(
|
||||
|
|
@ -1218,4 +1258,8 @@ def write_code_and_helpers(
|
|||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Callable
|
||||
|
||||
_TryCandidateFn = Callable[[Candidate], None]
|
||||
|
||||
from ._candidate_eval import _EvalAsyncFn
|
||||
|
|
|
|||
Loading…
Reference in a new issue