mirror of
https://github.com/codeflash-ai/codeflash.git
synced 2026-05-04 18:25:17 +00:00
better unit test discovery java
This commit is contained in:
parent
95cc60397d
commit
0ff54b5043
3 changed files with 2518 additions and 104 deletions
|
|
@ -2,6 +2,11 @@
|
|||
|
||||
This module provides functionality to discover tests that exercise
|
||||
specific functions, mapping source functions to their tests.
|
||||
|
||||
The core matching strategy traces method invocations in test code back to their
|
||||
declaring class by resolving variable types from declarations, field types, static
|
||||
imports, and constructor expressions. This is analogous to how Python test discovery
|
||||
uses jedi's "goto" functionality.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
|
@ -19,6 +24,8 @@ if TYPE_CHECKING:
|
|||
from collections.abc import Sequence
|
||||
from pathlib import Path
|
||||
|
||||
from tree_sitter import Node
|
||||
|
||||
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
|
||||
from codeflash.languages.java.parser import JavaAnalyzer
|
||||
|
||||
|
|
@ -30,11 +37,8 @@ def discover_tests(
|
|||
) -> dict[str, list[TestInfo]]:
|
||||
"""Map source functions to their tests via static analysis.
|
||||
|
||||
Uses several heuristics to match tests to functions:
|
||||
1. Test method name contains function name
|
||||
2. Test class name matches source class name
|
||||
3. Imports analysis
|
||||
4. Method call analysis in test code
|
||||
Resolves method invocations in test code back to their declaring class by
|
||||
tracing variable types, field types, static imports, and constructor calls.
|
||||
|
||||
Args:
|
||||
test_root: Root directory containing tests.
|
||||
|
|
@ -47,18 +51,16 @@ def discover_tests(
|
|||
"""
|
||||
analyzer = analyzer or get_java_analyzer()
|
||||
|
||||
# Build a map of function names for quick lookup
|
||||
function_map: dict[str, FunctionToOptimize] = {}
|
||||
for func in source_functions:
|
||||
function_map[func.function_name] = func
|
||||
function_map[func.qualified_name] = func
|
||||
|
||||
# Find all test files (various naming conventions)
|
||||
test_files = (
|
||||
list(test_root.rglob("*Test.java")) + list(test_root.rglob("*Tests.java")) + list(test_root.rglob("Test*.java"))
|
||||
)
|
||||
# Deduplicate (a file like FooTest.java could match multiple patterns)
|
||||
test_files = list(dict.fromkeys(test_files))
|
||||
|
||||
# Result map
|
||||
result: dict[str, list[TestInfo]] = defaultdict(list)
|
||||
|
||||
for test_file in test_files:
|
||||
|
|
@ -67,7 +69,6 @@ def discover_tests(
|
|||
source = test_file.read_text(encoding="utf-8")
|
||||
|
||||
for test_method in test_methods:
|
||||
# Find which source functions this test might exercise
|
||||
matched_functions = _match_test_to_functions(test_method, source, function_map, analyzer)
|
||||
|
||||
for func_name in matched_functions:
|
||||
|
|
@ -89,135 +90,230 @@ def _match_test_to_functions(
|
|||
function_map: dict[str, FunctionToOptimize],
|
||||
analyzer: JavaAnalyzer,
|
||||
) -> list[str]:
|
||||
"""Match a test method to source functions it might exercise.
|
||||
"""Match a test method to source functions it exercises.
|
||||
|
||||
Resolves each method invocation in the test to ClassName.methodName by:
|
||||
1. Building a variable-to-type map from local declarations and class fields.
|
||||
2. Building a static import map (method -> class).
|
||||
3. For each method_invocation, resolving the receiver to a class name.
|
||||
4. Matching resolved ClassName.methodName against the function map.
|
||||
|
||||
Args:
|
||||
test_method: The test method.
|
||||
test_source: Full source code of the test file.
|
||||
function_map: Map of function names to FunctionToOptimize.
|
||||
function_map: Map of qualified names to FunctionToOptimize.
|
||||
analyzer: JavaAnalyzer instance.
|
||||
|
||||
Returns:
|
||||
List of function qualified names that this test might exercise.
|
||||
List of function qualified names that this test exercises.
|
||||
|
||||
"""
|
||||
matched: list[str] = []
|
||||
|
||||
# Strategy 1: Test method name contains function name
|
||||
# e.g., testAdd -> add, testCalculatorAdd -> Calculator.add
|
||||
test_name_lower = test_method.function_name.lower()
|
||||
|
||||
for func_info in function_map.values():
|
||||
if func_info.function_name.lower() in test_name_lower:
|
||||
matched.append(func_info.qualified_name)
|
||||
|
||||
# Strategy 2: Method call analysis
|
||||
# Look for direct method calls in the test code
|
||||
source_bytes = test_source.encode("utf8")
|
||||
tree = analyzer.parse(source_bytes)
|
||||
|
||||
# Find method calls within the test method's line range
|
||||
method_calls = _find_method_calls_in_range(
|
||||
# Build type resolution context
|
||||
field_types = _build_field_type_map(tree.root_node, source_bytes, analyzer, test_method.class_name)
|
||||
local_types = _build_local_type_map(
|
||||
tree.root_node, source_bytes, test_method.starting_line, test_method.ending_line, analyzer
|
||||
)
|
||||
# Locals shadow fields
|
||||
type_map = {**field_types, **local_types}
|
||||
|
||||
for call_name in method_calls:
|
||||
if call_name in function_map:
|
||||
qualified = function_map[call_name].qualified_name
|
||||
if qualified not in matched:
|
||||
matched.append(qualified)
|
||||
static_import_map = _build_static_import_map(tree.root_node, source_bytes, analyzer)
|
||||
|
||||
# Strategy 3: Test class naming convention
|
||||
# e.g., CalculatorTest tests Calculator, TestCalculator tests Calculator
|
||||
if test_method.class_name:
|
||||
# Remove "Test/Tests" suffix or "Test" prefix
|
||||
source_class_name = test_method.class_name
|
||||
if source_class_name.endswith("Tests"):
|
||||
source_class_name = source_class_name[:-5]
|
||||
elif source_class_name.endswith("Test"):
|
||||
source_class_name = source_class_name[:-4]
|
||||
elif source_class_name.startswith("Test"):
|
||||
source_class_name = source_class_name[4:]
|
||||
# Resolve method calls to ClassName.methodName
|
||||
resolved_calls = _resolve_method_calls_in_range(
|
||||
tree.root_node, source_bytes, test_method.starting_line, test_method.ending_line, analyzer, type_map,
|
||||
static_import_map,
|
||||
)
|
||||
|
||||
# Look for functions in the matching class
|
||||
for func_info in function_map.values():
|
||||
if func_info.class_name == source_class_name:
|
||||
if func_info.qualified_name not in matched:
|
||||
matched.append(func_info.qualified_name)
|
||||
|
||||
# Strategy 4: Import-based matching
|
||||
# If the test file imports a class containing the target function, consider it a match
|
||||
# This handles cases like TestQueryBlob importing Buffer and calling Buffer methods
|
||||
imported_classes = _extract_imports(tree.root_node, source_bytes, analyzer)
|
||||
|
||||
for func_info in function_map.values():
|
||||
if func_info.qualified_name in matched:
|
||||
continue
|
||||
|
||||
# Check if the function's class is imported
|
||||
if func_info.class_name and func_info.class_name in imported_classes:
|
||||
matched.append(func_info.qualified_name)
|
||||
matched: list[str] = []
|
||||
for call in resolved_calls:
|
||||
if call in function_map and call not in matched:
|
||||
matched.append(call)
|
||||
|
||||
return matched
|
||||
|
||||
|
||||
def _extract_imports(node, source_bytes: bytes, analyzer: JavaAnalyzer) -> set[str]:
|
||||
"""Extract imported class names from a Java file.
|
||||
# ---------------------------------------------------------------------------
|
||||
# Type resolution helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
Args:
|
||||
node: Tree-sitter root node.
|
||||
source_bytes: Source code as bytes.
|
||||
analyzer: JavaAnalyzer instance.
|
||||
|
||||
Returns:
|
||||
Set of imported class names (simple names, not fully qualified).
|
||||
def _strip_generics(type_name: str) -> str:
|
||||
"""Strip generic type parameters: ``List<String>`` -> ``List``."""
|
||||
idx = type_name.find("<")
|
||||
if idx != -1:
|
||||
return type_name[:idx].strip()
|
||||
return type_name.strip()
|
||||
|
||||
|
||||
def _build_local_type_map(
|
||||
node: Node, source_bytes: bytes, start_line: int, end_line: int, analyzer: JavaAnalyzer
|
||||
) -> dict[str, str]:
|
||||
"""Map variable names to their declared types within a line range.
|
||||
|
||||
Handles local variable declarations (including ``var`` with constructor
|
||||
initializers) and enhanced-for loop variables.
|
||||
"""
|
||||
type_map: dict[str, str] = {}
|
||||
|
||||
def _infer_var_type(declarator: Node) -> str | None:
|
||||
value_node = declarator.child_by_field_name("value")
|
||||
if value_node is None:
|
||||
return None
|
||||
if value_node.type == "object_creation_expression":
|
||||
type_node = value_node.child_by_field_name("type")
|
||||
if type_node:
|
||||
return _strip_generics(analyzer.get_node_text(type_node, source_bytes))
|
||||
return None
|
||||
|
||||
def visit(n: Node) -> None:
|
||||
n_start = n.start_point[0] + 1
|
||||
n_end = n.end_point[0] + 1
|
||||
if n_end < start_line or n_start > end_line:
|
||||
return
|
||||
|
||||
if n.type == "local_variable_declaration":
|
||||
type_node = n.child_by_field_name("type")
|
||||
if type_node:
|
||||
type_name = _strip_generics(analyzer.get_node_text(type_node, source_bytes))
|
||||
for child in n.children:
|
||||
if child.type == "variable_declarator":
|
||||
name_node = child.child_by_field_name("name")
|
||||
if name_node:
|
||||
var_name = analyzer.get_node_text(name_node, source_bytes)
|
||||
if type_name == "var":
|
||||
resolved = _infer_var_type(child)
|
||||
if resolved:
|
||||
type_map[var_name] = resolved
|
||||
else:
|
||||
type_map[var_name] = type_name
|
||||
|
||||
elif n.type == "enhanced_for_statement":
|
||||
# for (Type item : iterable) -type and name are positional children
|
||||
prev_type: str | None = None
|
||||
for child in n.children:
|
||||
if child.type in ("type_identifier", "generic_type", "scoped_type_identifier", "array_type"):
|
||||
prev_type = _strip_generics(analyzer.get_node_text(child, source_bytes))
|
||||
elif child.type == "identifier" and prev_type is not None:
|
||||
type_map[analyzer.get_node_text(child, source_bytes)] = prev_type
|
||||
prev_type = None
|
||||
|
||||
elif n.type == "resource":
|
||||
# try-with-resources: try (Type res = ...) { ... }
|
||||
type_node = n.child_by_field_name("type")
|
||||
name_node = n.child_by_field_name("name")
|
||||
if type_node and name_node:
|
||||
type_map[analyzer.get_node_text(name_node, source_bytes)] = _strip_generics(
|
||||
analyzer.get_node_text(type_node, source_bytes)
|
||||
)
|
||||
|
||||
for child in n.children:
|
||||
visit(child)
|
||||
|
||||
visit(node)
|
||||
return type_map
|
||||
|
||||
|
||||
def _build_field_type_map(
|
||||
node: Node, source_bytes: bytes, analyzer: JavaAnalyzer, test_class_name: str | None
|
||||
) -> dict[str, str]:
|
||||
"""Map field names to their declared types for the given class."""
|
||||
type_map: dict[str, str] = {}
|
||||
|
||||
def visit(n: Node, current_class: str | None = None) -> None:
|
||||
if n.type in ("class_declaration", "interface_declaration", "enum_declaration"):
|
||||
name_node = n.child_by_field_name("name")
|
||||
if name_node:
|
||||
current_class = analyzer.get_node_text(name_node, source_bytes)
|
||||
|
||||
if n.type == "field_declaration" and current_class == test_class_name:
|
||||
type_node = n.child_by_field_name("type")
|
||||
if type_node:
|
||||
type_name = _strip_generics(analyzer.get_node_text(type_node, source_bytes))
|
||||
for child in n.children:
|
||||
if child.type == "variable_declarator":
|
||||
name_node = child.child_by_field_name("name")
|
||||
if name_node:
|
||||
type_map[analyzer.get_node_text(name_node, source_bytes)] = type_name
|
||||
|
||||
for child in n.children:
|
||||
visit(child, current_class)
|
||||
|
||||
visit(node)
|
||||
return type_map
|
||||
|
||||
|
||||
def _build_static_import_map(node: Node, source_bytes: bytes, analyzer: JavaAnalyzer) -> dict[str, str]:
|
||||
"""Map statically imported member names to their declaring class.
|
||||
|
||||
For ``import static com.example.Calculator.add;`` the result is
|
||||
``{"add": "Calculator"}``.
|
||||
"""
|
||||
static_map: dict[str, str] = {}
|
||||
|
||||
def visit(n: Node) -> None:
|
||||
if n.type == "import_declaration":
|
||||
import_text = analyzer.get_node_text(n, source_bytes)
|
||||
if "import static" not in import_text:
|
||||
for child in n.children:
|
||||
visit(child)
|
||||
return
|
||||
|
||||
path = import_text.replace("import static", "").replace(";", "").strip()
|
||||
if path.endswith(".*") or "." not in path:
|
||||
for child in n.children:
|
||||
visit(child)
|
||||
return
|
||||
|
||||
parts = path.rsplit(".", 2)
|
||||
if len(parts) >= 2:
|
||||
member_name = parts[-1]
|
||||
class_name = parts[-2]
|
||||
if class_name and class_name[0].isupper():
|
||||
static_map[member_name] = class_name
|
||||
|
||||
for child in n.children:
|
||||
visit(child)
|
||||
|
||||
visit(node)
|
||||
return static_map
|
||||
|
||||
|
||||
def _extract_imports(node: Node, source_bytes: bytes, analyzer: JavaAnalyzer) -> set[str]:
|
||||
"""Extract imported class names (simple names) from a Java file."""
|
||||
imports: set[str] = set()
|
||||
|
||||
def visit(n):
|
||||
def visit(n: Node) -> None:
|
||||
if n.type == "import_declaration":
|
||||
import_text = analyzer.get_node_text(n, source_bytes)
|
||||
|
||||
# Check if it's a wildcard import - skip these as we can't know specific classes
|
||||
if import_text.rstrip(";").endswith(".*"):
|
||||
# For static wildcard imports like "import static com.example.Utils.*"
|
||||
# we CAN extract the class name (Utils)
|
||||
if "import static" in import_text:
|
||||
# Extract class from "import static com.example.Utils.*"
|
||||
# Remove "import static " prefix and ".*;" suffix
|
||||
path = import_text.replace("import static ", "").rstrip(";").rstrip(".*")
|
||||
if "." in path:
|
||||
class_name = path.rsplit(".", 1)[-1]
|
||||
if class_name and class_name[0].isupper(): # Ensure it's a class name
|
||||
if class_name and class_name[0].isupper():
|
||||
imports.add(class_name)
|
||||
# For regular wildcards like "import com.example.*", skip entirely
|
||||
return
|
||||
|
||||
# Check if it's a static import of a specific method/field
|
||||
if "import static" in import_text:
|
||||
# "import static com.example.Utils.format;"
|
||||
# We want to extract "Utils" (the class), not "format" (the method)
|
||||
path = import_text.replace("import static ", "").rstrip(";")
|
||||
parts = path.rsplit(".", 2) # Split into [package..., Class, member]
|
||||
parts = path.rsplit(".", 2)
|
||||
if len(parts) >= 2:
|
||||
# The second-to-last part is the class name
|
||||
class_name = parts[-2]
|
||||
if class_name and class_name[0].isupper(): # Ensure it's a class name
|
||||
if class_name and class_name[0].isupper():
|
||||
imports.add(class_name)
|
||||
return
|
||||
|
||||
# Regular import: extract class name from scoped_identifier
|
||||
for child in n.children:
|
||||
if child.type in {"scoped_identifier", "identifier"}:
|
||||
import_path = analyzer.get_node_text(child, source_bytes)
|
||||
# Extract just the class name (last part)
|
||||
# e.g., "com.example.Buffer" -> "Buffer"
|
||||
if "." in import_path:
|
||||
class_name = import_path.rsplit(".", 1)[-1]
|
||||
else:
|
||||
class_name = import_path
|
||||
# Skip if it looks like a package name (lowercase)
|
||||
if class_name and class_name[0].isupper():
|
||||
imports.add(class_name)
|
||||
|
||||
|
|
@ -228,25 +324,119 @@ def _extract_imports(node, source_bytes: bytes, analyzer: JavaAnalyzer) -> set[s
|
|||
return imports
|
||||
|
||||
|
||||
def _find_method_calls_in_range(
|
||||
node, source_bytes: bytes, start_line: int, end_line: int, analyzer: JavaAnalyzer
|
||||
) -> list[str]:
|
||||
"""Find method calls within a line range.
|
||||
# ---------------------------------------------------------------------------
|
||||
# Method call resolution
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
Args:
|
||||
node: Tree-sitter node to search.
|
||||
source_bytes: Source code as bytes.
|
||||
start_line: Start line (1-indexed).
|
||||
end_line: End line (1-indexed).
|
||||
analyzer: JavaAnalyzer instance.
|
||||
|
||||
Returns:
|
||||
List of method names called.
|
||||
def _resolve_method_calls_in_range(
|
||||
node: Node,
|
||||
source_bytes: bytes,
|
||||
start_line: int,
|
||||
end_line: int,
|
||||
analyzer: JavaAnalyzer,
|
||||
type_map: dict[str, str],
|
||||
static_import_map: dict[str, str],
|
||||
) -> set[str]:
|
||||
"""Resolve method invocations and constructor calls within a line range.
|
||||
|
||||
Returns resolved references as ``ClassName.methodName`` strings.
|
||||
|
||||
Handles method invocations:
|
||||
- ``variable.method()`` - looks up variable type in *type_map*.
|
||||
- ``ClassName.staticMethod()`` - uppercase-first identifier treated as class.
|
||||
- ``new ClassName().method()`` - extracts type from constructor.
|
||||
- ``((ClassName) expr).method()`` - extracts type from cast.
|
||||
- ``this.field.method()`` - resolves field type via *type_map*.
|
||||
- ``method()`` with no receiver - checks *static_import_map*.
|
||||
|
||||
Handles constructor calls:
|
||||
- ``new ClassName(...)`` - emits ``ClassName.ClassName`` and ``ClassName.<init>``.
|
||||
"""
|
||||
resolved: set[str] = set()
|
||||
|
||||
def _type_from_object_node(obj: Node) -> str | None:
|
||||
"""Try to determine the class name from a method invocation's object."""
|
||||
if obj.type == "identifier":
|
||||
text = analyzer.get_node_text(obj, source_bytes)
|
||||
if text in type_map:
|
||||
return type_map[text]
|
||||
# Uppercase-first identifier without a type mapping → likely a class (static call)
|
||||
if text and text[0].isupper():
|
||||
return text
|
||||
return None
|
||||
|
||||
if obj.type == "object_creation_expression":
|
||||
type_node = obj.child_by_field_name("type")
|
||||
if type_node:
|
||||
return _strip_generics(analyzer.get_node_text(type_node, source_bytes))
|
||||
return None
|
||||
|
||||
if obj.type == "field_access":
|
||||
# this.field → look up field in type_map
|
||||
field_node = obj.child_by_field_name("field")
|
||||
obj_child = obj.child_by_field_name("object")
|
||||
if field_node and obj_child:
|
||||
field_name = analyzer.get_node_text(field_node, source_bytes)
|
||||
if obj_child.type == "this" and field_name in type_map:
|
||||
return type_map[field_name]
|
||||
return None
|
||||
|
||||
if obj.type == "parenthesized_expression":
|
||||
# Unwrap parentheses, look for cast_expression
|
||||
for child in obj.children:
|
||||
if child.type == "cast_expression":
|
||||
type_node = child.child_by_field_name("type")
|
||||
if type_node:
|
||||
return _strip_generics(analyzer.get_node_text(type_node, source_bytes))
|
||||
return None
|
||||
|
||||
return None
|
||||
|
||||
def visit(n: Node) -> None:
|
||||
n_start = n.start_point[0] + 1
|
||||
n_end = n.end_point[0] + 1
|
||||
if n_end < start_line or n_start > end_line:
|
||||
return
|
||||
|
||||
if n.type == "method_invocation":
|
||||
name_node = n.child_by_field_name("name")
|
||||
object_node = n.child_by_field_name("object")
|
||||
|
||||
if name_node:
|
||||
method_name = analyzer.get_node_text(name_node, source_bytes)
|
||||
|
||||
if object_node:
|
||||
class_name = _type_from_object_node(object_node)
|
||||
if class_name:
|
||||
resolved.add(f"{class_name}.{method_name}")
|
||||
# No receiver - check static imports
|
||||
elif method_name in static_import_map:
|
||||
resolved.add(f"{static_import_map[method_name]}.{method_name}")
|
||||
|
||||
elif n.type == "object_creation_expression":
|
||||
# Constructor call: new ClassName(...)
|
||||
# Emit both common qualified-name conventions so the function_map
|
||||
# can use either ClassName.ClassName or ClassName.<init>.
|
||||
type_node = n.child_by_field_name("type")
|
||||
if type_node:
|
||||
class_name = _strip_generics(analyzer.get_node_text(type_node, source_bytes))
|
||||
resolved.add(f"{class_name}.{class_name}")
|
||||
resolved.add(f"{class_name}.<init>")
|
||||
|
||||
for child in n.children:
|
||||
visit(child)
|
||||
|
||||
visit(node)
|
||||
return resolved
|
||||
|
||||
|
||||
def _find_method_calls_in_range(
|
||||
node: Node, source_bytes: bytes, start_line: int, end_line: int, analyzer: JavaAnalyzer
|
||||
) -> list[str]:
|
||||
"""Find bare method call names within a line range (legacy helper)."""
|
||||
calls: list[str] = []
|
||||
|
||||
# Check if this node is within the range (convert to 0-indexed)
|
||||
node_start = node.start_point[0] + 1
|
||||
node_end = node.end_point[0] + 1
|
||||
|
||||
|
|
|
|||
|
|
@ -6,10 +6,7 @@ regression test code that captures function return values.
|
|||
All tests assert for full string equality, no substring matching.
|
||||
"""
|
||||
|
||||
from codeflash.languages.java.remove_asserts import (
|
||||
JavaAssertTransformer,
|
||||
transform_java_assertions,
|
||||
)
|
||||
from codeflash.languages.java.remove_asserts import JavaAssertTransformer, transform_java_assertions
|
||||
|
||||
|
||||
class TestBasicJUnit5Assertions:
|
||||
|
|
|
|||
2227
tests/test_java_test_discovery.py
Normal file
2227
tests/test_java_test_discovery.py
Normal file
File diff suppressed because it is too large
Load diff
Loading…
Reference in a new issue