codeflash/tests/test_languages/test_java/test_replacement.py
Mohamed Ashraf efbd34159c test: annotate test_replacement.py for mypy prek hook
Add -> None return annotations and Path / JavaSupport parameter annotations
to every test method + fixture so the prek mypy hook passes when the file
is in the CI diff.
2026-04-28 15:22:42 +00:00

2000 lines
61 KiB
Python

"""Tests for Java code replacement.
Tests the high-level replacement functions using complete valid Java source files.
All optimized code is syntactically valid Java that could compile.
All assertions use exact string equality for rigorous verification.
"""
from pathlib import Path
import pytest
from codeflash.languages.code_replacer import replace_function_definitions_for_language
from codeflash.languages.java.support import JavaSupport
from codeflash.models.models import CodeStringsMarkdown
@pytest.fixture
def java_support() -> JavaSupport:
return JavaSupport()
class TestReplaceFunctionDefinitionsInModule:
"""Tests for replace_function_definitions_for_language with Java (basic cases)."""
def test_replace_simple_method(self, tmp_path: Path, java_support: JavaSupport) -> None:
"""Test replacing a simple method in a Java class."""
java_file = tmp_path / "Calculator.java"
original_code = """public class Calculator {
public int add(int a, int b) {
return a + b;
}
}
"""
java_file.write_text(original_code, encoding="utf-8")
optimized_markdown = f"""```java:{java_file.relative_to(tmp_path)}
public class Calculator {{
public int add(int a, int b) {{
return Math.addExact(a, b);
}}
}}
```"""
optimized_code = CodeStringsMarkdown.parse_markdown_code(optimized_markdown, expected_language="java")
result = replace_function_definitions_for_language(
function_names=["add"],
optimized_code=optimized_code,
module_abspath=java_file,
project_root_path=tmp_path,
lang_support=java_support,
)
assert result is True
new_code = java_file.read_text(encoding="utf-8")
expected = """public class Calculator {
public int add(int a, int b) {
return Math.addExact(a, b);
}
}
"""
assert new_code == expected
def test_replace_method_preserves_other_methods(self, tmp_path: Path, java_support: JavaSupport) -> None:
"""Test that replacing one method preserves other methods."""
java_file = tmp_path / "Calculator.java"
original_code = """public class Calculator {
public int add(int a, int b) {
return a + b;
}
public int subtract(int a, int b) {
return a - b;
}
public int multiply(int a, int b) {
return a * b;
}
}
"""
java_file.write_text(original_code, encoding="utf-8")
optimized_markdown = f"""```java:{java_file.relative_to(tmp_path)}
public class Calculator {{
public int add(int a, int b) {{
return Integer.sum(a, b);
}}
public int subtract(int a, int b) {{
return a - b;
}}
public int multiply(int a, int b) {{
return a * b;
}}
}}
```"""
optimized_code = CodeStringsMarkdown.parse_markdown_code(optimized_markdown, expected_language="java")
result = replace_function_definitions_for_language(
function_names=["add"],
optimized_code=optimized_code,
module_abspath=java_file,
project_root_path=tmp_path,
lang_support=java_support,
)
assert result is True
new_code = java_file.read_text(encoding="utf-8")
expected = """public class Calculator {
public int add(int a, int b) {
return Integer.sum(a, b);
}
public int subtract(int a, int b) {
return a - b;
}
public int multiply(int a, int b) {
return a * b;
}
}
"""
assert new_code == expected
def test_replace_method_with_javadoc(self, tmp_path: Path, java_support: JavaSupport) -> None:
"""Test replacing a method that has Javadoc comments."""
java_file = tmp_path / "MathUtils.java"
original_code = """public class MathUtils {
/**
* Calculates the factorial.
* @param n the number
* @return factorial of n
*/
public long factorial(int n) {
if (n <= 1) return 1;
long result = 1;
for (int i = 2; i <= n; i++) {
result *= i;
}
return result;
}
}
"""
java_file.write_text(original_code, encoding="utf-8")
optimized_markdown = f"""```java:{java_file.relative_to(tmp_path)}
public class MathUtils {{
/**
* Calculates the factorial (optimized).
* @param n the number
* @return factorial of n
*/
public long factorial(int n) {{
if (n <= 1) return 1;
long result = 1;
for (int i = 2; i <= n; i++) {{
result = Math.multiplyExact(result, i);
}}
return result;
}}
}}
```"""
optimized_code = CodeStringsMarkdown.parse_markdown_code(optimized_markdown, expected_language="java")
result = replace_function_definitions_for_language(
function_names=["factorial"],
optimized_code=optimized_code,
module_abspath=java_file,
project_root_path=tmp_path,
lang_support=java_support,
)
assert result is True
new_code = java_file.read_text(encoding="utf-8")
expected = """public class MathUtils {
/**
* Calculates the factorial (optimized).
* @param n the number
* @return factorial of n
*/
public long factorial(int n) {
if (n <= 1) return 1;
long result = 1;
for (int i = 2; i <= n; i++) {
result = Math.multiplyExact(result, i);
}
return result;
}
}
"""
assert new_code == expected
def test_no_change_when_code_identical(self, tmp_path: Path, java_support: JavaSupport) -> None:
"""Test that no change is made when optimized code is identical."""
java_file = tmp_path / "Identity.java"
original_code = """public class Identity {
public int getValue() {
return 42;
}
}
"""
java_file.write_text(original_code, encoding="utf-8")
optimized_markdown = f"""```java:{java_file.relative_to(tmp_path)}
public class Identity {{
public int getValue() {{
return 42;
}}
}}
```"""
optimized_code = CodeStringsMarkdown.parse_markdown_code(optimized_markdown, expected_language="java")
result = replace_function_definitions_for_language(
function_names=["getValue"],
optimized_code=optimized_code,
module_abspath=java_file,
project_root_path=tmp_path,
lang_support=java_support,
)
assert result is False
new_code = java_file.read_text(encoding="utf-8")
assert new_code == original_code
class TestReplaceFunctionDefinitionsForLanguage:
"""Tests for replace_function_definitions_for_language with Java."""
def test_replace_static_method(self, tmp_path: Path, java_support: JavaSupport) -> None:
"""Test replacing a static method."""
java_file = tmp_path / "Utils.java"
original_code = """public class Utils {
public static int square(int n) {
return n * n;
}
}
"""
java_file.write_text(original_code, encoding="utf-8")
optimized_markdown = f"""```java:{java_file.relative_to(tmp_path)}
public class Utils {{
public static int square(int n) {{
return Math.multiplyExact(n, n);
}}
}}
```"""
optimized_code = CodeStringsMarkdown.parse_markdown_code(optimized_markdown, expected_language="java")
result = replace_function_definitions_for_language(
function_names=["square"],
optimized_code=optimized_code,
module_abspath=java_file,
project_root_path=tmp_path,
lang_support=java_support,
)
assert result is True
new_code = java_file.read_text(encoding="utf-8")
expected = """public class Utils {
public static int square(int n) {
return Math.multiplyExact(n, n);
}
}
"""
assert new_code == expected
def test_replace_method_with_annotations(self, tmp_path: Path, java_support: JavaSupport) -> None:
"""Test replacing a method with annotations."""
java_file = tmp_path / "Service.java"
original_code = """public class Service {
@Override
public String process(String input) {
return input.trim();
}
}
"""
java_file.write_text(original_code, encoding="utf-8")
optimized_markdown = f"""```java:{java_file.relative_to(tmp_path)}
public class Service {{
@Override
public String process(String input) {{
return input == null ? "" : input.strip();
}}
}}
```"""
optimized_code = CodeStringsMarkdown.parse_markdown_code(optimized_markdown, expected_language="java")
result = replace_function_definitions_for_language(
function_names=["process"],
optimized_code=optimized_code,
module_abspath=java_file,
project_root_path=tmp_path,
lang_support=java_support,
)
assert result is True
new_code = java_file.read_text(encoding="utf-8")
expected = """public class Service {
@Override
public String process(String input) {
return input == null ? "" : input.strip();
}
}
"""
assert new_code == expected
def test_replace_method_in_interface(self, tmp_path: Path, java_support: JavaSupport) -> None:
"""Test replacing a default method in an interface."""
java_file = tmp_path / "Processor.java"
original_code = """public interface Processor {
default String process(String input) {
return input.toUpperCase();
}
}
"""
java_file.write_text(original_code, encoding="utf-8")
optimized_markdown = f"""```java:{java_file.relative_to(tmp_path)}
public interface Processor {{
default String process(String input) {{
return input == null ? null : input.toUpperCase();
}}
}}
```"""
optimized_code = CodeStringsMarkdown.parse_markdown_code(optimized_markdown, expected_language="java")
result = replace_function_definitions_for_language(
function_names=["process"],
optimized_code=optimized_code,
module_abspath=java_file,
project_root_path=tmp_path,
lang_support=java_support,
)
assert result is True
new_code = java_file.read_text(encoding="utf-8")
expected = """public interface Processor {
default String process(String input) {
return input == null ? null : input.toUpperCase();
}
}
"""
assert new_code == expected
def test_replace_method_in_enum(self, tmp_path: Path, java_support: JavaSupport) -> None:
"""Test replacing a method in an enum."""
java_file = tmp_path / "Color.java"
original_code = """public enum Color {
RED, GREEN, BLUE;
public String getCode() {
return name().substring(0, 1);
}
}
"""
java_file.write_text(original_code, encoding="utf-8")
optimized_markdown = f"""```java:{java_file.relative_to(tmp_path)}
public enum Color {{
RED, GREEN, BLUE;
public String getCode() {{
return String.valueOf(name().charAt(0));
}}
}}
```"""
optimized_code = CodeStringsMarkdown.parse_markdown_code(optimized_markdown, expected_language="java")
result = replace_function_definitions_for_language(
function_names=["getCode"],
optimized_code=optimized_code,
module_abspath=java_file,
project_root_path=tmp_path,
lang_support=java_support,
)
assert result is True
new_code = java_file.read_text(encoding="utf-8")
expected = """public enum Color {
RED, GREEN, BLUE;
public String getCode() {
return String.valueOf(name().charAt(0));
}
}
"""
assert new_code == expected
def test_replace_generic_method(self, tmp_path: Path, java_support: JavaSupport) -> None:
"""Test replacing a method with generics."""
java_file = tmp_path / "Container.java"
original_code = """import java.util.List;
import java.util.ArrayList;
public class Container<T> {
private List<T> items = new ArrayList<>();
public List<T> getItems() {
List<T> copy = new ArrayList<>();
for (T item : items) {
copy.add(item);
}
return copy;
}
}
"""
java_file.write_text(original_code, encoding="utf-8")
optimized_markdown = f"""```java:{java_file.relative_to(tmp_path)}
import java.util.List;
import java.util.ArrayList;
public class Container<T> {{
private List<T> items = new ArrayList<>();
public List<T> getItems() {{
return new ArrayList<>(items);
}}
}}
```"""
optimized_code = CodeStringsMarkdown.parse_markdown_code(optimized_markdown, expected_language="java")
result = replace_function_definitions_for_language(
function_names=["getItems"],
optimized_code=optimized_code,
module_abspath=java_file,
project_root_path=tmp_path,
lang_support=java_support,
)
assert result is True
new_code = java_file.read_text(encoding="utf-8")
expected = """import java.util.List;
import java.util.ArrayList;
public class Container<T> {
private List<T> items = new ArrayList<>();
public List<T> getItems() {
return new ArrayList<>(items);
}
}
"""
assert new_code == expected
def test_replace_method_with_throws(self, tmp_path: Path, java_support: JavaSupport) -> None:
"""Test replacing a method with throws clause."""
java_file = tmp_path / "FileReader.java"
original_code = """import java.io.IOException;
import java.nio.file.Files;
import java.nio.file.Path;
public class FileReader {
public String readFile(String path) throws IOException {
return new String(Files.readAllBytes(Path.of(path)));
}
}
"""
java_file.write_text(original_code, encoding="utf-8")
optimized_markdown = f"""```java:{java_file.relative_to(tmp_path)}
import java.io.IOException;
import java.nio.file.Files;
import java.nio.file.Path;
public class FileReader {{
public String readFile(String path) throws IOException {{
return Files.readString(Path.of(path));
}}
}}
```"""
optimized_code = CodeStringsMarkdown.parse_markdown_code(optimized_markdown, expected_language="java")
result = replace_function_definitions_for_language(
function_names=["readFile"],
optimized_code=optimized_code,
module_abspath=java_file,
project_root_path=tmp_path,
lang_support=java_support,
)
assert result is True
new_code = java_file.read_text(encoding="utf-8")
expected = """import java.io.IOException;
import java.nio.file.Files;
import java.nio.file.Path;
public class FileReader {
public String readFile(String path) throws IOException {
return Files.readString(Path.of(path));
}
}
"""
assert new_code == expected
class TestRealWorldOptimizationScenarios:
"""Real-world optimization scenarios with complete valid Java code."""
def test_optimize_string_concatenation(self, tmp_path: Path, java_support: JavaSupport) -> None:
"""Test optimizing string concatenation to StringBuilder."""
java_file = tmp_path / "StringJoiner.java"
original_code = """public class StringJoiner {
public String buildString(String[] items) {
String result = "";
for (String item : items) {
result = result + item;
}
return result;
}
}
"""
java_file.write_text(original_code, encoding="utf-8")
optimized_markdown = f"""```java:{java_file.relative_to(tmp_path)}
public class StringJoiner {{
public String buildString(String[] items) {{
StringBuilder sb = new StringBuilder();
for (String item : items) {{
sb.append(item);
}}
return sb.toString();
}}
}}
```"""
optimized_code = CodeStringsMarkdown.parse_markdown_code(optimized_markdown, expected_language="java")
result = replace_function_definitions_for_language(
function_names=["buildString"],
optimized_code=optimized_code,
module_abspath=java_file,
project_root_path=tmp_path,
lang_support=java_support,
)
assert result is True
new_code = java_file.read_text(encoding="utf-8")
expected = """public class StringJoiner {
public String buildString(String[] items) {
StringBuilder sb = new StringBuilder();
for (String item : items) {
sb.append(item);
}
return sb.toString();
}
}
"""
assert new_code == expected
def test_optimize_list_iteration(self, tmp_path: Path, java_support: JavaSupport) -> None:
"""Test optimizing list iteration with streams."""
java_file = tmp_path / "ListProcessor.java"
original_code = """import java.util.List;
public class ListProcessor {
public int sumList(List<Integer> numbers) {
int sum = 0;
for (int i = 0; i < numbers.size(); i++) {
sum += numbers.get(i);
}
return sum;
}
}
"""
java_file.write_text(original_code, encoding="utf-8")
optimized_markdown = f"""```java:{java_file.relative_to(tmp_path)}
import java.util.List;
public class ListProcessor {{
public int sumList(List<Integer> numbers) {{
return numbers.stream().mapToInt(Integer::intValue).sum();
}}
}}
```"""
optimized_code = CodeStringsMarkdown.parse_markdown_code(optimized_markdown, expected_language="java")
result = replace_function_definitions_for_language(
function_names=["sumList"],
optimized_code=optimized_code,
module_abspath=java_file,
project_root_path=tmp_path,
lang_support=java_support,
)
assert result is True
new_code = java_file.read_text(encoding="utf-8")
expected = """import java.util.List;
public class ListProcessor {
public int sumList(List<Integer> numbers) {
return numbers.stream().mapToInt(Integer::intValue).sum();
}
}
"""
assert new_code == expected
def test_optimize_null_checks(self, tmp_path: Path, java_support: JavaSupport) -> None:
"""Test optimizing null checks with Objects utility."""
java_file = tmp_path / "NullChecker.java"
original_code = """public class NullChecker {
public boolean isEqual(String s1, String s2) {
if (s1 == null && s2 == null) {
return true;
}
if (s1 == null || s2 == null) {
return false;
}
return s1.equals(s2);
}
}
"""
java_file.write_text(original_code, encoding="utf-8")
optimized_markdown = f"""```java:{java_file.relative_to(tmp_path)}
import java.util.Objects;
public class NullChecker {{
public boolean isEqual(String s1, String s2) {{
return Objects.equals(s1, s2);
}}
}}
```"""
optimized_code = CodeStringsMarkdown.parse_markdown_code(optimized_markdown, expected_language="java")
result = replace_function_definitions_for_language(
function_names=["isEqual"],
optimized_code=optimized_code,
module_abspath=java_file,
project_root_path=tmp_path,
lang_support=java_support,
)
assert result is True
new_code = java_file.read_text(encoding="utf-8")
expected = """public class NullChecker {
public boolean isEqual(String s1, String s2) {
return Objects.equals(s1, s2);
}
}
"""
assert new_code == expected
def test_optimize_collection_creation(self, tmp_path: Path, java_support: JavaSupport) -> None:
"""Test optimizing collection creation with factory methods."""
java_file = tmp_path / "CollectionFactory.java"
original_code = """import java.util.ArrayList;
import java.util.List;
public class CollectionFactory {
public List<String> createList() {
List<String> list = new ArrayList<>();
list.add("one");
list.add("two");
list.add("three");
return list;
}
}
"""
java_file.write_text(original_code, encoding="utf-8")
optimized_markdown = f"""```java:{java_file.relative_to(tmp_path)}
import java.util.ArrayList;
import java.util.List;
public class CollectionFactory {{
public List<String> createList() {{
return List.of("one", "two", "three");
}}
}}
```"""
optimized_code = CodeStringsMarkdown.parse_markdown_code(optimized_markdown, expected_language="java")
result = replace_function_definitions_for_language(
function_names=["createList"],
optimized_code=optimized_code,
module_abspath=java_file,
project_root_path=tmp_path,
lang_support=java_support,
)
assert result is True
new_code = java_file.read_text(encoding="utf-8")
expected = """import java.util.ArrayList;
import java.util.List;
public class CollectionFactory {
public List<String> createList() {
return List.of("one", "two", "three");
}
}
"""
assert new_code == expected
class TestMultipleClassesAndMethods:
"""Tests for files with multiple classes or multiple methods being optimized."""
def test_replace_method_in_first_class(self, tmp_path: Path, java_support: JavaSupport) -> None:
"""Test replacing a method in the first class when multiple classes exist."""
java_file = tmp_path / "MultiClass.java"
original_code = """public class Calculator {
public int add(int a, int b) {
return a + b;
}
}
class Helper {
public int helper() {
return 0;
}
}
"""
java_file.write_text(original_code, encoding="utf-8")
optimized_markdown = f"""```java:{java_file.relative_to(tmp_path)}
public class Calculator {{
public int add(int a, int b) {{
return Math.addExact(a, b);
}}
}}
class Helper {{
public int helper() {{
return 0;
}}
}}
```"""
optimized_code = CodeStringsMarkdown.parse_markdown_code(optimized_markdown, expected_language="java")
result = replace_function_definitions_for_language(
function_names=["add"],
optimized_code=optimized_code,
module_abspath=java_file,
project_root_path=tmp_path,
lang_support=java_support,
)
assert result is True
new_code = java_file.read_text(encoding="utf-8")
expected = """public class Calculator {
public int add(int a, int b) {
return Math.addExact(a, b);
}
}
class Helper {
public int helper() {
return 0;
}
}
"""
assert new_code == expected
def test_replace_multiple_methods(self, tmp_path: Path, java_support: JavaSupport) -> None:
"""Test replacing multiple methods in the same class."""
java_file = tmp_path / "MathOps.java"
original_code = """public class MathOps {
public int add(int a, int b) {
return a + b;
}
public int subtract(int a, int b) {
return a - b;
}
public int multiply(int a, int b) {
return a * b;
}
}
"""
java_file.write_text(original_code, encoding="utf-8")
optimized_markdown = f"""```java:{java_file.relative_to(tmp_path)}
public class MathOps {{
public int add(int a, int b) {{
return Math.addExact(a, b);
}}
public int subtract(int a, int b) {{
return Math.subtractExact(a, b);
}}
public int multiply(int a, int b) {{
return a * b;
}}
}}
```"""
optimized_code = CodeStringsMarkdown.parse_markdown_code(optimized_markdown, expected_language="java")
result = replace_function_definitions_for_language(
function_names=["add", "subtract"],
optimized_code=optimized_code,
module_abspath=java_file,
project_root_path=tmp_path,
lang_support=java_support,
)
assert result is True
new_code = java_file.read_text(encoding="utf-8")
expected = """public class MathOps {
public int add(int a, int b) {
return Math.addExact(a, b);
}
public int subtract(int a, int b) {
return Math.subtractExact(a, b);
}
public int multiply(int a, int b) {
return a * b;
}
}
"""
assert new_code == expected
class TestNestedClasses:
"""Tests for nested class scenarios."""
def test_replace_method_in_nested_class(self, tmp_path: Path, java_support: JavaSupport) -> None:
"""Nested class methods are skipped by discovery (PR #1726), so replacement returns False."""
java_file = tmp_path / "Outer.java"
original_code = """public class Outer {
public int outerMethod() {
return 1;
}
public static class Inner {
public int innerMethod() {
return 2;
}
}
}
"""
java_file.write_text(original_code, encoding="utf-8")
optimized_markdown = f"""```java:{java_file.relative_to(tmp_path)}
public class Outer {{
public int outerMethod() {{
return 1;
}}
public static class Inner {{
public int innerMethod() {{
return 2 + 0;
}}
}}
}}
```"""
optimized_code = CodeStringsMarkdown.parse_markdown_code(optimized_markdown, expected_language="java")
result = replace_function_definitions_for_language(
function_names=["innerMethod"],
optimized_code=optimized_code,
module_abspath=java_file,
project_root_path=tmp_path,
lang_support=java_support,
)
assert result is False
class TestPreservesStructure:
"""Tests that verify code structure is preserved during replacement."""
def test_preserves_fields_and_constructors(self, tmp_path: Path, java_support: JavaSupport) -> None:
"""Test that fields and constructors are preserved."""
java_file = tmp_path / "Counter.java"
original_code = """public class Counter {
private int count;
private final int max;
public Counter(int max) {
this.count = 0;
this.max = max;
}
public int increment() {
if (count < max) {
count++;
}
return count;
}
}
"""
java_file.write_text(original_code, encoding="utf-8")
optimized_markdown = f"""```java:{java_file.relative_to(tmp_path)}
public class Counter {{
private int count;
private final int max;
public Counter(int max) {{
this.count = 0;
this.max = max;
}}
public int increment() {{
return count < max ? ++count : count;
}}
}}
```"""
optimized_code = CodeStringsMarkdown.parse_markdown_code(optimized_markdown, expected_language="java")
result = replace_function_definitions_for_language(
function_names=["increment"],
optimized_code=optimized_code,
module_abspath=java_file,
project_root_path=tmp_path,
lang_support=java_support,
)
assert result is True
new_code = java_file.read_text(encoding="utf-8")
expected = """public class Counter {
private int count;
private final int max;
public Counter(int max) {
this.count = 0;
this.max = max;
}
public int increment() {
return count < max ? ++count : count;
}
}
"""
assert new_code == expected
class TestEdgeCases:
"""Edge cases and error handling tests."""
def test_empty_optimized_code_returns_false(self, tmp_path: Path, java_support: JavaSupport) -> None:
"""Test that empty optimized code returns False."""
java_file = tmp_path / "Empty.java"
original_code = """public class Empty {
public int getValue() {
return 42;
}
}
"""
java_file.write_text(original_code, encoding="utf-8")
optimized_markdown = """```java:Empty.java
```"""
optimized_code = CodeStringsMarkdown.parse_markdown_code(optimized_markdown, expected_language="java")
result = replace_function_definitions_for_language(
function_names=["getValue"],
optimized_code=optimized_code,
module_abspath=java_file,
project_root_path=tmp_path,
lang_support=java_support,
)
assert result is False
new_code = java_file.read_text(encoding="utf-8")
assert new_code == original_code
def test_function_not_found_returns_false(self, tmp_path: Path, java_support: JavaSupport) -> None:
"""Test that function not found returns False."""
java_file = tmp_path / "NotFound.java"
original_code = """public class NotFound {
public int getValue() {
return 42;
}
}
"""
java_file.write_text(original_code, encoding="utf-8")
optimized_markdown = f"""```java:{java_file.relative_to(tmp_path)}
public class NotFound {{
public int nonExistent() {{
return 0;
}}
}}
```"""
optimized_code = CodeStringsMarkdown.parse_markdown_code(optimized_markdown, expected_language="java")
result = replace_function_definitions_for_language(
function_names=["nonExistent"],
optimized_code=optimized_code,
module_abspath=java_file,
project_root_path=tmp_path,
lang_support=java_support,
)
assert result is False
def test_unicode_in_code(self, tmp_path: Path, java_support: JavaSupport) -> None:
"""Test handling of unicode characters in code."""
java_file = tmp_path / "Unicode.java"
original_code = """public class Unicode {
public String greet() {
return "Hello";
}
}
"""
java_file.write_text(original_code, encoding="utf-8")
optimized_markdown = f"""```java:{java_file.relative_to(tmp_path)}
public class Unicode {{
public String greet() {{
return "こんにちは";
}}
}}
```"""
optimized_code = CodeStringsMarkdown.parse_markdown_code(optimized_markdown, expected_language="java")
result = replace_function_definitions_for_language(
function_names=["greet"],
optimized_code=optimized_code,
module_abspath=java_file,
project_root_path=tmp_path,
lang_support=java_support,
)
assert result is True
new_code = java_file.read_text(encoding="utf-8")
expected = """public class Unicode {
public String greet() {
return "こんにちは";
}
}
"""
assert new_code == expected
class TestOptimizationWithStaticFields:
"""Tests for optimizations that add new static fields to the class."""
def test_add_static_lookup_table(self, tmp_path: Path, java_support: JavaSupport) -> None:
"""Test optimization that adds a static lookup table."""
java_file = tmp_path / "Buffer.java"
original_code = """public class Buffer {
public static String bytesToHexString(byte[] buf, int offset, int length) {
StringBuilder sb = new StringBuilder(length * 2);
for (int i = offset; i < length; i++) {
sb.append(String.format("%02x", buf[i]));
}
return sb.toString();
}
}
"""
java_file.write_text(original_code, encoding="utf-8")
# Optimization adds a static lookup table
optimized_markdown = f"""```java:{java_file.relative_to(tmp_path)}
public class Buffer {{
private static final char[] HEX_DIGITS = "0123456789abcdef".toCharArray();
public static String bytesToHexString(byte[] buf, int offset, int length) {{
StringBuilder sb = new StringBuilder(length * 2);
for (int i = offset; i < length; i++) {{
int v = buf[i] & 0xFF;
sb.append(HEX_DIGITS[v >>> 4]);
sb.append(HEX_DIGITS[v & 0x0F]);
}}
return sb.toString();
}}
}}
```"""
optimized_code = CodeStringsMarkdown.parse_markdown_code(optimized_markdown, expected_language="java")
result = replace_function_definitions_for_language(
function_names=["bytesToHexString"],
optimized_code=optimized_code,
module_abspath=java_file,
project_root_path=tmp_path,
lang_support=java_support,
)
assert result is True
new_code = java_file.read_text(encoding="utf-8")
expected = """public class Buffer {
private static final char[] HEX_DIGITS = "0123456789abcdef".toCharArray();
public static String bytesToHexString(byte[] buf, int offset, int length) {
StringBuilder sb = new StringBuilder(length * 2);
for (int i = offset; i < length; i++) {
int v = buf[i] & 0xFF;
sb.append(HEX_DIGITS[v >>> 4]);
sb.append(HEX_DIGITS[v & 0x0F]);
}
return sb.toString();
}
}
"""
assert new_code == expected
def test_add_precomputed_array(self, tmp_path: Path, java_support: JavaSupport) -> None:
"""Test optimization that adds a precomputed static array."""
java_file = tmp_path / "Encoder.java"
original_code = """public class Encoder {
public static String byteToHex(byte b) {
return String.format("%02x", b);
}
}
"""
java_file.write_text(original_code, encoding="utf-8")
# Optimization with precomputed byte-to-hex lookup
optimized_markdown = f"""```java:{java_file.relative_to(tmp_path)}
public class Encoder {{
private static final String[] BYTE_TO_HEX = createByteToHex();
private static String[] createByteToHex() {{
String[] map = new String[256];
for (int i = 0; i < 256; i++) {{
map[i] = String.format("%02x", i);
}}
return map;
}}
public static String byteToHex(byte b) {{
return BYTE_TO_HEX[b & 0xFF];
}}
}}
```"""
optimized_code = CodeStringsMarkdown.parse_markdown_code(optimized_markdown, expected_language="java")
result = replace_function_definitions_for_language(
function_names=["byteToHex"],
optimized_code=optimized_code,
module_abspath=java_file,
project_root_path=tmp_path,
lang_support=java_support,
)
assert result is True
new_code = java_file.read_text(encoding="utf-8")
expected = """public class Encoder {
private static final String[] BYTE_TO_HEX = createByteToHex();
private static String[] createByteToHex() {
String[] map = new String[256];
for (int i = 0; i < 256; i++) {
map[i] = String.format("%02x", i);
}
return map;
}
public static String byteToHex(byte b) {
return BYTE_TO_HEX[b & 0xFF];
}
}
"""
assert new_code == expected
def test_preserve_existing_fields(self, tmp_path: Path, java_support: JavaSupport) -> None:
"""Test that existing fields are preserved when adding new ones."""
java_file = tmp_path / "Calculator.java"
original_code = """public class Calculator {
private static final int MAX_VALUE = 1000;
public int calculate(int n) {
int result = 0;
for (int i = 0; i < n; i++) {
result += i;
}
return result;
}
}
"""
java_file.write_text(original_code, encoding="utf-8")
# Optimization adds a new static field
optimized_markdown = f"""```java:{java_file.relative_to(tmp_path)}
public class Calculator {{
private static final int MAX_VALUE = 1000;
private static final int[] PRECOMPUTED = precompute();
private static int[] precompute() {{
int[] arr = new int[1001];
for (int i = 1; i <= 1000; i++) {{
arr[i] = arr[i-1] + i - 1;
}}
return arr;
}}
public int calculate(int n) {{
if (n <= 1000) {{
return PRECOMPUTED[n];
}}
int result = PRECOMPUTED[1000];
for (int i = 1000; i < n; i++) {{
result += i;
}}
return result;
}}
}}
```"""
optimized_code = CodeStringsMarkdown.parse_markdown_code(optimized_markdown, expected_language="java")
result = replace_function_definitions_for_language(
function_names=["calculate"],
optimized_code=optimized_code,
module_abspath=java_file,
project_root_path=tmp_path,
lang_support=java_support,
)
assert result is True
new_code = java_file.read_text(encoding="utf-8")
expected = """public class Calculator {
private static final int MAX_VALUE = 1000;
private static final int[] PRECOMPUTED = precompute();
private static int[] precompute() {
int[] arr = new int[1001];
for (int i = 1; i <= 1000; i++) {
arr[i] = arr[i-1] + i - 1;
}
return arr;
}
public int calculate(int n) {
if (n <= 1000) {
return PRECOMPUTED[n];
}
int result = PRECOMPUTED[1000];
for (int i = 1000; i < n; i++) {
result += i;
}
return result;
}
}
"""
assert new_code == expected
class TestOptimizationWithHelperMethods:
"""Tests for optimizations that add new helper methods."""
def test_add_private_helper_method(self, tmp_path: Path, java_support: JavaSupport) -> None:
"""Test optimization that adds a private helper method."""
java_file = tmp_path / "StringUtils.java"
original_code = """public class StringUtils {
public static String reverse(String s) {
char[] chars = s.toCharArray();
int left = 0;
int right = chars.length - 1;
while (left < right) {
char temp = chars[left];
chars[left] = chars[right];
chars[right] = temp;
left++;
right--;
}
return new String(chars);
}
}
"""
java_file.write_text(original_code, encoding="utf-8")
# Optimization extracts swap logic to helper
optimized_markdown = f"""```java:{java_file.relative_to(tmp_path)}
public class StringUtils {{
private static void swap(char[] arr, int i, int j) {{
char temp = arr[i];
arr[i] = arr[j];
arr[j] = temp;
}}
public static String reverse(String s) {{
char[] chars = s.toCharArray();
for (int i = 0, j = chars.length - 1; i < j; i++, j--) {{
swap(chars, i, j);
}}
return new String(chars);
}}
}}
```"""
optimized_code = CodeStringsMarkdown.parse_markdown_code(optimized_markdown, expected_language="java")
result = replace_function_definitions_for_language(
function_names=["reverse"],
optimized_code=optimized_code,
module_abspath=java_file,
project_root_path=tmp_path,
lang_support=java_support,
)
assert result is True
new_code = java_file.read_text(encoding="utf-8")
expected = """public class StringUtils {
private static void swap(char[] arr, int i, int j) {
char temp = arr[i];
arr[i] = arr[j];
arr[j] = temp;
}
public static String reverse(String s) {
char[] chars = s.toCharArray();
for (int i = 0, j = chars.length - 1; i < j; i++, j--) {
swap(chars, i, j);
}
return new String(chars);
}
}
"""
assert new_code == expected
def test_add_multiple_helpers(self, tmp_path: Path, java_support: JavaSupport) -> None:
"""Test optimization that adds multiple helper methods."""
java_file = tmp_path / "MathUtils.java"
original_code = """public class MathUtils {
public static int gcd(int a, int b) {
while (b != 0) {
int temp = b;
b = a % b;
a = temp;
}
return a;
}
}
"""
java_file.write_text(original_code, encoding="utf-8")
# Optimization adds multiple helper methods
optimized_markdown = f"""```java:{java_file.relative_to(tmp_path)}
public class MathUtils {{
private static int abs(int x) {{
return x < 0 ? -x : x;
}}
private static int gcdInternal(int a, int b) {{
return b == 0 ? a : gcdInternal(b, a % b);
}}
public static int gcd(int a, int b) {{
return gcdInternal(abs(a), abs(b));
}}
}}
```"""
optimized_code = CodeStringsMarkdown.parse_markdown_code(optimized_markdown, expected_language="java")
result = replace_function_definitions_for_language(
function_names=["gcd"],
optimized_code=optimized_code,
module_abspath=java_file,
project_root_path=tmp_path,
lang_support=java_support,
)
assert result is True
new_code = java_file.read_text(encoding="utf-8")
expected = """public class MathUtils {
private static int abs(int x) {
return x < 0 ? -x : x;
}
private static int gcdInternal(int a, int b) {
return b == 0 ? a : gcdInternal(b, a % b);
}
public static int gcd(int a, int b) {
return gcdInternal(abs(a), abs(b));
}
}
"""
assert new_code == expected
class TestOptimizationWithFieldsAndHelpers:
"""Tests for optimizations that add both static fields and helper methods."""
def test_add_field_and_helper_together(self, tmp_path: Path, java_support: JavaSupport) -> None:
"""Test optimization that adds both a static field and helper method."""
java_file = tmp_path / "Fibonacci.java"
original_code = """public class Fibonacci {
public static long fib(int n) {
if (n <= 1) return n;
return fib(n - 1) + fib(n - 2);
}
}
"""
java_file.write_text(original_code, encoding="utf-8")
# Optimization with memoization using static field and helper
optimized_markdown = f"""```java:{java_file.relative_to(tmp_path)}
public class Fibonacci {{
private static final long[] CACHE = new long[100];
private static final boolean[] COMPUTED = new boolean[100];
private static long fibMemo(int n) {{
if (n <= 1) return n;
if (n < 100 && COMPUTED[n]) return CACHE[n];
long result = fibMemo(n - 1) + fibMemo(n - 2);
if (n < 100) {{
CACHE[n] = result;
COMPUTED[n] = true;
}}
return result;
}}
public static long fib(int n) {{
return fibMemo(n);
}}
}}
```"""
optimized_code = CodeStringsMarkdown.parse_markdown_code(optimized_markdown, expected_language="java")
result = replace_function_definitions_for_language(
function_names=["fib"],
optimized_code=optimized_code,
module_abspath=java_file,
project_root_path=tmp_path,
lang_support=java_support,
)
assert result is True
new_code = java_file.read_text(encoding="utf-8")
expected = """public class Fibonacci {
private static final long[] CACHE = new long[100];
private static final boolean[] COMPUTED = new boolean[100];
private static long fibMemo(int n) {
if (n <= 1) return n;
if (n < 100 && COMPUTED[n]) return CACHE[n];
long result = fibMemo(n - 1) + fibMemo(n - 2);
if (n < 100) {
CACHE[n] = result;
COMPUTED[n] = true;
}
return result;
}
public static long fib(int n) {
return fibMemo(n);
}
}
"""
assert new_code == expected
def test_real_world_bytes_to_hex_optimization(self, tmp_path: Path, java_support: JavaSupport) -> None:
"""Test the actual bytesToHexString optimization pattern from aerospike."""
java_file = tmp_path / "Buffer.java"
original_code = """package com.example;
public final class Buffer {
public static String bytesToHexString(byte[] buf, int offset, int length) {
StringBuilder sb = new StringBuilder(length * 2);
for (int i = offset; i < length; i++) {
sb.append(String.format("%02x", buf[i]));
}
return sb.toString();
}
public static int otherMethod() {
return 42;
}
}
"""
java_file.write_text(original_code, encoding="utf-8")
# The actual optimization pattern generated by the AI
optimized_markdown = f"""```java:{java_file.relative_to(tmp_path)}
package com.example;
public final class Buffer {{
private static final String[] BYTE_TO_HEX = createByteToHex();
private static String[] createByteToHex() {{
String[] map = new String[256];
for (int b = -128; b <= 127; b++) {{
map[b + 128] = String.format("%02x", (byte) b);
}}
return map;
}}
public static String bytesToHexString(byte[] buf, int offset, int length) {{
StringBuilder sb = new StringBuilder(length * 2);
for (int i = offset; i < length; i++) {{
sb.append(BYTE_TO_HEX[buf[i] + 128]);
}}
return sb.toString();
}}
public static int otherMethod() {{
return 42;
}}
}}
```"""
optimized_code = CodeStringsMarkdown.parse_markdown_code(optimized_markdown, expected_language="java")
result = replace_function_definitions_for_language(
function_names=["bytesToHexString"],
optimized_code=optimized_code,
module_abspath=java_file,
project_root_path=tmp_path,
lang_support=java_support,
)
assert result is True
new_code = java_file.read_text(encoding="utf-8")
expected = """package com.example;
public final class Buffer {
private static final String[] BYTE_TO_HEX = createByteToHex();
private static String[] createByteToHex() {
String[] map = new String[256];
for (int b = -128; b <= 127; b++) {
map[b + 128] = String.format("%02x", (byte) b);
}
return map;
}
public static String bytesToHexString(byte[] buf, int offset, int length) {
StringBuilder sb = new StringBuilder(length * 2);
for (int i = offset; i < length; i++) {
sb.append(BYTE_TO_HEX[buf[i] + 128]);
}
return sb.toString();
}
public static int otherMethod() {
return 42;
}
}
"""
assert new_code == expected
class TestOverloadedMethods:
"""Tests for handling overloaded methods (same name, different signatures)."""
def test_replace_specific_overload_by_line_number(self, tmp_path: Path, java_support: JavaSupport) -> None:
"""Test replacing a specific overload when multiple exist."""
java_file = tmp_path / "Buffer.java"
original_code = """public final class Buffer {
public static String bytesToHexString(byte[] buf) {
if (buf == null || buf.length == 0) {
return "";
}
StringBuilder sb = new StringBuilder(buf.length * 2);
for (int i = 0; i < buf.length; i++) {
sb.append(String.format("%02x", buf[i]));
}
return sb.toString();
}
public static String bytesToHexString(byte[] buf, int offset, int length) {
StringBuilder sb = new StringBuilder(length * 2);
for (int i = offset; i < length; i++) {
sb.append(String.format("%02x", buf[i]));
}
return sb.toString();
}
}
"""
java_file.write_text(original_code, encoding="utf-8")
# Optimization only for the 3-argument version
optimized_markdown = f"""```java:{java_file.relative_to(tmp_path)}
public final class Buffer {{
private static final char[] HEX_CHARS = {{'0','1','2','3','4','5','6','7','8','9','a','b','c','d','e','f'}};
public static String bytesToHexString(byte[] buf, int offset, int length) {{
char[] out = new char[(length - offset) * 2];
for (int i = offset, j = 0; i < length; i++) {{
int v = buf[i] & 0xFF;
out[j++] = HEX_CHARS[v >>> 4];
out[j++] = HEX_CHARS[v & 0x0F];
}}
return new String(out);
}}
}}
```"""
optimized_code = CodeStringsMarkdown.parse_markdown_code(optimized_markdown, expected_language="java")
# Create FunctionToOptimize with line info for the 3-arg version (lines 13-18)
from codeflash.discovery.functions_to_optimize import FunctionParent, FunctionToOptimize
function_to_optimize = FunctionToOptimize(
function_name="bytesToHexString",
file_path=java_file,
starting_line=13, # Line where 3-arg version starts (1-indexed)
ending_line=18,
parents=[FunctionParent(name="Buffer", type="ClassDef")],
is_method=True,
)
result = replace_function_definitions_for_language(
function_names=["bytesToHexString"],
optimized_code=optimized_code,
module_abspath=java_file,
project_root_path=tmp_path,
lang_support=java_support,
function_to_optimize=function_to_optimize,
)
assert result is True
new_code = java_file.read_text(encoding="utf-8")
expected = """public final class Buffer {
private static final char[] HEX_CHARS = {'0','1','2','3','4','5','6','7','8','9','a','b','c','d','e','f'};
public static String bytesToHexString(byte[] buf) {
if (buf == null || buf.length == 0) {
return "";
}
StringBuilder sb = new StringBuilder(buf.length * 2);
for (int i = 0; i < buf.length; i++) {
sb.append(String.format("%02x", buf[i]));
}
return sb.toString();
}
public static String bytesToHexString(byte[] buf, int offset, int length) {
char[] out = new char[(length - offset) * 2];
for (int i = offset, j = 0; i < length; i++) {
int v = buf[i] & 0xFF;
out[j++] = HEX_CHARS[v >>> 4];
out[j++] = HEX_CHARS[v & 0x0F];
}
return new String(out);
}
}
"""
assert new_code == expected
class TestWrongMethodNameGeneration:
"""Tests that guard against the LLM generating a different method name than the target.
When the optimizer generates code for method X but the LLM produces method Y instead,
applying that replacement would:
- Replace method X with the body of method Y (creating a duplicate of Y).
- Remove method X from the source.
These tests verify that codeflash detects this mismatch and leaves the original
source file unchanged.
"""
def test_standalone_wrong_method_name_leaves_source_unchanged(self, tmp_path: Path, java_support: JavaSupport) -> None:
"""Standalone generated method with wrong name must not replace the target.
Reproduces the Unpacker.unpackObjectMap bug: the LLM was asked to optimise
``unpackObjectMap`` but generated ``unpackMap`` as a standalone method.
Applying that would create a duplicate ``unpackMap`` and delete
``unpackObjectMap``, causing compilation failures.
"""
from codeflash.discovery.functions_to_optimize import FunctionParent, FunctionToOptimize
java_file = tmp_path / "Unpacker.java"
original_code = """\
public abstract class Unpacker {
public static Object unpackObjectMap(byte[] buffer, int offset, int length) {
return new Object();
}
public Object unpackMap() {
return null;
}
}
"""
java_file.write_text(original_code, encoding="utf-8")
# The LLM generated an optimised ``unpackMap`` when it should have
# optimised ``unpackObjectMap``.
optimized_markdown = f"""```java:{java_file.relative_to(tmp_path)}
public final Object unpackMap() {{
return new Object();
}}
```"""
optimized_code = CodeStringsMarkdown.parse_markdown_code(optimized_markdown, expected_language="java")
function_to_optimize = FunctionToOptimize(
function_name="unpackObjectMap",
file_path=java_file,
starting_line=2,
ending_line=4,
parents=[FunctionParent(name="Unpacker", type="ClassDef")],
is_method=True,
)
result = java_support.replace_function_definitions(
function_names=["unpackObjectMap"],
optimized_code=optimized_code,
module_abspath=java_file,
project_root_path=tmp_path,
function_to_optimize=function_to_optimize,
)
# No modification should occur — wrong method name in generated code.
assert result is False
assert java_file.read_text(encoding="utf-8") == original_code
def test_class_wrapper_with_wrong_target_method_leaves_source_unchanged(self, tmp_path: Path, java_support: JavaSupport) -> None:
"""Class-wrapped generated code missing the target method must not modify source.
Reproduces the Command.estimateKeySize bug: the LLM generated a class that
contained only ``sizeTxn`` (a helper) and did not include ``estimateKeySize``
(the target). Applying it would duplicate ``sizeTxn`` in the source.
"""
from codeflash.discovery.functions_to_optimize import FunctionParent, FunctionToOptimize
java_file = tmp_path / "Command.java"
original_code = """\
public class Command {
public int estimateKeySize(String key) {
return key.length() + 4;
}
private int sizeTxn(String key, Object txn, boolean hasWrite) {
return key.length();
}
}
"""
java_file.write_text(original_code, encoding="utf-8")
# The LLM generated a class containing only ``sizeTxn`` instead of
# the target ``estimateKeySize``.
optimized_markdown = f"""```java:{java_file.relative_to(tmp_path)}
public class Command {{
private int sizeTxn(String key, Object txn, boolean hasWrite) {{
return key.length() + 1;
}}
}}
```"""
optimized_code = CodeStringsMarkdown.parse_markdown_code(optimized_markdown, expected_language="java")
function_to_optimize = FunctionToOptimize(
function_name="estimateKeySize",
file_path=java_file,
starting_line=2,
ending_line=4,
parents=[FunctionParent(name="Command", type="ClassDef")],
is_method=True,
)
result = java_support.replace_function_definitions(
function_names=["estimateKeySize"],
optimized_code=optimized_code,
module_abspath=java_file,
project_root_path=tmp_path,
function_to_optimize=function_to_optimize,
)
# No modification should occur — target method absent from generated class.
assert result is False
assert java_file.read_text(encoding="utf-8") == original_code
class TestAnonymousInnerClassMethods:
"""Tests that methods inside anonymous inner classes are not hoisted as helpers.
When an optimised method uses an anonymous class (e.g. an inline Iterator),
the anonymous class's own methods (hasNext, next, remove ...) must NOT be
extracted and inserted as top-level class members. Doing so would create
broken methods: they would carry @Override annotations that do not correspond
to any supertype method, and would reference variables only available in the
enclosing method scope.
"""
def test_anonymous_iterator_methods_not_hoisted_to_class(self, tmp_path: Path, java_support: JavaSupport) -> None:
"""Reproduces the LuaMap.keySetIterator bug.
The LLM optimised ``keySetIterator`` by returning an anonymous
``Iterator`` with ``hasNext``, ``next``, and ``remove`` methods.
Those three methods must remain inside the anonymous class body and
must NOT be added as top-level members of the outer class.
"""
from codeflash.discovery.functions_to_optimize import FunctionParent, FunctionToOptimize
java_file = tmp_path / "LuaMap.java"
original_code = """\
import java.util.Iterator;
import java.util.Map;
public final class LuaMap {
private final Map<String, String> map;
public LuaMap(Map<String, String> map) {
this.map = map;
}
public Iterator<String> keySetIterator() {
return map.keySet().iterator();
}
public int size() {
return map.size();
}
}
"""
java_file.write_text(original_code, encoding="utf-8")
# Optimised version returns a custom anonymous Iterator that avoids
# creating a keySet view for empty maps.
optimized_markdown = f"""```java:{java_file.relative_to(tmp_path)}
import java.util.Iterator;
import java.util.Map;
public final class LuaMap {{
private final Map<String, String> map;
public LuaMap(Map<String, String> map) {{
this.map = map;
}}
public Iterator<String> keySetIterator() {{
if (map.isEmpty()) {{
return java.util.Collections.emptyIterator();
}}
final Iterator<Map.Entry<String, String>> it = map.entrySet().iterator();
return new Iterator<String>() {{
@Override
public boolean hasNext() {{
return it.hasNext();
}}
@Override
public String next() {{
return it.next().getKey();
}}
@Override
public void remove() {{
it.remove();
}}
}};
}}
public int size() {{
return map.size();
}}
}}
```"""
optimized_code = CodeStringsMarkdown.parse_markdown_code(optimized_markdown, expected_language="java")
function_to_optimize = FunctionToOptimize(
function_name="keySetIterator",
file_path=java_file,
starting_line=11,
ending_line=13,
parents=[FunctionParent(name="LuaMap", type="ClassDef")],
is_method=True,
)
result = java_support.replace_function_definitions(
function_names=["keySetIterator"],
optimized_code=optimized_code,
module_abspath=java_file,
project_root_path=tmp_path,
function_to_optimize=function_to_optimize,
)
assert result is True
new_code = java_file.read_text(encoding="utf-8")
expected_code = """\
import java.util.Iterator;
import java.util.Map;
public final class LuaMap {
private final Map<String, String> map;
public LuaMap(Map<String, String> map) {
this.map = map;
}
public Iterator<String> keySetIterator() {
if (map.isEmpty()) {
return java.util.Collections.emptyIterator();
}
final Iterator<Map.Entry<String, String>> it = map.entrySet().iterator();
return new Iterator<String>() {
@Override
public boolean hasNext() {
return it.hasNext();
}
@Override
public String next() {
return it.next().getKey();
}
@Override
public void remove() {
it.remove();
}
};
}
public int size() {
return map.size();
}
}
"""
assert new_code == expected_code
class TestFieldInjectionClassFiltering:
"""Tests that fields from inner/anonymous classes are not injected into the target class."""
def test_inner_class_fields_not_injected_into_outer(self, tmp_path: Path, java_support: JavaSupport) -> None:
"""Reproduces the Guava/Iterables.mergeSorted bug.
When the LLM generates an optimization that includes an inner class with
fields (e.g., generic type parameters), those fields must NOT be injected
into the outer class where the target method lives.
"""
from codeflash.discovery.functions_to_optimize import FunctionParent, FunctionToOptimize
java_file = tmp_path / "Outer.java"
original_code = """\
public class Outer {
private int count;
public int process(int x) {
return x + count;
}
}
"""
java_file.write_text(original_code, encoding="utf-8")
# LLM generates optimization with an inner class that has its own field.
# The inner class's field should NOT be injected into Outer.
optimized_markdown = f"""```java:{java_file.relative_to(tmp_path)}
public class Outer {{
private int count;
private static final int OFFSET = 10;
public int process(int x) {{
return x + count + OFFSET;
}}
private static class Inner {{
private final String badField;
Inner(String s) {{
this.badField = s;
}}
}}
}}
```"""
optimized_code = CodeStringsMarkdown.parse_markdown_code(optimized_markdown, expected_language="java")
function_to_optimize = FunctionToOptimize(
function_name="process",
file_path=java_file,
starting_line=4,
ending_line=6,
parents=[FunctionParent(name="Outer", type="ClassDef")],
is_method=True,
)
result = java_support.replace_function_definitions(
function_names=["process"],
optimized_code=optimized_code,
module_abspath=java_file,
project_root_path=tmp_path,
function_to_optimize=function_to_optimize,
)
assert result is True
new_code = java_file.read_text(encoding="utf-8")
# Only OFFSET field should be added (belongs to Outer).
# badField belongs to Inner and should NOT appear.
assert "OFFSET" in new_code
assert "badField" not in new_code