mirror of
https://github.com/codeflash-ai/codeflash.git
synced 2026-05-04 18:25:17 +00:00
Merge pull request #1951 from codeflash-ai/cf-1085-cap-wildcard-import-expansion
fix: cap wildcard import expansion to avoid token explosion
This commit is contained in:
commit
0a2ec48fa3
3 changed files with 193 additions and 120 deletions
|
|
@ -11,10 +11,11 @@ import logging
|
|||
from typing import TYPE_CHECKING
|
||||
|
||||
from codeflash.code_utils.code_utils import encoded_tokens_len
|
||||
from codeflash.languages.base import CodeContext, HelperFunction, Language
|
||||
from codeflash.languages.base import CodeContext, HelperFunction
|
||||
from codeflash.languages.java.discovery import discover_functions_from_source
|
||||
from codeflash.languages.java.import_resolver import JavaImportResolver, find_helper_files
|
||||
from codeflash.languages.java.parser import get_java_analyzer
|
||||
from codeflash.languages.language_enum import Language
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from pathlib import Path
|
||||
|
|
@ -22,7 +23,8 @@ if TYPE_CHECKING:
|
|||
from tree_sitter import Node
|
||||
|
||||
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
|
||||
from codeflash.languages.java.parser import JavaAnalyzer, JavaMethodNode
|
||||
from codeflash.languages.java.import_resolver import ResolvedImport
|
||||
from codeflash.languages.java.parser import JavaAnalyzer, JavaImportInfo, JavaMethodNode
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
|
@ -360,7 +362,7 @@ def _extract_type_declaration(type_node: Node, source_bytes: bytes, type_kind: s
|
|||
|
||||
|
||||
# Keep old function name for backwards compatibility
|
||||
def _extract_class_declaration(node, source_bytes):
|
||||
def _extract_class_declaration(node: Node, source_bytes: bytes) -> str:
|
||||
return _extract_type_declaration(node, source_bytes, "class")
|
||||
|
||||
|
||||
|
|
@ -629,6 +631,8 @@ def _extract_function_source_by_lines(source: str, function: FunctionToOptimize)
|
|||
|
||||
start_line = function.doc_start_line or function.starting_line
|
||||
end_line = function.ending_line
|
||||
if start_line is None or end_line is None:
|
||||
return ""
|
||||
|
||||
# Convert from 1-indexed to 0-indexed
|
||||
start_idx = start_line - 1
|
||||
|
|
@ -672,6 +676,8 @@ def find_helper_functions(
|
|||
func_id = f"{file_path}:{func.qualified_name}"
|
||||
if func_id not in visited_functions:
|
||||
visited_functions.add(func_id)
|
||||
if func.starting_line is None or func.ending_line is None:
|
||||
continue
|
||||
|
||||
# Extract the function source using tree-sitter for resilient lookup
|
||||
func_source = extract_function_source(source, func, analyzer=analyzer)
|
||||
|
|
@ -795,7 +801,7 @@ def extract_read_only_context(source: str, function: FunctionToOptimize, analyze
|
|||
return "\n".join(context_parts)
|
||||
|
||||
|
||||
def _import_to_statement(import_info) -> str:
|
||||
def _import_to_statement(import_info: JavaImportInfo) -> str:
|
||||
"""Convert a JavaImportInfo to an import statement string.
|
||||
|
||||
Args:
|
||||
|
|
@ -863,6 +869,10 @@ def extract_class_context(file_path: Path, class_name: str, analyzer: JavaAnalyz
|
|||
|
||||
# Maximum token budget for imported type skeletons to avoid bloating testgen context
|
||||
IMPORTED_SKELETON_TOKEN_BUDGET = 4000
|
||||
# Maximum types to expand from a single wildcard import before filtering to referenced types only.
|
||||
# Packages with more types than this (e.g. org.jooq with 870+) would waste minutes of disk I/O
|
||||
# and almost always exceed the token budget.
|
||||
MAX_WILDCARD_TYPES_UNFILTERED = 50
|
||||
|
||||
|
||||
def _extract_type_names_from_code(code: str, analyzer: JavaAnalyzer) -> set[str]:
|
||||
|
|
@ -894,7 +904,11 @@ def _extract_type_names_from_code(code: str, analyzer: JavaAnalyzer) -> set[str]
|
|||
|
||||
|
||||
def get_java_imported_type_skeletons(
|
||||
imports: list, project_root: Path, module_root: Path | None, analyzer: JavaAnalyzer, target_code: str = ""
|
||||
imports: list[JavaImportInfo],
|
||||
project_root: Path,
|
||||
module_root: Path | None,
|
||||
analyzer: JavaAnalyzer,
|
||||
target_code: str = "",
|
||||
) -> str:
|
||||
"""Extract type skeletons for project-internal imported types.
|
||||
|
||||
|
|
@ -929,14 +943,32 @@ def get_java_imported_type_skeletons(
|
|||
priority_types = _extract_type_names_from_code(target_code, analyzer)
|
||||
|
||||
# Pre-resolve all imports, expanding wildcards into individual types
|
||||
resolved_imports: list = []
|
||||
resolved_imports: list[ResolvedImport] = []
|
||||
for imp in imports:
|
||||
if imp.is_wildcard:
|
||||
# Expand wildcard imports (e.g., com.aerospike.client.policy.*) into individual types
|
||||
expanded = resolver.expand_wildcard_import(imp.import_path)
|
||||
# First try unfiltered expansion with a cap. If the package is small enough, take all types.
|
||||
# If it's huge (e.g. org.jooq.* with 870+ types), filter to only types referenced in the target code.
|
||||
expanded = resolver.expand_wildcard_import(imp.import_path, max_types=MAX_WILDCARD_TYPES_UNFILTERED + 1)
|
||||
if len(expanded) > MAX_WILDCARD_TYPES_UNFILTERED:
|
||||
if priority_types:
|
||||
expanded = resolver.expand_wildcard_import(imp.import_path, filter_names=priority_types)
|
||||
logger.debug(
|
||||
"Wildcard %s.* exceeds %d types, filtered to %d referenced types",
|
||||
imp.import_path,
|
||||
MAX_WILDCARD_TYPES_UNFILTERED,
|
||||
len(expanded),
|
||||
)
|
||||
else:
|
||||
expanded = expanded[:MAX_WILDCARD_TYPES_UNFILTERED]
|
||||
logger.debug(
|
||||
"Wildcard %s.* exceeds %d types, capped (no target types to filter by)",
|
||||
imp.import_path,
|
||||
MAX_WILDCARD_TYPES_UNFILTERED,
|
||||
)
|
||||
elif expanded:
|
||||
logger.debug("Expanded wildcard import %s.* into %d types", imp.import_path, len(expanded))
|
||||
if expanded:
|
||||
resolved_imports.extend(expanded)
|
||||
logger.debug("Expanded wildcard import %s.* into %d types", imp.import_path, len(expanded))
|
||||
continue
|
||||
|
||||
resolved = resolver.resolve_import(imp)
|
||||
|
|
@ -956,7 +988,7 @@ def get_java_imported_type_skeletons(
|
|||
|
||||
for resolved in resolved_imports:
|
||||
class_name = resolved.class_name
|
||||
if not class_name:
|
||||
if not class_name or resolved.file_path is None:
|
||||
continue
|
||||
|
||||
dedup_key = (str(resolved.file_path), class_name)
|
||||
|
|
@ -1078,8 +1110,6 @@ def _extract_public_method_signatures(source: str, class_name: str, analyzer: Ja
|
|||
continue
|
||||
|
||||
node = method.node
|
||||
if not node:
|
||||
continue
|
||||
|
||||
# Check if the method is public
|
||||
is_public = False
|
||||
|
|
|
|||
|
|
@ -220,14 +220,20 @@ class JavaImportResolver:
|
|||
return last_part
|
||||
return None
|
||||
|
||||
def expand_wildcard_import(self, import_path: str) -> list[ResolvedImport]:
|
||||
def expand_wildcard_import(
|
||||
self, import_path: str, max_types: int = 0, filter_names: set[str] | None = None
|
||||
) -> list[ResolvedImport]:
|
||||
"""Expand a wildcard import (e.g., com.example.utils.*) to individual class imports.
|
||||
|
||||
Resolves the package path to a directory and returns a ResolvedImport for each
|
||||
.java file found in that directory.
|
||||
|
||||
Args:
|
||||
import_path: The package path (without the trailing .*).
|
||||
max_types: Maximum number of types to return. 0 means no limit.
|
||||
filter_names: If provided, only include types whose class name is in this set.
|
||||
|
||||
"""
|
||||
# Convert package path to directory path
|
||||
# e.g., "com.example.utils" -> "com/example/utils"
|
||||
relative_dir = import_path.replace(".", "/")
|
||||
|
||||
resolved: list[ResolvedImport] = []
|
||||
|
|
@ -237,8 +243,10 @@ class JavaImportResolver:
|
|||
if candidate_dir.is_dir():
|
||||
for java_file in candidate_dir.glob("*.java"):
|
||||
class_name = java_file.stem
|
||||
# Only include files that look like class names (start with uppercase)
|
||||
if class_name and class_name[0].isupper():
|
||||
if not class_name or not class_name[0].isupper():
|
||||
continue
|
||||
if filter_names is not None and class_name not in filter_names:
|
||||
continue
|
||||
resolved.append(
|
||||
ResolvedImport(
|
||||
import_path=f"{import_path}.{class_name}",
|
||||
|
|
@ -248,6 +256,8 @@ class JavaImportResolver:
|
|||
class_name=class_name,
|
||||
)
|
||||
)
|
||||
if max_types and len(resolved) >= max_types:
|
||||
return resolved
|
||||
|
||||
return resolved
|
||||
|
||||
|
|
|
|||
|
|
@ -2,7 +2,8 @@
|
|||
|
||||
from pathlib import Path
|
||||
|
||||
from codeflash.languages.base import FunctionFilterCriteria, Language
|
||||
from codeflash.languages.base import FunctionFilterCriteria
|
||||
from codeflash.languages.language_enum import Language
|
||||
from codeflash.languages.java.context import (
|
||||
TypeSkeleton,
|
||||
_extract_public_method_signatures,
|
||||
|
|
@ -23,7 +24,7 @@ NO_RETURN_FILTER = FunctionFilterCriteria(require_return=False)
|
|||
class TestExtractCodeContextBasic:
|
||||
"""Tests for basic extract_code_context functionality."""
|
||||
|
||||
def test_simple_method(self, tmp_path: Path):
|
||||
def test_simple_method(self, tmp_path: Path) -> None:
|
||||
"""Test extracting context for a simple method."""
|
||||
java_file = tmp_path / "Calculator.java"
|
||||
java_file.write_text("""public class Calculator {
|
||||
|
|
@ -53,7 +54,7 @@ class TestExtractCodeContextBasic:
|
|||
assert context.helper_functions == []
|
||||
assert context.read_only_context == ""
|
||||
|
||||
def test_method_with_javadoc(self, tmp_path: Path):
|
||||
def test_method_with_javadoc(self, tmp_path: Path) -> None:
|
||||
"""Test extracting context for method with Javadoc."""
|
||||
java_file = tmp_path / "Calculator.java"
|
||||
java_file.write_text("""public class Calculator {
|
||||
|
|
@ -94,7 +95,7 @@ class TestExtractCodeContextBasic:
|
|||
assert context.helper_functions == []
|
||||
assert context.read_only_context == ""
|
||||
|
||||
def test_static_method(self, tmp_path: Path):
|
||||
def test_static_method(self, tmp_path: Path) -> None:
|
||||
"""Test extracting context for a static method."""
|
||||
java_file = tmp_path / "MathUtils.java"
|
||||
java_file.write_text("""public class MathUtils {
|
||||
|
|
@ -123,7 +124,7 @@ class TestExtractCodeContextBasic:
|
|||
assert context.helper_functions == []
|
||||
assert context.read_only_context == ""
|
||||
|
||||
def test_private_method(self, tmp_path: Path):
|
||||
def test_private_method(self, tmp_path: Path) -> None:
|
||||
"""Test extracting context for a private method."""
|
||||
java_file = tmp_path / "Helper.java"
|
||||
java_file.write_text("""public class Helper {
|
||||
|
|
@ -149,7 +150,7 @@ class TestExtractCodeContextBasic:
|
|||
"""
|
||||
)
|
||||
|
||||
def test_protected_method(self, tmp_path: Path):
|
||||
def test_protected_method(self, tmp_path: Path) -> None:
|
||||
"""Test extracting context for a protected method."""
|
||||
java_file = tmp_path / "Base.java"
|
||||
java_file.write_text("""public class Base {
|
||||
|
|
@ -175,7 +176,7 @@ class TestExtractCodeContextBasic:
|
|||
"""
|
||||
)
|
||||
|
||||
def test_synchronized_method(self, tmp_path: Path):
|
||||
def test_synchronized_method(self, tmp_path: Path) -> None:
|
||||
"""Test extracting context for a synchronized method."""
|
||||
java_file = tmp_path / "Counter.java"
|
||||
java_file.write_text("""public class Counter {
|
||||
|
|
@ -200,7 +201,7 @@ class TestExtractCodeContextBasic:
|
|||
"""
|
||||
)
|
||||
|
||||
def test_method_with_throws(self, tmp_path: Path):
|
||||
def test_method_with_throws(self, tmp_path: Path) -> None:
|
||||
"""Test extracting context for a method with throws clause."""
|
||||
java_file = tmp_path / "FileHandler.java"
|
||||
java_file.write_text("""public class FileHandler {
|
||||
|
|
@ -225,7 +226,7 @@ class TestExtractCodeContextBasic:
|
|||
"""
|
||||
)
|
||||
|
||||
def test_method_with_varargs(self, tmp_path: Path):
|
||||
def test_method_with_varargs(self, tmp_path: Path) -> None:
|
||||
"""Test extracting context for a method with varargs."""
|
||||
java_file = tmp_path / "Logger.java"
|
||||
java_file.write_text("""public class Logger {
|
||||
|
|
@ -250,7 +251,7 @@ class TestExtractCodeContextBasic:
|
|||
"""
|
||||
)
|
||||
|
||||
def test_void_method(self, tmp_path: Path):
|
||||
def test_void_method(self, tmp_path: Path) -> None:
|
||||
"""Test extracting context for a void method."""
|
||||
java_file = tmp_path / "Printer.java"
|
||||
java_file.write_text("""public class Printer {
|
||||
|
|
@ -277,7 +278,7 @@ class TestExtractCodeContextBasic:
|
|||
"""
|
||||
)
|
||||
|
||||
def test_generic_return_type(self, tmp_path: Path):
|
||||
def test_generic_return_type(self, tmp_path: Path) -> None:
|
||||
"""Test extracting context for a method with generic return type."""
|
||||
java_file = tmp_path / "Container.java"
|
||||
java_file.write_text("""public class Container {
|
||||
|
|
@ -306,7 +307,7 @@ class TestExtractCodeContextBasic:
|
|||
class TestExtractCodeContextWithImports:
|
||||
"""Tests for extract_code_context with various import types."""
|
||||
|
||||
def test_with_package_and_imports(self, tmp_path: Path):
|
||||
def test_with_package_and_imports(self, tmp_path: Path) -> None:
|
||||
"""Test context extraction with package and imports."""
|
||||
java_file = tmp_path / "Calculator.java"
|
||||
java_file.write_text("""package com.example;
|
||||
|
|
@ -344,7 +345,7 @@ public class Calculator {
|
|||
# Fields are in skeleton, so read_only_context is empty
|
||||
assert context.read_only_context == ""
|
||||
|
||||
def test_with_static_imports(self, tmp_path: Path):
|
||||
def test_with_static_imports(self, tmp_path: Path) -> None:
|
||||
"""Test context extraction with static imports."""
|
||||
java_file = tmp_path / "Calculator.java"
|
||||
java_file.write_text("""package com.example;
|
||||
|
|
@ -380,7 +381,7 @@ public class Calculator {
|
|||
"import static java.lang.Math.sqrt;",
|
||||
]
|
||||
|
||||
def test_with_wildcard_imports(self, tmp_path: Path):
|
||||
def test_with_wildcard_imports(self, tmp_path: Path) -> None:
|
||||
"""Test context extraction with wildcard imports."""
|
||||
java_file = tmp_path / "Processor.java"
|
||||
java_file.write_text("""package com.example;
|
||||
|
|
@ -402,7 +403,7 @@ public class Processor {
|
|||
assert context.language == Language.JAVA
|
||||
assert context.imports == ["import java.util.*;", "import java.io.*;"]
|
||||
|
||||
def test_with_multiple_import_types(self, tmp_path: Path):
|
||||
def test_with_multiple_import_types(self, tmp_path: Path) -> None:
|
||||
"""Test context extraction with various import types."""
|
||||
java_file = tmp_path / "Handler.java"
|
||||
java_file.write_text("""package com.example;
|
||||
|
|
@ -453,7 +454,7 @@ class TestExtractCodeContextWithFields:
|
|||
read_only_context should be empty to avoid duplication.
|
||||
"""
|
||||
|
||||
def test_with_instance_fields(self, tmp_path: Path):
|
||||
def test_with_instance_fields(self, tmp_path: Path) -> None:
|
||||
"""Test context extraction with instance fields."""
|
||||
java_file = tmp_path / "Person.java"
|
||||
java_file.write_text("""public class Person {
|
||||
|
|
@ -488,7 +489,7 @@ class TestExtractCodeContextWithFields:
|
|||
assert context.imports == []
|
||||
assert context.helper_functions == []
|
||||
|
||||
def test_with_static_fields(self, tmp_path: Path):
|
||||
def test_with_static_fields(self, tmp_path: Path) -> None:
|
||||
"""Test context extraction with static fields."""
|
||||
java_file = tmp_path / "Counter.java"
|
||||
java_file.write_text("""public class Counter {
|
||||
|
|
@ -519,7 +520,7 @@ class TestExtractCodeContextWithFields:
|
|||
# Fields are in skeleton, so read_only_context is empty
|
||||
assert context.read_only_context == ""
|
||||
|
||||
def test_with_final_fields(self, tmp_path: Path):
|
||||
def test_with_final_fields(self, tmp_path: Path) -> None:
|
||||
"""Test context extraction with final fields."""
|
||||
java_file = tmp_path / "Config.java"
|
||||
java_file.write_text("""public class Config {
|
||||
|
|
@ -549,7 +550,7 @@ class TestExtractCodeContextWithFields:
|
|||
)
|
||||
assert context.read_only_context == ""
|
||||
|
||||
def test_with_static_final_constants(self, tmp_path: Path):
|
||||
def test_with_static_final_constants(self, tmp_path: Path) -> None:
|
||||
"""Test context extraction with static final constants."""
|
||||
java_file = tmp_path / "Constants.java"
|
||||
java_file.write_text("""public class Constants {
|
||||
|
|
@ -581,7 +582,7 @@ class TestExtractCodeContextWithFields:
|
|||
)
|
||||
assert context.read_only_context == ""
|
||||
|
||||
def test_with_volatile_fields(self, tmp_path: Path):
|
||||
def test_with_volatile_fields(self, tmp_path: Path) -> None:
|
||||
"""Test context extraction with volatile fields."""
|
||||
java_file = tmp_path / "ThreadSafe.java"
|
||||
java_file.write_text("""public class ThreadSafe {
|
||||
|
|
@ -611,7 +612,7 @@ class TestExtractCodeContextWithFields:
|
|||
)
|
||||
assert context.read_only_context == ""
|
||||
|
||||
def test_with_generic_fields(self, tmp_path: Path):
|
||||
def test_with_generic_fields(self, tmp_path: Path) -> None:
|
||||
"""Test context extraction with generic type fields."""
|
||||
java_file = tmp_path / "Container.java"
|
||||
java_file.write_text("""public class Container {
|
||||
|
|
@ -643,7 +644,7 @@ class TestExtractCodeContextWithFields:
|
|||
)
|
||||
assert context.read_only_context == ""
|
||||
|
||||
def test_with_array_fields(self, tmp_path: Path):
|
||||
def test_with_array_fields(self, tmp_path: Path) -> None:
|
||||
"""Test context extraction with array fields."""
|
||||
java_file = tmp_path / "ArrayHolder.java"
|
||||
java_file.write_text("""public class ArrayHolder {
|
||||
|
|
@ -679,7 +680,7 @@ class TestExtractCodeContextWithFields:
|
|||
class TestExtractCodeContextWithHelpers:
|
||||
"""Tests for extract_code_context with helper functions."""
|
||||
|
||||
def test_single_helper_method(self, tmp_path: Path):
|
||||
def test_single_helper_method(self, tmp_path: Path) -> None:
|
||||
"""Test context extraction with a single helper method."""
|
||||
java_file = tmp_path / "Processor.java"
|
||||
java_file.write_text("""public class Processor {
|
||||
|
|
@ -715,7 +716,7 @@ class TestExtractCodeContextWithHelpers:
|
|||
== "private String normalize(String s) {\n return s.trim().toLowerCase();\n }"
|
||||
)
|
||||
|
||||
def test_multiple_helper_methods(self, tmp_path: Path):
|
||||
def test_multiple_helper_methods(self, tmp_path: Path) -> None:
|
||||
"""Test context extraction with multiple helper methods."""
|
||||
java_file = tmp_path / "Processor.java"
|
||||
java_file.write_text("""public class Processor {
|
||||
|
|
@ -758,7 +759,7 @@ class TestExtractCodeContextWithHelpers:
|
|||
helper_names = sorted([h.name for h in context.helper_functions])
|
||||
assert helper_names == ["trim", "upper"]
|
||||
|
||||
def test_chained_helper_calls(self, tmp_path: Path):
|
||||
def test_chained_helper_calls(self, tmp_path: Path) -> None:
|
||||
"""Test context extraction with chained helper calls."""
|
||||
java_file = tmp_path / "Processor.java"
|
||||
java_file.write_text("""public class Processor {
|
||||
|
|
@ -784,7 +785,7 @@ class TestExtractCodeContextWithHelpers:
|
|||
helper_names = [h.name for h in context.helper_functions]
|
||||
assert helper_names == ["normalize"]
|
||||
|
||||
def test_no_helpers_when_none_called(self, tmp_path: Path):
|
||||
def test_no_helpers_when_none_called(self, tmp_path: Path) -> None:
|
||||
"""Test context extraction when no helpers are called."""
|
||||
java_file = tmp_path / "Calculator.java"
|
||||
java_file.write_text("""public class Calculator {
|
||||
|
|
@ -814,7 +815,7 @@ class TestExtractCodeContextWithHelpers:
|
|||
)
|
||||
assert context.helper_functions == []
|
||||
|
||||
def test_static_helper_from_instance_method(self, tmp_path: Path):
|
||||
def test_static_helper_from_instance_method(self, tmp_path: Path) -> None:
|
||||
"""Test context extraction with static helper called from instance method."""
|
||||
java_file = tmp_path / "Calculator.java"
|
||||
java_file.write_text("""public class Calculator {
|
||||
|
|
@ -840,7 +841,7 @@ class TestExtractCodeContextWithHelpers:
|
|||
class TestExtractCodeContextWithJavadoc:
|
||||
"""Tests for extract_code_context with various Javadoc patterns."""
|
||||
|
||||
def test_simple_javadoc(self, tmp_path: Path):
|
||||
def test_simple_javadoc(self, tmp_path: Path) -> None:
|
||||
"""Test context extraction with simple Javadoc."""
|
||||
java_file = tmp_path / "Example.java"
|
||||
java_file.write_text("""public class Example {
|
||||
|
|
@ -866,7 +867,7 @@ class TestExtractCodeContextWithJavadoc:
|
|||
"""
|
||||
)
|
||||
|
||||
def test_javadoc_with_params(self, tmp_path: Path):
|
||||
def test_javadoc_with_params(self, tmp_path: Path) -> None:
|
||||
"""Test context extraction with Javadoc @param tags."""
|
||||
java_file = tmp_path / "Calculator.java"
|
||||
java_file.write_text("""public class Calculator {
|
||||
|
|
@ -900,7 +901,7 @@ class TestExtractCodeContextWithJavadoc:
|
|||
"""
|
||||
)
|
||||
|
||||
def test_javadoc_with_return(self, tmp_path: Path):
|
||||
def test_javadoc_with_return(self, tmp_path: Path) -> None:
|
||||
"""Test context extraction with Javadoc @return tag."""
|
||||
java_file = tmp_path / "Calculator.java"
|
||||
java_file.write_text("""public class Calculator {
|
||||
|
|
@ -932,7 +933,7 @@ class TestExtractCodeContextWithJavadoc:
|
|||
"""
|
||||
)
|
||||
|
||||
def test_javadoc_with_throws(self, tmp_path: Path):
|
||||
def test_javadoc_with_throws(self, tmp_path: Path) -> None:
|
||||
"""Test context extraction with Javadoc @throws tag."""
|
||||
java_file = tmp_path / "Divider.java"
|
||||
java_file.write_text("""public class Divider {
|
||||
|
|
@ -968,7 +969,7 @@ class TestExtractCodeContextWithJavadoc:
|
|||
"""
|
||||
)
|
||||
|
||||
def test_javadoc_multiline(self, tmp_path: Path):
|
||||
def test_javadoc_multiline(self, tmp_path: Path) -> None:
|
||||
"""Test context extraction with multi-paragraph Javadoc."""
|
||||
java_file = tmp_path / "Complex.java"
|
||||
java_file.write_text("""public class Complex {
|
||||
|
|
@ -1020,7 +1021,7 @@ class TestExtractCodeContextWithJavadoc:
|
|||
class TestExtractCodeContextWithGenerics:
|
||||
"""Tests for extract_code_context with generic types."""
|
||||
|
||||
def test_generic_method_type_parameter(self, tmp_path: Path):
|
||||
def test_generic_method_type_parameter(self, tmp_path: Path) -> None:
|
||||
"""Test context extraction with generic type parameter."""
|
||||
java_file = tmp_path / "Utils.java"
|
||||
java_file.write_text("""public class Utils {
|
||||
|
|
@ -1044,7 +1045,7 @@ class TestExtractCodeContextWithGenerics:
|
|||
"""
|
||||
)
|
||||
|
||||
def test_bounded_type_parameter(self, tmp_path: Path):
|
||||
def test_bounded_type_parameter(self, tmp_path: Path) -> None:
|
||||
"""Test context extraction with bounded type parameter."""
|
||||
java_file = tmp_path / "Statistics.java"
|
||||
java_file.write_text("""public class Statistics {
|
||||
|
|
@ -1076,7 +1077,7 @@ class TestExtractCodeContextWithGenerics:
|
|||
"""
|
||||
)
|
||||
|
||||
def test_wildcard_type(self, tmp_path: Path):
|
||||
def test_wildcard_type(self, tmp_path: Path) -> None:
|
||||
"""Test context extraction with wildcard type."""
|
||||
java_file = tmp_path / "Printer.java"
|
||||
java_file.write_text("""public class Printer {
|
||||
|
|
@ -1100,7 +1101,7 @@ class TestExtractCodeContextWithGenerics:
|
|||
"""
|
||||
)
|
||||
|
||||
def test_bounded_wildcard_extends(self, tmp_path: Path):
|
||||
def test_bounded_wildcard_extends(self, tmp_path: Path) -> None:
|
||||
"""Test context extraction with upper bounded wildcard."""
|
||||
java_file = tmp_path / "Aggregator.java"
|
||||
java_file.write_text("""public class Aggregator {
|
||||
|
|
@ -1132,7 +1133,7 @@ class TestExtractCodeContextWithGenerics:
|
|||
"""
|
||||
)
|
||||
|
||||
def test_bounded_wildcard_super(self, tmp_path: Path):
|
||||
def test_bounded_wildcard_super(self, tmp_path: Path) -> None:
|
||||
"""Test context extraction with lower bounded wildcard."""
|
||||
java_file = tmp_path / "Filler.java"
|
||||
java_file.write_text("""public class Filler {
|
||||
|
|
@ -1158,7 +1159,7 @@ class TestExtractCodeContextWithGenerics:
|
|||
"""
|
||||
)
|
||||
|
||||
def test_multiple_type_parameters(self, tmp_path: Path):
|
||||
def test_multiple_type_parameters(self, tmp_path: Path) -> None:
|
||||
"""Test context extraction with multiple type parameters."""
|
||||
java_file = tmp_path / "Mapper.java"
|
||||
java_file.write_text("""public class Mapper {
|
||||
|
|
@ -1190,7 +1191,7 @@ class TestExtractCodeContextWithGenerics:
|
|||
"""
|
||||
)
|
||||
|
||||
def test_recursive_type_bound(self, tmp_path: Path):
|
||||
def test_recursive_type_bound(self, tmp_path: Path) -> None:
|
||||
"""Test context extraction with recursive type bound."""
|
||||
java_file = tmp_path / "Sorter.java"
|
||||
java_file.write_text("""public class Sorter {
|
||||
|
|
@ -1218,7 +1219,7 @@ class TestExtractCodeContextWithGenerics:
|
|||
class TestExtractCodeContextWithAnnotations:
|
||||
"""Tests for extract_code_context with annotations."""
|
||||
|
||||
def test_override_annotation(self, tmp_path: Path):
|
||||
def test_override_annotation(self, tmp_path: Path) -> None:
|
||||
"""Test context extraction with @Override annotation."""
|
||||
java_file = tmp_path / "Child.java"
|
||||
java_file.write_text("""public class Child extends Parent {
|
||||
|
|
@ -1244,7 +1245,7 @@ class TestExtractCodeContextWithAnnotations:
|
|||
"""
|
||||
)
|
||||
|
||||
def test_deprecated_annotation(self, tmp_path: Path):
|
||||
def test_deprecated_annotation(self, tmp_path: Path) -> None:
|
||||
"""Test context extraction with @Deprecated annotation."""
|
||||
java_file = tmp_path / "Legacy.java"
|
||||
java_file.write_text("""public class Legacy {
|
||||
|
|
@ -1270,7 +1271,7 @@ class TestExtractCodeContextWithAnnotations:
|
|||
"""
|
||||
)
|
||||
|
||||
def test_suppress_warnings_annotation(self, tmp_path: Path):
|
||||
def test_suppress_warnings_annotation(self, tmp_path: Path) -> None:
|
||||
"""Test context extraction with @SuppressWarnings annotation."""
|
||||
java_file = tmp_path / "Processor.java"
|
||||
java_file.write_text("""public class Processor {
|
||||
|
|
@ -1296,7 +1297,7 @@ class TestExtractCodeContextWithAnnotations:
|
|||
"""
|
||||
)
|
||||
|
||||
def test_multiple_annotations(self, tmp_path: Path):
|
||||
def test_multiple_annotations(self, tmp_path: Path) -> None:
|
||||
"""Test context extraction with multiple annotations."""
|
||||
java_file = tmp_path / "Service.java"
|
||||
java_file.write_text("""public class Service {
|
||||
|
|
@ -1326,7 +1327,7 @@ class TestExtractCodeContextWithAnnotations:
|
|||
"""
|
||||
)
|
||||
|
||||
def test_annotation_with_array_value(self, tmp_path: Path):
|
||||
def test_annotation_with_array_value(self, tmp_path: Path) -> None:
|
||||
"""Test context extraction with annotation array value."""
|
||||
java_file = tmp_path / "Handler.java"
|
||||
java_file.write_text("""public class Handler {
|
||||
|
|
@ -1356,7 +1357,7 @@ class TestExtractCodeContextWithAnnotations:
|
|||
class TestExtractCodeContextWithInheritance:
|
||||
"""Tests for extract_code_context with inheritance scenarios."""
|
||||
|
||||
def test_method_in_subclass(self, tmp_path: Path):
|
||||
def test_method_in_subclass(self, tmp_path: Path) -> None:
|
||||
"""Test context extraction for method in subclass."""
|
||||
java_file = tmp_path / "AdvancedCalc.java"
|
||||
java_file.write_text("""public class AdvancedCalc extends Calculator {
|
||||
|
|
@ -1382,7 +1383,7 @@ class TestExtractCodeContextWithInheritance:
|
|||
"""
|
||||
)
|
||||
|
||||
def test_interface_implementation(self, tmp_path: Path):
|
||||
def test_interface_implementation(self, tmp_path: Path) -> None:
|
||||
"""Test context extraction for interface implementation."""
|
||||
java_file = tmp_path / "MyComparable.java"
|
||||
java_file.write_text("""public class MyComparable implements Comparable<MyComparable> {
|
||||
|
|
@ -1414,7 +1415,7 @@ class TestExtractCodeContextWithInheritance:
|
|||
# Fields are in skeleton, so read_only_context is empty (no duplication)
|
||||
assert context.read_only_context == ""
|
||||
|
||||
def test_multiple_interfaces(self, tmp_path: Path):
|
||||
def test_multiple_interfaces(self, tmp_path: Path) -> None:
|
||||
"""Test context extraction for multiple interface implementations."""
|
||||
java_file = tmp_path / "MultiImpl.java"
|
||||
java_file.write_text("""public class MultiImpl implements Runnable, Comparable<MultiImpl> {
|
||||
|
|
@ -1446,7 +1447,7 @@ class TestExtractCodeContextWithInheritance:
|
|||
"""
|
||||
)
|
||||
|
||||
def test_default_interface_method(self, tmp_path: Path):
|
||||
def test_default_interface_method(self, tmp_path: Path) -> None:
|
||||
"""Test context extraction for default interface method."""
|
||||
java_file = tmp_path / "MyInterface.java"
|
||||
java_file.write_text("""public interface MyInterface {
|
||||
|
|
@ -1479,7 +1480,7 @@ class TestExtractCodeContextWithInheritance:
|
|||
class TestExtractCodeContextWithInnerClasses:
|
||||
"""Tests for extract_code_context with inner/nested classes."""
|
||||
|
||||
def test_static_nested_class_method(self, tmp_path: Path):
|
||||
def test_static_nested_class_method(self, tmp_path: Path) -> None:
|
||||
"""Inner class methods are excluded from discovery and cannot be context-extracted.
|
||||
|
||||
Methods of static nested classes are skipped in discovery because they
|
||||
|
|
@ -1499,7 +1500,7 @@ class TestExtractCodeContextWithInnerClasses:
|
|||
# Inner class method must NOT be discovered
|
||||
assert compute_func is None
|
||||
|
||||
def test_inner_class_method(self, tmp_path: Path):
|
||||
def test_inner_class_method(self, tmp_path: Path) -> None:
|
||||
"""Inner class methods are excluded from discovery and cannot be context-extracted.
|
||||
|
||||
Methods of non-static inner classes are skipped in discovery because they
|
||||
|
|
@ -1525,7 +1526,7 @@ class TestExtractCodeContextWithInnerClasses:
|
|||
class TestExtractCodeContextWithEnumAndInterface:
|
||||
"""Tests for extract_code_context with enums and interfaces."""
|
||||
|
||||
def test_enum_method(self, tmp_path: Path):
|
||||
def test_enum_method(self, tmp_path: Path) -> None:
|
||||
"""Test context extraction for enum method."""
|
||||
java_file = tmp_path / "Operation.java"
|
||||
java_file.write_text("""public enum Operation {
|
||||
|
|
@ -1568,7 +1569,7 @@ class TestExtractCodeContextWithEnumAndInterface:
|
|||
)
|
||||
assert context.read_only_context == ""
|
||||
|
||||
def test_interface_default_method(self, tmp_path: Path):
|
||||
def test_interface_default_method(self, tmp_path: Path) -> None:
|
||||
"""Test context extraction for interface default method."""
|
||||
java_file = tmp_path / "Greeting.java"
|
||||
java_file.write_text("""public interface Greeting {
|
||||
|
|
@ -1595,7 +1596,7 @@ class TestExtractCodeContextWithEnumAndInterface:
|
|||
)
|
||||
assert context.read_only_context == ""
|
||||
|
||||
def test_interface_static_method(self, tmp_path: Path):
|
||||
def test_interface_static_method(self, tmp_path: Path) -> None:
|
||||
"""Test context extraction for interface static method."""
|
||||
java_file = tmp_path / "Factory.java"
|
||||
java_file.write_text("""public interface Factory {
|
||||
|
|
@ -1626,7 +1627,7 @@ class TestExtractCodeContextWithEnumAndInterface:
|
|||
class TestExtractCodeContextEdgeCases:
|
||||
"""Tests for extract_code_context edge cases."""
|
||||
|
||||
def test_empty_method(self, tmp_path: Path):
|
||||
def test_empty_method(self, tmp_path: Path) -> None:
|
||||
"""Test context extraction for empty method."""
|
||||
java_file = tmp_path / "Empty.java"
|
||||
java_file.write_text("""public class Empty {
|
||||
|
|
@ -1650,7 +1651,7 @@ class TestExtractCodeContextEdgeCases:
|
|||
"""
|
||||
)
|
||||
|
||||
def test_single_line_method(self, tmp_path: Path):
|
||||
def test_single_line_method(self, tmp_path: Path) -> None:
|
||||
"""Test context extraction for single-line method."""
|
||||
java_file = tmp_path / "OneLiner.java"
|
||||
java_file.write_text("""public class OneLiner {
|
||||
|
|
@ -1670,7 +1671,7 @@ class TestExtractCodeContextEdgeCases:
|
|||
"""
|
||||
)
|
||||
|
||||
def test_method_with_lambda(self, tmp_path: Path):
|
||||
def test_method_with_lambda(self, tmp_path: Path) -> None:
|
||||
"""Test context extraction for method with lambda."""
|
||||
java_file = tmp_path / "Functional.java"
|
||||
java_file.write_text("""public class Functional {
|
||||
|
|
@ -1698,7 +1699,7 @@ class TestExtractCodeContextEdgeCases:
|
|||
"""
|
||||
)
|
||||
|
||||
def test_method_with_method_reference(self, tmp_path: Path):
|
||||
def test_method_with_method_reference(self, tmp_path: Path) -> None:
|
||||
"""Test context extraction for method with method reference."""
|
||||
java_file = tmp_path / "Printer.java"
|
||||
java_file.write_text("""public class Printer {
|
||||
|
|
@ -1722,7 +1723,7 @@ class TestExtractCodeContextEdgeCases:
|
|||
"""
|
||||
)
|
||||
|
||||
def test_deeply_nested_blocks(self, tmp_path: Path):
|
||||
def test_deeply_nested_blocks(self, tmp_path: Path) -> None:
|
||||
"""Test context extraction for method with deeply nested blocks."""
|
||||
java_file = tmp_path / "Nested.java"
|
||||
java_file.write_text("""public class Nested {
|
||||
|
|
@ -1776,7 +1777,7 @@ class TestExtractCodeContextEdgeCases:
|
|||
"""
|
||||
)
|
||||
|
||||
def test_unicode_in_source(self, tmp_path: Path):
|
||||
def test_unicode_in_source(self, tmp_path: Path) -> None:
|
||||
"""Test context extraction for method with unicode characters."""
|
||||
java_file = tmp_path / "Unicode.java"
|
||||
java_file.write_text(
|
||||
|
|
@ -1803,7 +1804,7 @@ class TestExtractCodeContextEdgeCases:
|
|||
"""
|
||||
)
|
||||
|
||||
def test_file_not_found(self, tmp_path: Path):
|
||||
def test_file_not_found(self, tmp_path: Path) -> None:
|
||||
"""Test context extraction for missing file."""
|
||||
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
|
||||
from codeflash.models.function_types import FunctionParent
|
||||
|
|
@ -1824,7 +1825,7 @@ class TestExtractCodeContextEdgeCases:
|
|||
assert context.language == Language.JAVA
|
||||
assert context.target_file == missing_file
|
||||
|
||||
def test_max_helper_depth_zero(self, tmp_path: Path):
|
||||
def test_max_helper_depth_zero(self, tmp_path: Path) -> None:
|
||||
"""Test context extraction with max_helper_depth=0."""
|
||||
java_file = tmp_path / "Calculator.java"
|
||||
java_file.write_text("""public class Calculator {
|
||||
|
|
@ -1858,7 +1859,7 @@ class TestExtractCodeContextEdgeCases:
|
|||
class TestExtractCodeContextWithConstructor:
|
||||
"""Tests for extract_code_context with constructors in class skeleton."""
|
||||
|
||||
def test_class_with_constructor(self, tmp_path: Path):
|
||||
def test_class_with_constructor(self, tmp_path: Path) -> None:
|
||||
"""Test context extraction includes constructor in skeleton."""
|
||||
java_file = tmp_path / "Person.java"
|
||||
java_file.write_text("""public class Person {
|
||||
|
|
@ -1898,7 +1899,7 @@ class TestExtractCodeContextWithConstructor:
|
|||
"""
|
||||
)
|
||||
|
||||
def test_class_with_multiple_constructors(self, tmp_path: Path):
|
||||
def test_class_with_multiple_constructors(self, tmp_path: Path) -> None:
|
||||
"""Test context extraction includes all constructors in skeleton."""
|
||||
java_file = tmp_path / "Config.java"
|
||||
java_file.write_text("""public class Config {
|
||||
|
|
@ -1956,7 +1957,7 @@ class TestExtractCodeContextWithConstructor:
|
|||
class TestExtractCodeContextFullIntegration:
|
||||
"""Integration tests for extract_code_context with all components."""
|
||||
|
||||
def test_full_context_with_all_components(self, tmp_path: Path):
|
||||
def test_full_context_with_all_components(self, tmp_path: Path) -> None:
|
||||
"""Test context extraction with imports, fields, and helpers."""
|
||||
java_file = tmp_path / "Service.java"
|
||||
java_file.write_text("""package com.example;
|
||||
|
|
@ -2007,7 +2008,7 @@ public class Service {
|
|||
assert len(context.helper_functions) == 1
|
||||
assert context.helper_functions[0].name == "transform"
|
||||
|
||||
def test_complex_class_with_javadoc_and_annotations(self, tmp_path: Path):
|
||||
def test_complex_class_with_javadoc_and_annotations(self, tmp_path: Path) -> None:
|
||||
"""Test context extraction for complex class with javadoc and annotations."""
|
||||
java_file = tmp_path / "Calculator.java"
|
||||
java_file.write_text("""package com.example.math;
|
||||
|
|
@ -2073,7 +2074,7 @@ public class Calculator {
|
|||
class TestExtractClassContext:
|
||||
"""Tests for extract_class_context."""
|
||||
|
||||
def test_extract_class_with_imports(self, tmp_path: Path):
|
||||
def test_extract_class_with_imports(self, tmp_path: Path) -> None:
|
||||
"""Test extracting full class context with imports."""
|
||||
java_file = tmp_path / "Calculator.java"
|
||||
java_file.write_text("""package com.example;
|
||||
|
|
@ -2112,7 +2113,7 @@ public class Calculator {
|
|||
}"""
|
||||
)
|
||||
|
||||
def test_extract_class_not_found(self, tmp_path: Path):
|
||||
def test_extract_class_not_found(self, tmp_path: Path) -> None:
|
||||
"""Test extracting non-existent class returns empty string."""
|
||||
java_file = tmp_path / "Test.java"
|
||||
java_file.write_text("""public class Test {
|
||||
|
|
@ -2124,7 +2125,7 @@ public class Calculator {
|
|||
|
||||
assert context == ""
|
||||
|
||||
def test_extract_class_missing_file(self, tmp_path: Path):
|
||||
def test_extract_class_missing_file(self, tmp_path: Path) -> None:
|
||||
"""Test extracting from missing file returns empty string."""
|
||||
missing_file = tmp_path / "Missing.java"
|
||||
|
||||
|
|
@ -2141,7 +2142,7 @@ class TestExtractFunctionSourceStaleLineNumbers:
|
|||
extraction should still find the correct function by name.
|
||||
"""
|
||||
|
||||
def test_extraction_with_stale_line_numbers(self):
|
||||
def test_extraction_with_stale_line_numbers(self) -> None:
|
||||
"""Verify extraction works when pre-computed line numbers no longer match the source."""
|
||||
# Original source: functionA at lines 2-4, functionB at lines 5-7
|
||||
original_source = """public class Utils {
|
||||
|
|
@ -2177,7 +2178,7 @@ class TestExtractFunctionSourceStaleLineNumbers:
|
|||
assert "functionB" in result
|
||||
assert "return 2;" in result
|
||||
|
||||
def test_extraction_without_analyzer_uses_line_numbers(self):
|
||||
def test_extraction_without_analyzer_uses_line_numbers(self) -> None:
|
||||
"""Without analyzer, extraction falls back to pre-computed line numbers."""
|
||||
source = """public class Utils {
|
||||
public int functionA() {
|
||||
|
|
@ -2196,7 +2197,7 @@ class TestExtractFunctionSourceStaleLineNumbers:
|
|||
assert "functionB" in result
|
||||
assert "return 2;" in result
|
||||
|
||||
def test_extraction_with_javadoc_after_file_modification(self):
|
||||
def test_extraction_with_javadoc_after_file_modification(self) -> None:
|
||||
"""Verify Javadoc is included when using tree-sitter extraction on modified files."""
|
||||
original_source = """public class Utils {
|
||||
/** Adds two numbers. */
|
||||
|
|
@ -2233,7 +2234,7 @@ class TestExtractFunctionSourceStaleLineNumbers:
|
|||
assert "public int subtract" in result
|
||||
assert "return a - b;" in result
|
||||
|
||||
def test_extraction_with_overloaded_methods(self):
|
||||
def test_extraction_with_overloaded_methods(self) -> None:
|
||||
"""Verify correct overload is selected using line proximity."""
|
||||
source = """public class Utils {
|
||||
public int process(int x) {
|
||||
|
|
@ -2247,13 +2248,15 @@ class TestExtractFunctionSourceStaleLineNumbers:
|
|||
analyzer = get_java_analyzer()
|
||||
functions = discover_functions_from_source(source, file_path=Path("Utils.java"))
|
||||
# Get the second overload (process(int, int))
|
||||
func_two_args = [f for f in functions if f.function_name == "process" and f.ending_line > 4][0]
|
||||
func_two_args = [
|
||||
f for f in functions if f.function_name == "process" and f.ending_line is not None and f.ending_line > 4
|
||||
][0]
|
||||
|
||||
result = extract_function_source(source, func_two_args, analyzer=analyzer)
|
||||
assert "int x, int y" in result
|
||||
assert "return x + y;" in result
|
||||
|
||||
def test_extraction_function_not_found_falls_back(self):
|
||||
def test_extraction_function_not_found_falls_back(self) -> None:
|
||||
"""If tree-sitter can't find the method, fall back to line numbers."""
|
||||
source = """public class Utils {
|
||||
public int functionA() {
|
||||
|
|
@ -2282,7 +2285,7 @@ FIXTURE_DIR = Path(__file__).parent.parent / "fixtures" / "java_maven"
|
|||
class TestGetJavaImportedTypeSkeletons:
|
||||
"""Tests for get_java_imported_type_skeletons()."""
|
||||
|
||||
def test_resolves_internal_imports(self):
|
||||
def test_resolves_internal_imports(self) -> None:
|
||||
"""Verify that project-internal imports are resolved and skeletons extracted."""
|
||||
project_root = FIXTURE_DIR
|
||||
module_root = FIXTURE_DIR / "src" / "main" / "java"
|
||||
|
|
@ -2297,7 +2300,7 @@ class TestGetJavaImportedTypeSkeletons:
|
|||
assert "MathHelper" in result
|
||||
assert "Formatter" in result
|
||||
|
||||
def test_skeletons_contain_method_signatures(self):
|
||||
def test_skeletons_contain_method_signatures(self) -> None:
|
||||
"""Verify extracted skeletons include public method signatures."""
|
||||
project_root = FIXTURE_DIR
|
||||
module_root = FIXTURE_DIR / "src" / "main" / "java"
|
||||
|
|
@ -2313,7 +2316,7 @@ class TestGetJavaImportedTypeSkeletons:
|
|||
assert "multiply" in result
|
||||
assert "factorial" in result
|
||||
|
||||
def test_skips_external_imports(self):
|
||||
def test_skips_external_imports(self) -> None:
|
||||
"""Verify that standard library and external imports are skipped."""
|
||||
project_root = FIXTURE_DIR
|
||||
module_root = FIXTURE_DIR / "src" / "main" / "java"
|
||||
|
|
@ -2328,7 +2331,7 @@ class TestGetJavaImportedTypeSkeletons:
|
|||
# No internal imports → empty result
|
||||
assert result == ""
|
||||
|
||||
def test_deduplicates_imports(self):
|
||||
def test_deduplicates_imports(self) -> None:
|
||||
"""Verify that the same type imported twice is only included once."""
|
||||
project_root = FIXTURE_DIR
|
||||
module_root = FIXTURE_DIR / "src" / "main" / "java"
|
||||
|
|
@ -2344,7 +2347,7 @@ class TestGetJavaImportedTypeSkeletons:
|
|||
# Count occurrences of MathHelper — should appear exactly once
|
||||
assert result.count("class MathHelper") == 1
|
||||
|
||||
def test_empty_imports_returns_empty(self):
|
||||
def test_empty_imports_returns_empty(self) -> None:
|
||||
"""Verify that empty import list returns empty string."""
|
||||
project_root = FIXTURE_DIR
|
||||
analyzer = get_java_analyzer()
|
||||
|
|
@ -2353,7 +2356,7 @@ class TestGetJavaImportedTypeSkeletons:
|
|||
|
||||
assert result == ""
|
||||
|
||||
def test_respects_token_budget(self):
|
||||
def test_respects_token_budget(self) -> None:
|
||||
"""Verify that the function stops when token budget is exceeded."""
|
||||
project_root = FIXTURE_DIR
|
||||
module_root = FIXTURE_DIR / "src" / "main" / "java"
|
||||
|
|
@ -2378,7 +2381,7 @@ class TestGetJavaImportedTypeSkeletons:
|
|||
class TestExtractPublicMethodSignatures:
|
||||
"""Tests for _extract_public_method_signatures()."""
|
||||
|
||||
def test_extracts_public_methods(self):
|
||||
def test_extracts_public_methods(self) -> None:
|
||||
"""Verify public method signatures are extracted."""
|
||||
source = """public class Foo {
|
||||
public int add(int a, int b) {
|
||||
|
|
@ -2398,7 +2401,7 @@ class TestExtractPublicMethodSignatures:
|
|||
# private method should not be included
|
||||
assert not any("secret" in s for s in sigs)
|
||||
|
||||
def test_excludes_constructors(self):
|
||||
def test_excludes_constructors(self) -> None:
|
||||
"""Verify constructors are excluded from method signatures."""
|
||||
source = """public class Bar {
|
||||
public Bar(int x) { this.x = x; }
|
||||
|
|
@ -2411,7 +2414,7 @@ class TestExtractPublicMethodSignatures:
|
|||
assert "getX" in sigs[0]
|
||||
assert not any("Bar(" in s for s in sigs)
|
||||
|
||||
def test_empty_class_returns_empty(self):
|
||||
def test_empty_class_returns_empty(self) -> None:
|
||||
"""Verify empty class returns no signatures."""
|
||||
source = """public class Empty {}"""
|
||||
analyzer = get_java_analyzer()
|
||||
|
|
@ -2419,7 +2422,7 @@ class TestExtractPublicMethodSignatures:
|
|||
|
||||
assert sigs == []
|
||||
|
||||
def test_filters_by_class_name(self):
|
||||
def test_filters_by_class_name(self) -> None:
|
||||
"""Verify only methods from the specified class are returned."""
|
||||
source = """public class A {
|
||||
public int aMethod() { return 1; }
|
||||
|
|
@ -2440,7 +2443,7 @@ class B {
|
|||
class TestFormatSkeletonForContext:
|
||||
"""Tests for _format_skeleton_for_context()."""
|
||||
|
||||
def test_formats_basic_skeleton(self):
|
||||
def test_formats_basic_skeleton(self) -> None:
|
||||
"""Verify basic skeleton formatting with fields and constructors."""
|
||||
source = """public class Widget {
|
||||
private int size;
|
||||
|
|
@ -2467,7 +2470,7 @@ class TestFormatSkeletonForContext:
|
|||
assert "getSize" in result
|
||||
assert result.endswith("}")
|
||||
|
||||
def test_formats_enum_skeleton(self):
|
||||
def test_formats_enum_skeleton(self) -> None:
|
||||
"""Verify enum formatting includes constants."""
|
||||
source = """public enum Color {
|
||||
RED, GREEN, BLUE;
|
||||
|
|
@ -2490,7 +2493,7 @@ class TestFormatSkeletonForContext:
|
|||
assert "RED, GREEN, BLUE;" in result
|
||||
assert "lower" in result
|
||||
|
||||
def test_formats_empty_class(self):
|
||||
def test_formats_empty_class(self) -> None:
|
||||
"""Verify formatting of a class with no fields or methods."""
|
||||
source = """public class Empty {}"""
|
||||
analyzer = get_java_analyzer()
|
||||
|
|
@ -2512,7 +2515,7 @@ class TestFormatSkeletonForContext:
|
|||
class TestGetJavaImportedTypeSkeletonsEdgeCases:
|
||||
"""Additional edge case tests for get_java_imported_type_skeletons()."""
|
||||
|
||||
def test_wildcard_imports_are_expanded(self):
|
||||
def test_wildcard_imports_are_expanded(self) -> None:
|
||||
"""Wildcard imports (e.g., import com.example.helpers.*) are expanded to individual types."""
|
||||
project_root = FIXTURE_DIR
|
||||
module_root = FIXTURE_DIR / "src" / "main" / "java"
|
||||
|
|
@ -2530,7 +2533,37 @@ class TestGetJavaImportedTypeSkeletonsEdgeCases:
|
|||
# Wildcard imports should now be expanded to individual classes found in the package directory
|
||||
assert "MathHelper" in result
|
||||
|
||||
def test_import_to_nonexistent_class_in_file(self):
|
||||
def test_large_wildcard_is_filtered_to_referenced_types(self, tmp_path: Path) -> None:
|
||||
"""When wildcard expands to >50 types, only types referenced in target code are included."""
|
||||
from codeflash.languages.java.context import MAX_WILDCARD_TYPES_UNFILTERED
|
||||
|
||||
# Create a minimal Maven project structure so the resolver finds source roots
|
||||
(tmp_path / "pom.xml").write_text("<project/>", encoding="utf-8")
|
||||
pkg_dir = tmp_path / "src" / "main" / "java" / "com" / "bigpkg"
|
||||
pkg_dir.mkdir(parents=True)
|
||||
for i in range(MAX_WILDCARD_TYPES_UNFILTERED + 20):
|
||||
(pkg_dir / f"Type{i:03d}.java").write_text(
|
||||
f"package com.bigpkg;\npublic class Type{i:03d} {{ public int val() {{ return {i}; }} }}\n",
|
||||
encoding="utf-8",
|
||||
)
|
||||
|
||||
analyzer = get_java_analyzer()
|
||||
# Target code references Type000 and Type001 only
|
||||
target_code = "Type000 a = new Type000(); Type001 b = a.transform();"
|
||||
source = "package com.example;\nimport com.bigpkg.*;\npublic class Foo { void bar() {} }"
|
||||
imports = analyzer.find_imports(source)
|
||||
|
||||
result = get_java_imported_type_skeletons(
|
||||
imports, tmp_path, tmp_path / "src" / "main" / "java", analyzer, target_code=target_code
|
||||
)
|
||||
|
||||
# Only referenced types should appear, not all 70
|
||||
assert "Type000" in result
|
||||
assert "Type001" in result
|
||||
# Types not referenced in target code should be excluded
|
||||
assert "Type050" not in result
|
||||
|
||||
def test_import_to_nonexistent_class_in_file(self) -> None:
|
||||
"""When an import resolves to a file but the class doesn't exist in it, skeleton extraction returns None."""
|
||||
analyzer = get_java_analyzer()
|
||||
|
||||
|
|
@ -2540,7 +2573,7 @@ class TestGetJavaImportedTypeSkeletonsEdgeCases:
|
|||
|
||||
assert skeleton is None
|
||||
|
||||
def test_skeleton_output_is_well_formed(self):
|
||||
def test_skeleton_output_is_well_formed(self) -> None:
|
||||
"""Verify the skeleton string has proper Java-like structure with braces."""
|
||||
project_root = FIXTURE_DIR
|
||||
module_root = FIXTURE_DIR / "src" / "main" / "java"
|
||||
|
|
@ -2563,7 +2596,7 @@ class TestGetJavaImportedTypeSkeletonsEdgeCases:
|
|||
class TestExtractPublicMethodSignaturesEdgeCases:
|
||||
"""Additional edge case tests for _extract_public_method_signatures()."""
|
||||
|
||||
def test_excludes_protected_and_package_private(self):
|
||||
def test_excludes_protected_and_package_private(self) -> None:
|
||||
"""Verify protected and package-private methods are excluded."""
|
||||
source = """public class Visibility {
|
||||
public int publicMethod() { return 1; }
|
||||
|
|
@ -2580,7 +2613,7 @@ class TestExtractPublicMethodSignaturesEdgeCases:
|
|||
assert not any("packagePrivateMethod" in s for s in sigs)
|
||||
assert not any("privateMethod" in s for s in sigs)
|
||||
|
||||
def test_handles_overloaded_methods(self):
|
||||
def test_handles_overloaded_methods(self) -> None:
|
||||
"""Verify all public overloads are extracted."""
|
||||
source = """public class Overloaded {
|
||||
public int process(int x) { return x; }
|
||||
|
|
@ -2594,7 +2627,7 @@ class TestExtractPublicMethodSignaturesEdgeCases:
|
|||
# All should contain "process"
|
||||
assert all("process" in s for s in sigs)
|
||||
|
||||
def test_handles_generic_methods(self):
|
||||
def test_handles_generic_methods(self) -> None:
|
||||
"""Verify generic method signatures are extracted correctly."""
|
||||
source = """public class Generic {
|
||||
public <T> T identity(T value) { return value; }
|
||||
|
|
@ -2611,7 +2644,7 @@ class TestExtractPublicMethodSignaturesEdgeCases:
|
|||
class TestFormatSkeletonRoundTrip:
|
||||
"""Tests that verify _extract_type_skeleton → _format_skeleton_for_context produces valid output."""
|
||||
|
||||
def test_round_trip_produces_valid_skeleton(self):
|
||||
def test_round_trip_produces_valid_skeleton(self) -> None:
|
||||
"""Extract a real skeleton and format it — verify the output is sensible."""
|
||||
source = """public class Service {
|
||||
private final String name;
|
||||
|
|
@ -2659,7 +2692,7 @@ class TestFormatSkeletonRoundTrip:
|
|||
# Should end properly
|
||||
assert result.strip().endswith("}")
|
||||
|
||||
def test_round_trip_with_fixture_mathhelper(self):
|
||||
def test_round_trip_with_fixture_mathhelper(self) -> None:
|
||||
"""Round-trip test using the real MathHelper fixture file."""
|
||||
source = (FIXTURE_DIR / "src" / "main" / "java" / "com" / "example" / "helpers" / "MathHelper.java").read_text()
|
||||
analyzer = get_java_analyzer()
|
||||
|
|
|
|||
Loading…
Reference in a new issue