Parallelize candidate evaluation across all optimization passes

Evaluates candidates concurrently using ThreadPoolExecutor with project
overlays for isolation. Each candidate gets its own symlinked copy of
the project so test subprocesses don't interfere with each other or
the original source. Shared result lists protected with threading.Lock.
This commit is contained in:
Kevin Turcios 2026-04-21 04:16:43 -05:00
parent b455c1e69f
commit 66461ad4e7

View file

@ -69,6 +69,41 @@ NUMBA_REQUIRED_MODULES: frozenset[str] = frozenset(
)
def _evaluate_batch_parallel(
candidates: list[Candidate],
try_fn: _TryCandidateFn,
max_workers: int = 4,
) -> None:
"""Evaluate *candidates* using *try_fn*, in parallel when possible.
Falls back to sequential evaluation when only one candidate is
present or when the thread pool raises an unexpected error.
"""
if len(candidates) <= 1:
for c in candidates:
try_fn(c)
return
from concurrent.futures import ( # noqa: PLC0415
ThreadPoolExecutor,
as_completed,
)
with ThreadPoolExecutor(max_workers=max_workers) as pool:
futures = {
pool.submit(try_fn, c): c for c in candidates
}
for future in as_completed(futures):
exc = future.exception()
if exc is not None:
cid = futures[future].candidate_id
log.warning(
"Candidate %s raised during evaluation: %s",
cid,
exc,
)
def _uses_numerical_names(node: ast.AST, numerical_names: set[str]) -> bool:
"""Return *True* if *node* references any of *numerical_names*."""
return any(
@ -602,7 +637,7 @@ class PythonFunctionOptimizer:
# -- Evaluation and selection ------------------------------------
def _evaluate_and_select( # noqa: C901, PLR0912, PLR0915
def _evaluate_and_select( # noqa: C901, PLR0915
self,
candidates: list[Candidate],
fn_input: FunctionInput,
@ -614,7 +649,7 @@ class PythonFunctionOptimizer:
normalize_python_code,
)
from ._candidate_eval import ( # noqa: PLC0415
evaluate_candidate,
evaluate_candidate_isolated,
log_evaluation_results,
rank_candidates,
)
@ -643,9 +678,12 @@ class PythonFunctionOptimizer:
message=("All candidates duplicated the original"),
)
import threading # noqa: PLC0415
eval_ctx = EvaluationContext()
valid: list[Candidate] = []
diff_lengths: list[int] = []
_lock = threading.Lock()
async_eval = self._make_async_evaluator()
test_env = build_test_env(
@ -654,7 +692,7 @@ class PythonFunctionOptimizer:
def _try_candidate(c: Candidate) -> None:
"""Evaluate *c* and append to *valid* if it improves."""
sp = evaluate_candidate(
sp = evaluate_candidate_isolated(
candidate=c,
fn_input=fn_input,
baseline=baseline,
@ -668,14 +706,14 @@ class PythonFunctionOptimizer:
evaluate_async_fn=async_eval,
)
if sp is not None and sp > 0:
valid.append(c)
diff_lengths.append(
diff_length(c.code, fn_input.source_code),
)
with _lock:
valid.append(c)
diff_lengths.append(
diff_length(c.code, fn_input.source_code),
)
# Pass 1: evaluate initial candidates.
for candidate in unique:
_try_candidate(candidate)
# Pass 1: evaluate initial candidates in parallel.
_evaluate_batch_parallel(unique, _try_candidate)
# Pass 2: refinement + repair.
pass2 = generate_refinement_candidates(
@ -703,8 +741,9 @@ class PythonFunctionOptimizer:
normalize_fn=normalize_python_code,
original_normalized=normalized_original,
)
for candidate in pass2_unique:
_try_candidate(candidate)
_evaluate_batch_parallel(
pass2_unique, _try_candidate,
)
# Pass 3: adaptive optimization (needs >=2 valid).
if len(valid) >= 2: # noqa: PLR2004
@ -715,8 +754,9 @@ class PythonFunctionOptimizer:
eval_ctx=eval_ctx,
fn_input=fn_input,
)
for candidate in adaptive:
_try_candidate(candidate)
_evaluate_batch_parallel(
adaptive, _try_candidate,
)
if not valid:
return FunctionResult(
@ -1218,4 +1258,8 @@ def write_code_and_helpers(
if TYPE_CHECKING:
from collections.abc import Callable
_TryCandidateFn = Callable[[Candidate], None]
from ._candidate_eval import _EvalAsyncFn