better unit test discovery java

This commit is contained in:
misrasaurabh1 2026-02-05 23:57:13 -08:00
parent 95cc60397d
commit 0ff54b5043
3 changed files with 2518 additions and 104 deletions

View file

@ -2,6 +2,11 @@
This module provides functionality to discover tests that exercise
specific functions, mapping source functions to their tests.
The core matching strategy traces method invocations in test code back to their
declaring class by resolving variable types from declarations, field types, static
imports, and constructor expressions. This is analogous to how Python test discovery
uses jedi's "goto" functionality.
"""
from __future__ import annotations
@ -19,6 +24,8 @@ if TYPE_CHECKING:
from collections.abc import Sequence
from pathlib import Path
from tree_sitter import Node
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
from codeflash.languages.java.parser import JavaAnalyzer
@ -30,11 +37,8 @@ def discover_tests(
) -> dict[str, list[TestInfo]]:
"""Map source functions to their tests via static analysis.
Uses several heuristics to match tests to functions:
1. Test method name contains function name
2. Test class name matches source class name
3. Imports analysis
4. Method call analysis in test code
Resolves method invocations in test code back to their declaring class by
tracing variable types, field types, static imports, and constructor calls.
Args:
test_root: Root directory containing tests.
@ -47,18 +51,16 @@ def discover_tests(
"""
analyzer = analyzer or get_java_analyzer()
# Build a map of function names for quick lookup
function_map: dict[str, FunctionToOptimize] = {}
for func in source_functions:
function_map[func.function_name] = func
function_map[func.qualified_name] = func
# Find all test files (various naming conventions)
test_files = (
list(test_root.rglob("*Test.java")) + list(test_root.rglob("*Tests.java")) + list(test_root.rglob("Test*.java"))
)
# Deduplicate (a file like FooTest.java could match multiple patterns)
test_files = list(dict.fromkeys(test_files))
# Result map
result: dict[str, list[TestInfo]] = defaultdict(list)
for test_file in test_files:
@ -67,7 +69,6 @@ def discover_tests(
source = test_file.read_text(encoding="utf-8")
for test_method in test_methods:
# Find which source functions this test might exercise
matched_functions = _match_test_to_functions(test_method, source, function_map, analyzer)
for func_name in matched_functions:
@ -89,135 +90,230 @@ def _match_test_to_functions(
function_map: dict[str, FunctionToOptimize],
analyzer: JavaAnalyzer,
) -> list[str]:
"""Match a test method to source functions it might exercise.
"""Match a test method to source functions it exercises.
Resolves each method invocation in the test to ClassName.methodName by:
1. Building a variable-to-type map from local declarations and class fields.
2. Building a static import map (method -> class).
3. For each method_invocation, resolving the receiver to a class name.
4. Matching resolved ClassName.methodName against the function map.
Args:
test_method: The test method.
test_source: Full source code of the test file.
function_map: Map of function names to FunctionToOptimize.
function_map: Map of qualified names to FunctionToOptimize.
analyzer: JavaAnalyzer instance.
Returns:
List of function qualified names that this test might exercise.
List of function qualified names that this test exercises.
"""
matched: list[str] = []
# Strategy 1: Test method name contains function name
# e.g., testAdd -> add, testCalculatorAdd -> Calculator.add
test_name_lower = test_method.function_name.lower()
for func_info in function_map.values():
if func_info.function_name.lower() in test_name_lower:
matched.append(func_info.qualified_name)
# Strategy 2: Method call analysis
# Look for direct method calls in the test code
source_bytes = test_source.encode("utf8")
tree = analyzer.parse(source_bytes)
# Find method calls within the test method's line range
method_calls = _find_method_calls_in_range(
# Build type resolution context
field_types = _build_field_type_map(tree.root_node, source_bytes, analyzer, test_method.class_name)
local_types = _build_local_type_map(
tree.root_node, source_bytes, test_method.starting_line, test_method.ending_line, analyzer
)
# Locals shadow fields
type_map = {**field_types, **local_types}
for call_name in method_calls:
if call_name in function_map:
qualified = function_map[call_name].qualified_name
if qualified not in matched:
matched.append(qualified)
static_import_map = _build_static_import_map(tree.root_node, source_bytes, analyzer)
# Strategy 3: Test class naming convention
# e.g., CalculatorTest tests Calculator, TestCalculator tests Calculator
if test_method.class_name:
# Remove "Test/Tests" suffix or "Test" prefix
source_class_name = test_method.class_name
if source_class_name.endswith("Tests"):
source_class_name = source_class_name[:-5]
elif source_class_name.endswith("Test"):
source_class_name = source_class_name[:-4]
elif source_class_name.startswith("Test"):
source_class_name = source_class_name[4:]
# Resolve method calls to ClassName.methodName
resolved_calls = _resolve_method_calls_in_range(
tree.root_node, source_bytes, test_method.starting_line, test_method.ending_line, analyzer, type_map,
static_import_map,
)
# Look for functions in the matching class
for func_info in function_map.values():
if func_info.class_name == source_class_name:
if func_info.qualified_name not in matched:
matched.append(func_info.qualified_name)
# Strategy 4: Import-based matching
# If the test file imports a class containing the target function, consider it a match
# This handles cases like TestQueryBlob importing Buffer and calling Buffer methods
imported_classes = _extract_imports(tree.root_node, source_bytes, analyzer)
for func_info in function_map.values():
if func_info.qualified_name in matched:
continue
# Check if the function's class is imported
if func_info.class_name and func_info.class_name in imported_classes:
matched.append(func_info.qualified_name)
matched: list[str] = []
for call in resolved_calls:
if call in function_map and call not in matched:
matched.append(call)
return matched
def _extract_imports(node, source_bytes: bytes, analyzer: JavaAnalyzer) -> set[str]:
"""Extract imported class names from a Java file.
# ---------------------------------------------------------------------------
# Type resolution helpers
# ---------------------------------------------------------------------------
Args:
node: Tree-sitter root node.
source_bytes: Source code as bytes.
analyzer: JavaAnalyzer instance.
Returns:
Set of imported class names (simple names, not fully qualified).
def _strip_generics(type_name: str) -> str:
"""Strip generic type parameters: ``List<String>`` -> ``List``."""
idx = type_name.find("<")
if idx != -1:
return type_name[:idx].strip()
return type_name.strip()
def _build_local_type_map(
node: Node, source_bytes: bytes, start_line: int, end_line: int, analyzer: JavaAnalyzer
) -> dict[str, str]:
"""Map variable names to their declared types within a line range.
Handles local variable declarations (including ``var`` with constructor
initializers) and enhanced-for loop variables.
"""
type_map: dict[str, str] = {}
def _infer_var_type(declarator: Node) -> str | None:
value_node = declarator.child_by_field_name("value")
if value_node is None:
return None
if value_node.type == "object_creation_expression":
type_node = value_node.child_by_field_name("type")
if type_node:
return _strip_generics(analyzer.get_node_text(type_node, source_bytes))
return None
def visit(n: Node) -> None:
n_start = n.start_point[0] + 1
n_end = n.end_point[0] + 1
if n_end < start_line or n_start > end_line:
return
if n.type == "local_variable_declaration":
type_node = n.child_by_field_name("type")
if type_node:
type_name = _strip_generics(analyzer.get_node_text(type_node, source_bytes))
for child in n.children:
if child.type == "variable_declarator":
name_node = child.child_by_field_name("name")
if name_node:
var_name = analyzer.get_node_text(name_node, source_bytes)
if type_name == "var":
resolved = _infer_var_type(child)
if resolved:
type_map[var_name] = resolved
else:
type_map[var_name] = type_name
elif n.type == "enhanced_for_statement":
# for (Type item : iterable) -type and name are positional children
prev_type: str | None = None
for child in n.children:
if child.type in ("type_identifier", "generic_type", "scoped_type_identifier", "array_type"):
prev_type = _strip_generics(analyzer.get_node_text(child, source_bytes))
elif child.type == "identifier" and prev_type is not None:
type_map[analyzer.get_node_text(child, source_bytes)] = prev_type
prev_type = None
elif n.type == "resource":
# try-with-resources: try (Type res = ...) { ... }
type_node = n.child_by_field_name("type")
name_node = n.child_by_field_name("name")
if type_node and name_node:
type_map[analyzer.get_node_text(name_node, source_bytes)] = _strip_generics(
analyzer.get_node_text(type_node, source_bytes)
)
for child in n.children:
visit(child)
visit(node)
return type_map
def _build_field_type_map(
node: Node, source_bytes: bytes, analyzer: JavaAnalyzer, test_class_name: str | None
) -> dict[str, str]:
"""Map field names to their declared types for the given class."""
type_map: dict[str, str] = {}
def visit(n: Node, current_class: str | None = None) -> None:
if n.type in ("class_declaration", "interface_declaration", "enum_declaration"):
name_node = n.child_by_field_name("name")
if name_node:
current_class = analyzer.get_node_text(name_node, source_bytes)
if n.type == "field_declaration" and current_class == test_class_name:
type_node = n.child_by_field_name("type")
if type_node:
type_name = _strip_generics(analyzer.get_node_text(type_node, source_bytes))
for child in n.children:
if child.type == "variable_declarator":
name_node = child.child_by_field_name("name")
if name_node:
type_map[analyzer.get_node_text(name_node, source_bytes)] = type_name
for child in n.children:
visit(child, current_class)
visit(node)
return type_map
def _build_static_import_map(node: Node, source_bytes: bytes, analyzer: JavaAnalyzer) -> dict[str, str]:
"""Map statically imported member names to their declaring class.
For ``import static com.example.Calculator.add;`` the result is
``{"add": "Calculator"}``.
"""
static_map: dict[str, str] = {}
def visit(n: Node) -> None:
if n.type == "import_declaration":
import_text = analyzer.get_node_text(n, source_bytes)
if "import static" not in import_text:
for child in n.children:
visit(child)
return
path = import_text.replace("import static", "").replace(";", "").strip()
if path.endswith(".*") or "." not in path:
for child in n.children:
visit(child)
return
parts = path.rsplit(".", 2)
if len(parts) >= 2:
member_name = parts[-1]
class_name = parts[-2]
if class_name and class_name[0].isupper():
static_map[member_name] = class_name
for child in n.children:
visit(child)
visit(node)
return static_map
def _extract_imports(node: Node, source_bytes: bytes, analyzer: JavaAnalyzer) -> set[str]:
"""Extract imported class names (simple names) from a Java file."""
imports: set[str] = set()
def visit(n):
def visit(n: Node) -> None:
if n.type == "import_declaration":
import_text = analyzer.get_node_text(n, source_bytes)
# Check if it's a wildcard import - skip these as we can't know specific classes
if import_text.rstrip(";").endswith(".*"):
# For static wildcard imports like "import static com.example.Utils.*"
# we CAN extract the class name (Utils)
if "import static" in import_text:
# Extract class from "import static com.example.Utils.*"
# Remove "import static " prefix and ".*;" suffix
path = import_text.replace("import static ", "").rstrip(";").rstrip(".*")
if "." in path:
class_name = path.rsplit(".", 1)[-1]
if class_name and class_name[0].isupper(): # Ensure it's a class name
if class_name and class_name[0].isupper():
imports.add(class_name)
# For regular wildcards like "import com.example.*", skip entirely
return
# Check if it's a static import of a specific method/field
if "import static" in import_text:
# "import static com.example.Utils.format;"
# We want to extract "Utils" (the class), not "format" (the method)
path = import_text.replace("import static ", "").rstrip(";")
parts = path.rsplit(".", 2) # Split into [package..., Class, member]
parts = path.rsplit(".", 2)
if len(parts) >= 2:
# The second-to-last part is the class name
class_name = parts[-2]
if class_name and class_name[0].isupper(): # Ensure it's a class name
if class_name and class_name[0].isupper():
imports.add(class_name)
return
# Regular import: extract class name from scoped_identifier
for child in n.children:
if child.type in {"scoped_identifier", "identifier"}:
import_path = analyzer.get_node_text(child, source_bytes)
# Extract just the class name (last part)
# e.g., "com.example.Buffer" -> "Buffer"
if "." in import_path:
class_name = import_path.rsplit(".", 1)[-1]
else:
class_name = import_path
# Skip if it looks like a package name (lowercase)
if class_name and class_name[0].isupper():
imports.add(class_name)
@ -228,25 +324,119 @@ def _extract_imports(node, source_bytes: bytes, analyzer: JavaAnalyzer) -> set[s
return imports
def _find_method_calls_in_range(
node, source_bytes: bytes, start_line: int, end_line: int, analyzer: JavaAnalyzer
) -> list[str]:
"""Find method calls within a line range.
# ---------------------------------------------------------------------------
# Method call resolution
# ---------------------------------------------------------------------------
Args:
node: Tree-sitter node to search.
source_bytes: Source code as bytes.
start_line: Start line (1-indexed).
end_line: End line (1-indexed).
analyzer: JavaAnalyzer instance.
Returns:
List of method names called.
def _resolve_method_calls_in_range(
node: Node,
source_bytes: bytes,
start_line: int,
end_line: int,
analyzer: JavaAnalyzer,
type_map: dict[str, str],
static_import_map: dict[str, str],
) -> set[str]:
"""Resolve method invocations and constructor calls within a line range.
Returns resolved references as ``ClassName.methodName`` strings.
Handles method invocations:
- ``variable.method()`` - looks up variable type in *type_map*.
- ``ClassName.staticMethod()`` - uppercase-first identifier treated as class.
- ``new ClassName().method()`` - extracts type from constructor.
- ``((ClassName) expr).method()`` - extracts type from cast.
- ``this.field.method()`` - resolves field type via *type_map*.
- ``method()`` with no receiver - checks *static_import_map*.
Handles constructor calls:
- ``new ClassName(...)`` - emits ``ClassName.ClassName`` and ``ClassName.<init>``.
"""
resolved: set[str] = set()
def _type_from_object_node(obj: Node) -> str | None:
"""Try to determine the class name from a method invocation's object."""
if obj.type == "identifier":
text = analyzer.get_node_text(obj, source_bytes)
if text in type_map:
return type_map[text]
# Uppercase-first identifier without a type mapping → likely a class (static call)
if text and text[0].isupper():
return text
return None
if obj.type == "object_creation_expression":
type_node = obj.child_by_field_name("type")
if type_node:
return _strip_generics(analyzer.get_node_text(type_node, source_bytes))
return None
if obj.type == "field_access":
# this.field → look up field in type_map
field_node = obj.child_by_field_name("field")
obj_child = obj.child_by_field_name("object")
if field_node and obj_child:
field_name = analyzer.get_node_text(field_node, source_bytes)
if obj_child.type == "this" and field_name in type_map:
return type_map[field_name]
return None
if obj.type == "parenthesized_expression":
# Unwrap parentheses, look for cast_expression
for child in obj.children:
if child.type == "cast_expression":
type_node = child.child_by_field_name("type")
if type_node:
return _strip_generics(analyzer.get_node_text(type_node, source_bytes))
return None
return None
def visit(n: Node) -> None:
n_start = n.start_point[0] + 1
n_end = n.end_point[0] + 1
if n_end < start_line or n_start > end_line:
return
if n.type == "method_invocation":
name_node = n.child_by_field_name("name")
object_node = n.child_by_field_name("object")
if name_node:
method_name = analyzer.get_node_text(name_node, source_bytes)
if object_node:
class_name = _type_from_object_node(object_node)
if class_name:
resolved.add(f"{class_name}.{method_name}")
# No receiver - check static imports
elif method_name in static_import_map:
resolved.add(f"{static_import_map[method_name]}.{method_name}")
elif n.type == "object_creation_expression":
# Constructor call: new ClassName(...)
# Emit both common qualified-name conventions so the function_map
# can use either ClassName.ClassName or ClassName.<init>.
type_node = n.child_by_field_name("type")
if type_node:
class_name = _strip_generics(analyzer.get_node_text(type_node, source_bytes))
resolved.add(f"{class_name}.{class_name}")
resolved.add(f"{class_name}.<init>")
for child in n.children:
visit(child)
visit(node)
return resolved
def _find_method_calls_in_range(
node: Node, source_bytes: bytes, start_line: int, end_line: int, analyzer: JavaAnalyzer
) -> list[str]:
"""Find bare method call names within a line range (legacy helper)."""
calls: list[str] = []
# Check if this node is within the range (convert to 0-indexed)
node_start = node.start_point[0] + 1
node_end = node.end_point[0] + 1

View file

@ -6,10 +6,7 @@ regression test code that captures function return values.
All tests assert for full string equality, no substring matching.
"""
from codeflash.languages.java.remove_asserts import (
JavaAssertTransformer,
transform_java_assertions,
)
from codeflash.languages.java.remove_asserts import JavaAssertTransformer, transform_java_assertions
class TestBasicJUnit5Assertions:

File diff suppressed because it is too large Load diff