Merge pull request #1951 from codeflash-ai/cf-1085-cap-wildcard-import-expansion

fix: cap wildcard import expansion to avoid token explosion
This commit is contained in:
mashraf-222 2026-05-04 20:26:56 +03:00 committed by GitHub
commit 0a2ec48fa3
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 193 additions and 120 deletions

View file

@ -11,10 +11,11 @@ import logging
from typing import TYPE_CHECKING
from codeflash.code_utils.code_utils import encoded_tokens_len
from codeflash.languages.base import CodeContext, HelperFunction, Language
from codeflash.languages.base import CodeContext, HelperFunction
from codeflash.languages.java.discovery import discover_functions_from_source
from codeflash.languages.java.import_resolver import JavaImportResolver, find_helper_files
from codeflash.languages.java.parser import get_java_analyzer
from codeflash.languages.language_enum import Language
if TYPE_CHECKING:
from pathlib import Path
@ -22,7 +23,8 @@ if TYPE_CHECKING:
from tree_sitter import Node
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
from codeflash.languages.java.parser import JavaAnalyzer, JavaMethodNode
from codeflash.languages.java.import_resolver import ResolvedImport
from codeflash.languages.java.parser import JavaAnalyzer, JavaImportInfo, JavaMethodNode
logger = logging.getLogger(__name__)
@ -360,7 +362,7 @@ def _extract_type_declaration(type_node: Node, source_bytes: bytes, type_kind: s
# Keep old function name for backwards compatibility
def _extract_class_declaration(node, source_bytes):
def _extract_class_declaration(node: Node, source_bytes: bytes) -> str:
return _extract_type_declaration(node, source_bytes, "class")
@ -629,6 +631,8 @@ def _extract_function_source_by_lines(source: str, function: FunctionToOptimize)
start_line = function.doc_start_line or function.starting_line
end_line = function.ending_line
if start_line is None or end_line is None:
return ""
# Convert from 1-indexed to 0-indexed
start_idx = start_line - 1
@ -672,6 +676,8 @@ def find_helper_functions(
func_id = f"{file_path}:{func.qualified_name}"
if func_id not in visited_functions:
visited_functions.add(func_id)
if func.starting_line is None or func.ending_line is None:
continue
# Extract the function source using tree-sitter for resilient lookup
func_source = extract_function_source(source, func, analyzer=analyzer)
@ -795,7 +801,7 @@ def extract_read_only_context(source: str, function: FunctionToOptimize, analyze
return "\n".join(context_parts)
def _import_to_statement(import_info) -> str:
def _import_to_statement(import_info: JavaImportInfo) -> str:
"""Convert a JavaImportInfo to an import statement string.
Args:
@ -863,6 +869,10 @@ def extract_class_context(file_path: Path, class_name: str, analyzer: JavaAnalyz
# Maximum token budget for imported type skeletons to avoid bloating testgen context
IMPORTED_SKELETON_TOKEN_BUDGET = 4000
# Maximum types to expand from a single wildcard import before filtering to referenced types only.
# Packages with more types than this (e.g. org.jooq with 870+) would waste minutes of disk I/O
# and almost always exceed the token budget.
MAX_WILDCARD_TYPES_UNFILTERED = 50
def _extract_type_names_from_code(code: str, analyzer: JavaAnalyzer) -> set[str]:
@ -894,7 +904,11 @@ def _extract_type_names_from_code(code: str, analyzer: JavaAnalyzer) -> set[str]
def get_java_imported_type_skeletons(
imports: list, project_root: Path, module_root: Path | None, analyzer: JavaAnalyzer, target_code: str = ""
imports: list[JavaImportInfo],
project_root: Path,
module_root: Path | None,
analyzer: JavaAnalyzer,
target_code: str = "",
) -> str:
"""Extract type skeletons for project-internal imported types.
@ -929,14 +943,32 @@ def get_java_imported_type_skeletons(
priority_types = _extract_type_names_from_code(target_code, analyzer)
# Pre-resolve all imports, expanding wildcards into individual types
resolved_imports: list = []
resolved_imports: list[ResolvedImport] = []
for imp in imports:
if imp.is_wildcard:
# Expand wildcard imports (e.g., com.aerospike.client.policy.*) into individual types
expanded = resolver.expand_wildcard_import(imp.import_path)
# First try unfiltered expansion with a cap. If the package is small enough, take all types.
# If it's huge (e.g. org.jooq.* with 870+ types), filter to only types referenced in the target code.
expanded = resolver.expand_wildcard_import(imp.import_path, max_types=MAX_WILDCARD_TYPES_UNFILTERED + 1)
if len(expanded) > MAX_WILDCARD_TYPES_UNFILTERED:
if priority_types:
expanded = resolver.expand_wildcard_import(imp.import_path, filter_names=priority_types)
logger.debug(
"Wildcard %s.* exceeds %d types, filtered to %d referenced types",
imp.import_path,
MAX_WILDCARD_TYPES_UNFILTERED,
len(expanded),
)
else:
expanded = expanded[:MAX_WILDCARD_TYPES_UNFILTERED]
logger.debug(
"Wildcard %s.* exceeds %d types, capped (no target types to filter by)",
imp.import_path,
MAX_WILDCARD_TYPES_UNFILTERED,
)
elif expanded:
logger.debug("Expanded wildcard import %s.* into %d types", imp.import_path, len(expanded))
if expanded:
resolved_imports.extend(expanded)
logger.debug("Expanded wildcard import %s.* into %d types", imp.import_path, len(expanded))
continue
resolved = resolver.resolve_import(imp)
@ -956,7 +988,7 @@ def get_java_imported_type_skeletons(
for resolved in resolved_imports:
class_name = resolved.class_name
if not class_name:
if not class_name or resolved.file_path is None:
continue
dedup_key = (str(resolved.file_path), class_name)
@ -1078,8 +1110,6 @@ def _extract_public_method_signatures(source: str, class_name: str, analyzer: Ja
continue
node = method.node
if not node:
continue
# Check if the method is public
is_public = False

View file

@ -220,14 +220,20 @@ class JavaImportResolver:
return last_part
return None
def expand_wildcard_import(self, import_path: str) -> list[ResolvedImport]:
def expand_wildcard_import(
self, import_path: str, max_types: int = 0, filter_names: set[str] | None = None
) -> list[ResolvedImport]:
"""Expand a wildcard import (e.g., com.example.utils.*) to individual class imports.
Resolves the package path to a directory and returns a ResolvedImport for each
.java file found in that directory.
Args:
import_path: The package path (without the trailing .*).
max_types: Maximum number of types to return. 0 means no limit.
filter_names: If provided, only include types whose class name is in this set.
"""
# Convert package path to directory path
# e.g., "com.example.utils" -> "com/example/utils"
relative_dir = import_path.replace(".", "/")
resolved: list[ResolvedImport] = []
@ -237,8 +243,10 @@ class JavaImportResolver:
if candidate_dir.is_dir():
for java_file in candidate_dir.glob("*.java"):
class_name = java_file.stem
# Only include files that look like class names (start with uppercase)
if class_name and class_name[0].isupper():
if not class_name or not class_name[0].isupper():
continue
if filter_names is not None and class_name not in filter_names:
continue
resolved.append(
ResolvedImport(
import_path=f"{import_path}.{class_name}",
@ -248,6 +256,8 @@ class JavaImportResolver:
class_name=class_name,
)
)
if max_types and len(resolved) >= max_types:
return resolved
return resolved

View file

@ -2,7 +2,8 @@
from pathlib import Path
from codeflash.languages.base import FunctionFilterCriteria, Language
from codeflash.languages.base import FunctionFilterCriteria
from codeflash.languages.language_enum import Language
from codeflash.languages.java.context import (
TypeSkeleton,
_extract_public_method_signatures,
@ -23,7 +24,7 @@ NO_RETURN_FILTER = FunctionFilterCriteria(require_return=False)
class TestExtractCodeContextBasic:
"""Tests for basic extract_code_context functionality."""
def test_simple_method(self, tmp_path: Path):
def test_simple_method(self, tmp_path: Path) -> None:
"""Test extracting context for a simple method."""
java_file = tmp_path / "Calculator.java"
java_file.write_text("""public class Calculator {
@ -53,7 +54,7 @@ class TestExtractCodeContextBasic:
assert context.helper_functions == []
assert context.read_only_context == ""
def test_method_with_javadoc(self, tmp_path: Path):
def test_method_with_javadoc(self, tmp_path: Path) -> None:
"""Test extracting context for method with Javadoc."""
java_file = tmp_path / "Calculator.java"
java_file.write_text("""public class Calculator {
@ -94,7 +95,7 @@ class TestExtractCodeContextBasic:
assert context.helper_functions == []
assert context.read_only_context == ""
def test_static_method(self, tmp_path: Path):
def test_static_method(self, tmp_path: Path) -> None:
"""Test extracting context for a static method."""
java_file = tmp_path / "MathUtils.java"
java_file.write_text("""public class MathUtils {
@ -123,7 +124,7 @@ class TestExtractCodeContextBasic:
assert context.helper_functions == []
assert context.read_only_context == ""
def test_private_method(self, tmp_path: Path):
def test_private_method(self, tmp_path: Path) -> None:
"""Test extracting context for a private method."""
java_file = tmp_path / "Helper.java"
java_file.write_text("""public class Helper {
@ -149,7 +150,7 @@ class TestExtractCodeContextBasic:
"""
)
def test_protected_method(self, tmp_path: Path):
def test_protected_method(self, tmp_path: Path) -> None:
"""Test extracting context for a protected method."""
java_file = tmp_path / "Base.java"
java_file.write_text("""public class Base {
@ -175,7 +176,7 @@ class TestExtractCodeContextBasic:
"""
)
def test_synchronized_method(self, tmp_path: Path):
def test_synchronized_method(self, tmp_path: Path) -> None:
"""Test extracting context for a synchronized method."""
java_file = tmp_path / "Counter.java"
java_file.write_text("""public class Counter {
@ -200,7 +201,7 @@ class TestExtractCodeContextBasic:
"""
)
def test_method_with_throws(self, tmp_path: Path):
def test_method_with_throws(self, tmp_path: Path) -> None:
"""Test extracting context for a method with throws clause."""
java_file = tmp_path / "FileHandler.java"
java_file.write_text("""public class FileHandler {
@ -225,7 +226,7 @@ class TestExtractCodeContextBasic:
"""
)
def test_method_with_varargs(self, tmp_path: Path):
def test_method_with_varargs(self, tmp_path: Path) -> None:
"""Test extracting context for a method with varargs."""
java_file = tmp_path / "Logger.java"
java_file.write_text("""public class Logger {
@ -250,7 +251,7 @@ class TestExtractCodeContextBasic:
"""
)
def test_void_method(self, tmp_path: Path):
def test_void_method(self, tmp_path: Path) -> None:
"""Test extracting context for a void method."""
java_file = tmp_path / "Printer.java"
java_file.write_text("""public class Printer {
@ -277,7 +278,7 @@ class TestExtractCodeContextBasic:
"""
)
def test_generic_return_type(self, tmp_path: Path):
def test_generic_return_type(self, tmp_path: Path) -> None:
"""Test extracting context for a method with generic return type."""
java_file = tmp_path / "Container.java"
java_file.write_text("""public class Container {
@ -306,7 +307,7 @@ class TestExtractCodeContextBasic:
class TestExtractCodeContextWithImports:
"""Tests for extract_code_context with various import types."""
def test_with_package_and_imports(self, tmp_path: Path):
def test_with_package_and_imports(self, tmp_path: Path) -> None:
"""Test context extraction with package and imports."""
java_file = tmp_path / "Calculator.java"
java_file.write_text("""package com.example;
@ -344,7 +345,7 @@ public class Calculator {
# Fields are in skeleton, so read_only_context is empty
assert context.read_only_context == ""
def test_with_static_imports(self, tmp_path: Path):
def test_with_static_imports(self, tmp_path: Path) -> None:
"""Test context extraction with static imports."""
java_file = tmp_path / "Calculator.java"
java_file.write_text("""package com.example;
@ -380,7 +381,7 @@ public class Calculator {
"import static java.lang.Math.sqrt;",
]
def test_with_wildcard_imports(self, tmp_path: Path):
def test_with_wildcard_imports(self, tmp_path: Path) -> None:
"""Test context extraction with wildcard imports."""
java_file = tmp_path / "Processor.java"
java_file.write_text("""package com.example;
@ -402,7 +403,7 @@ public class Processor {
assert context.language == Language.JAVA
assert context.imports == ["import java.util.*;", "import java.io.*;"]
def test_with_multiple_import_types(self, tmp_path: Path):
def test_with_multiple_import_types(self, tmp_path: Path) -> None:
"""Test context extraction with various import types."""
java_file = tmp_path / "Handler.java"
java_file.write_text("""package com.example;
@ -453,7 +454,7 @@ class TestExtractCodeContextWithFields:
read_only_context should be empty to avoid duplication.
"""
def test_with_instance_fields(self, tmp_path: Path):
def test_with_instance_fields(self, tmp_path: Path) -> None:
"""Test context extraction with instance fields."""
java_file = tmp_path / "Person.java"
java_file.write_text("""public class Person {
@ -488,7 +489,7 @@ class TestExtractCodeContextWithFields:
assert context.imports == []
assert context.helper_functions == []
def test_with_static_fields(self, tmp_path: Path):
def test_with_static_fields(self, tmp_path: Path) -> None:
"""Test context extraction with static fields."""
java_file = tmp_path / "Counter.java"
java_file.write_text("""public class Counter {
@ -519,7 +520,7 @@ class TestExtractCodeContextWithFields:
# Fields are in skeleton, so read_only_context is empty
assert context.read_only_context == ""
def test_with_final_fields(self, tmp_path: Path):
def test_with_final_fields(self, tmp_path: Path) -> None:
"""Test context extraction with final fields."""
java_file = tmp_path / "Config.java"
java_file.write_text("""public class Config {
@ -549,7 +550,7 @@ class TestExtractCodeContextWithFields:
)
assert context.read_only_context == ""
def test_with_static_final_constants(self, tmp_path: Path):
def test_with_static_final_constants(self, tmp_path: Path) -> None:
"""Test context extraction with static final constants."""
java_file = tmp_path / "Constants.java"
java_file.write_text("""public class Constants {
@ -581,7 +582,7 @@ class TestExtractCodeContextWithFields:
)
assert context.read_only_context == ""
def test_with_volatile_fields(self, tmp_path: Path):
def test_with_volatile_fields(self, tmp_path: Path) -> None:
"""Test context extraction with volatile fields."""
java_file = tmp_path / "ThreadSafe.java"
java_file.write_text("""public class ThreadSafe {
@ -611,7 +612,7 @@ class TestExtractCodeContextWithFields:
)
assert context.read_only_context == ""
def test_with_generic_fields(self, tmp_path: Path):
def test_with_generic_fields(self, tmp_path: Path) -> None:
"""Test context extraction with generic type fields."""
java_file = tmp_path / "Container.java"
java_file.write_text("""public class Container {
@ -643,7 +644,7 @@ class TestExtractCodeContextWithFields:
)
assert context.read_only_context == ""
def test_with_array_fields(self, tmp_path: Path):
def test_with_array_fields(self, tmp_path: Path) -> None:
"""Test context extraction with array fields."""
java_file = tmp_path / "ArrayHolder.java"
java_file.write_text("""public class ArrayHolder {
@ -679,7 +680,7 @@ class TestExtractCodeContextWithFields:
class TestExtractCodeContextWithHelpers:
"""Tests for extract_code_context with helper functions."""
def test_single_helper_method(self, tmp_path: Path):
def test_single_helper_method(self, tmp_path: Path) -> None:
"""Test context extraction with a single helper method."""
java_file = tmp_path / "Processor.java"
java_file.write_text("""public class Processor {
@ -715,7 +716,7 @@ class TestExtractCodeContextWithHelpers:
== "private String normalize(String s) {\n return s.trim().toLowerCase();\n }"
)
def test_multiple_helper_methods(self, tmp_path: Path):
def test_multiple_helper_methods(self, tmp_path: Path) -> None:
"""Test context extraction with multiple helper methods."""
java_file = tmp_path / "Processor.java"
java_file.write_text("""public class Processor {
@ -758,7 +759,7 @@ class TestExtractCodeContextWithHelpers:
helper_names = sorted([h.name for h in context.helper_functions])
assert helper_names == ["trim", "upper"]
def test_chained_helper_calls(self, tmp_path: Path):
def test_chained_helper_calls(self, tmp_path: Path) -> None:
"""Test context extraction with chained helper calls."""
java_file = tmp_path / "Processor.java"
java_file.write_text("""public class Processor {
@ -784,7 +785,7 @@ class TestExtractCodeContextWithHelpers:
helper_names = [h.name for h in context.helper_functions]
assert helper_names == ["normalize"]
def test_no_helpers_when_none_called(self, tmp_path: Path):
def test_no_helpers_when_none_called(self, tmp_path: Path) -> None:
"""Test context extraction when no helpers are called."""
java_file = tmp_path / "Calculator.java"
java_file.write_text("""public class Calculator {
@ -814,7 +815,7 @@ class TestExtractCodeContextWithHelpers:
)
assert context.helper_functions == []
def test_static_helper_from_instance_method(self, tmp_path: Path):
def test_static_helper_from_instance_method(self, tmp_path: Path) -> None:
"""Test context extraction with static helper called from instance method."""
java_file = tmp_path / "Calculator.java"
java_file.write_text("""public class Calculator {
@ -840,7 +841,7 @@ class TestExtractCodeContextWithHelpers:
class TestExtractCodeContextWithJavadoc:
"""Tests for extract_code_context with various Javadoc patterns."""
def test_simple_javadoc(self, tmp_path: Path):
def test_simple_javadoc(self, tmp_path: Path) -> None:
"""Test context extraction with simple Javadoc."""
java_file = tmp_path / "Example.java"
java_file.write_text("""public class Example {
@ -866,7 +867,7 @@ class TestExtractCodeContextWithJavadoc:
"""
)
def test_javadoc_with_params(self, tmp_path: Path):
def test_javadoc_with_params(self, tmp_path: Path) -> None:
"""Test context extraction with Javadoc @param tags."""
java_file = tmp_path / "Calculator.java"
java_file.write_text("""public class Calculator {
@ -900,7 +901,7 @@ class TestExtractCodeContextWithJavadoc:
"""
)
def test_javadoc_with_return(self, tmp_path: Path):
def test_javadoc_with_return(self, tmp_path: Path) -> None:
"""Test context extraction with Javadoc @return tag."""
java_file = tmp_path / "Calculator.java"
java_file.write_text("""public class Calculator {
@ -932,7 +933,7 @@ class TestExtractCodeContextWithJavadoc:
"""
)
def test_javadoc_with_throws(self, tmp_path: Path):
def test_javadoc_with_throws(self, tmp_path: Path) -> None:
"""Test context extraction with Javadoc @throws tag."""
java_file = tmp_path / "Divider.java"
java_file.write_text("""public class Divider {
@ -968,7 +969,7 @@ class TestExtractCodeContextWithJavadoc:
"""
)
def test_javadoc_multiline(self, tmp_path: Path):
def test_javadoc_multiline(self, tmp_path: Path) -> None:
"""Test context extraction with multi-paragraph Javadoc."""
java_file = tmp_path / "Complex.java"
java_file.write_text("""public class Complex {
@ -1020,7 +1021,7 @@ class TestExtractCodeContextWithJavadoc:
class TestExtractCodeContextWithGenerics:
"""Tests for extract_code_context with generic types."""
def test_generic_method_type_parameter(self, tmp_path: Path):
def test_generic_method_type_parameter(self, tmp_path: Path) -> None:
"""Test context extraction with generic type parameter."""
java_file = tmp_path / "Utils.java"
java_file.write_text("""public class Utils {
@ -1044,7 +1045,7 @@ class TestExtractCodeContextWithGenerics:
"""
)
def test_bounded_type_parameter(self, tmp_path: Path):
def test_bounded_type_parameter(self, tmp_path: Path) -> None:
"""Test context extraction with bounded type parameter."""
java_file = tmp_path / "Statistics.java"
java_file.write_text("""public class Statistics {
@ -1076,7 +1077,7 @@ class TestExtractCodeContextWithGenerics:
"""
)
def test_wildcard_type(self, tmp_path: Path):
def test_wildcard_type(self, tmp_path: Path) -> None:
"""Test context extraction with wildcard type."""
java_file = tmp_path / "Printer.java"
java_file.write_text("""public class Printer {
@ -1100,7 +1101,7 @@ class TestExtractCodeContextWithGenerics:
"""
)
def test_bounded_wildcard_extends(self, tmp_path: Path):
def test_bounded_wildcard_extends(self, tmp_path: Path) -> None:
"""Test context extraction with upper bounded wildcard."""
java_file = tmp_path / "Aggregator.java"
java_file.write_text("""public class Aggregator {
@ -1132,7 +1133,7 @@ class TestExtractCodeContextWithGenerics:
"""
)
def test_bounded_wildcard_super(self, tmp_path: Path):
def test_bounded_wildcard_super(self, tmp_path: Path) -> None:
"""Test context extraction with lower bounded wildcard."""
java_file = tmp_path / "Filler.java"
java_file.write_text("""public class Filler {
@ -1158,7 +1159,7 @@ class TestExtractCodeContextWithGenerics:
"""
)
def test_multiple_type_parameters(self, tmp_path: Path):
def test_multiple_type_parameters(self, tmp_path: Path) -> None:
"""Test context extraction with multiple type parameters."""
java_file = tmp_path / "Mapper.java"
java_file.write_text("""public class Mapper {
@ -1190,7 +1191,7 @@ class TestExtractCodeContextWithGenerics:
"""
)
def test_recursive_type_bound(self, tmp_path: Path):
def test_recursive_type_bound(self, tmp_path: Path) -> None:
"""Test context extraction with recursive type bound."""
java_file = tmp_path / "Sorter.java"
java_file.write_text("""public class Sorter {
@ -1218,7 +1219,7 @@ class TestExtractCodeContextWithGenerics:
class TestExtractCodeContextWithAnnotations:
"""Tests for extract_code_context with annotations."""
def test_override_annotation(self, tmp_path: Path):
def test_override_annotation(self, tmp_path: Path) -> None:
"""Test context extraction with @Override annotation."""
java_file = tmp_path / "Child.java"
java_file.write_text("""public class Child extends Parent {
@ -1244,7 +1245,7 @@ class TestExtractCodeContextWithAnnotations:
"""
)
def test_deprecated_annotation(self, tmp_path: Path):
def test_deprecated_annotation(self, tmp_path: Path) -> None:
"""Test context extraction with @Deprecated annotation."""
java_file = tmp_path / "Legacy.java"
java_file.write_text("""public class Legacy {
@ -1270,7 +1271,7 @@ class TestExtractCodeContextWithAnnotations:
"""
)
def test_suppress_warnings_annotation(self, tmp_path: Path):
def test_suppress_warnings_annotation(self, tmp_path: Path) -> None:
"""Test context extraction with @SuppressWarnings annotation."""
java_file = tmp_path / "Processor.java"
java_file.write_text("""public class Processor {
@ -1296,7 +1297,7 @@ class TestExtractCodeContextWithAnnotations:
"""
)
def test_multiple_annotations(self, tmp_path: Path):
def test_multiple_annotations(self, tmp_path: Path) -> None:
"""Test context extraction with multiple annotations."""
java_file = tmp_path / "Service.java"
java_file.write_text("""public class Service {
@ -1326,7 +1327,7 @@ class TestExtractCodeContextWithAnnotations:
"""
)
def test_annotation_with_array_value(self, tmp_path: Path):
def test_annotation_with_array_value(self, tmp_path: Path) -> None:
"""Test context extraction with annotation array value."""
java_file = tmp_path / "Handler.java"
java_file.write_text("""public class Handler {
@ -1356,7 +1357,7 @@ class TestExtractCodeContextWithAnnotations:
class TestExtractCodeContextWithInheritance:
"""Tests for extract_code_context with inheritance scenarios."""
def test_method_in_subclass(self, tmp_path: Path):
def test_method_in_subclass(self, tmp_path: Path) -> None:
"""Test context extraction for method in subclass."""
java_file = tmp_path / "AdvancedCalc.java"
java_file.write_text("""public class AdvancedCalc extends Calculator {
@ -1382,7 +1383,7 @@ class TestExtractCodeContextWithInheritance:
"""
)
def test_interface_implementation(self, tmp_path: Path):
def test_interface_implementation(self, tmp_path: Path) -> None:
"""Test context extraction for interface implementation."""
java_file = tmp_path / "MyComparable.java"
java_file.write_text("""public class MyComparable implements Comparable<MyComparable> {
@ -1414,7 +1415,7 @@ class TestExtractCodeContextWithInheritance:
# Fields are in skeleton, so read_only_context is empty (no duplication)
assert context.read_only_context == ""
def test_multiple_interfaces(self, tmp_path: Path):
def test_multiple_interfaces(self, tmp_path: Path) -> None:
"""Test context extraction for multiple interface implementations."""
java_file = tmp_path / "MultiImpl.java"
java_file.write_text("""public class MultiImpl implements Runnable, Comparable<MultiImpl> {
@ -1446,7 +1447,7 @@ class TestExtractCodeContextWithInheritance:
"""
)
def test_default_interface_method(self, tmp_path: Path):
def test_default_interface_method(self, tmp_path: Path) -> None:
"""Test context extraction for default interface method."""
java_file = tmp_path / "MyInterface.java"
java_file.write_text("""public interface MyInterface {
@ -1479,7 +1480,7 @@ class TestExtractCodeContextWithInheritance:
class TestExtractCodeContextWithInnerClasses:
"""Tests for extract_code_context with inner/nested classes."""
def test_static_nested_class_method(self, tmp_path: Path):
def test_static_nested_class_method(self, tmp_path: Path) -> None:
"""Inner class methods are excluded from discovery and cannot be context-extracted.
Methods of static nested classes are skipped in discovery because they
@ -1499,7 +1500,7 @@ class TestExtractCodeContextWithInnerClasses:
# Inner class method must NOT be discovered
assert compute_func is None
def test_inner_class_method(self, tmp_path: Path):
def test_inner_class_method(self, tmp_path: Path) -> None:
"""Inner class methods are excluded from discovery and cannot be context-extracted.
Methods of non-static inner classes are skipped in discovery because they
@ -1525,7 +1526,7 @@ class TestExtractCodeContextWithInnerClasses:
class TestExtractCodeContextWithEnumAndInterface:
"""Tests for extract_code_context with enums and interfaces."""
def test_enum_method(self, tmp_path: Path):
def test_enum_method(self, tmp_path: Path) -> None:
"""Test context extraction for enum method."""
java_file = tmp_path / "Operation.java"
java_file.write_text("""public enum Operation {
@ -1568,7 +1569,7 @@ class TestExtractCodeContextWithEnumAndInterface:
)
assert context.read_only_context == ""
def test_interface_default_method(self, tmp_path: Path):
def test_interface_default_method(self, tmp_path: Path) -> None:
"""Test context extraction for interface default method."""
java_file = tmp_path / "Greeting.java"
java_file.write_text("""public interface Greeting {
@ -1595,7 +1596,7 @@ class TestExtractCodeContextWithEnumAndInterface:
)
assert context.read_only_context == ""
def test_interface_static_method(self, tmp_path: Path):
def test_interface_static_method(self, tmp_path: Path) -> None:
"""Test context extraction for interface static method."""
java_file = tmp_path / "Factory.java"
java_file.write_text("""public interface Factory {
@ -1626,7 +1627,7 @@ class TestExtractCodeContextWithEnumAndInterface:
class TestExtractCodeContextEdgeCases:
"""Tests for extract_code_context edge cases."""
def test_empty_method(self, tmp_path: Path):
def test_empty_method(self, tmp_path: Path) -> None:
"""Test context extraction for empty method."""
java_file = tmp_path / "Empty.java"
java_file.write_text("""public class Empty {
@ -1650,7 +1651,7 @@ class TestExtractCodeContextEdgeCases:
"""
)
def test_single_line_method(self, tmp_path: Path):
def test_single_line_method(self, tmp_path: Path) -> None:
"""Test context extraction for single-line method."""
java_file = tmp_path / "OneLiner.java"
java_file.write_text("""public class OneLiner {
@ -1670,7 +1671,7 @@ class TestExtractCodeContextEdgeCases:
"""
)
def test_method_with_lambda(self, tmp_path: Path):
def test_method_with_lambda(self, tmp_path: Path) -> None:
"""Test context extraction for method with lambda."""
java_file = tmp_path / "Functional.java"
java_file.write_text("""public class Functional {
@ -1698,7 +1699,7 @@ class TestExtractCodeContextEdgeCases:
"""
)
def test_method_with_method_reference(self, tmp_path: Path):
def test_method_with_method_reference(self, tmp_path: Path) -> None:
"""Test context extraction for method with method reference."""
java_file = tmp_path / "Printer.java"
java_file.write_text("""public class Printer {
@ -1722,7 +1723,7 @@ class TestExtractCodeContextEdgeCases:
"""
)
def test_deeply_nested_blocks(self, tmp_path: Path):
def test_deeply_nested_blocks(self, tmp_path: Path) -> None:
"""Test context extraction for method with deeply nested blocks."""
java_file = tmp_path / "Nested.java"
java_file.write_text("""public class Nested {
@ -1776,7 +1777,7 @@ class TestExtractCodeContextEdgeCases:
"""
)
def test_unicode_in_source(self, tmp_path: Path):
def test_unicode_in_source(self, tmp_path: Path) -> None:
"""Test context extraction for method with unicode characters."""
java_file = tmp_path / "Unicode.java"
java_file.write_text(
@ -1803,7 +1804,7 @@ class TestExtractCodeContextEdgeCases:
"""
)
def test_file_not_found(self, tmp_path: Path):
def test_file_not_found(self, tmp_path: Path) -> None:
"""Test context extraction for missing file."""
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
from codeflash.models.function_types import FunctionParent
@ -1824,7 +1825,7 @@ class TestExtractCodeContextEdgeCases:
assert context.language == Language.JAVA
assert context.target_file == missing_file
def test_max_helper_depth_zero(self, tmp_path: Path):
def test_max_helper_depth_zero(self, tmp_path: Path) -> None:
"""Test context extraction with max_helper_depth=0."""
java_file = tmp_path / "Calculator.java"
java_file.write_text("""public class Calculator {
@ -1858,7 +1859,7 @@ class TestExtractCodeContextEdgeCases:
class TestExtractCodeContextWithConstructor:
"""Tests for extract_code_context with constructors in class skeleton."""
def test_class_with_constructor(self, tmp_path: Path):
def test_class_with_constructor(self, tmp_path: Path) -> None:
"""Test context extraction includes constructor in skeleton."""
java_file = tmp_path / "Person.java"
java_file.write_text("""public class Person {
@ -1898,7 +1899,7 @@ class TestExtractCodeContextWithConstructor:
"""
)
def test_class_with_multiple_constructors(self, tmp_path: Path):
def test_class_with_multiple_constructors(self, tmp_path: Path) -> None:
"""Test context extraction includes all constructors in skeleton."""
java_file = tmp_path / "Config.java"
java_file.write_text("""public class Config {
@ -1956,7 +1957,7 @@ class TestExtractCodeContextWithConstructor:
class TestExtractCodeContextFullIntegration:
"""Integration tests for extract_code_context with all components."""
def test_full_context_with_all_components(self, tmp_path: Path):
def test_full_context_with_all_components(self, tmp_path: Path) -> None:
"""Test context extraction with imports, fields, and helpers."""
java_file = tmp_path / "Service.java"
java_file.write_text("""package com.example;
@ -2007,7 +2008,7 @@ public class Service {
assert len(context.helper_functions) == 1
assert context.helper_functions[0].name == "transform"
def test_complex_class_with_javadoc_and_annotations(self, tmp_path: Path):
def test_complex_class_with_javadoc_and_annotations(self, tmp_path: Path) -> None:
"""Test context extraction for complex class with javadoc and annotations."""
java_file = tmp_path / "Calculator.java"
java_file.write_text("""package com.example.math;
@ -2073,7 +2074,7 @@ public class Calculator {
class TestExtractClassContext:
"""Tests for extract_class_context."""
def test_extract_class_with_imports(self, tmp_path: Path):
def test_extract_class_with_imports(self, tmp_path: Path) -> None:
"""Test extracting full class context with imports."""
java_file = tmp_path / "Calculator.java"
java_file.write_text("""package com.example;
@ -2112,7 +2113,7 @@ public class Calculator {
}"""
)
def test_extract_class_not_found(self, tmp_path: Path):
def test_extract_class_not_found(self, tmp_path: Path) -> None:
"""Test extracting non-existent class returns empty string."""
java_file = tmp_path / "Test.java"
java_file.write_text("""public class Test {
@ -2124,7 +2125,7 @@ public class Calculator {
assert context == ""
def test_extract_class_missing_file(self, tmp_path: Path):
def test_extract_class_missing_file(self, tmp_path: Path) -> None:
"""Test extracting from missing file returns empty string."""
missing_file = tmp_path / "Missing.java"
@ -2141,7 +2142,7 @@ class TestExtractFunctionSourceStaleLineNumbers:
extraction should still find the correct function by name.
"""
def test_extraction_with_stale_line_numbers(self):
def test_extraction_with_stale_line_numbers(self) -> None:
"""Verify extraction works when pre-computed line numbers no longer match the source."""
# Original source: functionA at lines 2-4, functionB at lines 5-7
original_source = """public class Utils {
@ -2177,7 +2178,7 @@ class TestExtractFunctionSourceStaleLineNumbers:
assert "functionB" in result
assert "return 2;" in result
def test_extraction_without_analyzer_uses_line_numbers(self):
def test_extraction_without_analyzer_uses_line_numbers(self) -> None:
"""Without analyzer, extraction falls back to pre-computed line numbers."""
source = """public class Utils {
public int functionA() {
@ -2196,7 +2197,7 @@ class TestExtractFunctionSourceStaleLineNumbers:
assert "functionB" in result
assert "return 2;" in result
def test_extraction_with_javadoc_after_file_modification(self):
def test_extraction_with_javadoc_after_file_modification(self) -> None:
"""Verify Javadoc is included when using tree-sitter extraction on modified files."""
original_source = """public class Utils {
/** Adds two numbers. */
@ -2233,7 +2234,7 @@ class TestExtractFunctionSourceStaleLineNumbers:
assert "public int subtract" in result
assert "return a - b;" in result
def test_extraction_with_overloaded_methods(self):
def test_extraction_with_overloaded_methods(self) -> None:
"""Verify correct overload is selected using line proximity."""
source = """public class Utils {
public int process(int x) {
@ -2247,13 +2248,15 @@ class TestExtractFunctionSourceStaleLineNumbers:
analyzer = get_java_analyzer()
functions = discover_functions_from_source(source, file_path=Path("Utils.java"))
# Get the second overload (process(int, int))
func_two_args = [f for f in functions if f.function_name == "process" and f.ending_line > 4][0]
func_two_args = [
f for f in functions if f.function_name == "process" and f.ending_line is not None and f.ending_line > 4
][0]
result = extract_function_source(source, func_two_args, analyzer=analyzer)
assert "int x, int y" in result
assert "return x + y;" in result
def test_extraction_function_not_found_falls_back(self):
def test_extraction_function_not_found_falls_back(self) -> None:
"""If tree-sitter can't find the method, fall back to line numbers."""
source = """public class Utils {
public int functionA() {
@ -2282,7 +2285,7 @@ FIXTURE_DIR = Path(__file__).parent.parent / "fixtures" / "java_maven"
class TestGetJavaImportedTypeSkeletons:
"""Tests for get_java_imported_type_skeletons()."""
def test_resolves_internal_imports(self):
def test_resolves_internal_imports(self) -> None:
"""Verify that project-internal imports are resolved and skeletons extracted."""
project_root = FIXTURE_DIR
module_root = FIXTURE_DIR / "src" / "main" / "java"
@ -2297,7 +2300,7 @@ class TestGetJavaImportedTypeSkeletons:
assert "MathHelper" in result
assert "Formatter" in result
def test_skeletons_contain_method_signatures(self):
def test_skeletons_contain_method_signatures(self) -> None:
"""Verify extracted skeletons include public method signatures."""
project_root = FIXTURE_DIR
module_root = FIXTURE_DIR / "src" / "main" / "java"
@ -2313,7 +2316,7 @@ class TestGetJavaImportedTypeSkeletons:
assert "multiply" in result
assert "factorial" in result
def test_skips_external_imports(self):
def test_skips_external_imports(self) -> None:
"""Verify that standard library and external imports are skipped."""
project_root = FIXTURE_DIR
module_root = FIXTURE_DIR / "src" / "main" / "java"
@ -2328,7 +2331,7 @@ class TestGetJavaImportedTypeSkeletons:
# No internal imports → empty result
assert result == ""
def test_deduplicates_imports(self):
def test_deduplicates_imports(self) -> None:
"""Verify that the same type imported twice is only included once."""
project_root = FIXTURE_DIR
module_root = FIXTURE_DIR / "src" / "main" / "java"
@ -2344,7 +2347,7 @@ class TestGetJavaImportedTypeSkeletons:
# Count occurrences of MathHelper — should appear exactly once
assert result.count("class MathHelper") == 1
def test_empty_imports_returns_empty(self):
def test_empty_imports_returns_empty(self) -> None:
"""Verify that empty import list returns empty string."""
project_root = FIXTURE_DIR
analyzer = get_java_analyzer()
@ -2353,7 +2356,7 @@ class TestGetJavaImportedTypeSkeletons:
assert result == ""
def test_respects_token_budget(self):
def test_respects_token_budget(self) -> None:
"""Verify that the function stops when token budget is exceeded."""
project_root = FIXTURE_DIR
module_root = FIXTURE_DIR / "src" / "main" / "java"
@ -2378,7 +2381,7 @@ class TestGetJavaImportedTypeSkeletons:
class TestExtractPublicMethodSignatures:
"""Tests for _extract_public_method_signatures()."""
def test_extracts_public_methods(self):
def test_extracts_public_methods(self) -> None:
"""Verify public method signatures are extracted."""
source = """public class Foo {
public int add(int a, int b) {
@ -2398,7 +2401,7 @@ class TestExtractPublicMethodSignatures:
# private method should not be included
assert not any("secret" in s for s in sigs)
def test_excludes_constructors(self):
def test_excludes_constructors(self) -> None:
"""Verify constructors are excluded from method signatures."""
source = """public class Bar {
public Bar(int x) { this.x = x; }
@ -2411,7 +2414,7 @@ class TestExtractPublicMethodSignatures:
assert "getX" in sigs[0]
assert not any("Bar(" in s for s in sigs)
def test_empty_class_returns_empty(self):
def test_empty_class_returns_empty(self) -> None:
"""Verify empty class returns no signatures."""
source = """public class Empty {}"""
analyzer = get_java_analyzer()
@ -2419,7 +2422,7 @@ class TestExtractPublicMethodSignatures:
assert sigs == []
def test_filters_by_class_name(self):
def test_filters_by_class_name(self) -> None:
"""Verify only methods from the specified class are returned."""
source = """public class A {
public int aMethod() { return 1; }
@ -2440,7 +2443,7 @@ class B {
class TestFormatSkeletonForContext:
"""Tests for _format_skeleton_for_context()."""
def test_formats_basic_skeleton(self):
def test_formats_basic_skeleton(self) -> None:
"""Verify basic skeleton formatting with fields and constructors."""
source = """public class Widget {
private int size;
@ -2467,7 +2470,7 @@ class TestFormatSkeletonForContext:
assert "getSize" in result
assert result.endswith("}")
def test_formats_enum_skeleton(self):
def test_formats_enum_skeleton(self) -> None:
"""Verify enum formatting includes constants."""
source = """public enum Color {
RED, GREEN, BLUE;
@ -2490,7 +2493,7 @@ class TestFormatSkeletonForContext:
assert "RED, GREEN, BLUE;" in result
assert "lower" in result
def test_formats_empty_class(self):
def test_formats_empty_class(self) -> None:
"""Verify formatting of a class with no fields or methods."""
source = """public class Empty {}"""
analyzer = get_java_analyzer()
@ -2512,7 +2515,7 @@ class TestFormatSkeletonForContext:
class TestGetJavaImportedTypeSkeletonsEdgeCases:
"""Additional edge case tests for get_java_imported_type_skeletons()."""
def test_wildcard_imports_are_expanded(self):
def test_wildcard_imports_are_expanded(self) -> None:
"""Wildcard imports (e.g., import com.example.helpers.*) are expanded to individual types."""
project_root = FIXTURE_DIR
module_root = FIXTURE_DIR / "src" / "main" / "java"
@ -2530,7 +2533,37 @@ class TestGetJavaImportedTypeSkeletonsEdgeCases:
# Wildcard imports should now be expanded to individual classes found in the package directory
assert "MathHelper" in result
def test_import_to_nonexistent_class_in_file(self):
def test_large_wildcard_is_filtered_to_referenced_types(self, tmp_path: Path) -> None:
"""When wildcard expands to >50 types, only types referenced in target code are included."""
from codeflash.languages.java.context import MAX_WILDCARD_TYPES_UNFILTERED
# Create a minimal Maven project structure so the resolver finds source roots
(tmp_path / "pom.xml").write_text("<project/>", encoding="utf-8")
pkg_dir = tmp_path / "src" / "main" / "java" / "com" / "bigpkg"
pkg_dir.mkdir(parents=True)
for i in range(MAX_WILDCARD_TYPES_UNFILTERED + 20):
(pkg_dir / f"Type{i:03d}.java").write_text(
f"package com.bigpkg;\npublic class Type{i:03d} {{ public int val() {{ return {i}; }} }}\n",
encoding="utf-8",
)
analyzer = get_java_analyzer()
# Target code references Type000 and Type001 only
target_code = "Type000 a = new Type000(); Type001 b = a.transform();"
source = "package com.example;\nimport com.bigpkg.*;\npublic class Foo { void bar() {} }"
imports = analyzer.find_imports(source)
result = get_java_imported_type_skeletons(
imports, tmp_path, tmp_path / "src" / "main" / "java", analyzer, target_code=target_code
)
# Only referenced types should appear, not all 70
assert "Type000" in result
assert "Type001" in result
# Types not referenced in target code should be excluded
assert "Type050" not in result
def test_import_to_nonexistent_class_in_file(self) -> None:
"""When an import resolves to a file but the class doesn't exist in it, skeleton extraction returns None."""
analyzer = get_java_analyzer()
@ -2540,7 +2573,7 @@ class TestGetJavaImportedTypeSkeletonsEdgeCases:
assert skeleton is None
def test_skeleton_output_is_well_formed(self):
def test_skeleton_output_is_well_formed(self) -> None:
"""Verify the skeleton string has proper Java-like structure with braces."""
project_root = FIXTURE_DIR
module_root = FIXTURE_DIR / "src" / "main" / "java"
@ -2563,7 +2596,7 @@ class TestGetJavaImportedTypeSkeletonsEdgeCases:
class TestExtractPublicMethodSignaturesEdgeCases:
"""Additional edge case tests for _extract_public_method_signatures()."""
def test_excludes_protected_and_package_private(self):
def test_excludes_protected_and_package_private(self) -> None:
"""Verify protected and package-private methods are excluded."""
source = """public class Visibility {
public int publicMethod() { return 1; }
@ -2580,7 +2613,7 @@ class TestExtractPublicMethodSignaturesEdgeCases:
assert not any("packagePrivateMethod" in s for s in sigs)
assert not any("privateMethod" in s for s in sigs)
def test_handles_overloaded_methods(self):
def test_handles_overloaded_methods(self) -> None:
"""Verify all public overloads are extracted."""
source = """public class Overloaded {
public int process(int x) { return x; }
@ -2594,7 +2627,7 @@ class TestExtractPublicMethodSignaturesEdgeCases:
# All should contain "process"
assert all("process" in s for s in sigs)
def test_handles_generic_methods(self):
def test_handles_generic_methods(self) -> None:
"""Verify generic method signatures are extracted correctly."""
source = """public class Generic {
public <T> T identity(T value) { return value; }
@ -2611,7 +2644,7 @@ class TestExtractPublicMethodSignaturesEdgeCases:
class TestFormatSkeletonRoundTrip:
"""Tests that verify _extract_type_skeleton → _format_skeleton_for_context produces valid output."""
def test_round_trip_produces_valid_skeleton(self):
def test_round_trip_produces_valid_skeleton(self) -> None:
"""Extract a real skeleton and format it — verify the output is sensible."""
source = """public class Service {
private final String name;
@ -2659,7 +2692,7 @@ class TestFormatSkeletonRoundTrip:
# Should end properly
assert result.strip().endswith("}")
def test_round_trip_with_fixture_mathhelper(self):
def test_round_trip_with_fixture_mathhelper(self) -> None:
"""Round-trip test using the real MathHelper fixture file."""
source = (FIXTURE_DIR / "src" / "main" / "java" / "com" / "example" / "helpers" / "MathHelper.java").read_text()
analyzer = get_java_analyzer()