mirror of
https://github.com/codeflash-ai/codeflash.git
synced 2026-05-04 18:25:17 +00:00
Merge pull request #1880 from codeflash-ai/java-config-redesign
feat: zero-config Java projects + smart ReplayHelper for end-to-end optimization
This commit is contained in:
commit
15a92613e2
19 changed files with 992 additions and 121 deletions
12
.github/workflows/e2e-java-tracer.yaml
vendored
12
.github/workflows/e2e-java-tracer.yaml
vendored
|
|
@ -3,17 +3,9 @@ name: E2E - Java Tracer
|
|||
on:
|
||||
pull_request:
|
||||
paths:
|
||||
- 'codeflash/languages/java/**'
|
||||
- 'codeflash/languages/base.py'
|
||||
- 'codeflash/languages/registry.py'
|
||||
- 'codeflash/tracer.py'
|
||||
- 'codeflash/benchmarking/function_ranker.py'
|
||||
- 'codeflash/discovery/functions_to_optimize.py'
|
||||
- 'codeflash/optimization/**'
|
||||
- 'codeflash/verification/**'
|
||||
- 'codeflash/**'
|
||||
- 'codeflash-java-runtime/**'
|
||||
- 'tests/test_languages/fixtures/java_tracer_e2e/**'
|
||||
- 'tests/scripts/end_to_end_test_java_tracer.py'
|
||||
- 'tests/**'
|
||||
- '.github/workflows/e2e-java-tracer.yaml'
|
||||
|
||||
workflow_dispatch:
|
||||
|
|
|
|||
|
|
@ -12,20 +12,181 @@ import org.objectweb.asm.Type;
|
|||
|
||||
public class ReplayHelper {
|
||||
|
||||
private final Connection db;
|
||||
private final Connection traceDb;
|
||||
|
||||
// Codeflash instrumentation state — read from environment variables once
|
||||
private final String mode; // "behavior", "performance", or null
|
||||
private final int loopIndex;
|
||||
private final String testIteration;
|
||||
private final String outputFile; // SQLite path for behavior capture
|
||||
private final int innerIterations; // for performance looping
|
||||
|
||||
// Behavior mode: lazily opened SQLite connection for writing results
|
||||
private Connection behaviorDb;
|
||||
private boolean behaviorDbInitialized;
|
||||
|
||||
public ReplayHelper(String traceDbPath) {
|
||||
try {
|
||||
this.db = DriverManager.getConnection("jdbc:sqlite:" + traceDbPath);
|
||||
this.traceDb = DriverManager.getConnection("jdbc:sqlite:" + traceDbPath);
|
||||
} catch (SQLException e) {
|
||||
throw new RuntimeException("Failed to open trace database: " + traceDbPath, e);
|
||||
}
|
||||
|
||||
// Read codeflash instrumentation env vars (set by the test runner)
|
||||
this.mode = System.getenv("CODEFLASH_MODE");
|
||||
this.loopIndex = parseIntEnv("CODEFLASH_LOOP_INDEX", 1);
|
||||
this.testIteration = getEnvOrDefault("CODEFLASH_TEST_ITERATION", "0");
|
||||
this.outputFile = System.getenv("CODEFLASH_OUTPUT_FILE");
|
||||
this.innerIterations = parseIntEnv("CODEFLASH_INNER_ITERATIONS", 10);
|
||||
}
|
||||
|
||||
public void replay(String className, String methodName, String descriptor, int invocationIndex) throws Exception {
|
||||
// Query the function_calls table for this method at the given index
|
||||
// Deserialize args and resolve method (done once, outside timing)
|
||||
Object[] allArgs = loadArgs(className, methodName, descriptor, invocationIndex);
|
||||
Class<?> targetClass = Class.forName(className);
|
||||
|
||||
Type[] paramTypes = Type.getArgumentTypes(descriptor);
|
||||
Class<?>[] paramClasses = new Class<?>[paramTypes.length];
|
||||
for (int i = 0; i < paramTypes.length; i++) {
|
||||
paramClasses[i] = typeToClass(paramTypes[i]);
|
||||
}
|
||||
|
||||
Method method = targetClass.getDeclaredMethod(methodName, paramClasses);
|
||||
method.setAccessible(true);
|
||||
boolean isStatic = Modifier.isStatic(method.getModifiers());
|
||||
|
||||
Object instance = null;
|
||||
if (!isStatic) {
|
||||
try {
|
||||
java.lang.reflect.Constructor<?> ctor = targetClass.getDeclaredConstructor();
|
||||
ctor.setAccessible(true);
|
||||
instance = ctor.newInstance();
|
||||
} catch (NoSuchMethodException e) {
|
||||
instance = new org.objenesis.ObjenesisStd().newInstance(targetClass);
|
||||
}
|
||||
}
|
||||
|
||||
// Get the calling test method name from the stack trace
|
||||
String testMethodName = getCallingTestMethodName();
|
||||
// Module name = the test class that called us
|
||||
String testClassName = getCallingTestClassName();
|
||||
|
||||
if ("behavior".equals(mode)) {
|
||||
replayBehavior(method, instance, allArgs, className, methodName, testClassName, testMethodName);
|
||||
} else if ("performance".equals(mode)) {
|
||||
replayPerformance(method, instance, allArgs, className, methodName, testClassName, testMethodName);
|
||||
} else {
|
||||
// No codeflash mode — just invoke (trace-only or manual testing)
|
||||
method.invoke(instance, allArgs);
|
||||
}
|
||||
}
|
||||
|
||||
private void replayBehavior(Method method, Object instance, Object[] args,
|
||||
String className, String methodName,
|
||||
String testClassName, String testMethodName) throws Exception {
|
||||
// testIteration goes at the END so the Comparator's lastUnderscore stripping
|
||||
// removes it, making baseline (iteration=0) and candidate (iteration=N) keys match.
|
||||
String invId = testMethodName + "_" + testIteration;
|
||||
|
||||
// Print start marker (same format as behavior instrumentation)
|
||||
System.out.println("!$######" + testClassName + ":" + testClassName + "." + testMethodName
|
||||
+ ":" + methodName + ":" + loopIndex + ":" + invId + "######$!");
|
||||
|
||||
long startNs = System.nanoTime();
|
||||
Object result;
|
||||
try {
|
||||
result = method.invoke(instance, args);
|
||||
} catch (java.lang.reflect.InvocationTargetException e) {
|
||||
throw (Exception) e.getCause();
|
||||
}
|
||||
long durationNs = System.nanoTime() - startNs;
|
||||
|
||||
// Print end marker
|
||||
System.out.println("!######" + testClassName + ":" + testClassName + "." + testMethodName
|
||||
+ ":" + methodName + ":" + loopIndex + ":" + invId + ":" + durationNs + "######!");
|
||||
|
||||
// Write return value to SQLite for correctness comparison
|
||||
if (outputFile != null && !outputFile.isEmpty()) {
|
||||
writeBehaviorResult(testClassName, testMethodName, methodName, invId, durationNs, result);
|
||||
}
|
||||
}
|
||||
|
||||
private void replayPerformance(Method method, Object instance, Object[] args,
|
||||
String className, String methodName,
|
||||
String testClassName, String testMethodName) throws Exception {
|
||||
// Performance mode: run inner loop for JIT warmup, print timing for each iteration
|
||||
int maxInner = innerIterations;
|
||||
for (int inner = 0; inner < maxInner; inner++) {
|
||||
int loopId = (loopIndex - 1) * maxInner + inner;
|
||||
String invId = testMethodName;
|
||||
|
||||
// Print start marker
|
||||
System.out.println("!$######" + testClassName + ":" + testClassName + "." + testMethodName
|
||||
+ ":" + methodName + ":" + loopId + ":" + invId + "######$!");
|
||||
|
||||
long startNs = System.nanoTime();
|
||||
try {
|
||||
method.invoke(instance, args);
|
||||
} catch (java.lang.reflect.InvocationTargetException e) {
|
||||
// Swallow — performance mode doesn't check correctness
|
||||
}
|
||||
long durationNs = System.nanoTime() - startNs;
|
||||
|
||||
// Print end marker
|
||||
System.out.println("!######" + testClassName + ":" + testClassName + "." + testMethodName
|
||||
+ ":" + methodName + ":" + loopId + ":" + invId + ":" + durationNs + "######!");
|
||||
}
|
||||
}
|
||||
|
||||
private void writeBehaviorResult(String testClassName, String testMethodName,
|
||||
String functionName, String invId,
|
||||
long durationNs, Object result) {
|
||||
try {
|
||||
ensureBehaviorDb();
|
||||
String sql = "INSERT INTO test_results VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)";
|
||||
try (PreparedStatement ps = behaviorDb.prepareStatement(sql)) {
|
||||
ps.setString(1, testClassName); // test_module_path
|
||||
ps.setString(2, testClassName); // test_class_name
|
||||
ps.setString(3, testMethodName); // test_function_name
|
||||
ps.setString(4, functionName); // function_getting_tested
|
||||
ps.setInt(5, loopIndex); // loop_index
|
||||
ps.setString(6, invId); // iteration_id
|
||||
ps.setLong(7, durationNs); // runtime
|
||||
ps.setBytes(8, serializeResult(result)); // return_value
|
||||
ps.setString(9, "function_call"); // verification_type
|
||||
ps.executeUpdate();
|
||||
}
|
||||
} catch (Exception e) {
|
||||
System.err.println("ReplayHelper: SQLite behavior write error: " + e.getMessage());
|
||||
}
|
||||
}
|
||||
|
||||
private void ensureBehaviorDb() throws SQLException {
|
||||
if (behaviorDbInitialized) return;
|
||||
behaviorDbInitialized = true;
|
||||
behaviorDb = DriverManager.getConnection("jdbc:sqlite:" + outputFile);
|
||||
try (java.sql.Statement stmt = behaviorDb.createStatement()) {
|
||||
stmt.execute("CREATE TABLE IF NOT EXISTS test_results (" +
|
||||
"test_module_path TEXT, test_class_name TEXT, test_function_name TEXT, " +
|
||||
"function_getting_tested TEXT, loop_index INTEGER, iteration_id TEXT, " +
|
||||
"runtime INTEGER, return_value BLOB, verification_type TEXT)");
|
||||
}
|
||||
}
|
||||
|
||||
private byte[] serializeResult(Object result) {
|
||||
if (result == null) return null;
|
||||
try {
|
||||
return Serializer.serialize(result);
|
||||
} catch (Exception e) {
|
||||
// Fall back to String.valueOf if Kryo fails
|
||||
return String.valueOf(result).getBytes(java.nio.charset.StandardCharsets.UTF_8);
|
||||
}
|
||||
}
|
||||
|
||||
private Object[] loadArgs(String className, String methodName, String descriptor, int invocationIndex)
|
||||
throws SQLException {
|
||||
byte[] argsBlob;
|
||||
try (PreparedStatement stmt = db.prepareStatement(
|
||||
try (PreparedStatement stmt = traceDb.prepareStatement(
|
||||
"SELECT args FROM function_calls " +
|
||||
"WHERE classname = ? AND function = ? AND descriptor = ? " +
|
||||
"ORDER BY time_ns LIMIT 1 OFFSET ?")) {
|
||||
|
|
@ -43,46 +204,35 @@ public class ReplayHelper {
|
|||
}
|
||||
}
|
||||
|
||||
// Deserialize args
|
||||
Object deserialized = Serializer.deserialize(argsBlob);
|
||||
if (!(deserialized instanceof Object[])) {
|
||||
throw new RuntimeException("Deserialized args is not Object[], got: "
|
||||
+ (deserialized == null ? "null" : deserialized.getClass().getName()));
|
||||
}
|
||||
Object[] allArgs = (Object[]) deserialized;
|
||||
return (Object[]) deserialized;
|
||||
}
|
||||
|
||||
// Load the target class
|
||||
Class<?> targetClass = Class.forName(className);
|
||||
|
||||
// Parse descriptor to find parameter types
|
||||
Type[] paramTypes = Type.getArgumentTypes(descriptor);
|
||||
Class<?>[] paramClasses = new Class<?>[paramTypes.length];
|
||||
for (int i = 0; i < paramTypes.length; i++) {
|
||||
paramClasses[i] = typeToClass(paramTypes[i]);
|
||||
}
|
||||
|
||||
// Find the method
|
||||
Method method = targetClass.getDeclaredMethod(methodName, paramClasses);
|
||||
method.setAccessible(true);
|
||||
|
||||
boolean isStatic = Modifier.isStatic(method.getModifiers());
|
||||
|
||||
if (isStatic) {
|
||||
method.invoke(null, allArgs);
|
||||
} else {
|
||||
// Args contain only explicit parameters (no 'this').
|
||||
// Create a default instance via no-arg constructor or Kryo.
|
||||
Object instance;
|
||||
try {
|
||||
java.lang.reflect.Constructor<?> ctor = targetClass.getDeclaredConstructor();
|
||||
ctor.setAccessible(true);
|
||||
instance = ctor.newInstance();
|
||||
} catch (NoSuchMethodException e) {
|
||||
// Fall back to Objenesis instantiation (no constructor needed)
|
||||
instance = new org.objenesis.ObjenesisStd().newInstance(targetClass);
|
||||
private static String getCallingTestMethodName() {
|
||||
StackTraceElement[] stack = Thread.currentThread().getStackTrace();
|
||||
// Walk up: [0]=getStackTrace, [1]=this method, [2]=replay(), [3]=calling test method
|
||||
for (int i = 3; i < stack.length; i++) {
|
||||
String method = stack[i].getMethodName();
|
||||
if (method.startsWith("replay_")) {
|
||||
return method;
|
||||
}
|
||||
method.invoke(instance, allArgs);
|
||||
}
|
||||
return stack.length > 3 ? stack[3].getMethodName() : "unknown";
|
||||
}
|
||||
|
||||
private static String getCallingTestClassName() {
|
||||
StackTraceElement[] stack = Thread.currentThread().getStackTrace();
|
||||
for (int i = 3; i < stack.length; i++) {
|
||||
String cls = stack[i].getClassName();
|
||||
if (cls.contains("ReplayTest") || cls.contains("replay")) {
|
||||
return cls;
|
||||
}
|
||||
}
|
||||
return stack.length > 3 ? stack[3].getClassName() : "unknown";
|
||||
}
|
||||
|
||||
private static Class<?> typeToClass(Type type) throws ClassNotFoundException {
|
||||
|
|
@ -106,11 +256,23 @@ public class ReplayHelper {
|
|||
}
|
||||
}
|
||||
|
||||
private static int parseIntEnv(String name, int defaultValue) {
|
||||
String val = System.getenv(name);
|
||||
if (val == null || val.isEmpty()) return defaultValue;
|
||||
try { return Integer.parseInt(val); } catch (NumberFormatException e) { return defaultValue; }
|
||||
}
|
||||
|
||||
private static String getEnvOrDefault(String name, String defaultValue) {
|
||||
String val = System.getenv(name);
|
||||
return (val != null && !val.isEmpty()) ? val : defaultValue;
|
||||
}
|
||||
|
||||
public void close() {
|
||||
try {
|
||||
if (db != null) db.close();
|
||||
} catch (SQLException e) {
|
||||
System.err.println("Error closing ReplayHelper: " + e.getMessage());
|
||||
try { if (traceDb != null) traceDb.close(); } catch (SQLException e) {
|
||||
System.err.println("Error closing ReplayHelper trace db: " + e.getMessage());
|
||||
}
|
||||
try { if (behaviorDb != null) behaviorDb.close(); } catch (SQLException e) {
|
||||
System.err.println("Error closing ReplayHelper behavior db: " + e.getMessage());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -22,6 +22,7 @@ public final class TraceRecorder {
|
|||
private final TracerConfig config;
|
||||
private final TraceWriter writer;
|
||||
private final ConcurrentHashMap<String, AtomicInteger> functionCounts = new ConcurrentHashMap<>();
|
||||
private final AtomicInteger droppedCaptures = new AtomicInteger(0);
|
||||
private final int maxFunctionCount;
|
||||
private final ExecutorService serializerExecutor;
|
||||
|
||||
|
|
@ -82,11 +83,13 @@ public final class TraceRecorder {
|
|||
argsBlob = future.get(SERIALIZATION_TIMEOUT_MS, TimeUnit.MILLISECONDS);
|
||||
} catch (TimeoutException e) {
|
||||
future.cancel(true);
|
||||
droppedCaptures.incrementAndGet();
|
||||
System.err.println("[codeflash-tracer] Serialization timed out for " + className + "."
|
||||
+ methodName);
|
||||
return;
|
||||
} catch (Exception e) {
|
||||
Throwable cause = e.getCause() != null ? e.getCause() : e;
|
||||
droppedCaptures.incrementAndGet();
|
||||
System.err.println("[codeflash-tracer] Serialization failed for " + className + "."
|
||||
+ methodName + ": " + cause.getClass().getSimpleName() + ": " + cause.getMessage());
|
||||
return;
|
||||
|
|
@ -113,11 +116,15 @@ public final class TraceRecorder {
|
|||
}
|
||||
metadata.put("totalCaptures", String.valueOf(totalCaptures));
|
||||
|
||||
int dropped = droppedCaptures.get();
|
||||
metadata.put("droppedCaptures", String.valueOf(dropped));
|
||||
|
||||
writer.writeMetadata(metadata);
|
||||
writer.flush();
|
||||
writer.close();
|
||||
|
||||
System.err.println("[codeflash-tracer] Captured " + totalCaptures
|
||||
+ " invocations across " + functionCounts.size() + " methods");
|
||||
+ " invocations across " + functionCounts.size() + " methods"
|
||||
+ (dropped > 0 ? " (" + dropped + " dropped due to serialization timeout/failure)" : ""));
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -4,14 +4,20 @@ import org.objectweb.asm.ClassVisitor;
|
|||
import org.objectweb.asm.MethodVisitor;
|
||||
import org.objectweb.asm.Opcodes;
|
||||
|
||||
import java.util.Collections;
|
||||
import java.util.Map;
|
||||
|
||||
public class TracingClassVisitor extends ClassVisitor {
|
||||
|
||||
private final String internalClassName;
|
||||
private final Map<String, Integer> methodLineNumbers;
|
||||
private String sourceFile;
|
||||
|
||||
public TracingClassVisitor(ClassVisitor classVisitor, String internalClassName) {
|
||||
public TracingClassVisitor(ClassVisitor classVisitor, String internalClassName,
|
||||
Map<String, Integer> methodLineNumbers) {
|
||||
super(Opcodes.ASM9, classVisitor);
|
||||
this.internalClassName = internalClassName;
|
||||
this.methodLineNumbers = methodLineNumbers != null ? methodLineNumbers : Collections.emptyMap();
|
||||
}
|
||||
|
||||
@Override
|
||||
|
|
@ -37,7 +43,8 @@ public class TracingClassVisitor extends ClassVisitor {
|
|||
return mv;
|
||||
}
|
||||
|
||||
int lineNumber = methodLineNumbers.getOrDefault(name + descriptor, 0);
|
||||
return new TracingMethodAdapter(mv, access, name, descriptor,
|
||||
internalClassName, 0, sourceFile != null ? sourceFile : "");
|
||||
internalClassName, lineNumber, sourceFile != null ? sourceFile : "");
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,10 +1,16 @@
|
|||
package com.codeflash.tracer;
|
||||
|
||||
import org.objectweb.asm.ClassReader;
|
||||
import org.objectweb.asm.ClassVisitor;
|
||||
import org.objectweb.asm.ClassWriter;
|
||||
import org.objectweb.asm.Label;
|
||||
import org.objectweb.asm.MethodVisitor;
|
||||
import org.objectweb.asm.Opcodes;
|
||||
|
||||
import java.lang.instrument.ClassFileTransformer;
|
||||
import java.security.ProtectionDomain;
|
||||
import java.util.HashMap;
|
||||
import java.util.Map;
|
||||
|
||||
public class TracingTransformer implements ClassFileTransformer {
|
||||
|
||||
|
|
@ -22,11 +28,6 @@ public class TracingTransformer implements ClassFileTransformer {
|
|||
return null;
|
||||
}
|
||||
|
||||
// Skip instrumentation if we're inside a recording call (e.g., during Kryo serialization)
|
||||
if (TraceRecorder.isRecording()) {
|
||||
return null;
|
||||
}
|
||||
|
||||
// Skip internal JDK, framework, and synthetic classes
|
||||
if (className.startsWith("java/")
|
||||
|| className.startsWith("javax/")
|
||||
|
|
@ -51,6 +52,30 @@ public class TracingTransformer implements ClassFileTransformer {
|
|||
|
||||
private byte[] instrumentClass(String internalClassName, byte[] bytecode) {
|
||||
ClassReader cr = new ClassReader(bytecode);
|
||||
|
||||
// Pre-scan: collect the first source line number for each method.
|
||||
// ASM's visitMethod() doesn't provide line info — it arrives later via visitLineNumber().
|
||||
// We do a lightweight read pass first so the instrumentation pass has accurate line numbers.
|
||||
Map<String, Integer> methodLineNumbers = new HashMap<>();
|
||||
cr.accept(new ClassVisitor(Opcodes.ASM9) {
|
||||
@Override
|
||||
public MethodVisitor visitMethod(int access, String name, String descriptor,
|
||||
String signature, String[] exceptions) {
|
||||
String key = name + descriptor;
|
||||
return new MethodVisitor(Opcodes.ASM9) {
|
||||
private boolean captured = false;
|
||||
|
||||
@Override
|
||||
public void visitLineNumber(int line, Label start) {
|
||||
if (!captured) {
|
||||
methodLineNumbers.put(key, line);
|
||||
captured = true;
|
||||
}
|
||||
}
|
||||
};
|
||||
}
|
||||
}, ClassReader.SKIP_FRAMES);
|
||||
|
||||
// Use COMPUTE_MAXS only (not COMPUTE_FRAMES) to preserve original stack map frames.
|
||||
// COMPUTE_FRAMES recomputes all frames and calls getCommonSuperClass() which either
|
||||
// triggers classloader deadlocks or produces incorrect frames when returning "java/lang/Object".
|
||||
|
|
@ -58,7 +83,7 @@ public class TracingTransformer implements ClassFileTransformer {
|
|||
// adjusts offsets for injected code. Our AdviceAdapter only injects at method entry
|
||||
// (before any branch points), so existing frames remain valid.
|
||||
ClassWriter cw = new ClassWriter(cr, ClassWriter.COMPUTE_MAXS);
|
||||
TracingClassVisitor cv = new TracingClassVisitor(cw, internalClassName);
|
||||
TracingClassVisitor cv = new TracingClassVisitor(cw, internalClassName, methodLineNumbers);
|
||||
cr.accept(cv, ClassReader.EXPAND_FRAMES);
|
||||
return cw.toByteArray();
|
||||
}
|
||||
|
|
|
|||
|
|
@ -376,6 +376,7 @@ def _build_parser() -> ArgumentParser:
|
|||
subparsers.add_parser("vscode-install", help="Install the Codeflash VSCode extension")
|
||||
subparsers.add_parser("init-actions", help="Initialize GitHub Actions workflow")
|
||||
|
||||
trace_optimize = subparsers.add_parser("optimize", help="Trace and optimize your project.", add_help=False)
|
||||
auth_parser = subparsers.add_parser("auth", help="Authentication commands")
|
||||
auth_subparsers = auth_parser.add_subparsers(dest="auth_command", help="Auth sub-commands")
|
||||
auth_subparsers.add_parser("login", help="Log in to Codeflash via OAuth")
|
||||
|
|
@ -391,8 +392,6 @@ def _build_parser() -> ArgumentParser:
|
|||
compare_parser.add_argument("--timeout", type=int, default=600, help="Benchmark timeout in seconds (default: 600)")
|
||||
compare_parser.add_argument("--config-file", type=str, dest="config_file", help="Path to pyproject.toml")
|
||||
|
||||
trace_optimize = subparsers.add_parser("optimize", help="Trace and optimize your project.")
|
||||
|
||||
trace_optimize.add_argument(
|
||||
"--max-function-count",
|
||||
type=int,
|
||||
|
|
|
|||
|
|
@ -554,11 +554,13 @@ def get_all_replay_test_functions(
|
|||
|
||||
|
||||
def _get_java_replay_test_functions(
|
||||
replay_test: list[Path], test_cfg: TestConfig, project_root_path: Path
|
||||
replay_test: list[Path], test_cfg: TestConfig, project_root_path: Path | str
|
||||
) -> tuple[dict[Path, list[FunctionToOptimize]], Path]:
|
||||
"""Parse Java replay test files to extract functions and trace file path."""
|
||||
from codeflash.languages.java.replay_test import parse_replay_test_metadata
|
||||
|
||||
project_root_path = Path(project_root_path)
|
||||
|
||||
trace_file_path: Path | None = None
|
||||
functions: dict[Path, list[FunctionToOptimize]] = defaultdict(list)
|
||||
|
||||
|
|
@ -602,7 +604,7 @@ def _get_java_replay_test_functions(
|
|||
all_functions = lang_support.discover_functions(source_code, source_file)
|
||||
|
||||
for func in all_functions:
|
||||
if func.function_name in function_names:
|
||||
if func.function_name in function_names or func.qualified_name in function_names:
|
||||
functions[source_file].append(func)
|
||||
|
||||
if trace_file_path is None:
|
||||
|
|
|
|||
|
|
@ -745,26 +745,35 @@ def _add_behavior_instrumentation(source: str, class_name: str, func_name: str,
|
|||
if _is_test_annotation(stripped):
|
||||
if not helper_added:
|
||||
helper_added = True
|
||||
result.append(line)
|
||||
i += 1
|
||||
|
||||
# Collect any additional annotations
|
||||
while i < len(lines) and lines[i].strip().startswith("@"):
|
||||
result.append(lines[i])
|
||||
# Check if the @Test line already contains the method signature and opening brace
|
||||
# (common in compact test styles like replay tests: @Test void replay_foo_0() throws Exception {)
|
||||
if "{" in line:
|
||||
# The annotation line IS the method signature — don't look for a separate one
|
||||
result.append(line)
|
||||
i += 1
|
||||
method_lines = [line]
|
||||
else:
|
||||
result.append(line)
|
||||
i += 1
|
||||
|
||||
# Now find the method signature and opening brace
|
||||
method_lines = []
|
||||
while i < len(lines):
|
||||
method_lines.append(lines[i])
|
||||
if "{" in lines[i]:
|
||||
break
|
||||
i += 1
|
||||
# Collect any additional annotations
|
||||
while i < len(lines) and lines[i].strip().startswith("@"):
|
||||
result.append(lines[i])
|
||||
i += 1
|
||||
|
||||
# Add the method signature lines
|
||||
for ml in method_lines:
|
||||
result.append(ml)
|
||||
i += 1
|
||||
# Now find the method signature and opening brace
|
||||
method_lines = []
|
||||
while i < len(lines):
|
||||
method_lines.append(lines[i])
|
||||
if "{" in lines[i]:
|
||||
break
|
||||
i += 1
|
||||
|
||||
# Add the method signature lines
|
||||
for ml in method_lines:
|
||||
result.append(ml)
|
||||
i += 1
|
||||
|
||||
# Extract the test method name from the method signature
|
||||
test_method_name = _extract_test_method_name(method_lines)
|
||||
|
|
|
|||
|
|
@ -2,6 +2,7 @@ from __future__ import annotations
|
|||
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import shutil
|
||||
import subprocess
|
||||
from datetime import datetime
|
||||
|
|
@ -42,6 +43,13 @@ class JfrProfile:
|
|||
candidate = Path(home) / "bin" / "jfr"
|
||||
if candidate.exists():
|
||||
return str(candidate)
|
||||
|
||||
java_home_env = os.environ.get("JAVA_HOME")
|
||||
if java_home_env:
|
||||
candidate = Path(java_home_env) / "bin" / "jfr"
|
||||
if candidate.exists():
|
||||
return str(candidate)
|
||||
|
||||
return None
|
||||
|
||||
def _parse(self) -> None:
|
||||
|
|
@ -152,6 +160,8 @@ class JfrProfile:
|
|||
method_name = method.get("name", "")
|
||||
if not class_name or not method_name:
|
||||
return None
|
||||
# JFR uses / separators (JVM internal format), normalize to dots for package matching
|
||||
class_name = class_name.replace("/", ".")
|
||||
return f"{class_name}.{method_name}"
|
||||
|
||||
def _store_method_info(self, key: str, frame: dict[str, Any]) -> None:
|
||||
|
|
@ -159,7 +169,7 @@ class JfrProfile:
|
|||
return
|
||||
method = frame.get("method", {})
|
||||
self._method_info[key] = {
|
||||
"class_name": method.get("type", {}).get("name", ""),
|
||||
"class_name": method.get("type", {}).get("name", "").replace("/", "."),
|
||||
"method_name": method.get("name", ""),
|
||||
"descriptor": method.get("descriptor", ""),
|
||||
"line_number": str(frame.get("lineNumber", 0)),
|
||||
|
|
|
|||
|
|
@ -906,7 +906,15 @@ class MavenStrategy(BuildToolStrategy):
|
|||
" --add-opens java.base/java.net=ALL-UNNAMED"
|
||||
" --add-opens java.base/java.util.zip=ALL-UNNAMED"
|
||||
)
|
||||
if javaagent_arg:
|
||||
if enable_coverage:
|
||||
# When coverage is enabled, JaCoCo's prepare-agent goal sets argLine via
|
||||
# @{argLine}. Overriding -DargLine would clobber the JaCoCo agent flag.
|
||||
# Pass add-opens and javaagent via JDK_JAVA_OPTIONS instead.
|
||||
jdk_opts_parts = [add_opens_flags]
|
||||
if javaagent_arg:
|
||||
jdk_opts_parts.insert(0, javaagent_arg)
|
||||
env["JDK_JAVA_OPTIONS"] = " ".join(jdk_opts_parts)
|
||||
elif javaagent_arg:
|
||||
cmd.append(f"-DargLine={javaagent_arg} {add_opens_flags}")
|
||||
else:
|
||||
cmd.append(f"-DargLine={add_opens_flags}")
|
||||
|
|
|
|||
|
|
@ -12,9 +12,12 @@ if TYPE_CHECKING:
|
|||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def generate_replay_tests(trace_db_path: Path, output_dir: Path, project_root: Path, max_run_count: int = 256) -> int:
|
||||
"""Generate JUnit 5 replay test files from a trace SQLite database.
|
||||
def generate_replay_tests(
|
||||
trace_db_path: Path, output_dir: Path, project_root: Path, max_run_count: int = 256, test_framework: str = "junit5"
|
||||
) -> int:
|
||||
"""Generate JUnit replay test files from a trace SQLite database.
|
||||
|
||||
Supports both JUnit 5 (default) and JUnit 4.
|
||||
Returns the number of test files generated.
|
||||
"""
|
||||
if not trace_db_path.exists():
|
||||
|
|
@ -44,22 +47,29 @@ def generate_replay_tests(trace_db_path: Path, output_dir: Path, project_root: P
|
|||
|
||||
test_methods_code: list[str] = []
|
||||
class_function_names: list[str] = []
|
||||
# Global test counter to avoid duplicate method names for overloaded Java methods
|
||||
method_name_counters: dict[str, int] = {}
|
||||
|
||||
for method_name, descriptor in method_list:
|
||||
# Count invocations for this method
|
||||
count_result = conn.execute(
|
||||
"SELECT COUNT(*) FROM function_calls WHERE classname = ? AND function = ? AND descriptor = ?",
|
||||
(classname, method_name, descriptor),
|
||||
).fetchone()
|
||||
invocation_count = min(count_result[0], max_run_count)
|
||||
|
||||
class_function_names.append(method_name)
|
||||
simple_class = classname.rsplit(".", 1)[-1]
|
||||
class_function_names.append(f"{simple_class}.{method_name}")
|
||||
safe_method = _sanitize_identifier(method_name)
|
||||
|
||||
for i in range(invocation_count):
|
||||
# Use a global counter per method name to avoid collisions on overloaded methods
|
||||
test_idx = method_name_counters.get(safe_method, 0)
|
||||
method_name_counters[safe_method] = test_idx + 1
|
||||
|
||||
escaped_descriptor = descriptor.replace('"', '\\"')
|
||||
access = "public " if test_framework == "junit4" else ""
|
||||
test_methods_code.append(
|
||||
f" @Test void replay_{safe_method}_{i}() throws Exception {{\n"
|
||||
f" @Test {access}void replay_{safe_method}_{test_idx}() throws Exception {{\n"
|
||||
f' helper.replay("{classname}", "{method_name}", '
|
||||
f'"{escaped_descriptor}", {i});\n'
|
||||
f" }}"
|
||||
|
|
@ -69,18 +79,28 @@ def generate_replay_tests(trace_db_path: Path, output_dir: Path, project_root: P
|
|||
|
||||
# Generate the test file
|
||||
functions_comment = ",".join(class_function_names)
|
||||
if test_framework == "junit4":
|
||||
test_imports = "import org.junit.Test;\nimport org.junit.AfterClass;\n"
|
||||
cleanup_annotation = "@AfterClass"
|
||||
class_modifier = "public "
|
||||
else:
|
||||
test_imports = "import org.junit.jupiter.api.Test;\nimport org.junit.jupiter.api.AfterAll;\n"
|
||||
cleanup_annotation = "@AfterAll"
|
||||
class_modifier = ""
|
||||
|
||||
test_content = (
|
||||
f"// codeflash:functions={functions_comment}\n"
|
||||
f"// codeflash:trace_file={trace_db_path.as_posix()}\n"
|
||||
f"// codeflash:classname={classname}\n"
|
||||
f"package codeflash.replay;\n\n"
|
||||
f"import org.junit.jupiter.api.Test;\n"
|
||||
f"import org.junit.jupiter.api.AfterAll;\n"
|
||||
f"{test_imports}"
|
||||
f"import com.codeflash.ReplayHelper;\n\n"
|
||||
f"class {test_class_name} {{\n"
|
||||
f"{class_modifier}class {test_class_name} {{\n"
|
||||
f" private static final ReplayHelper helper =\n"
|
||||
f' new ReplayHelper("{trace_db_path.as_posix()}");\n\n'
|
||||
f" @AfterAll static void cleanup() {{ helper.close(); }}\n\n" + "\n\n".join(test_methods_code) + "\n"
|
||||
f" {cleanup_annotation} public static void cleanup() {{ helper.close(); }}\n\n"
|
||||
+ "\n\n".join(test_methods_code)
|
||||
+ "\n"
|
||||
"}\n"
|
||||
)
|
||||
|
||||
|
|
|
|||
Binary file not shown.
|
|
@ -14,6 +14,39 @@ if TYPE_CHECKING:
|
|||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
GRACEFUL_SHUTDOWN_WAIT = 5 # seconds to wait after SIGTERM before SIGKILL
|
||||
|
||||
|
||||
def _run_java_with_graceful_timeout(
|
||||
java_command: list[str], env: dict[str, str], timeout: int, stage_name: str
|
||||
) -> None:
|
||||
"""Run a Java command with graceful timeout handling.
|
||||
|
||||
Sends SIGTERM first (allowing JFR dump and shutdown hooks to run),
|
||||
then SIGKILL if the process doesn't exit within GRACEFUL_SHUTDOWN_WAIT seconds.
|
||||
"""
|
||||
if not timeout:
|
||||
subprocess.run(java_command, env=env, check=False)
|
||||
return
|
||||
|
||||
import signal
|
||||
|
||||
proc = subprocess.Popen(java_command, env=env)
|
||||
try:
|
||||
proc.wait(timeout=timeout)
|
||||
except subprocess.TimeoutExpired:
|
||||
logger.warning(
|
||||
"%s stage timed out after %d seconds, sending SIGTERM for graceful shutdown...", stage_name, timeout
|
||||
)
|
||||
proc.send_signal(signal.SIGTERM)
|
||||
try:
|
||||
proc.wait(timeout=GRACEFUL_SHUTDOWN_WAIT)
|
||||
except subprocess.TimeoutExpired:
|
||||
logger.warning("%s stage did not exit after SIGTERM, sending SIGKILL", stage_name)
|
||||
proc.kill()
|
||||
proc.wait()
|
||||
|
||||
|
||||
# --add-opens flags needed for Kryo serialization on Java 16+
|
||||
ADD_OPENS_FLAGS = (
|
||||
"--add-opens=java.base/java.util=ALL-UNNAMED "
|
||||
|
|
@ -48,10 +81,7 @@ class JavaTracer:
|
|||
# Stage 1: JFR Profiling
|
||||
logger.info("Stage 1: Running JFR profiling...")
|
||||
jfr_env = self.build_jfr_env(jfr_file)
|
||||
try:
|
||||
subprocess.run(java_command, env=jfr_env, check=False, timeout=timeout or None)
|
||||
except subprocess.TimeoutExpired:
|
||||
logger.warning("JFR profiling stage timed out after %d seconds", timeout)
|
||||
_run_java_with_graceful_timeout(java_command, jfr_env, timeout, "JFR profiling")
|
||||
|
||||
if not jfr_file.exists():
|
||||
logger.warning("JFR file was not created at %s", jfr_file)
|
||||
|
|
@ -62,10 +92,7 @@ class JavaTracer:
|
|||
trace_db_path, packages, project_root=project_root, max_function_count=max_function_count, timeout=timeout
|
||||
)
|
||||
agent_env = self.build_agent_env(config_path)
|
||||
try:
|
||||
subprocess.run(java_command, env=agent_env, check=False, timeout=timeout or None)
|
||||
except subprocess.TimeoutExpired:
|
||||
logger.warning("Argument capture stage timed out after %d seconds", timeout)
|
||||
_run_java_with_graceful_timeout(java_command, agent_env, timeout, "Argument capture")
|
||||
|
||||
if not trace_db_path.exists():
|
||||
logger.error("Trace database was not created at %s", trace_db_path)
|
||||
|
|
@ -95,7 +122,12 @@ class JavaTracer:
|
|||
|
||||
def build_jfr_env(self, jfr_file: Path) -> dict[str, str]:
|
||||
env = os.environ.copy()
|
||||
jfr_opts = f"-XX:StartFlightRecording=filename={jfr_file.resolve()},settings=profile,dumponexit=true"
|
||||
# Use profile settings with increased sampling frequency (1ms instead of default 10ms)
|
||||
# This captures more samples for short-running programs
|
||||
jfr_opts = (
|
||||
f"-XX:StartFlightRecording=filename={jfr_file.resolve()},settings=profile,dumponexit=true"
|
||||
",jdk.ExecutionSample#period=1ms"
|
||||
)
|
||||
existing = env.get("JAVA_TOOL_OPTIONS", "")
|
||||
env["JAVA_TOOL_OPTIONS"] = f"{existing} {jfr_opts}".strip()
|
||||
return env
|
||||
|
|
@ -133,7 +165,7 @@ class JavaTracer:
|
|||
if stripped.startswith("package "):
|
||||
pkg = stripped[8:].rstrip(";").strip()
|
||||
parts = pkg.split(".")
|
||||
prefix = ".".join(parts[: min(2, len(parts))])
|
||||
prefix = ".".join(parts[: min(3, len(parts))])
|
||||
packages.add(prefix)
|
||||
break
|
||||
if stripped and not stripped.startswith("//"):
|
||||
|
|
@ -153,6 +185,7 @@ def run_java_tracer(
|
|||
max_function_count: int = 256,
|
||||
timeout: int = 0,
|
||||
max_run_count: int = 256,
|
||||
test_framework: str = "junit5",
|
||||
) -> tuple[Path, Path, int]:
|
||||
"""High-level entry point: trace a Java command and generate replay tests.
|
||||
|
||||
|
|
@ -169,7 +202,11 @@ def run_java_tracer(
|
|||
)
|
||||
|
||||
test_count = generate_replay_tests(
|
||||
trace_db_path=trace_db, output_dir=output_dir, project_root=project_root, max_run_count=max_run_count
|
||||
trace_db_path=trace_db,
|
||||
output_dir=output_dir,
|
||||
project_root=project_root,
|
||||
max_run_count=max_run_count,
|
||||
test_framework=test_framework,
|
||||
)
|
||||
|
||||
return trace_db, jfr_file, test_count
|
||||
|
|
|
|||
|
|
@ -950,7 +950,7 @@ class TestResults(BaseModel): # noqa: PLW1641
|
|||
by_id: dict[InvocationId, list[int]] = {}
|
||||
for result in self.test_results:
|
||||
if result.did_pass:
|
||||
if result.runtime:
|
||||
if result.runtime is not None:
|
||||
by_id.setdefault(result.id, []).append(result.runtime)
|
||||
else:
|
||||
msg = (
|
||||
|
|
|
|||
|
|
@ -349,8 +349,12 @@ def _run_java_tracer(existing_args: Namespace | None = None) -> ArgumentParser:
|
|||
max_function_count = getattr(config, "max_function_count", 256)
|
||||
timeout = int(getattr(config, "timeout", None) or getattr(config, "tracer_timeout", 0) or 0)
|
||||
|
||||
console.print("[bold]Java project detected[/]")
|
||||
console.print(f" Project root: {project_root}")
|
||||
console.print(f" Module root: {getattr(config, 'module_root', '?')}")
|
||||
console.print(f" Tests root: {getattr(config, 'tests_root', '?')}")
|
||||
|
||||
from codeflash.code_utils.code_utils import get_run_tmp_file
|
||||
from codeflash.languages.java.build_tools import find_test_root
|
||||
from codeflash.languages.java.tracer import JavaTracer, run_java_tracer
|
||||
|
||||
tracer = JavaTracer()
|
||||
|
|
@ -360,12 +364,16 @@ def _run_java_tracer(existing_args: Namespace | None = None) -> ArgumentParser:
|
|||
|
||||
trace_db_path = get_run_tmp_file(Path("java_trace.db"))
|
||||
|
||||
# Place replay tests in the project's test source tree so Maven/Gradle can compile them
|
||||
test_root = find_test_root(project_root)
|
||||
if test_root:
|
||||
output_dir = test_root / "codeflash" / "replay"
|
||||
# Place replay tests in the project's test source tree so Maven/Gradle can compile them.
|
||||
# Use the config's tests_root (correctly resolved for multi-module projects) not find_test_root().
|
||||
tests_root = Path(getattr(config, "tests_root", ""))
|
||||
if tests_root.is_dir():
|
||||
output_dir = tests_root / "codeflash" / "replay"
|
||||
else:
|
||||
output_dir = project_root / "src" / "test" / "java" / "codeflash" / "replay"
|
||||
from codeflash.languages.java.build_tools import find_test_root
|
||||
|
||||
test_root = find_test_root(project_root)
|
||||
output_dir = (test_root or project_root / "src" / "test" / "java") / "codeflash" / "replay"
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Remaining args after our flags are the Java command
|
||||
|
|
@ -377,6 +385,12 @@ def _run_java_tracer(existing_args: Namespace | None = None) -> ArgumentParser:
|
|||
sys.exit(1)
|
||||
java_command = remaining
|
||||
|
||||
# Detect test framework for replay test generation
|
||||
from codeflash.languages.java.config import detect_java_project
|
||||
|
||||
java_config = detect_java_project(project_root)
|
||||
test_framework = java_config.test_framework if java_config else "junit5"
|
||||
|
||||
trace_db, jfr_file, test_count = run_java_tracer(
|
||||
java_command=java_command,
|
||||
trace_db_path=trace_db_path,
|
||||
|
|
@ -385,6 +399,7 @@ def _run_java_tracer(existing_args: Namespace | None = None) -> ArgumentParser:
|
|||
output_dir=output_dir,
|
||||
max_function_count=max_function_count,
|
||||
timeout=timeout,
|
||||
test_framework=test_framework,
|
||||
)
|
||||
|
||||
console.print(f"[bold green]Java tracing complete:[/] {test_count} replay test files generated")
|
||||
|
|
@ -404,7 +419,7 @@ def _run_java_tracer(existing_args: Namespace | None = None) -> ArgumentParser:
|
|||
config.replay_test = replay_test_paths
|
||||
config.previous_checkpoint_functions = None
|
||||
config.effort = EffortLevel.HIGH.value
|
||||
config.no_pr = True
|
||||
config.no_pr = getattr(config, "no_pr", False)
|
||||
config.file = None
|
||||
config.function = None
|
||||
config.test_project_root = project_root
|
||||
|
|
|
|||
|
|
@ -50,6 +50,7 @@ def run_test(expected_improvement_pct: int) -> bool:
|
|||
"--no-project",
|
||||
"-m",
|
||||
"codeflash.main",
|
||||
"--no-pr",
|
||||
"optimize",
|
||||
"java",
|
||||
"-cp",
|
||||
|
|
@ -59,6 +60,7 @@ def run_test(expected_improvement_pct: int) -> bool:
|
|||
|
||||
env = os.environ.copy()
|
||||
env["PYTHONIOENCODING"] = "utf-8"
|
||||
env["PYTHONUNBUFFERED"] = "1"
|
||||
logging.info(f"Running command: {' '.join(command)}")
|
||||
logging.info(f"Working directory: {fixture_dir}")
|
||||
process = subprocess.Popen(
|
||||
|
|
@ -73,13 +75,11 @@ def run_test(expected_improvement_pct: int) -> bool:
|
|||
|
||||
output = []
|
||||
for line in process.stdout:
|
||||
logging.info(line.strip())
|
||||
print(line, end="", flush=True)
|
||||
output.append(line)
|
||||
|
||||
return_code = process.wait()
|
||||
stdout = "".join(output)
|
||||
if return_code != 0:
|
||||
logging.error(f"Full output:\n{stdout}")
|
||||
|
||||
if return_code != 0:
|
||||
logging.error(f"Command returned exit code {return_code}")
|
||||
|
|
@ -90,7 +90,7 @@ def run_test(expected_improvement_pct: int) -> bool:
|
|||
logging.error("Failed to find replay test generation message")
|
||||
return False
|
||||
|
||||
# Validate: replay tests were discovered
|
||||
# Validate: replay tests were discovered (global count)
|
||||
replay_match = re.search(r"Discovered \d+ existing unit tests? and (\d+) replay tests?", stdout)
|
||||
if not replay_match:
|
||||
logging.error("Failed to find replay test discovery message")
|
||||
|
|
@ -101,6 +101,17 @@ def run_test(expected_improvement_pct: int) -> bool:
|
|||
return False
|
||||
logging.info(f"Replay tests discovered: {num_replay}")
|
||||
|
||||
# Validate: replay test files were used per-function
|
||||
replay_file_match = re.search(r"Discovered \d+ existing unit test files?, (\d+) replay test files?", stdout)
|
||||
if not replay_file_match:
|
||||
logging.error("Failed to find per-function replay test file discovery message")
|
||||
return False
|
||||
num_replay_files = int(replay_file_match.group(1))
|
||||
if num_replay_files == 0:
|
||||
logging.error("No replay test files discovered per-function")
|
||||
return False
|
||||
logging.info(f"Replay test files per-function: {num_replay_files}")
|
||||
|
||||
# Validate: at least one optimization was found
|
||||
if "⚡️ Optimization successful! 📄 " not in stdout:
|
||||
logging.error("Failed to find optimization success message")
|
||||
|
|
|
|||
|
|
@ -36,20 +36,30 @@ public class Workload {
|
|||
}
|
||||
|
||||
public static void main(String[] args) {
|
||||
// Exercise the methods so the tracer can capture invocations
|
||||
// Run methods with large inputs so JFR can capture CPU samples.
|
||||
// Small inputs finish too fast (<1ms) for JFR's 10ms sampling interval.
|
||||
for (int round = 0; round < 1000; round++) {
|
||||
computeSum(100_000);
|
||||
repeatString("hello world ", 1000);
|
||||
|
||||
List<Integer> nums = new ArrayList<>();
|
||||
for (int i = 1; i <= 10_000; i++) nums.add(i);
|
||||
filterEvens(nums);
|
||||
|
||||
Workload w = new Workload();
|
||||
w.instanceMethod(100_000, 42);
|
||||
}
|
||||
|
||||
// Also call with small inputs for variety in traced args
|
||||
System.out.println("computeSum(100) = " + computeSum(100));
|
||||
System.out.println("computeSum(50) = " + computeSum(50));
|
||||
|
||||
System.out.println("repeatString(\"ab\", 3) = " + repeatString("ab", 3));
|
||||
System.out.println("repeatString(\"x\", 5) = " + repeatString("x", 5));
|
||||
|
||||
List<Integer> nums = new ArrayList<>();
|
||||
for (int i = 1; i <= 10; i++) nums.add(i);
|
||||
System.out.println("filterEvens(1..10) = " + filterEvens(nums));
|
||||
List<Integer> small = new ArrayList<>();
|
||||
for (int i = 1; i <= 10; i++) small.add(i);
|
||||
System.out.println("filterEvens(1..10) = " + filterEvens(small));
|
||||
|
||||
Workload w = new Workload();
|
||||
System.out.println("instanceMethod(5, 3) = " + w.instanceMethod(5, 3));
|
||||
System.out.println("instanceMethod(10, 2) = " + w.instanceMethod(10, 2));
|
||||
|
||||
System.out.println("Workload complete.");
|
||||
}
|
||||
|
|
|
|||
302
tests/test_languages/test_java/test_jfr_parser.py
Normal file
302
tests/test_languages/test_java/test_jfr_parser.py
Normal file
|
|
@ -0,0 +1,302 @@
|
|||
"""Tests for JFR parser — class name normalization, package filtering, addressable time."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import subprocess
|
||||
from pathlib import Path
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from codeflash.languages.java.jfr_parser import JfrProfile
|
||||
|
||||
|
||||
def _make_jfr_json(events: list[dict]) -> str:
|
||||
"""Create fake JFR JSON output matching the jfr print format."""
|
||||
return json.dumps({"recording": {"events": events}})
|
||||
|
||||
|
||||
def _make_execution_sample(class_name: str, method_name: str, start_time: str = "2026-01-01T00:00:00Z") -> dict:
|
||||
return {
|
||||
"type": "jdk.ExecutionSample",
|
||||
"values": {
|
||||
"startTime": start_time,
|
||||
"stackTrace": {
|
||||
"frames": [
|
||||
{
|
||||
"method": {
|
||||
"type": {"name": class_name},
|
||||
"name": method_name,
|
||||
"descriptor": "()V",
|
||||
},
|
||||
"lineNumber": 42,
|
||||
}
|
||||
],
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
class TestClassNameNormalization:
|
||||
"""Test that JVM internal class names (com/example/Foo) are normalized to dots (com.example.Foo)."""
|
||||
|
||||
def test_slash_separators_normalized_to_dots(self, tmp_path: Path) -> None:
|
||||
jfr_file = tmp_path / "test.jfr"
|
||||
jfr_file.write_text("dummy", encoding="utf-8")
|
||||
|
||||
jfr_json = _make_jfr_json(
|
||||
[
|
||||
_make_execution_sample("com/aerospike/client/command/Buffer", "bytesToInt"),
|
||||
_make_execution_sample("com/aerospike/client/command/Buffer", "bytesToInt"),
|
||||
_make_execution_sample("com/aerospike/client/util/Utf8", "encodedLength"),
|
||||
]
|
||||
)
|
||||
|
||||
with patch("subprocess.run") as mock_run:
|
||||
mock_run.return_value = subprocess.CompletedProcess(args=[], returncode=0, stdout=jfr_json, stderr="")
|
||||
profile = JfrProfile(jfr_file, ["com.aerospike"])
|
||||
|
||||
assert profile._total_samples == 3
|
||||
assert len(profile._method_samples) == 2
|
||||
|
||||
# Keys should use dots, not slashes
|
||||
assert "com.aerospike.client.command.Buffer.bytesToInt" in profile._method_samples
|
||||
assert "com.aerospike.client.util.Utf8.encodedLength" in profile._method_samples
|
||||
|
||||
def test_method_info_uses_dot_class_names(self, tmp_path: Path) -> None:
|
||||
jfr_file = tmp_path / "test.jfr"
|
||||
jfr_file.write_text("dummy", encoding="utf-8")
|
||||
|
||||
jfr_json = _make_jfr_json(
|
||||
[_make_execution_sample("com/example/MyClass", "myMethod")]
|
||||
)
|
||||
|
||||
with patch("subprocess.run") as mock_run:
|
||||
mock_run.return_value = subprocess.CompletedProcess(args=[], returncode=0, stdout=jfr_json, stderr="")
|
||||
profile = JfrProfile(jfr_file, ["com.example"])
|
||||
|
||||
info = profile._method_info.get("com.example.MyClass.myMethod")
|
||||
assert info is not None
|
||||
assert info["class_name"] == "com.example.MyClass"
|
||||
assert info["method_name"] == "myMethod"
|
||||
|
||||
|
||||
class TestPackageFiltering:
|
||||
def test_filters_by_package_prefix(self, tmp_path: Path) -> None:
|
||||
jfr_file = tmp_path / "test.jfr"
|
||||
jfr_file.write_text("dummy", encoding="utf-8")
|
||||
|
||||
jfr_json = _make_jfr_json(
|
||||
[
|
||||
_make_execution_sample("com/aerospike/client/Value", "get"),
|
||||
_make_execution_sample("java/util/HashMap", "put"),
|
||||
_make_execution_sample("com/aerospike/benchmarks/Main", "main"),
|
||||
]
|
||||
)
|
||||
|
||||
with patch("subprocess.run") as mock_run:
|
||||
mock_run.return_value = subprocess.CompletedProcess(args=[], returncode=0, stdout=jfr_json, stderr="")
|
||||
profile = JfrProfile(jfr_file, ["com.aerospike"])
|
||||
|
||||
# Only com.aerospike classes should be in samples
|
||||
assert len(profile._method_samples) == 2
|
||||
assert "com.aerospike.client.Value.get" in profile._method_samples
|
||||
assert "com.aerospike.benchmarks.Main.main" in profile._method_samples
|
||||
assert "java.util.HashMap.put" not in profile._method_samples
|
||||
|
||||
def test_empty_packages_includes_all(self, tmp_path: Path) -> None:
|
||||
jfr_file = tmp_path / "test.jfr"
|
||||
jfr_file.write_text("dummy", encoding="utf-8")
|
||||
|
||||
jfr_json = _make_jfr_json(
|
||||
[
|
||||
_make_execution_sample("com/example/Foo", "bar"),
|
||||
_make_execution_sample("java/lang/String", "length"),
|
||||
]
|
||||
)
|
||||
|
||||
with patch("subprocess.run") as mock_run:
|
||||
mock_run.return_value = subprocess.CompletedProcess(args=[], returncode=0, stdout=jfr_json, stderr="")
|
||||
profile = JfrProfile(jfr_file, [])
|
||||
|
||||
assert len(profile._method_samples) == 2
|
||||
|
||||
|
||||
class TestAddressableTime:
|
||||
def test_addressable_time_proportional_to_samples(self, tmp_path: Path) -> None:
|
||||
jfr_file = tmp_path / "test.jfr"
|
||||
jfr_file.write_text("dummy", encoding="utf-8")
|
||||
|
||||
# 3 samples for methodA, 1 for methodB, spanning 10 seconds
|
||||
jfr_json = _make_jfr_json(
|
||||
[
|
||||
_make_execution_sample("com/example/Foo", "methodA", "2026-01-01T00:00:00Z"),
|
||||
_make_execution_sample("com/example/Foo", "methodA", "2026-01-01T00:00:03Z"),
|
||||
_make_execution_sample("com/example/Foo", "methodA", "2026-01-01T00:00:06Z"),
|
||||
_make_execution_sample("com/example/Foo", "methodB", "2026-01-01T00:00:10Z"),
|
||||
]
|
||||
)
|
||||
|
||||
with patch("subprocess.run") as mock_run:
|
||||
mock_run.return_value = subprocess.CompletedProcess(args=[], returncode=0, stdout=jfr_json, stderr="")
|
||||
profile = JfrProfile(jfr_file, ["com.example"])
|
||||
|
||||
time_a = profile.get_addressable_time_ns("com.example.Foo", "methodA")
|
||||
time_b = profile.get_addressable_time_ns("com.example.Foo", "methodB")
|
||||
|
||||
# methodA has 3x the samples of methodB, so 3x the addressable time
|
||||
assert time_a > 0
|
||||
assert time_b > 0
|
||||
assert time_a == pytest.approx(time_b * 3, rel=0.01)
|
||||
|
||||
def test_addressable_time_zero_for_unknown_method(self, tmp_path: Path) -> None:
|
||||
jfr_file = tmp_path / "test.jfr"
|
||||
jfr_file.write_text("dummy", encoding="utf-8")
|
||||
|
||||
jfr_json = _make_jfr_json(
|
||||
[_make_execution_sample("com/example/Foo", "bar")]
|
||||
)
|
||||
|
||||
with patch("subprocess.run") as mock_run:
|
||||
mock_run.return_value = subprocess.CompletedProcess(args=[], returncode=0, stdout=jfr_json, stderr="")
|
||||
profile = JfrProfile(jfr_file, ["com.example"])
|
||||
|
||||
assert profile.get_addressable_time_ns("com.example.Foo", "nonExistent") == 0.0
|
||||
|
||||
|
||||
class TestMethodRanking:
|
||||
def test_ranking_ordered_by_sample_count(self, tmp_path: Path) -> None:
|
||||
jfr_file = tmp_path / "test.jfr"
|
||||
jfr_file.write_text("dummy", encoding="utf-8")
|
||||
|
||||
jfr_json = _make_jfr_json(
|
||||
[
|
||||
_make_execution_sample("com/example/A", "hot"),
|
||||
_make_execution_sample("com/example/A", "hot"),
|
||||
_make_execution_sample("com/example/A", "hot"),
|
||||
_make_execution_sample("com/example/B", "warm"),
|
||||
_make_execution_sample("com/example/B", "warm"),
|
||||
_make_execution_sample("com/example/C", "cold"),
|
||||
]
|
||||
)
|
||||
|
||||
with patch("subprocess.run") as mock_run:
|
||||
mock_run.return_value = subprocess.CompletedProcess(args=[], returncode=0, stdout=jfr_json, stderr="")
|
||||
profile = JfrProfile(jfr_file, ["com.example"])
|
||||
|
||||
ranking = profile.get_method_ranking()
|
||||
assert len(ranking) == 3
|
||||
assert ranking[0]["method_name"] == "hot"
|
||||
assert ranking[0]["sample_count"] == 3
|
||||
assert ranking[1]["method_name"] == "warm"
|
||||
assert ranking[1]["sample_count"] == 2
|
||||
assert ranking[2]["method_name"] == "cold"
|
||||
assert ranking[2]["sample_count"] == 1
|
||||
|
||||
def test_empty_ranking_when_no_samples(self, tmp_path: Path) -> None:
|
||||
jfr_file = tmp_path / "test.jfr"
|
||||
jfr_file.write_text("dummy", encoding="utf-8")
|
||||
|
||||
jfr_json = _make_jfr_json([])
|
||||
|
||||
with patch("subprocess.run") as mock_run:
|
||||
mock_run.return_value = subprocess.CompletedProcess(args=[], returncode=0, stdout=jfr_json, stderr="")
|
||||
profile = JfrProfile(jfr_file, ["com.example"])
|
||||
|
||||
assert profile.get_method_ranking() == []
|
||||
|
||||
def test_ranking_uses_dot_class_names(self, tmp_path: Path) -> None:
|
||||
jfr_file = tmp_path / "test.jfr"
|
||||
jfr_file.write_text("dummy", encoding="utf-8")
|
||||
|
||||
jfr_json = _make_jfr_json(
|
||||
[_make_execution_sample("com/example/nested/Deep", "method")]
|
||||
)
|
||||
|
||||
with patch("subprocess.run") as mock_run:
|
||||
mock_run.return_value = subprocess.CompletedProcess(args=[], returncode=0, stdout=jfr_json, stderr="")
|
||||
profile = JfrProfile(jfr_file, ["com.example"])
|
||||
|
||||
ranking = profile.get_method_ranking()
|
||||
assert len(ranking) == 1
|
||||
assert ranking[0]["class_name"] == "com.example.nested.Deep"
|
||||
|
||||
|
||||
class TestGracefulTimeout:
|
||||
"""Test that _run_java_with_graceful_timeout sends SIGTERM before SIGKILL."""
|
||||
|
||||
def test_sends_sigterm_on_timeout(self) -> None:
|
||||
import signal
|
||||
|
||||
from codeflash.languages.java.tracer import _run_java_with_graceful_timeout
|
||||
|
||||
# Run a sleep command with a 1s timeout — should get SIGTERM'd
|
||||
import os
|
||||
|
||||
env = os.environ.copy()
|
||||
_run_java_with_graceful_timeout(["sleep", "60"], env, timeout=1, stage_name="test")
|
||||
# If we get here, the process was killed (didn't hang for 60s)
|
||||
|
||||
def test_no_timeout_runs_normally(self) -> None:
|
||||
import os
|
||||
|
||||
from codeflash.languages.java.tracer import _run_java_with_graceful_timeout
|
||||
|
||||
env = os.environ.copy()
|
||||
_run_java_with_graceful_timeout(["echo", "hello"], env, timeout=0, stage_name="test")
|
||||
# Should complete without error
|
||||
|
||||
|
||||
class TestProjectRootResolution:
|
||||
"""Test that project_root is correctly set for Java multi-module projects."""
|
||||
|
||||
def test_java_project_root_is_build_root_not_module(self, tmp_path: Path, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
"""For multi-module Maven, project_root should be the root with <modules>, not a sub-module."""
|
||||
# Create a multi-module project
|
||||
(tmp_path / "pom.xml").write_text(
|
||||
'<project xmlns="http://maven.apache.org/POM/4.0.0"><modules><module>client</module></modules></project>',
|
||||
encoding="utf-8",
|
||||
)
|
||||
client = tmp_path / "client"
|
||||
client.mkdir()
|
||||
(client / "pom.xml").write_text("<project/>", encoding="utf-8")
|
||||
src = client / "src" / "main" / "java"
|
||||
src.mkdir(parents=True)
|
||||
test = tmp_path / "src" / "test" / "java"
|
||||
test.mkdir(parents=True)
|
||||
monkeypatch.chdir(tmp_path)
|
||||
|
||||
from codeflash.code_utils.config_parser import parse_config_file
|
||||
|
||||
config, config_path = parse_config_file()
|
||||
assert config["language"] == "java"
|
||||
|
||||
# config_path should be the project root directory
|
||||
assert config_path == tmp_path
|
||||
|
||||
def test_project_root_is_path_not_string(self, tmp_path: Path, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
"""project_root from process_pyproject_config should be a Path for Java projects."""
|
||||
from argparse import Namespace
|
||||
|
||||
(tmp_path / "pom.xml").write_text("<project/>", encoding="utf-8")
|
||||
src = tmp_path / "src" / "main" / "java"
|
||||
src.mkdir(parents=True)
|
||||
test = tmp_path / "src" / "test" / "java"
|
||||
test.mkdir(parents=True)
|
||||
monkeypatch.chdir(tmp_path)
|
||||
|
||||
from codeflash.cli_cmds.cli import process_pyproject_config
|
||||
|
||||
# Create a minimal args namespace matching what parse_args produces
|
||||
args = Namespace(
|
||||
config_file=None, module_root=None, tests_root=None, benchmarks_root=None,
|
||||
ignore_paths=None, pytest_cmd=None, formatter_cmds=None, disable_telemetry=None,
|
||||
disable_imports_sorting=None, git_remote=None, override_fixtures=None,
|
||||
benchmark=False, verbose=False, version=False, show_config=False, reset_config=False,
|
||||
)
|
||||
args = process_pyproject_config(args)
|
||||
|
||||
assert hasattr(args, "project_root")
|
||||
assert isinstance(args.project_root, Path)
|
||||
assert args.project_root == tmp_path
|
||||
255
tests/test_languages/test_java/test_replay_test_generation.py
Normal file
255
tests/test_languages/test_java/test_replay_test_generation.py
Normal file
|
|
@ -0,0 +1,255 @@
|
|||
"""Tests for Java replay test generation — JUnit 4/5 support, overload handling, instrumentation skip."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import sqlite3
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
from codeflash.languages.java.replay_test import generate_replay_tests, parse_replay_test_metadata
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def trace_db(tmp_path: Path) -> Path:
|
||||
"""Create a trace database with sample function calls."""
|
||||
db_path = tmp_path / "trace.db"
|
||||
conn = sqlite3.connect(str(db_path))
|
||||
conn.execute(
|
||||
"CREATE TABLE function_calls("
|
||||
"type TEXT, function TEXT, classname TEXT, filename TEXT, "
|
||||
"line_number INTEGER, descriptor TEXT, time_ns INTEGER, args BLOB)"
|
||||
)
|
||||
conn.execute("CREATE TABLE metadata(key TEXT PRIMARY KEY, value TEXT)")
|
||||
conn.execute(
|
||||
"INSERT INTO function_calls VALUES (?, ?, ?, ?, ?, ?, ?, ?)",
|
||||
("call", "add", "com.example.Calculator", "Calculator.java", 10, "(II)I", 1000, b"\x00"),
|
||||
)
|
||||
conn.execute(
|
||||
"INSERT INTO function_calls VALUES (?, ?, ?, ?, ?, ?, ?, ?)",
|
||||
("call", "add", "com.example.Calculator", "Calculator.java", 10, "(II)I", 2000, b"\x00"),
|
||||
)
|
||||
conn.execute(
|
||||
"INSERT INTO function_calls VALUES (?, ?, ?, ?, ?, ?, ?, ?)",
|
||||
("call", "multiply", "com.example.Calculator", "Calculator.java", 20, "(II)I", 3000, b"\x00"),
|
||||
)
|
||||
conn.commit()
|
||||
conn.close()
|
||||
return db_path
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def trace_db_overloaded(tmp_path: Path) -> Path:
|
||||
"""Create a trace database with overloaded methods (same name, different descriptors)."""
|
||||
db_path = tmp_path / "trace_overloaded.db"
|
||||
conn = sqlite3.connect(str(db_path))
|
||||
conn.execute(
|
||||
"CREATE TABLE function_calls("
|
||||
"type TEXT, function TEXT, classname TEXT, filename TEXT, "
|
||||
"line_number INTEGER, descriptor TEXT, time_ns INTEGER, args BLOB)"
|
||||
)
|
||||
conn.execute("CREATE TABLE metadata(key TEXT PRIMARY KEY, value TEXT)")
|
||||
# Two overloads of estimateKeySize with different descriptors
|
||||
for i in range(3):
|
||||
conn.execute(
|
||||
"INSERT INTO function_calls VALUES (?, ?, ?, ?, ?, ?, ?, ?)",
|
||||
("call", "estimateKeySize", "com.example.Command", "Command.java", 10, "(I)I", i * 1000, b"\x00"),
|
||||
)
|
||||
for i in range(2):
|
||||
conn.execute(
|
||||
"INSERT INTO function_calls VALUES (?, ?, ?, ?, ?, ?, ?, ?)",
|
||||
(
|
||||
"call",
|
||||
"estimateKeySize",
|
||||
"com.example.Command",
|
||||
"Command.java",
|
||||
15,
|
||||
"(Ljava/lang/String;)I",
|
||||
(i + 10) * 1000,
|
||||
b"\x00",
|
||||
),
|
||||
)
|
||||
conn.commit()
|
||||
conn.close()
|
||||
return db_path
|
||||
|
||||
|
||||
class TestGenerateReplayTestsJunit5:
|
||||
def test_generates_junit5_by_default(self, trace_db: Path, tmp_path: Path) -> None:
|
||||
output_dir = tmp_path / "output"
|
||||
count = generate_replay_tests(trace_db, output_dir, tmp_path)
|
||||
assert count == 1
|
||||
|
||||
test_file = list(output_dir.glob("*.java"))[0]
|
||||
content = test_file.read_text(encoding="utf-8")
|
||||
assert "import org.junit.jupiter.api.Test;" in content
|
||||
assert "import org.junit.jupiter.api.AfterAll;" in content
|
||||
assert "@Test void replay_add_0()" in content
|
||||
|
||||
def test_junit5_class_is_package_private(self, trace_db: Path, tmp_path: Path) -> None:
|
||||
output_dir = tmp_path / "output"
|
||||
generate_replay_tests(trace_db, output_dir, tmp_path)
|
||||
|
||||
test_file = list(output_dir.glob("*.java"))[0]
|
||||
content = test_file.read_text(encoding="utf-8")
|
||||
assert "class ReplayTest_" in content
|
||||
assert "public class ReplayTest_" not in content
|
||||
|
||||
|
||||
class TestGenerateReplayTestsJunit4:
|
||||
def test_generates_junit4_imports(self, trace_db: Path, tmp_path: Path) -> None:
|
||||
output_dir = tmp_path / "output"
|
||||
count = generate_replay_tests(trace_db, output_dir, tmp_path, test_framework="junit4")
|
||||
assert count == 1
|
||||
|
||||
test_file = list(output_dir.glob("*.java"))[0]
|
||||
content = test_file.read_text(encoding="utf-8")
|
||||
assert "import org.junit.Test;" in content
|
||||
assert "import org.junit.AfterClass;" in content
|
||||
assert "org.junit.jupiter" not in content
|
||||
|
||||
def test_junit4_methods_are_public(self, trace_db: Path, tmp_path: Path) -> None:
|
||||
output_dir = tmp_path / "output"
|
||||
generate_replay_tests(trace_db, output_dir, tmp_path, test_framework="junit4")
|
||||
|
||||
test_file = list(output_dir.glob("*.java"))[0]
|
||||
content = test_file.read_text(encoding="utf-8")
|
||||
assert "@Test public void replay_add_0()" in content
|
||||
|
||||
def test_junit4_class_is_public(self, trace_db: Path, tmp_path: Path) -> None:
|
||||
output_dir = tmp_path / "output"
|
||||
generate_replay_tests(trace_db, output_dir, tmp_path, test_framework="junit4")
|
||||
|
||||
test_file = list(output_dir.glob("*.java"))[0]
|
||||
content = test_file.read_text(encoding="utf-8")
|
||||
assert "public class ReplayTest_" in content
|
||||
|
||||
def test_junit4_cleanup_uses_afterclass(self, trace_db: Path, tmp_path: Path) -> None:
|
||||
output_dir = tmp_path / "output"
|
||||
generate_replay_tests(trace_db, output_dir, tmp_path, test_framework="junit4")
|
||||
|
||||
test_file = list(output_dir.glob("*.java"))[0]
|
||||
content = test_file.read_text(encoding="utf-8")
|
||||
assert "@AfterClass" in content
|
||||
assert "@AfterAll" not in content
|
||||
|
||||
|
||||
class TestOverloadedMethods:
|
||||
def test_no_duplicate_method_names(self, trace_db_overloaded: Path, tmp_path: Path) -> None:
|
||||
output_dir = tmp_path / "output"
|
||||
count = generate_replay_tests(trace_db_overloaded, output_dir, tmp_path)
|
||||
assert count == 1
|
||||
|
||||
test_file = list(output_dir.glob("*.java"))[0]
|
||||
content = test_file.read_text(encoding="utf-8")
|
||||
|
||||
# Should have 5 unique methods (3 from first overload + 2 from second)
|
||||
assert "replay_estimateKeySize_0" in content
|
||||
assert "replay_estimateKeySize_1" in content
|
||||
assert "replay_estimateKeySize_2" in content
|
||||
assert "replay_estimateKeySize_3" in content
|
||||
assert "replay_estimateKeySize_4" in content
|
||||
|
||||
# Verify no duplicates by counting occurrences
|
||||
lines = content.splitlines()
|
||||
method_lines = [l for l in lines if "void replay_estimateKeySize_" in l]
|
||||
method_names = [l.split("void ")[1].split("(")[0] for l in method_lines]
|
||||
assert len(method_names) == len(set(method_names)), f"Duplicate methods: {method_names}"
|
||||
|
||||
|
||||
class TestReplayTestInstrumentation:
|
||||
def test_replay_tests_instrumented_correctly(self, trace_db: Path, tmp_path: Path) -> None:
|
||||
"""Replay tests with compact @Test lines should be instrumented without orphaned code."""
|
||||
from codeflash.languages.java.discovery import discover_functions_from_source
|
||||
|
||||
output_dir = tmp_path / "output"
|
||||
generate_replay_tests(trace_db, output_dir, tmp_path)
|
||||
|
||||
test_file = list(output_dir.glob("*.java"))[0]
|
||||
|
||||
src = "public class Calculator { public int add(int a, int b) { return a + b; } }"
|
||||
funcs = discover_functions_from_source(src, tmp_path / "Calculator.java")
|
||||
target = funcs[0]
|
||||
|
||||
from codeflash.languages.java.support import JavaSupport
|
||||
|
||||
support = JavaSupport()
|
||||
success, instrumented = support.instrument_existing_test(
|
||||
test_path=test_file,
|
||||
call_positions=[],
|
||||
function_to_optimize=target,
|
||||
tests_project_root=tmp_path,
|
||||
mode="behavior",
|
||||
)
|
||||
assert success
|
||||
assert instrumented is not None
|
||||
assert "__perfinstrumented" in instrumented
|
||||
|
||||
# Verify no code outside class body
|
||||
lines = instrumented.splitlines()
|
||||
class_closed = False
|
||||
for line in lines:
|
||||
if line.strip() == "}" and not line.startswith(" "):
|
||||
class_closed = True
|
||||
elif class_closed and line.strip() and not line.strip().startswith("//"):
|
||||
pytest.fail(f"Orphaned code outside class: {line}")
|
||||
|
||||
def test_replay_tests_perf_instrumented(self, trace_db: Path, tmp_path: Path) -> None:
|
||||
from codeflash.languages.java.discovery import discover_functions_from_source
|
||||
|
||||
output_dir = tmp_path / "output"
|
||||
generate_replay_tests(trace_db, output_dir, tmp_path)
|
||||
|
||||
test_file = list(output_dir.glob("*.java"))[0]
|
||||
|
||||
src = "public class Calculator { public int add(int a, int b) { return a + b; } }"
|
||||
funcs = discover_functions_from_source(src, tmp_path / "Calculator.java")
|
||||
target = funcs[0]
|
||||
|
||||
from codeflash.languages.java.support import JavaSupport
|
||||
|
||||
support = JavaSupport()
|
||||
success, instrumented = support.instrument_existing_test(
|
||||
test_path=test_file,
|
||||
call_positions=[],
|
||||
function_to_optimize=target,
|
||||
tests_project_root=tmp_path,
|
||||
mode="performance",
|
||||
)
|
||||
assert success
|
||||
assert "__perfonlyinstrumented" in instrumented
|
||||
|
||||
def test_regular_tests_still_instrumented(self, tmp_path: Path) -> None:
|
||||
from codeflash.languages.java.discovery import discover_functions_from_source
|
||||
|
||||
src = "public class Calculator { public int add(int a, int b) { return a + b; } }"
|
||||
funcs = discover_functions_from_source(src, tmp_path / "Calculator.java")
|
||||
target = funcs[0]
|
||||
|
||||
test_file = tmp_path / "CalculatorTest.java"
|
||||
test_file.write_text(
|
||||
"""
|
||||
import org.junit.jupiter.api.Test;
|
||||
public class CalculatorTest {
|
||||
@Test
|
||||
public void testAdd() {
|
||||
Calculator calc = new Calculator();
|
||||
calc.add(1, 2);
|
||||
}
|
||||
}
|
||||
""",
|
||||
encoding="utf-8",
|
||||
)
|
||||
|
||||
from codeflash.languages.java.support import JavaSupport
|
||||
|
||||
support = JavaSupport()
|
||||
success, instrumented = support.instrument_existing_test(
|
||||
test_path=test_file,
|
||||
call_positions=[],
|
||||
function_to_optimize=target,
|
||||
tests_project_root=tmp_path,
|
||||
mode="behavior",
|
||||
)
|
||||
assert success
|
||||
assert "CODEFLASH_LOOP_INDEX" in instrumented
|
||||
Loading…
Reference in a new issue