"""Tests for Java code context extraction.""" from pathlib import Path import pytest from codeflash.languages.base import FunctionFilterCriteria, Language, ParentInfo from codeflash.languages.java.context import ( extract_class_context, extract_code_context, extract_function_source, extract_read_only_context, find_helper_functions, ) from codeflash.languages.java.discovery import discover_functions_from_source from codeflash.languages.java.parser import get_java_analyzer # Filter criteria that includes void methods NO_RETURN_FILTER = FunctionFilterCriteria(require_return=False) class TestExtractCodeContextBasic: """Tests for basic extract_code_context functionality.""" def test_simple_method(self, tmp_path: Path): """Test extracting context for a simple method.""" java_file = tmp_path / "Calculator.java" java_file.write_text("""public class Calculator { public int add(int a, int b) { return a + b; } } """) functions = discover_functions_from_source(java_file.read_text(), file_path=java_file) assert len(functions) == 1 context = extract_code_context(functions[0], tmp_path) assert context.language == Language.JAVA assert context.target_file == java_file # Method is wrapped in class skeleton assert ( context.target_code == """public class Calculator { public int add(int a, int b) { return a + b; } } """ ) assert context.imports == [] assert context.helper_functions == [] assert context.read_only_context == "" def test_method_with_javadoc(self, tmp_path: Path): """Test extracting context for method with Javadoc.""" java_file = tmp_path / "Calculator.java" java_file.write_text("""public class Calculator { /** * Adds two numbers. * @param a first number * @param b second number * @return sum */ public int add(int a, int b) { return a + b; } } """) functions = discover_functions_from_source(java_file.read_text(), file_path=java_file) assert len(functions) == 1 context = extract_code_context(functions[0], tmp_path) assert context.language == Language.JAVA assert context.target_file == java_file assert ( context.target_code == """public class Calculator { /** * Adds two numbers. * @param a first number * @param b second number * @return sum */ public int add(int a, int b) { return a + b; } } """ ) assert context.imports == [] assert context.helper_functions == [] assert context.read_only_context == "" def test_static_method(self, tmp_path: Path): """Test extracting context for a static method.""" java_file = tmp_path / "MathUtils.java" java_file.write_text("""public class MathUtils { public static int multiply(int a, int b) { return a * b; } } """) functions = discover_functions_from_source(java_file.read_text(), file_path=java_file) assert len(functions) == 1 context = extract_code_context(functions[0], tmp_path) assert context.language == Language.JAVA assert context.target_file == java_file assert ( context.target_code == """public class MathUtils { public static int multiply(int a, int b) { return a * b; } } """ ) assert context.imports == [] assert context.helper_functions == [] assert context.read_only_context == "" def test_private_method(self, tmp_path: Path): """Test extracting context for a private method.""" java_file = tmp_path / "Helper.java" java_file.write_text("""public class Helper { private int getValue() { return 42; } } """) functions = discover_functions_from_source(java_file.read_text(), file_path=java_file) assert len(functions) == 1 context = extract_code_context(functions[0], tmp_path) assert context.language == Language.JAVA assert context.target_file == java_file assert ( context.target_code == """public class Helper { private int getValue() { return 42; } } """ ) def test_protected_method(self, tmp_path: Path): """Test extracting context for a protected method.""" java_file = tmp_path / "Base.java" java_file.write_text("""public class Base { protected int compute(int x) { return x * 2; } } """) functions = discover_functions_from_source(java_file.read_text(), file_path=java_file) assert len(functions) == 1 context = extract_code_context(functions[0], tmp_path) assert context.language == Language.JAVA assert context.target_file == java_file assert ( context.target_code == """public class Base { protected int compute(int x) { return x * 2; } } """ ) def test_synchronized_method(self, tmp_path: Path): """Test extracting context for a synchronized method.""" java_file = tmp_path / "Counter.java" java_file.write_text("""public class Counter { public synchronized int getCount() { return count; } } """) functions = discover_functions_from_source(java_file.read_text(), file_path=java_file) assert len(functions) == 1 context = extract_code_context(functions[0], tmp_path) assert context.language == Language.JAVA assert ( context.target_code == """public class Counter { public synchronized int getCount() { return count; } } """ ) def test_method_with_throws(self, tmp_path: Path): """Test extracting context for a method with throws clause.""" java_file = tmp_path / "FileHandler.java" java_file.write_text("""public class FileHandler { public String readFile(String path) throws IOException, FileNotFoundException { return Files.readString(Path.of(path)); } } """) functions = discover_functions_from_source(java_file.read_text(), file_path=java_file) assert len(functions) == 1 context = extract_code_context(functions[0], tmp_path) assert context.language == Language.JAVA assert ( context.target_code == """public class FileHandler { public String readFile(String path) throws IOException, FileNotFoundException { return Files.readString(Path.of(path)); } } """ ) def test_method_with_varargs(self, tmp_path: Path): """Test extracting context for a method with varargs.""" java_file = tmp_path / "Logger.java" java_file.write_text("""public class Logger { public String format(String... messages) { return String.join(", ", messages); } } """) functions = discover_functions_from_source(java_file.read_text(), file_path=java_file) assert len(functions) == 1 context = extract_code_context(functions[0], tmp_path) assert context.language == Language.JAVA assert ( context.target_code == """public class Logger { public String format(String... messages) { return String.join(", ", messages); } } """ ) def test_void_method(self, tmp_path: Path): """Test extracting context for a void method.""" java_file = tmp_path / "Printer.java" java_file.write_text("""public class Printer { public void print(String text) { System.out.println(text); } } """) functions = discover_functions_from_source( java_file.read_text(), file_path=java_file, filter_criteria=NO_RETURN_FILTER ) assert len(functions) == 1 context = extract_code_context(functions[0], tmp_path) assert context.language == Language.JAVA assert ( context.target_code == """public class Printer { public void print(String text) { System.out.println(text); } } """ ) def test_generic_return_type(self, tmp_path: Path): """Test extracting context for a method with generic return type.""" java_file = tmp_path / "Container.java" java_file.write_text("""public class Container { public List getNames() { return new ArrayList<>(); } } """) functions = discover_functions_from_source(java_file.read_text(), file_path=java_file) assert len(functions) == 1 context = extract_code_context(functions[0], tmp_path) assert context.language == Language.JAVA assert ( context.target_code == """public class Container { public List getNames() { return new ArrayList<>(); } } """ ) class TestExtractCodeContextWithImports: """Tests for extract_code_context with various import types.""" def test_with_package_and_imports(self, tmp_path: Path): """Test context extraction with package and imports.""" java_file = tmp_path / "Calculator.java" java_file.write_text("""package com.example; import java.util.List; public class Calculator { private int base = 0; public int add(int a, int b) { return a + b + base; } } """) functions = discover_functions_from_source(java_file.read_text(), file_path=java_file) add_func = next((f for f in functions if f.function_name == "add"), None) assert add_func is not None context = extract_code_context(add_func, tmp_path) assert context.language == Language.JAVA assert context.target_file == java_file # Class skeleton includes fields assert ( context.target_code == """public class Calculator { private int base = 0; public int add(int a, int b) { return a + b + base; } } """ ) assert context.imports == ["import java.util.List;"] # Fields are in skeleton, so read_only_context is empty assert context.read_only_context == "" def test_with_static_imports(self, tmp_path: Path): """Test context extraction with static imports.""" java_file = tmp_path / "Calculator.java" java_file.write_text("""package com.example; import java.util.List; import static java.lang.Math.PI; import static java.lang.Math.sqrt; public class Calculator { public double circleArea(double radius) { return PI * radius * radius; } } """) functions = discover_functions_from_source(java_file.read_text(), file_path=java_file) assert len(functions) == 1 context = extract_code_context(functions[0], tmp_path) assert context.language == Language.JAVA assert ( context.target_code == """public class Calculator { public double circleArea(double radius) { return PI * radius * radius; } } """ ) assert context.imports == [ "import java.util.List;", "import static java.lang.Math.PI;", "import static java.lang.Math.sqrt;", ] def test_with_wildcard_imports(self, tmp_path: Path): """Test context extraction with wildcard imports.""" java_file = tmp_path / "Processor.java" java_file.write_text("""package com.example; import java.util.*; import java.io.*; public class Processor { public List process(String input) { return Arrays.asList(input.split(",")); } } """) functions = discover_functions_from_source(java_file.read_text(), file_path=java_file) assert len(functions) == 1 context = extract_code_context(functions[0], tmp_path) assert context.language == Language.JAVA assert context.imports == ["import java.util.*;", "import java.io.*;"] def test_with_multiple_import_types(self, tmp_path: Path): """Test context extraction with various import types.""" java_file = tmp_path / "Handler.java" java_file.write_text("""package com.example; import java.util.List; import java.util.Map; import java.util.ArrayList; import static java.util.Collections.sort; import static java.util.Collections.reverse; public class Handler { public List sortNumbers(List nums) { sort(nums); return nums; } } """) functions = discover_functions_from_source(java_file.read_text(), file_path=java_file) assert len(functions) == 1 context = extract_code_context(functions[0], tmp_path) assert ( context.target_code == """public class Handler { public List sortNumbers(List nums) { sort(nums); return nums; } } """ ) assert context.imports == [ "import java.util.List;", "import java.util.Map;", "import java.util.ArrayList;", "import static java.util.Collections.sort;", "import static java.util.Collections.reverse;", ] assert context.read_only_context == "" assert context.helper_functions == [] class TestExtractCodeContextWithFields: """Tests for extract_code_context with class fields. Note: When fields are included in the class skeleton (target_code), read_only_context should be empty to avoid duplication. """ def test_with_instance_fields(self, tmp_path: Path): """Test context extraction with instance fields.""" java_file = tmp_path / "Person.java" java_file.write_text("""public class Person { private String name; private int age; public String getName() { return name; } } """) functions = discover_functions_from_source(java_file.read_text(), file_path=java_file) assert len(functions) == 1 context = extract_code_context(functions[0], tmp_path) assert context.language == Language.JAVA # Class skeleton includes fields assert ( context.target_code == """public class Person { private String name; private int age; public String getName() { return name; } } """ ) # Fields are in skeleton, so read_only_context is empty (no duplication) assert context.read_only_context == "" assert context.imports == [] assert context.helper_functions == [] def test_with_static_fields(self, tmp_path: Path): """Test context extraction with static fields.""" java_file = tmp_path / "Counter.java" java_file.write_text("""public class Counter { private static int instanceCount = 0; private static String prefix = "counter_"; public int getCount() { return instanceCount; } } """) functions = discover_functions_from_source(java_file.read_text(), file_path=java_file) assert len(functions) == 1 context = extract_code_context(functions[0], tmp_path) assert ( context.target_code == """public class Counter { private static int instanceCount = 0; private static String prefix = "counter_"; public int getCount() { return instanceCount; } } """ ) # Fields are in skeleton, so read_only_context is empty assert context.read_only_context == "" def test_with_final_fields(self, tmp_path: Path): """Test context extraction with final fields.""" java_file = tmp_path / "Config.java" java_file.write_text("""public class Config { private final String name; private final int maxSize; public String getName() { return name; } } """) functions = discover_functions_from_source(java_file.read_text(), file_path=java_file) assert len(functions) == 1 context = extract_code_context(functions[0], tmp_path) assert ( context.target_code == """public class Config { private final String name; private final int maxSize; public String getName() { return name; } } """ ) assert context.read_only_context == "" def test_with_static_final_constants(self, tmp_path: Path): """Test context extraction with static final constants.""" java_file = tmp_path / "Constants.java" java_file.write_text("""public class Constants { public static final double PI = 3.14159; public static final int MAX_VALUE = 100; private static final String PREFIX = "const_"; public double getPI() { return PI; } } """) functions = discover_functions_from_source(java_file.read_text(), file_path=java_file) assert len(functions) == 1 context = extract_code_context(functions[0], tmp_path) assert ( context.target_code == """public class Constants { public static final double PI = 3.14159; public static final int MAX_VALUE = 100; private static final String PREFIX = "const_"; public double getPI() { return PI; } } """ ) assert context.read_only_context == "" def test_with_volatile_fields(self, tmp_path: Path): """Test context extraction with volatile fields.""" java_file = tmp_path / "ThreadSafe.java" java_file.write_text("""public class ThreadSafe { private volatile boolean running = true; private volatile int counter = 0; public boolean isRunning() { return running; } } """) functions = discover_functions_from_source(java_file.read_text(), file_path=java_file) assert len(functions) == 1 context = extract_code_context(functions[0], tmp_path) assert ( context.target_code == """public class ThreadSafe { private volatile boolean running = true; private volatile int counter = 0; public boolean isRunning() { return running; } } """ ) assert context.read_only_context == "" def test_with_generic_fields(self, tmp_path: Path): """Test context extraction with generic type fields.""" java_file = tmp_path / "Container.java" java_file.write_text("""public class Container { private List names; private Map scores; private Set ids; public List getNames() { return names; } } """) functions = discover_functions_from_source(java_file.read_text(), file_path=java_file) assert len(functions) == 1 context = extract_code_context(functions[0], tmp_path) assert ( context.target_code == """public class Container { private List names; private Map scores; private Set ids; public List getNames() { return names; } } """ ) assert context.read_only_context == "" def test_with_array_fields(self, tmp_path: Path): """Test context extraction with array fields.""" java_file = tmp_path / "ArrayHolder.java" java_file.write_text("""public class ArrayHolder { private int[] numbers; private String[] names; private double[][] matrix; public int[] getNumbers() { return numbers; } } """) functions = discover_functions_from_source(java_file.read_text(), file_path=java_file) assert len(functions) == 1 context = extract_code_context(functions[0], tmp_path) assert ( context.target_code == """public class ArrayHolder { private int[] numbers; private String[] names; private double[][] matrix; public int[] getNumbers() { return numbers; } } """ ) assert context.read_only_context == "" class TestExtractCodeContextWithHelpers: """Tests for extract_code_context with helper functions.""" def test_single_helper_method(self, tmp_path: Path): """Test context extraction with a single helper method.""" java_file = tmp_path / "Processor.java" java_file.write_text("""public class Processor { public String process(String input) { return normalize(input); } private String normalize(String s) { return s.trim().toLowerCase(); } } """) functions = discover_functions_from_source(java_file.read_text(), file_path=java_file) process_func = next((f for f in functions if f.function_name == "process"), None) assert process_func is not None context = extract_code_context(process_func, tmp_path) assert context.language == Language.JAVA assert ( context.target_code == """public class Processor { public String process(String input) { return normalize(input); } } """ ) assert len(context.helper_functions) == 1 assert context.helper_functions[0].name == "normalize" assert ( context.helper_functions[0].source_code == "private String normalize(String s) {\n return s.trim().toLowerCase();\n }" ) def test_multiple_helper_methods(self, tmp_path: Path): """Test context extraction with multiple helper methods.""" java_file = tmp_path / "Processor.java" java_file.write_text("""public class Processor { public String process(String input) { String trimmed = trim(input); return upper(trimmed); } private String trim(String s) { return s.trim(); } private String upper(String s) { return s.toUpperCase(); } private String unused(String s) { return s; } } """) functions = discover_functions_from_source(java_file.read_text(), file_path=java_file) process_func = next((f for f in functions if f.function_name == "process"), None) assert process_func is not None context = extract_code_context(process_func, tmp_path) assert ( context.target_code == """public class Processor { public String process(String input) { String trimmed = trim(input); return upper(trimmed); } } """ ) assert context.read_only_context == "" assert context.imports == [] helper_names = sorted([h.name for h in context.helper_functions]) assert helper_names == ["trim", "upper"] def test_chained_helper_calls(self, tmp_path: Path): """Test context extraction with chained helper calls.""" java_file = tmp_path / "Processor.java" java_file.write_text("""public class Processor { public String process(String input) { return normalize(input); } private String normalize(String s) { return sanitize(s).toLowerCase(); } private String sanitize(String s) { return s.trim(); } } """) functions = discover_functions_from_source(java_file.read_text(), file_path=java_file) process_func = next((f for f in functions if f.function_name == "process"), None) assert process_func is not None context = extract_code_context(process_func, tmp_path) helper_names = [h.name for h in context.helper_functions] assert helper_names == ["normalize"] def test_no_helpers_when_none_called(self, tmp_path: Path): """Test context extraction when no helpers are called.""" java_file = tmp_path / "Calculator.java" java_file.write_text("""public class Calculator { public int add(int a, int b) { return a + b; } private int unused(int x) { return x * 2; } } """) functions = discover_functions_from_source(java_file.read_text(), file_path=java_file) add_func = next((f for f in functions if f.function_name == "add"), None) assert add_func is not None context = extract_code_context(add_func, tmp_path) assert ( context.target_code == """public class Calculator { public int add(int a, int b) { return a + b; } } """ ) assert context.helper_functions == [] def test_static_helper_from_instance_method(self, tmp_path: Path): """Test context extraction with static helper called from instance method.""" java_file = tmp_path / "Calculator.java" java_file.write_text("""public class Calculator { public int calculate(int x) { return staticHelper(x); } private static int staticHelper(int x) { return x * 2; } } """) functions = discover_functions_from_source(java_file.read_text(), file_path=java_file) calc_func = next((f for f in functions if f.function_name == "calculate"), None) assert calc_func is not None context = extract_code_context(calc_func, tmp_path) helper_names = [h.name for h in context.helper_functions] assert helper_names == ["staticHelper"] class TestExtractCodeContextWithJavadoc: """Tests for extract_code_context with various Javadoc patterns.""" def test_simple_javadoc(self, tmp_path: Path): """Test context extraction with simple Javadoc.""" java_file = tmp_path / "Example.java" java_file.write_text("""public class Example { /** Simple description. */ public void doSomething() { } } """) functions = discover_functions_from_source( java_file.read_text(), file_path=java_file, filter_criteria=NO_RETURN_FILTER ) assert len(functions) == 1 context = extract_code_context(functions[0], tmp_path) assert ( context.target_code == """public class Example { /** Simple description. */ public void doSomething() { } } """ ) def test_javadoc_with_params(self, tmp_path: Path): """Test context extraction with Javadoc @param tags.""" java_file = tmp_path / "Calculator.java" java_file.write_text("""public class Calculator { /** * Adds two numbers. * @param a the first number * @param b the second number */ public int add(int a, int b) { return a + b; } } """) functions = discover_functions_from_source(java_file.read_text(), file_path=java_file) assert len(functions) == 1 context = extract_code_context(functions[0], tmp_path) assert ( context.target_code == """public class Calculator { /** * Adds two numbers. * @param a the first number * @param b the second number */ public int add(int a, int b) { return a + b; } } """ ) def test_javadoc_with_return(self, tmp_path: Path): """Test context extraction with Javadoc @return tag.""" java_file = tmp_path / "Calculator.java" java_file.write_text("""public class Calculator { /** * Computes the sum. * @return the sum of a and b */ public int add(int a, int b) { return a + b; } } """) functions = discover_functions_from_source(java_file.read_text(), file_path=java_file) assert len(functions) == 1 context = extract_code_context(functions[0], tmp_path) assert ( context.target_code == """public class Calculator { /** * Computes the sum. * @return the sum of a and b */ public int add(int a, int b) { return a + b; } } """ ) def test_javadoc_with_throws(self, tmp_path: Path): """Test context extraction with Javadoc @throws tag.""" java_file = tmp_path / "Divider.java" java_file.write_text("""public class Divider { /** * Divides two numbers. * @throws ArithmeticException if divisor is zero * @throws IllegalArgumentException if inputs are negative */ public double divide(double a, double b) { if (b == 0) throw new ArithmeticException(); return a / b; } } """) functions = discover_functions_from_source(java_file.read_text(), file_path=java_file) assert len(functions) == 1 context = extract_code_context(functions[0], tmp_path) assert ( context.target_code == """public class Divider { /** * Divides two numbers. * @throws ArithmeticException if divisor is zero * @throws IllegalArgumentException if inputs are negative */ public double divide(double a, double b) { if (b == 0) throw new ArithmeticException(); return a / b; } } """ ) def test_javadoc_multiline(self, tmp_path: Path): """Test context extraction with multi-paragraph Javadoc.""" java_file = tmp_path / "Complex.java" java_file.write_text("""public class Complex { /** * This is a complex method. * *

It does many things:

*
    *
  • First thing
  • *
  • Second thing
  • *
* * @param input the input value * @return the processed result */ public String process(String input) { return input.toUpperCase(); } } """) functions = discover_functions_from_source(java_file.read_text(), file_path=java_file) assert len(functions) == 1 context = extract_code_context(functions[0], tmp_path) assert ( context.target_code == """public class Complex { /** * This is a complex method. * *

It does many things:

*
    *
  • First thing
  • *
  • Second thing
  • *
* * @param input the input value * @return the processed result */ public String process(String input) { return input.toUpperCase(); } } """ ) class TestExtractCodeContextWithGenerics: """Tests for extract_code_context with generic types.""" def test_generic_method_type_parameter(self, tmp_path: Path): """Test context extraction with generic type parameter.""" java_file = tmp_path / "Utils.java" java_file.write_text("""public class Utils { public T identity(T value) { return value; } } """) functions = discover_functions_from_source(java_file.read_text(), file_path=java_file) assert len(functions) == 1 context = extract_code_context(functions[0], tmp_path) assert ( context.target_code == """public class Utils { public T identity(T value) { return value; } } """ ) def test_bounded_type_parameter(self, tmp_path: Path): """Test context extraction with bounded type parameter.""" java_file = tmp_path / "Statistics.java" java_file.write_text("""public class Statistics { public double average(List numbers) { double sum = 0; for (T num : numbers) { sum += num.doubleValue(); } return sum / numbers.size(); } } """) functions = discover_functions_from_source(java_file.read_text(), file_path=java_file) assert len(functions) == 1 context = extract_code_context(functions[0], tmp_path) assert ( context.target_code == """public class Statistics { public double average(List numbers) { double sum = 0; for (T num : numbers) { sum += num.doubleValue(); } return sum / numbers.size(); } } """ ) def test_wildcard_type(self, tmp_path: Path): """Test context extraction with wildcard type.""" java_file = tmp_path / "Printer.java" java_file.write_text("""public class Printer { public int countItems(List items) { return items.size(); } } """) functions = discover_functions_from_source(java_file.read_text(), file_path=java_file) assert len(functions) == 1 context = extract_code_context(functions[0], tmp_path) assert ( context.target_code == """public class Printer { public int countItems(List items) { return items.size(); } } """ ) def test_bounded_wildcard_extends(self, tmp_path: Path): """Test context extraction with upper bounded wildcard.""" java_file = tmp_path / "Aggregator.java" java_file.write_text("""public class Aggregator { public double sum(List numbers) { double total = 0; for (Number n : numbers) { total += n.doubleValue(); } return total; } } """) functions = discover_functions_from_source(java_file.read_text(), file_path=java_file) assert len(functions) == 1 context = extract_code_context(functions[0], tmp_path) assert ( context.target_code == """public class Aggregator { public double sum(List numbers) { double total = 0; for (Number n : numbers) { total += n.doubleValue(); } return total; } } """ ) def test_bounded_wildcard_super(self, tmp_path: Path): """Test context extraction with lower bounded wildcard.""" java_file = tmp_path / "Filler.java" java_file.write_text("""public class Filler { public boolean fill(List list, Integer value) { list.add(value); return true; } } """) functions = discover_functions_from_source(java_file.read_text(), file_path=java_file) assert len(functions) == 1 context = extract_code_context(functions[0], tmp_path) assert ( context.target_code == """public class Filler { public boolean fill(List list, Integer value) { list.add(value); return true; } } """ ) def test_multiple_type_parameters(self, tmp_path: Path): """Test context extraction with multiple type parameters.""" java_file = tmp_path / "Mapper.java" java_file.write_text("""public class Mapper { public Map invert(Map map) { Map result = new HashMap<>(); for (Map.Entry entry : map.entrySet()) { result.put(entry.getValue(), entry.getKey()); } return result; } } """) functions = discover_functions_from_source(java_file.read_text(), file_path=java_file) assert len(functions) == 1 context = extract_code_context(functions[0], tmp_path) assert ( context.target_code == """public class Mapper { public Map invert(Map map) { Map result = new HashMap<>(); for (Map.Entry entry : map.entrySet()) { result.put(entry.getValue(), entry.getKey()); } return result; } } """ ) def test_recursive_type_bound(self, tmp_path: Path): """Test context extraction with recursive type bound.""" java_file = tmp_path / "Sorter.java" java_file.write_text("""public class Sorter { public > T max(T a, T b) { return a.compareTo(b) > 0 ? a : b; } } """) functions = discover_functions_from_source(java_file.read_text(), file_path=java_file) assert len(functions) == 1 context = extract_code_context(functions[0], tmp_path) assert ( context.target_code == """public class Sorter { public > T max(T a, T b) { return a.compareTo(b) > 0 ? a : b; } } """ ) class TestExtractCodeContextWithAnnotations: """Tests for extract_code_context with annotations.""" def test_override_annotation(self, tmp_path: Path): """Test context extraction with @Override annotation.""" java_file = tmp_path / "Child.java" java_file.write_text("""public class Child extends Parent { @Override public String toString() { return "Child"; } } """) functions = discover_functions_from_source(java_file.read_text(), file_path=java_file) assert len(functions) == 1 context = extract_code_context(functions[0], tmp_path) assert ( context.target_code == """public class Child extends Parent { @Override public String toString() { return "Child"; } } """ ) def test_deprecated_annotation(self, tmp_path: Path): """Test context extraction with @Deprecated annotation.""" java_file = tmp_path / "Legacy.java" java_file.write_text("""public class Legacy { @Deprecated public int oldMethod() { return 0; } } """) functions = discover_functions_from_source(java_file.read_text(), file_path=java_file) assert len(functions) == 1 context = extract_code_context(functions[0], tmp_path) assert ( context.target_code == """public class Legacy { @Deprecated public int oldMethod() { return 0; } } """ ) def test_suppress_warnings_annotation(self, tmp_path: Path): """Test context extraction with @SuppressWarnings annotation.""" java_file = tmp_path / "Processor.java" java_file.write_text("""public class Processor { @SuppressWarnings("unchecked") public List process(Object input) { return (List) input; } } """) functions = discover_functions_from_source(java_file.read_text(), file_path=java_file) assert len(functions) == 1 context = extract_code_context(functions[0], tmp_path) assert ( context.target_code == """public class Processor { @SuppressWarnings("unchecked") public List process(Object input) { return (List) input; } } """ ) def test_multiple_annotations(self, tmp_path: Path): """Test context extraction with multiple annotations.""" java_file = tmp_path / "Service.java" java_file.write_text("""public class Service { @Override @Deprecated @SuppressWarnings("deprecation") public String legacyMethod() { return "legacy"; } } """) functions = discover_functions_from_source(java_file.read_text(), file_path=java_file) assert len(functions) == 1 context = extract_code_context(functions[0], tmp_path) assert ( context.target_code == """public class Service { @Override @Deprecated @SuppressWarnings("deprecation") public String legacyMethod() { return "legacy"; } } """ ) def test_annotation_with_array_value(self, tmp_path: Path): """Test context extraction with annotation array value.""" java_file = tmp_path / "Handler.java" java_file.write_text("""public class Handler { @SuppressWarnings({"unchecked", "rawtypes"}) public Object handle(Object input) { return input; } } """) functions = discover_functions_from_source(java_file.read_text(), file_path=java_file) assert len(functions) == 1 context = extract_code_context(functions[0], tmp_path) assert ( context.target_code == """public class Handler { @SuppressWarnings({"unchecked", "rawtypes"}) public Object handle(Object input) { return input; } } """ ) class TestExtractCodeContextWithInheritance: """Tests for extract_code_context with inheritance scenarios.""" def test_method_in_subclass(self, tmp_path: Path): """Test context extraction for method in subclass.""" java_file = tmp_path / "AdvancedCalc.java" java_file.write_text("""public class AdvancedCalc extends Calculator { public int multiply(int a, int b) { return a * b; } } """) functions = discover_functions_from_source(java_file.read_text(), file_path=java_file) assert len(functions) == 1 context = extract_code_context(functions[0], tmp_path) assert context.language == Language.JAVA # Class skeleton includes extends clause assert ( context.target_code == """public class AdvancedCalc extends Calculator { public int multiply(int a, int b) { return a * b; } } """ ) def test_interface_implementation(self, tmp_path: Path): """Test context extraction for interface implementation.""" java_file = tmp_path / "MyComparable.java" java_file.write_text("""public class MyComparable implements Comparable { private int value; @Override public int compareTo(MyComparable other) { return Integer.compare(this.value, other.value); } } """) functions = discover_functions_from_source(java_file.read_text(), file_path=java_file) assert len(functions) == 1 context = extract_code_context(functions[0], tmp_path) # Class skeleton includes implements clause and fields assert ( context.target_code == """public class MyComparable implements Comparable { private int value; @Override public int compareTo(MyComparable other) { return Integer.compare(this.value, other.value); } } """ ) # Fields are in skeleton, so read_only_context is empty (no duplication) assert context.read_only_context == "" def test_multiple_interfaces(self, tmp_path: Path): """Test context extraction for multiple interface implementations.""" java_file = tmp_path / "MultiImpl.java" java_file.write_text("""public class MultiImpl implements Runnable, Comparable { public void run() { System.out.println("Running"); } public int compareTo(MultiImpl other) { return 0; } } """) functions = discover_functions_from_source( java_file.read_text(), file_path=java_file, filter_criteria=NO_RETURN_FILTER ) assert len(functions) == 2 run_func = next((f for f in functions if f.function_name == "run"), None) assert run_func is not None context = extract_code_context(run_func, tmp_path) assert ( context.target_code == """public class MultiImpl implements Runnable, Comparable { public void run() { System.out.println("Running"); } } """ ) def test_default_interface_method(self, tmp_path: Path): """Test context extraction for default interface method.""" java_file = tmp_path / "MyInterface.java" java_file.write_text("""public interface MyInterface { default String greet() { return "Hello"; } void doSomething(); } """) functions = discover_functions_from_source(java_file.read_text(), file_path=java_file) greet_func = next((f for f in functions if f.function_name == "greet"), None) assert greet_func is not None context = extract_code_context(greet_func, tmp_path) # Interface methods are wrapped in interface skeleton assert ( context.target_code == """public interface MyInterface { default String greet() { return "Hello"; } } """ ) assert context.read_only_context == "" class TestExtractCodeContextWithInnerClasses: """Tests for extract_code_context with inner/nested classes.""" def test_static_nested_class_method(self, tmp_path: Path): """Test context extraction for static nested class method.""" java_file = tmp_path / "Container.java" java_file.write_text("""public class Container { public static class Nested { public int compute(int x) { return x * 2; } } } """) functions = discover_functions_from_source(java_file.read_text(), file_path=java_file) compute_func = next((f for f in functions if f.function_name == "compute"), None) assert compute_func is not None context = extract_code_context(compute_func, tmp_path) # Inner class wrapped in outer class skeleton assert ( context.target_code == """public class Container { public static class Nested { public int compute(int x) { return x * 2; } } } """ ) assert context.read_only_context == "" def test_inner_class_method(self, tmp_path: Path): """Test context extraction for inner class method.""" java_file = tmp_path / "Outer.java" java_file.write_text("""public class Outer { private int value = 10; public class Inner { public int getValue() { return value; } } } """) functions = discover_functions_from_source(java_file.read_text(), file_path=java_file) get_func = next((f for f in functions if f.function_name == "getValue"), None) assert get_func is not None context = extract_code_context(get_func, tmp_path) # Inner class wrapped in outer class skeleton assert ( context.target_code == """public class Outer { public class Inner { public int getValue() { return value; } } } """ ) assert context.read_only_context == "" class TestExtractCodeContextWithEnumAndInterface: """Tests for extract_code_context with enums and interfaces.""" def test_enum_method(self, tmp_path: Path): """Test context extraction for enum method.""" java_file = tmp_path / "Operation.java" java_file.write_text("""public enum Operation { ADD, SUBTRACT, MULTIPLY, DIVIDE; public int apply(int a, int b) { switch (this) { case ADD: return a + b; case SUBTRACT: return a - b; case MULTIPLY: return a * b; case DIVIDE: return a / b; default: throw new AssertionError(); } } } """) functions = discover_functions_from_source(java_file.read_text(), file_path=java_file) apply_func = next((f for f in functions if f.function_name == "apply"), None) assert apply_func is not None context = extract_code_context(apply_func, tmp_path) # Enum methods are wrapped in enum skeleton with constants assert ( context.target_code == """public enum Operation { ADD, SUBTRACT, MULTIPLY, DIVIDE; public int apply(int a, int b) { switch (this) { case ADD: return a + b; case SUBTRACT: return a - b; case MULTIPLY: return a * b; case DIVIDE: return a / b; default: throw new AssertionError(); } } } """ ) assert context.read_only_context == "" def test_interface_default_method(self, tmp_path: Path): """Test context extraction for interface default method.""" java_file = tmp_path / "Greeting.java" java_file.write_text("""public interface Greeting { default String greet(String name) { return "Hello, " + name; } } """) functions = discover_functions_from_source(java_file.read_text(), file_path=java_file) greet_func = next((f for f in functions if f.function_name == "greet"), None) assert greet_func is not None context = extract_code_context(greet_func, tmp_path) # Interface methods are wrapped in interface skeleton assert ( context.target_code == """public interface Greeting { default String greet(String name) { return "Hello, " + name; } } """ ) assert context.read_only_context == "" def test_interface_static_method(self, tmp_path: Path): """Test context extraction for interface static method.""" java_file = tmp_path / "Factory.java" java_file.write_text("""public interface Factory { static Factory create() { return null; } } """) functions = discover_functions_from_source(java_file.read_text(), file_path=java_file) create_func = next((f for f in functions if f.function_name == "create"), None) assert create_func is not None context = extract_code_context(create_func, tmp_path) # Interface methods are wrapped in interface skeleton assert ( context.target_code == """public interface Factory { static Factory create() { return null; } } """ ) assert context.read_only_context == "" class TestExtractCodeContextEdgeCases: """Tests for extract_code_context edge cases.""" def test_empty_method(self, tmp_path: Path): """Test context extraction for empty method.""" java_file = tmp_path / "Empty.java" java_file.write_text("""public class Empty { public void doNothing() { } } """) functions = discover_functions_from_source( java_file.read_text(), file_path=java_file, filter_criteria=NO_RETURN_FILTER ) assert len(functions) == 1 context = extract_code_context(functions[0], tmp_path) assert ( context.target_code == """public class Empty { public void doNothing() { } } """ ) def test_single_line_method(self, tmp_path: Path): """Test context extraction for single-line method.""" java_file = tmp_path / "OneLiner.java" java_file.write_text("""public class OneLiner { public int get() { return 42; } } """) functions = discover_functions_from_source(java_file.read_text(), file_path=java_file) assert len(functions) == 1 context = extract_code_context(functions[0], tmp_path) assert ( context.target_code == """public class OneLiner { public int get() { return 42; } } """ ) def test_method_with_lambda(self, tmp_path: Path): """Test context extraction for method with lambda.""" java_file = tmp_path / "Functional.java" java_file.write_text("""public class Functional { public List filter(List items) { return items.stream() .filter(s -> s != null && !s.isEmpty()) .collect(Collectors.toList()); } } """) functions = discover_functions_from_source(java_file.read_text(), file_path=java_file) assert len(functions) == 1 context = extract_code_context(functions[0], tmp_path) assert ( context.target_code == """public class Functional { public List filter(List items) { return items.stream() .filter(s -> s != null && !s.isEmpty()) .collect(Collectors.toList()); } } """ ) def test_method_with_method_reference(self, tmp_path: Path): """Test context extraction for method with method reference.""" java_file = tmp_path / "Printer.java" java_file.write_text("""public class Printer { public List toUpper(List items) { return items.stream().map(String::toUpperCase).collect(Collectors.toList()); } } """) functions = discover_functions_from_source(java_file.read_text(), file_path=java_file) assert len(functions) == 1 context = extract_code_context(functions[0], tmp_path) assert ( context.target_code == """public class Printer { public List toUpper(List items) { return items.stream().map(String::toUpperCase).collect(Collectors.toList()); } } """ ) def test_deeply_nested_blocks(self, tmp_path: Path): """Test context extraction for method with deeply nested blocks.""" java_file = tmp_path / "Nested.java" java_file.write_text("""public class Nested { public int deepMethod(int n) { int result = 0; if (n > 0) { for (int i = 0; i < n; i++) { while (i > 0) { try { if (i % 2 == 0) { result += i; } } catch (Exception e) { result = -1; } break; } } } return result; } } """) functions = discover_functions_from_source(java_file.read_text(), file_path=java_file) assert len(functions) == 1 context = extract_code_context(functions[0], tmp_path) assert ( context.target_code == """public class Nested { public int deepMethod(int n) { int result = 0; if (n > 0) { for (int i = 0; i < n; i++) { while (i > 0) { try { if (i % 2 == 0) { result += i; } } catch (Exception e) { result = -1; } break; } } } return result; } } """ ) def test_unicode_in_source(self, tmp_path: Path): """Test context extraction for method with unicode characters.""" java_file = tmp_path / "Unicode.java" java_file.write_text("""public class Unicode { public String greet() { return "こんにちは世界"; } } """) functions = discover_functions_from_source(java_file.read_text(), file_path=java_file) assert len(functions) == 1 context = extract_code_context(functions[0], tmp_path) assert ( context.target_code == """public class Unicode { public String greet() { return "こんにちは世界"; } } """ ) def test_file_not_found(self, tmp_path: Path): """Test context extraction for missing file.""" from codeflash.discovery.functions_to_optimize import FunctionToOptimize from codeflash.models.function_types import FunctionParent missing_file = tmp_path / "NonExistent.java" func = FunctionToOptimize( function_name="test", file_path=missing_file, starting_line=1, ending_line=5, parents=[FunctionParent(name="Test", type="ClassDef")], language="java", ) context = extract_code_context(func, tmp_path) assert context.target_code == "" assert context.language == Language.JAVA assert context.target_file == missing_file def test_max_helper_depth_zero(self, tmp_path: Path): """Test context extraction with max_helper_depth=0.""" java_file = tmp_path / "Calculator.java" java_file.write_text("""public class Calculator { public int calculate(int x) { return helper(x); } private int helper(int x) { return x * 2; } } """) functions = discover_functions_from_source(java_file.read_text(), file_path=java_file) calc_func = next((f for f in functions if f.function_name == "calculate"), None) assert calc_func is not None context = extract_code_context(calc_func, tmp_path, max_helper_depth=0) # With max_depth=0, cross-file helpers should be empty, but same-file helpers are still found assert ( context.target_code == """public class Calculator { public int calculate(int x) { return helper(x); } } """ ) class TestExtractCodeContextWithConstructor: """Tests for extract_code_context with constructors in class skeleton.""" def test_class_with_constructor(self, tmp_path: Path): """Test context extraction includes constructor in skeleton.""" java_file = tmp_path / "Person.java" java_file.write_text("""public class Person { private String name; private int age; public Person(String name, int age) { this.name = name; this.age = age; } public String getName() { return name; } } """) functions = discover_functions_from_source(java_file.read_text(), file_path=java_file) get_func = next((f for f in functions if f.function_name == "getName"), None) assert get_func is not None context = extract_code_context(get_func, tmp_path) # Class skeleton includes fields and constructor assert ( context.target_code == """public class Person { private String name; private int age; public Person(String name, int age) { this.name = name; this.age = age; } public String getName() { return name; } } """ ) def test_class_with_multiple_constructors(self, tmp_path: Path): """Test context extraction includes all constructors in skeleton.""" java_file = tmp_path / "Config.java" java_file.write_text("""public class Config { private String name; private int value; public Config() { this("default", 0); } public Config(String name) { this(name, 0); } public Config(String name, int value) { this.name = name; this.value = value; } public String getName() { return name; } } """) functions = discover_functions_from_source(java_file.read_text(), file_path=java_file) get_func = next((f for f in functions if f.function_name == "getName"), None) assert get_func is not None context = extract_code_context(get_func, tmp_path) # Class skeleton includes fields and all constructors assert ( context.target_code == """public class Config { private String name; private int value; public Config() { this("default", 0); } public Config(String name) { this(name, 0); } public Config(String name, int value) { this.name = name; this.value = value; } public String getName() { return name; } } """ ) class TestExtractCodeContextFullIntegration: """Integration tests for extract_code_context with all components.""" def test_full_context_with_all_components(self, tmp_path: Path): """Test context extraction with imports, fields, and helpers.""" java_file = tmp_path / "Service.java" java_file.write_text("""package com.example; import java.util.List; import java.util.ArrayList; public class Service { private static final String PREFIX = "service_"; private List history = new ArrayList<>(); public String process(String input) { String result = transform(input); history.add(result); return result; } private String transform(String s) { return PREFIX + s.toUpperCase(); } } """) functions = discover_functions_from_source(java_file.read_text(), file_path=java_file) process_func = next((f for f in functions if f.function_name == "process"), None) assert process_func is not None context = extract_code_context(process_func, tmp_path) assert context.language == Language.JAVA assert context.target_file == java_file # Class skeleton includes fields assert ( context.target_code == """public class Service { private static final String PREFIX = "service_"; private List history = new ArrayList<>(); public String process(String input) { String result = transform(input); history.add(result); return result; } } """ ) assert context.imports == ["import java.util.List;", "import java.util.ArrayList;"] # Fields are in skeleton, so read_only_context is empty (no duplication) assert context.read_only_context == "" assert len(context.helper_functions) == 1 assert context.helper_functions[0].name == "transform" def test_complex_class_with_javadoc_and_annotations(self, tmp_path: Path): """Test context extraction for complex class with javadoc and annotations.""" java_file = tmp_path / "Calculator.java" java_file.write_text("""package com.example.math; import java.util.Objects; import static java.lang.Math.sqrt; public class Calculator { private double precision = 0.0001; /** * Calculates the square root using Newton's method. * @param n the number to calculate square root for * @return the approximate square root * @throws IllegalArgumentException if n is negative */ @SuppressWarnings("unused") public double sqrtNewton(double n) { if (n < 0) throw new IllegalArgumentException(); return approximate(n, n / 2); } private double approximate(double n, double guess) { double next = (guess + n / guess) / 2; if (Math.abs(guess - next) < precision) return next; return approximate(n, next); } } """) functions = discover_functions_from_source(java_file.read_text(), file_path=java_file) sqrt_func = next((f for f in functions if f.function_name == "sqrtNewton"), None) assert sqrt_func is not None context = extract_code_context(sqrt_func, tmp_path) assert context.language == Language.JAVA # Class skeleton includes fields and Javadoc assert ( context.target_code == """public class Calculator { private double precision = 0.0001; /** * Calculates the square root using Newton's method. * @param n the number to calculate square root for * @return the approximate square root * @throws IllegalArgumentException if n is negative */ @SuppressWarnings("unused") public double sqrtNewton(double n) { if (n < 0) throw new IllegalArgumentException(); return approximate(n, n / 2); } } """ ) assert context.imports == ["import java.util.Objects;", "import static java.lang.Math.sqrt;"] # Fields are in skeleton, so read_only_context is empty (no duplication) assert context.read_only_context == "" assert len(context.helper_functions) == 1 assert context.helper_functions[0].name == "approximate" class TestExtractClassContext: """Tests for extract_class_context.""" def test_extract_class_with_imports(self, tmp_path: Path): """Test extracting full class context with imports.""" java_file = tmp_path / "Calculator.java" java_file.write_text("""package com.example; import java.util.List; import java.util.ArrayList; public class Calculator { private List history = new ArrayList<>(); public int add(int a, int b) { int result = a + b; history.add(result); return result; } } """) context = extract_class_context(java_file, "Calculator") assert ( context == """package com.example; import java.util.List; import java.util.ArrayList; public class Calculator { private List history = new ArrayList<>(); public int add(int a, int b) { int result = a + b; history.add(result); return result; } }""" ) def test_extract_class_not_found(self, tmp_path: Path): """Test extracting non-existent class returns empty string.""" java_file = tmp_path / "Test.java" java_file.write_text("""public class Test { public void test() {} } """) context = extract_class_context(java_file, "NonExistent") assert context == "" def test_extract_class_missing_file(self, tmp_path: Path): """Test extracting from missing file returns empty string.""" missing_file = tmp_path / "Missing.java" context = extract_class_context(missing_file, "Missing") assert context == "" class TestExtractFunctionSourceStaleLineNumbers: """Tests for tree-sitter based function extraction resilience to stale line numbers. When running --all mode, a prior optimization may modify the source file, shifting line numbers for subsequent functions. The tree-sitter based extraction should still find the correct function by name. """ def test_extraction_with_stale_line_numbers(self): """Verify extraction works when pre-computed line numbers no longer match the source.""" # Original source: functionA at lines 2-4, functionB at lines 5-7 original_source = """public class Utils { public int functionA() { return 1; } public int functionB() { return 2; } } """ analyzer = get_java_analyzer() functions = discover_functions_from_source(original_source, file_path=Path("Utils.java")) func_b = [f for f in functions if f.function_name == "functionB"][0] original_b_start = func_b.starting_line # Simulate a prior optimization adding lines to functionA modified_source = """public class Utils { public int functionA() { int x = 1; int y = 2; int z = 3; return x + y + z; } public int functionB() { return 2; } } """ # func_b still has the STALE line numbers from the original source # With tree-sitter, extraction should still work correctly result = extract_function_source(modified_source, func_b, analyzer=analyzer) assert "functionB" in result assert "return 2;" in result def test_extraction_without_analyzer_uses_line_numbers(self): """Without analyzer, extraction falls back to pre-computed line numbers.""" source = """public class Utils { public int functionA() { return 1; } public int functionB() { return 2; } } """ functions = discover_functions_from_source(source, file_path=Path("Utils.java")) func_b = [f for f in functions if f.function_name == "functionB"][0] # Without analyzer, should still work with correct line numbers result = extract_function_source(source, func_b) assert "functionB" in result assert "return 2;" in result def test_extraction_with_javadoc_after_file_modification(self): """Verify Javadoc is included when using tree-sitter extraction on modified files.""" original_source = """public class Utils { /** Adds two numbers. */ public int add(int a, int b) { return a + b; } /** Subtracts two numbers. */ public int subtract(int a, int b) { return a - b; } } """ analyzer = get_java_analyzer() functions = discover_functions_from_source(original_source, file_path=Path("Utils.java")) func_sub = [f for f in functions if f.function_name == "subtract"][0] # Simulate prior optimization expanding the add method modified_source = """public class Utils { /** Adds two numbers. */ public int add(int a, int b) { // Optimized with null check if (a == 0) return b; if (b == 0) return a; return a + b; } /** Subtracts two numbers. */ public int subtract(int a, int b) { return a - b; } } """ result = extract_function_source(modified_source, func_sub, analyzer=analyzer) assert "/** Subtracts two numbers. */" in result assert "public int subtract" in result assert "return a - b;" in result def test_extraction_with_overloaded_methods(self): """Verify correct overload is selected using line proximity.""" source = """public class Utils { public int process(int x) { return x * 2; } public int process(int x, int y) { return x + y; } } """ analyzer = get_java_analyzer() functions = discover_functions_from_source(source, file_path=Path("Utils.java")) # Get the second overload (process(int, int)) func_two_args = [f for f in functions if f.function_name == "process" and f.ending_line > 4][0] result = extract_function_source(source, func_two_args, analyzer=analyzer) assert "int x, int y" in result assert "return x + y;" in result def test_extraction_function_not_found_falls_back(self): """If tree-sitter can't find the method, fall back to line numbers.""" source = """public class Utils { public int functionA() { return 1; } } """ analyzer = get_java_analyzer() functions = discover_functions_from_source(source, file_path=Path("Utils.java")) func_a = functions[0] # Create a copy with a non-existent name so tree-sitter can't find it from dataclasses import replace func_fake = replace(func_a, function_name="nonExistentMethod") # Should fall back to line-number extraction (which still works since source is unmodified) result = extract_function_source(source, func_fake, analyzer=analyzer) assert "functionA" in result assert "return 1;" in result