prek
This commit is contained in:
parent
22541e085a
commit
60a28c0843
19 changed files with 128 additions and 129 deletions
|
|
@ -715,7 +715,12 @@ def inject_profiling_into_existing_test(
|
||||||
from codeflash.languages.javascript.instrument import inject_profiling_into_existing_js_test
|
from codeflash.languages.javascript.instrument import inject_profiling_into_existing_js_test
|
||||||
|
|
||||||
return inject_profiling_into_existing_js_test(
|
return inject_profiling_into_existing_js_test(
|
||||||
test_string=test_string, call_positions=call_positions, function_to_optimize=function_to_optimize, tests_project_root=tests_project_root, mode= mode.value, test_path=test_path
|
test_string=test_string,
|
||||||
|
call_positions=call_positions,
|
||||||
|
function_to_optimize=function_to_optimize,
|
||||||
|
tests_project_root=tests_project_root,
|
||||||
|
mode=mode.value,
|
||||||
|
test_path=test_path,
|
||||||
)
|
)
|
||||||
|
|
||||||
if is_java():
|
if is_java():
|
||||||
|
|
@ -725,11 +730,14 @@ def inject_profiling_into_existing_test(
|
||||||
|
|
||||||
if function_to_optimize.is_async:
|
if function_to_optimize.is_async:
|
||||||
return inject_async_profiling_into_existing_test(
|
return inject_async_profiling_into_existing_test(
|
||||||
test_string=test_string, call_positions=call_positions, function_to_optimize=function_to_optimize, tests_project_root=tests_project_root, mode=mode.value, test_path=test_path
|
test_string=test_string,
|
||||||
|
call_positions=call_positions,
|
||||||
|
function_to_optimize=function_to_optimize,
|
||||||
|
tests_project_root=tests_project_root,
|
||||||
|
mode=mode.value,
|
||||||
|
test_path=test_path,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
used_frameworks = detect_frameworks_from_code(test_string)
|
used_frameworks = detect_frameworks_from_code(test_string)
|
||||||
try:
|
try:
|
||||||
tree = ast.parse(test_string)
|
tree = ast.parse(test_string)
|
||||||
|
|
|
||||||
|
|
@ -572,7 +572,7 @@ class LanguageSupport(Protocol):
|
||||||
function_to_optimize: Any,
|
function_to_optimize: Any,
|
||||||
tests_project_root: Path,
|
tests_project_root: Path,
|
||||||
mode: str,
|
mode: str,
|
||||||
test_path: Path | None
|
test_path: Path | None,
|
||||||
) -> tuple[bool, str | None]:
|
) -> tuple[bool, str | None]:
|
||||||
"""Inject profiling code into an existing test file.
|
"""Inject profiling code into an existing test file.
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -184,7 +184,6 @@ def _get_maven_project_info(project_root: Path) -> JavaProjectInfo | None:
|
||||||
if test_src.exists():
|
if test_src.exists():
|
||||||
test_roots.append(test_src)
|
test_roots.append(test_src)
|
||||||
|
|
||||||
|
|
||||||
# Check for custom source directories in pom.xml <build> section
|
# Check for custom source directories in pom.xml <build> section
|
||||||
for build in [root.find("m:build", ns), root.find("build")]:
|
for build in [root.find("m:build", ns), root.find("build")]:
|
||||||
if build is not None:
|
if build is not None:
|
||||||
|
|
@ -660,7 +659,7 @@ def add_codeflash_dependency_to_pom(pom_path: Path) -> bool:
|
||||||
|
|
||||||
new_content = content[:idx] + CODEFLASH_DEPENDENCY_SNIPPET
|
new_content = content[:idx] + CODEFLASH_DEPENDENCY_SNIPPET
|
||||||
# Skip the original </dependencies> tag since our snippet includes it
|
# Skip the original </dependencies> tag since our snippet includes it
|
||||||
new_content += content[idx + len(closing_tag):]
|
new_content += content[idx + len(closing_tag) :]
|
||||||
|
|
||||||
pom_path.write_text(new_content, encoding="utf-8")
|
pom_path.write_text(new_content, encoding="utf-8")
|
||||||
logger.info("Added codeflash-runtime dependency to pom.xml")
|
logger.info("Added codeflash-runtime dependency to pom.xml")
|
||||||
|
|
|
||||||
|
|
@ -179,13 +179,20 @@ def compare_test_results(
|
||||||
[
|
[
|
||||||
java_exe,
|
java_exe,
|
||||||
# Java 16+ module system: Kryo needs reflective access to internal JDK classes
|
# Java 16+ module system: Kryo needs reflective access to internal JDK classes
|
||||||
"--add-opens", "java.base/java.util=ALL-UNNAMED",
|
"--add-opens",
|
||||||
"--add-opens", "java.base/java.lang=ALL-UNNAMED",
|
"java.base/java.util=ALL-UNNAMED",
|
||||||
"--add-opens", "java.base/java.lang.reflect=ALL-UNNAMED",
|
"--add-opens",
|
||||||
"--add-opens", "java.base/java.io=ALL-UNNAMED",
|
"java.base/java.lang=ALL-UNNAMED",
|
||||||
"--add-opens", "java.base/java.math=ALL-UNNAMED",
|
"--add-opens",
|
||||||
"--add-opens", "java.base/java.net=ALL-UNNAMED",
|
"java.base/java.lang.reflect=ALL-UNNAMED",
|
||||||
"--add-opens", "java.base/java.util.zip=ALL-UNNAMED",
|
"--add-opens",
|
||||||
|
"java.base/java.io=ALL-UNNAMED",
|
||||||
|
"--add-opens",
|
||||||
|
"java.base/java.math=ALL-UNNAMED",
|
||||||
|
"--add-opens",
|
||||||
|
"java.base/java.net=ALL-UNNAMED",
|
||||||
|
"--add-opens",
|
||||||
|
"java.base/java.util.zip=ALL-UNNAMED",
|
||||||
"-cp",
|
"-cp",
|
||||||
str(jar_path),
|
str(jar_path),
|
||||||
"com.codeflash.Comparator",
|
"com.codeflash.Comparator",
|
||||||
|
|
|
||||||
|
|
@ -18,8 +18,6 @@ from pathlib import Path
|
||||||
from typing import TYPE_CHECKING
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from tree_sitter import Node
|
|
||||||
|
|
||||||
from codeflash.languages.base import FunctionInfo
|
from codeflash.languages.base import FunctionInfo
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
@ -306,9 +304,7 @@ class JavaConcurrencyAnalyzer:
|
||||||
return suggestions
|
return suggestions
|
||||||
|
|
||||||
|
|
||||||
def analyze_function_concurrency(
|
def analyze_function_concurrency(func: FunctionInfo, source: str | None = None, analyzer=None) -> ConcurrencyInfo:
|
||||||
func: FunctionInfo, source: str | None = None, analyzer=None
|
|
||||||
) -> ConcurrencyInfo:
|
|
||||||
"""Analyze a function for concurrency patterns.
|
"""Analyze a function for concurrency patterns.
|
||||||
|
|
||||||
Convenience function that creates a JavaConcurrencyAnalyzer and analyzes the function.
|
Convenience function that creates a JavaConcurrencyAnalyzer and analyzes the function.
|
||||||
|
|
|
||||||
|
|
@ -34,6 +34,7 @@ class JavaLineProfiler:
|
||||||
instrumented = profiler.instrument_source(source, file_path, functions)
|
instrumented = profiler.instrument_source(source, file_path, functions)
|
||||||
# Run instrumented code
|
# Run instrumented code
|
||||||
results = JavaLineProfiler.parse_results(Path("profile.json"))
|
results = JavaLineProfiler.parse_results(Path("profile.json"))
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, output_file: Path) -> None:
|
def __init__(self, output_file: Path) -> None:
|
||||||
|
|
@ -48,13 +49,7 @@ class JavaLineProfiler:
|
||||||
self.profiler_var = "__codeflashProfiler__"
|
self.profiler_var = "__codeflashProfiler__"
|
||||||
self.line_contents: dict[str, str] = {}
|
self.line_contents: dict[str, str] = {}
|
||||||
|
|
||||||
def instrument_source(
|
def instrument_source(self, source: str, file_path: Path, functions: list[FunctionInfo], analyzer=None) -> str:
|
||||||
self,
|
|
||||||
source: str,
|
|
||||||
file_path: Path,
|
|
||||||
functions: list[FunctionInfo],
|
|
||||||
analyzer=None,
|
|
||||||
) -> str:
|
|
||||||
"""Instrument Java source code with line profiling.
|
"""Instrument Java source code with line profiling.
|
||||||
|
|
||||||
Adds profiling instrumentation to track line-level execution for the
|
Adds profiling instrumentation to track line-level execution for the
|
||||||
|
|
@ -106,9 +101,7 @@ class JavaLineProfiler:
|
||||||
import_end_idx = i
|
import_end_idx = i
|
||||||
break
|
break
|
||||||
|
|
||||||
lines_with_profiler = (
|
lines_with_profiler = lines[:import_end_idx] + [profiler_class_code + "\n"] + lines[import_end_idx:]
|
||||||
lines[:import_end_idx] + [profiler_class_code + "\n"] + lines[import_end_idx:]
|
|
||||||
)
|
|
||||||
|
|
||||||
result = "".join(lines_with_profiler)
|
result = "".join(lines_with_profiler)
|
||||||
if not analyzer.validate_syntax(result):
|
if not analyzer.validate_syntax(result):
|
||||||
|
|
@ -121,7 +114,7 @@ class JavaLineProfiler:
|
||||||
# Store line contents as a simple map (embedded directly in code)
|
# Store line contents as a simple map (embedded directly in code)
|
||||||
line_contents_code = self._generate_line_contents_map()
|
line_contents_code = self._generate_line_contents_map()
|
||||||
|
|
||||||
return f'''
|
return f"""
|
||||||
/**
|
/**
|
||||||
* Codeflash line profiler - tracks per-line execution statistics.
|
* Codeflash line profiler - tracks per-line execution statistics.
|
||||||
* Auto-generated - do not modify.
|
* Auto-generated - do not modify.
|
||||||
|
|
@ -132,7 +125,7 @@ class {self.profiler_class} {{
|
||||||
private static final ThreadLocal<Long> lastLineTime = new ThreadLocal<>();
|
private static final ThreadLocal<Long> lastLineTime = new ThreadLocal<>();
|
||||||
private static final ThreadLocal<String> lastKey = new ThreadLocal<>();
|
private static final ThreadLocal<String> lastKey = new ThreadLocal<>();
|
||||||
private static final java.util.concurrent.atomic.AtomicInteger totalHits = new java.util.concurrent.atomic.AtomicInteger(0);
|
private static final java.util.concurrent.atomic.AtomicInteger totalHits = new java.util.concurrent.atomic.AtomicInteger(0);
|
||||||
private static final String OUTPUT_FILE = "{str(self.output_file)}";
|
private static final String OUTPUT_FILE = "{self.output_file!s}";
|
||||||
|
|
||||||
static class LineStats {{
|
static class LineStats {{
|
||||||
public final java.util.concurrent.atomic.AtomicLong hits = new java.util.concurrent.atomic.AtomicLong(0);
|
public final java.util.concurrent.atomic.AtomicLong hits = new java.util.concurrent.atomic.AtomicLong(0);
|
||||||
|
|
@ -247,15 +240,9 @@ class {self.profiler_class} {{
|
||||||
Runtime.getRuntime().addShutdownHook(new Thread(() -> save()));
|
Runtime.getRuntime().addShutdownHook(new Thread(() -> save()));
|
||||||
}}
|
}}
|
||||||
}}
|
}}
|
||||||
'''
|
"""
|
||||||
|
|
||||||
def _instrument_function(
|
def _instrument_function(self, func: FunctionInfo, lines: list[str], file_path: Path, analyzer) -> list[str]:
|
||||||
self,
|
|
||||||
func: FunctionInfo,
|
|
||||||
lines: list[str],
|
|
||||||
file_path: Path,
|
|
||||||
analyzer,
|
|
||||||
) -> list[str]:
|
|
||||||
"""Instrument a single function with line profiling.
|
"""Instrument a single function with line profiling.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
|
@ -300,9 +287,7 @@ class {self.profiler_class} {{
|
||||||
|
|
||||||
# Add the line with enterFunction() call after it
|
# Add the line with enterFunction() call after it
|
||||||
instrumented_lines.append(line)
|
instrumented_lines.append(line)
|
||||||
instrumented_lines.append(
|
instrumented_lines.append(f"{body_indent}{self.profiler_class}.enterFunction();\n")
|
||||||
f"{body_indent}{self.profiler_class}.enterFunction();\n"
|
|
||||||
)
|
|
||||||
function_entry_added = True
|
function_entry_added = True
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
|
@ -326,8 +311,7 @@ class {self.profiler_class} {{
|
||||||
|
|
||||||
# Add hit() call before the line
|
# Add hit() call before the line
|
||||||
profiled_line = (
|
profiled_line = (
|
||||||
f"{indent_str}{self.profiler_class}.hit("
|
f'{indent_str}{self.profiler_class}.hit("{file_path.as_posix()}", {global_line_num});\n{line}'
|
||||||
f'"{file_path.as_posix()}", {global_line_num});\n{line}'
|
|
||||||
)
|
)
|
||||||
instrumented_lines.append(profiled_line)
|
instrumented_lines.append(profiled_line)
|
||||||
else:
|
else:
|
||||||
|
|
@ -497,8 +481,6 @@ def format_line_profile_results(results: dict, file_path: Path | None = None) ->
|
||||||
avg_ms = time_ms / hits if hits > 0 else 0
|
avg_ms = time_ms / hits if hits > 0 else 0
|
||||||
content = stats.get("content", "")[:50] # Truncate long lines
|
content = stats.get("content", "")[:50] # Truncate long lines
|
||||||
|
|
||||||
output.append(
|
output.append(f"{line_num:6d} | {hits:10d} | {time_ms:12.3f} | {avg_ms:12.6f} | {content}")
|
||||||
f"{line_num:6d} | {hits:10d} | {time_ms:12.3f} | {avg_ms:12.6f} | {content}"
|
|
||||||
)
|
|
||||||
|
|
||||||
return "\n".join(output)
|
return "\n".join(output)
|
||||||
|
|
|
||||||
|
|
@ -298,9 +298,7 @@ class JavaAssertTransformer:
|
||||||
# - Assertions.assertEquals (JUnit 5)
|
# - Assertions.assertEquals (JUnit 5)
|
||||||
# - org.junit.jupiter.api.Assertions.assertEquals (fully qualified)
|
# - org.junit.jupiter.api.Assertions.assertEquals (fully qualified)
|
||||||
all_assertions = "|".join(JUNIT5_ALL_ASSERTIONS)
|
all_assertions = "|".join(JUNIT5_ALL_ASSERTIONS)
|
||||||
pattern = re.compile(
|
pattern = re.compile(rf"(\s*)((?:(?:\w+\.)*Assert(?:ions)?\.)?({all_assertions}))\s*\(", re.MULTILINE)
|
||||||
rf"(\s*)((?:(?:\w+\.)*Assert(?:ions)?\.)?({all_assertions}))\s*\(", re.MULTILINE
|
|
||||||
)
|
|
||||||
|
|
||||||
for match in pattern.finditer(source):
|
for match in pattern.finditer(source):
|
||||||
leading_ws = match.group(1)
|
leading_ws = match.group(1)
|
||||||
|
|
@ -549,8 +547,12 @@ class JavaAssertTransformer:
|
||||||
return results
|
return results
|
||||||
|
|
||||||
def _collect_target_invocations(
|
def _collect_target_invocations(
|
||||||
self, node, wrapper_bytes: bytes, content_bytes: bytes,
|
self,
|
||||||
base_offset: int, out: list[TargetCall],
|
node,
|
||||||
|
wrapper_bytes: bytes,
|
||||||
|
content_bytes: bytes,
|
||||||
|
base_offset: int,
|
||||||
|
out: list[TargetCall],
|
||||||
seen_top_level: set[tuple[int, int]] | None = None,
|
seen_top_level: set[tuple[int, int]] | None = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Recursively walk the AST and collect method_invocation nodes that match self.func_name.
|
"""Recursively walk the AST and collect method_invocation nodes that match self.func_name.
|
||||||
|
|
@ -574,22 +576,24 @@ class JavaAssertTransformer:
|
||||||
seen_top_level.add(range_key)
|
seen_top_level.add(range_key)
|
||||||
start = top_node.start_byte - prefix_len
|
start = top_node.start_byte - prefix_len
|
||||||
end = top_node.end_byte - prefix_len
|
end = top_node.end_byte - prefix_len
|
||||||
if 0 <= start and end <= len(content_bytes):
|
if start >= 0 and end <= len(content_bytes):
|
||||||
full_call = self.analyzer.get_node_text(top_node, wrapper_bytes)
|
full_call = self.analyzer.get_node_text(top_node, wrapper_bytes)
|
||||||
start_char = len(content_bytes[:start].decode("utf8"))
|
start_char = len(content_bytes[:start].decode("utf8"))
|
||||||
end_char = len(content_bytes[:end].decode("utf8"))
|
end_char = len(content_bytes[:end].decode("utf8"))
|
||||||
out.append(TargetCall(
|
out.append(
|
||||||
receiver=None,
|
TargetCall(
|
||||||
method_name=self.func_name,
|
receiver=None,
|
||||||
arguments="",
|
method_name=self.func_name,
|
||||||
full_call=full_call,
|
arguments="",
|
||||||
start_pos=base_offset + start_char,
|
full_call=full_call,
|
||||||
end_pos=base_offset + end_char,
|
start_pos=base_offset + start_char,
|
||||||
))
|
end_pos=base_offset + end_char,
|
||||||
|
)
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
start = node.start_byte - prefix_len
|
start = node.start_byte - prefix_len
|
||||||
end = node.end_byte - prefix_len
|
end = node.end_byte - prefix_len
|
||||||
if 0 <= start and end <= len(content_bytes):
|
if start >= 0 and end <= len(content_bytes):
|
||||||
out.append(self._build_target_call(node, wrapper_bytes, content_bytes, start, end, base_offset))
|
out.append(self._build_target_call(node, wrapper_bytes, content_bytes, start, end, base_offset))
|
||||||
return
|
return
|
||||||
|
|
||||||
|
|
@ -597,8 +601,7 @@ class JavaAssertTransformer:
|
||||||
self._collect_target_invocations(child, wrapper_bytes, content_bytes, base_offset, out, seen_top_level)
|
self._collect_target_invocations(child, wrapper_bytes, content_bytes, base_offset, out, seen_top_level)
|
||||||
|
|
||||||
def _build_target_call(
|
def _build_target_call(
|
||||||
self, node, wrapper_bytes: bytes, content_bytes: bytes,
|
self, node, wrapper_bytes: bytes, content_bytes: bytes, start_byte: int, end_byte: int, base_offset: int
|
||||||
start_byte: int, end_byte: int, base_offset: int,
|
|
||||||
) -> TargetCall:
|
) -> TargetCall:
|
||||||
"""Build a TargetCall from a tree-sitter method_invocation node."""
|
"""Build a TargetCall from a tree-sitter method_invocation node."""
|
||||||
get_text = self.analyzer.get_node_text
|
get_text = self.analyzer.get_node_text
|
||||||
|
|
@ -679,7 +682,6 @@ class JavaAssertTransformer:
|
||||||
# Handle generic types: Type<Generic> varName = ...
|
# Handle generic types: Type<Generic> varName = ...
|
||||||
match = self._assign_re.search(source, line_start, assertion_start)
|
match = self._assign_re.search(source, line_start, assertion_start)
|
||||||
|
|
||||||
|
|
||||||
if match:
|
if match:
|
||||||
var_type = match.group(1).strip()
|
var_type = match.group(1).strip()
|
||||||
var_name = match.group(2).strip()
|
var_name = match.group(2).strip()
|
||||||
|
|
@ -934,18 +936,12 @@ class JavaAssertTransformer:
|
||||||
f"catch (Exception _cf_ignored{counter}) {{}}"
|
f"catch (Exception _cf_ignored{counter}) {{}}"
|
||||||
)
|
)
|
||||||
|
|
||||||
return (
|
return f"{ws}try {{ {code_to_run} }} catch (Exception _cf_ignored{counter}) {{}}"
|
||||||
f"{ws}try {{ {code_to_run} }} "
|
|
||||||
f"catch (Exception _cf_ignored{counter}) {{}}"
|
|
||||||
)
|
|
||||||
|
|
||||||
# If no lambda body found, try to extract from target calls
|
# If no lambda body found, try to extract from target calls
|
||||||
if assertion.target_calls:
|
if assertion.target_calls:
|
||||||
call = assertion.target_calls[0]
|
call = assertion.target_calls[0]
|
||||||
return (
|
return f"{ws}try {{ {call.full_call}; }} catch (Exception _cf_ignored{counter}) {{}}"
|
||||||
f"{ws}try {{ {call.full_call}; }} "
|
|
||||||
f"catch (Exception _cf_ignored{counter}) {{}}"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Fallback: comment out the assertion
|
# Fallback: comment out the assertion
|
||||||
return f"{ws}// Removed assertThrows: could not extract callable"
|
return f"{ws}// Removed assertThrows: could not extract callable"
|
||||||
|
|
|
||||||
|
|
@ -12,11 +12,8 @@ from typing import TYPE_CHECKING, Any
|
||||||
from codeflash.languages.base import Language, LanguageSupport
|
from codeflash.languages.base import Language, LanguageSupport
|
||||||
from codeflash.languages.java.build_tools import find_test_root
|
from codeflash.languages.java.build_tools import find_test_root
|
||||||
from codeflash.languages.java.comparator import compare_test_results as _compare_test_results
|
from codeflash.languages.java.comparator import compare_test_results as _compare_test_results
|
||||||
|
from codeflash.languages.java.concurrency_analyzer import analyze_function_concurrency
|
||||||
from codeflash.languages.java.config import detect_java_project
|
from codeflash.languages.java.config import detect_java_project
|
||||||
from codeflash.languages.java.concurrency_analyzer import (
|
|
||||||
JavaConcurrencyAnalyzer,
|
|
||||||
analyze_function_concurrency,
|
|
||||||
)
|
|
||||||
from codeflash.languages.java.context import extract_code_context, find_helper_functions
|
from codeflash.languages.java.context import extract_code_context, find_helper_functions
|
||||||
from codeflash.languages.java.discovery import discover_functions, discover_functions_from_source
|
from codeflash.languages.java.discovery import discover_functions, discover_functions_from_source
|
||||||
from codeflash.languages.java.formatter import format_java_code, normalize_java_code
|
from codeflash.languages.java.formatter import format_java_code, normalize_java_code
|
||||||
|
|
@ -288,14 +285,11 @@ class JavaSupport(LanguageSupport):
|
||||||
function_to_optimize: Any,
|
function_to_optimize: Any,
|
||||||
tests_project_root: Path,
|
tests_project_root: Path,
|
||||||
mode: str,
|
mode: str,
|
||||||
test_path: Path | None
|
test_path: Path | None,
|
||||||
) -> tuple[bool, str | None]:
|
) -> tuple[bool, str | None]:
|
||||||
"""Inject profiling code into an existing test file."""
|
"""Inject profiling code into an existing test file."""
|
||||||
return instrument_existing_test(
|
return instrument_existing_test(
|
||||||
test_string=test_string,
|
test_string=test_string, function_to_optimize=function_to_optimize, mode=mode, test_path=test_path
|
||||||
function_to_optimize=function_to_optimize,
|
|
||||||
mode=mode,
|
|
||||||
test_path=test_path
|
|
||||||
)
|
)
|
||||||
|
|
||||||
def instrument_source_for_line_profiler(
|
def instrument_source_for_line_profiler(
|
||||||
|
|
|
||||||
|
|
@ -610,13 +610,20 @@ def _run_tests_direct(
|
||||||
cmd = [
|
cmd = [
|
||||||
str(java),
|
str(java),
|
||||||
# Java 16+ module system: Kryo needs reflective access to internal JDK classes
|
# Java 16+ module system: Kryo needs reflective access to internal JDK classes
|
||||||
"--add-opens", "java.base/java.util=ALL-UNNAMED",
|
"--add-opens",
|
||||||
"--add-opens", "java.base/java.lang=ALL-UNNAMED",
|
"java.base/java.util=ALL-UNNAMED",
|
||||||
"--add-opens", "java.base/java.lang.reflect=ALL-UNNAMED",
|
"--add-opens",
|
||||||
"--add-opens", "java.base/java.io=ALL-UNNAMED",
|
"java.base/java.lang=ALL-UNNAMED",
|
||||||
"--add-opens", "java.base/java.math=ALL-UNNAMED",
|
"--add-opens",
|
||||||
"--add-opens", "java.base/java.net=ALL-UNNAMED",
|
"java.base/java.lang.reflect=ALL-UNNAMED",
|
||||||
"--add-opens", "java.base/java.util.zip=ALL-UNNAMED",
|
"--add-opens",
|
||||||
|
"java.base/java.io=ALL-UNNAMED",
|
||||||
|
"--add-opens",
|
||||||
|
"java.base/java.math=ALL-UNNAMED",
|
||||||
|
"--add-opens",
|
||||||
|
"java.base/java.net=ALL-UNNAMED",
|
||||||
|
"--add-opens",
|
||||||
|
"java.base/java.util.zip=ALL-UNNAMED",
|
||||||
"-cp",
|
"-cp",
|
||||||
classpath,
|
classpath,
|
||||||
"org.junit.platform.console.ConsoleLauncher",
|
"org.junit.platform.console.ConsoleLauncher",
|
||||||
|
|
|
||||||
|
|
@ -1941,7 +1941,7 @@ class JavaScriptSupport:
|
||||||
function_to_optimize: Any,
|
function_to_optimize: Any,
|
||||||
tests_project_root: Path,
|
tests_project_root: Path,
|
||||||
mode: str,
|
mode: str,
|
||||||
test_path: Path|None,
|
test_path: Path | None,
|
||||||
) -> tuple[bool, str | None]:
|
) -> tuple[bool, str | None]:
|
||||||
"""Inject profiling code into an existing JavaScript test file.
|
"""Inject profiling code into an existing JavaScript test file.
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -800,7 +800,9 @@ class FunctionOptimizer:
|
||||||
logger.debug(
|
logger.debug(
|
||||||
f"[JAVA] Detected Java sources root: {java_sources_root} (from tests_root: {tests_root})"
|
f"[JAVA] Detected Java sources root: {java_sources_root} (from tests_root: {tests_root})"
|
||||||
)
|
)
|
||||||
logger.debug(f"[JAVA-ROOT] Returning Java sources root: {java_sources_root}, tests_root was: {tests_root}")
|
logger.debug(
|
||||||
|
f"[JAVA-ROOT] Returning Java sources root: {java_sources_root}, tests_root was: {tests_root}"
|
||||||
|
)
|
||||||
return java_sources_root
|
return java_sources_root
|
||||||
|
|
||||||
# If no standard package prefix found, check if there's a 'java' directory
|
# If no standard package prefix found, check if there's a 'java' directory
|
||||||
|
|
@ -810,7 +812,9 @@ class FunctionOptimizer:
|
||||||
# Return up to and including 'java'
|
# Return up to and including 'java'
|
||||||
java_sources_root = Path(*parts[: i + 1])
|
java_sources_root = Path(*parts[: i + 1])
|
||||||
logger.debug(f"[JAVA] Detected Maven-style Java sources root: {java_sources_root}")
|
logger.debug(f"[JAVA] Detected Maven-style Java sources root: {java_sources_root}")
|
||||||
logger.debug(f"[JAVA-ROOT] Returning Java sources root: {java_sources_root}, tests_root was: {tests_root}")
|
logger.debug(
|
||||||
|
f"[JAVA-ROOT] Returning Java sources root: {java_sources_root}, tests_root was: {tests_root}"
|
||||||
|
)
|
||||||
return java_sources_root
|
return java_sources_root
|
||||||
|
|
||||||
# Default: return tests_root as-is (original behavior)
|
# Default: return tests_root as-is (original behavior)
|
||||||
|
|
@ -862,7 +866,7 @@ class FunctionOptimizer:
|
||||||
if main_match:
|
if main_match:
|
||||||
main_module_name = main_match.group(1)
|
main_module_name = main_match.group(1)
|
||||||
if package_name.startswith(main_module_name):
|
if package_name.startswith(main_module_name):
|
||||||
suffix = package_name[len(main_module_name):]
|
suffix = package_name[len(main_module_name) :]
|
||||||
new_package = test_module_name + suffix
|
new_package = test_module_name + suffix
|
||||||
old_decl = f"package {package_name};"
|
old_decl = f"package {package_name};"
|
||||||
new_decl = f"package {new_package};"
|
new_decl = f"package {new_package};"
|
||||||
|
|
|
||||||
|
|
@ -164,7 +164,15 @@ def _find_project_root(start_path: Path) -> Path | None:
|
||||||
|
|
||||||
while current != current.parent:
|
while current != current.parent:
|
||||||
# Check for project markers
|
# Check for project markers
|
||||||
markers = [".git", "pyproject.toml", "package.json", "Cargo.toml", "pom.xml", "build.gradle", "build.gradle.kts"]
|
markers = [
|
||||||
|
".git",
|
||||||
|
"pyproject.toml",
|
||||||
|
"package.json",
|
||||||
|
"Cargo.toml",
|
||||||
|
"pom.xml",
|
||||||
|
"build.gradle",
|
||||||
|
"build.gradle.kts",
|
||||||
|
]
|
||||||
for marker in markers:
|
for marker in markers:
|
||||||
if (current / marker).exists():
|
if (current / marker).exists():
|
||||||
return current
|
return current
|
||||||
|
|
@ -489,10 +497,17 @@ def _detect_tests_root(project_root: Path, language: str) -> tuple[Path | None,
|
||||||
for elem in [build.find("m:testSourceDirectory", ns), build.find("testSourceDirectory")]:
|
for elem in [build.find("m:testSourceDirectory", ns), build.find("testSourceDirectory")]:
|
||||||
if elem is not None and elem.text:
|
if elem is not None and elem.text:
|
||||||
# Resolve ${project.basedir}/src -> test_module_dir/src
|
# Resolve ${project.basedir}/src -> test_module_dir/src
|
||||||
dir_text = elem.text.strip().replace("${project.basedir}/", "").replace("${project.basedir}", ".")
|
dir_text = (
|
||||||
|
elem.text.strip()
|
||||||
|
.replace("${project.basedir}/", "")
|
||||||
|
.replace("${project.basedir}", ".")
|
||||||
|
)
|
||||||
resolved = test_module_dir / dir_text
|
resolved = test_module_dir / dir_text
|
||||||
if resolved.is_dir():
|
if resolved.is_dir():
|
||||||
return resolved, f"{test_module_name}/{dir_text} (from {test_module_name}/pom.xml testSourceDirectory)"
|
return (
|
||||||
|
resolved,
|
||||||
|
f"{test_module_name}/{dir_text} (from {test_module_name}/pom.xml testSourceDirectory)",
|
||||||
|
)
|
||||||
except ET.ParseError:
|
except ET.ParseError:
|
||||||
pass
|
pass
|
||||||
# Test module exists but no custom testSourceDirectory - use the module root
|
# Test module exists but no custom testSourceDirectory - use the module root
|
||||||
|
|
@ -548,8 +563,6 @@ def _detect_test_runner(project_root: Path, language: str) -> tuple[str, str]:
|
||||||
|
|
||||||
def _detect_java_test_runner(project_root: Path) -> tuple[str, str]:
|
def _detect_java_test_runner(project_root: Path) -> tuple[str, str]:
|
||||||
"""Detect Java test framework."""
|
"""Detect Java test framework."""
|
||||||
import xml.etree.ElementTree as ET
|
|
||||||
|
|
||||||
pom_path = project_root / "pom.xml"
|
pom_path = project_root / "pom.xml"
|
||||||
if pom_path.exists():
|
if pom_path.exists():
|
||||||
try:
|
try:
|
||||||
|
|
|
||||||
|
|
@ -231,7 +231,9 @@ class JacocoCoverageUtils:
|
||||||
f"File preview: {content_preview!r}"
|
f"File preview: {content_preview!r}"
|
||||||
)
|
)
|
||||||
except Exception as read_err:
|
except Exception as read_err:
|
||||||
logger.warning(f"Failed to parse JaCoCo XML file at '{jacoco_xml_path}': {e}. Could not read file: {read_err}")
|
logger.warning(
|
||||||
|
f"Failed to parse JaCoCo XML file at '{jacoco_xml_path}': {e}. Could not read file: {read_err}"
|
||||||
|
)
|
||||||
return CoverageData.create_empty(source_code_path, function_name, code_context)
|
return CoverageData.create_empty(source_code_path, function_name, code_context)
|
||||||
|
|
||||||
# Determine expected source file name from path
|
# Determine expected source file name from path
|
||||||
|
|
|
||||||
|
|
@ -27,9 +27,7 @@ def safe_repr(obj: object) -> str:
|
||||||
return f"<repr failed: {type(e).__name__}: {e}>"
|
return f"<repr failed: {type(e).__name__}: {e}>"
|
||||||
|
|
||||||
|
|
||||||
def compare_test_results(
|
def compare_test_results(original_results: TestResults, candidate_results: TestResults) -> tuple[bool, list[TestDiff]]:
|
||||||
original_results: TestResults, candidate_results: TestResults
|
|
||||||
) -> tuple[bool, list[TestDiff]]:
|
|
||||||
# This is meant to be only called with test results for the first loop index
|
# This is meant to be only called with test results for the first loop index
|
||||||
if len(original_results) == 0 or len(candidate_results) == 0:
|
if len(original_results) == 0 or len(candidate_results) == 0:
|
||||||
return False, [] # empty test results are not equal
|
return False, [] # empty test results are not equal
|
||||||
|
|
@ -102,9 +100,7 @@ def compare_test_results(
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
elif not comparator(
|
elif not comparator(original_test_result.return_value, cdd_test_result.return_value, superset_obj=superset_obj):
|
||||||
original_test_result.return_value, cdd_test_result.return_value, superset_obj=superset_obj
|
|
||||||
):
|
|
||||||
test_diffs.append(
|
test_diffs.append(
|
||||||
TestDiff(
|
TestDiff(
|
||||||
scope=TestDiffScope.RETURN_VALUE,
|
scope=TestDiffScope.RETURN_VALUE,
|
||||||
|
|
@ -129,9 +125,8 @@ def compare_test_results(
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(e)
|
logger.error(e)
|
||||||
elif (
|
elif (original_test_result.stdout and cdd_test_result.stdout) and not comparator(
|
||||||
(original_test_result.stdout and cdd_test_result.stdout)
|
original_test_result.stdout, cdd_test_result.stdout
|
||||||
and not comparator(original_test_result.stdout, cdd_test_result.stdout)
|
|
||||||
):
|
):
|
||||||
test_diffs.append(
|
test_diffs.append(
|
||||||
TestDiff(
|
TestDiff(
|
||||||
|
|
|
||||||
|
|
@ -1002,7 +1002,9 @@ def parse_test_xml(
|
||||||
# Always use tests_project_rootdir since pytest is now the test runner for all frameworks
|
# Always use tests_project_rootdir since pytest is now the test runner for all frameworks
|
||||||
base_dir = test_config.tests_project_rootdir
|
base_dir = test_config.tests_project_rootdir
|
||||||
logger.debug(f"[PARSE-XML] base_dir for resolution: {base_dir}")
|
logger.debug(f"[PARSE-XML] base_dir for resolution: {base_dir}")
|
||||||
logger.debug(f"[PARSE-XML] Registered test files: {[str(tf.instrumented_behavior_file_path) for tf in test_files.test_files]}")
|
logger.debug(
|
||||||
|
f"[PARSE-XML] Registered test files: {[str(tf.instrumented_behavior_file_path) for tf in test_files.test_files]}"
|
||||||
|
)
|
||||||
|
|
||||||
# For Java: pre-parse fallback stdout once (not per testcase) to avoid O(n²) complexity
|
# For Java: pre-parse fallback stdout once (not per testcase) to avoid O(n²) complexity
|
||||||
java_fallback_stdout = None
|
java_fallback_stdout = None
|
||||||
|
|
@ -1067,7 +1069,9 @@ def parse_test_xml(
|
||||||
test_file_path = resolve_test_file_from_class_path(test_class_path, base_dir)
|
test_file_path = resolve_test_file_from_class_path(test_class_path, base_dir)
|
||||||
|
|
||||||
if test_file_path is None:
|
if test_file_path is None:
|
||||||
logger.error(f"[PARSE-XML] ERROR: Could not resolve test_class_path={test_class_path}, base_dir={base_dir}")
|
logger.error(
|
||||||
|
f"[PARSE-XML] ERROR: Could not resolve test_class_path={test_class_path}, base_dir={base_dir}"
|
||||||
|
)
|
||||||
logger.warning(f"Could not find the test for file name - {test_class_path} ")
|
logger.warning(f"Could not find the test for file name - {test_class_path} ")
|
||||||
continue
|
continue
|
||||||
else:
|
else:
|
||||||
|
|
@ -1271,9 +1275,7 @@ def parse_test_xml(
|
||||||
str(test_file.instrumented_behavior_file_path or test_file.original_file_path)
|
str(test_file.instrumented_behavior_file_path or test_file.original_file_path)
|
||||||
for test_file in test_files.test_files
|
for test_file in test_files.test_files
|
||||||
]
|
]
|
||||||
logger.info(
|
logger.info(f"Tests {test_paths_display} failed to run, skipping")
|
||||||
f"Tests {test_paths_display} failed to run, skipping"
|
|
||||||
)
|
|
||||||
if run_result is not None:
|
if run_result is not None:
|
||||||
stdout, stderr = "", ""
|
stdout, stderr = "", ""
|
||||||
try:
|
try:
|
||||||
|
|
|
||||||
|
|
@ -109,7 +109,11 @@ def generate_tests(
|
||||||
|
|
||||||
# Instrument for behavior verification (renames class)
|
# Instrument for behavior verification (renames class)
|
||||||
instrumented_behavior_test_source = instrument_generated_java_test(
|
instrumented_behavior_test_source = instrument_generated_java_test(
|
||||||
test_code=generated_test_source, function_name=func_name, qualified_name=qualified_name, mode="behavior", function_to_optimize=function_to_optimize
|
test_code=generated_test_source,
|
||||||
|
function_name=func_name,
|
||||||
|
qualified_name=qualified_name,
|
||||||
|
mode="behavior",
|
||||||
|
function_to_optimize=function_to_optimize,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Instrument for performance measurement (adds timing markers)
|
# Instrument for performance measurement (adds timing markers)
|
||||||
|
|
@ -118,7 +122,7 @@ def generate_tests(
|
||||||
function_name=func_name,
|
function_name=func_name,
|
||||||
qualified_name=qualified_name,
|
qualified_name=qualified_name,
|
||||||
mode="performance",
|
mode="performance",
|
||||||
function_to_optimize=function_to_optimize
|
function_to_optimize=function_to_optimize,
|
||||||
)
|
)
|
||||||
|
|
||||||
logger.debug(f"Instrumented Java tests locally for {func_name}")
|
logger.debug(f"Instrumented Java tests locally for {func_name}")
|
||||||
|
|
|
||||||
|
|
@ -1153,7 +1153,7 @@ void testWithThreadSleep() throws InterruptedException {
|
||||||
assert result == expected
|
assert result == expected
|
||||||
|
|
||||||
def test_synchronized_method_signature_preserved(self):
|
def test_synchronized_method_signature_preserved(self):
|
||||||
"""synchronized modifier on a test method is preserved after transformation."""
|
"""Synchronized modifier on a test method is preserved after transformation."""
|
||||||
source = """\
|
source = """\
|
||||||
@Test
|
@Test
|
||||||
synchronized void testSyncMethod() {
|
synchronized void testSyncMethod() {
|
||||||
|
|
|
||||||
|
|
@ -12,8 +12,6 @@ Also includes end-to-end execution tests that:
|
||||||
|
|
||||||
import os
|
import os
|
||||||
import re
|
import re
|
||||||
import shutil
|
|
||||||
import subprocess
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
@ -24,7 +22,6 @@ os.environ["CODEFLASH_API_KEY"] = "cf-test-key"
|
||||||
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
|
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
|
||||||
from codeflash.languages.base import Language
|
from codeflash.languages.base import Language
|
||||||
from codeflash.languages.current import set_current_language
|
from codeflash.languages.current import set_current_language
|
||||||
from codeflash.models.function_types import FunctionParent
|
|
||||||
from codeflash.languages.java.build_tools import find_maven_executable
|
from codeflash.languages.java.build_tools import find_maven_executable
|
||||||
from codeflash.languages.java.discovery import discover_functions_from_source
|
from codeflash.languages.java.discovery import discover_functions_from_source
|
||||||
from codeflash.languages.java.instrumentation import (
|
from codeflash.languages.java.instrumentation import (
|
||||||
|
|
@ -1148,7 +1145,7 @@ public class TargetBenchmark {
|
||||||
"public class TargetBenchmark {\n"
|
"public class TargetBenchmark {\n"
|
||||||
"\n"
|
"\n"
|
||||||
" @Test\n"
|
" @Test\n"
|
||||||
" @DisplayName(\"Benchmark multiply\")\n"
|
' @DisplayName("Benchmark multiply")\n'
|
||||||
" public void benchmarkMultiply() {\n"
|
" public void benchmarkMultiply() {\n"
|
||||||
" \n" # Empty test_setup_code with 8-space indent
|
" \n" # Empty test_setup_code with 8-space indent
|
||||||
"\n"
|
"\n"
|
||||||
|
|
@ -1167,7 +1164,7 @@ public class TargetBenchmark {
|
||||||
" long totalNanos = endTime - startTime;\n"
|
" long totalNanos = endTime - startTime;\n"
|
||||||
" long avgNanos = totalNanos / 5000;\n"
|
" long avgNanos = totalNanos / 5000;\n"
|
||||||
"\n"
|
"\n"
|
||||||
" System.out.println(\"CODEFLASH_BENCHMARK:multiply:total_ns=\" + totalNanos + \",avg_ns=\" + avgNanos + \",iterations=5000\");\n"
|
' System.out.println("CODEFLASH_BENCHMARK:multiply:total_ns=" + totalNanos + ",avg_ns=" + avgNanos + ",iterations=5000");\n'
|
||||||
" }\n"
|
" }\n"
|
||||||
"}\n"
|
"}\n"
|
||||||
)
|
)
|
||||||
|
|
@ -1934,9 +1931,6 @@ class TestRunAndParseTests:
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def java_project(self, tmp_path: Path):
|
def java_project(self, tmp_path: Path):
|
||||||
"""Create a temporary Maven project and set up Java language context."""
|
"""Create a temporary Maven project and set up Java language context."""
|
||||||
from codeflash.languages.base import Language
|
|
||||||
from codeflash.languages.current import set_current_language
|
|
||||||
|
|
||||||
# Force set the language to Java (reset the singleton first)
|
# Force set the language to Java (reset the singleton first)
|
||||||
import codeflash.languages.current as current_module
|
import codeflash.languages.current as current_module
|
||||||
current_module._current_language = None
|
current_module._current_language = None
|
||||||
|
|
@ -2432,7 +2426,6 @@ public class BrokenCalcTest {
|
||||||
def test_behavior_mode_writes_to_sqlite(self, java_project):
|
def test_behavior_mode_writes_to_sqlite(self, java_project):
|
||||||
"""Test that behavior mode correctly writes results to SQLite file."""
|
"""Test that behavior mode correctly writes results to SQLite file."""
|
||||||
import sqlite3
|
import sqlite3
|
||||||
|
|
||||||
from argparse import Namespace
|
from argparse import Namespace
|
||||||
|
|
||||||
from codeflash.code_utils.code_utils import get_run_tmp_file
|
from codeflash.code_utils.code_utils import get_run_tmp_file
|
||||||
|
|
|
||||||
|
|
@ -16,10 +16,7 @@ Covers:
|
||||||
- Edge cases: static calls, qualified calls, method chaining
|
- Edge cases: static calls, qualified calls, method chaining
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from codeflash.languages.java.remove_asserts import (
|
from codeflash.languages.java.remove_asserts import JavaAssertTransformer, transform_java_assertions
|
||||||
JavaAssertTransformer,
|
|
||||||
transform_java_assertions,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class TestJUnit4Assertions:
|
class TestJUnit4Assertions:
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue