feat: early dedup of optimization candidates before benchmark loop

Dedup candidates in CandidateProcessor when each batch arrives (initial,
line profiler, repair, refinement, adaptive) instead of only catching
duplicates one-by-one during the benchmark loop.

Changes:
- Add dedup_candidates() to CandidateProcessor with persistent
  seen_normalized set that tracks queued candidates across batches
- Simplify handle_duplicate_candidate/register_new_candidate to accept
  original_flat_code string instead of full CodeOptimizationContext
- 14 unit tests covering all dedup paths
This commit is contained in:
Kevin Turcios 2026-03-22 05:19:15 -05:00
parent 59031a145e
commit c72945b045
3 changed files with 350 additions and 10 deletions

View file

@ -257,6 +257,9 @@ class CandidateProcessor:
future_all_refinements: list[concurrent.futures.Future],
future_all_code_repair: list[concurrent.futures.Future],
future_adaptive_optimizations: list[concurrent.futures.Future],
normalize_fn: Callable[[str], str],
normalized_original: str,
original_flat_code: str,
) -> None:
self.candidate_queue = queue.Queue()
self.forest = CandidateForest()
@ -264,12 +267,17 @@ class CandidateProcessor:
self.refinement_done = False
self.eval_ctx = eval_ctx
self.effort = effort
self.candidate_len = len(initial_candidates)
self.refinement_calls_count = 0
self.original_markdown_code = original_markdown_code
self.normalize_fn = normalize_fn
self.normalized_original = normalized_original
self.original_flat_code = original_flat_code
self.seen_normalized: set[str] = set()
# Initialize queue with initial candidates
for candidate in initial_candidates:
# Dedup initial candidates before queuing
deduped = self.dedup_candidates(initial_candidates)
self.candidate_len = len(deduped)
for candidate in deduped:
self.forest.add(candidate)
self.candidate_queue.put(candidate)
@ -278,6 +286,53 @@ class CandidateProcessor:
self.future_all_code_repair = future_all_code_repair
self.future_adaptive_optimizations = future_adaptive_optimizations
def dedup_candidates(self, candidates: list[OptimizedCandidate]) -> list[OptimizedCandidate]:
"""Remove duplicates from a batch of candidates before queuing.
Filters out candidates that are:
- Identical to the original code
- Duplicates of previously-registered candidates in eval_ctx.ast_code_to_id
(called when the queue is empty, after all prior candidates have been
both registered and benchmarked)
- Duplicates of candidates already queued from prior batches
(tracked in self.seen_normalized which persists across calls)
- Intra-batch duplicates
"""
unique: list[OptimizedCandidate] = []
removed_original = 0
removed_cross_batch = 0
removed_duplicate = 0
for candidate in candidates:
normalized = self.normalize_fn(candidate.source_code.flat.strip())
if normalized == self.normalized_original:
removed_original += 1
continue
if normalized in self.eval_ctx.ast_code_to_id:
self.eval_ctx.handle_duplicate_candidate(candidate, normalized, self.original_flat_code)
removed_cross_batch += 1
continue
if normalized in self.seen_normalized:
removed_duplicate += 1
continue
self.seen_normalized.add(normalized)
unique.append(candidate)
total_removed = removed_original + removed_cross_batch + removed_duplicate
if total_removed > 0:
logger.info(
f"Early dedup removed {total_removed} candidate(s) "
f"({removed_original} identical to original, "
f"{removed_cross_batch} already-benchmarked duplicates, "
f"{removed_duplicate} duplicates)"
)
return unique
def get_total_llm_calls(self) -> int:
return self.refinement_calls_count
@ -347,6 +402,7 @@ class CandidateProcessor:
candidates.append(candidate_result)
candidates = filter_candidates_func(candidates) if filter_candidates_func else candidates
candidates = self.dedup_candidates(candidates)
for candidate in candidates:
self.forest.add(candidate)
self.candidate_queue.put(candidate)
@ -1107,7 +1163,7 @@ class FunctionOptimizer:
logger.info(
f"h3|Candidate {candidate_index}/{total_candidates}: Duplicate of a previous candidate, skipping."
)
eval_ctx.handle_duplicate_candidate(candidate, normalized_code, code_context)
eval_ctx.handle_duplicate_candidate(candidate, normalized_code, code_context.read_writable_code.flat)
console.rule()
return None
@ -1139,7 +1195,7 @@ class FunctionOptimizer:
)
return None
eval_ctx.register_new_candidate(normalized_code, candidate, code_context)
eval_ctx.register_new_candidate(normalized_code, candidate, code_context.read_writable_code.flat)
# Run the optimized candidate
run_results = self.run_optimized_candidate(
@ -1299,6 +1355,7 @@ class FunctionOptimizer:
language_version=self.language_support.language_version,
)
normalized_original = self.language_support.normalize_code(code_context.read_writable_code.flat.strip())
processor = CandidateProcessor(
candidates,
future_line_profile_results,
@ -1308,9 +1365,11 @@ class FunctionOptimizer:
self.future_all_refinements,
self.future_all_code_repair,
self.future_adaptive_optimizations,
normalize_fn=self.language_support.normalize_code,
normalized_original=normalized_original,
original_flat_code=code_context.read_writable_code.flat,
)
candidate_index = 0
normalized_original = self.language_support.normalize_code(code_context.read_writable_code.flat.strip())
# Process candidates using queue-based approach
while not processor.is_done():

View file

@ -542,7 +542,7 @@ class CandidateEvaluationContext:
self.optimized_line_profiler_results[optimization_id] = result
def handle_duplicate_candidate(
self, candidate: OptimizedCandidate, normalized_code: str, code_context: CodeOptimizationContext
self, candidate: OptimizedCandidate, normalized_code: str, original_flat_code: str
) -> None:
"""Handle a candidate that has been seen before."""
past_opt_id = self.ast_code_to_id[normalized_code]["optimization_id"]
@ -564,19 +564,19 @@ class CandidateEvaluationContext:
self.optimizations_post[past_opt_id] = self.ast_code_to_id[normalized_code]["shorter_source_code"].markdown
# Update to shorter code if this candidate has a shorter diff
new_diff_len = diff_length(candidate.source_code.flat, code_context.read_writable_code.flat)
new_diff_len = diff_length(candidate.source_code.flat, original_flat_code)
if new_diff_len < self.ast_code_to_id[normalized_code]["diff_len"]:
self.ast_code_to_id[normalized_code]["shorter_source_code"] = candidate.source_code
self.ast_code_to_id[normalized_code]["diff_len"] = new_diff_len
def register_new_candidate(
self, normalized_code: str, candidate: OptimizedCandidate, code_context: CodeOptimizationContext
self, normalized_code: str, candidate: OptimizedCandidate, original_flat_code: str
) -> None:
"""Register a new candidate that hasn't been seen before."""
self.ast_code_to_id[normalized_code] = {
"optimization_id": candidate.optimization_id,
"shorter_source_code": candidate.source_code,
"diff_len": diff_length(candidate.source_code.flat, code_context.read_writable_code.flat),
"diff_len": diff_length(candidate.source_code.flat, original_flat_code),
}
def get_speedup_ratio(self, optimization_id: str) -> float | None:

281
tests/test_early_dedup.py Normal file
View file

@ -0,0 +1,281 @@
"""Tests for early candidate deduplication in CandidateProcessor."""
from __future__ import annotations
import concurrent.futures
from pathlib import Path
import pytest
from codeflash.languages.function_optimizer import CandidateProcessor
from codeflash.languages.python.normalizer import normalize_python_code
from codeflash.models.models import (
CandidateEvaluationContext,
CodeString,
CodeStringsMarkdown,
OptimizedCandidate,
OptimizedCandidateSource,
)
def make_source_code(code: str) -> CodeStringsMarkdown:
return CodeStringsMarkdown(
code_strings=[CodeString(code=code, file_path=Path("test.py"))],
)
def make_candidate(code: str, opt_id: str | None = None, source: OptimizedCandidateSource = OptimizedCandidateSource.OPTIMIZE) -> OptimizedCandidate:
return OptimizedCandidate(
source_code=make_source_code(code),
explanation="test",
optimization_id=opt_id or f"opt-{id(code)}",
source=source,
)
def normalize_fn(source: str) -> str:
try:
return normalize_python_code(source, remove_docstrings=True)
except Exception:
return source
ORIGINAL_CODE = "def foo(x):\n return x + 1\n"
ORIGINAL_FLAT = f"# file: test.py\n{ORIGINAL_CODE}"
# Normalizes identically to ORIGINAL_CODE (docstring and comment stripped)
IDENTICAL_TO_ORIGINAL = 'def foo(x):\n """Docstring."""\n # comment\n return x + 1\n'
# Different from original
CANDIDATE_A = "def foo(x):\n return x + 2\n"
CANDIDATE_B = "def foo(x):\n return x * 2\n"
CANDIDATE_C = "def foo(x):\n return x << 1\n"
# Normalizes identically to CANDIDATE_A (added comment stripped by normalizer)
CANDIDATE_A_DUP = "def foo(x):\n # optimized\n return x + 2\n"
def make_done_future(value=None):
f = concurrent.futures.Future()
f.set_result(value)
return f
def make_processor(initial_candidates, eval_ctx=None):
if eval_ctx is None:
eval_ctx = CandidateEvaluationContext()
return CandidateProcessor(
initial_candidates=initial_candidates,
future_line_profile_results=make_done_future(None),
eval_ctx=eval_ctx,
effort="default",
original_markdown_code=f"```python\n{ORIGINAL_CODE}```",
future_all_refinements=[],
future_all_code_repair=[],
future_adaptive_optimizations=[],
normalize_fn=normalize_fn,
normalized_original=normalize_fn(ORIGINAL_CODE.strip()),
original_flat_code=ORIGINAL_FLAT,
)
class TestDedup:
def test_unique_candidates_pass_through(self):
candidates = [
make_candidate(CANDIDATE_A, "opt-a"),
make_candidate(CANDIDATE_B, "opt-b"),
make_candidate(CANDIDATE_C, "opt-c"),
]
proc = make_processor(candidates)
assert proc.candidate_len == 3
def test_identical_to_original_removed(self):
candidates = [
make_candidate(IDENTICAL_TO_ORIGINAL, "opt-dup-orig"),
make_candidate(CANDIDATE_A, "opt-a"),
]
proc = make_processor(candidates)
assert proc.candidate_len == 1
def test_intra_batch_duplicates_removed(self):
candidates = [
make_candidate(CANDIDATE_A, "opt-a1"),
make_candidate(CANDIDATE_A_DUP, "opt-a2"),
make_candidate(CANDIDATE_B, "opt-b"),
]
proc = make_processor(candidates)
assert proc.candidate_len == 2
def test_cross_batch_duplicates_copy_results(self):
eval_ctx = CandidateEvaluationContext()
# Simulate a previously-benchmarked candidate
prev_candidate = make_candidate(CANDIDATE_A, "opt-prev")
eval_ctx.register_new_candidate(
normalize_fn(CANDIDATE_A.strip()),
prev_candidate,
ORIGINAL_FLAT,
)
eval_ctx.record_successful_candidate("opt-prev", runtime=1000.0, speedup=2.0)
# New batch has a duplicate of the already-benchmarked candidate
new_candidates = [
make_candidate(CANDIDATE_A_DUP, "opt-new-dup"),
make_candidate(CANDIDATE_B, "opt-b"),
]
proc = make_processor(new_candidates, eval_ctx=eval_ctx)
# Only CANDIDATE_B should be queued (A_DUP is a cross-batch dup)
assert proc.candidate_len == 1
# Results should be copied to the duplicate
assert eval_ctx.speedup_ratios["opt-new-dup"] == 2.0
assert eval_ctx.optimized_runtimes["opt-new-dup"] == 1000.0
assert eval_ctx.is_correct["opt-new-dup"] is True
def test_empty_list(self):
proc = make_processor([])
assert proc.candidate_len == 0
def test_all_duplicates_of_original(self):
candidates = [
make_candidate(IDENTICAL_TO_ORIGINAL, "opt-1"),
make_candidate(ORIGINAL_CODE, "opt-2"),
]
proc = make_processor(candidates)
assert proc.candidate_len == 0
def test_mixed_removal_types(self):
eval_ctx = CandidateEvaluationContext()
prev = make_candidate(CANDIDATE_C, "opt-prev-c")
eval_ctx.register_new_candidate(normalize_fn(CANDIDATE_C.strip()), prev, ORIGINAL_FLAT)
eval_ctx.record_successful_candidate("opt-prev-c", runtime=500.0, speedup=3.0)
candidates = [
make_candidate(IDENTICAL_TO_ORIGINAL, "opt-orig"), # identical to original
make_candidate(CANDIDATE_A, "opt-a1"), # unique
make_candidate(CANDIDATE_A_DUP, "opt-a2"), # intra-batch dup of opt-a1
make_candidate(CANDIDATE_C, "opt-c-dup"), # cross-batch dup
make_candidate(CANDIDATE_B, "opt-b"), # unique
]
proc = make_processor(candidates, eval_ctx=eval_ctx)
# Only CANDIDATE_A and CANDIDATE_B should survive
assert proc.candidate_len == 2
# Cross-batch dup should have results copied
assert eval_ctx.speedup_ratios["opt-c-dup"] == 3.0
def test_dedup_in_async_batch(self):
"""Candidates arriving from line profiler are deduped against prior batches via seen_normalized."""
candidates_initial = [make_candidate(CANDIDATE_A, "opt-a")]
proc = make_processor(candidates_initial)
assert proc.candidate_len == 1
# Simulate what _process_candidates does: dedup a new batch
async_batch = [
make_candidate(CANDIDATE_B, "opt-b"),
make_candidate(CANDIDATE_A_DUP, "opt-a-lp"), # dup of initial, caught by seen_normalized
]
deduped = proc.dedup_candidates(async_batch)
assert len(deduped) == 1
assert deduped[0].optimization_id == "opt-b"
def test_dedup_in_async_batch_after_benchmark(self):
"""After initial candidates are benchmarked, async batch dedup catches cross-batch dups."""
eval_ctx = CandidateEvaluationContext()
# Simulate initial candidate already benchmarked
prev = make_candidate(CANDIDATE_A, "opt-a")
eval_ctx.register_new_candidate(normalize_fn(CANDIDATE_A.strip()), prev, ORIGINAL_FLAT)
eval_ctx.record_successful_candidate("opt-a", runtime=2000.0, speedup=1.5)
proc = make_processor([], eval_ctx=eval_ctx)
async_batch = [
make_candidate(CANDIDATE_A_DUP, "opt-a-lp"),
make_candidate(CANDIDATE_B, "opt-b"),
]
deduped = proc.dedup_candidates(async_batch)
assert len(deduped) == 1
assert deduped[0].optimization_id == "opt-b"
assert eval_ctx.speedup_ratios["opt-a-lp"] == 1.5
class TestCandidateEvaluationContext:
"""Direct tests for register_new_candidate and handle_duplicate_candidate with original_flat_code param."""
def test_register_new_candidate_stores_diff_len(self):
eval_ctx = CandidateEvaluationContext()
candidate = make_candidate(CANDIDATE_A, "opt-a")
normalized = normalize_fn(CANDIDATE_A.strip())
eval_ctx.register_new_candidate(normalized, candidate, ORIGINAL_FLAT)
entry = eval_ctx.ast_code_to_id[normalized]
assert entry["optimization_id"] == "opt-a"
assert entry["shorter_source_code"] is candidate.source_code
assert isinstance(entry["diff_len"], int)
assert entry["diff_len"] > 0
def test_handle_duplicate_copies_all_results(self):
eval_ctx = CandidateEvaluationContext()
first = make_candidate(CANDIDATE_A, "opt-first")
normalized = normalize_fn(CANDIDATE_A.strip())
eval_ctx.register_new_candidate(normalized, first, ORIGINAL_FLAT)
eval_ctx.record_successful_candidate("opt-first", runtime=1234.0, speedup=2.5)
eval_ctx.record_line_profiler_result("opt-first", "line profiler output")
dup = make_candidate(CANDIDATE_A_DUP, "opt-dup")
eval_ctx.handle_duplicate_candidate(dup, normalized, ORIGINAL_FLAT)
assert eval_ctx.speedup_ratios["opt-dup"] == 2.5
assert eval_ctx.optimized_runtimes["opt-dup"] == 1234.0
assert eval_ctx.is_correct["opt-dup"] is True
assert eval_ctx.optimized_line_profiler_results["opt-dup"] == "line profiler output"
def test_handle_duplicate_copies_failed_results(self):
eval_ctx = CandidateEvaluationContext()
first = make_candidate(CANDIDATE_A, "opt-first")
normalized = normalize_fn(CANDIDATE_A.strip())
eval_ctx.register_new_candidate(normalized, first, ORIGINAL_FLAT)
eval_ctx.record_failed_candidate("opt-first")
dup = make_candidate(CANDIDATE_A_DUP, "opt-dup")
eval_ctx.handle_duplicate_candidate(dup, normalized, ORIGINAL_FLAT)
assert eval_ctx.speedup_ratios["opt-dup"] is None
assert eval_ctx.optimized_runtimes["opt-dup"] is None
assert eval_ctx.is_correct["opt-dup"] is False
def test_handle_duplicate_tracks_shorter_source(self):
"""When a duplicate has a shorter diff, it replaces the stored shorter_source_code."""
eval_ctx = CandidateEvaluationContext()
# Register a candidate with longer code
longer_code = "def foo(x):\n # this comment makes it longer\n # and this one too\n return x + 2\n"
first = make_candidate(longer_code, "opt-long")
normalized = normalize_fn(longer_code.strip())
eval_ctx.register_new_candidate(normalized, first, ORIGINAL_FLAT)
eval_ctx.record_successful_candidate("opt-long", runtime=500.0, speedup=3.0)
original_diff_len = eval_ctx.ast_code_to_id[normalized]["diff_len"]
# Duplicate with shorter code (same normalized form)
shorter = make_candidate(CANDIDATE_A, "opt-short")
eval_ctx.handle_duplicate_candidate(shorter, normalized, ORIGINAL_FLAT)
new_diff_len = eval_ctx.ast_code_to_id[normalized]["diff_len"]
# Shorter code should have replaced the longer one
assert new_diff_len <= original_diff_len
assert eval_ctx.ast_code_to_id[normalized]["shorter_source_code"] is shorter.source_code
def test_handle_duplicate_keeps_shorter_when_new_is_longer(self):
"""When a duplicate has a longer diff, the original shorter_source_code is kept."""
eval_ctx = CandidateEvaluationContext()
first = make_candidate(CANDIDATE_A, "opt-short")
normalized = normalize_fn(CANDIDATE_A.strip())
eval_ctx.register_new_candidate(normalized, first, ORIGINAL_FLAT)
eval_ctx.record_successful_candidate("opt-short", runtime=500.0, speedup=3.0)
longer_code = "def foo(x):\n # this comment makes it longer\n # and this one too\n return x + 2\n"
dup = make_candidate(longer_code, "opt-long")
eval_ctx.handle_duplicate_candidate(dup, normalized, ORIGINAL_FLAT)
assert eval_ctx.ast_code_to_id[normalized]["shorter_source_code"] is first.source_code