Merge pull request #1880 from codeflash-ai/java-config-redesign

feat: zero-config Java projects + smart ReplayHelper for end-to-end optimization
This commit is contained in:
Saurabh Misra 2026-04-01 15:04:26 -07:00 committed by GitHub
commit 15a92613e2
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
19 changed files with 992 additions and 121 deletions

View file

@ -3,17 +3,9 @@ name: E2E - Java Tracer
on:
pull_request:
paths:
- 'codeflash/languages/java/**'
- 'codeflash/languages/base.py'
- 'codeflash/languages/registry.py'
- 'codeflash/tracer.py'
- 'codeflash/benchmarking/function_ranker.py'
- 'codeflash/discovery/functions_to_optimize.py'
- 'codeflash/optimization/**'
- 'codeflash/verification/**'
- 'codeflash/**'
- 'codeflash-java-runtime/**'
- 'tests/test_languages/fixtures/java_tracer_e2e/**'
- 'tests/scripts/end_to_end_test_java_tracer.py'
- 'tests/**'
- '.github/workflows/e2e-java-tracer.yaml'
workflow_dispatch:

View file

@ -12,20 +12,181 @@ import org.objectweb.asm.Type;
public class ReplayHelper {
private final Connection db;
private final Connection traceDb;
// Codeflash instrumentation state read from environment variables once
private final String mode; // "behavior", "performance", or null
private final int loopIndex;
private final String testIteration;
private final String outputFile; // SQLite path for behavior capture
private final int innerIterations; // for performance looping
// Behavior mode: lazily opened SQLite connection for writing results
private Connection behaviorDb;
private boolean behaviorDbInitialized;
public ReplayHelper(String traceDbPath) {
try {
this.db = DriverManager.getConnection("jdbc:sqlite:" + traceDbPath);
this.traceDb = DriverManager.getConnection("jdbc:sqlite:" + traceDbPath);
} catch (SQLException e) {
throw new RuntimeException("Failed to open trace database: " + traceDbPath, e);
}
// Read codeflash instrumentation env vars (set by the test runner)
this.mode = System.getenv("CODEFLASH_MODE");
this.loopIndex = parseIntEnv("CODEFLASH_LOOP_INDEX", 1);
this.testIteration = getEnvOrDefault("CODEFLASH_TEST_ITERATION", "0");
this.outputFile = System.getenv("CODEFLASH_OUTPUT_FILE");
this.innerIterations = parseIntEnv("CODEFLASH_INNER_ITERATIONS", 10);
}
public void replay(String className, String methodName, String descriptor, int invocationIndex) throws Exception {
// Query the function_calls table for this method at the given index
// Deserialize args and resolve method (done once, outside timing)
Object[] allArgs = loadArgs(className, methodName, descriptor, invocationIndex);
Class<?> targetClass = Class.forName(className);
Type[] paramTypes = Type.getArgumentTypes(descriptor);
Class<?>[] paramClasses = new Class<?>[paramTypes.length];
for (int i = 0; i < paramTypes.length; i++) {
paramClasses[i] = typeToClass(paramTypes[i]);
}
Method method = targetClass.getDeclaredMethod(methodName, paramClasses);
method.setAccessible(true);
boolean isStatic = Modifier.isStatic(method.getModifiers());
Object instance = null;
if (!isStatic) {
try {
java.lang.reflect.Constructor<?> ctor = targetClass.getDeclaredConstructor();
ctor.setAccessible(true);
instance = ctor.newInstance();
} catch (NoSuchMethodException e) {
instance = new org.objenesis.ObjenesisStd().newInstance(targetClass);
}
}
// Get the calling test method name from the stack trace
String testMethodName = getCallingTestMethodName();
// Module name = the test class that called us
String testClassName = getCallingTestClassName();
if ("behavior".equals(mode)) {
replayBehavior(method, instance, allArgs, className, methodName, testClassName, testMethodName);
} else if ("performance".equals(mode)) {
replayPerformance(method, instance, allArgs, className, methodName, testClassName, testMethodName);
} else {
// No codeflash mode just invoke (trace-only or manual testing)
method.invoke(instance, allArgs);
}
}
private void replayBehavior(Method method, Object instance, Object[] args,
String className, String methodName,
String testClassName, String testMethodName) throws Exception {
// testIteration goes at the END so the Comparator's lastUnderscore stripping
// removes it, making baseline (iteration=0) and candidate (iteration=N) keys match.
String invId = testMethodName + "_" + testIteration;
// Print start marker (same format as behavior instrumentation)
System.out.println("!$######" + testClassName + ":" + testClassName + "." + testMethodName
+ ":" + methodName + ":" + loopIndex + ":" + invId + "######$!");
long startNs = System.nanoTime();
Object result;
try {
result = method.invoke(instance, args);
} catch (java.lang.reflect.InvocationTargetException e) {
throw (Exception) e.getCause();
}
long durationNs = System.nanoTime() - startNs;
// Print end marker
System.out.println("!######" + testClassName + ":" + testClassName + "." + testMethodName
+ ":" + methodName + ":" + loopIndex + ":" + invId + ":" + durationNs + "######!");
// Write return value to SQLite for correctness comparison
if (outputFile != null && !outputFile.isEmpty()) {
writeBehaviorResult(testClassName, testMethodName, methodName, invId, durationNs, result);
}
}
private void replayPerformance(Method method, Object instance, Object[] args,
String className, String methodName,
String testClassName, String testMethodName) throws Exception {
// Performance mode: run inner loop for JIT warmup, print timing for each iteration
int maxInner = innerIterations;
for (int inner = 0; inner < maxInner; inner++) {
int loopId = (loopIndex - 1) * maxInner + inner;
String invId = testMethodName;
// Print start marker
System.out.println("!$######" + testClassName + ":" + testClassName + "." + testMethodName
+ ":" + methodName + ":" + loopId + ":" + invId + "######$!");
long startNs = System.nanoTime();
try {
method.invoke(instance, args);
} catch (java.lang.reflect.InvocationTargetException e) {
// Swallow performance mode doesn't check correctness
}
long durationNs = System.nanoTime() - startNs;
// Print end marker
System.out.println("!######" + testClassName + ":" + testClassName + "." + testMethodName
+ ":" + methodName + ":" + loopId + ":" + invId + ":" + durationNs + "######!");
}
}
private void writeBehaviorResult(String testClassName, String testMethodName,
String functionName, String invId,
long durationNs, Object result) {
try {
ensureBehaviorDb();
String sql = "INSERT INTO test_results VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)";
try (PreparedStatement ps = behaviorDb.prepareStatement(sql)) {
ps.setString(1, testClassName); // test_module_path
ps.setString(2, testClassName); // test_class_name
ps.setString(3, testMethodName); // test_function_name
ps.setString(4, functionName); // function_getting_tested
ps.setInt(5, loopIndex); // loop_index
ps.setString(6, invId); // iteration_id
ps.setLong(7, durationNs); // runtime
ps.setBytes(8, serializeResult(result)); // return_value
ps.setString(9, "function_call"); // verification_type
ps.executeUpdate();
}
} catch (Exception e) {
System.err.println("ReplayHelper: SQLite behavior write error: " + e.getMessage());
}
}
private void ensureBehaviorDb() throws SQLException {
if (behaviorDbInitialized) return;
behaviorDbInitialized = true;
behaviorDb = DriverManager.getConnection("jdbc:sqlite:" + outputFile);
try (java.sql.Statement stmt = behaviorDb.createStatement()) {
stmt.execute("CREATE TABLE IF NOT EXISTS test_results (" +
"test_module_path TEXT, test_class_name TEXT, test_function_name TEXT, " +
"function_getting_tested TEXT, loop_index INTEGER, iteration_id TEXT, " +
"runtime INTEGER, return_value BLOB, verification_type TEXT)");
}
}
private byte[] serializeResult(Object result) {
if (result == null) return null;
try {
return Serializer.serialize(result);
} catch (Exception e) {
// Fall back to String.valueOf if Kryo fails
return String.valueOf(result).getBytes(java.nio.charset.StandardCharsets.UTF_8);
}
}
private Object[] loadArgs(String className, String methodName, String descriptor, int invocationIndex)
throws SQLException {
byte[] argsBlob;
try (PreparedStatement stmt = db.prepareStatement(
try (PreparedStatement stmt = traceDb.prepareStatement(
"SELECT args FROM function_calls " +
"WHERE classname = ? AND function = ? AND descriptor = ? " +
"ORDER BY time_ns LIMIT 1 OFFSET ?")) {
@ -43,46 +204,35 @@ public class ReplayHelper {
}
}
// Deserialize args
Object deserialized = Serializer.deserialize(argsBlob);
if (!(deserialized instanceof Object[])) {
throw new RuntimeException("Deserialized args is not Object[], got: "
+ (deserialized == null ? "null" : deserialized.getClass().getName()));
}
Object[] allArgs = (Object[]) deserialized;
return (Object[]) deserialized;
}
// Load the target class
Class<?> targetClass = Class.forName(className);
// Parse descriptor to find parameter types
Type[] paramTypes = Type.getArgumentTypes(descriptor);
Class<?>[] paramClasses = new Class<?>[paramTypes.length];
for (int i = 0; i < paramTypes.length; i++) {
paramClasses[i] = typeToClass(paramTypes[i]);
}
// Find the method
Method method = targetClass.getDeclaredMethod(methodName, paramClasses);
method.setAccessible(true);
boolean isStatic = Modifier.isStatic(method.getModifiers());
if (isStatic) {
method.invoke(null, allArgs);
} else {
// Args contain only explicit parameters (no 'this').
// Create a default instance via no-arg constructor or Kryo.
Object instance;
try {
java.lang.reflect.Constructor<?> ctor = targetClass.getDeclaredConstructor();
ctor.setAccessible(true);
instance = ctor.newInstance();
} catch (NoSuchMethodException e) {
// Fall back to Objenesis instantiation (no constructor needed)
instance = new org.objenesis.ObjenesisStd().newInstance(targetClass);
private static String getCallingTestMethodName() {
StackTraceElement[] stack = Thread.currentThread().getStackTrace();
// Walk up: [0]=getStackTrace, [1]=this method, [2]=replay(), [3]=calling test method
for (int i = 3; i < stack.length; i++) {
String method = stack[i].getMethodName();
if (method.startsWith("replay_")) {
return method;
}
method.invoke(instance, allArgs);
}
return stack.length > 3 ? stack[3].getMethodName() : "unknown";
}
private static String getCallingTestClassName() {
StackTraceElement[] stack = Thread.currentThread().getStackTrace();
for (int i = 3; i < stack.length; i++) {
String cls = stack[i].getClassName();
if (cls.contains("ReplayTest") || cls.contains("replay")) {
return cls;
}
}
return stack.length > 3 ? stack[3].getClassName() : "unknown";
}
private static Class<?> typeToClass(Type type) throws ClassNotFoundException {
@ -106,11 +256,23 @@ public class ReplayHelper {
}
}
private static int parseIntEnv(String name, int defaultValue) {
String val = System.getenv(name);
if (val == null || val.isEmpty()) return defaultValue;
try { return Integer.parseInt(val); } catch (NumberFormatException e) { return defaultValue; }
}
private static String getEnvOrDefault(String name, String defaultValue) {
String val = System.getenv(name);
return (val != null && !val.isEmpty()) ? val : defaultValue;
}
public void close() {
try {
if (db != null) db.close();
} catch (SQLException e) {
System.err.println("Error closing ReplayHelper: " + e.getMessage());
try { if (traceDb != null) traceDb.close(); } catch (SQLException e) {
System.err.println("Error closing ReplayHelper trace db: " + e.getMessage());
}
try { if (behaviorDb != null) behaviorDb.close(); } catch (SQLException e) {
System.err.println("Error closing ReplayHelper behavior db: " + e.getMessage());
}
}
}

View file

@ -22,6 +22,7 @@ public final class TraceRecorder {
private final TracerConfig config;
private final TraceWriter writer;
private final ConcurrentHashMap<String, AtomicInteger> functionCounts = new ConcurrentHashMap<>();
private final AtomicInteger droppedCaptures = new AtomicInteger(0);
private final int maxFunctionCount;
private final ExecutorService serializerExecutor;
@ -82,11 +83,13 @@ public final class TraceRecorder {
argsBlob = future.get(SERIALIZATION_TIMEOUT_MS, TimeUnit.MILLISECONDS);
} catch (TimeoutException e) {
future.cancel(true);
droppedCaptures.incrementAndGet();
System.err.println("[codeflash-tracer] Serialization timed out for " + className + "."
+ methodName);
return;
} catch (Exception e) {
Throwable cause = e.getCause() != null ? e.getCause() : e;
droppedCaptures.incrementAndGet();
System.err.println("[codeflash-tracer] Serialization failed for " + className + "."
+ methodName + ": " + cause.getClass().getSimpleName() + ": " + cause.getMessage());
return;
@ -113,11 +116,15 @@ public final class TraceRecorder {
}
metadata.put("totalCaptures", String.valueOf(totalCaptures));
int dropped = droppedCaptures.get();
metadata.put("droppedCaptures", String.valueOf(dropped));
writer.writeMetadata(metadata);
writer.flush();
writer.close();
System.err.println("[codeflash-tracer] Captured " + totalCaptures
+ " invocations across " + functionCounts.size() + " methods");
+ " invocations across " + functionCounts.size() + " methods"
+ (dropped > 0 ? " (" + dropped + " dropped due to serialization timeout/failure)" : ""));
}
}

View file

@ -4,14 +4,20 @@ import org.objectweb.asm.ClassVisitor;
import org.objectweb.asm.MethodVisitor;
import org.objectweb.asm.Opcodes;
import java.util.Collections;
import java.util.Map;
public class TracingClassVisitor extends ClassVisitor {
private final String internalClassName;
private final Map<String, Integer> methodLineNumbers;
private String sourceFile;
public TracingClassVisitor(ClassVisitor classVisitor, String internalClassName) {
public TracingClassVisitor(ClassVisitor classVisitor, String internalClassName,
Map<String, Integer> methodLineNumbers) {
super(Opcodes.ASM9, classVisitor);
this.internalClassName = internalClassName;
this.methodLineNumbers = methodLineNumbers != null ? methodLineNumbers : Collections.emptyMap();
}
@Override
@ -37,7 +43,8 @@ public class TracingClassVisitor extends ClassVisitor {
return mv;
}
int lineNumber = methodLineNumbers.getOrDefault(name + descriptor, 0);
return new TracingMethodAdapter(mv, access, name, descriptor,
internalClassName, 0, sourceFile != null ? sourceFile : "");
internalClassName, lineNumber, sourceFile != null ? sourceFile : "");
}
}

View file

@ -1,10 +1,16 @@
package com.codeflash.tracer;
import org.objectweb.asm.ClassReader;
import org.objectweb.asm.ClassVisitor;
import org.objectweb.asm.ClassWriter;
import org.objectweb.asm.Label;
import org.objectweb.asm.MethodVisitor;
import org.objectweb.asm.Opcodes;
import java.lang.instrument.ClassFileTransformer;
import java.security.ProtectionDomain;
import java.util.HashMap;
import java.util.Map;
public class TracingTransformer implements ClassFileTransformer {
@ -22,11 +28,6 @@ public class TracingTransformer implements ClassFileTransformer {
return null;
}
// Skip instrumentation if we're inside a recording call (e.g., during Kryo serialization)
if (TraceRecorder.isRecording()) {
return null;
}
// Skip internal JDK, framework, and synthetic classes
if (className.startsWith("java/")
|| className.startsWith("javax/")
@ -51,6 +52,30 @@ public class TracingTransformer implements ClassFileTransformer {
private byte[] instrumentClass(String internalClassName, byte[] bytecode) {
ClassReader cr = new ClassReader(bytecode);
// Pre-scan: collect the first source line number for each method.
// ASM's visitMethod() doesn't provide line info it arrives later via visitLineNumber().
// We do a lightweight read pass first so the instrumentation pass has accurate line numbers.
Map<String, Integer> methodLineNumbers = new HashMap<>();
cr.accept(new ClassVisitor(Opcodes.ASM9) {
@Override
public MethodVisitor visitMethod(int access, String name, String descriptor,
String signature, String[] exceptions) {
String key = name + descriptor;
return new MethodVisitor(Opcodes.ASM9) {
private boolean captured = false;
@Override
public void visitLineNumber(int line, Label start) {
if (!captured) {
methodLineNumbers.put(key, line);
captured = true;
}
}
};
}
}, ClassReader.SKIP_FRAMES);
// Use COMPUTE_MAXS only (not COMPUTE_FRAMES) to preserve original stack map frames.
// COMPUTE_FRAMES recomputes all frames and calls getCommonSuperClass() which either
// triggers classloader deadlocks or produces incorrect frames when returning "java/lang/Object".
@ -58,7 +83,7 @@ public class TracingTransformer implements ClassFileTransformer {
// adjusts offsets for injected code. Our AdviceAdapter only injects at method entry
// (before any branch points), so existing frames remain valid.
ClassWriter cw = new ClassWriter(cr, ClassWriter.COMPUTE_MAXS);
TracingClassVisitor cv = new TracingClassVisitor(cw, internalClassName);
TracingClassVisitor cv = new TracingClassVisitor(cw, internalClassName, methodLineNumbers);
cr.accept(cv, ClassReader.EXPAND_FRAMES);
return cw.toByteArray();
}

View file

@ -376,6 +376,7 @@ def _build_parser() -> ArgumentParser:
subparsers.add_parser("vscode-install", help="Install the Codeflash VSCode extension")
subparsers.add_parser("init-actions", help="Initialize GitHub Actions workflow")
trace_optimize = subparsers.add_parser("optimize", help="Trace and optimize your project.", add_help=False)
auth_parser = subparsers.add_parser("auth", help="Authentication commands")
auth_subparsers = auth_parser.add_subparsers(dest="auth_command", help="Auth sub-commands")
auth_subparsers.add_parser("login", help="Log in to Codeflash via OAuth")
@ -391,8 +392,6 @@ def _build_parser() -> ArgumentParser:
compare_parser.add_argument("--timeout", type=int, default=600, help="Benchmark timeout in seconds (default: 600)")
compare_parser.add_argument("--config-file", type=str, dest="config_file", help="Path to pyproject.toml")
trace_optimize = subparsers.add_parser("optimize", help="Trace and optimize your project.")
trace_optimize.add_argument(
"--max-function-count",
type=int,

View file

@ -554,11 +554,13 @@ def get_all_replay_test_functions(
def _get_java_replay_test_functions(
replay_test: list[Path], test_cfg: TestConfig, project_root_path: Path
replay_test: list[Path], test_cfg: TestConfig, project_root_path: Path | str
) -> tuple[dict[Path, list[FunctionToOptimize]], Path]:
"""Parse Java replay test files to extract functions and trace file path."""
from codeflash.languages.java.replay_test import parse_replay_test_metadata
project_root_path = Path(project_root_path)
trace_file_path: Path | None = None
functions: dict[Path, list[FunctionToOptimize]] = defaultdict(list)
@ -602,7 +604,7 @@ def _get_java_replay_test_functions(
all_functions = lang_support.discover_functions(source_code, source_file)
for func in all_functions:
if func.function_name in function_names:
if func.function_name in function_names or func.qualified_name in function_names:
functions[source_file].append(func)
if trace_file_path is None:

View file

@ -745,26 +745,35 @@ def _add_behavior_instrumentation(source: str, class_name: str, func_name: str,
if _is_test_annotation(stripped):
if not helper_added:
helper_added = True
result.append(line)
i += 1
# Collect any additional annotations
while i < len(lines) and lines[i].strip().startswith("@"):
result.append(lines[i])
# Check if the @Test line already contains the method signature and opening brace
# (common in compact test styles like replay tests: @Test void replay_foo_0() throws Exception {)
if "{" in line:
# The annotation line IS the method signature — don't look for a separate one
result.append(line)
i += 1
method_lines = [line]
else:
result.append(line)
i += 1
# Now find the method signature and opening brace
method_lines = []
while i < len(lines):
method_lines.append(lines[i])
if "{" in lines[i]:
break
i += 1
# Collect any additional annotations
while i < len(lines) and lines[i].strip().startswith("@"):
result.append(lines[i])
i += 1
# Add the method signature lines
for ml in method_lines:
result.append(ml)
i += 1
# Now find the method signature and opening brace
method_lines = []
while i < len(lines):
method_lines.append(lines[i])
if "{" in lines[i]:
break
i += 1
# Add the method signature lines
for ml in method_lines:
result.append(ml)
i += 1
# Extract the test method name from the method signature
test_method_name = _extract_test_method_name(method_lines)

View file

@ -2,6 +2,7 @@ from __future__ import annotations
import json
import logging
import os
import shutil
import subprocess
from datetime import datetime
@ -42,6 +43,13 @@ class JfrProfile:
candidate = Path(home) / "bin" / "jfr"
if candidate.exists():
return str(candidate)
java_home_env = os.environ.get("JAVA_HOME")
if java_home_env:
candidate = Path(java_home_env) / "bin" / "jfr"
if candidate.exists():
return str(candidate)
return None
def _parse(self) -> None:
@ -152,6 +160,8 @@ class JfrProfile:
method_name = method.get("name", "")
if not class_name or not method_name:
return None
# JFR uses / separators (JVM internal format), normalize to dots for package matching
class_name = class_name.replace("/", ".")
return f"{class_name}.{method_name}"
def _store_method_info(self, key: str, frame: dict[str, Any]) -> None:
@ -159,7 +169,7 @@ class JfrProfile:
return
method = frame.get("method", {})
self._method_info[key] = {
"class_name": method.get("type", {}).get("name", ""),
"class_name": method.get("type", {}).get("name", "").replace("/", "."),
"method_name": method.get("name", ""),
"descriptor": method.get("descriptor", ""),
"line_number": str(frame.get("lineNumber", 0)),

View file

@ -906,7 +906,15 @@ class MavenStrategy(BuildToolStrategy):
" --add-opens java.base/java.net=ALL-UNNAMED"
" --add-opens java.base/java.util.zip=ALL-UNNAMED"
)
if javaagent_arg:
if enable_coverage:
# When coverage is enabled, JaCoCo's prepare-agent goal sets argLine via
# @{argLine}. Overriding -DargLine would clobber the JaCoCo agent flag.
# Pass add-opens and javaagent via JDK_JAVA_OPTIONS instead.
jdk_opts_parts = [add_opens_flags]
if javaagent_arg:
jdk_opts_parts.insert(0, javaagent_arg)
env["JDK_JAVA_OPTIONS"] = " ".join(jdk_opts_parts)
elif javaagent_arg:
cmd.append(f"-DargLine={javaagent_arg} {add_opens_flags}")
else:
cmd.append(f"-DargLine={add_opens_flags}")

View file

@ -12,9 +12,12 @@ if TYPE_CHECKING:
logger = logging.getLogger(__name__)
def generate_replay_tests(trace_db_path: Path, output_dir: Path, project_root: Path, max_run_count: int = 256) -> int:
"""Generate JUnit 5 replay test files from a trace SQLite database.
def generate_replay_tests(
trace_db_path: Path, output_dir: Path, project_root: Path, max_run_count: int = 256, test_framework: str = "junit5"
) -> int:
"""Generate JUnit replay test files from a trace SQLite database.
Supports both JUnit 5 (default) and JUnit 4.
Returns the number of test files generated.
"""
if not trace_db_path.exists():
@ -44,22 +47,29 @@ def generate_replay_tests(trace_db_path: Path, output_dir: Path, project_root: P
test_methods_code: list[str] = []
class_function_names: list[str] = []
# Global test counter to avoid duplicate method names for overloaded Java methods
method_name_counters: dict[str, int] = {}
for method_name, descriptor in method_list:
# Count invocations for this method
count_result = conn.execute(
"SELECT COUNT(*) FROM function_calls WHERE classname = ? AND function = ? AND descriptor = ?",
(classname, method_name, descriptor),
).fetchone()
invocation_count = min(count_result[0], max_run_count)
class_function_names.append(method_name)
simple_class = classname.rsplit(".", 1)[-1]
class_function_names.append(f"{simple_class}.{method_name}")
safe_method = _sanitize_identifier(method_name)
for i in range(invocation_count):
# Use a global counter per method name to avoid collisions on overloaded methods
test_idx = method_name_counters.get(safe_method, 0)
method_name_counters[safe_method] = test_idx + 1
escaped_descriptor = descriptor.replace('"', '\\"')
access = "public " if test_framework == "junit4" else ""
test_methods_code.append(
f" @Test void replay_{safe_method}_{i}() throws Exception {{\n"
f" @Test {access}void replay_{safe_method}_{test_idx}() throws Exception {{\n"
f' helper.replay("{classname}", "{method_name}", '
f'"{escaped_descriptor}", {i});\n'
f" }}"
@ -69,18 +79,28 @@ def generate_replay_tests(trace_db_path: Path, output_dir: Path, project_root: P
# Generate the test file
functions_comment = ",".join(class_function_names)
if test_framework == "junit4":
test_imports = "import org.junit.Test;\nimport org.junit.AfterClass;\n"
cleanup_annotation = "@AfterClass"
class_modifier = "public "
else:
test_imports = "import org.junit.jupiter.api.Test;\nimport org.junit.jupiter.api.AfterAll;\n"
cleanup_annotation = "@AfterAll"
class_modifier = ""
test_content = (
f"// codeflash:functions={functions_comment}\n"
f"// codeflash:trace_file={trace_db_path.as_posix()}\n"
f"// codeflash:classname={classname}\n"
f"package codeflash.replay;\n\n"
f"import org.junit.jupiter.api.Test;\n"
f"import org.junit.jupiter.api.AfterAll;\n"
f"{test_imports}"
f"import com.codeflash.ReplayHelper;\n\n"
f"class {test_class_name} {{\n"
f"{class_modifier}class {test_class_name} {{\n"
f" private static final ReplayHelper helper =\n"
f' new ReplayHelper("{trace_db_path.as_posix()}");\n\n'
f" @AfterAll static void cleanup() {{ helper.close(); }}\n\n" + "\n\n".join(test_methods_code) + "\n"
f" {cleanup_annotation} public static void cleanup() {{ helper.close(); }}\n\n"
+ "\n\n".join(test_methods_code)
+ "\n"
"}\n"
)

View file

@ -14,6 +14,39 @@ if TYPE_CHECKING:
logger = logging.getLogger(__name__)
GRACEFUL_SHUTDOWN_WAIT = 5 # seconds to wait after SIGTERM before SIGKILL
def _run_java_with_graceful_timeout(
java_command: list[str], env: dict[str, str], timeout: int, stage_name: str
) -> None:
"""Run a Java command with graceful timeout handling.
Sends SIGTERM first (allowing JFR dump and shutdown hooks to run),
then SIGKILL if the process doesn't exit within GRACEFUL_SHUTDOWN_WAIT seconds.
"""
if not timeout:
subprocess.run(java_command, env=env, check=False)
return
import signal
proc = subprocess.Popen(java_command, env=env)
try:
proc.wait(timeout=timeout)
except subprocess.TimeoutExpired:
logger.warning(
"%s stage timed out after %d seconds, sending SIGTERM for graceful shutdown...", stage_name, timeout
)
proc.send_signal(signal.SIGTERM)
try:
proc.wait(timeout=GRACEFUL_SHUTDOWN_WAIT)
except subprocess.TimeoutExpired:
logger.warning("%s stage did not exit after SIGTERM, sending SIGKILL", stage_name)
proc.kill()
proc.wait()
# --add-opens flags needed for Kryo serialization on Java 16+
ADD_OPENS_FLAGS = (
"--add-opens=java.base/java.util=ALL-UNNAMED "
@ -48,10 +81,7 @@ class JavaTracer:
# Stage 1: JFR Profiling
logger.info("Stage 1: Running JFR profiling...")
jfr_env = self.build_jfr_env(jfr_file)
try:
subprocess.run(java_command, env=jfr_env, check=False, timeout=timeout or None)
except subprocess.TimeoutExpired:
logger.warning("JFR profiling stage timed out after %d seconds", timeout)
_run_java_with_graceful_timeout(java_command, jfr_env, timeout, "JFR profiling")
if not jfr_file.exists():
logger.warning("JFR file was not created at %s", jfr_file)
@ -62,10 +92,7 @@ class JavaTracer:
trace_db_path, packages, project_root=project_root, max_function_count=max_function_count, timeout=timeout
)
agent_env = self.build_agent_env(config_path)
try:
subprocess.run(java_command, env=agent_env, check=False, timeout=timeout or None)
except subprocess.TimeoutExpired:
logger.warning("Argument capture stage timed out after %d seconds", timeout)
_run_java_with_graceful_timeout(java_command, agent_env, timeout, "Argument capture")
if not trace_db_path.exists():
logger.error("Trace database was not created at %s", trace_db_path)
@ -95,7 +122,12 @@ class JavaTracer:
def build_jfr_env(self, jfr_file: Path) -> dict[str, str]:
env = os.environ.copy()
jfr_opts = f"-XX:StartFlightRecording=filename={jfr_file.resolve()},settings=profile,dumponexit=true"
# Use profile settings with increased sampling frequency (1ms instead of default 10ms)
# This captures more samples for short-running programs
jfr_opts = (
f"-XX:StartFlightRecording=filename={jfr_file.resolve()},settings=profile,dumponexit=true"
",jdk.ExecutionSample#period=1ms"
)
existing = env.get("JAVA_TOOL_OPTIONS", "")
env["JAVA_TOOL_OPTIONS"] = f"{existing} {jfr_opts}".strip()
return env
@ -133,7 +165,7 @@ class JavaTracer:
if stripped.startswith("package "):
pkg = stripped[8:].rstrip(";").strip()
parts = pkg.split(".")
prefix = ".".join(parts[: min(2, len(parts))])
prefix = ".".join(parts[: min(3, len(parts))])
packages.add(prefix)
break
if stripped and not stripped.startswith("//"):
@ -153,6 +185,7 @@ def run_java_tracer(
max_function_count: int = 256,
timeout: int = 0,
max_run_count: int = 256,
test_framework: str = "junit5",
) -> tuple[Path, Path, int]:
"""High-level entry point: trace a Java command and generate replay tests.
@ -169,7 +202,11 @@ def run_java_tracer(
)
test_count = generate_replay_tests(
trace_db_path=trace_db, output_dir=output_dir, project_root=project_root, max_run_count=max_run_count
trace_db_path=trace_db,
output_dir=output_dir,
project_root=project_root,
max_run_count=max_run_count,
test_framework=test_framework,
)
return trace_db, jfr_file, test_count

View file

@ -950,7 +950,7 @@ class TestResults(BaseModel): # noqa: PLW1641
by_id: dict[InvocationId, list[int]] = {}
for result in self.test_results:
if result.did_pass:
if result.runtime:
if result.runtime is not None:
by_id.setdefault(result.id, []).append(result.runtime)
else:
msg = (

View file

@ -349,8 +349,12 @@ def _run_java_tracer(existing_args: Namespace | None = None) -> ArgumentParser:
max_function_count = getattr(config, "max_function_count", 256)
timeout = int(getattr(config, "timeout", None) or getattr(config, "tracer_timeout", 0) or 0)
console.print("[bold]Java project detected[/]")
console.print(f" Project root: {project_root}")
console.print(f" Module root: {getattr(config, 'module_root', '?')}")
console.print(f" Tests root: {getattr(config, 'tests_root', '?')}")
from codeflash.code_utils.code_utils import get_run_tmp_file
from codeflash.languages.java.build_tools import find_test_root
from codeflash.languages.java.tracer import JavaTracer, run_java_tracer
tracer = JavaTracer()
@ -360,12 +364,16 @@ def _run_java_tracer(existing_args: Namespace | None = None) -> ArgumentParser:
trace_db_path = get_run_tmp_file(Path("java_trace.db"))
# Place replay tests in the project's test source tree so Maven/Gradle can compile them
test_root = find_test_root(project_root)
if test_root:
output_dir = test_root / "codeflash" / "replay"
# Place replay tests in the project's test source tree so Maven/Gradle can compile them.
# Use the config's tests_root (correctly resolved for multi-module projects) not find_test_root().
tests_root = Path(getattr(config, "tests_root", ""))
if tests_root.is_dir():
output_dir = tests_root / "codeflash" / "replay"
else:
output_dir = project_root / "src" / "test" / "java" / "codeflash" / "replay"
from codeflash.languages.java.build_tools import find_test_root
test_root = find_test_root(project_root)
output_dir = (test_root or project_root / "src" / "test" / "java") / "codeflash" / "replay"
output_dir.mkdir(parents=True, exist_ok=True)
# Remaining args after our flags are the Java command
@ -377,6 +385,12 @@ def _run_java_tracer(existing_args: Namespace | None = None) -> ArgumentParser:
sys.exit(1)
java_command = remaining
# Detect test framework for replay test generation
from codeflash.languages.java.config import detect_java_project
java_config = detect_java_project(project_root)
test_framework = java_config.test_framework if java_config else "junit5"
trace_db, jfr_file, test_count = run_java_tracer(
java_command=java_command,
trace_db_path=trace_db_path,
@ -385,6 +399,7 @@ def _run_java_tracer(existing_args: Namespace | None = None) -> ArgumentParser:
output_dir=output_dir,
max_function_count=max_function_count,
timeout=timeout,
test_framework=test_framework,
)
console.print(f"[bold green]Java tracing complete:[/] {test_count} replay test files generated")
@ -404,7 +419,7 @@ def _run_java_tracer(existing_args: Namespace | None = None) -> ArgumentParser:
config.replay_test = replay_test_paths
config.previous_checkpoint_functions = None
config.effort = EffortLevel.HIGH.value
config.no_pr = True
config.no_pr = getattr(config, "no_pr", False)
config.file = None
config.function = None
config.test_project_root = project_root

View file

@ -50,6 +50,7 @@ def run_test(expected_improvement_pct: int) -> bool:
"--no-project",
"-m",
"codeflash.main",
"--no-pr",
"optimize",
"java",
"-cp",
@ -59,6 +60,7 @@ def run_test(expected_improvement_pct: int) -> bool:
env = os.environ.copy()
env["PYTHONIOENCODING"] = "utf-8"
env["PYTHONUNBUFFERED"] = "1"
logging.info(f"Running command: {' '.join(command)}")
logging.info(f"Working directory: {fixture_dir}")
process = subprocess.Popen(
@ -73,13 +75,11 @@ def run_test(expected_improvement_pct: int) -> bool:
output = []
for line in process.stdout:
logging.info(line.strip())
print(line, end="", flush=True)
output.append(line)
return_code = process.wait()
stdout = "".join(output)
if return_code != 0:
logging.error(f"Full output:\n{stdout}")
if return_code != 0:
logging.error(f"Command returned exit code {return_code}")
@ -90,7 +90,7 @@ def run_test(expected_improvement_pct: int) -> bool:
logging.error("Failed to find replay test generation message")
return False
# Validate: replay tests were discovered
# Validate: replay tests were discovered (global count)
replay_match = re.search(r"Discovered \d+ existing unit tests? and (\d+) replay tests?", stdout)
if not replay_match:
logging.error("Failed to find replay test discovery message")
@ -101,6 +101,17 @@ def run_test(expected_improvement_pct: int) -> bool:
return False
logging.info(f"Replay tests discovered: {num_replay}")
# Validate: replay test files were used per-function
replay_file_match = re.search(r"Discovered \d+ existing unit test files?, (\d+) replay test files?", stdout)
if not replay_file_match:
logging.error("Failed to find per-function replay test file discovery message")
return False
num_replay_files = int(replay_file_match.group(1))
if num_replay_files == 0:
logging.error("No replay test files discovered per-function")
return False
logging.info(f"Replay test files per-function: {num_replay_files}")
# Validate: at least one optimization was found
if "⚡️ Optimization successful! 📄 " not in stdout:
logging.error("Failed to find optimization success message")

View file

@ -36,20 +36,30 @@ public class Workload {
}
public static void main(String[] args) {
// Exercise the methods so the tracer can capture invocations
// Run methods with large inputs so JFR can capture CPU samples.
// Small inputs finish too fast (<1ms) for JFR's 10ms sampling interval.
for (int round = 0; round < 1000; round++) {
computeSum(100_000);
repeatString("hello world ", 1000);
List<Integer> nums = new ArrayList<>();
for (int i = 1; i <= 10_000; i++) nums.add(i);
filterEvens(nums);
Workload w = new Workload();
w.instanceMethod(100_000, 42);
}
// Also call with small inputs for variety in traced args
System.out.println("computeSum(100) = " + computeSum(100));
System.out.println("computeSum(50) = " + computeSum(50));
System.out.println("repeatString(\"ab\", 3) = " + repeatString("ab", 3));
System.out.println("repeatString(\"x\", 5) = " + repeatString("x", 5));
List<Integer> nums = new ArrayList<>();
for (int i = 1; i <= 10; i++) nums.add(i);
System.out.println("filterEvens(1..10) = " + filterEvens(nums));
List<Integer> small = new ArrayList<>();
for (int i = 1; i <= 10; i++) small.add(i);
System.out.println("filterEvens(1..10) = " + filterEvens(small));
Workload w = new Workload();
System.out.println("instanceMethod(5, 3) = " + w.instanceMethod(5, 3));
System.out.println("instanceMethod(10, 2) = " + w.instanceMethod(10, 2));
System.out.println("Workload complete.");
}

View file

@ -0,0 +1,302 @@
"""Tests for JFR parser — class name normalization, package filtering, addressable time."""
from __future__ import annotations
import json
import subprocess
from pathlib import Path
from unittest.mock import patch
import pytest
from codeflash.languages.java.jfr_parser import JfrProfile
def _make_jfr_json(events: list[dict]) -> str:
"""Create fake JFR JSON output matching the jfr print format."""
return json.dumps({"recording": {"events": events}})
def _make_execution_sample(class_name: str, method_name: str, start_time: str = "2026-01-01T00:00:00Z") -> dict:
return {
"type": "jdk.ExecutionSample",
"values": {
"startTime": start_time,
"stackTrace": {
"frames": [
{
"method": {
"type": {"name": class_name},
"name": method_name,
"descriptor": "()V",
},
"lineNumber": 42,
}
],
},
},
}
class TestClassNameNormalization:
"""Test that JVM internal class names (com/example/Foo) are normalized to dots (com.example.Foo)."""
def test_slash_separators_normalized_to_dots(self, tmp_path: Path) -> None:
jfr_file = tmp_path / "test.jfr"
jfr_file.write_text("dummy", encoding="utf-8")
jfr_json = _make_jfr_json(
[
_make_execution_sample("com/aerospike/client/command/Buffer", "bytesToInt"),
_make_execution_sample("com/aerospike/client/command/Buffer", "bytesToInt"),
_make_execution_sample("com/aerospike/client/util/Utf8", "encodedLength"),
]
)
with patch("subprocess.run") as mock_run:
mock_run.return_value = subprocess.CompletedProcess(args=[], returncode=0, stdout=jfr_json, stderr="")
profile = JfrProfile(jfr_file, ["com.aerospike"])
assert profile._total_samples == 3
assert len(profile._method_samples) == 2
# Keys should use dots, not slashes
assert "com.aerospike.client.command.Buffer.bytesToInt" in profile._method_samples
assert "com.aerospike.client.util.Utf8.encodedLength" in profile._method_samples
def test_method_info_uses_dot_class_names(self, tmp_path: Path) -> None:
jfr_file = tmp_path / "test.jfr"
jfr_file.write_text("dummy", encoding="utf-8")
jfr_json = _make_jfr_json(
[_make_execution_sample("com/example/MyClass", "myMethod")]
)
with patch("subprocess.run") as mock_run:
mock_run.return_value = subprocess.CompletedProcess(args=[], returncode=0, stdout=jfr_json, stderr="")
profile = JfrProfile(jfr_file, ["com.example"])
info = profile._method_info.get("com.example.MyClass.myMethod")
assert info is not None
assert info["class_name"] == "com.example.MyClass"
assert info["method_name"] == "myMethod"
class TestPackageFiltering:
def test_filters_by_package_prefix(self, tmp_path: Path) -> None:
jfr_file = tmp_path / "test.jfr"
jfr_file.write_text("dummy", encoding="utf-8")
jfr_json = _make_jfr_json(
[
_make_execution_sample("com/aerospike/client/Value", "get"),
_make_execution_sample("java/util/HashMap", "put"),
_make_execution_sample("com/aerospike/benchmarks/Main", "main"),
]
)
with patch("subprocess.run") as mock_run:
mock_run.return_value = subprocess.CompletedProcess(args=[], returncode=0, stdout=jfr_json, stderr="")
profile = JfrProfile(jfr_file, ["com.aerospike"])
# Only com.aerospike classes should be in samples
assert len(profile._method_samples) == 2
assert "com.aerospike.client.Value.get" in profile._method_samples
assert "com.aerospike.benchmarks.Main.main" in profile._method_samples
assert "java.util.HashMap.put" not in profile._method_samples
def test_empty_packages_includes_all(self, tmp_path: Path) -> None:
jfr_file = tmp_path / "test.jfr"
jfr_file.write_text("dummy", encoding="utf-8")
jfr_json = _make_jfr_json(
[
_make_execution_sample("com/example/Foo", "bar"),
_make_execution_sample("java/lang/String", "length"),
]
)
with patch("subprocess.run") as mock_run:
mock_run.return_value = subprocess.CompletedProcess(args=[], returncode=0, stdout=jfr_json, stderr="")
profile = JfrProfile(jfr_file, [])
assert len(profile._method_samples) == 2
class TestAddressableTime:
def test_addressable_time_proportional_to_samples(self, tmp_path: Path) -> None:
jfr_file = tmp_path / "test.jfr"
jfr_file.write_text("dummy", encoding="utf-8")
# 3 samples for methodA, 1 for methodB, spanning 10 seconds
jfr_json = _make_jfr_json(
[
_make_execution_sample("com/example/Foo", "methodA", "2026-01-01T00:00:00Z"),
_make_execution_sample("com/example/Foo", "methodA", "2026-01-01T00:00:03Z"),
_make_execution_sample("com/example/Foo", "methodA", "2026-01-01T00:00:06Z"),
_make_execution_sample("com/example/Foo", "methodB", "2026-01-01T00:00:10Z"),
]
)
with patch("subprocess.run") as mock_run:
mock_run.return_value = subprocess.CompletedProcess(args=[], returncode=0, stdout=jfr_json, stderr="")
profile = JfrProfile(jfr_file, ["com.example"])
time_a = profile.get_addressable_time_ns("com.example.Foo", "methodA")
time_b = profile.get_addressable_time_ns("com.example.Foo", "methodB")
# methodA has 3x the samples of methodB, so 3x the addressable time
assert time_a > 0
assert time_b > 0
assert time_a == pytest.approx(time_b * 3, rel=0.01)
def test_addressable_time_zero_for_unknown_method(self, tmp_path: Path) -> None:
jfr_file = tmp_path / "test.jfr"
jfr_file.write_text("dummy", encoding="utf-8")
jfr_json = _make_jfr_json(
[_make_execution_sample("com/example/Foo", "bar")]
)
with patch("subprocess.run") as mock_run:
mock_run.return_value = subprocess.CompletedProcess(args=[], returncode=0, stdout=jfr_json, stderr="")
profile = JfrProfile(jfr_file, ["com.example"])
assert profile.get_addressable_time_ns("com.example.Foo", "nonExistent") == 0.0
class TestMethodRanking:
def test_ranking_ordered_by_sample_count(self, tmp_path: Path) -> None:
jfr_file = tmp_path / "test.jfr"
jfr_file.write_text("dummy", encoding="utf-8")
jfr_json = _make_jfr_json(
[
_make_execution_sample("com/example/A", "hot"),
_make_execution_sample("com/example/A", "hot"),
_make_execution_sample("com/example/A", "hot"),
_make_execution_sample("com/example/B", "warm"),
_make_execution_sample("com/example/B", "warm"),
_make_execution_sample("com/example/C", "cold"),
]
)
with patch("subprocess.run") as mock_run:
mock_run.return_value = subprocess.CompletedProcess(args=[], returncode=0, stdout=jfr_json, stderr="")
profile = JfrProfile(jfr_file, ["com.example"])
ranking = profile.get_method_ranking()
assert len(ranking) == 3
assert ranking[0]["method_name"] == "hot"
assert ranking[0]["sample_count"] == 3
assert ranking[1]["method_name"] == "warm"
assert ranking[1]["sample_count"] == 2
assert ranking[2]["method_name"] == "cold"
assert ranking[2]["sample_count"] == 1
def test_empty_ranking_when_no_samples(self, tmp_path: Path) -> None:
jfr_file = tmp_path / "test.jfr"
jfr_file.write_text("dummy", encoding="utf-8")
jfr_json = _make_jfr_json([])
with patch("subprocess.run") as mock_run:
mock_run.return_value = subprocess.CompletedProcess(args=[], returncode=0, stdout=jfr_json, stderr="")
profile = JfrProfile(jfr_file, ["com.example"])
assert profile.get_method_ranking() == []
def test_ranking_uses_dot_class_names(self, tmp_path: Path) -> None:
jfr_file = tmp_path / "test.jfr"
jfr_file.write_text("dummy", encoding="utf-8")
jfr_json = _make_jfr_json(
[_make_execution_sample("com/example/nested/Deep", "method")]
)
with patch("subprocess.run") as mock_run:
mock_run.return_value = subprocess.CompletedProcess(args=[], returncode=0, stdout=jfr_json, stderr="")
profile = JfrProfile(jfr_file, ["com.example"])
ranking = profile.get_method_ranking()
assert len(ranking) == 1
assert ranking[0]["class_name"] == "com.example.nested.Deep"
class TestGracefulTimeout:
"""Test that _run_java_with_graceful_timeout sends SIGTERM before SIGKILL."""
def test_sends_sigterm_on_timeout(self) -> None:
import signal
from codeflash.languages.java.tracer import _run_java_with_graceful_timeout
# Run a sleep command with a 1s timeout — should get SIGTERM'd
import os
env = os.environ.copy()
_run_java_with_graceful_timeout(["sleep", "60"], env, timeout=1, stage_name="test")
# If we get here, the process was killed (didn't hang for 60s)
def test_no_timeout_runs_normally(self) -> None:
import os
from codeflash.languages.java.tracer import _run_java_with_graceful_timeout
env = os.environ.copy()
_run_java_with_graceful_timeout(["echo", "hello"], env, timeout=0, stage_name="test")
# Should complete without error
class TestProjectRootResolution:
"""Test that project_root is correctly set for Java multi-module projects."""
def test_java_project_root_is_build_root_not_module(self, tmp_path: Path, monkeypatch: pytest.MonkeyPatch) -> None:
"""For multi-module Maven, project_root should be the root with <modules>, not a sub-module."""
# Create a multi-module project
(tmp_path / "pom.xml").write_text(
'<project xmlns="http://maven.apache.org/POM/4.0.0"><modules><module>client</module></modules></project>',
encoding="utf-8",
)
client = tmp_path / "client"
client.mkdir()
(client / "pom.xml").write_text("<project/>", encoding="utf-8")
src = client / "src" / "main" / "java"
src.mkdir(parents=True)
test = tmp_path / "src" / "test" / "java"
test.mkdir(parents=True)
monkeypatch.chdir(tmp_path)
from codeflash.code_utils.config_parser import parse_config_file
config, config_path = parse_config_file()
assert config["language"] == "java"
# config_path should be the project root directory
assert config_path == tmp_path
def test_project_root_is_path_not_string(self, tmp_path: Path, monkeypatch: pytest.MonkeyPatch) -> None:
"""project_root from process_pyproject_config should be a Path for Java projects."""
from argparse import Namespace
(tmp_path / "pom.xml").write_text("<project/>", encoding="utf-8")
src = tmp_path / "src" / "main" / "java"
src.mkdir(parents=True)
test = tmp_path / "src" / "test" / "java"
test.mkdir(parents=True)
monkeypatch.chdir(tmp_path)
from codeflash.cli_cmds.cli import process_pyproject_config
# Create a minimal args namespace matching what parse_args produces
args = Namespace(
config_file=None, module_root=None, tests_root=None, benchmarks_root=None,
ignore_paths=None, pytest_cmd=None, formatter_cmds=None, disable_telemetry=None,
disable_imports_sorting=None, git_remote=None, override_fixtures=None,
benchmark=False, verbose=False, version=False, show_config=False, reset_config=False,
)
args = process_pyproject_config(args)
assert hasattr(args, "project_root")
assert isinstance(args.project_root, Path)
assert args.project_root == tmp_path

View file

@ -0,0 +1,255 @@
"""Tests for Java replay test generation — JUnit 4/5 support, overload handling, instrumentation skip."""
from __future__ import annotations
import sqlite3
from pathlib import Path
import pytest
from codeflash.languages.java.replay_test import generate_replay_tests, parse_replay_test_metadata
@pytest.fixture
def trace_db(tmp_path: Path) -> Path:
"""Create a trace database with sample function calls."""
db_path = tmp_path / "trace.db"
conn = sqlite3.connect(str(db_path))
conn.execute(
"CREATE TABLE function_calls("
"type TEXT, function TEXT, classname TEXT, filename TEXT, "
"line_number INTEGER, descriptor TEXT, time_ns INTEGER, args BLOB)"
)
conn.execute("CREATE TABLE metadata(key TEXT PRIMARY KEY, value TEXT)")
conn.execute(
"INSERT INTO function_calls VALUES (?, ?, ?, ?, ?, ?, ?, ?)",
("call", "add", "com.example.Calculator", "Calculator.java", 10, "(II)I", 1000, b"\x00"),
)
conn.execute(
"INSERT INTO function_calls VALUES (?, ?, ?, ?, ?, ?, ?, ?)",
("call", "add", "com.example.Calculator", "Calculator.java", 10, "(II)I", 2000, b"\x00"),
)
conn.execute(
"INSERT INTO function_calls VALUES (?, ?, ?, ?, ?, ?, ?, ?)",
("call", "multiply", "com.example.Calculator", "Calculator.java", 20, "(II)I", 3000, b"\x00"),
)
conn.commit()
conn.close()
return db_path
@pytest.fixture
def trace_db_overloaded(tmp_path: Path) -> Path:
"""Create a trace database with overloaded methods (same name, different descriptors)."""
db_path = tmp_path / "trace_overloaded.db"
conn = sqlite3.connect(str(db_path))
conn.execute(
"CREATE TABLE function_calls("
"type TEXT, function TEXT, classname TEXT, filename TEXT, "
"line_number INTEGER, descriptor TEXT, time_ns INTEGER, args BLOB)"
)
conn.execute("CREATE TABLE metadata(key TEXT PRIMARY KEY, value TEXT)")
# Two overloads of estimateKeySize with different descriptors
for i in range(3):
conn.execute(
"INSERT INTO function_calls VALUES (?, ?, ?, ?, ?, ?, ?, ?)",
("call", "estimateKeySize", "com.example.Command", "Command.java", 10, "(I)I", i * 1000, b"\x00"),
)
for i in range(2):
conn.execute(
"INSERT INTO function_calls VALUES (?, ?, ?, ?, ?, ?, ?, ?)",
(
"call",
"estimateKeySize",
"com.example.Command",
"Command.java",
15,
"(Ljava/lang/String;)I",
(i + 10) * 1000,
b"\x00",
),
)
conn.commit()
conn.close()
return db_path
class TestGenerateReplayTestsJunit5:
def test_generates_junit5_by_default(self, trace_db: Path, tmp_path: Path) -> None:
output_dir = tmp_path / "output"
count = generate_replay_tests(trace_db, output_dir, tmp_path)
assert count == 1
test_file = list(output_dir.glob("*.java"))[0]
content = test_file.read_text(encoding="utf-8")
assert "import org.junit.jupiter.api.Test;" in content
assert "import org.junit.jupiter.api.AfterAll;" in content
assert "@Test void replay_add_0()" in content
def test_junit5_class_is_package_private(self, trace_db: Path, tmp_path: Path) -> None:
output_dir = tmp_path / "output"
generate_replay_tests(trace_db, output_dir, tmp_path)
test_file = list(output_dir.glob("*.java"))[0]
content = test_file.read_text(encoding="utf-8")
assert "class ReplayTest_" in content
assert "public class ReplayTest_" not in content
class TestGenerateReplayTestsJunit4:
def test_generates_junit4_imports(self, trace_db: Path, tmp_path: Path) -> None:
output_dir = tmp_path / "output"
count = generate_replay_tests(trace_db, output_dir, tmp_path, test_framework="junit4")
assert count == 1
test_file = list(output_dir.glob("*.java"))[0]
content = test_file.read_text(encoding="utf-8")
assert "import org.junit.Test;" in content
assert "import org.junit.AfterClass;" in content
assert "org.junit.jupiter" not in content
def test_junit4_methods_are_public(self, trace_db: Path, tmp_path: Path) -> None:
output_dir = tmp_path / "output"
generate_replay_tests(trace_db, output_dir, tmp_path, test_framework="junit4")
test_file = list(output_dir.glob("*.java"))[0]
content = test_file.read_text(encoding="utf-8")
assert "@Test public void replay_add_0()" in content
def test_junit4_class_is_public(self, trace_db: Path, tmp_path: Path) -> None:
output_dir = tmp_path / "output"
generate_replay_tests(trace_db, output_dir, tmp_path, test_framework="junit4")
test_file = list(output_dir.glob("*.java"))[0]
content = test_file.read_text(encoding="utf-8")
assert "public class ReplayTest_" in content
def test_junit4_cleanup_uses_afterclass(self, trace_db: Path, tmp_path: Path) -> None:
output_dir = tmp_path / "output"
generate_replay_tests(trace_db, output_dir, tmp_path, test_framework="junit4")
test_file = list(output_dir.glob("*.java"))[0]
content = test_file.read_text(encoding="utf-8")
assert "@AfterClass" in content
assert "@AfterAll" not in content
class TestOverloadedMethods:
def test_no_duplicate_method_names(self, trace_db_overloaded: Path, tmp_path: Path) -> None:
output_dir = tmp_path / "output"
count = generate_replay_tests(trace_db_overloaded, output_dir, tmp_path)
assert count == 1
test_file = list(output_dir.glob("*.java"))[0]
content = test_file.read_text(encoding="utf-8")
# Should have 5 unique methods (3 from first overload + 2 from second)
assert "replay_estimateKeySize_0" in content
assert "replay_estimateKeySize_1" in content
assert "replay_estimateKeySize_2" in content
assert "replay_estimateKeySize_3" in content
assert "replay_estimateKeySize_4" in content
# Verify no duplicates by counting occurrences
lines = content.splitlines()
method_lines = [l for l in lines if "void replay_estimateKeySize_" in l]
method_names = [l.split("void ")[1].split("(")[0] for l in method_lines]
assert len(method_names) == len(set(method_names)), f"Duplicate methods: {method_names}"
class TestReplayTestInstrumentation:
def test_replay_tests_instrumented_correctly(self, trace_db: Path, tmp_path: Path) -> None:
"""Replay tests with compact @Test lines should be instrumented without orphaned code."""
from codeflash.languages.java.discovery import discover_functions_from_source
output_dir = tmp_path / "output"
generate_replay_tests(trace_db, output_dir, tmp_path)
test_file = list(output_dir.glob("*.java"))[0]
src = "public class Calculator { public int add(int a, int b) { return a + b; } }"
funcs = discover_functions_from_source(src, tmp_path / "Calculator.java")
target = funcs[0]
from codeflash.languages.java.support import JavaSupport
support = JavaSupport()
success, instrumented = support.instrument_existing_test(
test_path=test_file,
call_positions=[],
function_to_optimize=target,
tests_project_root=tmp_path,
mode="behavior",
)
assert success
assert instrumented is not None
assert "__perfinstrumented" in instrumented
# Verify no code outside class body
lines = instrumented.splitlines()
class_closed = False
for line in lines:
if line.strip() == "}" and not line.startswith(" "):
class_closed = True
elif class_closed and line.strip() and not line.strip().startswith("//"):
pytest.fail(f"Orphaned code outside class: {line}")
def test_replay_tests_perf_instrumented(self, trace_db: Path, tmp_path: Path) -> None:
from codeflash.languages.java.discovery import discover_functions_from_source
output_dir = tmp_path / "output"
generate_replay_tests(trace_db, output_dir, tmp_path)
test_file = list(output_dir.glob("*.java"))[0]
src = "public class Calculator { public int add(int a, int b) { return a + b; } }"
funcs = discover_functions_from_source(src, tmp_path / "Calculator.java")
target = funcs[0]
from codeflash.languages.java.support import JavaSupport
support = JavaSupport()
success, instrumented = support.instrument_existing_test(
test_path=test_file,
call_positions=[],
function_to_optimize=target,
tests_project_root=tmp_path,
mode="performance",
)
assert success
assert "__perfonlyinstrumented" in instrumented
def test_regular_tests_still_instrumented(self, tmp_path: Path) -> None:
from codeflash.languages.java.discovery import discover_functions_from_source
src = "public class Calculator { public int add(int a, int b) { return a + b; } }"
funcs = discover_functions_from_source(src, tmp_path / "Calculator.java")
target = funcs[0]
test_file = tmp_path / "CalculatorTest.java"
test_file.write_text(
"""
import org.junit.jupiter.api.Test;
public class CalculatorTest {
@Test
public void testAdd() {
Calculator calc = new Calculator();
calc.add(1, 2);
}
}
""",
encoding="utf-8",
)
from codeflash.languages.java.support import JavaSupport
support = JavaSupport()
success, instrumented = support.instrument_existing_test(
test_path=test_file,
call_positions=[],
function_to_optimize=target,
tests_project_root=tmp_path,
mode="behavior",
)
assert success
assert "CODEFLASH_LOOP_INDEX" in instrumented