mirror of
https://github.com/codeflash-ai/codeflash.git
synced 2026-05-04 18:25:17 +00:00
Merge pull request #1951 from codeflash-ai/cf-1085-cap-wildcard-import-expansion
fix: cap wildcard import expansion to avoid token explosion
This commit is contained in:
commit
0a2ec48fa3
3 changed files with 193 additions and 120 deletions
|
|
@ -11,10 +11,11 @@ import logging
|
||||||
from typing import TYPE_CHECKING
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
from codeflash.code_utils.code_utils import encoded_tokens_len
|
from codeflash.code_utils.code_utils import encoded_tokens_len
|
||||||
from codeflash.languages.base import CodeContext, HelperFunction, Language
|
from codeflash.languages.base import CodeContext, HelperFunction
|
||||||
from codeflash.languages.java.discovery import discover_functions_from_source
|
from codeflash.languages.java.discovery import discover_functions_from_source
|
||||||
from codeflash.languages.java.import_resolver import JavaImportResolver, find_helper_files
|
from codeflash.languages.java.import_resolver import JavaImportResolver, find_helper_files
|
||||||
from codeflash.languages.java.parser import get_java_analyzer
|
from codeflash.languages.java.parser import get_java_analyzer
|
||||||
|
from codeflash.languages.language_enum import Language
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
@ -22,7 +23,8 @@ if TYPE_CHECKING:
|
||||||
from tree_sitter import Node
|
from tree_sitter import Node
|
||||||
|
|
||||||
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
|
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
|
||||||
from codeflash.languages.java.parser import JavaAnalyzer, JavaMethodNode
|
from codeflash.languages.java.import_resolver import ResolvedImport
|
||||||
|
from codeflash.languages.java.parser import JavaAnalyzer, JavaImportInfo, JavaMethodNode
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
@ -360,7 +362,7 @@ def _extract_type_declaration(type_node: Node, source_bytes: bytes, type_kind: s
|
||||||
|
|
||||||
|
|
||||||
# Keep old function name for backwards compatibility
|
# Keep old function name for backwards compatibility
|
||||||
def _extract_class_declaration(node, source_bytes):
|
def _extract_class_declaration(node: Node, source_bytes: bytes) -> str:
|
||||||
return _extract_type_declaration(node, source_bytes, "class")
|
return _extract_type_declaration(node, source_bytes, "class")
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -629,6 +631,8 @@ def _extract_function_source_by_lines(source: str, function: FunctionToOptimize)
|
||||||
|
|
||||||
start_line = function.doc_start_line or function.starting_line
|
start_line = function.doc_start_line or function.starting_line
|
||||||
end_line = function.ending_line
|
end_line = function.ending_line
|
||||||
|
if start_line is None or end_line is None:
|
||||||
|
return ""
|
||||||
|
|
||||||
# Convert from 1-indexed to 0-indexed
|
# Convert from 1-indexed to 0-indexed
|
||||||
start_idx = start_line - 1
|
start_idx = start_line - 1
|
||||||
|
|
@ -672,6 +676,8 @@ def find_helper_functions(
|
||||||
func_id = f"{file_path}:{func.qualified_name}"
|
func_id = f"{file_path}:{func.qualified_name}"
|
||||||
if func_id not in visited_functions:
|
if func_id not in visited_functions:
|
||||||
visited_functions.add(func_id)
|
visited_functions.add(func_id)
|
||||||
|
if func.starting_line is None or func.ending_line is None:
|
||||||
|
continue
|
||||||
|
|
||||||
# Extract the function source using tree-sitter for resilient lookup
|
# Extract the function source using tree-sitter for resilient lookup
|
||||||
func_source = extract_function_source(source, func, analyzer=analyzer)
|
func_source = extract_function_source(source, func, analyzer=analyzer)
|
||||||
|
|
@ -795,7 +801,7 @@ def extract_read_only_context(source: str, function: FunctionToOptimize, analyze
|
||||||
return "\n".join(context_parts)
|
return "\n".join(context_parts)
|
||||||
|
|
||||||
|
|
||||||
def _import_to_statement(import_info) -> str:
|
def _import_to_statement(import_info: JavaImportInfo) -> str:
|
||||||
"""Convert a JavaImportInfo to an import statement string.
|
"""Convert a JavaImportInfo to an import statement string.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
|
@ -863,6 +869,10 @@ def extract_class_context(file_path: Path, class_name: str, analyzer: JavaAnalyz
|
||||||
|
|
||||||
# Maximum token budget for imported type skeletons to avoid bloating testgen context
|
# Maximum token budget for imported type skeletons to avoid bloating testgen context
|
||||||
IMPORTED_SKELETON_TOKEN_BUDGET = 4000
|
IMPORTED_SKELETON_TOKEN_BUDGET = 4000
|
||||||
|
# Maximum types to expand from a single wildcard import before filtering to referenced types only.
|
||||||
|
# Packages with more types than this (e.g. org.jooq with 870+) would waste minutes of disk I/O
|
||||||
|
# and almost always exceed the token budget.
|
||||||
|
MAX_WILDCARD_TYPES_UNFILTERED = 50
|
||||||
|
|
||||||
|
|
||||||
def _extract_type_names_from_code(code: str, analyzer: JavaAnalyzer) -> set[str]:
|
def _extract_type_names_from_code(code: str, analyzer: JavaAnalyzer) -> set[str]:
|
||||||
|
|
@ -894,7 +904,11 @@ def _extract_type_names_from_code(code: str, analyzer: JavaAnalyzer) -> set[str]
|
||||||
|
|
||||||
|
|
||||||
def get_java_imported_type_skeletons(
|
def get_java_imported_type_skeletons(
|
||||||
imports: list, project_root: Path, module_root: Path | None, analyzer: JavaAnalyzer, target_code: str = ""
|
imports: list[JavaImportInfo],
|
||||||
|
project_root: Path,
|
||||||
|
module_root: Path | None,
|
||||||
|
analyzer: JavaAnalyzer,
|
||||||
|
target_code: str = "",
|
||||||
) -> str:
|
) -> str:
|
||||||
"""Extract type skeletons for project-internal imported types.
|
"""Extract type skeletons for project-internal imported types.
|
||||||
|
|
||||||
|
|
@ -929,14 +943,32 @@ def get_java_imported_type_skeletons(
|
||||||
priority_types = _extract_type_names_from_code(target_code, analyzer)
|
priority_types = _extract_type_names_from_code(target_code, analyzer)
|
||||||
|
|
||||||
# Pre-resolve all imports, expanding wildcards into individual types
|
# Pre-resolve all imports, expanding wildcards into individual types
|
||||||
resolved_imports: list = []
|
resolved_imports: list[ResolvedImport] = []
|
||||||
for imp in imports:
|
for imp in imports:
|
||||||
if imp.is_wildcard:
|
if imp.is_wildcard:
|
||||||
# Expand wildcard imports (e.g., com.aerospike.client.policy.*) into individual types
|
# First try unfiltered expansion with a cap. If the package is small enough, take all types.
|
||||||
expanded = resolver.expand_wildcard_import(imp.import_path)
|
# If it's huge (e.g. org.jooq.* with 870+ types), filter to only types referenced in the target code.
|
||||||
|
expanded = resolver.expand_wildcard_import(imp.import_path, max_types=MAX_WILDCARD_TYPES_UNFILTERED + 1)
|
||||||
|
if len(expanded) > MAX_WILDCARD_TYPES_UNFILTERED:
|
||||||
|
if priority_types:
|
||||||
|
expanded = resolver.expand_wildcard_import(imp.import_path, filter_names=priority_types)
|
||||||
|
logger.debug(
|
||||||
|
"Wildcard %s.* exceeds %d types, filtered to %d referenced types",
|
||||||
|
imp.import_path,
|
||||||
|
MAX_WILDCARD_TYPES_UNFILTERED,
|
||||||
|
len(expanded),
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
expanded = expanded[:MAX_WILDCARD_TYPES_UNFILTERED]
|
||||||
|
logger.debug(
|
||||||
|
"Wildcard %s.* exceeds %d types, capped (no target types to filter by)",
|
||||||
|
imp.import_path,
|
||||||
|
MAX_WILDCARD_TYPES_UNFILTERED,
|
||||||
|
)
|
||||||
|
elif expanded:
|
||||||
|
logger.debug("Expanded wildcard import %s.* into %d types", imp.import_path, len(expanded))
|
||||||
if expanded:
|
if expanded:
|
||||||
resolved_imports.extend(expanded)
|
resolved_imports.extend(expanded)
|
||||||
logger.debug("Expanded wildcard import %s.* into %d types", imp.import_path, len(expanded))
|
|
||||||
continue
|
continue
|
||||||
|
|
||||||
resolved = resolver.resolve_import(imp)
|
resolved = resolver.resolve_import(imp)
|
||||||
|
|
@ -956,7 +988,7 @@ def get_java_imported_type_skeletons(
|
||||||
|
|
||||||
for resolved in resolved_imports:
|
for resolved in resolved_imports:
|
||||||
class_name = resolved.class_name
|
class_name = resolved.class_name
|
||||||
if not class_name:
|
if not class_name or resolved.file_path is None:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
dedup_key = (str(resolved.file_path), class_name)
|
dedup_key = (str(resolved.file_path), class_name)
|
||||||
|
|
@ -1078,8 +1110,6 @@ def _extract_public_method_signatures(source: str, class_name: str, analyzer: Ja
|
||||||
continue
|
continue
|
||||||
|
|
||||||
node = method.node
|
node = method.node
|
||||||
if not node:
|
|
||||||
continue
|
|
||||||
|
|
||||||
# Check if the method is public
|
# Check if the method is public
|
||||||
is_public = False
|
is_public = False
|
||||||
|
|
|
||||||
|
|
@ -220,14 +220,20 @@ class JavaImportResolver:
|
||||||
return last_part
|
return last_part
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def expand_wildcard_import(self, import_path: str) -> list[ResolvedImport]:
|
def expand_wildcard_import(
|
||||||
|
self, import_path: str, max_types: int = 0, filter_names: set[str] | None = None
|
||||||
|
) -> list[ResolvedImport]:
|
||||||
"""Expand a wildcard import (e.g., com.example.utils.*) to individual class imports.
|
"""Expand a wildcard import (e.g., com.example.utils.*) to individual class imports.
|
||||||
|
|
||||||
Resolves the package path to a directory and returns a ResolvedImport for each
|
Resolves the package path to a directory and returns a ResolvedImport for each
|
||||||
.java file found in that directory.
|
.java file found in that directory.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
import_path: The package path (without the trailing .*).
|
||||||
|
max_types: Maximum number of types to return. 0 means no limit.
|
||||||
|
filter_names: If provided, only include types whose class name is in this set.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
# Convert package path to directory path
|
|
||||||
# e.g., "com.example.utils" -> "com/example/utils"
|
|
||||||
relative_dir = import_path.replace(".", "/")
|
relative_dir = import_path.replace(".", "/")
|
||||||
|
|
||||||
resolved: list[ResolvedImport] = []
|
resolved: list[ResolvedImport] = []
|
||||||
|
|
@ -237,8 +243,10 @@ class JavaImportResolver:
|
||||||
if candidate_dir.is_dir():
|
if candidate_dir.is_dir():
|
||||||
for java_file in candidate_dir.glob("*.java"):
|
for java_file in candidate_dir.glob("*.java"):
|
||||||
class_name = java_file.stem
|
class_name = java_file.stem
|
||||||
# Only include files that look like class names (start with uppercase)
|
if not class_name or not class_name[0].isupper():
|
||||||
if class_name and class_name[0].isupper():
|
continue
|
||||||
|
if filter_names is not None and class_name not in filter_names:
|
||||||
|
continue
|
||||||
resolved.append(
|
resolved.append(
|
||||||
ResolvedImport(
|
ResolvedImport(
|
||||||
import_path=f"{import_path}.{class_name}",
|
import_path=f"{import_path}.{class_name}",
|
||||||
|
|
@ -248,6 +256,8 @@ class JavaImportResolver:
|
||||||
class_name=class_name,
|
class_name=class_name,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
if max_types and len(resolved) >= max_types:
|
||||||
|
return resolved
|
||||||
|
|
||||||
return resolved
|
return resolved
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -2,7 +2,8 @@
|
||||||
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
from codeflash.languages.base import FunctionFilterCriteria, Language
|
from codeflash.languages.base import FunctionFilterCriteria
|
||||||
|
from codeflash.languages.language_enum import Language
|
||||||
from codeflash.languages.java.context import (
|
from codeflash.languages.java.context import (
|
||||||
TypeSkeleton,
|
TypeSkeleton,
|
||||||
_extract_public_method_signatures,
|
_extract_public_method_signatures,
|
||||||
|
|
@ -23,7 +24,7 @@ NO_RETURN_FILTER = FunctionFilterCriteria(require_return=False)
|
||||||
class TestExtractCodeContextBasic:
|
class TestExtractCodeContextBasic:
|
||||||
"""Tests for basic extract_code_context functionality."""
|
"""Tests for basic extract_code_context functionality."""
|
||||||
|
|
||||||
def test_simple_method(self, tmp_path: Path):
|
def test_simple_method(self, tmp_path: Path) -> None:
|
||||||
"""Test extracting context for a simple method."""
|
"""Test extracting context for a simple method."""
|
||||||
java_file = tmp_path / "Calculator.java"
|
java_file = tmp_path / "Calculator.java"
|
||||||
java_file.write_text("""public class Calculator {
|
java_file.write_text("""public class Calculator {
|
||||||
|
|
@ -53,7 +54,7 @@ class TestExtractCodeContextBasic:
|
||||||
assert context.helper_functions == []
|
assert context.helper_functions == []
|
||||||
assert context.read_only_context == ""
|
assert context.read_only_context == ""
|
||||||
|
|
||||||
def test_method_with_javadoc(self, tmp_path: Path):
|
def test_method_with_javadoc(self, tmp_path: Path) -> None:
|
||||||
"""Test extracting context for method with Javadoc."""
|
"""Test extracting context for method with Javadoc."""
|
||||||
java_file = tmp_path / "Calculator.java"
|
java_file = tmp_path / "Calculator.java"
|
||||||
java_file.write_text("""public class Calculator {
|
java_file.write_text("""public class Calculator {
|
||||||
|
|
@ -94,7 +95,7 @@ class TestExtractCodeContextBasic:
|
||||||
assert context.helper_functions == []
|
assert context.helper_functions == []
|
||||||
assert context.read_only_context == ""
|
assert context.read_only_context == ""
|
||||||
|
|
||||||
def test_static_method(self, tmp_path: Path):
|
def test_static_method(self, tmp_path: Path) -> None:
|
||||||
"""Test extracting context for a static method."""
|
"""Test extracting context for a static method."""
|
||||||
java_file = tmp_path / "MathUtils.java"
|
java_file = tmp_path / "MathUtils.java"
|
||||||
java_file.write_text("""public class MathUtils {
|
java_file.write_text("""public class MathUtils {
|
||||||
|
|
@ -123,7 +124,7 @@ class TestExtractCodeContextBasic:
|
||||||
assert context.helper_functions == []
|
assert context.helper_functions == []
|
||||||
assert context.read_only_context == ""
|
assert context.read_only_context == ""
|
||||||
|
|
||||||
def test_private_method(self, tmp_path: Path):
|
def test_private_method(self, tmp_path: Path) -> None:
|
||||||
"""Test extracting context for a private method."""
|
"""Test extracting context for a private method."""
|
||||||
java_file = tmp_path / "Helper.java"
|
java_file = tmp_path / "Helper.java"
|
||||||
java_file.write_text("""public class Helper {
|
java_file.write_text("""public class Helper {
|
||||||
|
|
@ -149,7 +150,7 @@ class TestExtractCodeContextBasic:
|
||||||
"""
|
"""
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_protected_method(self, tmp_path: Path):
|
def test_protected_method(self, tmp_path: Path) -> None:
|
||||||
"""Test extracting context for a protected method."""
|
"""Test extracting context for a protected method."""
|
||||||
java_file = tmp_path / "Base.java"
|
java_file = tmp_path / "Base.java"
|
||||||
java_file.write_text("""public class Base {
|
java_file.write_text("""public class Base {
|
||||||
|
|
@ -175,7 +176,7 @@ class TestExtractCodeContextBasic:
|
||||||
"""
|
"""
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_synchronized_method(self, tmp_path: Path):
|
def test_synchronized_method(self, tmp_path: Path) -> None:
|
||||||
"""Test extracting context for a synchronized method."""
|
"""Test extracting context for a synchronized method."""
|
||||||
java_file = tmp_path / "Counter.java"
|
java_file = tmp_path / "Counter.java"
|
||||||
java_file.write_text("""public class Counter {
|
java_file.write_text("""public class Counter {
|
||||||
|
|
@ -200,7 +201,7 @@ class TestExtractCodeContextBasic:
|
||||||
"""
|
"""
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_method_with_throws(self, tmp_path: Path):
|
def test_method_with_throws(self, tmp_path: Path) -> None:
|
||||||
"""Test extracting context for a method with throws clause."""
|
"""Test extracting context for a method with throws clause."""
|
||||||
java_file = tmp_path / "FileHandler.java"
|
java_file = tmp_path / "FileHandler.java"
|
||||||
java_file.write_text("""public class FileHandler {
|
java_file.write_text("""public class FileHandler {
|
||||||
|
|
@ -225,7 +226,7 @@ class TestExtractCodeContextBasic:
|
||||||
"""
|
"""
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_method_with_varargs(self, tmp_path: Path):
|
def test_method_with_varargs(self, tmp_path: Path) -> None:
|
||||||
"""Test extracting context for a method with varargs."""
|
"""Test extracting context for a method with varargs."""
|
||||||
java_file = tmp_path / "Logger.java"
|
java_file = tmp_path / "Logger.java"
|
||||||
java_file.write_text("""public class Logger {
|
java_file.write_text("""public class Logger {
|
||||||
|
|
@ -250,7 +251,7 @@ class TestExtractCodeContextBasic:
|
||||||
"""
|
"""
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_void_method(self, tmp_path: Path):
|
def test_void_method(self, tmp_path: Path) -> None:
|
||||||
"""Test extracting context for a void method."""
|
"""Test extracting context for a void method."""
|
||||||
java_file = tmp_path / "Printer.java"
|
java_file = tmp_path / "Printer.java"
|
||||||
java_file.write_text("""public class Printer {
|
java_file.write_text("""public class Printer {
|
||||||
|
|
@ -277,7 +278,7 @@ class TestExtractCodeContextBasic:
|
||||||
"""
|
"""
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_generic_return_type(self, tmp_path: Path):
|
def test_generic_return_type(self, tmp_path: Path) -> None:
|
||||||
"""Test extracting context for a method with generic return type."""
|
"""Test extracting context for a method with generic return type."""
|
||||||
java_file = tmp_path / "Container.java"
|
java_file = tmp_path / "Container.java"
|
||||||
java_file.write_text("""public class Container {
|
java_file.write_text("""public class Container {
|
||||||
|
|
@ -306,7 +307,7 @@ class TestExtractCodeContextBasic:
|
||||||
class TestExtractCodeContextWithImports:
|
class TestExtractCodeContextWithImports:
|
||||||
"""Tests for extract_code_context with various import types."""
|
"""Tests for extract_code_context with various import types."""
|
||||||
|
|
||||||
def test_with_package_and_imports(self, tmp_path: Path):
|
def test_with_package_and_imports(self, tmp_path: Path) -> None:
|
||||||
"""Test context extraction with package and imports."""
|
"""Test context extraction with package and imports."""
|
||||||
java_file = tmp_path / "Calculator.java"
|
java_file = tmp_path / "Calculator.java"
|
||||||
java_file.write_text("""package com.example;
|
java_file.write_text("""package com.example;
|
||||||
|
|
@ -344,7 +345,7 @@ public class Calculator {
|
||||||
# Fields are in skeleton, so read_only_context is empty
|
# Fields are in skeleton, so read_only_context is empty
|
||||||
assert context.read_only_context == ""
|
assert context.read_only_context == ""
|
||||||
|
|
||||||
def test_with_static_imports(self, tmp_path: Path):
|
def test_with_static_imports(self, tmp_path: Path) -> None:
|
||||||
"""Test context extraction with static imports."""
|
"""Test context extraction with static imports."""
|
||||||
java_file = tmp_path / "Calculator.java"
|
java_file = tmp_path / "Calculator.java"
|
||||||
java_file.write_text("""package com.example;
|
java_file.write_text("""package com.example;
|
||||||
|
|
@ -380,7 +381,7 @@ public class Calculator {
|
||||||
"import static java.lang.Math.sqrt;",
|
"import static java.lang.Math.sqrt;",
|
||||||
]
|
]
|
||||||
|
|
||||||
def test_with_wildcard_imports(self, tmp_path: Path):
|
def test_with_wildcard_imports(self, tmp_path: Path) -> None:
|
||||||
"""Test context extraction with wildcard imports."""
|
"""Test context extraction with wildcard imports."""
|
||||||
java_file = tmp_path / "Processor.java"
|
java_file = tmp_path / "Processor.java"
|
||||||
java_file.write_text("""package com.example;
|
java_file.write_text("""package com.example;
|
||||||
|
|
@ -402,7 +403,7 @@ public class Processor {
|
||||||
assert context.language == Language.JAVA
|
assert context.language == Language.JAVA
|
||||||
assert context.imports == ["import java.util.*;", "import java.io.*;"]
|
assert context.imports == ["import java.util.*;", "import java.io.*;"]
|
||||||
|
|
||||||
def test_with_multiple_import_types(self, tmp_path: Path):
|
def test_with_multiple_import_types(self, tmp_path: Path) -> None:
|
||||||
"""Test context extraction with various import types."""
|
"""Test context extraction with various import types."""
|
||||||
java_file = tmp_path / "Handler.java"
|
java_file = tmp_path / "Handler.java"
|
||||||
java_file.write_text("""package com.example;
|
java_file.write_text("""package com.example;
|
||||||
|
|
@ -453,7 +454,7 @@ class TestExtractCodeContextWithFields:
|
||||||
read_only_context should be empty to avoid duplication.
|
read_only_context should be empty to avoid duplication.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def test_with_instance_fields(self, tmp_path: Path):
|
def test_with_instance_fields(self, tmp_path: Path) -> None:
|
||||||
"""Test context extraction with instance fields."""
|
"""Test context extraction with instance fields."""
|
||||||
java_file = tmp_path / "Person.java"
|
java_file = tmp_path / "Person.java"
|
||||||
java_file.write_text("""public class Person {
|
java_file.write_text("""public class Person {
|
||||||
|
|
@ -488,7 +489,7 @@ class TestExtractCodeContextWithFields:
|
||||||
assert context.imports == []
|
assert context.imports == []
|
||||||
assert context.helper_functions == []
|
assert context.helper_functions == []
|
||||||
|
|
||||||
def test_with_static_fields(self, tmp_path: Path):
|
def test_with_static_fields(self, tmp_path: Path) -> None:
|
||||||
"""Test context extraction with static fields."""
|
"""Test context extraction with static fields."""
|
||||||
java_file = tmp_path / "Counter.java"
|
java_file = tmp_path / "Counter.java"
|
||||||
java_file.write_text("""public class Counter {
|
java_file.write_text("""public class Counter {
|
||||||
|
|
@ -519,7 +520,7 @@ class TestExtractCodeContextWithFields:
|
||||||
# Fields are in skeleton, so read_only_context is empty
|
# Fields are in skeleton, so read_only_context is empty
|
||||||
assert context.read_only_context == ""
|
assert context.read_only_context == ""
|
||||||
|
|
||||||
def test_with_final_fields(self, tmp_path: Path):
|
def test_with_final_fields(self, tmp_path: Path) -> None:
|
||||||
"""Test context extraction with final fields."""
|
"""Test context extraction with final fields."""
|
||||||
java_file = tmp_path / "Config.java"
|
java_file = tmp_path / "Config.java"
|
||||||
java_file.write_text("""public class Config {
|
java_file.write_text("""public class Config {
|
||||||
|
|
@ -549,7 +550,7 @@ class TestExtractCodeContextWithFields:
|
||||||
)
|
)
|
||||||
assert context.read_only_context == ""
|
assert context.read_only_context == ""
|
||||||
|
|
||||||
def test_with_static_final_constants(self, tmp_path: Path):
|
def test_with_static_final_constants(self, tmp_path: Path) -> None:
|
||||||
"""Test context extraction with static final constants."""
|
"""Test context extraction with static final constants."""
|
||||||
java_file = tmp_path / "Constants.java"
|
java_file = tmp_path / "Constants.java"
|
||||||
java_file.write_text("""public class Constants {
|
java_file.write_text("""public class Constants {
|
||||||
|
|
@ -581,7 +582,7 @@ class TestExtractCodeContextWithFields:
|
||||||
)
|
)
|
||||||
assert context.read_only_context == ""
|
assert context.read_only_context == ""
|
||||||
|
|
||||||
def test_with_volatile_fields(self, tmp_path: Path):
|
def test_with_volatile_fields(self, tmp_path: Path) -> None:
|
||||||
"""Test context extraction with volatile fields."""
|
"""Test context extraction with volatile fields."""
|
||||||
java_file = tmp_path / "ThreadSafe.java"
|
java_file = tmp_path / "ThreadSafe.java"
|
||||||
java_file.write_text("""public class ThreadSafe {
|
java_file.write_text("""public class ThreadSafe {
|
||||||
|
|
@ -611,7 +612,7 @@ class TestExtractCodeContextWithFields:
|
||||||
)
|
)
|
||||||
assert context.read_only_context == ""
|
assert context.read_only_context == ""
|
||||||
|
|
||||||
def test_with_generic_fields(self, tmp_path: Path):
|
def test_with_generic_fields(self, tmp_path: Path) -> None:
|
||||||
"""Test context extraction with generic type fields."""
|
"""Test context extraction with generic type fields."""
|
||||||
java_file = tmp_path / "Container.java"
|
java_file = tmp_path / "Container.java"
|
||||||
java_file.write_text("""public class Container {
|
java_file.write_text("""public class Container {
|
||||||
|
|
@ -643,7 +644,7 @@ class TestExtractCodeContextWithFields:
|
||||||
)
|
)
|
||||||
assert context.read_only_context == ""
|
assert context.read_only_context == ""
|
||||||
|
|
||||||
def test_with_array_fields(self, tmp_path: Path):
|
def test_with_array_fields(self, tmp_path: Path) -> None:
|
||||||
"""Test context extraction with array fields."""
|
"""Test context extraction with array fields."""
|
||||||
java_file = tmp_path / "ArrayHolder.java"
|
java_file = tmp_path / "ArrayHolder.java"
|
||||||
java_file.write_text("""public class ArrayHolder {
|
java_file.write_text("""public class ArrayHolder {
|
||||||
|
|
@ -679,7 +680,7 @@ class TestExtractCodeContextWithFields:
|
||||||
class TestExtractCodeContextWithHelpers:
|
class TestExtractCodeContextWithHelpers:
|
||||||
"""Tests for extract_code_context with helper functions."""
|
"""Tests for extract_code_context with helper functions."""
|
||||||
|
|
||||||
def test_single_helper_method(self, tmp_path: Path):
|
def test_single_helper_method(self, tmp_path: Path) -> None:
|
||||||
"""Test context extraction with a single helper method."""
|
"""Test context extraction with a single helper method."""
|
||||||
java_file = tmp_path / "Processor.java"
|
java_file = tmp_path / "Processor.java"
|
||||||
java_file.write_text("""public class Processor {
|
java_file.write_text("""public class Processor {
|
||||||
|
|
@ -715,7 +716,7 @@ class TestExtractCodeContextWithHelpers:
|
||||||
== "private String normalize(String s) {\n return s.trim().toLowerCase();\n }"
|
== "private String normalize(String s) {\n return s.trim().toLowerCase();\n }"
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_multiple_helper_methods(self, tmp_path: Path):
|
def test_multiple_helper_methods(self, tmp_path: Path) -> None:
|
||||||
"""Test context extraction with multiple helper methods."""
|
"""Test context extraction with multiple helper methods."""
|
||||||
java_file = tmp_path / "Processor.java"
|
java_file = tmp_path / "Processor.java"
|
||||||
java_file.write_text("""public class Processor {
|
java_file.write_text("""public class Processor {
|
||||||
|
|
@ -758,7 +759,7 @@ class TestExtractCodeContextWithHelpers:
|
||||||
helper_names = sorted([h.name for h in context.helper_functions])
|
helper_names = sorted([h.name for h in context.helper_functions])
|
||||||
assert helper_names == ["trim", "upper"]
|
assert helper_names == ["trim", "upper"]
|
||||||
|
|
||||||
def test_chained_helper_calls(self, tmp_path: Path):
|
def test_chained_helper_calls(self, tmp_path: Path) -> None:
|
||||||
"""Test context extraction with chained helper calls."""
|
"""Test context extraction with chained helper calls."""
|
||||||
java_file = tmp_path / "Processor.java"
|
java_file = tmp_path / "Processor.java"
|
||||||
java_file.write_text("""public class Processor {
|
java_file.write_text("""public class Processor {
|
||||||
|
|
@ -784,7 +785,7 @@ class TestExtractCodeContextWithHelpers:
|
||||||
helper_names = [h.name for h in context.helper_functions]
|
helper_names = [h.name for h in context.helper_functions]
|
||||||
assert helper_names == ["normalize"]
|
assert helper_names == ["normalize"]
|
||||||
|
|
||||||
def test_no_helpers_when_none_called(self, tmp_path: Path):
|
def test_no_helpers_when_none_called(self, tmp_path: Path) -> None:
|
||||||
"""Test context extraction when no helpers are called."""
|
"""Test context extraction when no helpers are called."""
|
||||||
java_file = tmp_path / "Calculator.java"
|
java_file = tmp_path / "Calculator.java"
|
||||||
java_file.write_text("""public class Calculator {
|
java_file.write_text("""public class Calculator {
|
||||||
|
|
@ -814,7 +815,7 @@ class TestExtractCodeContextWithHelpers:
|
||||||
)
|
)
|
||||||
assert context.helper_functions == []
|
assert context.helper_functions == []
|
||||||
|
|
||||||
def test_static_helper_from_instance_method(self, tmp_path: Path):
|
def test_static_helper_from_instance_method(self, tmp_path: Path) -> None:
|
||||||
"""Test context extraction with static helper called from instance method."""
|
"""Test context extraction with static helper called from instance method."""
|
||||||
java_file = tmp_path / "Calculator.java"
|
java_file = tmp_path / "Calculator.java"
|
||||||
java_file.write_text("""public class Calculator {
|
java_file.write_text("""public class Calculator {
|
||||||
|
|
@ -840,7 +841,7 @@ class TestExtractCodeContextWithHelpers:
|
||||||
class TestExtractCodeContextWithJavadoc:
|
class TestExtractCodeContextWithJavadoc:
|
||||||
"""Tests for extract_code_context with various Javadoc patterns."""
|
"""Tests for extract_code_context with various Javadoc patterns."""
|
||||||
|
|
||||||
def test_simple_javadoc(self, tmp_path: Path):
|
def test_simple_javadoc(self, tmp_path: Path) -> None:
|
||||||
"""Test context extraction with simple Javadoc."""
|
"""Test context extraction with simple Javadoc."""
|
||||||
java_file = tmp_path / "Example.java"
|
java_file = tmp_path / "Example.java"
|
||||||
java_file.write_text("""public class Example {
|
java_file.write_text("""public class Example {
|
||||||
|
|
@ -866,7 +867,7 @@ class TestExtractCodeContextWithJavadoc:
|
||||||
"""
|
"""
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_javadoc_with_params(self, tmp_path: Path):
|
def test_javadoc_with_params(self, tmp_path: Path) -> None:
|
||||||
"""Test context extraction with Javadoc @param tags."""
|
"""Test context extraction with Javadoc @param tags."""
|
||||||
java_file = tmp_path / "Calculator.java"
|
java_file = tmp_path / "Calculator.java"
|
||||||
java_file.write_text("""public class Calculator {
|
java_file.write_text("""public class Calculator {
|
||||||
|
|
@ -900,7 +901,7 @@ class TestExtractCodeContextWithJavadoc:
|
||||||
"""
|
"""
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_javadoc_with_return(self, tmp_path: Path):
|
def test_javadoc_with_return(self, tmp_path: Path) -> None:
|
||||||
"""Test context extraction with Javadoc @return tag."""
|
"""Test context extraction with Javadoc @return tag."""
|
||||||
java_file = tmp_path / "Calculator.java"
|
java_file = tmp_path / "Calculator.java"
|
||||||
java_file.write_text("""public class Calculator {
|
java_file.write_text("""public class Calculator {
|
||||||
|
|
@ -932,7 +933,7 @@ class TestExtractCodeContextWithJavadoc:
|
||||||
"""
|
"""
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_javadoc_with_throws(self, tmp_path: Path):
|
def test_javadoc_with_throws(self, tmp_path: Path) -> None:
|
||||||
"""Test context extraction with Javadoc @throws tag."""
|
"""Test context extraction with Javadoc @throws tag."""
|
||||||
java_file = tmp_path / "Divider.java"
|
java_file = tmp_path / "Divider.java"
|
||||||
java_file.write_text("""public class Divider {
|
java_file.write_text("""public class Divider {
|
||||||
|
|
@ -968,7 +969,7 @@ class TestExtractCodeContextWithJavadoc:
|
||||||
"""
|
"""
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_javadoc_multiline(self, tmp_path: Path):
|
def test_javadoc_multiline(self, tmp_path: Path) -> None:
|
||||||
"""Test context extraction with multi-paragraph Javadoc."""
|
"""Test context extraction with multi-paragraph Javadoc."""
|
||||||
java_file = tmp_path / "Complex.java"
|
java_file = tmp_path / "Complex.java"
|
||||||
java_file.write_text("""public class Complex {
|
java_file.write_text("""public class Complex {
|
||||||
|
|
@ -1020,7 +1021,7 @@ class TestExtractCodeContextWithJavadoc:
|
||||||
class TestExtractCodeContextWithGenerics:
|
class TestExtractCodeContextWithGenerics:
|
||||||
"""Tests for extract_code_context with generic types."""
|
"""Tests for extract_code_context with generic types."""
|
||||||
|
|
||||||
def test_generic_method_type_parameter(self, tmp_path: Path):
|
def test_generic_method_type_parameter(self, tmp_path: Path) -> None:
|
||||||
"""Test context extraction with generic type parameter."""
|
"""Test context extraction with generic type parameter."""
|
||||||
java_file = tmp_path / "Utils.java"
|
java_file = tmp_path / "Utils.java"
|
||||||
java_file.write_text("""public class Utils {
|
java_file.write_text("""public class Utils {
|
||||||
|
|
@ -1044,7 +1045,7 @@ class TestExtractCodeContextWithGenerics:
|
||||||
"""
|
"""
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_bounded_type_parameter(self, tmp_path: Path):
|
def test_bounded_type_parameter(self, tmp_path: Path) -> None:
|
||||||
"""Test context extraction with bounded type parameter."""
|
"""Test context extraction with bounded type parameter."""
|
||||||
java_file = tmp_path / "Statistics.java"
|
java_file = tmp_path / "Statistics.java"
|
||||||
java_file.write_text("""public class Statistics {
|
java_file.write_text("""public class Statistics {
|
||||||
|
|
@ -1076,7 +1077,7 @@ class TestExtractCodeContextWithGenerics:
|
||||||
"""
|
"""
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_wildcard_type(self, tmp_path: Path):
|
def test_wildcard_type(self, tmp_path: Path) -> None:
|
||||||
"""Test context extraction with wildcard type."""
|
"""Test context extraction with wildcard type."""
|
||||||
java_file = tmp_path / "Printer.java"
|
java_file = tmp_path / "Printer.java"
|
||||||
java_file.write_text("""public class Printer {
|
java_file.write_text("""public class Printer {
|
||||||
|
|
@ -1100,7 +1101,7 @@ class TestExtractCodeContextWithGenerics:
|
||||||
"""
|
"""
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_bounded_wildcard_extends(self, tmp_path: Path):
|
def test_bounded_wildcard_extends(self, tmp_path: Path) -> None:
|
||||||
"""Test context extraction with upper bounded wildcard."""
|
"""Test context extraction with upper bounded wildcard."""
|
||||||
java_file = tmp_path / "Aggregator.java"
|
java_file = tmp_path / "Aggregator.java"
|
||||||
java_file.write_text("""public class Aggregator {
|
java_file.write_text("""public class Aggregator {
|
||||||
|
|
@ -1132,7 +1133,7 @@ class TestExtractCodeContextWithGenerics:
|
||||||
"""
|
"""
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_bounded_wildcard_super(self, tmp_path: Path):
|
def test_bounded_wildcard_super(self, tmp_path: Path) -> None:
|
||||||
"""Test context extraction with lower bounded wildcard."""
|
"""Test context extraction with lower bounded wildcard."""
|
||||||
java_file = tmp_path / "Filler.java"
|
java_file = tmp_path / "Filler.java"
|
||||||
java_file.write_text("""public class Filler {
|
java_file.write_text("""public class Filler {
|
||||||
|
|
@ -1158,7 +1159,7 @@ class TestExtractCodeContextWithGenerics:
|
||||||
"""
|
"""
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_multiple_type_parameters(self, tmp_path: Path):
|
def test_multiple_type_parameters(self, tmp_path: Path) -> None:
|
||||||
"""Test context extraction with multiple type parameters."""
|
"""Test context extraction with multiple type parameters."""
|
||||||
java_file = tmp_path / "Mapper.java"
|
java_file = tmp_path / "Mapper.java"
|
||||||
java_file.write_text("""public class Mapper {
|
java_file.write_text("""public class Mapper {
|
||||||
|
|
@ -1190,7 +1191,7 @@ class TestExtractCodeContextWithGenerics:
|
||||||
"""
|
"""
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_recursive_type_bound(self, tmp_path: Path):
|
def test_recursive_type_bound(self, tmp_path: Path) -> None:
|
||||||
"""Test context extraction with recursive type bound."""
|
"""Test context extraction with recursive type bound."""
|
||||||
java_file = tmp_path / "Sorter.java"
|
java_file = tmp_path / "Sorter.java"
|
||||||
java_file.write_text("""public class Sorter {
|
java_file.write_text("""public class Sorter {
|
||||||
|
|
@ -1218,7 +1219,7 @@ class TestExtractCodeContextWithGenerics:
|
||||||
class TestExtractCodeContextWithAnnotations:
|
class TestExtractCodeContextWithAnnotations:
|
||||||
"""Tests for extract_code_context with annotations."""
|
"""Tests for extract_code_context with annotations."""
|
||||||
|
|
||||||
def test_override_annotation(self, tmp_path: Path):
|
def test_override_annotation(self, tmp_path: Path) -> None:
|
||||||
"""Test context extraction with @Override annotation."""
|
"""Test context extraction with @Override annotation."""
|
||||||
java_file = tmp_path / "Child.java"
|
java_file = tmp_path / "Child.java"
|
||||||
java_file.write_text("""public class Child extends Parent {
|
java_file.write_text("""public class Child extends Parent {
|
||||||
|
|
@ -1244,7 +1245,7 @@ class TestExtractCodeContextWithAnnotations:
|
||||||
"""
|
"""
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_deprecated_annotation(self, tmp_path: Path):
|
def test_deprecated_annotation(self, tmp_path: Path) -> None:
|
||||||
"""Test context extraction with @Deprecated annotation."""
|
"""Test context extraction with @Deprecated annotation."""
|
||||||
java_file = tmp_path / "Legacy.java"
|
java_file = tmp_path / "Legacy.java"
|
||||||
java_file.write_text("""public class Legacy {
|
java_file.write_text("""public class Legacy {
|
||||||
|
|
@ -1270,7 +1271,7 @@ class TestExtractCodeContextWithAnnotations:
|
||||||
"""
|
"""
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_suppress_warnings_annotation(self, tmp_path: Path):
|
def test_suppress_warnings_annotation(self, tmp_path: Path) -> None:
|
||||||
"""Test context extraction with @SuppressWarnings annotation."""
|
"""Test context extraction with @SuppressWarnings annotation."""
|
||||||
java_file = tmp_path / "Processor.java"
|
java_file = tmp_path / "Processor.java"
|
||||||
java_file.write_text("""public class Processor {
|
java_file.write_text("""public class Processor {
|
||||||
|
|
@ -1296,7 +1297,7 @@ class TestExtractCodeContextWithAnnotations:
|
||||||
"""
|
"""
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_multiple_annotations(self, tmp_path: Path):
|
def test_multiple_annotations(self, tmp_path: Path) -> None:
|
||||||
"""Test context extraction with multiple annotations."""
|
"""Test context extraction with multiple annotations."""
|
||||||
java_file = tmp_path / "Service.java"
|
java_file = tmp_path / "Service.java"
|
||||||
java_file.write_text("""public class Service {
|
java_file.write_text("""public class Service {
|
||||||
|
|
@ -1326,7 +1327,7 @@ class TestExtractCodeContextWithAnnotations:
|
||||||
"""
|
"""
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_annotation_with_array_value(self, tmp_path: Path):
|
def test_annotation_with_array_value(self, tmp_path: Path) -> None:
|
||||||
"""Test context extraction with annotation array value."""
|
"""Test context extraction with annotation array value."""
|
||||||
java_file = tmp_path / "Handler.java"
|
java_file = tmp_path / "Handler.java"
|
||||||
java_file.write_text("""public class Handler {
|
java_file.write_text("""public class Handler {
|
||||||
|
|
@ -1356,7 +1357,7 @@ class TestExtractCodeContextWithAnnotations:
|
||||||
class TestExtractCodeContextWithInheritance:
|
class TestExtractCodeContextWithInheritance:
|
||||||
"""Tests for extract_code_context with inheritance scenarios."""
|
"""Tests for extract_code_context with inheritance scenarios."""
|
||||||
|
|
||||||
def test_method_in_subclass(self, tmp_path: Path):
|
def test_method_in_subclass(self, tmp_path: Path) -> None:
|
||||||
"""Test context extraction for method in subclass."""
|
"""Test context extraction for method in subclass."""
|
||||||
java_file = tmp_path / "AdvancedCalc.java"
|
java_file = tmp_path / "AdvancedCalc.java"
|
||||||
java_file.write_text("""public class AdvancedCalc extends Calculator {
|
java_file.write_text("""public class AdvancedCalc extends Calculator {
|
||||||
|
|
@ -1382,7 +1383,7 @@ class TestExtractCodeContextWithInheritance:
|
||||||
"""
|
"""
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_interface_implementation(self, tmp_path: Path):
|
def test_interface_implementation(self, tmp_path: Path) -> None:
|
||||||
"""Test context extraction for interface implementation."""
|
"""Test context extraction for interface implementation."""
|
||||||
java_file = tmp_path / "MyComparable.java"
|
java_file = tmp_path / "MyComparable.java"
|
||||||
java_file.write_text("""public class MyComparable implements Comparable<MyComparable> {
|
java_file.write_text("""public class MyComparable implements Comparable<MyComparable> {
|
||||||
|
|
@ -1414,7 +1415,7 @@ class TestExtractCodeContextWithInheritance:
|
||||||
# Fields are in skeleton, so read_only_context is empty (no duplication)
|
# Fields are in skeleton, so read_only_context is empty (no duplication)
|
||||||
assert context.read_only_context == ""
|
assert context.read_only_context == ""
|
||||||
|
|
||||||
def test_multiple_interfaces(self, tmp_path: Path):
|
def test_multiple_interfaces(self, tmp_path: Path) -> None:
|
||||||
"""Test context extraction for multiple interface implementations."""
|
"""Test context extraction for multiple interface implementations."""
|
||||||
java_file = tmp_path / "MultiImpl.java"
|
java_file = tmp_path / "MultiImpl.java"
|
||||||
java_file.write_text("""public class MultiImpl implements Runnable, Comparable<MultiImpl> {
|
java_file.write_text("""public class MultiImpl implements Runnable, Comparable<MultiImpl> {
|
||||||
|
|
@ -1446,7 +1447,7 @@ class TestExtractCodeContextWithInheritance:
|
||||||
"""
|
"""
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_default_interface_method(self, tmp_path: Path):
|
def test_default_interface_method(self, tmp_path: Path) -> None:
|
||||||
"""Test context extraction for default interface method."""
|
"""Test context extraction for default interface method."""
|
||||||
java_file = tmp_path / "MyInterface.java"
|
java_file = tmp_path / "MyInterface.java"
|
||||||
java_file.write_text("""public interface MyInterface {
|
java_file.write_text("""public interface MyInterface {
|
||||||
|
|
@ -1479,7 +1480,7 @@ class TestExtractCodeContextWithInheritance:
|
||||||
class TestExtractCodeContextWithInnerClasses:
|
class TestExtractCodeContextWithInnerClasses:
|
||||||
"""Tests for extract_code_context with inner/nested classes."""
|
"""Tests for extract_code_context with inner/nested classes."""
|
||||||
|
|
||||||
def test_static_nested_class_method(self, tmp_path: Path):
|
def test_static_nested_class_method(self, tmp_path: Path) -> None:
|
||||||
"""Inner class methods are excluded from discovery and cannot be context-extracted.
|
"""Inner class methods are excluded from discovery and cannot be context-extracted.
|
||||||
|
|
||||||
Methods of static nested classes are skipped in discovery because they
|
Methods of static nested classes are skipped in discovery because they
|
||||||
|
|
@ -1499,7 +1500,7 @@ class TestExtractCodeContextWithInnerClasses:
|
||||||
# Inner class method must NOT be discovered
|
# Inner class method must NOT be discovered
|
||||||
assert compute_func is None
|
assert compute_func is None
|
||||||
|
|
||||||
def test_inner_class_method(self, tmp_path: Path):
|
def test_inner_class_method(self, tmp_path: Path) -> None:
|
||||||
"""Inner class methods are excluded from discovery and cannot be context-extracted.
|
"""Inner class methods are excluded from discovery and cannot be context-extracted.
|
||||||
|
|
||||||
Methods of non-static inner classes are skipped in discovery because they
|
Methods of non-static inner classes are skipped in discovery because they
|
||||||
|
|
@ -1525,7 +1526,7 @@ class TestExtractCodeContextWithInnerClasses:
|
||||||
class TestExtractCodeContextWithEnumAndInterface:
|
class TestExtractCodeContextWithEnumAndInterface:
|
||||||
"""Tests for extract_code_context with enums and interfaces."""
|
"""Tests for extract_code_context with enums and interfaces."""
|
||||||
|
|
||||||
def test_enum_method(self, tmp_path: Path):
|
def test_enum_method(self, tmp_path: Path) -> None:
|
||||||
"""Test context extraction for enum method."""
|
"""Test context extraction for enum method."""
|
||||||
java_file = tmp_path / "Operation.java"
|
java_file = tmp_path / "Operation.java"
|
||||||
java_file.write_text("""public enum Operation {
|
java_file.write_text("""public enum Operation {
|
||||||
|
|
@ -1568,7 +1569,7 @@ class TestExtractCodeContextWithEnumAndInterface:
|
||||||
)
|
)
|
||||||
assert context.read_only_context == ""
|
assert context.read_only_context == ""
|
||||||
|
|
||||||
def test_interface_default_method(self, tmp_path: Path):
|
def test_interface_default_method(self, tmp_path: Path) -> None:
|
||||||
"""Test context extraction for interface default method."""
|
"""Test context extraction for interface default method."""
|
||||||
java_file = tmp_path / "Greeting.java"
|
java_file = tmp_path / "Greeting.java"
|
||||||
java_file.write_text("""public interface Greeting {
|
java_file.write_text("""public interface Greeting {
|
||||||
|
|
@ -1595,7 +1596,7 @@ class TestExtractCodeContextWithEnumAndInterface:
|
||||||
)
|
)
|
||||||
assert context.read_only_context == ""
|
assert context.read_only_context == ""
|
||||||
|
|
||||||
def test_interface_static_method(self, tmp_path: Path):
|
def test_interface_static_method(self, tmp_path: Path) -> None:
|
||||||
"""Test context extraction for interface static method."""
|
"""Test context extraction for interface static method."""
|
||||||
java_file = tmp_path / "Factory.java"
|
java_file = tmp_path / "Factory.java"
|
||||||
java_file.write_text("""public interface Factory {
|
java_file.write_text("""public interface Factory {
|
||||||
|
|
@ -1626,7 +1627,7 @@ class TestExtractCodeContextWithEnumAndInterface:
|
||||||
class TestExtractCodeContextEdgeCases:
|
class TestExtractCodeContextEdgeCases:
|
||||||
"""Tests for extract_code_context edge cases."""
|
"""Tests for extract_code_context edge cases."""
|
||||||
|
|
||||||
def test_empty_method(self, tmp_path: Path):
|
def test_empty_method(self, tmp_path: Path) -> None:
|
||||||
"""Test context extraction for empty method."""
|
"""Test context extraction for empty method."""
|
||||||
java_file = tmp_path / "Empty.java"
|
java_file = tmp_path / "Empty.java"
|
||||||
java_file.write_text("""public class Empty {
|
java_file.write_text("""public class Empty {
|
||||||
|
|
@ -1650,7 +1651,7 @@ class TestExtractCodeContextEdgeCases:
|
||||||
"""
|
"""
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_single_line_method(self, tmp_path: Path):
|
def test_single_line_method(self, tmp_path: Path) -> None:
|
||||||
"""Test context extraction for single-line method."""
|
"""Test context extraction for single-line method."""
|
||||||
java_file = tmp_path / "OneLiner.java"
|
java_file = tmp_path / "OneLiner.java"
|
||||||
java_file.write_text("""public class OneLiner {
|
java_file.write_text("""public class OneLiner {
|
||||||
|
|
@ -1670,7 +1671,7 @@ class TestExtractCodeContextEdgeCases:
|
||||||
"""
|
"""
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_method_with_lambda(self, tmp_path: Path):
|
def test_method_with_lambda(self, tmp_path: Path) -> None:
|
||||||
"""Test context extraction for method with lambda."""
|
"""Test context extraction for method with lambda."""
|
||||||
java_file = tmp_path / "Functional.java"
|
java_file = tmp_path / "Functional.java"
|
||||||
java_file.write_text("""public class Functional {
|
java_file.write_text("""public class Functional {
|
||||||
|
|
@ -1698,7 +1699,7 @@ class TestExtractCodeContextEdgeCases:
|
||||||
"""
|
"""
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_method_with_method_reference(self, tmp_path: Path):
|
def test_method_with_method_reference(self, tmp_path: Path) -> None:
|
||||||
"""Test context extraction for method with method reference."""
|
"""Test context extraction for method with method reference."""
|
||||||
java_file = tmp_path / "Printer.java"
|
java_file = tmp_path / "Printer.java"
|
||||||
java_file.write_text("""public class Printer {
|
java_file.write_text("""public class Printer {
|
||||||
|
|
@ -1722,7 +1723,7 @@ class TestExtractCodeContextEdgeCases:
|
||||||
"""
|
"""
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_deeply_nested_blocks(self, tmp_path: Path):
|
def test_deeply_nested_blocks(self, tmp_path: Path) -> None:
|
||||||
"""Test context extraction for method with deeply nested blocks."""
|
"""Test context extraction for method with deeply nested blocks."""
|
||||||
java_file = tmp_path / "Nested.java"
|
java_file = tmp_path / "Nested.java"
|
||||||
java_file.write_text("""public class Nested {
|
java_file.write_text("""public class Nested {
|
||||||
|
|
@ -1776,7 +1777,7 @@ class TestExtractCodeContextEdgeCases:
|
||||||
"""
|
"""
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_unicode_in_source(self, tmp_path: Path):
|
def test_unicode_in_source(self, tmp_path: Path) -> None:
|
||||||
"""Test context extraction for method with unicode characters."""
|
"""Test context extraction for method with unicode characters."""
|
||||||
java_file = tmp_path / "Unicode.java"
|
java_file = tmp_path / "Unicode.java"
|
||||||
java_file.write_text(
|
java_file.write_text(
|
||||||
|
|
@ -1803,7 +1804,7 @@ class TestExtractCodeContextEdgeCases:
|
||||||
"""
|
"""
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_file_not_found(self, tmp_path: Path):
|
def test_file_not_found(self, tmp_path: Path) -> None:
|
||||||
"""Test context extraction for missing file."""
|
"""Test context extraction for missing file."""
|
||||||
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
|
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
|
||||||
from codeflash.models.function_types import FunctionParent
|
from codeflash.models.function_types import FunctionParent
|
||||||
|
|
@ -1824,7 +1825,7 @@ class TestExtractCodeContextEdgeCases:
|
||||||
assert context.language == Language.JAVA
|
assert context.language == Language.JAVA
|
||||||
assert context.target_file == missing_file
|
assert context.target_file == missing_file
|
||||||
|
|
||||||
def test_max_helper_depth_zero(self, tmp_path: Path):
|
def test_max_helper_depth_zero(self, tmp_path: Path) -> None:
|
||||||
"""Test context extraction with max_helper_depth=0."""
|
"""Test context extraction with max_helper_depth=0."""
|
||||||
java_file = tmp_path / "Calculator.java"
|
java_file = tmp_path / "Calculator.java"
|
||||||
java_file.write_text("""public class Calculator {
|
java_file.write_text("""public class Calculator {
|
||||||
|
|
@ -1858,7 +1859,7 @@ class TestExtractCodeContextEdgeCases:
|
||||||
class TestExtractCodeContextWithConstructor:
|
class TestExtractCodeContextWithConstructor:
|
||||||
"""Tests for extract_code_context with constructors in class skeleton."""
|
"""Tests for extract_code_context with constructors in class skeleton."""
|
||||||
|
|
||||||
def test_class_with_constructor(self, tmp_path: Path):
|
def test_class_with_constructor(self, tmp_path: Path) -> None:
|
||||||
"""Test context extraction includes constructor in skeleton."""
|
"""Test context extraction includes constructor in skeleton."""
|
||||||
java_file = tmp_path / "Person.java"
|
java_file = tmp_path / "Person.java"
|
||||||
java_file.write_text("""public class Person {
|
java_file.write_text("""public class Person {
|
||||||
|
|
@ -1898,7 +1899,7 @@ class TestExtractCodeContextWithConstructor:
|
||||||
"""
|
"""
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_class_with_multiple_constructors(self, tmp_path: Path):
|
def test_class_with_multiple_constructors(self, tmp_path: Path) -> None:
|
||||||
"""Test context extraction includes all constructors in skeleton."""
|
"""Test context extraction includes all constructors in skeleton."""
|
||||||
java_file = tmp_path / "Config.java"
|
java_file = tmp_path / "Config.java"
|
||||||
java_file.write_text("""public class Config {
|
java_file.write_text("""public class Config {
|
||||||
|
|
@ -1956,7 +1957,7 @@ class TestExtractCodeContextWithConstructor:
|
||||||
class TestExtractCodeContextFullIntegration:
|
class TestExtractCodeContextFullIntegration:
|
||||||
"""Integration tests for extract_code_context with all components."""
|
"""Integration tests for extract_code_context with all components."""
|
||||||
|
|
||||||
def test_full_context_with_all_components(self, tmp_path: Path):
|
def test_full_context_with_all_components(self, tmp_path: Path) -> None:
|
||||||
"""Test context extraction with imports, fields, and helpers."""
|
"""Test context extraction with imports, fields, and helpers."""
|
||||||
java_file = tmp_path / "Service.java"
|
java_file = tmp_path / "Service.java"
|
||||||
java_file.write_text("""package com.example;
|
java_file.write_text("""package com.example;
|
||||||
|
|
@ -2007,7 +2008,7 @@ public class Service {
|
||||||
assert len(context.helper_functions) == 1
|
assert len(context.helper_functions) == 1
|
||||||
assert context.helper_functions[0].name == "transform"
|
assert context.helper_functions[0].name == "transform"
|
||||||
|
|
||||||
def test_complex_class_with_javadoc_and_annotations(self, tmp_path: Path):
|
def test_complex_class_with_javadoc_and_annotations(self, tmp_path: Path) -> None:
|
||||||
"""Test context extraction for complex class with javadoc and annotations."""
|
"""Test context extraction for complex class with javadoc and annotations."""
|
||||||
java_file = tmp_path / "Calculator.java"
|
java_file = tmp_path / "Calculator.java"
|
||||||
java_file.write_text("""package com.example.math;
|
java_file.write_text("""package com.example.math;
|
||||||
|
|
@ -2073,7 +2074,7 @@ public class Calculator {
|
||||||
class TestExtractClassContext:
|
class TestExtractClassContext:
|
||||||
"""Tests for extract_class_context."""
|
"""Tests for extract_class_context."""
|
||||||
|
|
||||||
def test_extract_class_with_imports(self, tmp_path: Path):
|
def test_extract_class_with_imports(self, tmp_path: Path) -> None:
|
||||||
"""Test extracting full class context with imports."""
|
"""Test extracting full class context with imports."""
|
||||||
java_file = tmp_path / "Calculator.java"
|
java_file = tmp_path / "Calculator.java"
|
||||||
java_file.write_text("""package com.example;
|
java_file.write_text("""package com.example;
|
||||||
|
|
@ -2112,7 +2113,7 @@ public class Calculator {
|
||||||
}"""
|
}"""
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_extract_class_not_found(self, tmp_path: Path):
|
def test_extract_class_not_found(self, tmp_path: Path) -> None:
|
||||||
"""Test extracting non-existent class returns empty string."""
|
"""Test extracting non-existent class returns empty string."""
|
||||||
java_file = tmp_path / "Test.java"
|
java_file = tmp_path / "Test.java"
|
||||||
java_file.write_text("""public class Test {
|
java_file.write_text("""public class Test {
|
||||||
|
|
@ -2124,7 +2125,7 @@ public class Calculator {
|
||||||
|
|
||||||
assert context == ""
|
assert context == ""
|
||||||
|
|
||||||
def test_extract_class_missing_file(self, tmp_path: Path):
|
def test_extract_class_missing_file(self, tmp_path: Path) -> None:
|
||||||
"""Test extracting from missing file returns empty string."""
|
"""Test extracting from missing file returns empty string."""
|
||||||
missing_file = tmp_path / "Missing.java"
|
missing_file = tmp_path / "Missing.java"
|
||||||
|
|
||||||
|
|
@ -2141,7 +2142,7 @@ class TestExtractFunctionSourceStaleLineNumbers:
|
||||||
extraction should still find the correct function by name.
|
extraction should still find the correct function by name.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def test_extraction_with_stale_line_numbers(self):
|
def test_extraction_with_stale_line_numbers(self) -> None:
|
||||||
"""Verify extraction works when pre-computed line numbers no longer match the source."""
|
"""Verify extraction works when pre-computed line numbers no longer match the source."""
|
||||||
# Original source: functionA at lines 2-4, functionB at lines 5-7
|
# Original source: functionA at lines 2-4, functionB at lines 5-7
|
||||||
original_source = """public class Utils {
|
original_source = """public class Utils {
|
||||||
|
|
@ -2177,7 +2178,7 @@ class TestExtractFunctionSourceStaleLineNumbers:
|
||||||
assert "functionB" in result
|
assert "functionB" in result
|
||||||
assert "return 2;" in result
|
assert "return 2;" in result
|
||||||
|
|
||||||
def test_extraction_without_analyzer_uses_line_numbers(self):
|
def test_extraction_without_analyzer_uses_line_numbers(self) -> None:
|
||||||
"""Without analyzer, extraction falls back to pre-computed line numbers."""
|
"""Without analyzer, extraction falls back to pre-computed line numbers."""
|
||||||
source = """public class Utils {
|
source = """public class Utils {
|
||||||
public int functionA() {
|
public int functionA() {
|
||||||
|
|
@ -2196,7 +2197,7 @@ class TestExtractFunctionSourceStaleLineNumbers:
|
||||||
assert "functionB" in result
|
assert "functionB" in result
|
||||||
assert "return 2;" in result
|
assert "return 2;" in result
|
||||||
|
|
||||||
def test_extraction_with_javadoc_after_file_modification(self):
|
def test_extraction_with_javadoc_after_file_modification(self) -> None:
|
||||||
"""Verify Javadoc is included when using tree-sitter extraction on modified files."""
|
"""Verify Javadoc is included when using tree-sitter extraction on modified files."""
|
||||||
original_source = """public class Utils {
|
original_source = """public class Utils {
|
||||||
/** Adds two numbers. */
|
/** Adds two numbers. */
|
||||||
|
|
@ -2233,7 +2234,7 @@ class TestExtractFunctionSourceStaleLineNumbers:
|
||||||
assert "public int subtract" in result
|
assert "public int subtract" in result
|
||||||
assert "return a - b;" in result
|
assert "return a - b;" in result
|
||||||
|
|
||||||
def test_extraction_with_overloaded_methods(self):
|
def test_extraction_with_overloaded_methods(self) -> None:
|
||||||
"""Verify correct overload is selected using line proximity."""
|
"""Verify correct overload is selected using line proximity."""
|
||||||
source = """public class Utils {
|
source = """public class Utils {
|
||||||
public int process(int x) {
|
public int process(int x) {
|
||||||
|
|
@ -2247,13 +2248,15 @@ class TestExtractFunctionSourceStaleLineNumbers:
|
||||||
analyzer = get_java_analyzer()
|
analyzer = get_java_analyzer()
|
||||||
functions = discover_functions_from_source(source, file_path=Path("Utils.java"))
|
functions = discover_functions_from_source(source, file_path=Path("Utils.java"))
|
||||||
# Get the second overload (process(int, int))
|
# Get the second overload (process(int, int))
|
||||||
func_two_args = [f for f in functions if f.function_name == "process" and f.ending_line > 4][0]
|
func_two_args = [
|
||||||
|
f for f in functions if f.function_name == "process" and f.ending_line is not None and f.ending_line > 4
|
||||||
|
][0]
|
||||||
|
|
||||||
result = extract_function_source(source, func_two_args, analyzer=analyzer)
|
result = extract_function_source(source, func_two_args, analyzer=analyzer)
|
||||||
assert "int x, int y" in result
|
assert "int x, int y" in result
|
||||||
assert "return x + y;" in result
|
assert "return x + y;" in result
|
||||||
|
|
||||||
def test_extraction_function_not_found_falls_back(self):
|
def test_extraction_function_not_found_falls_back(self) -> None:
|
||||||
"""If tree-sitter can't find the method, fall back to line numbers."""
|
"""If tree-sitter can't find the method, fall back to line numbers."""
|
||||||
source = """public class Utils {
|
source = """public class Utils {
|
||||||
public int functionA() {
|
public int functionA() {
|
||||||
|
|
@ -2282,7 +2285,7 @@ FIXTURE_DIR = Path(__file__).parent.parent / "fixtures" / "java_maven"
|
||||||
class TestGetJavaImportedTypeSkeletons:
|
class TestGetJavaImportedTypeSkeletons:
|
||||||
"""Tests for get_java_imported_type_skeletons()."""
|
"""Tests for get_java_imported_type_skeletons()."""
|
||||||
|
|
||||||
def test_resolves_internal_imports(self):
|
def test_resolves_internal_imports(self) -> None:
|
||||||
"""Verify that project-internal imports are resolved and skeletons extracted."""
|
"""Verify that project-internal imports are resolved and skeletons extracted."""
|
||||||
project_root = FIXTURE_DIR
|
project_root = FIXTURE_DIR
|
||||||
module_root = FIXTURE_DIR / "src" / "main" / "java"
|
module_root = FIXTURE_DIR / "src" / "main" / "java"
|
||||||
|
|
@ -2297,7 +2300,7 @@ class TestGetJavaImportedTypeSkeletons:
|
||||||
assert "MathHelper" in result
|
assert "MathHelper" in result
|
||||||
assert "Formatter" in result
|
assert "Formatter" in result
|
||||||
|
|
||||||
def test_skeletons_contain_method_signatures(self):
|
def test_skeletons_contain_method_signatures(self) -> None:
|
||||||
"""Verify extracted skeletons include public method signatures."""
|
"""Verify extracted skeletons include public method signatures."""
|
||||||
project_root = FIXTURE_DIR
|
project_root = FIXTURE_DIR
|
||||||
module_root = FIXTURE_DIR / "src" / "main" / "java"
|
module_root = FIXTURE_DIR / "src" / "main" / "java"
|
||||||
|
|
@ -2313,7 +2316,7 @@ class TestGetJavaImportedTypeSkeletons:
|
||||||
assert "multiply" in result
|
assert "multiply" in result
|
||||||
assert "factorial" in result
|
assert "factorial" in result
|
||||||
|
|
||||||
def test_skips_external_imports(self):
|
def test_skips_external_imports(self) -> None:
|
||||||
"""Verify that standard library and external imports are skipped."""
|
"""Verify that standard library and external imports are skipped."""
|
||||||
project_root = FIXTURE_DIR
|
project_root = FIXTURE_DIR
|
||||||
module_root = FIXTURE_DIR / "src" / "main" / "java"
|
module_root = FIXTURE_DIR / "src" / "main" / "java"
|
||||||
|
|
@ -2328,7 +2331,7 @@ class TestGetJavaImportedTypeSkeletons:
|
||||||
# No internal imports → empty result
|
# No internal imports → empty result
|
||||||
assert result == ""
|
assert result == ""
|
||||||
|
|
||||||
def test_deduplicates_imports(self):
|
def test_deduplicates_imports(self) -> None:
|
||||||
"""Verify that the same type imported twice is only included once."""
|
"""Verify that the same type imported twice is only included once."""
|
||||||
project_root = FIXTURE_DIR
|
project_root = FIXTURE_DIR
|
||||||
module_root = FIXTURE_DIR / "src" / "main" / "java"
|
module_root = FIXTURE_DIR / "src" / "main" / "java"
|
||||||
|
|
@ -2344,7 +2347,7 @@ class TestGetJavaImportedTypeSkeletons:
|
||||||
# Count occurrences of MathHelper — should appear exactly once
|
# Count occurrences of MathHelper — should appear exactly once
|
||||||
assert result.count("class MathHelper") == 1
|
assert result.count("class MathHelper") == 1
|
||||||
|
|
||||||
def test_empty_imports_returns_empty(self):
|
def test_empty_imports_returns_empty(self) -> None:
|
||||||
"""Verify that empty import list returns empty string."""
|
"""Verify that empty import list returns empty string."""
|
||||||
project_root = FIXTURE_DIR
|
project_root = FIXTURE_DIR
|
||||||
analyzer = get_java_analyzer()
|
analyzer = get_java_analyzer()
|
||||||
|
|
@ -2353,7 +2356,7 @@ class TestGetJavaImportedTypeSkeletons:
|
||||||
|
|
||||||
assert result == ""
|
assert result == ""
|
||||||
|
|
||||||
def test_respects_token_budget(self):
|
def test_respects_token_budget(self) -> None:
|
||||||
"""Verify that the function stops when token budget is exceeded."""
|
"""Verify that the function stops when token budget is exceeded."""
|
||||||
project_root = FIXTURE_DIR
|
project_root = FIXTURE_DIR
|
||||||
module_root = FIXTURE_DIR / "src" / "main" / "java"
|
module_root = FIXTURE_DIR / "src" / "main" / "java"
|
||||||
|
|
@ -2378,7 +2381,7 @@ class TestGetJavaImportedTypeSkeletons:
|
||||||
class TestExtractPublicMethodSignatures:
|
class TestExtractPublicMethodSignatures:
|
||||||
"""Tests for _extract_public_method_signatures()."""
|
"""Tests for _extract_public_method_signatures()."""
|
||||||
|
|
||||||
def test_extracts_public_methods(self):
|
def test_extracts_public_methods(self) -> None:
|
||||||
"""Verify public method signatures are extracted."""
|
"""Verify public method signatures are extracted."""
|
||||||
source = """public class Foo {
|
source = """public class Foo {
|
||||||
public int add(int a, int b) {
|
public int add(int a, int b) {
|
||||||
|
|
@ -2398,7 +2401,7 @@ class TestExtractPublicMethodSignatures:
|
||||||
# private method should not be included
|
# private method should not be included
|
||||||
assert not any("secret" in s for s in sigs)
|
assert not any("secret" in s for s in sigs)
|
||||||
|
|
||||||
def test_excludes_constructors(self):
|
def test_excludes_constructors(self) -> None:
|
||||||
"""Verify constructors are excluded from method signatures."""
|
"""Verify constructors are excluded from method signatures."""
|
||||||
source = """public class Bar {
|
source = """public class Bar {
|
||||||
public Bar(int x) { this.x = x; }
|
public Bar(int x) { this.x = x; }
|
||||||
|
|
@ -2411,7 +2414,7 @@ class TestExtractPublicMethodSignatures:
|
||||||
assert "getX" in sigs[0]
|
assert "getX" in sigs[0]
|
||||||
assert not any("Bar(" in s for s in sigs)
|
assert not any("Bar(" in s for s in sigs)
|
||||||
|
|
||||||
def test_empty_class_returns_empty(self):
|
def test_empty_class_returns_empty(self) -> None:
|
||||||
"""Verify empty class returns no signatures."""
|
"""Verify empty class returns no signatures."""
|
||||||
source = """public class Empty {}"""
|
source = """public class Empty {}"""
|
||||||
analyzer = get_java_analyzer()
|
analyzer = get_java_analyzer()
|
||||||
|
|
@ -2419,7 +2422,7 @@ class TestExtractPublicMethodSignatures:
|
||||||
|
|
||||||
assert sigs == []
|
assert sigs == []
|
||||||
|
|
||||||
def test_filters_by_class_name(self):
|
def test_filters_by_class_name(self) -> None:
|
||||||
"""Verify only methods from the specified class are returned."""
|
"""Verify only methods from the specified class are returned."""
|
||||||
source = """public class A {
|
source = """public class A {
|
||||||
public int aMethod() { return 1; }
|
public int aMethod() { return 1; }
|
||||||
|
|
@ -2440,7 +2443,7 @@ class B {
|
||||||
class TestFormatSkeletonForContext:
|
class TestFormatSkeletonForContext:
|
||||||
"""Tests for _format_skeleton_for_context()."""
|
"""Tests for _format_skeleton_for_context()."""
|
||||||
|
|
||||||
def test_formats_basic_skeleton(self):
|
def test_formats_basic_skeleton(self) -> None:
|
||||||
"""Verify basic skeleton formatting with fields and constructors."""
|
"""Verify basic skeleton formatting with fields and constructors."""
|
||||||
source = """public class Widget {
|
source = """public class Widget {
|
||||||
private int size;
|
private int size;
|
||||||
|
|
@ -2467,7 +2470,7 @@ class TestFormatSkeletonForContext:
|
||||||
assert "getSize" in result
|
assert "getSize" in result
|
||||||
assert result.endswith("}")
|
assert result.endswith("}")
|
||||||
|
|
||||||
def test_formats_enum_skeleton(self):
|
def test_formats_enum_skeleton(self) -> None:
|
||||||
"""Verify enum formatting includes constants."""
|
"""Verify enum formatting includes constants."""
|
||||||
source = """public enum Color {
|
source = """public enum Color {
|
||||||
RED, GREEN, BLUE;
|
RED, GREEN, BLUE;
|
||||||
|
|
@ -2490,7 +2493,7 @@ class TestFormatSkeletonForContext:
|
||||||
assert "RED, GREEN, BLUE;" in result
|
assert "RED, GREEN, BLUE;" in result
|
||||||
assert "lower" in result
|
assert "lower" in result
|
||||||
|
|
||||||
def test_formats_empty_class(self):
|
def test_formats_empty_class(self) -> None:
|
||||||
"""Verify formatting of a class with no fields or methods."""
|
"""Verify formatting of a class with no fields or methods."""
|
||||||
source = """public class Empty {}"""
|
source = """public class Empty {}"""
|
||||||
analyzer = get_java_analyzer()
|
analyzer = get_java_analyzer()
|
||||||
|
|
@ -2512,7 +2515,7 @@ class TestFormatSkeletonForContext:
|
||||||
class TestGetJavaImportedTypeSkeletonsEdgeCases:
|
class TestGetJavaImportedTypeSkeletonsEdgeCases:
|
||||||
"""Additional edge case tests for get_java_imported_type_skeletons()."""
|
"""Additional edge case tests for get_java_imported_type_skeletons()."""
|
||||||
|
|
||||||
def test_wildcard_imports_are_expanded(self):
|
def test_wildcard_imports_are_expanded(self) -> None:
|
||||||
"""Wildcard imports (e.g., import com.example.helpers.*) are expanded to individual types."""
|
"""Wildcard imports (e.g., import com.example.helpers.*) are expanded to individual types."""
|
||||||
project_root = FIXTURE_DIR
|
project_root = FIXTURE_DIR
|
||||||
module_root = FIXTURE_DIR / "src" / "main" / "java"
|
module_root = FIXTURE_DIR / "src" / "main" / "java"
|
||||||
|
|
@ -2530,7 +2533,37 @@ class TestGetJavaImportedTypeSkeletonsEdgeCases:
|
||||||
# Wildcard imports should now be expanded to individual classes found in the package directory
|
# Wildcard imports should now be expanded to individual classes found in the package directory
|
||||||
assert "MathHelper" in result
|
assert "MathHelper" in result
|
||||||
|
|
||||||
def test_import_to_nonexistent_class_in_file(self):
|
def test_large_wildcard_is_filtered_to_referenced_types(self, tmp_path: Path) -> None:
|
||||||
|
"""When wildcard expands to >50 types, only types referenced in target code are included."""
|
||||||
|
from codeflash.languages.java.context import MAX_WILDCARD_TYPES_UNFILTERED
|
||||||
|
|
||||||
|
# Create a minimal Maven project structure so the resolver finds source roots
|
||||||
|
(tmp_path / "pom.xml").write_text("<project/>", encoding="utf-8")
|
||||||
|
pkg_dir = tmp_path / "src" / "main" / "java" / "com" / "bigpkg"
|
||||||
|
pkg_dir.mkdir(parents=True)
|
||||||
|
for i in range(MAX_WILDCARD_TYPES_UNFILTERED + 20):
|
||||||
|
(pkg_dir / f"Type{i:03d}.java").write_text(
|
||||||
|
f"package com.bigpkg;\npublic class Type{i:03d} {{ public int val() {{ return {i}; }} }}\n",
|
||||||
|
encoding="utf-8",
|
||||||
|
)
|
||||||
|
|
||||||
|
analyzer = get_java_analyzer()
|
||||||
|
# Target code references Type000 and Type001 only
|
||||||
|
target_code = "Type000 a = new Type000(); Type001 b = a.transform();"
|
||||||
|
source = "package com.example;\nimport com.bigpkg.*;\npublic class Foo { void bar() {} }"
|
||||||
|
imports = analyzer.find_imports(source)
|
||||||
|
|
||||||
|
result = get_java_imported_type_skeletons(
|
||||||
|
imports, tmp_path, tmp_path / "src" / "main" / "java", analyzer, target_code=target_code
|
||||||
|
)
|
||||||
|
|
||||||
|
# Only referenced types should appear, not all 70
|
||||||
|
assert "Type000" in result
|
||||||
|
assert "Type001" in result
|
||||||
|
# Types not referenced in target code should be excluded
|
||||||
|
assert "Type050" not in result
|
||||||
|
|
||||||
|
def test_import_to_nonexistent_class_in_file(self) -> None:
|
||||||
"""When an import resolves to a file but the class doesn't exist in it, skeleton extraction returns None."""
|
"""When an import resolves to a file but the class doesn't exist in it, skeleton extraction returns None."""
|
||||||
analyzer = get_java_analyzer()
|
analyzer = get_java_analyzer()
|
||||||
|
|
||||||
|
|
@ -2540,7 +2573,7 @@ class TestGetJavaImportedTypeSkeletonsEdgeCases:
|
||||||
|
|
||||||
assert skeleton is None
|
assert skeleton is None
|
||||||
|
|
||||||
def test_skeleton_output_is_well_formed(self):
|
def test_skeleton_output_is_well_formed(self) -> None:
|
||||||
"""Verify the skeleton string has proper Java-like structure with braces."""
|
"""Verify the skeleton string has proper Java-like structure with braces."""
|
||||||
project_root = FIXTURE_DIR
|
project_root = FIXTURE_DIR
|
||||||
module_root = FIXTURE_DIR / "src" / "main" / "java"
|
module_root = FIXTURE_DIR / "src" / "main" / "java"
|
||||||
|
|
@ -2563,7 +2596,7 @@ class TestGetJavaImportedTypeSkeletonsEdgeCases:
|
||||||
class TestExtractPublicMethodSignaturesEdgeCases:
|
class TestExtractPublicMethodSignaturesEdgeCases:
|
||||||
"""Additional edge case tests for _extract_public_method_signatures()."""
|
"""Additional edge case tests for _extract_public_method_signatures()."""
|
||||||
|
|
||||||
def test_excludes_protected_and_package_private(self):
|
def test_excludes_protected_and_package_private(self) -> None:
|
||||||
"""Verify protected and package-private methods are excluded."""
|
"""Verify protected and package-private methods are excluded."""
|
||||||
source = """public class Visibility {
|
source = """public class Visibility {
|
||||||
public int publicMethod() { return 1; }
|
public int publicMethod() { return 1; }
|
||||||
|
|
@ -2580,7 +2613,7 @@ class TestExtractPublicMethodSignaturesEdgeCases:
|
||||||
assert not any("packagePrivateMethod" in s for s in sigs)
|
assert not any("packagePrivateMethod" in s for s in sigs)
|
||||||
assert not any("privateMethod" in s for s in sigs)
|
assert not any("privateMethod" in s for s in sigs)
|
||||||
|
|
||||||
def test_handles_overloaded_methods(self):
|
def test_handles_overloaded_methods(self) -> None:
|
||||||
"""Verify all public overloads are extracted."""
|
"""Verify all public overloads are extracted."""
|
||||||
source = """public class Overloaded {
|
source = """public class Overloaded {
|
||||||
public int process(int x) { return x; }
|
public int process(int x) { return x; }
|
||||||
|
|
@ -2594,7 +2627,7 @@ class TestExtractPublicMethodSignaturesEdgeCases:
|
||||||
# All should contain "process"
|
# All should contain "process"
|
||||||
assert all("process" in s for s in sigs)
|
assert all("process" in s for s in sigs)
|
||||||
|
|
||||||
def test_handles_generic_methods(self):
|
def test_handles_generic_methods(self) -> None:
|
||||||
"""Verify generic method signatures are extracted correctly."""
|
"""Verify generic method signatures are extracted correctly."""
|
||||||
source = """public class Generic {
|
source = """public class Generic {
|
||||||
public <T> T identity(T value) { return value; }
|
public <T> T identity(T value) { return value; }
|
||||||
|
|
@ -2611,7 +2644,7 @@ class TestExtractPublicMethodSignaturesEdgeCases:
|
||||||
class TestFormatSkeletonRoundTrip:
|
class TestFormatSkeletonRoundTrip:
|
||||||
"""Tests that verify _extract_type_skeleton → _format_skeleton_for_context produces valid output."""
|
"""Tests that verify _extract_type_skeleton → _format_skeleton_for_context produces valid output."""
|
||||||
|
|
||||||
def test_round_trip_produces_valid_skeleton(self):
|
def test_round_trip_produces_valid_skeleton(self) -> None:
|
||||||
"""Extract a real skeleton and format it — verify the output is sensible."""
|
"""Extract a real skeleton and format it — verify the output is sensible."""
|
||||||
source = """public class Service {
|
source = """public class Service {
|
||||||
private final String name;
|
private final String name;
|
||||||
|
|
@ -2659,7 +2692,7 @@ class TestFormatSkeletonRoundTrip:
|
||||||
# Should end properly
|
# Should end properly
|
||||||
assert result.strip().endswith("}")
|
assert result.strip().endswith("}")
|
||||||
|
|
||||||
def test_round_trip_with_fixture_mathhelper(self):
|
def test_round_trip_with_fixture_mathhelper(self) -> None:
|
||||||
"""Round-trip test using the real MathHelper fixture file."""
|
"""Round-trip test using the real MathHelper fixture file."""
|
||||||
source = (FIXTURE_DIR / "src" / "main" / "java" / "com" / "example" / "helpers" / "MathHelper.java").read_text()
|
source = (FIXTURE_DIR / "src" / "main" / "java" / "com" / "example" / "helpers" / "MathHelper.java").read_text()
|
||||||
analyzer = get_java_analyzer()
|
analyzer = get_java_analyzer()
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue