From c72945b045263ec141478a6bc6013add154c1ccb Mon Sep 17 00:00:00 2001 From: Kevin Turcios Date: Sun, 22 Mar 2026 05:19:15 -0500 Subject: [PATCH] feat: early dedup of optimization candidates before benchmark loop Dedup candidates in CandidateProcessor when each batch arrives (initial, line profiler, repair, refinement, adaptive) instead of only catching duplicates one-by-one during the benchmark loop. Changes: - Add dedup_candidates() to CandidateProcessor with persistent seen_normalized set that tracks queued candidates across batches - Simplify handle_duplicate_candidate/register_new_candidate to accept original_flat_code string instead of full CodeOptimizationContext - 14 unit tests covering all dedup paths --- codeflash/languages/function_optimizer.py | 71 +++++- codeflash/models/models.py | 8 +- tests/test_early_dedup.py | 281 ++++++++++++++++++++++ 3 files changed, 350 insertions(+), 10 deletions(-) create mode 100644 tests/test_early_dedup.py diff --git a/codeflash/languages/function_optimizer.py b/codeflash/languages/function_optimizer.py index 72b02b486..59fea9b6d 100644 --- a/codeflash/languages/function_optimizer.py +++ b/codeflash/languages/function_optimizer.py @@ -257,6 +257,9 @@ class CandidateProcessor: future_all_refinements: list[concurrent.futures.Future], future_all_code_repair: list[concurrent.futures.Future], future_adaptive_optimizations: list[concurrent.futures.Future], + normalize_fn: Callable[[str], str], + normalized_original: str, + original_flat_code: str, ) -> None: self.candidate_queue = queue.Queue() self.forest = CandidateForest() @@ -264,12 +267,17 @@ class CandidateProcessor: self.refinement_done = False self.eval_ctx = eval_ctx self.effort = effort - self.candidate_len = len(initial_candidates) self.refinement_calls_count = 0 self.original_markdown_code = original_markdown_code + self.normalize_fn = normalize_fn + self.normalized_original = normalized_original + self.original_flat_code = original_flat_code + self.seen_normalized: set[str] = set() - # Initialize queue with initial candidates - for candidate in initial_candidates: + # Dedup initial candidates before queuing + deduped = self.dedup_candidates(initial_candidates) + self.candidate_len = len(deduped) + for candidate in deduped: self.forest.add(candidate) self.candidate_queue.put(candidate) @@ -278,6 +286,53 @@ class CandidateProcessor: self.future_all_code_repair = future_all_code_repair self.future_adaptive_optimizations = future_adaptive_optimizations + def dedup_candidates(self, candidates: list[OptimizedCandidate]) -> list[OptimizedCandidate]: + """Remove duplicates from a batch of candidates before queuing. + + Filters out candidates that are: + - Identical to the original code + - Duplicates of previously-registered candidates in eval_ctx.ast_code_to_id + (called when the queue is empty, after all prior candidates have been + both registered and benchmarked) + - Duplicates of candidates already queued from prior batches + (tracked in self.seen_normalized which persists across calls) + - Intra-batch duplicates + """ + unique: list[OptimizedCandidate] = [] + removed_original = 0 + removed_cross_batch = 0 + removed_duplicate = 0 + + for candidate in candidates: + normalized = self.normalize_fn(candidate.source_code.flat.strip()) + + if normalized == self.normalized_original: + removed_original += 1 + continue + + if normalized in self.eval_ctx.ast_code_to_id: + self.eval_ctx.handle_duplicate_candidate(candidate, normalized, self.original_flat_code) + removed_cross_batch += 1 + continue + + if normalized in self.seen_normalized: + removed_duplicate += 1 + continue + + self.seen_normalized.add(normalized) + unique.append(candidate) + + total_removed = removed_original + removed_cross_batch + removed_duplicate + if total_removed > 0: + logger.info( + f"Early dedup removed {total_removed} candidate(s) " + f"({removed_original} identical to original, " + f"{removed_cross_batch} already-benchmarked duplicates, " + f"{removed_duplicate} duplicates)" + ) + + return unique + def get_total_llm_calls(self) -> int: return self.refinement_calls_count @@ -347,6 +402,7 @@ class CandidateProcessor: candidates.append(candidate_result) candidates = filter_candidates_func(candidates) if filter_candidates_func else candidates + candidates = self.dedup_candidates(candidates) for candidate in candidates: self.forest.add(candidate) self.candidate_queue.put(candidate) @@ -1107,7 +1163,7 @@ class FunctionOptimizer: logger.info( f"h3|Candidate {candidate_index}/{total_candidates}: Duplicate of a previous candidate, skipping." ) - eval_ctx.handle_duplicate_candidate(candidate, normalized_code, code_context) + eval_ctx.handle_duplicate_candidate(candidate, normalized_code, code_context.read_writable_code.flat) console.rule() return None @@ -1139,7 +1195,7 @@ class FunctionOptimizer: ) return None - eval_ctx.register_new_candidate(normalized_code, candidate, code_context) + eval_ctx.register_new_candidate(normalized_code, candidate, code_context.read_writable_code.flat) # Run the optimized candidate run_results = self.run_optimized_candidate( @@ -1299,6 +1355,7 @@ class FunctionOptimizer: language_version=self.language_support.language_version, ) + normalized_original = self.language_support.normalize_code(code_context.read_writable_code.flat.strip()) processor = CandidateProcessor( candidates, future_line_profile_results, @@ -1308,9 +1365,11 @@ class FunctionOptimizer: self.future_all_refinements, self.future_all_code_repair, self.future_adaptive_optimizations, + normalize_fn=self.language_support.normalize_code, + normalized_original=normalized_original, + original_flat_code=code_context.read_writable_code.flat, ) candidate_index = 0 - normalized_original = self.language_support.normalize_code(code_context.read_writable_code.flat.strip()) # Process candidates using queue-based approach while not processor.is_done(): diff --git a/codeflash/models/models.py b/codeflash/models/models.py index d1cecb554..2e0900640 100644 --- a/codeflash/models/models.py +++ b/codeflash/models/models.py @@ -542,7 +542,7 @@ class CandidateEvaluationContext: self.optimized_line_profiler_results[optimization_id] = result def handle_duplicate_candidate( - self, candidate: OptimizedCandidate, normalized_code: str, code_context: CodeOptimizationContext + self, candidate: OptimizedCandidate, normalized_code: str, original_flat_code: str ) -> None: """Handle a candidate that has been seen before.""" past_opt_id = self.ast_code_to_id[normalized_code]["optimization_id"] @@ -564,19 +564,19 @@ class CandidateEvaluationContext: self.optimizations_post[past_opt_id] = self.ast_code_to_id[normalized_code]["shorter_source_code"].markdown # Update to shorter code if this candidate has a shorter diff - new_diff_len = diff_length(candidate.source_code.flat, code_context.read_writable_code.flat) + new_diff_len = diff_length(candidate.source_code.flat, original_flat_code) if new_diff_len < self.ast_code_to_id[normalized_code]["diff_len"]: self.ast_code_to_id[normalized_code]["shorter_source_code"] = candidate.source_code self.ast_code_to_id[normalized_code]["diff_len"] = new_diff_len def register_new_candidate( - self, normalized_code: str, candidate: OptimizedCandidate, code_context: CodeOptimizationContext + self, normalized_code: str, candidate: OptimizedCandidate, original_flat_code: str ) -> None: """Register a new candidate that hasn't been seen before.""" self.ast_code_to_id[normalized_code] = { "optimization_id": candidate.optimization_id, "shorter_source_code": candidate.source_code, - "diff_len": diff_length(candidate.source_code.flat, code_context.read_writable_code.flat), + "diff_len": diff_length(candidate.source_code.flat, original_flat_code), } def get_speedup_ratio(self, optimization_id: str) -> float | None: diff --git a/tests/test_early_dedup.py b/tests/test_early_dedup.py new file mode 100644 index 000000000..71649a9b4 --- /dev/null +++ b/tests/test_early_dedup.py @@ -0,0 +1,281 @@ +"""Tests for early candidate deduplication in CandidateProcessor.""" + +from __future__ import annotations + +import concurrent.futures +from pathlib import Path + +import pytest + +from codeflash.languages.function_optimizer import CandidateProcessor +from codeflash.languages.python.normalizer import normalize_python_code +from codeflash.models.models import ( + CandidateEvaluationContext, + CodeString, + CodeStringsMarkdown, + OptimizedCandidate, + OptimizedCandidateSource, +) + + +def make_source_code(code: str) -> CodeStringsMarkdown: + return CodeStringsMarkdown( + code_strings=[CodeString(code=code, file_path=Path("test.py"))], + ) + + +def make_candidate(code: str, opt_id: str | None = None, source: OptimizedCandidateSource = OptimizedCandidateSource.OPTIMIZE) -> OptimizedCandidate: + return OptimizedCandidate( + source_code=make_source_code(code), + explanation="test", + optimization_id=opt_id or f"opt-{id(code)}", + source=source, + ) + + +def normalize_fn(source: str) -> str: + try: + return normalize_python_code(source, remove_docstrings=True) + except Exception: + return source + + +ORIGINAL_CODE = "def foo(x):\n return x + 1\n" +ORIGINAL_FLAT = f"# file: test.py\n{ORIGINAL_CODE}" + +# Normalizes identically to ORIGINAL_CODE (docstring and comment stripped) +IDENTICAL_TO_ORIGINAL = 'def foo(x):\n """Docstring."""\n # comment\n return x + 1\n' + +# Different from original +CANDIDATE_A = "def foo(x):\n return x + 2\n" +CANDIDATE_B = "def foo(x):\n return x * 2\n" +CANDIDATE_C = "def foo(x):\n return x << 1\n" + +# Normalizes identically to CANDIDATE_A (added comment stripped by normalizer) +CANDIDATE_A_DUP = "def foo(x):\n # optimized\n return x + 2\n" + + +def make_done_future(value=None): + f = concurrent.futures.Future() + f.set_result(value) + return f + + +def make_processor(initial_candidates, eval_ctx=None): + if eval_ctx is None: + eval_ctx = CandidateEvaluationContext() + return CandidateProcessor( + initial_candidates=initial_candidates, + future_line_profile_results=make_done_future(None), + eval_ctx=eval_ctx, + effort="default", + original_markdown_code=f"```python\n{ORIGINAL_CODE}```", + future_all_refinements=[], + future_all_code_repair=[], + future_adaptive_optimizations=[], + normalize_fn=normalize_fn, + normalized_original=normalize_fn(ORIGINAL_CODE.strip()), + original_flat_code=ORIGINAL_FLAT, + ) + + +class TestDedup: + def test_unique_candidates_pass_through(self): + candidates = [ + make_candidate(CANDIDATE_A, "opt-a"), + make_candidate(CANDIDATE_B, "opt-b"), + make_candidate(CANDIDATE_C, "opt-c"), + ] + proc = make_processor(candidates) + assert proc.candidate_len == 3 + + def test_identical_to_original_removed(self): + candidates = [ + make_candidate(IDENTICAL_TO_ORIGINAL, "opt-dup-orig"), + make_candidate(CANDIDATE_A, "opt-a"), + ] + proc = make_processor(candidates) + assert proc.candidate_len == 1 + + def test_intra_batch_duplicates_removed(self): + candidates = [ + make_candidate(CANDIDATE_A, "opt-a1"), + make_candidate(CANDIDATE_A_DUP, "opt-a2"), + make_candidate(CANDIDATE_B, "opt-b"), + ] + proc = make_processor(candidates) + assert proc.candidate_len == 2 + + def test_cross_batch_duplicates_copy_results(self): + eval_ctx = CandidateEvaluationContext() + # Simulate a previously-benchmarked candidate + prev_candidate = make_candidate(CANDIDATE_A, "opt-prev") + eval_ctx.register_new_candidate( + normalize_fn(CANDIDATE_A.strip()), + prev_candidate, + ORIGINAL_FLAT, + ) + eval_ctx.record_successful_candidate("opt-prev", runtime=1000.0, speedup=2.0) + + # New batch has a duplicate of the already-benchmarked candidate + new_candidates = [ + make_candidate(CANDIDATE_A_DUP, "opt-new-dup"), + make_candidate(CANDIDATE_B, "opt-b"), + ] + proc = make_processor(new_candidates, eval_ctx=eval_ctx) + # Only CANDIDATE_B should be queued (A_DUP is a cross-batch dup) + assert proc.candidate_len == 1 + # Results should be copied to the duplicate + assert eval_ctx.speedup_ratios["opt-new-dup"] == 2.0 + assert eval_ctx.optimized_runtimes["opt-new-dup"] == 1000.0 + assert eval_ctx.is_correct["opt-new-dup"] is True + + def test_empty_list(self): + proc = make_processor([]) + assert proc.candidate_len == 0 + + def test_all_duplicates_of_original(self): + candidates = [ + make_candidate(IDENTICAL_TO_ORIGINAL, "opt-1"), + make_candidate(ORIGINAL_CODE, "opt-2"), + ] + proc = make_processor(candidates) + assert proc.candidate_len == 0 + + def test_mixed_removal_types(self): + eval_ctx = CandidateEvaluationContext() + prev = make_candidate(CANDIDATE_C, "opt-prev-c") + eval_ctx.register_new_candidate(normalize_fn(CANDIDATE_C.strip()), prev, ORIGINAL_FLAT) + eval_ctx.record_successful_candidate("opt-prev-c", runtime=500.0, speedup=3.0) + + candidates = [ + make_candidate(IDENTICAL_TO_ORIGINAL, "opt-orig"), # identical to original + make_candidate(CANDIDATE_A, "opt-a1"), # unique + make_candidate(CANDIDATE_A_DUP, "opt-a2"), # intra-batch dup of opt-a1 + make_candidate(CANDIDATE_C, "opt-c-dup"), # cross-batch dup + make_candidate(CANDIDATE_B, "opt-b"), # unique + ] + proc = make_processor(candidates, eval_ctx=eval_ctx) + # Only CANDIDATE_A and CANDIDATE_B should survive + assert proc.candidate_len == 2 + # Cross-batch dup should have results copied + assert eval_ctx.speedup_ratios["opt-c-dup"] == 3.0 + + def test_dedup_in_async_batch(self): + """Candidates arriving from line profiler are deduped against prior batches via seen_normalized.""" + candidates_initial = [make_candidate(CANDIDATE_A, "opt-a")] + proc = make_processor(candidates_initial) + assert proc.candidate_len == 1 + + # Simulate what _process_candidates does: dedup a new batch + async_batch = [ + make_candidate(CANDIDATE_B, "opt-b"), + make_candidate(CANDIDATE_A_DUP, "opt-a-lp"), # dup of initial, caught by seen_normalized + ] + deduped = proc.dedup_candidates(async_batch) + assert len(deduped) == 1 + assert deduped[0].optimization_id == "opt-b" + + def test_dedup_in_async_batch_after_benchmark(self): + """After initial candidates are benchmarked, async batch dedup catches cross-batch dups.""" + eval_ctx = CandidateEvaluationContext() + # Simulate initial candidate already benchmarked + prev = make_candidate(CANDIDATE_A, "opt-a") + eval_ctx.register_new_candidate(normalize_fn(CANDIDATE_A.strip()), prev, ORIGINAL_FLAT) + eval_ctx.record_successful_candidate("opt-a", runtime=2000.0, speedup=1.5) + + proc = make_processor([], eval_ctx=eval_ctx) + + async_batch = [ + make_candidate(CANDIDATE_A_DUP, "opt-a-lp"), + make_candidate(CANDIDATE_B, "opt-b"), + ] + deduped = proc.dedup_candidates(async_batch) + assert len(deduped) == 1 + assert deduped[0].optimization_id == "opt-b" + assert eval_ctx.speedup_ratios["opt-a-lp"] == 1.5 + + +class TestCandidateEvaluationContext: + """Direct tests for register_new_candidate and handle_duplicate_candidate with original_flat_code param.""" + + def test_register_new_candidate_stores_diff_len(self): + eval_ctx = CandidateEvaluationContext() + candidate = make_candidate(CANDIDATE_A, "opt-a") + normalized = normalize_fn(CANDIDATE_A.strip()) + + eval_ctx.register_new_candidate(normalized, candidate, ORIGINAL_FLAT) + + entry = eval_ctx.ast_code_to_id[normalized] + assert entry["optimization_id"] == "opt-a" + assert entry["shorter_source_code"] is candidate.source_code + assert isinstance(entry["diff_len"], int) + assert entry["diff_len"] > 0 + + def test_handle_duplicate_copies_all_results(self): + eval_ctx = CandidateEvaluationContext() + first = make_candidate(CANDIDATE_A, "opt-first") + normalized = normalize_fn(CANDIDATE_A.strip()) + + eval_ctx.register_new_candidate(normalized, first, ORIGINAL_FLAT) + eval_ctx.record_successful_candidate("opt-first", runtime=1234.0, speedup=2.5) + eval_ctx.record_line_profiler_result("opt-first", "line profiler output") + + dup = make_candidate(CANDIDATE_A_DUP, "opt-dup") + eval_ctx.handle_duplicate_candidate(dup, normalized, ORIGINAL_FLAT) + + assert eval_ctx.speedup_ratios["opt-dup"] == 2.5 + assert eval_ctx.optimized_runtimes["opt-dup"] == 1234.0 + assert eval_ctx.is_correct["opt-dup"] is True + assert eval_ctx.optimized_line_profiler_results["opt-dup"] == "line profiler output" + + def test_handle_duplicate_copies_failed_results(self): + eval_ctx = CandidateEvaluationContext() + first = make_candidate(CANDIDATE_A, "opt-first") + normalized = normalize_fn(CANDIDATE_A.strip()) + + eval_ctx.register_new_candidate(normalized, first, ORIGINAL_FLAT) + eval_ctx.record_failed_candidate("opt-first") + + dup = make_candidate(CANDIDATE_A_DUP, "opt-dup") + eval_ctx.handle_duplicate_candidate(dup, normalized, ORIGINAL_FLAT) + + assert eval_ctx.speedup_ratios["opt-dup"] is None + assert eval_ctx.optimized_runtimes["opt-dup"] is None + assert eval_ctx.is_correct["opt-dup"] is False + + def test_handle_duplicate_tracks_shorter_source(self): + """When a duplicate has a shorter diff, it replaces the stored shorter_source_code.""" + eval_ctx = CandidateEvaluationContext() + # Register a candidate with longer code + longer_code = "def foo(x):\n # this comment makes it longer\n # and this one too\n return x + 2\n" + first = make_candidate(longer_code, "opt-long") + normalized = normalize_fn(longer_code.strip()) + + eval_ctx.register_new_candidate(normalized, first, ORIGINAL_FLAT) + eval_ctx.record_successful_candidate("opt-long", runtime=500.0, speedup=3.0) + original_diff_len = eval_ctx.ast_code_to_id[normalized]["diff_len"] + + # Duplicate with shorter code (same normalized form) + shorter = make_candidate(CANDIDATE_A, "opt-short") + eval_ctx.handle_duplicate_candidate(shorter, normalized, ORIGINAL_FLAT) + new_diff_len = eval_ctx.ast_code_to_id[normalized]["diff_len"] + + # Shorter code should have replaced the longer one + assert new_diff_len <= original_diff_len + assert eval_ctx.ast_code_to_id[normalized]["shorter_source_code"] is shorter.source_code + + def test_handle_duplicate_keeps_shorter_when_new_is_longer(self): + """When a duplicate has a longer diff, the original shorter_source_code is kept.""" + eval_ctx = CandidateEvaluationContext() + first = make_candidate(CANDIDATE_A, "opt-short") + normalized = normalize_fn(CANDIDATE_A.strip()) + + eval_ctx.register_new_candidate(normalized, first, ORIGINAL_FLAT) + eval_ctx.record_successful_candidate("opt-short", runtime=500.0, speedup=3.0) + + longer_code = "def foo(x):\n # this comment makes it longer\n # and this one too\n return x + 2\n" + dup = make_candidate(longer_code, "opt-long") + eval_ctx.handle_duplicate_candidate(dup, normalized, ORIGINAL_FLAT) + + assert eval_ctx.ast_code_to_id[normalized]["shorter_source_code"] is first.source_code