This commit is contained in:
HeshamHM28 2026-02-17 23:27:05 +02:00
parent 22541e085a
commit 60a28c0843
19 changed files with 128 additions and 129 deletions

View file

@ -715,7 +715,12 @@ def inject_profiling_into_existing_test(
from codeflash.languages.javascript.instrument import inject_profiling_into_existing_js_test from codeflash.languages.javascript.instrument import inject_profiling_into_existing_js_test
return inject_profiling_into_existing_js_test( return inject_profiling_into_existing_js_test(
test_string=test_string, call_positions=call_positions, function_to_optimize=function_to_optimize, tests_project_root=tests_project_root, mode= mode.value, test_path=test_path test_string=test_string,
call_positions=call_positions,
function_to_optimize=function_to_optimize,
tests_project_root=tests_project_root,
mode=mode.value,
test_path=test_path,
) )
if is_java(): if is_java():
@ -725,11 +730,14 @@ def inject_profiling_into_existing_test(
if function_to_optimize.is_async: if function_to_optimize.is_async:
return inject_async_profiling_into_existing_test( return inject_async_profiling_into_existing_test(
test_string=test_string, call_positions=call_positions, function_to_optimize=function_to_optimize, tests_project_root=tests_project_root, mode=mode.value, test_path=test_path test_string=test_string,
call_positions=call_positions,
function_to_optimize=function_to_optimize,
tests_project_root=tests_project_root,
mode=mode.value,
test_path=test_path,
) )
used_frameworks = detect_frameworks_from_code(test_string) used_frameworks = detect_frameworks_from_code(test_string)
try: try:
tree = ast.parse(test_string) tree = ast.parse(test_string)

View file

@ -572,7 +572,7 @@ class LanguageSupport(Protocol):
function_to_optimize: Any, function_to_optimize: Any,
tests_project_root: Path, tests_project_root: Path,
mode: str, mode: str,
test_path: Path | None test_path: Path | None,
) -> tuple[bool, str | None]: ) -> tuple[bool, str | None]:
"""Inject profiling code into an existing test file. """Inject profiling code into an existing test file.

View file

@ -184,7 +184,6 @@ def _get_maven_project_info(project_root: Path) -> JavaProjectInfo | None:
if test_src.exists(): if test_src.exists():
test_roots.append(test_src) test_roots.append(test_src)
# Check for custom source directories in pom.xml <build> section # Check for custom source directories in pom.xml <build> section
for build in [root.find("m:build", ns), root.find("build")]: for build in [root.find("m:build", ns), root.find("build")]:
if build is not None: if build is not None:
@ -660,7 +659,7 @@ def add_codeflash_dependency_to_pom(pom_path: Path) -> bool:
new_content = content[:idx] + CODEFLASH_DEPENDENCY_SNIPPET new_content = content[:idx] + CODEFLASH_DEPENDENCY_SNIPPET
# Skip the original </dependencies> tag since our snippet includes it # Skip the original </dependencies> tag since our snippet includes it
new_content += content[idx + len(closing_tag):] new_content += content[idx + len(closing_tag) :]
pom_path.write_text(new_content, encoding="utf-8") pom_path.write_text(new_content, encoding="utf-8")
logger.info("Added codeflash-runtime dependency to pom.xml") logger.info("Added codeflash-runtime dependency to pom.xml")

View file

@ -179,13 +179,20 @@ def compare_test_results(
[ [
java_exe, java_exe,
# Java 16+ module system: Kryo needs reflective access to internal JDK classes # Java 16+ module system: Kryo needs reflective access to internal JDK classes
"--add-opens", "java.base/java.util=ALL-UNNAMED", "--add-opens",
"--add-opens", "java.base/java.lang=ALL-UNNAMED", "java.base/java.util=ALL-UNNAMED",
"--add-opens", "java.base/java.lang.reflect=ALL-UNNAMED", "--add-opens",
"--add-opens", "java.base/java.io=ALL-UNNAMED", "java.base/java.lang=ALL-UNNAMED",
"--add-opens", "java.base/java.math=ALL-UNNAMED", "--add-opens",
"--add-opens", "java.base/java.net=ALL-UNNAMED", "java.base/java.lang.reflect=ALL-UNNAMED",
"--add-opens", "java.base/java.util.zip=ALL-UNNAMED", "--add-opens",
"java.base/java.io=ALL-UNNAMED",
"--add-opens",
"java.base/java.math=ALL-UNNAMED",
"--add-opens",
"java.base/java.net=ALL-UNNAMED",
"--add-opens",
"java.base/java.util.zip=ALL-UNNAMED",
"-cp", "-cp",
str(jar_path), str(jar_path),
"com.codeflash.Comparator", "com.codeflash.Comparator",

View file

@ -18,8 +18,6 @@ from pathlib import Path
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
if TYPE_CHECKING: if TYPE_CHECKING:
from tree_sitter import Node
from codeflash.languages.base import FunctionInfo from codeflash.languages.base import FunctionInfo
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -306,9 +304,7 @@ class JavaConcurrencyAnalyzer:
return suggestions return suggestions
def analyze_function_concurrency( def analyze_function_concurrency(func: FunctionInfo, source: str | None = None, analyzer=None) -> ConcurrencyInfo:
func: FunctionInfo, source: str | None = None, analyzer=None
) -> ConcurrencyInfo:
"""Analyze a function for concurrency patterns. """Analyze a function for concurrency patterns.
Convenience function that creates a JavaConcurrencyAnalyzer and analyzes the function. Convenience function that creates a JavaConcurrencyAnalyzer and analyzes the function.

View file

@ -34,6 +34,7 @@ class JavaLineProfiler:
instrumented = profiler.instrument_source(source, file_path, functions) instrumented = profiler.instrument_source(source, file_path, functions)
# Run instrumented code # Run instrumented code
results = JavaLineProfiler.parse_results(Path("profile.json")) results = JavaLineProfiler.parse_results(Path("profile.json"))
""" """
def __init__(self, output_file: Path) -> None: def __init__(self, output_file: Path) -> None:
@ -48,13 +49,7 @@ class JavaLineProfiler:
self.profiler_var = "__codeflashProfiler__" self.profiler_var = "__codeflashProfiler__"
self.line_contents: dict[str, str] = {} self.line_contents: dict[str, str] = {}
def instrument_source( def instrument_source(self, source: str, file_path: Path, functions: list[FunctionInfo], analyzer=None) -> str:
self,
source: str,
file_path: Path,
functions: list[FunctionInfo],
analyzer=None,
) -> str:
"""Instrument Java source code with line profiling. """Instrument Java source code with line profiling.
Adds profiling instrumentation to track line-level execution for the Adds profiling instrumentation to track line-level execution for the
@ -106,9 +101,7 @@ class JavaLineProfiler:
import_end_idx = i import_end_idx = i
break break
lines_with_profiler = ( lines_with_profiler = lines[:import_end_idx] + [profiler_class_code + "\n"] + lines[import_end_idx:]
lines[:import_end_idx] + [profiler_class_code + "\n"] + lines[import_end_idx:]
)
result = "".join(lines_with_profiler) result = "".join(lines_with_profiler)
if not analyzer.validate_syntax(result): if not analyzer.validate_syntax(result):
@ -121,7 +114,7 @@ class JavaLineProfiler:
# Store line contents as a simple map (embedded directly in code) # Store line contents as a simple map (embedded directly in code)
line_contents_code = self._generate_line_contents_map() line_contents_code = self._generate_line_contents_map()
return f''' return f"""
/** /**
* Codeflash line profiler - tracks per-line execution statistics. * Codeflash line profiler - tracks per-line execution statistics.
* Auto-generated - do not modify. * Auto-generated - do not modify.
@ -132,7 +125,7 @@ class {self.profiler_class} {{
private static final ThreadLocal<Long> lastLineTime = new ThreadLocal<>(); private static final ThreadLocal<Long> lastLineTime = new ThreadLocal<>();
private static final ThreadLocal<String> lastKey = new ThreadLocal<>(); private static final ThreadLocal<String> lastKey = new ThreadLocal<>();
private static final java.util.concurrent.atomic.AtomicInteger totalHits = new java.util.concurrent.atomic.AtomicInteger(0); private static final java.util.concurrent.atomic.AtomicInteger totalHits = new java.util.concurrent.atomic.AtomicInteger(0);
private static final String OUTPUT_FILE = "{str(self.output_file)}"; private static final String OUTPUT_FILE = "{self.output_file!s}";
static class LineStats {{ static class LineStats {{
public final java.util.concurrent.atomic.AtomicLong hits = new java.util.concurrent.atomic.AtomicLong(0); public final java.util.concurrent.atomic.AtomicLong hits = new java.util.concurrent.atomic.AtomicLong(0);
@ -247,15 +240,9 @@ class {self.profiler_class} {{
Runtime.getRuntime().addShutdownHook(new Thread(() -> save())); Runtime.getRuntime().addShutdownHook(new Thread(() -> save()));
}} }}
}} }}
''' """
def _instrument_function( def _instrument_function(self, func: FunctionInfo, lines: list[str], file_path: Path, analyzer) -> list[str]:
self,
func: FunctionInfo,
lines: list[str],
file_path: Path,
analyzer,
) -> list[str]:
"""Instrument a single function with line profiling. """Instrument a single function with line profiling.
Args: Args:
@ -300,9 +287,7 @@ class {self.profiler_class} {{
# Add the line with enterFunction() call after it # Add the line with enterFunction() call after it
instrumented_lines.append(line) instrumented_lines.append(line)
instrumented_lines.append( instrumented_lines.append(f"{body_indent}{self.profiler_class}.enterFunction();\n")
f"{body_indent}{self.profiler_class}.enterFunction();\n"
)
function_entry_added = True function_entry_added = True
continue continue
@ -326,8 +311,7 @@ class {self.profiler_class} {{
# Add hit() call before the line # Add hit() call before the line
profiled_line = ( profiled_line = (
f"{indent_str}{self.profiler_class}.hit(" f'{indent_str}{self.profiler_class}.hit("{file_path.as_posix()}", {global_line_num});\n{line}'
f'"{file_path.as_posix()}", {global_line_num});\n{line}'
) )
instrumented_lines.append(profiled_line) instrumented_lines.append(profiled_line)
else: else:
@ -497,8 +481,6 @@ def format_line_profile_results(results: dict, file_path: Path | None = None) ->
avg_ms = time_ms / hits if hits > 0 else 0 avg_ms = time_ms / hits if hits > 0 else 0
content = stats.get("content", "")[:50] # Truncate long lines content = stats.get("content", "")[:50] # Truncate long lines
output.append( output.append(f"{line_num:6d} | {hits:10d} | {time_ms:12.3f} | {avg_ms:12.6f} | {content}")
f"{line_num:6d} | {hits:10d} | {time_ms:12.3f} | {avg_ms:12.6f} | {content}"
)
return "\n".join(output) return "\n".join(output)

View file

@ -298,9 +298,7 @@ class JavaAssertTransformer:
# - Assertions.assertEquals (JUnit 5) # - Assertions.assertEquals (JUnit 5)
# - org.junit.jupiter.api.Assertions.assertEquals (fully qualified) # - org.junit.jupiter.api.Assertions.assertEquals (fully qualified)
all_assertions = "|".join(JUNIT5_ALL_ASSERTIONS) all_assertions = "|".join(JUNIT5_ALL_ASSERTIONS)
pattern = re.compile( pattern = re.compile(rf"(\s*)((?:(?:\w+\.)*Assert(?:ions)?\.)?({all_assertions}))\s*\(", re.MULTILINE)
rf"(\s*)((?:(?:\w+\.)*Assert(?:ions)?\.)?({all_assertions}))\s*\(", re.MULTILINE
)
for match in pattern.finditer(source): for match in pattern.finditer(source):
leading_ws = match.group(1) leading_ws = match.group(1)
@ -549,8 +547,12 @@ class JavaAssertTransformer:
return results return results
def _collect_target_invocations( def _collect_target_invocations(
self, node, wrapper_bytes: bytes, content_bytes: bytes, self,
base_offset: int, out: list[TargetCall], node,
wrapper_bytes: bytes,
content_bytes: bytes,
base_offset: int,
out: list[TargetCall],
seen_top_level: set[tuple[int, int]] | None = None, seen_top_level: set[tuple[int, int]] | None = None,
) -> None: ) -> None:
"""Recursively walk the AST and collect method_invocation nodes that match self.func_name. """Recursively walk the AST and collect method_invocation nodes that match self.func_name.
@ -574,22 +576,24 @@ class JavaAssertTransformer:
seen_top_level.add(range_key) seen_top_level.add(range_key)
start = top_node.start_byte - prefix_len start = top_node.start_byte - prefix_len
end = top_node.end_byte - prefix_len end = top_node.end_byte - prefix_len
if 0 <= start and end <= len(content_bytes): if start >= 0 and end <= len(content_bytes):
full_call = self.analyzer.get_node_text(top_node, wrapper_bytes) full_call = self.analyzer.get_node_text(top_node, wrapper_bytes)
start_char = len(content_bytes[:start].decode("utf8")) start_char = len(content_bytes[:start].decode("utf8"))
end_char = len(content_bytes[:end].decode("utf8")) end_char = len(content_bytes[:end].decode("utf8"))
out.append(TargetCall( out.append(
receiver=None, TargetCall(
method_name=self.func_name, receiver=None,
arguments="", method_name=self.func_name,
full_call=full_call, arguments="",
start_pos=base_offset + start_char, full_call=full_call,
end_pos=base_offset + end_char, start_pos=base_offset + start_char,
)) end_pos=base_offset + end_char,
)
)
else: else:
start = node.start_byte - prefix_len start = node.start_byte - prefix_len
end = node.end_byte - prefix_len end = node.end_byte - prefix_len
if 0 <= start and end <= len(content_bytes): if start >= 0 and end <= len(content_bytes):
out.append(self._build_target_call(node, wrapper_bytes, content_bytes, start, end, base_offset)) out.append(self._build_target_call(node, wrapper_bytes, content_bytes, start, end, base_offset))
return return
@ -597,8 +601,7 @@ class JavaAssertTransformer:
self._collect_target_invocations(child, wrapper_bytes, content_bytes, base_offset, out, seen_top_level) self._collect_target_invocations(child, wrapper_bytes, content_bytes, base_offset, out, seen_top_level)
def _build_target_call( def _build_target_call(
self, node, wrapper_bytes: bytes, content_bytes: bytes, self, node, wrapper_bytes: bytes, content_bytes: bytes, start_byte: int, end_byte: int, base_offset: int
start_byte: int, end_byte: int, base_offset: int,
) -> TargetCall: ) -> TargetCall:
"""Build a TargetCall from a tree-sitter method_invocation node.""" """Build a TargetCall from a tree-sitter method_invocation node."""
get_text = self.analyzer.get_node_text get_text = self.analyzer.get_node_text
@ -679,7 +682,6 @@ class JavaAssertTransformer:
# Handle generic types: Type<Generic> varName = ... # Handle generic types: Type<Generic> varName = ...
match = self._assign_re.search(source, line_start, assertion_start) match = self._assign_re.search(source, line_start, assertion_start)
if match: if match:
var_type = match.group(1).strip() var_type = match.group(1).strip()
var_name = match.group(2).strip() var_name = match.group(2).strip()
@ -934,18 +936,12 @@ class JavaAssertTransformer:
f"catch (Exception _cf_ignored{counter}) {{}}" f"catch (Exception _cf_ignored{counter}) {{}}"
) )
return ( return f"{ws}try {{ {code_to_run} }} catch (Exception _cf_ignored{counter}) {{}}"
f"{ws}try {{ {code_to_run} }} "
f"catch (Exception _cf_ignored{counter}) {{}}"
)
# If no lambda body found, try to extract from target calls # If no lambda body found, try to extract from target calls
if assertion.target_calls: if assertion.target_calls:
call = assertion.target_calls[0] call = assertion.target_calls[0]
return ( return f"{ws}try {{ {call.full_call}; }} catch (Exception _cf_ignored{counter}) {{}}"
f"{ws}try {{ {call.full_call}; }} "
f"catch (Exception _cf_ignored{counter}) {{}}"
)
# Fallback: comment out the assertion # Fallback: comment out the assertion
return f"{ws}// Removed assertThrows: could not extract callable" return f"{ws}// Removed assertThrows: could not extract callable"

View file

@ -12,11 +12,8 @@ from typing import TYPE_CHECKING, Any
from codeflash.languages.base import Language, LanguageSupport from codeflash.languages.base import Language, LanguageSupport
from codeflash.languages.java.build_tools import find_test_root from codeflash.languages.java.build_tools import find_test_root
from codeflash.languages.java.comparator import compare_test_results as _compare_test_results from codeflash.languages.java.comparator import compare_test_results as _compare_test_results
from codeflash.languages.java.concurrency_analyzer import analyze_function_concurrency
from codeflash.languages.java.config import detect_java_project from codeflash.languages.java.config import detect_java_project
from codeflash.languages.java.concurrency_analyzer import (
JavaConcurrencyAnalyzer,
analyze_function_concurrency,
)
from codeflash.languages.java.context import extract_code_context, find_helper_functions from codeflash.languages.java.context import extract_code_context, find_helper_functions
from codeflash.languages.java.discovery import discover_functions, discover_functions_from_source from codeflash.languages.java.discovery import discover_functions, discover_functions_from_source
from codeflash.languages.java.formatter import format_java_code, normalize_java_code from codeflash.languages.java.formatter import format_java_code, normalize_java_code
@ -288,14 +285,11 @@ class JavaSupport(LanguageSupport):
function_to_optimize: Any, function_to_optimize: Any,
tests_project_root: Path, tests_project_root: Path,
mode: str, mode: str,
test_path: Path | None test_path: Path | None,
) -> tuple[bool, str | None]: ) -> tuple[bool, str | None]:
"""Inject profiling code into an existing test file.""" """Inject profiling code into an existing test file."""
return instrument_existing_test( return instrument_existing_test(
test_string=test_string, test_string=test_string, function_to_optimize=function_to_optimize, mode=mode, test_path=test_path
function_to_optimize=function_to_optimize,
mode=mode,
test_path=test_path
) )
def instrument_source_for_line_profiler( def instrument_source_for_line_profiler(

View file

@ -610,13 +610,20 @@ def _run_tests_direct(
cmd = [ cmd = [
str(java), str(java),
# Java 16+ module system: Kryo needs reflective access to internal JDK classes # Java 16+ module system: Kryo needs reflective access to internal JDK classes
"--add-opens", "java.base/java.util=ALL-UNNAMED", "--add-opens",
"--add-opens", "java.base/java.lang=ALL-UNNAMED", "java.base/java.util=ALL-UNNAMED",
"--add-opens", "java.base/java.lang.reflect=ALL-UNNAMED", "--add-opens",
"--add-opens", "java.base/java.io=ALL-UNNAMED", "java.base/java.lang=ALL-UNNAMED",
"--add-opens", "java.base/java.math=ALL-UNNAMED", "--add-opens",
"--add-opens", "java.base/java.net=ALL-UNNAMED", "java.base/java.lang.reflect=ALL-UNNAMED",
"--add-opens", "java.base/java.util.zip=ALL-UNNAMED", "--add-opens",
"java.base/java.io=ALL-UNNAMED",
"--add-opens",
"java.base/java.math=ALL-UNNAMED",
"--add-opens",
"java.base/java.net=ALL-UNNAMED",
"--add-opens",
"java.base/java.util.zip=ALL-UNNAMED",
"-cp", "-cp",
classpath, classpath,
"org.junit.platform.console.ConsoleLauncher", "org.junit.platform.console.ConsoleLauncher",

View file

@ -1941,7 +1941,7 @@ class JavaScriptSupport:
function_to_optimize: Any, function_to_optimize: Any,
tests_project_root: Path, tests_project_root: Path,
mode: str, mode: str,
test_path: Path|None, test_path: Path | None,
) -> tuple[bool, str | None]: ) -> tuple[bool, str | None]:
"""Inject profiling code into an existing JavaScript test file. """Inject profiling code into an existing JavaScript test file.

View file

@ -800,7 +800,9 @@ class FunctionOptimizer:
logger.debug( logger.debug(
f"[JAVA] Detected Java sources root: {java_sources_root} (from tests_root: {tests_root})" f"[JAVA] Detected Java sources root: {java_sources_root} (from tests_root: {tests_root})"
) )
logger.debug(f"[JAVA-ROOT] Returning Java sources root: {java_sources_root}, tests_root was: {tests_root}") logger.debug(
f"[JAVA-ROOT] Returning Java sources root: {java_sources_root}, tests_root was: {tests_root}"
)
return java_sources_root return java_sources_root
# If no standard package prefix found, check if there's a 'java' directory # If no standard package prefix found, check if there's a 'java' directory
@ -810,7 +812,9 @@ class FunctionOptimizer:
# Return up to and including 'java' # Return up to and including 'java'
java_sources_root = Path(*parts[: i + 1]) java_sources_root = Path(*parts[: i + 1])
logger.debug(f"[JAVA] Detected Maven-style Java sources root: {java_sources_root}") logger.debug(f"[JAVA] Detected Maven-style Java sources root: {java_sources_root}")
logger.debug(f"[JAVA-ROOT] Returning Java sources root: {java_sources_root}, tests_root was: {tests_root}") logger.debug(
f"[JAVA-ROOT] Returning Java sources root: {java_sources_root}, tests_root was: {tests_root}"
)
return java_sources_root return java_sources_root
# Default: return tests_root as-is (original behavior) # Default: return tests_root as-is (original behavior)
@ -862,7 +866,7 @@ class FunctionOptimizer:
if main_match: if main_match:
main_module_name = main_match.group(1) main_module_name = main_match.group(1)
if package_name.startswith(main_module_name): if package_name.startswith(main_module_name):
suffix = package_name[len(main_module_name):] suffix = package_name[len(main_module_name) :]
new_package = test_module_name + suffix new_package = test_module_name + suffix
old_decl = f"package {package_name};" old_decl = f"package {package_name};"
new_decl = f"package {new_package};" new_decl = f"package {new_package};"

View file

@ -164,7 +164,15 @@ def _find_project_root(start_path: Path) -> Path | None:
while current != current.parent: while current != current.parent:
# Check for project markers # Check for project markers
markers = [".git", "pyproject.toml", "package.json", "Cargo.toml", "pom.xml", "build.gradle", "build.gradle.kts"] markers = [
".git",
"pyproject.toml",
"package.json",
"Cargo.toml",
"pom.xml",
"build.gradle",
"build.gradle.kts",
]
for marker in markers: for marker in markers:
if (current / marker).exists(): if (current / marker).exists():
return current return current
@ -489,10 +497,17 @@ def _detect_tests_root(project_root: Path, language: str) -> tuple[Path | None,
for elem in [build.find("m:testSourceDirectory", ns), build.find("testSourceDirectory")]: for elem in [build.find("m:testSourceDirectory", ns), build.find("testSourceDirectory")]:
if elem is not None and elem.text: if elem is not None and elem.text:
# Resolve ${project.basedir}/src -> test_module_dir/src # Resolve ${project.basedir}/src -> test_module_dir/src
dir_text = elem.text.strip().replace("${project.basedir}/", "").replace("${project.basedir}", ".") dir_text = (
elem.text.strip()
.replace("${project.basedir}/", "")
.replace("${project.basedir}", ".")
)
resolved = test_module_dir / dir_text resolved = test_module_dir / dir_text
if resolved.is_dir(): if resolved.is_dir():
return resolved, f"{test_module_name}/{dir_text} (from {test_module_name}/pom.xml testSourceDirectory)" return (
resolved,
f"{test_module_name}/{dir_text} (from {test_module_name}/pom.xml testSourceDirectory)",
)
except ET.ParseError: except ET.ParseError:
pass pass
# Test module exists but no custom testSourceDirectory - use the module root # Test module exists but no custom testSourceDirectory - use the module root
@ -548,8 +563,6 @@ def _detect_test_runner(project_root: Path, language: str) -> tuple[str, str]:
def _detect_java_test_runner(project_root: Path) -> tuple[str, str]: def _detect_java_test_runner(project_root: Path) -> tuple[str, str]:
"""Detect Java test framework.""" """Detect Java test framework."""
import xml.etree.ElementTree as ET
pom_path = project_root / "pom.xml" pom_path = project_root / "pom.xml"
if pom_path.exists(): if pom_path.exists():
try: try:

View file

@ -231,7 +231,9 @@ class JacocoCoverageUtils:
f"File preview: {content_preview!r}" f"File preview: {content_preview!r}"
) )
except Exception as read_err: except Exception as read_err:
logger.warning(f"Failed to parse JaCoCo XML file at '{jacoco_xml_path}': {e}. Could not read file: {read_err}") logger.warning(
f"Failed to parse JaCoCo XML file at '{jacoco_xml_path}': {e}. Could not read file: {read_err}"
)
return CoverageData.create_empty(source_code_path, function_name, code_context) return CoverageData.create_empty(source_code_path, function_name, code_context)
# Determine expected source file name from path # Determine expected source file name from path

View file

@ -27,9 +27,7 @@ def safe_repr(obj: object) -> str:
return f"<repr failed: {type(e).__name__}: {e}>" return f"<repr failed: {type(e).__name__}: {e}>"
def compare_test_results( def compare_test_results(original_results: TestResults, candidate_results: TestResults) -> tuple[bool, list[TestDiff]]:
original_results: TestResults, candidate_results: TestResults
) -> tuple[bool, list[TestDiff]]:
# This is meant to be only called with test results for the first loop index # This is meant to be only called with test results for the first loop index
if len(original_results) == 0 or len(candidate_results) == 0: if len(original_results) == 0 or len(candidate_results) == 0:
return False, [] # empty test results are not equal return False, [] # empty test results are not equal
@ -102,9 +100,7 @@ def compare_test_results(
) )
) )
elif not comparator( elif not comparator(original_test_result.return_value, cdd_test_result.return_value, superset_obj=superset_obj):
original_test_result.return_value, cdd_test_result.return_value, superset_obj=superset_obj
):
test_diffs.append( test_diffs.append(
TestDiff( TestDiff(
scope=TestDiffScope.RETURN_VALUE, scope=TestDiffScope.RETURN_VALUE,
@ -129,9 +125,8 @@ def compare_test_results(
) )
except Exception as e: except Exception as e:
logger.error(e) logger.error(e)
elif ( elif (original_test_result.stdout and cdd_test_result.stdout) and not comparator(
(original_test_result.stdout and cdd_test_result.stdout) original_test_result.stdout, cdd_test_result.stdout
and not comparator(original_test_result.stdout, cdd_test_result.stdout)
): ):
test_diffs.append( test_diffs.append(
TestDiff( TestDiff(

View file

@ -1002,7 +1002,9 @@ def parse_test_xml(
# Always use tests_project_rootdir since pytest is now the test runner for all frameworks # Always use tests_project_rootdir since pytest is now the test runner for all frameworks
base_dir = test_config.tests_project_rootdir base_dir = test_config.tests_project_rootdir
logger.debug(f"[PARSE-XML] base_dir for resolution: {base_dir}") logger.debug(f"[PARSE-XML] base_dir for resolution: {base_dir}")
logger.debug(f"[PARSE-XML] Registered test files: {[str(tf.instrumented_behavior_file_path) for tf in test_files.test_files]}") logger.debug(
f"[PARSE-XML] Registered test files: {[str(tf.instrumented_behavior_file_path) for tf in test_files.test_files]}"
)
# For Java: pre-parse fallback stdout once (not per testcase) to avoid O(n²) complexity # For Java: pre-parse fallback stdout once (not per testcase) to avoid O(n²) complexity
java_fallback_stdout = None java_fallback_stdout = None
@ -1067,7 +1069,9 @@ def parse_test_xml(
test_file_path = resolve_test_file_from_class_path(test_class_path, base_dir) test_file_path = resolve_test_file_from_class_path(test_class_path, base_dir)
if test_file_path is None: if test_file_path is None:
logger.error(f"[PARSE-XML] ERROR: Could not resolve test_class_path={test_class_path}, base_dir={base_dir}") logger.error(
f"[PARSE-XML] ERROR: Could not resolve test_class_path={test_class_path}, base_dir={base_dir}"
)
logger.warning(f"Could not find the test for file name - {test_class_path} ") logger.warning(f"Could not find the test for file name - {test_class_path} ")
continue continue
else: else:
@ -1271,9 +1275,7 @@ def parse_test_xml(
str(test_file.instrumented_behavior_file_path or test_file.original_file_path) str(test_file.instrumented_behavior_file_path or test_file.original_file_path)
for test_file in test_files.test_files for test_file in test_files.test_files
] ]
logger.info( logger.info(f"Tests {test_paths_display} failed to run, skipping")
f"Tests {test_paths_display} failed to run, skipping"
)
if run_result is not None: if run_result is not None:
stdout, stderr = "", "" stdout, stderr = "", ""
try: try:

View file

@ -109,7 +109,11 @@ def generate_tests(
# Instrument for behavior verification (renames class) # Instrument for behavior verification (renames class)
instrumented_behavior_test_source = instrument_generated_java_test( instrumented_behavior_test_source = instrument_generated_java_test(
test_code=generated_test_source, function_name=func_name, qualified_name=qualified_name, mode="behavior", function_to_optimize=function_to_optimize test_code=generated_test_source,
function_name=func_name,
qualified_name=qualified_name,
mode="behavior",
function_to_optimize=function_to_optimize,
) )
# Instrument for performance measurement (adds timing markers) # Instrument for performance measurement (adds timing markers)
@ -118,7 +122,7 @@ def generate_tests(
function_name=func_name, function_name=func_name,
qualified_name=qualified_name, qualified_name=qualified_name,
mode="performance", mode="performance",
function_to_optimize=function_to_optimize function_to_optimize=function_to_optimize,
) )
logger.debug(f"Instrumented Java tests locally for {func_name}") logger.debug(f"Instrumented Java tests locally for {func_name}")

View file

@ -1153,7 +1153,7 @@ void testWithThreadSleep() throws InterruptedException {
assert result == expected assert result == expected
def test_synchronized_method_signature_preserved(self): def test_synchronized_method_signature_preserved(self):
"""synchronized modifier on a test method is preserved after transformation.""" """Synchronized modifier on a test method is preserved after transformation."""
source = """\ source = """\
@Test @Test
synchronized void testSyncMethod() { synchronized void testSyncMethod() {

View file

@ -12,8 +12,6 @@ Also includes end-to-end execution tests that:
import os import os
import re import re
import shutil
import subprocess
from pathlib import Path from pathlib import Path
import pytest import pytest
@ -24,7 +22,6 @@ os.environ["CODEFLASH_API_KEY"] = "cf-test-key"
from codeflash.discovery.functions_to_optimize import FunctionToOptimize from codeflash.discovery.functions_to_optimize import FunctionToOptimize
from codeflash.languages.base import Language from codeflash.languages.base import Language
from codeflash.languages.current import set_current_language from codeflash.languages.current import set_current_language
from codeflash.models.function_types import FunctionParent
from codeflash.languages.java.build_tools import find_maven_executable from codeflash.languages.java.build_tools import find_maven_executable
from codeflash.languages.java.discovery import discover_functions_from_source from codeflash.languages.java.discovery import discover_functions_from_source
from codeflash.languages.java.instrumentation import ( from codeflash.languages.java.instrumentation import (
@ -1148,7 +1145,7 @@ public class TargetBenchmark {
"public class TargetBenchmark {\n" "public class TargetBenchmark {\n"
"\n" "\n"
" @Test\n" " @Test\n"
" @DisplayName(\"Benchmark multiply\")\n" ' @DisplayName("Benchmark multiply")\n'
" public void benchmarkMultiply() {\n" " public void benchmarkMultiply() {\n"
" \n" # Empty test_setup_code with 8-space indent " \n" # Empty test_setup_code with 8-space indent
"\n" "\n"
@ -1167,7 +1164,7 @@ public class TargetBenchmark {
" long totalNanos = endTime - startTime;\n" " long totalNanos = endTime - startTime;\n"
" long avgNanos = totalNanos / 5000;\n" " long avgNanos = totalNanos / 5000;\n"
"\n" "\n"
" System.out.println(\"CODEFLASH_BENCHMARK:multiply:total_ns=\" + totalNanos + \",avg_ns=\" + avgNanos + \",iterations=5000\");\n" ' System.out.println("CODEFLASH_BENCHMARK:multiply:total_ns=" + totalNanos + ",avg_ns=" + avgNanos + ",iterations=5000");\n'
" }\n" " }\n"
"}\n" "}\n"
) )
@ -1934,9 +1931,6 @@ class TestRunAndParseTests:
@pytest.fixture @pytest.fixture
def java_project(self, tmp_path: Path): def java_project(self, tmp_path: Path):
"""Create a temporary Maven project and set up Java language context.""" """Create a temporary Maven project and set up Java language context."""
from codeflash.languages.base import Language
from codeflash.languages.current import set_current_language
# Force set the language to Java (reset the singleton first) # Force set the language to Java (reset the singleton first)
import codeflash.languages.current as current_module import codeflash.languages.current as current_module
current_module._current_language = None current_module._current_language = None
@ -2432,7 +2426,6 @@ public class BrokenCalcTest {
def test_behavior_mode_writes_to_sqlite(self, java_project): def test_behavior_mode_writes_to_sqlite(self, java_project):
"""Test that behavior mode correctly writes results to SQLite file.""" """Test that behavior mode correctly writes results to SQLite file."""
import sqlite3 import sqlite3
from argparse import Namespace from argparse import Namespace
from codeflash.code_utils.code_utils import get_run_tmp_file from codeflash.code_utils.code_utils import get_run_tmp_file

View file

@ -16,10 +16,7 @@ Covers:
- Edge cases: static calls, qualified calls, method chaining - Edge cases: static calls, qualified calls, method chaining
""" """
from codeflash.languages.java.remove_asserts import ( from codeflash.languages.java.remove_asserts import JavaAssertTransformer, transform_java_assertions
JavaAssertTransformer,
transform_java_assertions,
)
class TestJUnit4Assertions: class TestJUnit4Assertions: