Merge pull request #661 from codeflash-ai/rewrite-candidate-loop-clean
Abstracted Clean candidate loop using queue.Queue()
This commit is contained in:
commit
02b4d6557a
1 changed files with 96 additions and 45 deletions
|
|
@ -3,11 +3,12 @@ from __future__ import annotations
|
|||
import ast
|
||||
import concurrent.futures
|
||||
import os
|
||||
import queue
|
||||
import random
|
||||
import subprocess
|
||||
import time
|
||||
import uuid
|
||||
from collections import defaultdict, deque
|
||||
from collections import defaultdict
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
|
|
@ -104,6 +105,83 @@ if TYPE_CHECKING:
|
|||
from codeflash.verification.verification_utils import TestConfig
|
||||
|
||||
|
||||
class CandidateProcessor:
|
||||
"""Handles candidate processing using a queue-based approach."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
initial_candidates: list,
|
||||
future_line_profile_results: concurrent.futures.Future,
|
||||
future_all_refinements: list,
|
||||
) -> None:
|
||||
self.candidate_queue = queue.Queue()
|
||||
self.line_profiler_done = False
|
||||
self.refinement_done = False
|
||||
self.candidate_len = len(initial_candidates)
|
||||
|
||||
# Initialize queue with initial candidates
|
||||
for candidate in initial_candidates:
|
||||
self.candidate_queue.put(candidate)
|
||||
|
||||
self.future_line_profile_results = future_line_profile_results
|
||||
self.future_all_refinements = future_all_refinements
|
||||
|
||||
def get_next_candidate(self) -> OptimizedCandidate | None:
|
||||
"""Get the next candidate from the queue, handling async results as needed."""
|
||||
try:
|
||||
return self.candidate_queue.get_nowait()
|
||||
except queue.Empty:
|
||||
return self._handle_empty_queue()
|
||||
|
||||
def _handle_empty_queue(self) -> OptimizedCandidate | None:
|
||||
"""Handle empty queue by checking for pending async results."""
|
||||
if not self.line_profiler_done:
|
||||
return self._process_line_profiler_results()
|
||||
if self.line_profiler_done and not self.refinement_done:
|
||||
return self._process_refinement_results()
|
||||
return None # All done
|
||||
|
||||
def _process_line_profiler_results(self) -> OptimizedCandidate | None:
|
||||
"""Process line profiler results and add to queue."""
|
||||
logger.debug("all candidates processed, await candidates from line profiler")
|
||||
concurrent.futures.wait([self.future_line_profile_results])
|
||||
line_profile_results = self.future_line_profile_results.result()
|
||||
|
||||
for candidate in line_profile_results:
|
||||
self.candidate_queue.put(candidate)
|
||||
|
||||
self.candidate_len += len(line_profile_results)
|
||||
logger.info(f"Added results from line profiler to candidates, total candidates now: {self.candidate_len}")
|
||||
self.line_profiler_done = True
|
||||
|
||||
return self.get_next_candidate()
|
||||
|
||||
def _process_refinement_results(self) -> OptimizedCandidate | None:
|
||||
"""Process refinement results and add to queue."""
|
||||
concurrent.futures.wait(self.future_all_refinements)
|
||||
refinement_response = []
|
||||
|
||||
for future_refinement in self.future_all_refinements:
|
||||
possible_refinement = future_refinement.result()
|
||||
if len(possible_refinement) > 0:
|
||||
refinement_response.append(possible_refinement[0])
|
||||
|
||||
for candidate in refinement_response:
|
||||
self.candidate_queue.put(candidate)
|
||||
|
||||
self.candidate_len += len(refinement_response)
|
||||
logger.info(
|
||||
f"Added {len(refinement_response)} candidates from refinement, total candidates now: {self.candidate_len}"
|
||||
)
|
||||
self.refinement_done = True
|
||||
|
||||
return self.get_next_candidate()
|
||||
|
||||
def is_done(self) -> bool:
|
||||
"""Check if processing is complete."""
|
||||
return self.line_profiler_done and self.refinement_done and self.candidate_queue.empty()
|
||||
|
||||
|
||||
class FunctionOptimizer:
|
||||
def __init__(
|
||||
self,
|
||||
|
|
@ -378,15 +456,13 @@ class FunctionOptimizer:
|
|||
f"{self.function_to_optimize.qualified_name}…"
|
||||
)
|
||||
console.rule()
|
||||
candidates = deque(candidates)
|
||||
refinement_done = False
|
||||
line_profiler_done = False
|
||||
|
||||
future_all_refinements: list[concurrent.futures.Future] = []
|
||||
ast_code_to_id = {}
|
||||
valid_optimizations = []
|
||||
optimizations_post = {} # we need to overwrite some opt candidates' code strings as they are no longer evaluated, instead their shorter/longer versions might be evaluated
|
||||
# Start a new thread for AI service request, start loop in main thread
|
||||
# check if aiservice request is complete, when it is complete, append result to the candidates list
|
||||
|
||||
# Start a new thread for AI service request
|
||||
ai_service_client = self.aiservice_client if exp_type == "EXP0" else self.local_aiservice_client
|
||||
future_line_profile_results = self.executor.submit(
|
||||
ai_service_client.optimize_python_code_line_profiler,
|
||||
|
|
@ -401,48 +477,23 @@ class FunctionOptimizer:
|
|||
if self.experiment_id
|
||||
else None,
|
||||
)
|
||||
candidate_index = 0
|
||||
original_len = len(candidates)
|
||||
# TODO : We need to rewrite this candidate loop as a class, the container which has candidates receives new candidates at unknown times due to the async nature of lp and refinement calls,
|
||||
# TODO : in addition, the refinement calls depend on line profiler calls being complete so we need to check that reliably
|
||||
while True:
|
||||
try:
|
||||
if len(candidates) > 0:
|
||||
candidate = candidates.popleft()
|
||||
else:
|
||||
if not line_profiler_done:
|
||||
logger.debug("all candidates processed, await candidates from line profiler")
|
||||
concurrent.futures.wait([future_line_profile_results])
|
||||
line_profile_results = future_line_profile_results.result()
|
||||
candidates.extend(line_profile_results)
|
||||
original_len += len(line_profile_results)
|
||||
logger.info(
|
||||
f"Added results from line profiler to candidates, total candidates now: {original_len}"
|
||||
)
|
||||
line_profiler_done = True
|
||||
continue
|
||||
if line_profiler_done and not refinement_done:
|
||||
concurrent.futures.wait(future_all_refinements)
|
||||
refinement_response = []
|
||||
for future_refinement in future_all_refinements:
|
||||
possible_refinement = future_refinement.result()
|
||||
if len(possible_refinement) > 0: # if the api returns a valid response
|
||||
refinement_response.append(possible_refinement[0])
|
||||
candidates.extend(refinement_response)
|
||||
original_len += len(refinement_response)
|
||||
logger.info(
|
||||
f"Added {len(refinement_response)} candidates from refinement, total candidates now: {original_len}"
|
||||
)
|
||||
refinement_done = True
|
||||
continue
|
||||
if line_profiler_done and refinement_done:
|
||||
logger.debug("everything done, exiting")
|
||||
break
|
||||
|
||||
# Initialize candidate processor
|
||||
processor = CandidateProcessor(candidates, future_line_profile_results, future_all_refinements)
|
||||
candidate_index = 0
|
||||
|
||||
# Process candidates using queue-based approach
|
||||
while not processor.is_done():
|
||||
candidate = processor.get_next_candidate()
|
||||
if candidate is None:
|
||||
logger.debug("everything done, exiting")
|
||||
break
|
||||
|
||||
try:
|
||||
candidate_index += 1
|
||||
get_run_tmp_file(Path(f"test_return_values_{candidate_index}.bin")).unlink(missing_ok=True)
|
||||
get_run_tmp_file(Path(f"test_return_values_{candidate_index}.sqlite")).unlink(missing_ok=True)
|
||||
logger.info(f"Optimization candidate {candidate_index}/{original_len}:")
|
||||
logger.info(f"Optimization candidate {candidate_index}/{processor.candidate_len}:")
|
||||
code_print(candidate.source_code.flat)
|
||||
# map ast normalized code to diff len, unnormalized code
|
||||
# map opt id to the shortest unnormalized code
|
||||
|
|
@ -467,7 +518,7 @@ class FunctionOptimizer:
|
|||
# check if this code has been evaluated before by checking the ast normalized code string
|
||||
normalized_code = ast.unparse(ast.parse(candidate.source_code.flat.strip()))
|
||||
if normalized_code in ast_code_to_id:
|
||||
logger.warning(
|
||||
logger.info(
|
||||
"Current candidate has been encountered before in testing, Skipping optimization candidate."
|
||||
)
|
||||
past_opt_id = ast_code_to_id[normalized_code]["optimization_id"]
|
||||
|
|
|
|||
Loading…
Reference in a new issue