From 3bd02550400baa2d99238a0b2e342bf2d27d3d75 Mon Sep 17 00:00:00 2001 From: Mohamed Ashraf Date: Wed, 1 Apr 2026 16:37:23 +0000 Subject: [PATCH 1/3] fix: scope field extraction to target class to prevent cross-class injection find_fields() was called without a class_name filter, causing fields from inner/anonymous classes to be injected into the outer target class. Now scoped to target_method.class_name using the existing filter parameter. Co-Authored-By: Claude Opus 4.6 --- codeflash/languages/java/replacement.py | 7 +- .../test_java/test_replacement.py | 74 +++++++++++++++++++ 2 files changed, 79 insertions(+), 2 deletions(-) diff --git a/codeflash/languages/java/replacement.py b/codeflash/languages/java/replacement.py index 5ed9bf8f1..765b6837d 100644 --- a/codeflash/languages/java/replacement.py +++ b/codeflash/languages/java/replacement.py @@ -69,7 +69,6 @@ def _parse_optimization_source(new_source: str, target_method_name: str, analyze if classes: # It's a class - extract components methods = analyzer.find_methods(new_source) - fields = analyzer.find_fields(new_source) # Find the target method and its index among all methods target_method = None @@ -128,7 +127,11 @@ def _parse_optimization_source(new_source: str, target_method_name: str, analyze ctor_end = c.end_line modified_constructors.append("".join(ctor_lines[ctor_start:ctor_end])) - # Extract fields + # Extract fields scoped to the target method's class only. + # Without class filtering, fields from inner/anonymous classes would be + # incorrectly injected into the outer target class. + target_class_name = target_method.class_name if target_method else None + fields = analyzer.find_fields(new_source, class_name=target_class_name) for f in fields: if f.source_text: new_fields.append(f.source_text) diff --git a/tests/test_languages/test_java/test_replacement.py b/tests/test_languages/test_java/test_replacement.py index 1bd4f7abb..02776fc1d 100644 --- a/tests/test_languages/test_java/test_replacement.py +++ b/tests/test_languages/test_java/test_replacement.py @@ -1929,3 +1929,77 @@ public final class LuaMap { } """ assert new_code == expected_code + + +class TestFieldInjectionClassFiltering: + """Tests that fields from inner/anonymous classes are not injected into the target class.""" + + def test_inner_class_fields_not_injected_into_outer(self, tmp_path, java_support): + """Reproduces the Guava/Iterables.mergeSorted bug. + + When the LLM generates an optimization that includes an inner class with + fields (e.g., generic type parameters), those fields must NOT be injected + into the outer class where the target method lives. + """ + from codeflash.discovery.functions_to_optimize import FunctionParent, FunctionToOptimize + + java_file = tmp_path / "Outer.java" + original_code = """\ +public class Outer { + private int count; + + public int process(int x) { + return x + count; + } +} +""" + java_file.write_text(original_code, encoding="utf-8") + + # LLM generates optimization with an inner class that has its own field. + # The inner class's field should NOT be injected into Outer. + optimized_markdown = f"""```java:{java_file.relative_to(tmp_path)} +public class Outer {{ + private int count; + private static final int OFFSET = 10; + + public int process(int x) {{ + return x + count + OFFSET; + }} + + private static class Inner {{ + private final String badField; + + Inner(String s) {{ + this.badField = s; + }} + }} +}} +```""" + + optimized_code = CodeStringsMarkdown.parse_markdown_code(optimized_markdown, expected_language="java") + + function_to_optimize = FunctionToOptimize( + function_name="process", + file_path=java_file, + starting_line=4, + ending_line=6, + parents=[FunctionParent(name="Outer", type="ClassDef")], + qualified_name="Outer.process", + is_method=True, + ) + + result = java_support.replace_function_definitions( + function_names=["process"], + optimized_code=optimized_code, + module_abspath=java_file, + project_root_path=tmp_path, + function_to_optimize=function_to_optimize, + ) + + assert result is True + new_code = java_file.read_text(encoding="utf-8") + + # Only OFFSET field should be added (belongs to Outer). + # badField belongs to Inner and should NOT appear. + assert "OFFSET" in new_code + assert "badField" not in new_code From efbd34159c2245a9bd1b8ccb679fcb229420e1b9 Mon Sep 17 00:00:00 2001 From: Mohamed Ashraf Date: Tue, 28 Apr 2026 15:22:42 +0000 Subject: [PATCH 2/3] test: annotate test_replacement.py for mypy prek hook Add -> None return annotations and Path / JavaSupport parameter annotations to every test method + fixture so the prek mypy hook passes when the file is in the CI diff. --- .../test_java/test_replacement.py | 73 +++++++++---------- 1 file changed, 34 insertions(+), 39 deletions(-) diff --git a/tests/test_languages/test_java/test_replacement.py b/tests/test_languages/test_java/test_replacement.py index 02776fc1d..35067130b 100644 --- a/tests/test_languages/test_java/test_replacement.py +++ b/tests/test_languages/test_java/test_replacement.py @@ -15,14 +15,14 @@ from codeflash.models.models import CodeStringsMarkdown @pytest.fixture -def java_support(): +def java_support() -> JavaSupport: return JavaSupport() class TestReplaceFunctionDefinitionsInModule: """Tests for replace_function_definitions_for_language with Java (basic cases).""" - def test_replace_simple_method(self, tmp_path: Path, java_support: JavaSupport): + def test_replace_simple_method(self, tmp_path: Path, java_support: JavaSupport) -> None: """Test replacing a simple method in a Java class.""" java_file = tmp_path / "Calculator.java" original_code = """public class Calculator { @@ -61,7 +61,7 @@ public class Calculator {{ """ assert new_code == expected - def test_replace_method_preserves_other_methods(self, tmp_path: Path, java_support: JavaSupport): + def test_replace_method_preserves_other_methods(self, tmp_path: Path, java_support: JavaSupport) -> None: """Test that replacing one method preserves other methods.""" java_file = tmp_path / "Calculator.java" original_code = """public class Calculator { @@ -124,7 +124,7 @@ public class Calculator {{ """ assert new_code == expected - def test_replace_method_with_javadoc(self, tmp_path: Path, java_support: JavaSupport): + def test_replace_method_with_javadoc(self, tmp_path: Path, java_support: JavaSupport) -> None: """Test replacing a method that has Javadoc comments.""" java_file = tmp_path / "MathUtils.java" original_code = """public class MathUtils { @@ -193,7 +193,7 @@ public class MathUtils {{ """ assert new_code == expected - def test_no_change_when_code_identical(self, tmp_path: Path, java_support: JavaSupport): + def test_no_change_when_code_identical(self, tmp_path: Path, java_support: JavaSupport) -> None: """Test that no change is made when optimized code is identical.""" java_file = tmp_path / "Identity.java" original_code = """public class Identity { @@ -230,7 +230,7 @@ public class Identity {{ class TestReplaceFunctionDefinitionsForLanguage: """Tests for replace_function_definitions_for_language with Java.""" - def test_replace_static_method(self, tmp_path: Path, java_support: JavaSupport): + def test_replace_static_method(self, tmp_path: Path, java_support: JavaSupport) -> None: """Test replacing a static method.""" java_file = tmp_path / "Utils.java" original_code = """public class Utils { @@ -269,7 +269,7 @@ public class Utils {{ """ assert new_code == expected - def test_replace_method_with_annotations(self, tmp_path: Path, java_support: JavaSupport): + def test_replace_method_with_annotations(self, tmp_path: Path, java_support: JavaSupport) -> None: """Test replacing a method with annotations.""" java_file = tmp_path / "Service.java" original_code = """public class Service { @@ -311,7 +311,7 @@ public class Service {{ """ assert new_code == expected - def test_replace_method_in_interface(self, tmp_path: Path, java_support: JavaSupport): + def test_replace_method_in_interface(self, tmp_path: Path, java_support: JavaSupport) -> None: """Test replacing a default method in an interface.""" java_file = tmp_path / "Processor.java" original_code = """public interface Processor { @@ -350,7 +350,7 @@ public interface Processor {{ """ assert new_code == expected - def test_replace_method_in_enum(self, tmp_path: Path, java_support: JavaSupport): + def test_replace_method_in_enum(self, tmp_path: Path, java_support: JavaSupport) -> None: """Test replacing a method in an enum.""" java_file = tmp_path / "Color.java" original_code = """public enum Color { @@ -395,7 +395,7 @@ public enum Color {{ """ assert new_code == expected - def test_replace_generic_method(self, tmp_path: Path, java_support: JavaSupport): + def test_replace_generic_method(self, tmp_path: Path, java_support: JavaSupport) -> None: """Test replacing a method with generics.""" java_file = tmp_path / "Container.java" original_code = """import java.util.List; @@ -453,7 +453,7 @@ public class Container { """ assert new_code == expected - def test_replace_method_with_throws(self, tmp_path: Path, java_support: JavaSupport): + def test_replace_method_with_throws(self, tmp_path: Path, java_support: JavaSupport) -> None: """Test replacing a method with throws clause.""" java_file = tmp_path / "FileReader.java" original_code = """import java.io.IOException; @@ -508,7 +508,7 @@ public class FileReader { class TestRealWorldOptimizationScenarios: """Real-world optimization scenarios with complete valid Java code.""" - def test_optimize_string_concatenation(self, tmp_path: Path, java_support: JavaSupport): + def test_optimize_string_concatenation(self, tmp_path: Path, java_support: JavaSupport) -> None: """Test optimizing string concatenation to StringBuilder.""" java_file = tmp_path / "StringJoiner.java" original_code = """public class StringJoiner { @@ -559,7 +559,7 @@ public class StringJoiner {{ """ assert new_code == expected - def test_optimize_list_iteration(self, tmp_path: Path, java_support: JavaSupport): + def test_optimize_list_iteration(self, tmp_path: Path, java_support: JavaSupport) -> None: """Test optimizing list iteration with streams.""" java_file = tmp_path / "ListProcessor.java" original_code = """import java.util.List; @@ -608,7 +608,7 @@ public class ListProcessor { """ assert new_code == expected - def test_optimize_null_checks(self, tmp_path: Path, java_support: JavaSupport): + def test_optimize_null_checks(self, tmp_path: Path, java_support: JavaSupport) -> None: """Test optimizing null checks with Objects utility.""" java_file = tmp_path / "NullChecker.java" original_code = """public class NullChecker { @@ -655,7 +655,7 @@ public class NullChecker {{ """ assert new_code == expected - def test_optimize_collection_creation(self, tmp_path: Path, java_support: JavaSupport): + def test_optimize_collection_creation(self, tmp_path: Path, java_support: JavaSupport) -> None: """Test optimizing collection creation with factory methods.""" java_file = tmp_path / "CollectionFactory.java" original_code = """import java.util.ArrayList; @@ -711,7 +711,7 @@ public class CollectionFactory { class TestMultipleClassesAndMethods: """Tests for files with multiple classes or multiple methods being optimized.""" - def test_replace_method_in_first_class(self, tmp_path: Path, java_support: JavaSupport): + def test_replace_method_in_first_class(self, tmp_path: Path, java_support: JavaSupport) -> None: """Test replacing a method in the first class when multiple classes exist.""" java_file = tmp_path / "MultiClass.java" original_code = """public class Calculator { @@ -768,7 +768,7 @@ class Helper { """ assert new_code == expected - def test_replace_multiple_methods(self, tmp_path: Path, java_support: JavaSupport): + def test_replace_multiple_methods(self, tmp_path: Path, java_support: JavaSupport) -> None: """Test replacing multiple methods in the same class.""" java_file = tmp_path / "MathOps.java" original_code = """public class MathOps { @@ -835,7 +835,7 @@ public class MathOps {{ class TestNestedClasses: """Tests for nested class scenarios.""" - def test_replace_method_in_nested_class(self, tmp_path: Path, java_support: JavaSupport): + def test_replace_method_in_nested_class(self, tmp_path: Path, java_support: JavaSupport) -> None: """Nested class methods are skipped by discovery (PR #1726), so replacement returns False.""" java_file = tmp_path / "Outer.java" original_code = """public class Outer { @@ -882,7 +882,7 @@ public class Outer {{ class TestPreservesStructure: """Tests that verify code structure is preserved during replacement.""" - def test_preserves_fields_and_constructors(self, tmp_path: Path, java_support: JavaSupport): + def test_preserves_fields_and_constructors(self, tmp_path: Path, java_support: JavaSupport) -> None: """Test that fields and constructors are preserved.""" java_file = tmp_path / "Counter.java" original_code = """public class Counter { @@ -952,7 +952,7 @@ public class Counter {{ class TestEdgeCases: """Edge cases and error handling tests.""" - def test_empty_optimized_code_returns_false(self, tmp_path: Path, java_support: JavaSupport): + def test_empty_optimized_code_returns_false(self, tmp_path: Path, java_support: JavaSupport) -> None: """Test that empty optimized code returns False.""" java_file = tmp_path / "Empty.java" original_code = """public class Empty { @@ -980,7 +980,7 @@ class TestEdgeCases: new_code = java_file.read_text(encoding="utf-8") assert new_code == original_code - def test_function_not_found_returns_false(self, tmp_path: Path, java_support: JavaSupport): + def test_function_not_found_returns_false(self, tmp_path: Path, java_support: JavaSupport) -> None: """Test that function not found returns False.""" java_file = tmp_path / "NotFound.java" original_code = """public class NotFound { @@ -1011,7 +1011,7 @@ public class NotFound {{ assert result is False - def test_unicode_in_code(self, tmp_path: Path, java_support: JavaSupport): + def test_unicode_in_code(self, tmp_path: Path, java_support: JavaSupport) -> None: """Test handling of unicode characters in code.""" java_file = tmp_path / "Unicode.java" original_code = """public class Unicode { @@ -1054,7 +1054,7 @@ public class Unicode {{ class TestOptimizationWithStaticFields: """Tests for optimizations that add new static fields to the class.""" - def test_add_static_lookup_table(self, tmp_path: Path, java_support: JavaSupport): + def test_add_static_lookup_table(self, tmp_path: Path, java_support: JavaSupport) -> None: """Test optimization that adds a static lookup table.""" java_file = tmp_path / "Buffer.java" original_code = """public class Buffer { @@ -1114,7 +1114,7 @@ public class Buffer {{ """ assert new_code == expected - def test_add_precomputed_array(self, tmp_path: Path, java_support: JavaSupport): + def test_add_precomputed_array(self, tmp_path: Path, java_support: JavaSupport) -> None: """Test optimization that adds a precomputed static array.""" java_file = tmp_path / "Encoder.java" original_code = """public class Encoder { @@ -1174,7 +1174,7 @@ public class Encoder {{ """ assert new_code == expected - def test_preserve_existing_fields(self, tmp_path: Path, java_support: JavaSupport): + def test_preserve_existing_fields(self, tmp_path: Path, java_support: JavaSupport) -> None: """Test that existing fields are preserved when adding new ones.""" java_file = tmp_path / "Calculator.java" original_code = """public class Calculator { @@ -1260,7 +1260,7 @@ public class Calculator {{ class TestOptimizationWithHelperMethods: """Tests for optimizations that add new helper methods.""" - def test_add_private_helper_method(self, tmp_path: Path, java_support: JavaSupport): + def test_add_private_helper_method(self, tmp_path: Path, java_support: JavaSupport) -> None: """Test optimization that adds a private helper method.""" java_file = tmp_path / "StringUtils.java" original_code = """public class StringUtils { @@ -1330,7 +1330,7 @@ public class StringUtils {{ """ assert new_code == expected - def test_add_multiple_helpers(self, tmp_path: Path, java_support: JavaSupport): + def test_add_multiple_helpers(self, tmp_path: Path, java_support: JavaSupport) -> None: """Test optimization that adds multiple helper methods.""" java_file = tmp_path / "MathUtils.java" original_code = """public class MathUtils { @@ -1395,7 +1395,7 @@ public class MathUtils {{ class TestOptimizationWithFieldsAndHelpers: """Tests for optimizations that add both static fields and helper methods.""" - def test_add_field_and_helper_together(self, tmp_path: Path, java_support: JavaSupport): + def test_add_field_and_helper_together(self, tmp_path: Path, java_support: JavaSupport) -> None: """Test optimization that adds both a static field and helper method.""" java_file = tmp_path / "Fibonacci.java" original_code = """public class Fibonacci { @@ -1464,7 +1464,7 @@ public class Fibonacci {{ """ assert new_code == expected - def test_real_world_bytes_to_hex_optimization(self, tmp_path: Path, java_support: JavaSupport): + def test_real_world_bytes_to_hex_optimization(self, tmp_path: Path, java_support: JavaSupport) -> None: """Test the actual bytesToHexString optimization pattern from aerospike.""" java_file = tmp_path / "Buffer.java" original_code = """package com.example; @@ -1561,7 +1561,7 @@ public final class Buffer { class TestOverloadedMethods: """Tests for handling overloaded methods (same name, different signatures).""" - def test_replace_specific_overload_by_line_number(self, tmp_path: Path, java_support: JavaSupport): + def test_replace_specific_overload_by_line_number(self, tmp_path: Path, java_support: JavaSupport) -> None: """Test replacing a specific overload when multiple exist.""" java_file = tmp_path / "Buffer.java" original_code = """public final class Buffer { @@ -1615,7 +1615,6 @@ public final class Buffer {{ starting_line=13, # Line where 3-arg version starts (1-indexed) ending_line=18, parents=[FunctionParent(name="Buffer", type="ClassDef")], - qualified_name="Buffer.bytesToHexString", is_method=True, ) @@ -1670,7 +1669,7 @@ class TestWrongMethodNameGeneration: source file unchanged. """ - def test_standalone_wrong_method_name_leaves_source_unchanged(self, tmp_path, java_support): + def test_standalone_wrong_method_name_leaves_source_unchanged(self, tmp_path: Path, java_support: JavaSupport) -> None: """Standalone generated method with wrong name must not replace the target. Reproduces the Unpacker.unpackObjectMap bug: the LLM was asked to optimise @@ -1710,7 +1709,6 @@ public final Object unpackMap() {{ starting_line=2, ending_line=4, parents=[FunctionParent(name="Unpacker", type="ClassDef")], - qualified_name="Unpacker.unpackObjectMap", is_method=True, ) @@ -1726,7 +1724,7 @@ public final Object unpackMap() {{ assert result is False assert java_file.read_text(encoding="utf-8") == original_code - def test_class_wrapper_with_wrong_target_method_leaves_source_unchanged(self, tmp_path, java_support): + def test_class_wrapper_with_wrong_target_method_leaves_source_unchanged(self, tmp_path: Path, java_support: JavaSupport) -> None: """Class-wrapped generated code missing the target method must not modify source. Reproduces the Command.estimateKeySize bug: the LLM generated a class that @@ -1767,7 +1765,6 @@ public class Command {{ starting_line=2, ending_line=4, parents=[FunctionParent(name="Command", type="ClassDef")], - qualified_name="Command.estimateKeySize", is_method=True, ) @@ -1795,7 +1792,7 @@ class TestAnonymousInnerClassMethods: enclosing method scope. """ - def test_anonymous_iterator_methods_not_hoisted_to_class(self, tmp_path, java_support): + def test_anonymous_iterator_methods_not_hoisted_to_class(self, tmp_path: Path, java_support: JavaSupport) -> None: """Reproduces the LuaMap.keySetIterator bug. The LLM optimised ``keySetIterator`` by returning an anonymous @@ -1876,7 +1873,6 @@ public final class LuaMap {{ starting_line=11, ending_line=13, parents=[FunctionParent(name="LuaMap", type="ClassDef")], - qualified_name="LuaMap.keySetIterator", is_method=True, ) @@ -1934,7 +1930,7 @@ public final class LuaMap { class TestFieldInjectionClassFiltering: """Tests that fields from inner/anonymous classes are not injected into the target class.""" - def test_inner_class_fields_not_injected_into_outer(self, tmp_path, java_support): + def test_inner_class_fields_not_injected_into_outer(self, tmp_path: Path, java_support: JavaSupport) -> None: """Reproduces the Guava/Iterables.mergeSorted bug. When the LLM generates an optimization that includes an inner class with @@ -1984,7 +1980,6 @@ public class Outer {{ starting_line=4, ending_line=6, parents=[FunctionParent(name="Outer", type="ClassDef")], - qualified_name="Outer.process", is_method=True, ) From f02b99f8fb76916d5d2db5d61aefac3f26884ed3 Mon Sep 17 00:00:00 2001 From: Mohamed Ashraf Date: Tue, 28 Apr 2026 16:39:00 +0000 Subject: [PATCH 3/3] fix: decode help-banner test subprocess output as UTF-8 MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Rich renders the banner panel with box-drawing characters (╭, ╮, │, etc.) that cp1252 cannot decode. On Windows, subprocess.run(..., text=True) uses cp1252 by default, so decoding the child stdout raises UnicodeDecodeError and subprocess sets result.stdout to None — breaking the assertion with a misleading "argument of type 'NoneType' is not iterable". Pass encoding="utf-8" explicitly so the test passes on every platform. --- tests/test_help_banner.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/tests/test_help_banner.py b/tests/test_help_banner.py index c5d801b23..27d749790 100644 --- a/tests/test_help_banner.py +++ b/tests/test_help_banner.py @@ -4,7 +4,10 @@ import sys def test_help_displays_logo() -> None: result = subprocess.run( - [sys.executable, "-c", "from codeflash.main import main; main()", "--help"], capture_output=True, text=True + [sys.executable, "-c", "from codeflash.main import main; main()", "--help"], + capture_output=True, + text=True, + encoding="utf-8", ) assert result.returncode == 0 assert "codeflash.ai" in result.stdout @@ -12,7 +15,10 @@ def test_help_displays_logo() -> None: def test_help_short_flag_displays_logo() -> None: result = subprocess.run( - [sys.executable, "-c", "from codeflash.main import main; main()", "-h"], capture_output=True, text=True + [sys.executable, "-c", "from codeflash.main import main; main()", "-h"], + capture_output=True, + text=True, + encoding="utf-8", ) assert result.returncode == 0 assert "codeflash.ai" in result.stdout