mirror of
https://github.com/codeflash-ai/codeflash.git
synced 2026-05-04 18:25:17 +00:00
feat: early dedup of optimization candidates before benchmark loop
Dedup candidates in CandidateProcessor when each batch arrives (initial, line profiler, repair, refinement, adaptive) instead of only catching duplicates one-by-one during the benchmark loop. Changes: - Add dedup_candidates() to CandidateProcessor with persistent seen_normalized set that tracks queued candidates across batches - Simplify handle_duplicate_candidate/register_new_candidate to accept original_flat_code string instead of full CodeOptimizationContext - 14 unit tests covering all dedup paths
This commit is contained in:
parent
59031a145e
commit
c72945b045
3 changed files with 350 additions and 10 deletions
|
|
@ -257,6 +257,9 @@ class CandidateProcessor:
|
|||
future_all_refinements: list[concurrent.futures.Future],
|
||||
future_all_code_repair: list[concurrent.futures.Future],
|
||||
future_adaptive_optimizations: list[concurrent.futures.Future],
|
||||
normalize_fn: Callable[[str], str],
|
||||
normalized_original: str,
|
||||
original_flat_code: str,
|
||||
) -> None:
|
||||
self.candidate_queue = queue.Queue()
|
||||
self.forest = CandidateForest()
|
||||
|
|
@ -264,12 +267,17 @@ class CandidateProcessor:
|
|||
self.refinement_done = False
|
||||
self.eval_ctx = eval_ctx
|
||||
self.effort = effort
|
||||
self.candidate_len = len(initial_candidates)
|
||||
self.refinement_calls_count = 0
|
||||
self.original_markdown_code = original_markdown_code
|
||||
self.normalize_fn = normalize_fn
|
||||
self.normalized_original = normalized_original
|
||||
self.original_flat_code = original_flat_code
|
||||
self.seen_normalized: set[str] = set()
|
||||
|
||||
# Initialize queue with initial candidates
|
||||
for candidate in initial_candidates:
|
||||
# Dedup initial candidates before queuing
|
||||
deduped = self.dedup_candidates(initial_candidates)
|
||||
self.candidate_len = len(deduped)
|
||||
for candidate in deduped:
|
||||
self.forest.add(candidate)
|
||||
self.candidate_queue.put(candidate)
|
||||
|
||||
|
|
@ -278,6 +286,53 @@ class CandidateProcessor:
|
|||
self.future_all_code_repair = future_all_code_repair
|
||||
self.future_adaptive_optimizations = future_adaptive_optimizations
|
||||
|
||||
def dedup_candidates(self, candidates: list[OptimizedCandidate]) -> list[OptimizedCandidate]:
|
||||
"""Remove duplicates from a batch of candidates before queuing.
|
||||
|
||||
Filters out candidates that are:
|
||||
- Identical to the original code
|
||||
- Duplicates of previously-registered candidates in eval_ctx.ast_code_to_id
|
||||
(called when the queue is empty, after all prior candidates have been
|
||||
both registered and benchmarked)
|
||||
- Duplicates of candidates already queued from prior batches
|
||||
(tracked in self.seen_normalized which persists across calls)
|
||||
- Intra-batch duplicates
|
||||
"""
|
||||
unique: list[OptimizedCandidate] = []
|
||||
removed_original = 0
|
||||
removed_cross_batch = 0
|
||||
removed_duplicate = 0
|
||||
|
||||
for candidate in candidates:
|
||||
normalized = self.normalize_fn(candidate.source_code.flat.strip())
|
||||
|
||||
if normalized == self.normalized_original:
|
||||
removed_original += 1
|
||||
continue
|
||||
|
||||
if normalized in self.eval_ctx.ast_code_to_id:
|
||||
self.eval_ctx.handle_duplicate_candidate(candidate, normalized, self.original_flat_code)
|
||||
removed_cross_batch += 1
|
||||
continue
|
||||
|
||||
if normalized in self.seen_normalized:
|
||||
removed_duplicate += 1
|
||||
continue
|
||||
|
||||
self.seen_normalized.add(normalized)
|
||||
unique.append(candidate)
|
||||
|
||||
total_removed = removed_original + removed_cross_batch + removed_duplicate
|
||||
if total_removed > 0:
|
||||
logger.info(
|
||||
f"Early dedup removed {total_removed} candidate(s) "
|
||||
f"({removed_original} identical to original, "
|
||||
f"{removed_cross_batch} already-benchmarked duplicates, "
|
||||
f"{removed_duplicate} duplicates)"
|
||||
)
|
||||
|
||||
return unique
|
||||
|
||||
def get_total_llm_calls(self) -> int:
|
||||
return self.refinement_calls_count
|
||||
|
||||
|
|
@ -347,6 +402,7 @@ class CandidateProcessor:
|
|||
candidates.append(candidate_result)
|
||||
|
||||
candidates = filter_candidates_func(candidates) if filter_candidates_func else candidates
|
||||
candidates = self.dedup_candidates(candidates)
|
||||
for candidate in candidates:
|
||||
self.forest.add(candidate)
|
||||
self.candidate_queue.put(candidate)
|
||||
|
|
@ -1107,7 +1163,7 @@ class FunctionOptimizer:
|
|||
logger.info(
|
||||
f"h3|Candidate {candidate_index}/{total_candidates}: Duplicate of a previous candidate, skipping."
|
||||
)
|
||||
eval_ctx.handle_duplicate_candidate(candidate, normalized_code, code_context)
|
||||
eval_ctx.handle_duplicate_candidate(candidate, normalized_code, code_context.read_writable_code.flat)
|
||||
console.rule()
|
||||
return None
|
||||
|
||||
|
|
@ -1139,7 +1195,7 @@ class FunctionOptimizer:
|
|||
)
|
||||
return None
|
||||
|
||||
eval_ctx.register_new_candidate(normalized_code, candidate, code_context)
|
||||
eval_ctx.register_new_candidate(normalized_code, candidate, code_context.read_writable_code.flat)
|
||||
|
||||
# Run the optimized candidate
|
||||
run_results = self.run_optimized_candidate(
|
||||
|
|
@ -1299,6 +1355,7 @@ class FunctionOptimizer:
|
|||
language_version=self.language_support.language_version,
|
||||
)
|
||||
|
||||
normalized_original = self.language_support.normalize_code(code_context.read_writable_code.flat.strip())
|
||||
processor = CandidateProcessor(
|
||||
candidates,
|
||||
future_line_profile_results,
|
||||
|
|
@ -1308,9 +1365,11 @@ class FunctionOptimizer:
|
|||
self.future_all_refinements,
|
||||
self.future_all_code_repair,
|
||||
self.future_adaptive_optimizations,
|
||||
normalize_fn=self.language_support.normalize_code,
|
||||
normalized_original=normalized_original,
|
||||
original_flat_code=code_context.read_writable_code.flat,
|
||||
)
|
||||
candidate_index = 0
|
||||
normalized_original = self.language_support.normalize_code(code_context.read_writable_code.flat.strip())
|
||||
|
||||
# Process candidates using queue-based approach
|
||||
while not processor.is_done():
|
||||
|
|
|
|||
|
|
@ -542,7 +542,7 @@ class CandidateEvaluationContext:
|
|||
self.optimized_line_profiler_results[optimization_id] = result
|
||||
|
||||
def handle_duplicate_candidate(
|
||||
self, candidate: OptimizedCandidate, normalized_code: str, code_context: CodeOptimizationContext
|
||||
self, candidate: OptimizedCandidate, normalized_code: str, original_flat_code: str
|
||||
) -> None:
|
||||
"""Handle a candidate that has been seen before."""
|
||||
past_opt_id = self.ast_code_to_id[normalized_code]["optimization_id"]
|
||||
|
|
@ -564,19 +564,19 @@ class CandidateEvaluationContext:
|
|||
self.optimizations_post[past_opt_id] = self.ast_code_to_id[normalized_code]["shorter_source_code"].markdown
|
||||
|
||||
# Update to shorter code if this candidate has a shorter diff
|
||||
new_diff_len = diff_length(candidate.source_code.flat, code_context.read_writable_code.flat)
|
||||
new_diff_len = diff_length(candidate.source_code.flat, original_flat_code)
|
||||
if new_diff_len < self.ast_code_to_id[normalized_code]["diff_len"]:
|
||||
self.ast_code_to_id[normalized_code]["shorter_source_code"] = candidate.source_code
|
||||
self.ast_code_to_id[normalized_code]["diff_len"] = new_diff_len
|
||||
|
||||
def register_new_candidate(
|
||||
self, normalized_code: str, candidate: OptimizedCandidate, code_context: CodeOptimizationContext
|
||||
self, normalized_code: str, candidate: OptimizedCandidate, original_flat_code: str
|
||||
) -> None:
|
||||
"""Register a new candidate that hasn't been seen before."""
|
||||
self.ast_code_to_id[normalized_code] = {
|
||||
"optimization_id": candidate.optimization_id,
|
||||
"shorter_source_code": candidate.source_code,
|
||||
"diff_len": diff_length(candidate.source_code.flat, code_context.read_writable_code.flat),
|
||||
"diff_len": diff_length(candidate.source_code.flat, original_flat_code),
|
||||
}
|
||||
|
||||
def get_speedup_ratio(self, optimization_id: str) -> float | None:
|
||||
|
|
|
|||
281
tests/test_early_dedup.py
Normal file
281
tests/test_early_dedup.py
Normal file
|
|
@ -0,0 +1,281 @@
|
|||
"""Tests for early candidate deduplication in CandidateProcessor."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import concurrent.futures
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
from codeflash.languages.function_optimizer import CandidateProcessor
|
||||
from codeflash.languages.python.normalizer import normalize_python_code
|
||||
from codeflash.models.models import (
|
||||
CandidateEvaluationContext,
|
||||
CodeString,
|
||||
CodeStringsMarkdown,
|
||||
OptimizedCandidate,
|
||||
OptimizedCandidateSource,
|
||||
)
|
||||
|
||||
|
||||
def make_source_code(code: str) -> CodeStringsMarkdown:
|
||||
return CodeStringsMarkdown(
|
||||
code_strings=[CodeString(code=code, file_path=Path("test.py"))],
|
||||
)
|
||||
|
||||
|
||||
def make_candidate(code: str, opt_id: str | None = None, source: OptimizedCandidateSource = OptimizedCandidateSource.OPTIMIZE) -> OptimizedCandidate:
|
||||
return OptimizedCandidate(
|
||||
source_code=make_source_code(code),
|
||||
explanation="test",
|
||||
optimization_id=opt_id or f"opt-{id(code)}",
|
||||
source=source,
|
||||
)
|
||||
|
||||
|
||||
def normalize_fn(source: str) -> str:
|
||||
try:
|
||||
return normalize_python_code(source, remove_docstrings=True)
|
||||
except Exception:
|
||||
return source
|
||||
|
||||
|
||||
ORIGINAL_CODE = "def foo(x):\n return x + 1\n"
|
||||
ORIGINAL_FLAT = f"# file: test.py\n{ORIGINAL_CODE}"
|
||||
|
||||
# Normalizes identically to ORIGINAL_CODE (docstring and comment stripped)
|
||||
IDENTICAL_TO_ORIGINAL = 'def foo(x):\n """Docstring."""\n # comment\n return x + 1\n'
|
||||
|
||||
# Different from original
|
||||
CANDIDATE_A = "def foo(x):\n return x + 2\n"
|
||||
CANDIDATE_B = "def foo(x):\n return x * 2\n"
|
||||
CANDIDATE_C = "def foo(x):\n return x << 1\n"
|
||||
|
||||
# Normalizes identically to CANDIDATE_A (added comment stripped by normalizer)
|
||||
CANDIDATE_A_DUP = "def foo(x):\n # optimized\n return x + 2\n"
|
||||
|
||||
|
||||
def make_done_future(value=None):
|
||||
f = concurrent.futures.Future()
|
||||
f.set_result(value)
|
||||
return f
|
||||
|
||||
|
||||
def make_processor(initial_candidates, eval_ctx=None):
|
||||
if eval_ctx is None:
|
||||
eval_ctx = CandidateEvaluationContext()
|
||||
return CandidateProcessor(
|
||||
initial_candidates=initial_candidates,
|
||||
future_line_profile_results=make_done_future(None),
|
||||
eval_ctx=eval_ctx,
|
||||
effort="default",
|
||||
original_markdown_code=f"```python\n{ORIGINAL_CODE}```",
|
||||
future_all_refinements=[],
|
||||
future_all_code_repair=[],
|
||||
future_adaptive_optimizations=[],
|
||||
normalize_fn=normalize_fn,
|
||||
normalized_original=normalize_fn(ORIGINAL_CODE.strip()),
|
||||
original_flat_code=ORIGINAL_FLAT,
|
||||
)
|
||||
|
||||
|
||||
class TestDedup:
|
||||
def test_unique_candidates_pass_through(self):
|
||||
candidates = [
|
||||
make_candidate(CANDIDATE_A, "opt-a"),
|
||||
make_candidate(CANDIDATE_B, "opt-b"),
|
||||
make_candidate(CANDIDATE_C, "opt-c"),
|
||||
]
|
||||
proc = make_processor(candidates)
|
||||
assert proc.candidate_len == 3
|
||||
|
||||
def test_identical_to_original_removed(self):
|
||||
candidates = [
|
||||
make_candidate(IDENTICAL_TO_ORIGINAL, "opt-dup-orig"),
|
||||
make_candidate(CANDIDATE_A, "opt-a"),
|
||||
]
|
||||
proc = make_processor(candidates)
|
||||
assert proc.candidate_len == 1
|
||||
|
||||
def test_intra_batch_duplicates_removed(self):
|
||||
candidates = [
|
||||
make_candidate(CANDIDATE_A, "opt-a1"),
|
||||
make_candidate(CANDIDATE_A_DUP, "opt-a2"),
|
||||
make_candidate(CANDIDATE_B, "opt-b"),
|
||||
]
|
||||
proc = make_processor(candidates)
|
||||
assert proc.candidate_len == 2
|
||||
|
||||
def test_cross_batch_duplicates_copy_results(self):
|
||||
eval_ctx = CandidateEvaluationContext()
|
||||
# Simulate a previously-benchmarked candidate
|
||||
prev_candidate = make_candidate(CANDIDATE_A, "opt-prev")
|
||||
eval_ctx.register_new_candidate(
|
||||
normalize_fn(CANDIDATE_A.strip()),
|
||||
prev_candidate,
|
||||
ORIGINAL_FLAT,
|
||||
)
|
||||
eval_ctx.record_successful_candidate("opt-prev", runtime=1000.0, speedup=2.0)
|
||||
|
||||
# New batch has a duplicate of the already-benchmarked candidate
|
||||
new_candidates = [
|
||||
make_candidate(CANDIDATE_A_DUP, "opt-new-dup"),
|
||||
make_candidate(CANDIDATE_B, "opt-b"),
|
||||
]
|
||||
proc = make_processor(new_candidates, eval_ctx=eval_ctx)
|
||||
# Only CANDIDATE_B should be queued (A_DUP is a cross-batch dup)
|
||||
assert proc.candidate_len == 1
|
||||
# Results should be copied to the duplicate
|
||||
assert eval_ctx.speedup_ratios["opt-new-dup"] == 2.0
|
||||
assert eval_ctx.optimized_runtimes["opt-new-dup"] == 1000.0
|
||||
assert eval_ctx.is_correct["opt-new-dup"] is True
|
||||
|
||||
def test_empty_list(self):
|
||||
proc = make_processor([])
|
||||
assert proc.candidate_len == 0
|
||||
|
||||
def test_all_duplicates_of_original(self):
|
||||
candidates = [
|
||||
make_candidate(IDENTICAL_TO_ORIGINAL, "opt-1"),
|
||||
make_candidate(ORIGINAL_CODE, "opt-2"),
|
||||
]
|
||||
proc = make_processor(candidates)
|
||||
assert proc.candidate_len == 0
|
||||
|
||||
def test_mixed_removal_types(self):
|
||||
eval_ctx = CandidateEvaluationContext()
|
||||
prev = make_candidate(CANDIDATE_C, "opt-prev-c")
|
||||
eval_ctx.register_new_candidate(normalize_fn(CANDIDATE_C.strip()), prev, ORIGINAL_FLAT)
|
||||
eval_ctx.record_successful_candidate("opt-prev-c", runtime=500.0, speedup=3.0)
|
||||
|
||||
candidates = [
|
||||
make_candidate(IDENTICAL_TO_ORIGINAL, "opt-orig"), # identical to original
|
||||
make_candidate(CANDIDATE_A, "opt-a1"), # unique
|
||||
make_candidate(CANDIDATE_A_DUP, "opt-a2"), # intra-batch dup of opt-a1
|
||||
make_candidate(CANDIDATE_C, "opt-c-dup"), # cross-batch dup
|
||||
make_candidate(CANDIDATE_B, "opt-b"), # unique
|
||||
]
|
||||
proc = make_processor(candidates, eval_ctx=eval_ctx)
|
||||
# Only CANDIDATE_A and CANDIDATE_B should survive
|
||||
assert proc.candidate_len == 2
|
||||
# Cross-batch dup should have results copied
|
||||
assert eval_ctx.speedup_ratios["opt-c-dup"] == 3.0
|
||||
|
||||
def test_dedup_in_async_batch(self):
|
||||
"""Candidates arriving from line profiler are deduped against prior batches via seen_normalized."""
|
||||
candidates_initial = [make_candidate(CANDIDATE_A, "opt-a")]
|
||||
proc = make_processor(candidates_initial)
|
||||
assert proc.candidate_len == 1
|
||||
|
||||
# Simulate what _process_candidates does: dedup a new batch
|
||||
async_batch = [
|
||||
make_candidate(CANDIDATE_B, "opt-b"),
|
||||
make_candidate(CANDIDATE_A_DUP, "opt-a-lp"), # dup of initial, caught by seen_normalized
|
||||
]
|
||||
deduped = proc.dedup_candidates(async_batch)
|
||||
assert len(deduped) == 1
|
||||
assert deduped[0].optimization_id == "opt-b"
|
||||
|
||||
def test_dedup_in_async_batch_after_benchmark(self):
|
||||
"""After initial candidates are benchmarked, async batch dedup catches cross-batch dups."""
|
||||
eval_ctx = CandidateEvaluationContext()
|
||||
# Simulate initial candidate already benchmarked
|
||||
prev = make_candidate(CANDIDATE_A, "opt-a")
|
||||
eval_ctx.register_new_candidate(normalize_fn(CANDIDATE_A.strip()), prev, ORIGINAL_FLAT)
|
||||
eval_ctx.record_successful_candidate("opt-a", runtime=2000.0, speedup=1.5)
|
||||
|
||||
proc = make_processor([], eval_ctx=eval_ctx)
|
||||
|
||||
async_batch = [
|
||||
make_candidate(CANDIDATE_A_DUP, "opt-a-lp"),
|
||||
make_candidate(CANDIDATE_B, "opt-b"),
|
||||
]
|
||||
deduped = proc.dedup_candidates(async_batch)
|
||||
assert len(deduped) == 1
|
||||
assert deduped[0].optimization_id == "opt-b"
|
||||
assert eval_ctx.speedup_ratios["opt-a-lp"] == 1.5
|
||||
|
||||
|
||||
class TestCandidateEvaluationContext:
|
||||
"""Direct tests for register_new_candidate and handle_duplicate_candidate with original_flat_code param."""
|
||||
|
||||
def test_register_new_candidate_stores_diff_len(self):
|
||||
eval_ctx = CandidateEvaluationContext()
|
||||
candidate = make_candidate(CANDIDATE_A, "opt-a")
|
||||
normalized = normalize_fn(CANDIDATE_A.strip())
|
||||
|
||||
eval_ctx.register_new_candidate(normalized, candidate, ORIGINAL_FLAT)
|
||||
|
||||
entry = eval_ctx.ast_code_to_id[normalized]
|
||||
assert entry["optimization_id"] == "opt-a"
|
||||
assert entry["shorter_source_code"] is candidate.source_code
|
||||
assert isinstance(entry["diff_len"], int)
|
||||
assert entry["diff_len"] > 0
|
||||
|
||||
def test_handle_duplicate_copies_all_results(self):
|
||||
eval_ctx = CandidateEvaluationContext()
|
||||
first = make_candidate(CANDIDATE_A, "opt-first")
|
||||
normalized = normalize_fn(CANDIDATE_A.strip())
|
||||
|
||||
eval_ctx.register_new_candidate(normalized, first, ORIGINAL_FLAT)
|
||||
eval_ctx.record_successful_candidate("opt-first", runtime=1234.0, speedup=2.5)
|
||||
eval_ctx.record_line_profiler_result("opt-first", "line profiler output")
|
||||
|
||||
dup = make_candidate(CANDIDATE_A_DUP, "opt-dup")
|
||||
eval_ctx.handle_duplicate_candidate(dup, normalized, ORIGINAL_FLAT)
|
||||
|
||||
assert eval_ctx.speedup_ratios["opt-dup"] == 2.5
|
||||
assert eval_ctx.optimized_runtimes["opt-dup"] == 1234.0
|
||||
assert eval_ctx.is_correct["opt-dup"] is True
|
||||
assert eval_ctx.optimized_line_profiler_results["opt-dup"] == "line profiler output"
|
||||
|
||||
def test_handle_duplicate_copies_failed_results(self):
|
||||
eval_ctx = CandidateEvaluationContext()
|
||||
first = make_candidate(CANDIDATE_A, "opt-first")
|
||||
normalized = normalize_fn(CANDIDATE_A.strip())
|
||||
|
||||
eval_ctx.register_new_candidate(normalized, first, ORIGINAL_FLAT)
|
||||
eval_ctx.record_failed_candidate("opt-first")
|
||||
|
||||
dup = make_candidate(CANDIDATE_A_DUP, "opt-dup")
|
||||
eval_ctx.handle_duplicate_candidate(dup, normalized, ORIGINAL_FLAT)
|
||||
|
||||
assert eval_ctx.speedup_ratios["opt-dup"] is None
|
||||
assert eval_ctx.optimized_runtimes["opt-dup"] is None
|
||||
assert eval_ctx.is_correct["opt-dup"] is False
|
||||
|
||||
def test_handle_duplicate_tracks_shorter_source(self):
|
||||
"""When a duplicate has a shorter diff, it replaces the stored shorter_source_code."""
|
||||
eval_ctx = CandidateEvaluationContext()
|
||||
# Register a candidate with longer code
|
||||
longer_code = "def foo(x):\n # this comment makes it longer\n # and this one too\n return x + 2\n"
|
||||
first = make_candidate(longer_code, "opt-long")
|
||||
normalized = normalize_fn(longer_code.strip())
|
||||
|
||||
eval_ctx.register_new_candidate(normalized, first, ORIGINAL_FLAT)
|
||||
eval_ctx.record_successful_candidate("opt-long", runtime=500.0, speedup=3.0)
|
||||
original_diff_len = eval_ctx.ast_code_to_id[normalized]["diff_len"]
|
||||
|
||||
# Duplicate with shorter code (same normalized form)
|
||||
shorter = make_candidate(CANDIDATE_A, "opt-short")
|
||||
eval_ctx.handle_duplicate_candidate(shorter, normalized, ORIGINAL_FLAT)
|
||||
new_diff_len = eval_ctx.ast_code_to_id[normalized]["diff_len"]
|
||||
|
||||
# Shorter code should have replaced the longer one
|
||||
assert new_diff_len <= original_diff_len
|
||||
assert eval_ctx.ast_code_to_id[normalized]["shorter_source_code"] is shorter.source_code
|
||||
|
||||
def test_handle_duplicate_keeps_shorter_when_new_is_longer(self):
|
||||
"""When a duplicate has a longer diff, the original shorter_source_code is kept."""
|
||||
eval_ctx = CandidateEvaluationContext()
|
||||
first = make_candidate(CANDIDATE_A, "opt-short")
|
||||
normalized = normalize_fn(CANDIDATE_A.strip())
|
||||
|
||||
eval_ctx.register_new_candidate(normalized, first, ORIGINAL_FLAT)
|
||||
eval_ctx.record_successful_candidate("opt-short", runtime=500.0, speedup=3.0)
|
||||
|
||||
longer_code = "def foo(x):\n # this comment makes it longer\n # and this one too\n return x + 2\n"
|
||||
dup = make_candidate(longer_code, "opt-long")
|
||||
eval_ctx.handle_duplicate_candidate(dup, normalized, ORIGINAL_FLAT)
|
||||
|
||||
assert eval_ctx.ast_code_to_id[normalized]["shorter_source_code"] is first.source_code
|
||||
Loading…
Reference in a new issue