Merge pull request #661 from codeflash-ai/rewrite-candidate-loop-clean

Abstracted Clean candidate loop using queue.Queue()
This commit is contained in:
Aseem Saxena 2025-08-22 18:09:26 -04:00 committed by GitHub
commit 02b4d6557a
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -3,11 +3,12 @@ from __future__ import annotations
import ast
import concurrent.futures
import os
import queue
import random
import subprocess
import time
import uuid
from collections import defaultdict, deque
from collections import defaultdict
from pathlib import Path
from typing import TYPE_CHECKING
@ -104,6 +105,83 @@ if TYPE_CHECKING:
from codeflash.verification.verification_utils import TestConfig
class CandidateProcessor:
"""Handles candidate processing using a queue-based approach."""
def __init__(
self,
initial_candidates: list,
future_line_profile_results: concurrent.futures.Future,
future_all_refinements: list,
) -> None:
self.candidate_queue = queue.Queue()
self.line_profiler_done = False
self.refinement_done = False
self.candidate_len = len(initial_candidates)
# Initialize queue with initial candidates
for candidate in initial_candidates:
self.candidate_queue.put(candidate)
self.future_line_profile_results = future_line_profile_results
self.future_all_refinements = future_all_refinements
def get_next_candidate(self) -> OptimizedCandidate | None:
"""Get the next candidate from the queue, handling async results as needed."""
try:
return self.candidate_queue.get_nowait()
except queue.Empty:
return self._handle_empty_queue()
def _handle_empty_queue(self) -> OptimizedCandidate | None:
"""Handle empty queue by checking for pending async results."""
if not self.line_profiler_done:
return self._process_line_profiler_results()
if self.line_profiler_done and not self.refinement_done:
return self._process_refinement_results()
return None # All done
def _process_line_profiler_results(self) -> OptimizedCandidate | None:
"""Process line profiler results and add to queue."""
logger.debug("all candidates processed, await candidates from line profiler")
concurrent.futures.wait([self.future_line_profile_results])
line_profile_results = self.future_line_profile_results.result()
for candidate in line_profile_results:
self.candidate_queue.put(candidate)
self.candidate_len += len(line_profile_results)
logger.info(f"Added results from line profiler to candidates, total candidates now: {self.candidate_len}")
self.line_profiler_done = True
return self.get_next_candidate()
def _process_refinement_results(self) -> OptimizedCandidate | None:
"""Process refinement results and add to queue."""
concurrent.futures.wait(self.future_all_refinements)
refinement_response = []
for future_refinement in self.future_all_refinements:
possible_refinement = future_refinement.result()
if len(possible_refinement) > 0:
refinement_response.append(possible_refinement[0])
for candidate in refinement_response:
self.candidate_queue.put(candidate)
self.candidate_len += len(refinement_response)
logger.info(
f"Added {len(refinement_response)} candidates from refinement, total candidates now: {self.candidate_len}"
)
self.refinement_done = True
return self.get_next_candidate()
def is_done(self) -> bool:
"""Check if processing is complete."""
return self.line_profiler_done and self.refinement_done and self.candidate_queue.empty()
class FunctionOptimizer:
def __init__(
self,
@ -378,15 +456,13 @@ class FunctionOptimizer:
f"{self.function_to_optimize.qualified_name}"
)
console.rule()
candidates = deque(candidates)
refinement_done = False
line_profiler_done = False
future_all_refinements: list[concurrent.futures.Future] = []
ast_code_to_id = {}
valid_optimizations = []
optimizations_post = {} # we need to overwrite some opt candidates' code strings as they are no longer evaluated, instead their shorter/longer versions might be evaluated
# Start a new thread for AI service request, start loop in main thread
# check if aiservice request is complete, when it is complete, append result to the candidates list
# Start a new thread for AI service request
ai_service_client = self.aiservice_client if exp_type == "EXP0" else self.local_aiservice_client
future_line_profile_results = self.executor.submit(
ai_service_client.optimize_python_code_line_profiler,
@ -401,48 +477,23 @@ class FunctionOptimizer:
if self.experiment_id
else None,
)
candidate_index = 0
original_len = len(candidates)
# TODO : We need to rewrite this candidate loop as a class, the container which has candidates receives new candidates at unknown times due to the async nature of lp and refinement calls,
# TODO : in addition, the refinement calls depend on line profiler calls being complete so we need to check that reliably
while True:
try:
if len(candidates) > 0:
candidate = candidates.popleft()
else:
if not line_profiler_done:
logger.debug("all candidates processed, await candidates from line profiler")
concurrent.futures.wait([future_line_profile_results])
line_profile_results = future_line_profile_results.result()
candidates.extend(line_profile_results)
original_len += len(line_profile_results)
logger.info(
f"Added results from line profiler to candidates, total candidates now: {original_len}"
)
line_profiler_done = True
continue
if line_profiler_done and not refinement_done:
concurrent.futures.wait(future_all_refinements)
refinement_response = []
for future_refinement in future_all_refinements:
possible_refinement = future_refinement.result()
if len(possible_refinement) > 0: # if the api returns a valid response
refinement_response.append(possible_refinement[0])
candidates.extend(refinement_response)
original_len += len(refinement_response)
logger.info(
f"Added {len(refinement_response)} candidates from refinement, total candidates now: {original_len}"
)
refinement_done = True
continue
if line_profiler_done and refinement_done:
logger.debug("everything done, exiting")
break
# Initialize candidate processor
processor = CandidateProcessor(candidates, future_line_profile_results, future_all_refinements)
candidate_index = 0
# Process candidates using queue-based approach
while not processor.is_done():
candidate = processor.get_next_candidate()
if candidate is None:
logger.debug("everything done, exiting")
break
try:
candidate_index += 1
get_run_tmp_file(Path(f"test_return_values_{candidate_index}.bin")).unlink(missing_ok=True)
get_run_tmp_file(Path(f"test_return_values_{candidate_index}.sqlite")).unlink(missing_ok=True)
logger.info(f"Optimization candidate {candidate_index}/{original_len}:")
logger.info(f"Optimization candidate {candidate_index}/{processor.candidate_len}:")
code_print(candidate.source_code.flat)
# map ast normalized code to diff len, unnormalized code
# map opt id to the shortest unnormalized code
@ -467,7 +518,7 @@ class FunctionOptimizer:
# check if this code has been evaluated before by checking the ast normalized code string
normalized_code = ast.unparse(ast.parse(candidate.source_code.flat.strip()))
if normalized_code in ast_code_to_id:
logger.warning(
logger.info(
"Current candidate has been encountered before in testing, Skipping optimization candidate."
)
past_opt_id = ast_code_to_id[normalized_code]["optimization_id"]