Merge branch 'main' into cf-1086-java-init-yes-flag

This commit is contained in:
Mohamed Ashraf 2026-04-08 17:11:32 +00:00
commit 9c8bc7b8df
74 changed files with 4546 additions and 751 deletions

View file

@ -21,7 +21,7 @@ codeflash/
├── api/ # AI service communication
├── code_utils/ # Code parsing, git utilities
├── models/ # Pydantic models and types
├── languages/ # Multi-language support (Python, JavaScript/TypeScript, Java planned)
├── languages/ # Multi-language support (Python, JavaScript/TypeScript, Java)
│ ├── base.py # LanguageSupport protocol and shared data types
│ ├── registry.py # Language registration and lookup by extension/enum
│ ├── current.py # Current language singleton (set_current_language / current_language_support)
@ -35,11 +35,29 @@ codeflash/
│ │ ├── test_runner.py # Test subprocess execution for Python
│ │ ├── instrument_codeflash_capture.py # Instrument __init__ with capture decorators
│ │ └── parse_line_profile_test_output.py # Parse line profiler output
│ └── javascript/
│ ├── support.py # JavaScriptSupport (LanguageSupport implementation)
│ ├── function_optimizer.py # JavaScriptFunctionOptimizer subclass
│ ├── optimizer.py # JS project root finding & module preparation
│ └── normalizer.py # JS/TS code normalization for deduplication
│ ├── javascript/
│ │ ├── support.py # JavaScriptSupport (LanguageSupport implementation)
│ │ ├── function_optimizer.py # JavaScriptFunctionOptimizer subclass
│ │ ├── optimizer.py # JS project root finding & module preparation
│ │ └── normalizer.py # JS/TS code normalization for deduplication
│ └── java/
│ ├── support.py # JavaSupport (LanguageSupport implementation)
│ ├── function_optimizer.py # JavaFunctionOptimizer subclass
│ ├── build_tool_strategy.py # Abstract BuildToolStrategy for Maven/Gradle
│ ├── maven_strategy.py # Maven build tool strategy
│ ├── gradle_strategy.py # Gradle build tool strategy
│ ├── build_tools.py # Build tool detection and project info
│ ├── build_config_strategy.py # Config read/write for pom.xml / gradle.properties
│ ├── test_runner.py # Test execution via Maven/Gradle
│ ├── instrumentation.py # Behavior capture and benchmarking instrumentation
│ ├── discovery.py # Function discovery using tree-sitter
│ ├── test_discovery.py # Test discovery for JUnit/TestNG
│ ├── context.py # Code context extraction
│ ├── comparator.py # Test result comparison
│ ├── config.py # Java project detection and config
│ ├── formatter.py # Code formatting and normalization
│ ├── line_profiler.py # JVM bytecode agent-based line profiling
│ └── tracer.py # Two-stage JFR + argument capture tracer
├── setup/ # Config schema, auto-detection, first-run experience
├── picklepatch/ # Serialization/deserialization utilities
├── tracing/ # Function call tracing
@ -57,7 +75,7 @@ codeflash/
|------|------------|
| CLI arguments & commands | `cli_cmds/cli.py` (parsing), `main.py` (subcommand dispatch) |
| Optimization orchestration | `optimization/optimizer.py``run()` |
| Per-function optimization | `languages/function_optimizer.py` (base), `languages/python/function_optimizer.py`, `languages/javascript/function_optimizer.py` |
| Per-function optimization | `languages/function_optimizer.py` (base), `languages/python/function_optimizer.py`, `languages/javascript/function_optimizer.py`, `languages/java/function_optimizer.py` |
| Function discovery | `discovery/functions_to_optimize.py` |
| Context extraction | `languages/<lang>/context/code_context_extractor.py` |
| Test execution | `languages/<lang>/support.py` (`run_behavioral_tests`, etc.), `verification/pytest_plugin.py` |
@ -67,7 +85,7 @@ codeflash/
## LanguageSupport Protocol Methods
Core protocol in `languages/base.py`. Each language (`PythonSupport`, `JavaScriptSupport`) implements these.
Core protocol in `languages/base.py`. Each language (`PythonSupport`, `JavaScriptSupport`, `JavaSupport`) implements these.
| Category | Method/Property | Purpose |
|----------|----------------|---------|

View file

@ -3,17 +3,9 @@ name: E2E - Java Tracer
on:
pull_request:
paths:
- 'codeflash/languages/java/**'
- 'codeflash/languages/base.py'
- 'codeflash/languages/registry.py'
- 'codeflash/tracer.py'
- 'codeflash/benchmarking/function_ranker.py'
- 'codeflash/discovery/functions_to_optimize.py'
- 'codeflash/optimization/**'
- 'codeflash/verification/**'
- 'codeflash/**'
- 'codeflash-java-runtime/**'
- 'tests/test_languages/fixtures/java_tracer_e2e/**'
- 'tests/scripts/end_to_end_test_java_tracer.py'
- 'tests/**'
- '.github/workflows/e2e-java-tracer.yaml'
workflow_dispatch:

View file

@ -42,7 +42,7 @@
<dependency>
<groupId>com.codeflash</groupId>
<artifactId>codeflash-runtime</artifactId>
<version>1.0.0</version>
<version>1.0.1</version>
<scope>test</scope>
</dependency>
</dependencies>

View file

@ -7,7 +7,7 @@
<groupId>com.codeflash</groupId>
<artifactId>codeflash-runtime</artifactId>
<version>1.0.0</version>
<version>1.0.1</version>
<packaging>jar</packaging>
<name>CodeFlash Java Runtime</name>

View file

@ -12,20 +12,181 @@ import org.objectweb.asm.Type;
public class ReplayHelper {
private final Connection db;
private final Connection traceDb;
// Codeflash instrumentation state read from environment variables once
private final String mode; // "behavior", "performance", or null
private final int loopIndex;
private final String testIteration;
private final String outputFile; // SQLite path for behavior capture
private final int innerIterations; // for performance looping
// Behavior mode: lazily opened SQLite connection for writing results
private Connection behaviorDb;
private boolean behaviorDbInitialized;
public ReplayHelper(String traceDbPath) {
try {
this.db = DriverManager.getConnection("jdbc:sqlite:" + traceDbPath);
this.traceDb = DriverManager.getConnection("jdbc:sqlite:" + traceDbPath);
} catch (SQLException e) {
throw new RuntimeException("Failed to open trace database: " + traceDbPath, e);
}
// Read codeflash instrumentation env vars (set by the test runner)
this.mode = System.getenv("CODEFLASH_MODE");
this.loopIndex = parseIntEnv("CODEFLASH_LOOP_INDEX", 1);
this.testIteration = getEnvOrDefault("CODEFLASH_TEST_ITERATION", "0");
this.outputFile = System.getenv("CODEFLASH_OUTPUT_FILE");
this.innerIterations = parseIntEnv("CODEFLASH_INNER_ITERATIONS", 10);
}
public void replay(String className, String methodName, String descriptor, int invocationIndex) throws Exception {
// Query the function_calls table for this method at the given index
// Deserialize args and resolve method (done once, outside timing)
Object[] allArgs = loadArgs(className, methodName, descriptor, invocationIndex);
Class<?> targetClass = Class.forName(className);
Type[] paramTypes = Type.getArgumentTypes(descriptor);
Class<?>[] paramClasses = new Class<?>[paramTypes.length];
for (int i = 0; i < paramTypes.length; i++) {
paramClasses[i] = typeToClass(paramTypes[i]);
}
Method method = targetClass.getDeclaredMethod(methodName, paramClasses);
method.setAccessible(true);
boolean isStatic = Modifier.isStatic(method.getModifiers());
Object instance = null;
if (!isStatic) {
try {
java.lang.reflect.Constructor<?> ctor = targetClass.getDeclaredConstructor();
ctor.setAccessible(true);
instance = ctor.newInstance();
} catch (NoSuchMethodException e) {
instance = new org.objenesis.ObjenesisStd().newInstance(targetClass);
}
}
// Get the calling test method name from the stack trace
String testMethodName = getCallingTestMethodName();
// Module name = the test class that called us
String testClassName = getCallingTestClassName();
if ("behavior".equals(mode)) {
replayBehavior(method, instance, allArgs, className, methodName, testClassName, testMethodName);
} else if ("performance".equals(mode)) {
replayPerformance(method, instance, allArgs, className, methodName, testClassName, testMethodName);
} else {
// No codeflash mode just invoke (trace-only or manual testing)
method.invoke(instance, allArgs);
}
}
private void replayBehavior(Method method, Object instance, Object[] args,
String className, String methodName,
String testClassName, String testMethodName) throws Exception {
// testIteration goes at the END so the Comparator's lastUnderscore stripping
// removes it, making baseline (iteration=0) and candidate (iteration=N) keys match.
String invId = testMethodName + "_" + testIteration;
// Print start marker (same format as behavior instrumentation)
System.out.println("!$######" + testClassName + ":" + testClassName + "." + testMethodName
+ ":" + methodName + ":" + loopIndex + ":" + invId + "######$!");
long startNs = System.nanoTime();
Object result;
try {
result = method.invoke(instance, args);
} catch (java.lang.reflect.InvocationTargetException e) {
throw (Exception) e.getCause();
}
long durationNs = System.nanoTime() - startNs;
// Print end marker
System.out.println("!######" + testClassName + ":" + testClassName + "." + testMethodName
+ ":" + methodName + ":" + loopIndex + ":" + invId + ":" + durationNs + "######!");
// Write return value to SQLite for correctness comparison
if (outputFile != null && !outputFile.isEmpty()) {
writeBehaviorResult(testClassName, testMethodName, methodName, invId, durationNs, result);
}
}
private void replayPerformance(Method method, Object instance, Object[] args,
String className, String methodName,
String testClassName, String testMethodName) throws Exception {
// Performance mode: run inner loop for JIT warmup, print timing for each iteration
int maxInner = innerIterations;
for (int inner = 0; inner < maxInner; inner++) {
int loopId = (loopIndex - 1) * maxInner + inner;
String invId = testMethodName;
// Print start marker
System.out.println("!$######" + testClassName + ":" + testClassName + "." + testMethodName
+ ":" + methodName + ":" + loopId + ":" + invId + "######$!");
long startNs = System.nanoTime();
try {
method.invoke(instance, args);
} catch (java.lang.reflect.InvocationTargetException e) {
// Swallow performance mode doesn't check correctness
}
long durationNs = System.nanoTime() - startNs;
// Print end marker
System.out.println("!######" + testClassName + ":" + testClassName + "." + testMethodName
+ ":" + methodName + ":" + loopId + ":" + invId + ":" + durationNs + "######!");
}
}
private void writeBehaviorResult(String testClassName, String testMethodName,
String functionName, String invId,
long durationNs, Object result) {
try {
ensureBehaviorDb();
String sql = "INSERT INTO test_results VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)";
try (PreparedStatement ps = behaviorDb.prepareStatement(sql)) {
ps.setString(1, testClassName); // test_module_path
ps.setString(2, testClassName); // test_class_name
ps.setString(3, testMethodName); // test_function_name
ps.setString(4, functionName); // function_getting_tested
ps.setInt(5, loopIndex); // loop_index
ps.setString(6, invId); // iteration_id
ps.setLong(7, durationNs); // runtime
ps.setBytes(8, serializeResult(result)); // return_value
ps.setString(9, "function_call"); // verification_type
ps.executeUpdate();
}
} catch (Exception e) {
System.err.println("ReplayHelper: SQLite behavior write error: " + e.getMessage());
}
}
private void ensureBehaviorDb() throws SQLException {
if (behaviorDbInitialized) return;
behaviorDbInitialized = true;
behaviorDb = DriverManager.getConnection("jdbc:sqlite:" + outputFile);
try (java.sql.Statement stmt = behaviorDb.createStatement()) {
stmt.execute("CREATE TABLE IF NOT EXISTS test_results (" +
"test_module_path TEXT, test_class_name TEXT, test_function_name TEXT, " +
"function_getting_tested TEXT, loop_index INTEGER, iteration_id TEXT, " +
"runtime INTEGER, return_value BLOB, verification_type TEXT)");
}
}
private byte[] serializeResult(Object result) {
if (result == null) return null;
try {
return Serializer.serialize(result);
} catch (Exception e) {
// Fall back to String.valueOf if Kryo fails
return String.valueOf(result).getBytes(java.nio.charset.StandardCharsets.UTF_8);
}
}
private Object[] loadArgs(String className, String methodName, String descriptor, int invocationIndex)
throws SQLException {
byte[] argsBlob;
try (PreparedStatement stmt = db.prepareStatement(
try (PreparedStatement stmt = traceDb.prepareStatement(
"SELECT args FROM function_calls " +
"WHERE classname = ? AND function = ? AND descriptor = ? " +
"ORDER BY time_ns LIMIT 1 OFFSET ?")) {
@ -43,46 +204,35 @@ public class ReplayHelper {
}
}
// Deserialize args
Object deserialized = Serializer.deserialize(argsBlob);
if (!(deserialized instanceof Object[])) {
throw new RuntimeException("Deserialized args is not Object[], got: "
+ (deserialized == null ? "null" : deserialized.getClass().getName()));
}
Object[] allArgs = (Object[]) deserialized;
// Load the target class
Class<?> targetClass = Class.forName(className);
// Parse descriptor to find parameter types
Type[] paramTypes = Type.getArgumentTypes(descriptor);
Class<?>[] paramClasses = new Class<?>[paramTypes.length];
for (int i = 0; i < paramTypes.length; i++) {
paramClasses[i] = typeToClass(paramTypes[i]);
return (Object[]) deserialized;
}
// Find the method
Method method = targetClass.getDeclaredMethod(methodName, paramClasses);
method.setAccessible(true);
boolean isStatic = Modifier.isStatic(method.getModifiers());
if (isStatic) {
method.invoke(null, allArgs);
} else {
// Args contain only explicit parameters (no 'this').
// Create a default instance via no-arg constructor or Kryo.
Object instance;
try {
java.lang.reflect.Constructor<?> ctor = targetClass.getDeclaredConstructor();
ctor.setAccessible(true);
instance = ctor.newInstance();
} catch (NoSuchMethodException e) {
// Fall back to Objenesis instantiation (no constructor needed)
instance = new org.objenesis.ObjenesisStd().newInstance(targetClass);
private static String getCallingTestMethodName() {
StackTraceElement[] stack = Thread.currentThread().getStackTrace();
// Walk up: [0]=getStackTrace, [1]=this method, [2]=replay(), [3]=calling test method
for (int i = 3; i < stack.length; i++) {
String method = stack[i].getMethodName();
if (method.startsWith("replay_")) {
return method;
}
method.invoke(instance, allArgs);
}
return stack.length > 3 ? stack[3].getMethodName() : "unknown";
}
private static String getCallingTestClassName() {
StackTraceElement[] stack = Thread.currentThread().getStackTrace();
for (int i = 3; i < stack.length; i++) {
String cls = stack[i].getClassName();
if (cls.contains("ReplayTest") || cls.contains("replay")) {
return cls;
}
}
return stack.length > 3 ? stack[3].getClassName() : "unknown";
}
private static Class<?> typeToClass(Type type) throws ClassNotFoundException {
@ -106,11 +256,23 @@ public class ReplayHelper {
}
}
private static int parseIntEnv(String name, int defaultValue) {
String val = System.getenv(name);
if (val == null || val.isEmpty()) return defaultValue;
try { return Integer.parseInt(val); } catch (NumberFormatException e) { return defaultValue; }
}
private static String getEnvOrDefault(String name, String defaultValue) {
String val = System.getenv(name);
return (val != null && !val.isEmpty()) ? val : defaultValue;
}
public void close() {
try {
if (db != null) db.close();
} catch (SQLException e) {
System.err.println("Error closing ReplayHelper: " + e.getMessage());
try { if (traceDb != null) traceDb.close(); } catch (SQLException e) {
System.err.println("Error closing ReplayHelper trace db: " + e.getMessage());
}
try { if (behaviorDb != null) behaviorDb.close(); } catch (SQLException e) {
System.err.println("Error closing ReplayHelper behavior db: " + e.getMessage());
}
}
}

View file

@ -22,6 +22,7 @@ public final class TraceRecorder {
private final TracerConfig config;
private final TraceWriter writer;
private final ConcurrentHashMap<String, AtomicInteger> functionCounts = new ConcurrentHashMap<>();
private final AtomicInteger droppedCaptures = new AtomicInteger(0);
private final int maxFunctionCount;
private final ExecutorService serializerExecutor;
@ -82,11 +83,13 @@ public final class TraceRecorder {
argsBlob = future.get(SERIALIZATION_TIMEOUT_MS, TimeUnit.MILLISECONDS);
} catch (TimeoutException e) {
future.cancel(true);
droppedCaptures.incrementAndGet();
System.err.println("[codeflash-tracer] Serialization timed out for " + className + "."
+ methodName);
return;
} catch (Exception e) {
Throwable cause = e.getCause() != null ? e.getCause() : e;
droppedCaptures.incrementAndGet();
System.err.println("[codeflash-tracer] Serialization failed for " + className + "."
+ methodName + ": " + cause.getClass().getSimpleName() + ": " + cause.getMessage());
return;
@ -113,11 +116,15 @@ public final class TraceRecorder {
}
metadata.put("totalCaptures", String.valueOf(totalCaptures));
int dropped = droppedCaptures.get();
metadata.put("droppedCaptures", String.valueOf(dropped));
writer.writeMetadata(metadata);
writer.flush();
writer.close();
System.err.println("[codeflash-tracer] Captured " + totalCaptures
+ " invocations across " + functionCounts.size() + " methods");
+ " invocations across " + functionCounts.size() + " methods"
+ (dropped > 0 ? " (" + dropped + " dropped due to serialization timeout/failure)" : ""));
}
}

View file

@ -4,14 +4,20 @@ import org.objectweb.asm.ClassVisitor;
import org.objectweb.asm.MethodVisitor;
import org.objectweb.asm.Opcodes;
import java.util.Collections;
import java.util.Map;
public class TracingClassVisitor extends ClassVisitor {
private final String internalClassName;
private final Map<String, Integer> methodLineNumbers;
private String sourceFile;
public TracingClassVisitor(ClassVisitor classVisitor, String internalClassName) {
public TracingClassVisitor(ClassVisitor classVisitor, String internalClassName,
Map<String, Integer> methodLineNumbers) {
super(Opcodes.ASM9, classVisitor);
this.internalClassName = internalClassName;
this.methodLineNumbers = methodLineNumbers != null ? methodLineNumbers : Collections.emptyMap();
}
@Override
@ -37,7 +43,8 @@ public class TracingClassVisitor extends ClassVisitor {
return mv;
}
int lineNumber = methodLineNumbers.getOrDefault(name + descriptor, 0);
return new TracingMethodAdapter(mv, access, name, descriptor,
internalClassName, 0, sourceFile != null ? sourceFile : "");
internalClassName, lineNumber, sourceFile != null ? sourceFile : "");
}
}

View file

@ -1,10 +1,16 @@
package com.codeflash.tracer;
import org.objectweb.asm.ClassReader;
import org.objectweb.asm.ClassVisitor;
import org.objectweb.asm.ClassWriter;
import org.objectweb.asm.Label;
import org.objectweb.asm.MethodVisitor;
import org.objectweb.asm.Opcodes;
import java.lang.instrument.ClassFileTransformer;
import java.security.ProtectionDomain;
import java.util.HashMap;
import java.util.Map;
public class TracingTransformer implements ClassFileTransformer {
@ -22,11 +28,6 @@ public class TracingTransformer implements ClassFileTransformer {
return null;
}
// Skip instrumentation if we're inside a recording call (e.g., during Kryo serialization)
if (TraceRecorder.isRecording()) {
return null;
}
// Skip internal JDK, framework, and synthetic classes
if (className.startsWith("java/")
|| className.startsWith("javax/")
@ -51,6 +52,30 @@ public class TracingTransformer implements ClassFileTransformer {
private byte[] instrumentClass(String internalClassName, byte[] bytecode) {
ClassReader cr = new ClassReader(bytecode);
// Pre-scan: collect the first source line number for each method.
// ASM's visitMethod() doesn't provide line info it arrives later via visitLineNumber().
// We do a lightweight read pass first so the instrumentation pass has accurate line numbers.
Map<String, Integer> methodLineNumbers = new HashMap<>();
cr.accept(new ClassVisitor(Opcodes.ASM9) {
@Override
public MethodVisitor visitMethod(int access, String name, String descriptor,
String signature, String[] exceptions) {
String key = name + descriptor;
return new MethodVisitor(Opcodes.ASM9) {
private boolean captured = false;
@Override
public void visitLineNumber(int line, Label start) {
if (!captured) {
methodLineNumbers.put(key, line);
captured = true;
}
}
};
}
}, ClassReader.SKIP_FRAMES);
// Use COMPUTE_MAXS only (not COMPUTE_FRAMES) to preserve original stack map frames.
// COMPUTE_FRAMES recomputes all frames and calls getCommonSuperClass() which either
// triggers classloader deadlocks or produces incorrect frames when returning "java/lang/Object".
@ -58,7 +83,7 @@ public class TracingTransformer implements ClassFileTransformer {
// adjusts offsets for injected code. Our AdviceAdapter only injects at method entry
// (before any branch points), so existing frames remain valid.
ClassWriter cw = new ClassWriter(cr, ClassWriter.COMPUTE_MAXS);
TracingClassVisitor cv = new TracingClassVisitor(cw, internalClassName);
TracingClassVisitor cv = new TracingClassVisitor(cw, internalClassName, methodLineNumbers);
cr.accept(cv, ClassReader.EXPAND_FRAMES);
return cw.toByteArray();
}

File diff suppressed because it is too large Load diff

View file

@ -1,10 +1,14 @@
from __future__ import annotations
import gc
import importlib.util
import os
import sqlite3
import statistics
import sys
import time
from dataclasses import dataclass
from math import ceil
from pathlib import Path
from typing import TYPE_CHECKING
@ -18,6 +22,96 @@ if TYPE_CHECKING:
PYTEST_BENCHMARK_INSTALLED = importlib.util.find_spec("pytest_benchmark") is not None
# Calibration defaults (matching pytest-benchmark)
MIN_TIME = 0.000005 # 5µs — minimum time per round during calibration
MAX_TIME = 1.0 # 1s — maximum wall-clock time per test
MIN_ROUNDS = 5
CALIBRATION_PRECISION = 10
@dataclass
class BenchmarkStats:
min_ns: float
max_ns: float
mean_ns: float
median_ns: float
stddev_ns: float
iqr_ns: float
rounds: int
iterations: int
outliers: str
@staticmethod
def from_per_iteration_times(times_ns: list[float], iterations: int) -> BenchmarkStats:
n = len(times_ns)
sorted_times = sorted(times_ns)
q1 = sorted_times[n // 4] if n >= 4 else sorted_times[0]
q3 = sorted_times[3 * n // 4] if n >= 4 else sorted_times[-1]
iqr = q3 - q1
low_fence = q1 - 1.5 * iqr
high_fence = q3 + 1.5 * iqr
mild_outliers = sum(1 for t in times_ns if t < low_fence or t > high_fence)
severe_fence_low = q1 - 3.0 * iqr
severe_fence_high = q3 + 3.0 * iqr
severe_outliers = sum(1 for t in times_ns if t < severe_fence_low or t > severe_fence_high)
return BenchmarkStats(
min_ns=min(times_ns),
max_ns=max(times_ns),
mean_ns=statistics.mean(times_ns),
median_ns=statistics.median(times_ns),
stddev_ns=statistics.stdev(times_ns) if n > 1 else 0.0,
iqr_ns=iqr,
rounds=n,
iterations=iterations,
outliers=f"{severe_outliers};{mild_outliers}",
)
@dataclass
class MemoryStats:
peak_memory_bytes: int
total_allocations: int
@staticmethod
def parse_memray_results(bin_dir: Path, bin_prefix: str) -> dict:
from codeflash.models.models import BenchmarkKey
try:
from memray import FileReader
except ImportError as e:
msg = "memray is required for --memory profiling. Install with: uv add memray pytest-memray"
raise ImportError(msg) from e
results: dict[BenchmarkKey, MemoryStats] = {}
for bin_file in sorted(bin_dir.glob(f"{bin_prefix}-*.bin")):
stem = bin_file.stem
# pytest-memray names: {prefix}-{nodeid with :: and os.sep replaced by -}.bin
nodeid_part = stem[len(bin_prefix) + 1 :] # strip "{prefix}-"
# Extract the test function name (last segment after the final -)
# Node IDs look like: tests-benchmarks-test_file.py-test_func_name
# We need the module_path and function_name for BenchmarkKey
# Split on ".py-" to separate module path from function name
parts = nodeid_part.split(".py-", 1)
if len(parts) == 2:
module_part = parts[0].replace("-", ".")
function_name = parts[1]
else:
module_part = nodeid_part.rsplit("-", 1)[0].replace("-", ".")
function_name = nodeid_part.rsplit("-", 1)[-1] if "-" in nodeid_part else nodeid_part
try:
reader = FileReader(str(bin_file))
meta = reader.metadata
bm_key = BenchmarkKey(module_path=module_part, function_name=function_name)
results[bm_key] = MemoryStats(
peak_memory_bytes=meta.peak_memory, total_allocations=meta.total_allocations
)
reader.close()
except OSError:
continue
return results
class CodeFlashBenchmarkPlugin:
def __init__(self) -> None:
@ -28,7 +122,6 @@ class CodeFlashBenchmarkPlugin:
def setup(self, trace_path: str, project_root: str) -> None:
try:
# Open connection
self.project_root = project_root
self._trace_path = trace_path
self._connection = sqlite3.connect(self._trace_path)
@ -38,10 +131,10 @@ class CodeFlashBenchmarkPlugin:
cur.execute(
"CREATE TABLE IF NOT EXISTS benchmark_timings("
"benchmark_module_path TEXT, benchmark_function_name TEXT, benchmark_line_number INTEGER,"
"benchmark_time_ns INTEGER)"
"round_index INTEGER, iterations INTEGER, round_time_ns INTEGER)"
)
self._connection.commit()
self.close() # Reopen only at the end of pytest session
self.close()
except Exception as e:
print(f"Database setup error: {e}")
if self._connection:
@ -51,20 +144,21 @@ class CodeFlashBenchmarkPlugin:
def write_benchmark_timings(self) -> None:
if not self.benchmark_timings:
return # No data to write
return
if self._connection is None:
self._connection = sqlite3.connect(self._trace_path)
try:
cur = self._connection.cursor()
# Insert data into the benchmark_timings table
cur.executemany(
"INSERT INTO benchmark_timings (benchmark_module_path, benchmark_function_name, benchmark_line_number, benchmark_time_ns) VALUES (?, ?, ?, ?)",
"INSERT INTO benchmark_timings "
"(benchmark_module_path, benchmark_function_name, benchmark_line_number, "
"round_index, iterations, round_time_ns) VALUES (?, ?, ?, ?, ?, ?)",
self.benchmark_timings,
)
self._connection.commit()
self.benchmark_timings = [] # Clear the benchmark timings list
self.benchmark_timings = []
except Exception as e:
print(f"Error writing to benchmark timings database: {e}")
self._connection.rollback()
@ -76,124 +170,107 @@ class CodeFlashBenchmarkPlugin:
self._connection = None
@staticmethod
def get_function_benchmark_timings(trace_path: Path) -> dict[str, dict[BenchmarkKey, int]]:
def get_function_benchmark_timings(trace_path: Path) -> dict[str, dict[BenchmarkKey, float]]:
from codeflash.models.models import BenchmarkKey
"""Process the trace file and extract timing data for all functions.
Args:
----
trace_path: Path to the trace file
Returns:
-------
A nested dictionary where:
- Outer keys are module_name.qualified_name (module.class.function)
- Inner keys are of type BenchmarkKey
- Values are function timing in milliseconds
"""
# Initialize the result dictionary
result = {}
# Connect to the SQLite database
result: dict[str, dict[BenchmarkKey, float]] = {}
connection = sqlite3.connect(trace_path)
cursor = connection.cursor()
try:
# Query the function_calls table for all function calls
# Get total iterations per benchmark to normalize
cursor.execute(
"SELECT benchmark_module_path, benchmark_function_name, "
"SUM(iterations) FROM benchmark_timings "
"GROUP BY benchmark_module_path, benchmark_function_name"
)
total_iterations: dict[BenchmarkKey, int] = {}
for row in cursor.fetchall():
bm_file, bm_func, total_iters = row
key = BenchmarkKey(module_path=bm_file, function_name=bm_func)
total_iterations[key] = total_iters
cursor.execute(
"SELECT module_name, class_name, function_name, "
"benchmark_module_path, benchmark_function_name, benchmark_line_number, function_time_ns "
"FROM benchmark_function_timings"
)
# Process each row
# Accumulate total function time
raw_totals: dict[str, dict[BenchmarkKey, int]] = {}
for row in cursor.fetchall():
module_name, class_name, function_name, benchmark_file, benchmark_func, _benchmark_line, time_ns = row
# Create the function key (module_name.class_name.function_name)
if class_name:
qualified_name = f"{module_name}.{class_name}.{function_name}"
else:
qualified_name = f"{module_name}.{function_name}"
# Create the benchmark key (file::function::line)
benchmark_key = BenchmarkKey(module_path=benchmark_file, function_name=benchmark_func)
# Initialize the inner dictionary if needed
if qualified_name not in result:
result[qualified_name] = {}
if qualified_name not in raw_totals:
raw_totals[qualified_name] = {}
raw_totals[qualified_name][benchmark_key] = raw_totals[qualified_name].get(benchmark_key, 0) + time_ns
# If multiple calls to the same function in the same benchmark,
# add the times together
if benchmark_key in result[qualified_name]:
result[qualified_name][benchmark_key] += time_ns
else:
result[qualified_name][benchmark_key] = time_ns
# Normalize to per-iteration average
for qualified_name, bm_dict in raw_totals.items():
result[qualified_name] = {}
for bm_key, total_ns in bm_dict.items():
iters = total_iterations.get(bm_key, 1)
result[qualified_name][bm_key] = total_ns / iters
finally:
# Close the connection
connection.close()
return result
@staticmethod
def get_benchmark_timings(trace_path: Path) -> dict[BenchmarkKey, int]:
def get_benchmark_timings(trace_path: Path) -> dict[BenchmarkKey, BenchmarkStats]:
from codeflash.models.models import BenchmarkKey
"""Extract total benchmark timings from trace files.
Args:
----
trace_path: Path to the trace file
Returns:
-------
A dictionary mapping where:
- Keys are of type BenchmarkKey
- Values are total benchmark timing in milliseconds (with overhead subtracted)
"""
# Initialize the result dictionary
result = {}
overhead_by_benchmark = {}
# Connect to the SQLite database
connection = sqlite3.connect(trace_path)
cursor = connection.cursor()
try:
# Query the benchmark_function_timings table to get total overhead for each benchmark
# Get overhead per benchmark to subtract
cursor.execute(
"SELECT benchmark_module_path, benchmark_function_name, benchmark_line_number, SUM(overhead_time_ns) "
"FROM benchmark_function_timings "
"GROUP BY benchmark_module_path, benchmark_function_name, benchmark_line_number"
)
# Process overhead information
overhead_by_benchmark: dict[BenchmarkKey, int] = {}
for row in cursor.fetchall():
benchmark_file, benchmark_func, _benchmark_line, total_overhead_ns = row
benchmark_key = BenchmarkKey(module_path=benchmark_file, function_name=benchmark_func)
overhead_by_benchmark[benchmark_key] = total_overhead_ns or 0 # Handle NULL sum case
bm_file, bm_func, _bm_line, total_overhead_ns = row
key = BenchmarkKey(module_path=bm_file, function_name=bm_func)
overhead_by_benchmark[key] = total_overhead_ns or 0
# Query the benchmark_timings table for total times
# Get per-round data
cursor.execute(
"SELECT benchmark_module_path, benchmark_function_name, benchmark_line_number, benchmark_time_ns "
"FROM benchmark_timings"
"SELECT benchmark_module_path, benchmark_function_name, benchmark_line_number, "
"round_index, iterations, round_time_ns "
"FROM benchmark_timings ORDER BY round_index"
)
# Process each row and subtract overhead
rounds_data: dict[BenchmarkKey, list[tuple[int, int]]] = {}
for row in cursor.fetchall():
benchmark_file, benchmark_func, _benchmark_line, time_ns = row
bm_file, bm_func, _bm_line, _round_idx, iterations, round_time_ns = row
key = BenchmarkKey(module_path=bm_file, function_name=bm_func)
if key not in rounds_data:
rounds_data[key] = []
rounds_data[key].append((iterations, round_time_ns))
# Create the benchmark key (file::function::line)
benchmark_key = BenchmarkKey(module_path=benchmark_file, function_name=benchmark_func)
# Subtract overhead from total time
overhead = overhead_by_benchmark.get(benchmark_key, 0)
result[benchmark_key] = time_ns - overhead
result: dict[BenchmarkKey, BenchmarkStats] = {}
for bm_key, rounds in rounds_data.items():
total_overhead = overhead_by_benchmark.get(bm_key, 0)
total_rounds = len(rounds)
overhead_per_round = total_overhead / total_rounds if total_rounds > 0 else 0
iterations = rounds[0][0] # All rounds have same iteration count
per_iteration_times = []
for iters, round_time_ns in rounds:
adjusted = max(0, round_time_ns - overhead_per_round)
per_iteration_times.append(adjusted / iters)
result[bm_key] = BenchmarkStats.from_per_iteration_times(per_iteration_times, iterations)
finally:
# Close the connection
connection.close()
return result
@ -201,56 +278,42 @@ class CodeFlashBenchmarkPlugin:
# Pytest hooks
@pytest.hookimpl
def pytest_sessionfinish(self, session, exitstatus) -> None:
"""Execute after whole test run is completed."""
# Write any remaining benchmark timings to the database
codeflash_trace.close()
if self.benchmark_timings:
self.write_benchmark_timings()
# Close the database connection
self.close()
@staticmethod
def pytest_collection_modifyitems(config: pytest.Config, items: list[pytest.Item]) -> None:
# Skip tests that don't have the benchmark fixture
if not config.getoption("--codeflash-trace"):
return
skip_no_benchmark = pytest.mark.skip(reason="Test requires benchmark fixture")
for item in items:
# Check for direct benchmark fixture usage
has_fixture = hasattr(item, "fixturenames") and "benchmark" in item.fixturenames # ty:ignore[unsupported-operator]
# Check for @pytest.mark.benchmark marker
has_marker = False
if hasattr(item, "get_closest_marker"):
marker = item.get_closest_marker("benchmark")
if marker is not None:
has_marker = True
# Skip if neither fixture nor marker is present
if not (has_fixture or has_marker):
item.add_marker(skip_no_benchmark)
# Benchmark fixture
class Benchmark: # noqa: D106
def __init__(self, request: pytest.FixtureRequest) -> None:
self.request = request
def __call__(self, func, *args, **kwargs): # noqa: ANN002, ANN003, ANN204
"""Handle both direct function calls and decorator usage."""
if args or kwargs:
# Used as benchmark(func, *args, **kwargs)
return self._run_benchmark(func, *args, **kwargs)
return self.run_benchmark(func, *args, **kwargs)
# Used as @benchmark decorator
def wrapped_func(*args, **kwargs): # noqa: ANN002, ANN003
return func(*args, **kwargs)
self._run_benchmark(func)
self.run_benchmark(func)
return wrapped_func
def _run_benchmark(self, func, *args, **kwargs): # noqa: ANN002, ANN003
"""Actual benchmark implementation."""
def run_benchmark(self, func, *args, **kwargs): # noqa: ANN002, ANN003, ANN201
node_path = getattr(self.request.node, "path", None) or getattr(self.request.node, "fspath", None)
if node_path is None:
raise RuntimeError("Unable to determine test file path from pytest node")
@ -258,31 +321,87 @@ class CodeFlashBenchmarkPlugin:
benchmark_module_path = module_name_from_file_path(
Path(str(node_path)), Path(codeflash_benchmark_plugin.project_root), traverse_up=True
)
benchmark_function_name = self.request.node.name
line_number = int(str(sys._getframe(2).f_lineno)) # 2 frames up in the call stack # noqa: SLF001
# Set env vars
line_number = int(str(sys._getframe(2).f_lineno)) # noqa: SLF001
os.environ["CODEFLASH_BENCHMARK_FUNCTION_NAME"] = benchmark_function_name
os.environ["CODEFLASH_BENCHMARK_MODULE_PATH"] = benchmark_module_path
os.environ["CODEFLASH_BENCHMARK_LINE_NUMBER"] = str(line_number)
# Phase 1: Calibrate (tracing disabled to avoid overhead)
os.environ["CODEFLASH_BENCHMARKING"] = "False"
iterations, calibrated_duration = calibrate(func, args, kwargs)
# Phase 2: Multi-round benchmark (tracing enabled)
os.environ["CODEFLASH_BENCHMARKING"] = "True"
# Run the function
rounds = max(MIN_ROUNDS, ceil(MAX_TIME / calibrated_duration)) if calibrated_duration > 0 else MIN_ROUNDS
result = None
for round_idx in range(rounds):
gc_was_enabled = gc.isenabled()
gc.disable()
try:
start = time.perf_counter_ns()
for _ in range(iterations):
result = func(*args, **kwargs)
end = time.perf_counter_ns()
# Reset the environment variable
os.environ["CODEFLASH_BENCHMARKING"] = "False"
finally:
if gc_was_enabled:
gc.enable()
# Write function calls
codeflash_trace.write_function_timings()
# Reset function call count
codeflash_trace.function_call_count = 0
# Add to the benchmark timings buffer
round_time = end - start
codeflash_benchmark_plugin.benchmark_timings.append(
(benchmark_module_path, benchmark_function_name, line_number, end - start)
(benchmark_module_path, benchmark_function_name, line_number, round_idx, iterations, round_time)
)
# Flush function timings per round
codeflash_trace.write_function_timings()
codeflash_trace.function_call_count = 0
os.environ["CODEFLASH_BENCHMARKING"] = "False"
return result
def compute_timer_precision() -> float:
minimum = float("inf")
for _ in range(20):
t1 = time.perf_counter_ns()
t2 = time.perf_counter_ns()
dt = t2 - t1
if dt > 0:
minimum = min(minimum, dt)
return minimum / 1e9 # Convert to seconds
def calibrate(func, args, kwargs) -> tuple[int, float]:
timer_precision = compute_timer_precision()
min_time = max(MIN_TIME, timer_precision * CALIBRATION_PRECISION)
min_time_estimate = min_time * 5 / CALIBRATION_PRECISION
iterations = 1
while True:
gc_was_enabled = gc.isenabled()
gc.disable()
try:
start = time.perf_counter_ns()
for _ in range(iterations):
func(*args, **kwargs)
end = time.perf_counter_ns()
finally:
if gc_was_enabled:
gc.enable()
duration = (end - start) / 1e9 # Convert to seconds
if duration >= min_time:
break
if duration >= min_time_estimate:
iterations = ceil(min_time * iterations / duration)
else:
iterations *= 10
return iterations, duration
codeflash_benchmark_plugin = CodeFlashBenchmarkPlugin()

View file

@ -0,0 +1,42 @@
"""Subprocess entry point for memory profiling benchmarks via pytest-memray.
Runs pytest with --memray --native to profile peak memory per test function.
The codeflash-benchmark plugin is left active (without --codeflash-trace) so it
provides a no-op ``benchmark`` fixture for tests that depend on it.
"""
import sys
from pathlib import Path
benchmarks_root = sys.argv[1]
memray_bin_dir = sys.argv[2]
memray_bin_prefix = sys.argv[3]
if __name__ == "__main__":
import pytest
Path(memray_bin_dir).mkdir(parents=True, exist_ok=True)
exitcode = pytest.main(
[
benchmarks_root,
"--memray",
"--native",
f"--memray-bin-path={memray_bin_dir}",
f"--memray-bin-prefix={memray_bin_prefix}",
"--hide-memray-summary",
"-p",
"no:benchmark",
"-p",
"no:codspeed",
"-p",
"no:cov",
"-p",
"no:profiling",
"-s",
"-o",
"addopts=",
]
)
sys.exit(exitcode)

View file

@ -46,3 +46,39 @@ def trace_benchmarks_pytest(
error_section = combined_output
logger.warning(f"Error collecting benchmarks - Pytest Exit code: {result.returncode}, {error_section}")
logger.debug(f"Full pytest output:\n{combined_output}")
def memory_benchmarks_pytest(
benchmarks_root: Path, project_root: Path, memray_bin_dir: Path, memray_bin_prefix: str, timeout: int = 300
) -> None:
benchmark_env = make_env_with_project_root(project_root)
run_args = get_cross_platform_subprocess_run_args(
cwd=project_root, env=benchmark_env, timeout=timeout, check=False, text=True, capture_output=True
)
result = subprocess.run( # noqa: PLW1510
[
SAFE_SYS_EXECUTABLE,
Path(__file__).parent / "pytest_new_process_memory_benchmarks.py",
benchmarks_root,
memray_bin_dir,
memray_bin_prefix,
],
**run_args,
)
if result.returncode != 0:
combined_output = result.stdout
if result.stderr:
combined_output = combined_output + "\n" + result.stderr if combined_output else result.stderr
if "ERROR collecting" in combined_output:
error_pattern = r"={3,}\s*ERRORS\s*={3,}\n([\s\S]*?)(?:={3,}|$)"
match = re.search(error_pattern, combined_output)
error_section = match.group(1) if match else combined_output
elif "FAILURES" in combined_output:
error_pattern = r"={3,}\s*FAILURES\s*={3,}\n([\s\S]*?)(?:={3,}|$)"
match = re.search(error_pattern, combined_output)
error_section = match.group(1) if match else combined_output
else:
error_section = combined_output
logger.warning(f"Error collecting memory benchmarks - Pytest Exit code: {result.returncode}, {error_section}")
logger.debug(f"Full pytest output:\n{combined_output}")

View file

@ -1,6 +1,8 @@
from __future__ import annotations
import logging
import shutil
from operator import itemgetter
from typing import TYPE_CHECKING, Optional
from rich.console import Console
@ -16,27 +18,30 @@ if TYPE_CHECKING:
def validate_and_format_benchmark_table(
function_benchmark_timings: dict[str, dict[BenchmarkKey, int]], total_benchmark_timings: dict[BenchmarkKey, int]
function_benchmark_timings: dict[str, dict[BenchmarkKey, float]], total_benchmark_timings: dict[BenchmarkKey, float]
) -> dict[str, list[tuple[BenchmarkKey, float, float, float]]]:
function_to_result = {}
# Process each function's benchmark data
scale = 1_000_000.0
for func_path, test_times in function_benchmark_timings.items():
# Sort by percentage (highest first)
sorted_tests = []
for benchmark_key, func_time in test_times.items():
total_time = total_benchmark_timings.get(benchmark_key, 0)
if func_time > total_time:
logger.debug(f"Skipping test {benchmark_key} due to func_time {func_time} > total_time {total_time}")
# If the function time is greater than total time, likely to have multithreading / multiprocessing issues.
# Do not try to project the optimization impact for this function.
if logger.isEnabledFor(logging.DEBUG):
logger.debug(
f"Skipping test {benchmark_key} due to func_time {func_time} > total_time {total_time}"
)
sorted_tests.append((benchmark_key, 0.0, 0.0, 0.0))
elif total_time > 0:
percentage = (func_time / total_time) * 100
# Convert nanoseconds to milliseconds
func_time_ms = func_time / 1_000_000
total_time_ms = total_time / 1_000_000
func_time_ms = func_time / scale
total_time_ms = total_time / scale
sorted_tests.append((benchmark_key, total_time_ms, func_time_ms, percentage))
sorted_tests.sort(key=lambda x: x[3], reverse=True)
sorted_tests.sort(key=itemgetter(3), reverse=True)
function_to_result[func_path] = sorted_tests
return function_to_result
@ -77,8 +82,8 @@ def print_benchmark_table(function_to_results: dict[str, list[tuple[BenchmarkKey
def process_benchmark_data(
replay_performance_gain: dict[BenchmarkKey, float],
fto_benchmark_timings: dict[BenchmarkKey, int],
total_benchmark_timings: dict[BenchmarkKey, int],
fto_benchmark_timings: dict[BenchmarkKey, float],
total_benchmark_timings: dict[BenchmarkKey, float],
) -> Optional[ProcessedBenchmarkInfo]:
"""Process benchmark data and generate detailed benchmark information.

View file

@ -376,22 +376,40 @@ def _build_parser() -> ArgumentParser:
subparsers.add_parser("vscode-install", help="Install the Codeflash VSCode extension")
subparsers.add_parser("init-actions", help="Initialize GitHub Actions workflow")
trace_optimize = subparsers.add_parser("optimize", help="Trace and optimize your project.", add_help=False)
auth_parser = subparsers.add_parser("auth", help="Authentication commands")
auth_subparsers = auth_parser.add_subparsers(dest="auth_command", help="Auth sub-commands")
auth_subparsers.add_parser("login", help="Log in to Codeflash via OAuth")
auth_subparsers.add_parser("status", help="Check authentication status")
compare_parser = subparsers.add_parser("compare", help="Compare benchmark performance between two git refs.")
compare_parser.add_argument("base_ref", help="Base git ref (branch, tag, or commit)")
compare_parser.add_argument(
"base_ref", nargs="?", default=None, help="Base git ref (default: auto-detect from PR or default branch)"
)
compare_parser.add_argument("head_ref", nargs="?", default=None, help="Head git ref (default: current branch)")
compare_parser.add_argument("--pr", type=int, help="Resolve head ref from a PR number (requires gh CLI)")
compare_parser.add_argument(
"--functions", type=str, help="Explicit functions to instrument: 'file.py::func1,func2;other.py::func3'"
)
compare_parser.add_argument("--timeout", type=int, default=600, help="Benchmark timeout in seconds (default: 600)")
compare_parser.add_argument("--output", "-o", type=str, help="Write markdown report to file")
compare_parser.add_argument(
"--memory", action="store_true", help="Profile peak memory usage per benchmark (requires memray, Linux/macOS)"
)
compare_parser.add_argument("--script", type=str, help="Shell command to run as benchmark in each worktree")
compare_parser.add_argument(
"--script-output",
type=str,
dest="script_output",
help="Relative path to JSON results file produced by --script (required with --script)",
)
compare_parser.add_argument("--config-file", type=str, dest="config_file", help="Path to pyproject.toml")
trace_optimize = subparsers.add_parser("optimize", help="Trace and optimize your project.")
compare_parser.add_argument(
"--inject",
nargs="+",
default=None,
help="Files or directories to copy into both worktrees before benchmarking. Paths are relative to repo root.",
)
trace_optimize.add_argument(
"--max-function-count",

View file

@ -13,15 +13,76 @@ if TYPE_CHECKING:
from codeflash.models.function_types import FunctionToOptimize
from codeflash.cli_cmds.console import logger
from codeflash.code_utils.config_parser import parse_config_file
def run_compare(args: Namespace) -> None:
"""Entry point for the compare subcommand."""
# Load project config
pyproject_config, pyproject_file_path = parse_config_file(args.config_file)
# Resolve head_ref: explicit arg > --pr > current branch
head_ref = args.head_ref
if args.pr:
head_ref = resolve_pr_branch(args.pr)
if not head_ref:
head_ref = get_current_branch()
if not head_ref:
logger.error("Must provide head_ref, --pr, or be on a branch")
sys.exit(1)
logger.info(f"Auto-detected head ref: {head_ref}")
# Resolve base_ref: explicit arg > PR base branch > repo default branch
base_ref = args.base_ref
if not base_ref:
base_ref = detect_base_ref(head_ref)
if not base_ref:
logger.error("Could not auto-detect base ref. Provide it explicitly or ensure gh CLI is available.")
sys.exit(1)
logger.info(f"Auto-detected base ref: {base_ref}")
# Script mode: run an arbitrary benchmark command on each worktree (no codeflash config needed)
script_cmd = getattr(args, "script", None)
if script_cmd:
if getattr(args, "inject", None):
logger.warning("--inject is not supported in --script mode and will be ignored")
script_output = getattr(args, "script_output", None)
if not script_output:
logger.error("--script-output is required when using --script")
sys.exit(1)
import git
project_root = Path(git.Repo(Path.cwd(), search_parent_directories=True).working_dir)
from codeflash.benchmarking.compare import compare_with_script
result = compare_with_script(
base_ref=base_ref,
head_ref=head_ref,
project_root=project_root,
script_cmd=script_cmd,
script_output=script_output,
timeout=args.timeout,
memory=getattr(args, "memory", False),
)
if not result.base_results and not result.head_results:
logger.warning("No benchmark data collected. Check that --script-output points to a valid JSON file.")
sys.exit(1)
if args.output:
md = result.format_markdown()
Path(args.output).write_text(md, encoding="utf-8")
logger.info(f"Markdown report written to {args.output}")
return
# Standard trace-benchmark mode: requires codeflash config
from codeflash.code_utils.config_parser import parse_config_file
pyproject_config, pyproject_file_path = parse_config_file(args.config_file)
module_root = Path(pyproject_config.get("module_root", ".")).resolve()
from codeflash.cli_cmds.cli import project_root_from_module_root
project_root = project_root_from_module_root(module_root, pyproject_file_path)
tests_root = Path(pyproject_config.get("tests_root", "tests")).resolve()
benchmarks_root_str = pyproject_config.get("benchmarks_root")
@ -34,42 +95,90 @@ def run_compare(args: Namespace) -> None:
logger.error(f"benchmarks-root {benchmarks_root} is not a valid directory")
sys.exit(1)
from codeflash.cli_cmds.cli import project_root_from_module_root
project_root = project_root_from_module_root(module_root, pyproject_file_path)
# Resolve head_ref
head_ref = args.head_ref
if args.pr:
head_ref = _resolve_pr_branch(args.pr)
if not head_ref:
logger.error("Must provide head_ref or --pr")
sys.exit(1)
# Parse explicit functions if provided
functions = None
if args.functions:
functions = _parse_functions_arg(args.functions, project_root)
functions = parse_functions_arg(args.functions, project_root)
from codeflash.benchmarking.compare import compare_branches
result = compare_branches(
base_ref=args.base_ref,
base_ref=base_ref,
head_ref=head_ref,
project_root=project_root,
benchmarks_root=benchmarks_root,
tests_root=tests_root,
functions=functions,
timeout=args.timeout,
memory=getattr(args, "memory", False),
inject_paths=getattr(args, "inject", None),
)
if not result.base_total_ns and not result.head_total_ns:
if not result.base_stats and not result.head_stats:
logger.warning("No benchmark data collected. Check that benchmarks-root is configured and benchmarks exist.")
sys.exit(1)
if args.output:
md = result.format_markdown()
Path(args.output).write_text(md, encoding="utf-8")
logger.info(f"Markdown report written to {args.output}")
def _resolve_pr_branch(pr_number: int) -> str:
"""Resolve a PR number to its head branch name using gh CLI."""
def get_current_branch() -> str | None:
try:
result = subprocess.run(
["git", "rev-parse", "--abbrev-ref", "HEAD"], capture_output=True, text=True, check=True
)
branch = result.stdout.strip()
return branch if branch and branch != "HEAD" else None
except (FileNotFoundError, subprocess.CalledProcessError):
return None
def detect_base_ref(head_ref: str) -> str | None:
# Try to find an open PR for this branch and use its base
try:
result = subprocess.run(
["gh", "pr", "view", head_ref, "--json", "baseRefName", "-q", ".baseRefName"],
capture_output=True,
text=True,
check=True,
)
base = result.stdout.strip()
if base:
return base
except (FileNotFoundError, subprocess.CalledProcessError):
pass
# Fall back to repo default branch
try:
result = subprocess.run(
["gh", "repo", "view", "--json", "defaultBranchRef", "-q", ".defaultBranchRef.name"],
capture_output=True,
text=True,
check=True,
)
default = result.stdout.strip()
if default:
return default
except (FileNotFoundError, subprocess.CalledProcessError):
pass
# Last resort: check for common default branch names
try:
for candidate in ("main", "master"):
result = subprocess.run(
["git", "rev-parse", "--verify", candidate], capture_output=True, text=True, check=False
)
if result.returncode == 0:
return candidate
except FileNotFoundError:
pass
return None
def resolve_pr_branch(pr_number: int) -> str:
try:
result = subprocess.run(
["gh", "pr", "view", str(pr_number), "--json", "headRefName", "-q", ".headRefName"],
@ -91,7 +200,7 @@ def _resolve_pr_branch(pr_number: int) -> str:
sys.exit(1)
def _parse_functions_arg(functions_str: str, project_root: Path) -> dict[Path, list[FunctionToOptimize]]:
def parse_functions_arg(functions_str: str, project_root: Path) -> dict[Path, list[FunctionToOptimize]]:
"""Parse --functions arg format: 'file.py::func1,func2;other.py::func3'."""
from codeflash.models.function_types import FunctionToOptimize

View file

@ -554,11 +554,13 @@ def get_all_replay_test_functions(
def _get_java_replay_test_functions(
replay_test: list[Path], test_cfg: TestConfig, project_root_path: Path
replay_test: list[Path], test_cfg: TestConfig, project_root_path: Path | str
) -> tuple[dict[Path, list[FunctionToOptimize]], Path]:
"""Parse Java replay test files to extract functions and trace file path."""
from codeflash.languages.java.replay_test import parse_replay_test_metadata
project_root_path = Path(project_root_path)
trace_file_path: Path | None = None
functions: dict[Path, list[FunctionToOptimize]] = defaultdict(list)
@ -602,7 +604,7 @@ def _get_java_replay_test_functions(
all_functions = lang_support.discover_functions(source_code, source_file)
for func in all_functions:
if func.function_name in function_names:
if func.function_name in function_names or func.qualified_name in function_names:
functions[source_file].append(func)
if trace_file_path is None:

View file

@ -2787,6 +2787,25 @@ class FunctionOptimizer:
did_pass_all_tests = all(result.did_pass for result in behavioral_results)
if not did_pass_all_tests:
return Failure("Tests failed to pass for the original code.")
# Check if coverage data was not found (file excluded from coverage)
from codeflash.models.models import CoverageStatus
if coverage_results and coverage_results.status == CoverageStatus.NOT_FOUND:
# File was not found in coverage data - likely excluded by test framework config
logger.warning(
f"No coverage data found for {self.function_to_optimize.file_path}. "
f"This file may be excluded from coverage collection by your test framework configuration "
f"(e.g., coverage.exclude in vitest.config.ts for Vitest, or testMatch/coveragePathIgnorePatterns "
f"for Jest). Tests ran successfully but coverage cannot be measured."
)
return Failure(
f"Coverage data not found for {self.function_to_optimize.file_path}. "
f"The file may be excluded from coverage by your test framework config. "
f"Check coverage.exclude patterns in vitest.config.ts or jest.config.js."
)
# Normal coverage failure (tests ran but coverage below threshold)
coverage_pct = coverage_results.coverage if coverage_results else 0
return Failure(
f"Test coverage is {coverage_pct}%, which is below the required threshold of {COVERAGE_THRESHOLD}%."
@ -3066,6 +3085,16 @@ class FunctionOptimizer:
)
)
def get_js_project_root(self) -> Path | None:
# Only calculate for JavaScript/TypeScript projects
if self.function_to_optimize.language not in ("javascript", "typescript"):
return self.test_cfg.js_project_root # Fall back to cached value for non-JS
# For JS/TS, calculate fresh for each function to support monorepos
from codeflash.languages.javascript.test_runner import find_node_project_root
return find_node_project_root(Path(self.function_to_optimize.file_path))
def run_and_parse_tests(
self,
testing_type: TestingMode,
@ -3084,33 +3113,39 @@ class FunctionOptimizer:
coverage_config_file = None
try:
if testing_type == TestingMode.BEHAVIOR:
# Calculate js_project_root for the current function being optimized
# instead of using cached value from test_cfg, which may be from a different function
js_project_root = self.get_js_project_root()
result_file_path, run_result, coverage_database_file, coverage_config_file = (
self.language_support.run_behavioral_tests(
test_paths=test_files,
test_env=test_env,
cwd=self.project_root,
timeout=INDIVIDUAL_TESTCASE_TIMEOUT,
project_root=self.test_cfg.js_project_root,
project_root=js_project_root,
enable_coverage=enable_coverage,
candidate_index=optimization_iteration,
)
)
elif testing_type == TestingMode.LINE_PROFILE:
js_project_root = self.get_js_project_root()
result_file_path, run_result = self.language_support.run_line_profile_tests(
test_paths=test_files,
test_env=test_env,
cwd=self.project_root,
timeout=INDIVIDUAL_TESTCASE_TIMEOUT,
project_root=self.test_cfg.js_project_root,
project_root=js_project_root,
line_profile_output_file=line_profiler_output_file,
)
elif testing_type == TestingMode.PERFORMANCE:
js_project_root = self.get_js_project_root()
result_file_path, run_result = self.language_support.run_benchmarking_tests(
test_paths=test_files,
test_env=test_env,
cwd=self.project_root,
timeout=INDIVIDUAL_TESTCASE_TIMEOUT,
project_root=self.test_cfg.js_project_root,
project_root=js_project_root,
min_loops=pytest_min_loops,
max_loops=pytest_max_loops,
target_duration_seconds=testing_time,

View file

@ -9,6 +9,7 @@ from __future__ import annotations
import logging
import os
import shutil
from abc import ABC, abstractmethod
from pathlib import Path
from typing import TYPE_CHECKING, Any
@ -20,7 +21,7 @@ if TYPE_CHECKING:
logger = logging.getLogger(__name__)
_RUNTIME_JAR_NAME = "codeflash-runtime-1.0.0.jar"
_RUNTIME_JAR_NAME = "codeflash-runtime-1.0.1.jar"
_JAVA_RUNTIME_DIR = Path(__file__).parent.parent.parent.parent / "codeflash-java-runtime"
@ -73,6 +74,18 @@ class BuildToolStrategy(ABC):
return None
def find_wrapper_executable(
self, build_root: Path, wrapper_names: tuple[str, ...], system_command: str
) -> str | None:
search = build_root.resolve()
while search != search.parent:
for name in wrapper_names:
candidate = search / name
if candidate.exists():
return str(candidate)
search = search.parent
return shutil.which(system_command)
@abstractmethod
def find_executable(self, build_root: Path) -> str | None:
"""Find the build tool executable, searching up parent directories if needed."""

View file

@ -14,7 +14,7 @@ from pathlib import Path # noqa: TC003 — used at runtime
logger = logging.getLogger(__name__)
CODEFLASH_RUNTIME_VERSION = "1.0.0"
CODEFLASH_RUNTIME_VERSION = "1.0.1"
CODEFLASH_RUNTIME_JAR_NAME = f"codeflash-runtime-{CODEFLASH_RUNTIME_VERSION}.jar"
JACOCO_PLUGIN_VERSION = "0.8.13"

View file

@ -45,7 +45,8 @@ gradle.projectsEvaluated {
'spotbugsMain', 'spotbugsTest',
'pmdMain', 'pmdTest',
'rat', 'japicmp',
'jarHell', 'thirdPartyAudit'
'jarHell', 'thirdPartyAudit',
'spotlessCheck', 'spotlessApply', 'spotlessJava', 'spotlessKotlin', 'spotlessScala'
]
}.configureEach {
enabled = false
@ -417,22 +418,7 @@ class GradleStrategy(BuildToolStrategy):
)
def find_executable(self, build_root: Path) -> str | None:
# Walk up from build_root to find gradlew — for multi-module projects
# the wrapper lives at the repo root, which may be a parent of build_root.
current = build_root.resolve()
while True:
gradlew_path = current / "gradlew"
if gradlew_path.exists():
return str(gradlew_path)
gradlew_bat_path = current / "gradlew.bat"
if gradlew_bat_path.exists():
return str(gradlew_bat_path)
parent = current.parent
if parent == current:
break
current = parent
# Fall back to system Gradle
return shutil.which("gradle")
return self.find_wrapper_executable(build_root, ("gradlew", "gradlew.bat"), "gradle")
def ensure_runtime(self, build_root: Path, test_module: str | None) -> bool:
runtime_jar = self.find_runtime_jar()
@ -447,7 +433,7 @@ class GradleStrategy(BuildToolStrategy):
libs_dir = module_root / "libs"
libs_dir.mkdir(parents=True, exist_ok=True)
dest_jar = libs_dir / "codeflash-runtime-1.0.0.jar"
dest_jar = libs_dir / "codeflash-runtime-1.0.1.jar"
if not dest_jar.exists():
logger.info("Copying codeflash-runtime JAR to %s", dest_jar)

View file

@ -745,6 +745,15 @@ def _add_behavior_instrumentation(source: str, class_name: str, func_name: str,
if _is_test_annotation(stripped):
if not helper_added:
helper_added = True
# Check if the @Test line already contains the method signature and opening brace
# (common in compact test styles like replay tests: @Test void replay_foo_0() throws Exception {)
if "{" in line:
# The annotation line IS the method signature — don't look for a separate one
result.append(line)
i += 1
method_lines = [line]
else:
result.append(line)
i += 1

View file

@ -2,6 +2,7 @@ from __future__ import annotations
import json
import logging
import os
import shutil
import subprocess
from datetime import datetime
@ -42,6 +43,13 @@ class JfrProfile:
candidate = Path(home) / "bin" / "jfr"
if candidate.exists():
return str(candidate)
java_home_env = os.environ.get("JAVA_HOME")
if java_home_env:
candidate = Path(java_home_env) / "bin" / "jfr"
if candidate.exists():
return str(candidate)
return None
def _parse(self) -> None:
@ -152,6 +160,8 @@ class JfrProfile:
method_name = method.get("name", "")
if not class_name or not method_name:
return None
# JFR uses / separators (JVM internal format), normalize to dots for package matching
class_name = class_name.replace("/", ".")
return f"{class_name}.{method_name}"
def _store_method_info(self, key: str, frame: dict[str, Any]) -> None:
@ -159,7 +169,7 @@ class JfrProfile:
return
method = frame.get("method", {})
self._method_info[key] = {
"class_name": method.get("type", {}).get("name", ""),
"class_name": method.get("type", {}).get("name", "").replace("/", "."),
"method_name": method.get("name", ""),
"descriptor": method.get("descriptor", ""),
"line_number": str(frame.get("lineNumber", 0)),

View file

@ -43,6 +43,8 @@ _MAVEN_VALIDATION_SKIP_FLAGS = [
"-Denforcer.skip=true",
"-Djapicmp.skip=true",
"-Derrorprone.skip=true",
"-Dspotless.check.skip=true",
"-Dspotless.apply.skip=true",
"-Dmaven.compiler.failOnWarning=false",
"-Dmaven.compiler.showWarnings=false",
]
@ -62,11 +64,11 @@ GITHUB_RELEASE_URL = (
CODEFLASH_CACHE_DIR = Path.home() / ".cache" / "codeflash"
CODEFLASH_DEPENDENCY_SNIPPET = """\
CODEFLASH_DEPENDENCY_SNIPPET = f"""\
<dependency>
<groupId>com.codeflash</groupId>
<artifactId>codeflash-runtime</artifactId>
<version>1.0.0</version>
<version>{CODEFLASH_RUNTIME_VERSION}</version>
<scope>test</scope>
</dependency>
</dependencies>"""
@ -140,7 +142,7 @@ def install_codeflash_runtime(project_root: Path, runtime_jar_path: Path, mvn: s
f"-Dfile={runtime_jar_path}",
"-DgroupId=com.codeflash",
"-DartifactId=codeflash-runtime",
"-Dversion=1.0.0",
f"-Dversion={CODEFLASH_RUNTIME_VERSION}",
"-Dpackaging=jar",
"-B",
]
@ -288,26 +290,26 @@ def add_codeflash_dependency(pom_path: Path) -> bool:
content = pom_path.read_text(encoding="utf-8")
if "codeflash-runtime" in content:
if "<scope>system</scope>" in content:
def replace_system_dep(match: re.Match[str]) -> str:
def update_codeflash_dep(match: re.Match[str]) -> str:
block: str = match.group(0)
if "codeflash-runtime" in block and "<scope>system</scope>" in block:
if "codeflash-runtime" not in block:
return block
return (
"<dependency>\n"
" <groupId>com.codeflash</groupId>\n"
" <artifactId>codeflash-runtime</artifactId>\n"
" <version>1.0.0</version>\n"
f" <version>{CODEFLASH_RUNTIME_VERSION}</version>\n"
" <scope>test</scope>\n"
" </dependency>"
)
return block
content = re.sub(r"<dependency>[\s\S]*?</dependency>", replace_system_dep, content)
pom_path.write_text(content, encoding="utf-8")
logger.info("Replaced system-scope codeflash-runtime dependency with test scope")
return True
logger.info("codeflash-runtime dependency already present in pom.xml")
updated = re.sub(r"<dependency>[\s\S]*?</dependency>", update_codeflash_dep, content)
if updated != content:
pom_path.write_text(updated, encoding="utf-8")
logger.info("Updated codeflash-runtime dependency to version %s in pom.xml", CODEFLASH_RUNTIME_VERSION)
else:
logger.info("codeflash-runtime dependency already up to date in pom.xml")
return True
closing_tag = "</dependencies>"
@ -571,8 +573,8 @@ class MavenStrategy(BuildToolStrategy):
/ "com"
/ "codeflash"
/ "codeflash-runtime"
/ "1.0.0"
/ "codeflash-runtime-1.0.0.jar"
/ "1.0.1"
/ "codeflash-runtime-1.0.1.jar"
)
@property
@ -647,17 +649,7 @@ class MavenStrategy(BuildToolStrategy):
return None
def find_executable(self, build_root: Path) -> str | None:
mvnw_path = build_root / "mvnw"
if mvnw_path.exists():
return str(mvnw_path)
mvnw_cmd_path = build_root / "mvnw.cmd"
if mvnw_cmd_path.exists():
return str(mvnw_cmd_path)
if Path("mvnw").exists():
return "./mvnw"
if Path("mvnw.cmd").exists():
return "mvnw.cmd"
return shutil.which("mvn")
return self.find_wrapper_executable(build_root, ("mvnw", "mvnw.cmd"), "mvn")
def find_runtime_jar(self) -> Path | None:
if self._M2_JAR.exists():
@ -916,7 +908,15 @@ class MavenStrategy(BuildToolStrategy):
" --add-opens java.base/java.net=ALL-UNNAMED"
" --add-opens java.base/java.util.zip=ALL-UNNAMED"
)
if enable_coverage:
# When coverage is enabled, JaCoCo's prepare-agent goal sets argLine via
# @{argLine}. Overriding -DargLine would clobber the JaCoCo agent flag.
# Pass add-opens and javaagent via JDK_JAVA_OPTIONS instead.
jdk_opts_parts = [add_opens_flags]
if javaagent_arg:
jdk_opts_parts.insert(0, javaagent_arg)
env["JDK_JAVA_OPTIONS"] = " ".join(jdk_opts_parts)
elif javaagent_arg:
cmd.append(f"-DargLine={javaagent_arg} {add_opens_flags}")
else:
cmd.append(f"-DargLine={add_opens_flags}")

View file

@ -12,9 +12,12 @@ if TYPE_CHECKING:
logger = logging.getLogger(__name__)
def generate_replay_tests(trace_db_path: Path, output_dir: Path, project_root: Path, max_run_count: int = 256) -> int:
"""Generate JUnit 5 replay test files from a trace SQLite database.
def generate_replay_tests(
trace_db_path: Path, output_dir: Path, project_root: Path, max_run_count: int = 256, test_framework: str = "junit5"
) -> int:
"""Generate JUnit replay test files from a trace SQLite database.
Supports both JUnit 5 (default) and JUnit 4.
Returns the number of test files generated.
"""
if not trace_db_path.exists():
@ -44,22 +47,29 @@ def generate_replay_tests(trace_db_path: Path, output_dir: Path, project_root: P
test_methods_code: list[str] = []
class_function_names: list[str] = []
# Global test counter to avoid duplicate method names for overloaded Java methods
method_name_counters: dict[str, int] = {}
for method_name, descriptor in method_list:
# Count invocations for this method
count_result = conn.execute(
"SELECT COUNT(*) FROM function_calls WHERE classname = ? AND function = ? AND descriptor = ?",
(classname, method_name, descriptor),
).fetchone()
invocation_count = min(count_result[0], max_run_count)
class_function_names.append(method_name)
simple_class = classname.rsplit(".", 1)[-1]
class_function_names.append(f"{simple_class}.{method_name}")
safe_method = _sanitize_identifier(method_name)
for i in range(invocation_count):
# Use a global counter per method name to avoid collisions on overloaded methods
test_idx = method_name_counters.get(safe_method, 0)
method_name_counters[safe_method] = test_idx + 1
escaped_descriptor = descriptor.replace('"', '\\"')
access = "public " if test_framework == "junit4" else ""
test_methods_code.append(
f" @Test void replay_{safe_method}_{i}() throws Exception {{\n"
f" @Test {access}void replay_{safe_method}_{test_idx}() throws Exception {{\n"
f' helper.replay("{classname}", "{method_name}", '
f'"{escaped_descriptor}", {i});\n'
f" }}"
@ -69,18 +79,28 @@ def generate_replay_tests(trace_db_path: Path, output_dir: Path, project_root: P
# Generate the test file
functions_comment = ",".join(class_function_names)
if test_framework == "junit4":
test_imports = "import org.junit.Test;\nimport org.junit.AfterClass;\n"
cleanup_annotation = "@AfterClass"
class_modifier = "public "
else:
test_imports = "import org.junit.jupiter.api.Test;\nimport org.junit.jupiter.api.AfterAll;\n"
cleanup_annotation = "@AfterAll"
class_modifier = ""
test_content = (
f"// codeflash:functions={functions_comment}\n"
f"// codeflash:trace_file={trace_db_path.as_posix()}\n"
f"// codeflash:classname={classname}\n"
f"package codeflash.replay;\n\n"
f"import org.junit.jupiter.api.Test;\n"
f"import org.junit.jupiter.api.AfterAll;\n"
f"{test_imports}"
f"import com.codeflash.ReplayHelper;\n\n"
f"class {test_class_name} {{\n"
f"{class_modifier}class {test_class_name} {{\n"
f" private static final ReplayHelper helper =\n"
f' new ReplayHelper("{trace_db_path.as_posix()}");\n\n'
f" @AfterAll static void cleanup() {{ helper.close(); }}\n\n" + "\n\n".join(test_methods_code) + "\n"
f" {cleanup_annotation} public static void cleanup() {{ helper.close(); }}\n\n"
+ "\n\n".join(test_methods_code)
+ "\n"
"}\n"
)

View file

@ -14,6 +14,39 @@ if TYPE_CHECKING:
logger = logging.getLogger(__name__)
GRACEFUL_SHUTDOWN_WAIT = 5 # seconds to wait after SIGTERM before SIGKILL
def _run_java_with_graceful_timeout(
java_command: list[str], env: dict[str, str], timeout: int, stage_name: str
) -> None:
"""Run a Java command with graceful timeout handling.
Sends SIGTERM first (allowing JFR dump and shutdown hooks to run),
then SIGKILL if the process doesn't exit within GRACEFUL_SHUTDOWN_WAIT seconds.
"""
if not timeout:
subprocess.run(java_command, env=env, check=False)
return
import signal
proc = subprocess.Popen(java_command, env=env)
try:
proc.wait(timeout=timeout)
except subprocess.TimeoutExpired:
logger.warning(
"%s stage timed out after %d seconds, sending SIGTERM for graceful shutdown...", stage_name, timeout
)
proc.send_signal(signal.SIGTERM)
try:
proc.wait(timeout=GRACEFUL_SHUTDOWN_WAIT)
except subprocess.TimeoutExpired:
logger.warning("%s stage did not exit after SIGTERM, sending SIGKILL", stage_name)
proc.kill()
proc.wait()
# --add-opens flags needed for Kryo serialization on Java 16+
ADD_OPENS_FLAGS = (
"--add-opens=java.base/java.util=ALL-UNNAMED "
@ -48,10 +81,7 @@ class JavaTracer:
# Stage 1: JFR Profiling
logger.info("Stage 1: Running JFR profiling...")
jfr_env = self.build_jfr_env(jfr_file)
try:
subprocess.run(java_command, env=jfr_env, check=False, timeout=timeout or None)
except subprocess.TimeoutExpired:
logger.warning("JFR profiling stage timed out after %d seconds", timeout)
_run_java_with_graceful_timeout(java_command, jfr_env, timeout, "JFR profiling")
if not jfr_file.exists():
logger.warning("JFR file was not created at %s", jfr_file)
@ -62,10 +92,7 @@ class JavaTracer:
trace_db_path, packages, project_root=project_root, max_function_count=max_function_count, timeout=timeout
)
agent_env = self.build_agent_env(config_path)
try:
subprocess.run(java_command, env=agent_env, check=False, timeout=timeout or None)
except subprocess.TimeoutExpired:
logger.warning("Argument capture stage timed out after %d seconds", timeout)
_run_java_with_graceful_timeout(java_command, agent_env, timeout, "Argument capture")
if not trace_db_path.exists():
logger.error("Trace database was not created at %s", trace_db_path)
@ -95,7 +122,12 @@ class JavaTracer:
def build_jfr_env(self, jfr_file: Path) -> dict[str, str]:
env = os.environ.copy()
jfr_opts = f"-XX:StartFlightRecording=filename={jfr_file.resolve()},settings=profile,dumponexit=true"
# Use profile settings with increased sampling frequency (1ms instead of default 10ms)
# This captures more samples for short-running programs
jfr_opts = (
f"-XX:StartFlightRecording=filename={jfr_file.resolve()},settings=profile,dumponexit=true"
",jdk.ExecutionSample#period=1ms"
)
existing = env.get("JAVA_TOOL_OPTIONS", "")
env["JAVA_TOOL_OPTIONS"] = f"{existing} {jfr_opts}".strip()
return env
@ -133,7 +165,7 @@ class JavaTracer:
if stripped.startswith("package "):
pkg = stripped[8:].rstrip(";").strip()
parts = pkg.split(".")
prefix = ".".join(parts[: min(2, len(parts))])
prefix = ".".join(parts[: min(3, len(parts))])
packages.add(prefix)
break
if stripped and not stripped.startswith("//"):
@ -153,6 +185,7 @@ def run_java_tracer(
max_function_count: int = 256,
timeout: int = 0,
max_run_count: int = 256,
test_framework: str = "junit5",
) -> tuple[Path, Path, int]:
"""High-level entry point: trace a Java command and generate replay tests.
@ -169,7 +202,11 @@ def run_java_tracer(
)
test_count = generate_replay_tests(
trace_db_path=trace_db, output_dir=output_dir, project_root=project_root, max_run_count=max_run_count
trace_db_path=trace_db,
output_dir=output_dir,
project_root=project_root,
max_run_count=max_run_count,
test_framework=test_framework,
)
return trace_db, jfr_file, test_count

View file

@ -226,6 +226,14 @@ def normalize_codeflash_imports(source: str) -> str:
return _CODEFLASH_IMPORT_PATTERN.sub(r"import \1 from 'codeflash'", source)
# Pattern to detect existing framework imports (regardless of specific identifiers imported)
# This catches semantic duplicates even if the order/identifiers differ from what we'd inject
_HAS_VITEST_IMPORT_RE = re.compile(r"import\s+\{[^}]*\}\s+from\s+['\"]vitest['\"]", re.MULTILINE)
_HAS_JEST_IMPORT_RE = re.compile(r"import\s+\{[^}]*\}\s+from\s+['\"]@jest/globals['\"]", re.MULTILINE)
_HAS_MOCHA_ASSERT_IMPORT_RE = re.compile(r"import\s+.*\s+from\s+['\"]node:assert", re.MULTILINE)
_HAS_MOCHA_ASSERT_REQUIRE_RE = re.compile(r"(?:const|let|var)\s+.*\s*=\s*require\s*\(\s*['\"]node:assert", re.MULTILINE)
# Author: ali <mohammed18200118@gmail.com>
def inject_test_globals(
generated_tests: GeneratedTestsList, test_framework: str = "jest", module_system: str = "esm"
@ -246,24 +254,29 @@ def inject_test_globals(
# Use vitest imports for vitest projects, jest imports for jest projects
if test_framework == "vitest":
global_import = "import { vi, describe, it, expect, beforeEach, afterEach, beforeAll, test } from 'vitest'\n"
has_import_re = _HAS_VITEST_IMPORT_RE
elif test_framework == "mocha":
if is_cjs:
global_import = "const assert = require('node:assert/strict');\n"
has_import_re = _HAS_MOCHA_ASSERT_REQUIRE_RE
else:
global_import = "import assert from 'node:assert/strict';\n"
has_import_re = _HAS_MOCHA_ASSERT_IMPORT_RE
else:
# Default to jest imports for jest and other frameworks
global_import = (
"import { jest, describe, it, expect, beforeEach, afterEach, beforeAll, test } from '@jest/globals'\n"
)
has_import_re = _HAS_JEST_IMPORT_RE
for test in generated_tests.generated_tests:
# Skip injection if the source already has the import (LLM may have included it)
if global_import.strip() not in test.generated_original_test_source:
# Skip injection if the source already has ANY import from the framework
# This catches semantic duplicates even if the AI used different identifiers/order
if not has_import_re.search(test.generated_original_test_source):
test.generated_original_test_source = global_import + test.generated_original_test_source
if global_import.strip() not in test.instrumented_behavior_test_source:
if not has_import_re.search(test.instrumented_behavior_test_source):
test.instrumented_behavior_test_source = global_import + test.instrumented_behavior_test_source
if global_import.strip() not in test.instrumented_perf_test_source:
if not has_import_re.search(test.instrumented_perf_test_source):
test.instrumented_perf_test_source = global_import + test.instrumented_perf_test_source
return generated_tests

View file

@ -1287,13 +1287,13 @@ def fix_imports_inside_test_blocks(test_code: str) -> str:
def fix_jest_mock_paths(test_code: str, test_file_path: Path, source_file_path: Path, tests_root: Path) -> str:
"""Fix relative paths in jest.mock() calls to be correct from the test file's location.
"""Fix relative paths in jest.mock() and vi.mock() calls to be correct from the test file's location.
The AI sometimes generates jest.mock() calls with paths relative to the source file
The AI sometimes generates mock calls with paths relative to the source file
instead of the test file. For example:
- Source at `src/queue/queue.ts` imports `../environment` (-> src/environment)
- Test at `tests/test.test.ts` generates `jest.mock('../environment')` (-> ./environment, wrong!)
- Should generate `jest.mock('../src/environment')`
- Test at `tests/test.test.ts` generates `jest.mock('../environment')` or `vi.mock('../environment')` (-> ./environment, wrong!)
- Should generate `jest.mock('../src/environment')` or `vi.mock('../src/environment')`
This function detects relative mock paths and adjusts them based on the test file's
location relative to the source file's directory.
@ -1318,8 +1318,8 @@ def fix_jest_mock_paths(test_code: str, test_file_path: Path, source_file_path:
test_dir = test_file_path.resolve().parent
project_root = tests_root.resolve().parent if tests_root.name == "tests" else tests_root.resolve()
# Pattern to match jest.mock() or jest.doMock() with relative paths
mock_pattern = re.compile(r"(jest\.(?:mock|doMock)\s*\(\s*['\"])(\.\./[^'\"]+|\.\/[^'\"]+)(['\"])")
# Pattern to match jest.mock(), jest.doMock(), or vi.mock() with relative paths
mock_pattern = re.compile(r"((?:jest|vi)\.(?:mock|doMock)\s*\(\s*['\"])(\.\./[^'\"]+|\.\/[^'\"]+)(['\"])")
def fix_mock_path(match: re.Match[str]) -> str:
original = match.group(0)
@ -1359,7 +1359,7 @@ def fix_jest_mock_paths(test_code: str, test_file_path: Path, source_file_path:
if not new_rel_path.startswith("../") and not new_rel_path.startswith("./"):
new_rel_path = f"./{new_rel_path}"
logger.debug(f"Fixed jest.mock path: {rel_path} -> {new_rel_path}")
logger.debug(f"Fixed mock path: {rel_path} -> {new_rel_path}")
return f"{prefix}{new_rel_path}{suffix}"
except (ValueError, OSError):

View file

@ -513,3 +513,54 @@ def ensure_vitest_imports(code: str, test_framework: str) -> str:
logger.debug("Added vitest imports: %s", used_globals)
return "\n".join(lines)
def add_js_extensions_to_relative_imports(code: str) -> str:
"""Add .js extensions to relative imports in ESM code.
In ESM mode with TypeScript, Node.js requires explicit .js extensions
for relative imports, even though the source files are .ts files.
This function adds .js extensions to relative imports that don't already
have a file extension.
Args:
code: JavaScript/TypeScript code with import statements.
Returns:
Code with .js extensions added to relative imports.
Examples:
>>> add_js_extensions_to_relative_imports("import X from './module';")
"import X from './module.js';"
>>> add_js_extensions_to_relative_imports("import X from './module.js';")
"import X from './module.js';"
>>> add_js_extensions_to_relative_imports("import X from 'node:assert';")
"import X from 'node:assert';"
"""
# Pattern to match ES module import statements with relative paths
# Matches: import ... from './path' or import ... from "../path"
# Groups: (import statement)(quote char)(relative path)(quote char)
import_pattern = re.compile(
r"(import\s+(?:(?:\{[^}]*\})|(?:\*\s+as\s+\w+)|(?:\w+))\s+from\s+)(['\"])(\.\.?[^'\"]+)(['\"])"
)
def add_extension(match):
"""Add .js extension if the import path doesn't have one."""
prefix = match.group(1) # "import ... from "
quote_open = match.group(2) # ' or "
path = match.group(3) # The relative path (e.g., "./module" or "../foo/bar")
quote_close = match.group(4) # ' or "
# Check if path already has an extension
# Common extensions: .js, .ts, .jsx, .tsx, .mjs, .mts, .json
if re.search(r"\.(js|ts|jsx|tsx|mjs|mts|json)$", path):
return match.group(0)
# Add .js extension
return f"{prefix}{quote_open}{path}.js{quote_close}"
return import_pattern.sub(add_extension, code)

View file

@ -7,6 +7,7 @@ using tree-sitter for code analysis and Jest for test execution.
from __future__ import annotations
import logging
import re
import subprocess
import xml.etree.ElementTree as ET
from pathlib import Path
@ -160,9 +161,15 @@ class JavaScriptSupport:
if not criteria.include_async and func.is_async:
continue
# Skip nested functions (functions defined inside other functions)
# Nested functions depend on closure variables from parent scope and cannot
# be optimized in isolation without complex context extraction
if func.parent_function:
logger.debug(f"Skipping nested function: {func.name} (parent: {func.parent_function})") # noqa: G004
continue
# Skip non-exported functions (can't be imported in tests)
# Exception: nested functions and methods are allowed if their parent is exported
if criteria.require_export and not func.is_exported and not func.parent_function:
if criteria.require_export and not func.is_exported:
logger.debug(f"Skipping non-exported function: {func.name}") # noqa: G004
continue
@ -224,6 +231,15 @@ class JavaScriptSupport:
"""
result: dict[str, list[TestInfo]] = {}
# Build indices for O(1) lookup per imported name (avoids O(NxM) loop)
function_name_to_qualified: dict[str, str] = {}
class_name_to_qualified_names: dict[str, list[str]] = {}
for func in source_functions:
function_name_to_qualified[func.function_name] = func.qualified_name
for parent in func.parents:
if parent.type == "ClassDef":
class_name_to_qualified_names.setdefault(parent.name, []).append(func.qualified_name)
# Find all test files using language-specific patterns
test_patterns = self._get_test_patterns()
@ -237,24 +253,39 @@ class JavaScriptSupport:
analyzer = get_analyzer_for_file(test_file)
imports = analyzer.find_imports(source)
# Build a set of imported function names
# Build a set of imported names, resolving aliases and namespace member access
imported_names: set[str] = set()
for imp in imports:
if imp.default_import:
imported_names.add(imp.default_import)
# Extract member access patterns: e.g. `math.calculate(...)` → "calculate"
for m in re.finditer(rf"\b{re.escape(imp.default_import)}\.(\w+)", source):
imported_names.add(m.group(1))
if imp.namespace_import:
imported_names.add(imp.namespace_import)
for m in re.finditer(rf"\b{re.escape(imp.namespace_import)}\.(\w+)", source):
imported_names.add(m.group(1))
for name, alias in imp.named_imports:
imported_names.add(alias or name)
imported_names.add(name)
if alias:
imported_names.add(alias)
# Find test functions (describe/it/test blocks)
test_functions = self._find_jest_tests(source, analyzer)
# Match source functions to tests
for func in source_functions:
if func.function_name in imported_names or func.function_name in source:
if func.qualified_name not in result:
result[func.qualified_name] = []
# Match via indices: function names and class names → qualified names
matched_qualified_names: set[str] = set()
for imported_name in imported_names:
if imported_name in function_name_to_qualified:
matched_qualified_names.add(function_name_to_qualified[imported_name])
if imported_name in class_name_to_qualified_names:
matched_qualified_names.update(class_name_to_qualified_names[imported_name])
for qualified_name in matched_qualified_names:
if qualified_name not in result:
result[qualified_name] = []
for test_name in test_functions:
result[func.qualified_name].append(
result[qualified_name].append(
TestInfo(test_name=test_name, test_file=test_file, test_class=None)
)
except Exception as e:
@ -2012,6 +2043,7 @@ class JavaScriptSupport:
validate_and_fix_import_style,
)
from codeflash.languages.javascript.module_system import (
ModuleSystem,
ensure_module_system_compatibility,
ensure_vitest_imports,
)
@ -2036,6 +2068,13 @@ class JavaScriptSupport:
generated_test_source, project_module_system, test_cfg.tests_project_rootdir
)
# Add .js extensions to relative imports for ESM projects
# TypeScript + ESM requires explicit .js extensions even for .ts source files
if project_module_system == ModuleSystem.ES_MODULE:
from codeflash.languages.javascript.module_system import add_js_extensions_to_relative_imports
generated_test_source = add_js_extensions_to_relative_imports(generated_test_source)
# Ensure vitest imports are present when using vitest framework
generated_test_source = ensure_vitest_imports(generated_test_source, test_cfg.test_framework)
@ -2229,7 +2268,10 @@ class JavaScriptSupport:
source_without_ext = source_file_abs.with_suffix("")
# Use os.path.relpath to compute relative path from tests_root to source file
rel_path = os.path.relpath(str(source_without_ext), str(tests_root_abs))
# Replace backslashes with forward slashes — JavaScript import/require paths
# must use forward slashes. Backslashes are escape chars in JS strings
# (e.g. \t → tab, \n → newline) and would break imports on Windows.
rel_path = os.path.relpath(str(source_without_ext), str(tests_root_abs)).replace("\\", "/")
# For ESM, add .js extension (TypeScript convention)
# TypeScript requires imports to reference the OUTPUT file extension (.js),

View file

@ -369,7 +369,9 @@ def _create_runtime_jest_config(base_config_path: Path | None, project_root: Pat
runtime_config_path = config_dir / f"jest.codeflash.runtime.config{config_ext}"
test_dirs_js = ", ".join(f"'{d}'" for d in sorted(test_dirs))
# Normalize to forward slashes — backslashes in JS strings are escape chars
# (e.g. \t → tab, \n → newline) and would corrupt paths on Windows.
test_dirs_js = ", ".join(f"'{d.replace(chr(92), '/')}'" for d in sorted(test_dirs))
# In monorepos, add the root node_modules to moduleDirectories so Jest
# can resolve workspace packages that are hoisted to the monorepo root.
@ -382,7 +384,13 @@ def _create_runtime_jest_config(base_config_path: Path | None, project_root: Pat
else:
module_dirs_line_no_base = ""
if base_config_path:
project_root_posix = project_root.as_posix()
# TypeScript configs (.ts) cannot be required from CommonJS modules
# because Node.js cannot parse TypeScript syntax in require().
# When the base config is TypeScript, we create a standalone config
# instead of trying to extend it via require().
if base_config_path and base_config_path.suffix != ".ts":
require_path = f"./{base_config_path.name}"
config_content = f"""// Auto-generated by codeflash - runtime config with test roots
const baseConfig = require('{require_path}');
@ -393,12 +401,13 @@ module.exports = {{
{test_dirs_js},
],
testMatch: ['**/*.test.ts', '**/*.test.js', '**/*.test.tsx', '**/*.test.jsx'],
testRegex: undefined, // Clear testRegex from baseConfig to avoid conflict with testMatch
{module_dirs_line}}};
"""
else:
config_content = f"""// Auto-generated by codeflash - runtime config with test roots
module.exports = {{
roots: ['{project_root}', {test_dirs_js}],
roots: ['{project_root_posix}', {test_dirs_js}],
testMatch: ['**/*.test.ts', '**/*.test.js', '**/*.test.tsx', '**/*.test.jsx'],
{module_dirs_line_no_base}}};
"""

View file

@ -7,6 +7,7 @@ verification and performance benchmarking.
from __future__ import annotations
import os
import re
import subprocess
import time
from pathlib import Path
@ -169,9 +170,24 @@ def _is_vitest_workspace(project_root: Path) -> bool:
return False
try:
content = vitest_config.read_text()
# Check for workspace indicators
return "workspace" in content.lower() or "defineWorkspace" in content
content = vitest_config.read_text(encoding="utf-8")
# Check for actual workspace configuration patterns (not just the word "workspace" in comments)
# Valid indicators:
# - defineWorkspace() function call
# - workspace: [ array config
# - separate vitest.workspace.ts/js file
# Match defineWorkspace calls or workspace: property assignments
workspace_pattern = re.compile(
r"(?:^|[^a-zA-Z_])defineWorkspace\s*\(|" # defineWorkspace( function call
r"(?:^|[^a-zA-Z_])workspace\s*:\s*\[", # workspace: [ array
re.MULTILINE,
)
if workspace_pattern.search(content):
return True
# Also check for separate workspace config file
if (project_root / "vitest.workspace.ts").exists() or (project_root / "vitest.workspace.js").exists():
return True
return False
except Exception:
return False
@ -238,6 +254,18 @@ export default mergeConfig(originalConfig, {{
include: ['**/*.test.ts', '**/*.test.js', '**/*.test.tsx', '**/*.test.jsx'],
// Use forks pool so timing markers from process.stdout.write flow to parent stdout
pool: 'forks',
// Disable setupFiles to prevent relative path resolution issues in nested directories.
// Project setupFiles often use relative paths (e.g., "test/setup.ts") which resolve
// incorrectly when tests are in subdirectories (e.g., extensions/discord/test/).
// Codeflash-generated tests are self-contained and don't require project setup files.
setupFiles: [],
// Override coverage settings to ensure JSON reporter is used.
// Vitest's mergeConfig doesn't properly handle nested coverage object merge with
// command-line flags, so we explicitly set reporter here to guarantee coverage
// files are written to the expected location (coverage-final.json).
coverage: {{
reporter: ['json'],
}},
}},
}});
"""
@ -254,6 +282,10 @@ export default defineConfig({
exclude: ['**/node_modules/**', '**/dist/**'],
// Use forks pool so timing markers from process.stdout.write flow to parent stdout
pool: 'forks',
// Override coverage settings to ensure JSON reporter is used
coverage: {
reporter: ['json'],
},
},
});
"""
@ -446,7 +478,21 @@ def run_vitest_behavioral_tests(
# Pre-creating an empty directory may cause vitest to delete it
logger.debug(f"Coverage will be written to: {coverage_dir}")
vitest_cmd.extend(["--coverage", "--coverage.reporter=json", f"--coverage.reportsDirectory={coverage_dir}"])
vitest_cmd.extend(
[
"--coverage",
"--coverage.reporter=json",
f"--coverage.reportsDirectory={coverage_dir}",
# Disable project-level coverage thresholds to prevent false failures.
# Codeflash-generated tests typically cover only a single function (~1-2% of codebase),
# which would fail projects with thresholds like 70% lines/functions configured
# in their vitest.config.ts.
"--coverage.thresholds.lines=0",
"--coverage.thresholds.functions=0",
"--coverage.thresholds.statements=0",
"--coverage.thresholds.branches=0",
]
)
# Note: Removed --coverage.enabled=true (redundant) and --coverage.all false
# The version mismatch between vitest and @vitest/coverage-v8 can cause
# issues with coverage flag parsing. Let vitest use default settings.

View file

@ -950,7 +950,7 @@ class TestResults(BaseModel): # noqa: PLW1641
by_id: dict[InvocationId, list[int]] = {}
for result in self.test_results:
if result.did_pass:
if result.runtime:
if result.runtime is not None:
by_id.setdefault(result.id, []).append(result.runtime)
else:
msg = (

View file

@ -127,7 +127,8 @@ class Optimizer:
function_benchmark_timings = CodeFlashBenchmarkPlugin.get_function_benchmark_timings(
self.trace_file
)
total_benchmark_timings = CodeFlashBenchmarkPlugin.get_benchmark_timings(self.trace_file)
total_benchmark_stats = CodeFlashBenchmarkPlugin.get_benchmark_timings(self.trace_file)
total_benchmark_timings = {k: v.median_ns for k, v in total_benchmark_stats.items()}
function_to_results = validate_and_format_benchmark_table(
function_benchmark_timings, total_benchmark_timings
)

View file

@ -349,8 +349,12 @@ def _run_java_tracer(existing_args: Namespace | None = None) -> ArgumentParser:
max_function_count = getattr(config, "max_function_count", 256)
timeout = int(getattr(config, "timeout", None) or getattr(config, "tracer_timeout", 0) or 0)
console.print("[bold]Java project detected[/]")
console.print(f" Project root: {project_root}")
console.print(f" Module root: {getattr(config, 'module_root', '?')}")
console.print(f" Tests root: {getattr(config, 'tests_root', '?')}")
from codeflash.code_utils.code_utils import get_run_tmp_file
from codeflash.languages.java.build_tools import find_test_root
from codeflash.languages.java.tracer import JavaTracer, run_java_tracer
tracer = JavaTracer()
@ -360,12 +364,16 @@ def _run_java_tracer(existing_args: Namespace | None = None) -> ArgumentParser:
trace_db_path = get_run_tmp_file(Path("java_trace.db"))
# Place replay tests in the project's test source tree so Maven/Gradle can compile them
test_root = find_test_root(project_root)
if test_root:
output_dir = test_root / "codeflash" / "replay"
# Place replay tests in the project's test source tree so Maven/Gradle can compile them.
# Use the config's tests_root (correctly resolved for multi-module projects) not find_test_root().
tests_root = Path(getattr(config, "tests_root", ""))
if tests_root.is_dir():
output_dir = tests_root / "codeflash" / "replay"
else:
output_dir = project_root / "src" / "test" / "java" / "codeflash" / "replay"
from codeflash.languages.java.build_tools import find_test_root
test_root = find_test_root(project_root)
output_dir = (test_root or project_root / "src" / "test" / "java") / "codeflash" / "replay"
output_dir.mkdir(parents=True, exist_ok=True)
# Remaining args after our flags are the Java command
@ -377,6 +385,12 @@ def _run_java_tracer(existing_args: Namespace | None = None) -> ArgumentParser:
sys.exit(1)
java_command = remaining
# Detect test framework for replay test generation
from codeflash.languages.java.config import detect_java_project
java_config = detect_java_project(project_root)
test_framework = java_config.test_framework if java_config else "junit5"
trace_db, jfr_file, test_count = run_java_tracer(
java_command=java_command,
trace_db_path=trace_db_path,
@ -385,6 +399,7 @@ def _run_java_tracer(existing_args: Namespace | None = None) -> ArgumentParser:
output_dir=output_dir,
max_function_count=max_function_count,
timeout=timeout,
test_framework=test_framework,
)
console.print(f"[bold green]Java tracing complete:[/] {test_count} replay test files generated")
@ -404,7 +419,7 @@ def _run_java_tracer(existing_args: Namespace | None = None) -> ArgumentParser:
config.replay_test = replay_test_paths
config.previous_checkpoint_functions = None
config.effort = EffortLevel.HIGH.value
config.no_pr = True
config.no_pr = getattr(config, "no_pr", False)
config.file = None
config.function = None
config.test_project_root = project_root

View file

@ -43,30 +43,33 @@ class JestCoverageUtils:
"""
if not coverage_json_path or not coverage_json_path.exists():
logger.debug(f"Jest coverage file not found: {coverage_json_path}")
logger.debug(f"JavaScript coverage file not found: {coverage_json_path}")
return CoverageData.create_empty(source_code_path, function_name, code_context)
try:
with coverage_json_path.open(encoding="utf-8") as f:
coverage_data = json.load(f)
except (json.JSONDecodeError, OSError) as e:
logger.warning(f"Failed to parse Jest coverage file: {e}")
logger.warning(f"Failed to parse JavaScript coverage file: {e}")
return CoverageData.create_empty(source_code_path, function_name, code_context)
# Find the file entry in coverage data
# Jest uses absolute paths as keys
# Jest/Vitest always writes coverage keys with forward slashes (POSIX paths),
# so we normalize our paths to POSIX for comparison — critical on Windows
# where Path.resolve() and str(Path) produce backslash paths.
file_coverage = None
source_path_str = str(source_code_path.resolve())
source_path_posix = source_code_path.resolve().as_posix()
source_relative_posix = source_code_path.as_posix()
for file_path, file_data in coverage_data.items():
# Match exact path or path ending with full relative path from src/
# Avoid matching files with same name in different directories (e.g., db/utils.ts vs utils/utils.ts)
if file_path == source_path_str or file_path.endswith(str(source_code_path)):
if file_path == source_path_posix or file_path.endswith(source_relative_posix):
file_coverage = file_data
break
if not file_coverage:
logger.debug(f"No coverage data found for {source_code_path} in Jest coverage")
logger.debug(f"No coverage data found for {source_code_path} in JavaScript coverage")
return CoverageData.create_empty(source_code_path, function_name, code_context)
# Extract line coverage from statement map and execution counts
@ -94,7 +97,7 @@ class JestCoverageUtils:
# If function not found in fnMap, use entire file
fn_start_line = 1
fn_end_line = 999999
logger.debug(f"Function {function_name} not found in Jest fnMap, using file coverage")
logger.debug(f"Function {function_name} not found in JavaScript fnMap, using file coverage")
# Calculate executed and unexecuted lines within the function
executed_lines = []

View file

@ -34,7 +34,20 @@ def generate_tests(
# TODO: Sometimes this recreates the original Class definition. This overrides and messes up the original
# class import. Remove the recreation of the class definition
start_time = time.perf_counter()
test_module_path = Path(module_name_from_file_path(test_path, test_cfg.tests_project_rootdir))
# Compute test module path - handle case where test file is outside tests_project_rootdir
# (e.g., JavaScript/TypeScript tests generated in __tests__ subdirectories adjacent to source files)
# Similar to javascript/parse.py:330-333 fallback pattern
try:
# Use traverse_up=True to handle co-located __tests__ directories that may be outside
# the configured tests_root (e.g., src/gateway/__tests__/ when tests_root is test/)
test_module_path = Path(module_name_from_file_path(test_path, test_cfg.tests_project_rootdir, traverse_up=True))
except ValueError:
# Test file is not within tests_project_rootdir - use just the filename
# This can happen for JavaScript/TypeScript when get_test_dir_for_source()
# places tests adjacent to source files (e.g., in src/foo/__tests__/)
# instead of within the configured tests_root
test_module_path = Path(test_path.name)
# Detect module system via language support (non-None for JS/TS, None for Python)
lang_support = current_language_support()

View file

@ -1,2 +1,2 @@
# These version placeholders will be replaced by uv-dynamic-versioning during build.
__version__ = "0.20.4"
__version__ = "0.20.5"

View file

@ -29,7 +29,7 @@ Flags can be combined: `/optimize src/utils.py my_function`
### What happens behind the scenes
1. The skill (defined in `skills/optimize/SKILL.md`) forks context and spawns the **optimizer agent**
2. The agent locates your project config (`pyproject.toml` or `package.json` or `codeflash.toml`)
2. The agent locates your project config (`pyproject.toml`, `package.json`, or `pom.xml`/`gradle.properties`)
3. It verifies the codeflash CLI is installed and the project is configured
4. It runs `codeflash --subagent` as a **background task** with a 10-minute timeout
5. You're notified when optimization completes with results

View file

@ -1,43 +1,52 @@
---
title: "Java Configuration"
description: "Configure Codeflash for Java projects using codeflash.toml"
description: "Configure Codeflash for Java projects"
icon: "java"
sidebarTitle: "Java (codeflash.toml)"
sidebarTitle: "Java"
keywords:
[
"configuration",
"codeflash.toml",
"java",
"maven",
"gradle",
"junit",
"pom.xml",
"gradle.properties",
]
---
# Java Configuration
Codeflash stores its configuration in `codeflash.toml` under the `[tool.codeflash]` section.
Codeflash stores its configuration inside your existing build file — `pom.xml` properties for Maven projects, or `gradle.properties` for Gradle projects. No separate config file is needed.
## Full Reference
## Maven Configuration
```toml
[tool.codeflash]
# Required
module-root = "src/main/java"
tests-root = "src/test/java"
language = "java"
For Maven projects, Codeflash writes properties under the `<properties>` section of your `pom.xml` with the `codeflash.` prefix:
# Optional
test-framework = "junit5" # "junit5", "junit4", or "testng"
disable-telemetry = false
git-remote = "origin"
ignore-paths = ["src/main/java/generated/"]
```xml
<properties>
<!-- Only non-default overrides are written -->
<codeflash.moduleRoot>src/main/java</codeflash.moduleRoot>
<codeflash.testsRoot>src/test/java</codeflash.testsRoot>
<codeflash.gitRemote>origin</codeflash.gitRemote>
<codeflash.formatterCmds>mvn spotless:apply -DspotlessFiles=$file</codeflash.formatterCmds>
<codeflash.disableTelemetry>false</codeflash.disableTelemetry>
<codeflash.ignorePaths>src/main/java/generated/</codeflash.ignorePaths>
</properties>
```
All file paths are relative to the directory containing `codeflash.toml`.
## Gradle Configuration
For Gradle projects, Codeflash writes settings to `gradle.properties` with the `codeflash.` prefix:
```properties
codeflash.moduleRoot=src/main/java
codeflash.testsRoot=src/test/java
codeflash.gitRemote=origin
```
<Info>
Codeflash auto-detects most settings from your project structure. Running `codeflash init` will set up the correct config — manual configuration is usually not needed.
Codeflash auto-detects most settings from your project structure. Running `codeflash init` will set up the correct config — manual configuration is usually not needed. For standard Maven/Gradle layouts, Codeflash may write no config at all if all defaults are correct.
</Info>
## Auto-Detection
@ -46,54 +55,42 @@ When you run `codeflash init`, Codeflash inspects your project and auto-detects:
| Setting | Detection logic |
|---------|----------------|
| `module-root` | Looks for `src/main/java` (Maven/Gradle standard layout) |
| `tests-root` | Looks for `src/test/java`, `test/`, `tests/` |
| `language` | Detected from build files (`pom.xml`, `build.gradle`) and `.java` files |
| `test-framework` | Checks build file dependencies for JUnit 5, JUnit 4, or TestNG |
| **Source root** | Looks for `src/main/java` (Maven/Gradle standard layout), falls back to pom.xml `sourceDirectory` |
| **Test root** | Looks for `src/test/java`, `test/`, `tests/` |
| **Build tool** | Detects Maven (`pom.xml`) or Gradle (`build.gradle` / `build.gradle.kts`) |
| **Test framework** | Checks build file dependencies for JUnit 5, JUnit 4, or TestNG |
## Required Options
## Configuration Options
- **`module-root`**: The source directory to optimize. Only code under this directory is discovered for optimization. For standard Maven/Gradle projects, this is `src/main/java`.
- **`tests-root`**: The directory where your tests are located. Codeflash discovers existing tests and places generated replay tests here.
- **`language`**: Must be set to `"java"` for Java projects.
| Property | Description | Default |
|----------|-------------|---------|
| `moduleRoot` | Source directory to optimize | `src/main/java` |
| `testsRoot` | Test directory | `src/test/java` |
| `gitRemote` | Git remote for pull requests | `origin` |
| `formatterCmds` | Code formatter command (`$file` placeholder for file path) | (none) |
| `disableTelemetry` | Disable anonymized telemetry | `false` |
| `ignorePaths` | Paths within source root to skip during optimization | (none) |
## Optional Options
- **`test-framework`**: Test framework. Auto-detected from build dependencies. Supported values: `"junit5"` (default), `"junit4"`, `"testng"`.
- **`disable-telemetry`**: Disable anonymized telemetry. Defaults to `false`.
- **`git-remote`**: Git remote for pull requests. Defaults to `"origin"`.
- **`ignore-paths`**: Paths within `module-root` to skip during optimization.
<Info>
Only non-default values are written to the config. If your project uses the standard `src/main/java` and `src/test/java` layout with the default `origin` remote, Codeflash may not need to write any config properties at all.
</Info>
## Multi-Module Projects
For multi-module Maven/Gradle projects, place `codeflash.toml` at the project root and set `module-root` to the module you want to optimize:
For multi-module Maven/Gradle projects, run `codeflash init` from the module you want to optimize. The config is written to that module's `pom.xml` or `gradle.properties`:
```text
my-project/
|- client/
| |- src/main/java/com/example/client/
| |- src/test/java/com/example/client/
| |- pom.xml <-- run codeflash init here
|- server/
| |- src/main/java/com/example/server/
|- pom.xml
|- codeflash.toml
```
```toml
[tool.codeflash]
module-root = "client/src/main/java"
tests-root = "client/src/test/java"
language = "java"
```
For non-standard layouts (like the Aerospike client where source is under `client/src/`), adjust paths accordingly:
```toml
[tool.codeflash]
module-root = "client/src"
tests-root = "test/src"
language = "java"
```
For non-standard layouts (like the Aerospike client where source is under `client/src/`), `codeflash init` will prompt you to override the detected paths.
## Tracer Options
@ -124,15 +121,9 @@ my-app/
| |- test/java/com/example/
| |- AppTest.java
|- pom.xml
|- codeflash.toml
```
```toml
[tool.codeflash]
module-root = "src/main/java"
tests-root = "src/test/java"
language = "java"
```
Standard layout — no extra config needed. `codeflash init` detects everything automatically.
### Gradle project
@ -142,12 +133,7 @@ my-lib/
| |- main/java/com/example/
| |- test/java/com/example/
|- build.gradle
|- codeflash.toml
|- gradle.properties <-- codeflash config written here if overrides needed
```
```toml
[tool.codeflash]
module-root = "src/main/java"
tests-root = "src/test/java"
language = "java"
```
Standard layout — no extra config needed. `codeflash init` detects everything automatically.

View file

@ -15,7 +15,9 @@ keywords:
]
---
Codeflash supports Java projects using Maven or Gradle build systems. It uses a two-stage tracing approach to capture method arguments and profiling data from running Java programs, then optimizes the hottest functions.
Codeflash supports optimizing Java projects using Maven or Gradle build systems. It works in two main ways:
1. Codeflash can optimize new java code written in a Pull Request through Github Actions.
2. Codeflash can optimize real workloads end to end. It uses a two-stage tracing approach to capture method arguments and profiling data from running Java program, then optimizes the hottest functions with that data.
### Prerequisites
@ -32,16 +34,16 @@ Good to have (optional):
<Steps>
<Step title="Install Codeflash CLI">
Codeflash CLI is a Python tool. Install it with pip:
Codeflash uses Python to run its CLI. You can use uv as a package manager and installer for Python programs.
To install uv, run the following or [see these instructions](https://docs.astral.sh/uv/getting-started/installation/)
```bash
pip install codeflash
curl -LsSf https://astral.sh/uv/install.sh | sh
```
Or with uv:
Then install Codeflash as a uv tool.
```bash
uv pip install codeflash
uv tool install codeflash
```
</Step>
@ -56,65 +58,35 @@ codeflash init
This will:
- Detect your build tool (Maven/Gradle)
- Find your source and test directories
- Create a `codeflash.toml` configuration file
</Step>
<Step title="Verify setup">
Check that the configuration looks correct:
```bash
cat codeflash.toml
```
You should see something like:
```toml
[tool.codeflash]
module-root = "src/main/java"
tests-root = "src/test/java"
language = "java"
```
- Write Codeflash configuration to your `pom.xml` properties (Maven) or `gradle.properties` (Gradle)
</Step>
<Step title="Run your first optimization">
Trace and optimize a running Java program:
Optimize a specific function:
```bash
codeflash optimize java -jar target/my-app.jar
codeflash --file src/main/java/com/example/Utils.java --function myMethod
```
Or with Maven:
Or optimize all functions in your project:
```bash
codeflash optimize mvn exec:java -Dexec.mainClass="com.example.Main"
codeflash --all
```
Codeflash will:
1. Profile your program using JFR (Java Flight Recorder)
2. Capture method arguments using a bytecode instrumentation agent
3. Generate JUnit replay tests from the captured data
4. Rank functions by performance impact
5. Optimize the most impactful functions
1. Discover optimizable functions in your source code
2. Generate tests and optimization candidates using AI
3. Verify correctness by running tests (JUnit 5, JUnit 4, or TestNG)
4. Benchmark performance improvements
5. Create a pull request with the optimization (if the GitHub App is installed)
For advanced workflow tracing (profiling a running Java program), see [Trace & Optimize](/optimizing-with-codeflash/trace-and-optimize).
</Step>
</Steps>
## How it works
Codeflash uses a **two-stage tracing** approach for Java:
1. **Stage 1 — JFR Profiling**: Runs your program with Java Flight Recorder enabled to collect accurate method-level CPU profiling data. JFR has ~1% overhead and doesn't affect JIT compilation.
2. **Stage 2 — Argument Capture**: Runs your program again with a bytecode instrumentation agent that captures method arguments using Kryo serialization. Arguments are stored in an SQLite database.
The traced data is used to generate **JUnit replay tests** that exercise your functions with real-world inputs. Codeflash uses these tests alongside any existing unit tests to verify correctness and benchmark optimization candidates.
<Info>
Your program runs **twice** — once for profiling, once for argument capture. This separation ensures profiling data isn't distorted by serialization overhead.
</Info>
## Supported build tools
| Build Tool | Detection | Test Execution |

View file

@ -71,16 +71,15 @@ bun add --dev codeflash
</Tip>
<Info>
**Codeflash also requires a Python installation** (3.9+) to run the CLI optimizer. Install the Python CLI globally:
**One-time setup required.** The Codeflash optimizer runs on Python behind the scenes. After installing the npm package, run:
```bash
pip install codeflash
# or
uv pip install codeflash
npx codeflash setup
```
The Python CLI orchestrates the optimization pipeline, while the npm package provides the JavaScript runtime (test runners, serialization, reporters).
This automatically creates an isolated Python environment — no global installs or manual Python management needed. After setup, all Codeflash commands run through `npx codeflash` which uses the installed binary automatically.
</Info>
</Step>
<Step title="Generate a Codeflash API Key">

View file

@ -2,11 +2,11 @@
title: "Codeflash is an AI performance optimizer for your code"
icon: "rocket"
sidebarTitle: "Overview"
keywords: ["python", "javascript", "typescript", "performance", "optimization", "AI", "code analysis", "benchmarking"]
keywords: ["python", "javascript", "typescript", "java", "performance", "optimization", "AI", "code analysis", "benchmarking"]
---
Codeflash speeds up your code by figuring out the best way to rewrite it while verifying that the behavior is unchanged, and verifying real speed
gains through performance benchmarking. It supports **Python**, **JavaScript**, and **TypeScript**.
gains through performance benchmarking. It supports **Python**, **JavaScript**, **TypeScript**, and **Java**.
The optimizations Codeflash finds are generally better algorithms, opportunities to remove wasteful compute, better logic, utilizing caching and utilization of more efficient library methods. Codeflash
does not modify the system architecture of your code, but it tries to find the most efficient implementation of your current architecture.
@ -15,18 +15,21 @@ does not modify the system architecture of your code, but it tries to find the m
Pick your language to install and configure Codeflash:
<CardGroup cols={2}>
<CardGroup cols={3}>
<Card title="Python" icon="python" href="/getting-started/local-installation">
Install via pip, uv, or poetry. Configure in `pyproject.toml`.
</Card>
<Card title="JavaScript / TypeScript" icon="js" href="/getting-started/javascript-installation">
Install via npm, yarn, pnpm, or bun. Configure in `package.json`. Supports Jest, Vitest, and Mocha.
</Card>
<Card title="Java" icon="java" href="/getting-started/java-installation">
Install via uv. Supports Maven and Gradle. JUnit 5, JUnit 4, and TestNG.
</Card>
</CardGroup>
### How to use Codeflash
These commands work for both Python and JS/TS projects:
These commands work for Python, JS/TS, and Java projects:
<CardGroup cols={2}>
<Card title="Optimize a Function" icon="bullseye" href="/optimizing-with-codeflash/one-function">
@ -56,13 +59,16 @@ These commands work for both Python and JS/TS projects:
### Configuration Reference
<CardGroup cols={2}>
<CardGroup cols={3}>
<Card title="Python Config" icon="python" href="/configuration/python">
`pyproject.toml` reference
</Card>
<Card title="JS / TS Config" icon="js" href="/configuration/javascript">
`package.json` reference — includes monorepo, scattered tests, manual setup
</Card>
<Card title="Java Config" icon="java" href="/configuration/java">
`pom.xml` / `gradle.properties` reference
</Card>
</CardGroup>
### How does Codeflash verify correctness?

View file

@ -9,7 +9,7 @@ keywords: ["codebase optimization", "all functions", "batch optimization", "gith
# Optimize your entire codebase
Codeflash can optimize your entire codebase by analyzing all the functions in your project and generating optimized versions of them.
It iterates through all the functions in your codebase and optimizes them one by one. This works for Python, JavaScript, and TypeScript projects.
It iterates through all the functions in your codebase and optimizes them one by one. This works for Python, JavaScript, TypeScript, and Java projects.
To optimize your entire codebase, run the following command in your project directory:

View file

@ -13,6 +13,7 @@ keywords:
"javascript",
"typescript",
"python",
"java",
]
---
@ -45,6 +46,11 @@ codeflash --file path/to/your/file.js --function functionName
codeflash --file path/to/your/file.ts --function functionName
```
</Tab>
<Tab title="Java">
```bash
codeflash --file src/main/java/com/example/Utils.java --function methodName
```
</Tab>
</Tabs>
If you have installed the GitHub App to your repository, the above command will open a pull request with the optimized function.
@ -61,6 +67,11 @@ codeflash --file path/to/your/file.py --function function_name --no-pr
codeflash --file path/to/your/file.ts --function functionName --no-pr
```
</Tab>
<Tab title="Java">
```bash
codeflash --file src/main/java/com/example/Utils.java --function methodName --no-pr
```
</Tab>
</Tabs>
### Optimizing class methods
@ -78,4 +89,9 @@ codeflash --file path/to/your/file.py --function ClassName.method_name
codeflash --file path/to/your/file.ts --function ClassName.methodName
```
</Tab>
<Tab title="Java">
```bash
codeflash --file src/main/java/com/example/Utils.java --function methodName
```
</Tab>
</Tabs>

View file

@ -53,6 +53,8 @@ dependencies = [
"filelock>=3.20.3; python_version >= '3.10'",
"filelock<3.20.3; python_version < '3.10'",
"pytest-asyncio>=0.18.0",
"memray>=1.12; sys_platform != 'win32'",
"pytest-memray>=1.7; sys_platform != 'win32'",
]
[project.urls]

View file

@ -0,0 +1,117 @@
"""Tests for JavaScript/TypeScript function discovery logic."""
from __future__ import annotations
from pathlib import Path
import pytest
from codeflash.languages.base import FunctionFilterCriteria
from codeflash.languages.javascript.support import JavaScriptSupport
class TestFunctionDiscovery:
"""Tests for discover_functions method."""
@pytest.fixture
def js_support(self) -> JavaScriptSupport:
"""Create a JavaScriptSupport instance."""
return JavaScriptSupport()
def test_discovers_top_level_function(self, js_support: JavaScriptSupport) -> None:
"""Should discover top-level exported functions."""
code = """
export function topLevelFunc() {
return 42;
}
"""
functions = js_support.discover_functions(
code,
Path("/tmp/test.js"),
FunctionFilterCriteria(require_export=True, require_return=True),
)
assert len(functions) == 1
assert functions[0].function_name == "topLevelFunc"
assert functions[0].parents == []
def test_skips_nested_functions_in_closures(self, js_support: JavaScriptSupport) -> None:
"""Should skip nested functions that are defined inside other functions.
Nested functions depend on closure variables from their parent scope and cannot
be optimized in isolation without extracting the entire parent context.
Bug: Previously, nested functions were discovered and attempted to be optimized,
but the extraction logic only captured the nested function body, causing
validation errors like "Undefined variable(s): base, streamFn, record".
"""
code = """
export function wrapStreamFn(streamFn) {
const base = { id: 1 };
const record = (event) => { };
const wrapped = (model, context, options) => {
if (!model) {
return streamFn(model, context, options);
}
record({ data: base });
return base;
};
return wrapped;
}
"""
functions = js_support.discover_functions(
code,
Path("/tmp/test.js"),
FunctionFilterCriteria(require_export=True, require_return=True),
)
# Should only discover the top-level function, not the nested ones
assert len(functions) == 1, f"Expected 1 function but found {len(functions)}: {[f.function_name for f in functions]}"
assert functions[0].function_name == "wrapStreamFn"
assert functions[0].parents == []
def test_discovers_class_methods(self, js_support: JavaScriptSupport) -> None:
"""Should discover class methods (these are handled specially with class wrapping)."""
code = """
export class MyClass {
myMethod() {
return 42;
}
}
"""
functions = js_support.discover_functions(
code,
Path("/tmp/test.js"),
FunctionFilterCriteria(require_export=True, require_return=True, include_methods=True),
)
assert len(functions) == 1
assert functions[0].function_name == "myMethod"
assert len(functions[0].parents) == 1
assert functions[0].parents[0].name == "MyClass"
assert functions[0].parents[0].type == "ClassDef"
def test_skips_nested_functions_with_multiple_levels(self, js_support: JavaScriptSupport) -> None:
"""Should skip deeply nested functions."""
code = """
export function outer() {
const middle = () => {
const inner = () => {
return 42;
};
return inner();
};
return middle();
}
"""
functions = js_support.discover_functions(
code,
Path("/tmp/test.js"),
FunctionFilterCriteria(require_export=True, require_return=True),
)
# Should only discover the top-level function
assert len(functions) == 1
assert functions[0].function_name == "outer"

View file

@ -0,0 +1,109 @@
"""Test for false positive test discovery bug (Bug #4)."""
from pathlib import Path
from tempfile import TemporaryDirectory
import pytest
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
from codeflash.languages.javascript.support import TypeScriptSupport
from codeflash.models.models import CodePosition
def test_discover_tests_should_not_match_mocked_functions():
"""Test that functions mentioned only in mocks are not matched as test targets.
Regression test for Bug #4: False positive test discovery due to substring matching.
When a test file mocks a function (e.g., vi.mock("./restart-request.js", () => ({...}))),
that function should NOT be considered as tested by that file, since it's only mocked,
not actually called or tested.
"""
support = TypeScriptSupport()
with TemporaryDirectory() as tmpdir:
test_root = Path(tmpdir)
# Create a test file that MOCKS parseRestartRequestParams but doesn't test it
test_file = test_root / "update.test.ts"
test_file.write_text(
'''
import { updateSomething } from "./update.js";
vi.mock("./restart-request.js", () => ({
parseRestartRequestParams: (params: any) => ({ sessionKey: undefined }),
}));
describe("updateSomething", () => {
it("should update successfully", () => {
const result = updateSomething();
expect(result).toBe(true);
});
});
'''
)
# Source function that is only mocked, not tested
source_function = FunctionToOptimize(
qualified_name="parseRestartRequestParams",
function_name="parseRestartRequestParams",
file_path=test_root / "restart-request.ts",
starting_line=1,
ending_line=10,
function_signature="",
code_position=CodePosition(line_no=1, col_no=0),
file_path_relative_to_project_root="restart-request.ts",
)
# Discover tests
result = support.discover_tests(test_root, [source_function])
# The bug: discovers update.test.ts as a test for parseRestartRequestParams
# because "parseRestartRequestParams" appears as a substring in the mock
# Expected: should NOT match (empty result)
assert (
source_function.qualified_name not in result or len(result[source_function.qualified_name]) == 0
), f"Should not match mocked function, but found: {result.get(source_function.qualified_name, [])}"
def test_discover_tests_should_match_actually_imported_functions():
"""Test that functions actually imported and tested ARE correctly matched.
This is the positive case to ensure we don't break legitimate test discovery.
"""
support = TypeScriptSupport()
with TemporaryDirectory() as tmpdir:
test_root = Path(tmpdir)
# Create a test file that ACTUALLY imports and tests the function
test_file = test_root / "restart-request.test.ts"
test_file.write_text(
'''
import { parseRestartRequestParams } from "./restart-request.js";
describe("parseRestartRequestParams", () => {
it("should parse valid params", () => {
const result = parseRestartRequestParams({ sessionKey: "abc" });
expect(result.sessionKey).toBe("abc");
});
});
'''
)
source_function = FunctionToOptimize(
qualified_name="parseRestartRequestParams",
function_name="parseRestartRequestParams",
file_path=test_root / "restart-request.ts",
starting_line=1,
ending_line=10,
function_signature="",
code_position=CodePosition(line_no=1, col_no=0),
file_path_relative_to_project_root="restart-request.ts",
)
result = support.discover_tests(test_root, [source_function])
# Should match: function is imported and tested
assert source_function.qualified_name in result, f"Should match imported function, but got: {result}"
assert len(result[source_function.qualified_name]) > 0, "Should find at least one test"

View file

@ -0,0 +1,79 @@
"""Test that Codeflash Vitest config properly overrides coverage settings."""
from pathlib import Path
import pytest
from codeflash.languages.javascript.vitest_runner import _ensure_codeflash_vitest_config
def test_codeflash_vitest_config_overrides_coverage(tmp_path: Path) -> None:
project_root = tmp_path.resolve()
vitest_config = project_root / "vitest.config.ts"
vitest_config.write_text(
"""
import { defineConfig } from 'vitest/config';
export default defineConfig({
test: {
include: ['test/**/*.test.ts'],
coverage: {
provider: 'v8',
reporter: ['text', 'lcov'],
all: false,
thresholds: {
lines: 70,
functions: 70,
},
},
},
});
""",
encoding="utf-8",
)
config_path = _ensure_codeflash_vitest_config(project_root)
assert config_path is not None, "Config should be created"
assert config_path.exists(), "Config file should exist"
config_content = config_path.read_text(encoding="utf-8")
assert "mergeConfig" in config_content, "Should use mergeConfig"
assert "import originalConfig from './vitest.config.ts'" in config_content
assert "coverage:" in config_content, (
"Config must explicitly override coverage settings to ensure "
"json reporter is used regardless of project config"
)
assert "reporter:" in config_content, "Config must override coverage.reporter to ['json']"
assert "['json']" in config_content or '["json"]' in config_content, (
"Coverage reporter must be set to ['json'] to ensure coverage files are written in the expected format"
)
def test_codeflash_vitest_config_without_original_coverage(tmp_path: Path) -> None:
project_root = tmp_path.resolve()
vitest_config = project_root / "vitest.config.ts"
vitest_config.write_text(
"""
import { defineConfig } from 'vitest/config';
export default defineConfig({
test: {
include: ['test/**/*.test.ts'],
},
});
""",
encoding="utf-8",
)
config_path = _ensure_codeflash_vitest_config(project_root)
assert config_path is not None
assert config_path.exists()
config_content = config_path.read_text(encoding="utf-8")
assert "coverage:" in config_content, "Config must explicitly set coverage even when original doesn't have it"

View file

@ -0,0 +1,125 @@
"""Tests for handling Vitest coverage exclusions.
These tests verify that Codeflash correctly detects and handles files
that are excluded from coverage by vitest.config.ts, preventing false
0% coverage reports.
"""
from __future__ import annotations
import json
import tempfile
from pathlib import Path
import pytest
from codeflash.models.models import CodeOptimizationContext, CoverageStatus
from codeflash.verification.coverage_utils import JestCoverageUtils
class TestVitestCoverageExclusions:
"""Tests for Vitest coverage exclusion handling."""
def test_missing_coverage_returns_not_found_status(self) -> None:
"""Should return NOT_FOUND status when file is not in coverage data.
When a file is excluded from Vitest coverage (via coverage.exclude),
it won't appear in coverage-final.json. Codeflash should return
NOT_FOUND status (not PARSED_SUCCESSFULLY).
This test verifies the current behavior is correct at the coverage
parsing level. The issue is at a higher level (function_optimizer.py)
where NOT_FOUND status needs better handling.
"""
with tempfile.TemporaryDirectory() as tmp_dir:
tmp_path = Path(tmp_dir)
# Create mock coverage-final.json that's missing the target file
coverage_file = tmp_path / "coverage-final.json"
coverage_data = {
"/workspace/project/src/utils/helpers.ts": {
"fnMap": {},
"s": {},
},
# src/agents/sandbox/fs-paths.ts is NOT here (excluded by Vitest)
}
with coverage_file.open("w") as f:
json.dump(coverage_data, f)
# Try to load coverage for a missing file
missing_file = Path("/workspace/project/src/agents/sandbox/fs-paths.ts")
from codeflash.models.models import CodeStringsMarkdown
mock_context = CodeOptimizationContext(
testgen_context=CodeStringsMarkdown(language="typescript"),
read_writable_code=CodeStringsMarkdown(language="typescript"),
helper_functions=[],
preexisting_objects=set(),
)
result = JestCoverageUtils.load_from_jest_json(
coverage_json_path=coverage_file,
function_name="parseSandboxBindMount",
code_context=mock_context,
source_code_path=missing_file,
)
# Should return NOT_FOUND when file not in coverage
assert result.status == CoverageStatus.NOT_FOUND, (
f"Expected NOT_FOUND for missing file, got {result.status}"
)
assert result.coverage == 0.0
def test_handles_included_file_normally(self) -> None:
"""Should handle files that ARE included in coverage normally.
This test verifies that the fix doesn't break normal coverage parsing
for files that are NOT excluded.
"""
with tempfile.TemporaryDirectory() as tmp_dir:
tmp_path = Path(tmp_dir)
# Create mock coverage-final.json with a valid file
coverage_file = tmp_path / "coverage-final.json"
test_file = "/workspace/project/src/utils/helpers.ts"
coverage_data = {
test_file: {
"fnMap": {
"0": {"name": "someHelper", "loc": {"start": {"line": 1}, "end": {"line": 5}}}
},
"statementMap": {
"0": {"start": {"line": 2}, "end": {"line": 2}},
"1": {"start": {"line": 3}, "end": {"line": 3}},
},
"s": {"0": 5, "1": 5}, # Both statements executed
"branchMap": {},
"b": {},
}
}
with coverage_file.open("w") as f:
json.dump(coverage_data, f)
source_file = Path(test_file)
from codeflash.models.models import CodeStringsMarkdown
mock_context = CodeOptimizationContext(
testgen_context=CodeStringsMarkdown(language="typescript"),
read_writable_code=CodeStringsMarkdown(language="typescript"),
helper_functions=[],
preexisting_objects=set(),
)
result = JestCoverageUtils.load_from_jest_json(
coverage_json_path=coverage_file,
function_name="someHelper",
code_context=mock_context,
source_code_path=source_file,
)
# Should parse successfully for non-excluded files
assert result.status == CoverageStatus.PARSED_SUCCESSFULLY
assert result.coverage > 0.0 # Should have actual coverage
if __name__ == "__main__":
pytest.main([__file__, "-v"])

View file

@ -0,0 +1,61 @@
from pathlib import Path
import pytest
from codeflash.languages.javascript.vitest_runner import _ensure_codeflash_vitest_config
def test_codeflash_vitest_config_overrides_setupfiles(tmp_path: Path) -> None:
project_root = tmp_path.resolve()
# Create a project with setup file
(project_root / "test").mkdir()
(project_root / "test" / "setup.ts").write_text("// Setup file\n", encoding="utf-8")
vitest_config = """import { defineConfig } from 'vitest/config';
export default defineConfig({
test: {
setupFiles: ["test/setup.ts"], // Relative path - will cause issues
include: ["src/**/*.test.ts"],
},
});
"""
(project_root / "vitest.config.ts").write_text(vitest_config, encoding="utf-8")
codeflash_config_path = _ensure_codeflash_vitest_config(project_root)
assert codeflash_config_path is not None
assert codeflash_config_path.exists()
config_content = codeflash_config_path.read_text(encoding="utf-8")
assert "setupFiles" in config_content, (
"Generated config must explicitly handle setupFiles to prevent "
"relative path resolution issues. Current config:\n" + config_content
)
assert "setupFiles: []" in config_content or "setupFiles:" in config_content, (
"setupFiles must be explicitly set in the merged config"
)
def test_codeflash_vitest_config_without_setupfiles(tmp_path: Path) -> None:
project_root = tmp_path.resolve()
vitest_config = """import { defineConfig } from 'vitest/config';
export default defineConfig({
test: {
include: ["src/**/*.test.ts"],
},
});
"""
(project_root / "vitest.config.ts").write_text(vitest_config, encoding="utf-8")
codeflash_config_path = _ensure_codeflash_vitest_config(project_root)
assert codeflash_config_path is not None
assert codeflash_config_path.exists()
config_content = codeflash_config_path.read_text(encoding="utf-8")
assert "mergeConfig" in config_content or "defineConfig" in config_content

View file

@ -0,0 +1,45 @@
"""Test for coverage exclusion error message (Bug #5 regression test)."""
from pathlib import Path
from codeflash.models.function_types import FunctionToOptimize
from codeflash.models.models import CodePosition
def test_function_to_optimize_has_file_path_not_source_file_path():
"""Test that FunctionToOptimize has file_path attribute, not source_file_path.
Regression test for Bug #5: Bug #1's fix used wrong attribute name 'source_file_path'
instead of 'file_path', causing AttributeError when constructing coverage error messages.
The bug occurred in function_optimizer.py lines 2797 and 2803:
f"No coverage data found for {self.function_to_optimize.source_file_path}."
This should be:
f"No coverage data found for {self.function_to_optimize.file_path}."
Trace ID: 5c4a75fb-d8eb-4f75-9e57-893f0c44b9c7
"""
# Create a FunctionToOptimize object
func = FunctionToOptimize(
function_name="testFunc",
file_path=Path("/workspace/target/src/test.ts"),
starting_line=1,
ending_line=10,
code_position=CodePosition(line_no=1, col_no=0),
file_path_relative_to_project_root="src/test.ts",
)
# Verify correct attribute exists
assert hasattr(func, "file_path"), "FunctionToOptimize should have 'file_path' attribute"
assert func.file_path == Path("/workspace/target/src/test.ts")
# Verify wrong attribute does NOT exist
assert not hasattr(
func, "source_file_path"
), "FunctionToOptimize should NOT have 'source_file_path' attribute (it's a typo/bug)"
# Verify we can access file_path in string formatting (like the bug location does)
error_message = f"No coverage data found for {func.file_path}."
assert "test.ts" in error_message
# This should NOT raise AttributeError

View file

@ -50,6 +50,7 @@ def run_test(expected_improvement_pct: int) -> bool:
"--no-project",
"-m",
"codeflash.main",
"--no-pr",
"optimize",
"java",
"-cp",
@ -59,6 +60,7 @@ def run_test(expected_improvement_pct: int) -> bool:
env = os.environ.copy()
env["PYTHONIOENCODING"] = "utf-8"
env["PYTHONUNBUFFERED"] = "1"
logging.info(f"Running command: {' '.join(command)}")
logging.info(f"Working directory: {fixture_dir}")
process = subprocess.Popen(
@ -73,13 +75,11 @@ def run_test(expected_improvement_pct: int) -> bool:
output = []
for line in process.stdout:
logging.info(line.strip())
print(line, end="", flush=True)
output.append(line)
return_code = process.wait()
stdout = "".join(output)
if return_code != 0:
logging.error(f"Full output:\n{stdout}")
if return_code != 0:
logging.error(f"Command returned exit code {return_code}")
@ -90,7 +90,7 @@ def run_test(expected_improvement_pct: int) -> bool:
logging.error("Failed to find replay test generation message")
return False
# Validate: replay tests were discovered
# Validate: replay tests were discovered (global count)
replay_match = re.search(r"Discovered \d+ existing unit tests? and (\d+) replay tests?", stdout)
if not replay_match:
logging.error("Failed to find replay test discovery message")
@ -101,6 +101,17 @@ def run_test(expected_improvement_pct: int) -> bool:
return False
logging.info(f"Replay tests discovered: {num_replay}")
# Validate: replay test files were used per-function
replay_file_match = re.search(r"Discovered \d+ existing unit test files?, (\d+) replay test files?", stdout)
if not replay_file_match:
logging.error("Failed to find per-function replay test file discovery message")
return False
num_replay_files = int(replay_file_match.group(1))
if num_replay_files == 0:
logging.error("No replay test files discovered per-function")
return False
logging.info(f"Replay test files per-function: {num_replay_files}")
# Validate: at least one optimization was found
if "⚡️ Optimization successful! 📄 " not in stdout:
logging.error("Failed to find optimization success message")

240
tests/test_compare.py Normal file
View file

@ -0,0 +1,240 @@
from __future__ import annotations
from codeflash.benchmarking.compare import (
CompareResult,
ScriptCompareResult,
has_meaningful_memory_change,
render_comparison,
render_script_comparison,
)
from codeflash.benchmarking.plugin.plugin import BenchmarkStats, MemoryStats
from codeflash.models.models import BenchmarkKey
def _make_stats(median_ns: float = 1000.0, rounds: int = 10) -> BenchmarkStats:
return BenchmarkStats(
min_ns=median_ns * 0.9,
max_ns=median_ns * 1.1,
mean_ns=median_ns,
median_ns=median_ns,
stddev_ns=median_ns * 0.05,
iqr_ns=median_ns * 0.1,
rounds=rounds,
iterations=100,
outliers="0;0",
)
def _make_memory(peak: int = 4_194_304, allocs: int = 1000) -> MemoryStats:
return MemoryStats(peak_memory_bytes=peak, total_allocations=allocs)
BM_KEY = BenchmarkKey(module_path="tests.benchmarks.test_example", function_name="test_func")
class TestFormatMarkdownMemoryOnly:
def test_memory_only_no_timing_table(self) -> None:
result = CompareResult(
base_ref="abc123",
head_ref="def456",
base_memory={BM_KEY: _make_memory(peak=10_000_000, allocs=500)},
head_memory={BM_KEY: _make_memory(peak=7_000_000, allocs=400)},
)
md = result.format_markdown()
# Should have memory data
assert "Peak Memory" in md
assert "Allocations" in md
# Should NOT have timing table headers
assert "Min | Median | Mean | OPS" not in md
assert "Per-Function" not in md
def test_memory_only_returns_empty_when_no_data(self) -> None:
result = CompareResult(base_ref="abc123", head_ref="def456")
md = result.format_markdown()
assert md == "_No benchmark results to compare._"
def test_mixed_timing_and_memory(self) -> None:
result = CompareResult(
base_ref="abc123",
head_ref="def456",
base_stats={BM_KEY: _make_stats()},
head_stats={BM_KEY: _make_stats(median_ns=500.0)},
base_memory={BM_KEY: _make_memory(peak=10_000_000)},
head_memory={BM_KEY: _make_memory(peak=5_000_000)},
)
md = result.format_markdown()
# Should have both timing and memory
assert "Min | Median | Mean | OPS" in md
assert "Peak Memory" in md
def test_memory_only_always_shows_memory(self) -> None:
"""Memory-only keys always render the memory table, even if delta is <1%."""
result = CompareResult(
base_ref="abc123",
head_ref="def456",
base_memory={BM_KEY: _make_memory(peak=10_000_000, allocs=1000)},
head_memory={BM_KEY: _make_memory(peak=10_000_000, allocs=1000)},
)
md = result.format_markdown()
# Even with identical memory, memory-only keys always show the table
assert "Peak Memory" in md
def test_timing_with_negligible_memory_suppressed(self) -> None:
"""When timing data exists, negligible memory changes are suppressed."""
result = CompareResult(
base_ref="abc123",
head_ref="def456",
base_stats={BM_KEY: _make_stats()},
head_stats={BM_KEY: _make_stats()},
base_memory={BM_KEY: _make_memory(peak=10_000_000, allocs=1000)},
head_memory={BM_KEY: _make_memory(peak=10_000_000, allocs=1000)},
)
md = result.format_markdown()
# Timing table should be there
assert "Min | Median | Mean | OPS" in md
# Memory table should be suppressed (delta <1% and timing exists)
assert "Peak Memory" not in md
def test_memory_only_key_mixed_with_timing_key(self) -> None:
"""Some keys have timing, others are memory-only."""
timing_key = BenchmarkKey(module_path="tests.bench", function_name="test_timing")
memory_key = BenchmarkKey(module_path="tests.bench", function_name="test_memory")
result = CompareResult(
base_ref="abc123",
head_ref="def456",
base_stats={timing_key: _make_stats()},
head_stats={timing_key: _make_stats(median_ns=500.0)},
base_memory={timing_key: _make_memory(peak=10_000_000), memory_key: _make_memory(peak=8_000_000)},
head_memory={timing_key: _make_memory(peak=5_000_000), memory_key: _make_memory(peak=6_000_000)},
)
md = result.format_markdown()
# Both benchmark keys should appear
assert "test_timing" in md
assert "test_memory" in md
# Timing table for timing_key
assert "Min | Median | Mean | OPS" in md
class TestRenderComparisonMemoryOnly:
def test_memory_only_no_crash(self, capsys: object) -> None:
"""render_comparison should not crash or warn with memory-only data."""
result = CompareResult(
base_ref="abc123",
head_ref="def456",
base_memory={BM_KEY: _make_memory(peak=10_000_000)},
head_memory={BM_KEY: _make_memory(peak=7_000_000)},
)
# Should not raise
render_comparison(result)
def test_empty_result_warns(self) -> None:
result = CompareResult(base_ref="abc123", head_ref="def456")
# Should return without error (just logs a warning)
render_comparison(result)
class TestHasMeaningfulMemoryChange:
def test_both_none(self) -> None:
assert not has_meaningful_memory_change(None, None)
def test_one_none(self) -> None:
assert has_meaningful_memory_change(_make_memory(), None)
assert has_meaningful_memory_change(None, _make_memory())
def test_both_zero(self) -> None:
assert not has_meaningful_memory_change(_make_memory(0, 0), _make_memory(0, 0))
def test_no_change(self) -> None:
mem = _make_memory(peak=1000, allocs=100)
assert not has_meaningful_memory_change(mem, mem)
def test_significant_peak_change(self) -> None:
base = _make_memory(peak=10_000_000, allocs=1000)
head = _make_memory(peak=8_000_000, allocs=1000)
assert has_meaningful_memory_change(base, head)
def test_significant_alloc_change(self) -> None:
base = _make_memory(peak=10_000_000, allocs=1000)
head = _make_memory(peak=10_000_000, allocs=800)
assert has_meaningful_memory_change(base, head)
class TestScriptCompareResult:
def test_format_markdown_basic(self) -> None:
result = ScriptCompareResult(
base_ref="abc123",
head_ref="def456",
base_results={"file1.pdf": 12.34, "file2.docx": 1.23},
head_results={"file1.pdf": 10.21, "file2.docx": 1.45},
)
md = result.format_markdown()
assert "file1.pdf" in md
assert "file2.docx" in md
assert "Base" in md
assert "Head" in md
def test_format_markdown_empty(self) -> None:
result = ScriptCompareResult(base_ref="abc123", head_ref="def456")
md = result.format_markdown()
assert md == "_No benchmark results to compare._"
def test_format_markdown_total_row(self) -> None:
result = ScriptCompareResult(
base_ref="abc123",
head_ref="def456",
base_results={"test1": 1.0, "__total__": 5.0},
head_results={"test1": 0.8, "__total__": 4.0},
)
md = result.format_markdown()
assert "**TOTAL**" in md
# __total__ should not appear as a regular key row
assert md.count("__total__") == 0
def test_format_markdown_missing_keys(self) -> None:
result = ScriptCompareResult(
base_ref="abc123", head_ref="def456", base_results={"only_base": 2.0}, head_results={"only_head": 3.0}
)
md = result.format_markdown()
assert "only_base" in md
assert "only_head" in md
def test_format_markdown_with_memory(self) -> None:
result = ScriptCompareResult(
base_ref="abc123",
head_ref="def456",
base_results={"test1": 1.0},
head_results={"test1": 0.5},
base_memory=_make_memory(peak=10_000_000, allocs=500),
head_memory=_make_memory(peak=7_000_000, allocs=400),
)
md = result.format_markdown()
assert "Peak Memory" in md
assert "Allocations" in md
def test_render_no_crash(self) -> None:
result = ScriptCompareResult(
base_ref="abc123",
head_ref="def456",
base_results={"a": 1.0, "b": 2.0, "__total__": 3.0},
head_results={"a": 0.5, "b": 1.5, "__total__": 2.0},
)
render_script_comparison(result)
def test_render_empty_no_crash(self) -> None:
result = ScriptCompareResult(base_ref="abc123", head_ref="def456")
render_script_comparison(result)
def test_render_with_memory_no_crash(self) -> None:
result = ScriptCompareResult(
base_ref="abc123",
head_ref="def456",
base_results={"test1": 5.0},
head_results={"test1": 4.0},
base_memory=_make_memory(peak=10_000_000, allocs=1000),
head_memory=_make_memory(peak=8_000_000, allocs=900),
)
render_script_comparison(result)

View file

@ -0,0 +1,94 @@
"""Test fix_jest_mock_paths function with vitest mocks."""
from pathlib import Path
from codeflash.languages.javascript.instrument import fix_jest_mock_paths
def test_fix_vitest_mock_paths():
"""Test that vi.mock() paths are fixed correctly."""
# Simulate source at src/agents/workspace.ts importing from ../routing/session-key
# Test at test/test_workspace.test.ts should mock ../src/routing/session-key, not ../routing/session-key
test_code = """
vi.mock('../routing/session-key', () => ({
isSubagentSessionKey: vi.fn(),
isCronSessionKey: vi.fn(),
}));
import { filterBootstrapFilesForSession } from '../src/agents/workspace.js';
"""
# Create temp directories and files for testing
import tempfile
with tempfile.TemporaryDirectory() as tmpdir:
project = Path(tmpdir)
# Create directory structure
src = project / "src"
src_agents = src / "agents"
src_routing = src / "routing"
test_dir = project / "test"
src_agents.mkdir(parents=True)
src_routing.mkdir(parents=True)
test_dir.mkdir(parents=True)
# Create files
source_file = src_agents / "workspace.ts"
source_file.write_text("export function filterBootstrapFilesForSession() {}")
routing_file = src_routing / "session-key.ts"
routing_file.write_text("export function isSubagentSessionKey() {}")
test_file = test_dir / "test_workspace.test.ts"
test_file.write_text(test_code)
# Fix the paths
fixed = fix_jest_mock_paths(test_code, test_file, source_file, test_dir)
# Should change ../routing/session-key to ../src/routing/session-key
assert "../src/routing/session-key" in fixed, f"Expected path to be fixed, got: {fixed}"
assert "../routing/session-key" not in fixed or "../src/routing/session-key" in fixed
def test_fix_jest_mock_paths_still_works():
"""Test that jest.mock() paths are still fixed correctly."""
test_code = """
jest.mock('../routing/session-key', () => ({
isSubagentSessionKey: jest.fn(),
}));
"""
import tempfile
with tempfile.TemporaryDirectory() as tmpdir:
project = Path(tmpdir)
src = project / "src"
src_agents = src / "agents"
src_routing = src / "routing"
test_dir = project / "test"
src_agents.mkdir(parents=True)
src_routing.mkdir(parents=True)
test_dir.mkdir(parents=True)
source_file = src_agents / "workspace.ts"
source_file.write_text("")
routing_file = src_routing / "session-key.ts"
routing_file.write_text("")
test_file = test_dir / "test_workspace.test.ts"
test_file.write_text(test_code)
fixed = fix_jest_mock_paths(test_code, test_file, source_file, test_dir)
assert "../src/routing/session-key" in fixed
if __name__ == "__main__":
test_fix_vitest_mock_paths()
test_fix_jest_mock_paths_still_works()
print("All tests passed!")

View file

@ -0,0 +1,66 @@
"""Test that js_project_root is recalculated per function, not cached."""
from pathlib import Path
from codeflash.languages.javascript.test_runner import find_node_project_root
def test_find_node_project_root_returns_different_roots_for_different_files(tmp_path: Path) -> None:
"""Test that find_node_project_root returns the correct root for each file."""
# Create main project structure
main_project = (tmp_path / "project").resolve()
main_project.mkdir()
(main_project / "package.json").write_text("{}", encoding="utf-8")
(main_project / "src").mkdir()
main_file = (main_project / "src" / "main.ts").resolve()
main_file.write_text("// main file", encoding="utf-8")
# Create extension subdirectory with its own package.json
extension_dir = (main_project / "extensions" / "discord").resolve()
extension_dir.mkdir(parents=True)
(extension_dir / "package.json").write_text("{}", encoding="utf-8")
(extension_dir / "src").mkdir()
extension_file = (extension_dir / "src" / "accounts.ts").resolve()
extension_file.write_text("// extension file", encoding="utf-8")
# Extension file should return extension directory
result1 = find_node_project_root(extension_file)
assert result1 == extension_dir, f"Expected {extension_dir}, got {result1}"
# Main file should return main project directory
result2 = find_node_project_root(main_file)
assert result2 == main_project, f"Expected {main_project}, got {result2}"
# Calling again with extension file should still return extension dir
result3 = find_node_project_root(extension_file)
assert result3 == extension_dir, f"Expected {extension_dir}, got {result3}"
def test_js_project_root_recalculated_per_function(tmp_path: Path) -> None:
"""Each function in a monorepo should resolve to its own nearest package.json root."""
# Create main project
main_project = (tmp_path / "project").resolve()
main_project.mkdir()
(main_project / "package.json").write_text('{"name": "main"}', encoding="utf-8")
(main_project / "src").mkdir()
# Create extension with its own package.json
extension_dir = (main_project / "extensions" / "discord").resolve()
extension_dir.mkdir(parents=True)
(extension_dir / "package.json").write_text('{"name": "discord-extension"}', encoding="utf-8")
(extension_dir / "src").mkdir()
extension_file = (extension_dir / "src" / "accounts.ts").resolve()
extension_file.write_text("export function foo() {}", encoding="utf-8")
main_file = (main_project / "src" / "commands.ts").resolve()
main_file.write_text("export function bar() {}", encoding="utf-8")
js_project_root_1 = find_node_project_root(extension_file)
assert js_project_root_1 == extension_dir
js_project_root_2 = find_node_project_root(main_file)
assert js_project_root_2 == main_project, (
f"Expected {main_project}, got {js_project_root_2}. "
f"Happens when js_project_root is not recalculated per function."
)

View file

@ -23,7 +23,7 @@
<dependency>
<groupId>com.codeflash</groupId>
<artifactId>codeflash-runtime</artifactId>
<version>1.0.0</version>
<version>1.0.1</version>
<scope>test</scope>
</dependency>
</dependencies>

View file

@ -36,20 +36,30 @@ public class Workload {
}
public static void main(String[] args) {
// Exercise the methods so the tracer can capture invocations
System.out.println("computeSum(100) = " + computeSum(100));
System.out.println("computeSum(50) = " + computeSum(50));
System.out.println("repeatString(\"ab\", 3) = " + repeatString("ab", 3));
System.out.println("repeatString(\"x\", 5) = " + repeatString("x", 5));
// Run methods with large inputs so JFR can capture CPU samples.
// Small inputs finish too fast (<1ms) for JFR's 10ms sampling interval.
for (int round = 0; round < 1000; round++) {
computeSum(100_000);
repeatString("hello world ", 1000);
List<Integer> nums = new ArrayList<>();
for (int i = 1; i <= 10; i++) nums.add(i);
System.out.println("filterEvens(1..10) = " + filterEvens(nums));
for (int i = 1; i <= 10_000; i++) nums.add(i);
filterEvens(nums);
Workload w = new Workload();
w.instanceMethod(100_000, 42);
}
// Also call with small inputs for variety in traced args
System.out.println("computeSum(100) = " + computeSum(100));
System.out.println("repeatString(\"ab\", 3) = " + repeatString("ab", 3));
List<Integer> small = new ArrayList<>();
for (int i = 1; i <= 10; i++) small.add(i);
System.out.println("filterEvens(1..10) = " + filterEvens(small));
Workload w = new Workload();
System.out.println("instanceMethod(5, 3) = " + w.instanceMethod(5, 3));
System.out.println("instanceMethod(10, 2) = " + w.instanceMethod(10, 2));
System.out.println("Workload complete.");
}

View file

@ -0,0 +1,100 @@
"""Test for inject_test_globals duplicate import bug.
This test reproduces the bug where AI-generated tests already have vitest imports,
but inject_test_globals() adds them again because the string-based check doesn't
catch semantic duplicates with different identifier orders.
"""
import pytest
from codeflash.languages.javascript.edit_tests import inject_test_globals
from codeflash.models.models import GeneratedTests, GeneratedTestsList
from pathlib import Path
def test_inject_test_globals_skips_existing_vitest_imports() -> None:
"""Test that inject_test_globals skips injection when vitest import already exists."""
# AI service generated this test with vitest imports already present
# (note: different order and identifiers than what inject_test_globals would add)
ai_generated_test = """// vitest imports (REQUIRED for vitest - globals are NOT enabled by default)
import { describe, test, expect, vi, beforeEach, afterEach } from 'vitest';
// function import
import { isWindowsDrivePath } from './infra/archive-path';
// unit tests
describe('isWindowsDrivePath', () => {
test('should return true for Windows drive paths', () => {
expect(isWindowsDrivePath('C:\\\\')).toBe(true);
});
});
"""
generated_tests = GeneratedTestsList(
generated_tests=[
GeneratedTests(
generated_original_test_source=ai_generated_test,
instrumented_behavior_test_source=ai_generated_test,
instrumented_perf_test_source=ai_generated_test,
behavior_file_path=Path("/tmp/test_isWindowsDrivePath.test.ts"),
perf_file_path=Path("/tmp/test_isWindowsDrivePath_perf.test.ts"),
)
]
)
# Call inject_test_globals for vitest + esm (this is what the CLI does)
result = inject_test_globals(generated_tests, test_framework="vitest", module_system="esm")
# Check that the import was NOT duplicated
result_source = result.generated_tests[0].generated_original_test_source
# Count how many times "from 'vitest'" appears
import_count = result_source.count("from 'vitest'")
# Should be exactly 1 import, not 2
assert import_count == 1, (
f"Expected exactly 1 vitest import, but found {import_count}. "
f"inject_test_globals() added a duplicate import when one already existed.\n"
f"Result:\n{result_source[:500]}"
)
# Also verify that we have the expected number of import statements
# Count actual import statements, not comments containing the word "import"
import_lines = [line for line in result_source.split('\n') if line.strip().startswith('import ')]
assert len(import_lines) == 2, f"Should have 2 import statements (vitest + function), found {len(import_lines)}: {import_lines}"
def test_inject_test_globals_adds_import_when_missing() -> None:
"""Test that inject_test_globals DOES add import when it's truly missing."""
# Test without any vitest imports
test_without_imports = """// function import
import { isWindowsDrivePath } from './infra/archive-path';
describe('isWindowsDrivePath', () => {
test('should return true', () => {
expect(isWindowsDrivePath('C:\\\\')).toBe(true);
});
});
"""
generated_tests = GeneratedTestsList(
generated_tests=[
GeneratedTests(
generated_original_test_source=test_without_imports,
instrumented_behavior_test_source=test_without_imports,
instrumented_perf_test_source=test_without_imports,
behavior_file_path=Path("/tmp/test.test.ts"),
perf_file_path=Path("/tmp/test_perf.test.ts"),
)
]
)
result = inject_test_globals(generated_tests, test_framework="vitest", module_system="esm")
result_source = result.generated_tests[0].generated_original_test_source
# Should have exactly 1 vitest import (the one we added)
import_count = result_source.count("from 'vitest'")
assert import_count == 1, f"Expected vitest import to be added, found {import_count}"
# Should be at the beginning of the file
assert result_source.startswith("import { vi, describe, it, expect"), (
"Vitest import should be added at the beginning"
)

View file

@ -641,3 +641,28 @@ class TestGradleEnsureRuntimeMultiModule:
assert result is True
nested_build = (nested / "build.gradle.kts").read_text(encoding="utf-8")
assert "codeflash-runtime" in nested_build
class TestValidationSkipFlags:
"""Tests that validation skip flags include all known static analysis and formatting plugins."""
def test_maven_skip_flags_include_spotless(self):
from codeflash.languages.java.maven_strategy import _MAVEN_VALIDATION_SKIP_FLAGS
flags_str = " ".join(_MAVEN_VALIDATION_SKIP_FLAGS)
assert "-Dspotless.check.skip=true" in flags_str
assert "-Dspotless.apply.skip=true" in flags_str
def test_maven_skip_flags_include_all_known_plugins(self):
from codeflash.languages.java.maven_strategy import _MAVEN_VALIDATION_SKIP_FLAGS
flags_str = " ".join(_MAVEN_VALIDATION_SKIP_FLAGS)
for plugin in ["rat", "checkstyle", "spotbugs", "pmd", "enforcer", "japicmp", "errorprone", "spotless"]:
assert plugin in flags_str, f"Missing skip flag for {plugin}"
def test_gradle_skip_script_includes_spotless(self):
from codeflash.languages.java.gradle_strategy import _GRADLE_SKIP_VALIDATION_INIT_SCRIPT
assert "spotlessCheck" in _GRADLE_SKIP_VALIDATION_INIT_SCRIPT
assert "spotlessApply" in _GRADLE_SKIP_VALIDATION_INIT_SCRIPT
assert "spotlessJava" in _GRADLE_SKIP_VALIDATION_INIT_SCRIPT

View file

@ -0,0 +1,302 @@
"""Tests for JFR parser — class name normalization, package filtering, addressable time."""
from __future__ import annotations
import json
import subprocess
from pathlib import Path
from unittest.mock import patch
import pytest
from codeflash.languages.java.jfr_parser import JfrProfile
def _make_jfr_json(events: list[dict]) -> str:
"""Create fake JFR JSON output matching the jfr print format."""
return json.dumps({"recording": {"events": events}})
def _make_execution_sample(class_name: str, method_name: str, start_time: str = "2026-01-01T00:00:00Z") -> dict:
return {
"type": "jdk.ExecutionSample",
"values": {
"startTime": start_time,
"stackTrace": {
"frames": [
{
"method": {
"type": {"name": class_name},
"name": method_name,
"descriptor": "()V",
},
"lineNumber": 42,
}
],
},
},
}
class TestClassNameNormalization:
"""Test that JVM internal class names (com/example/Foo) are normalized to dots (com.example.Foo)."""
def test_slash_separators_normalized_to_dots(self, tmp_path: Path) -> None:
jfr_file = tmp_path / "test.jfr"
jfr_file.write_text("dummy", encoding="utf-8")
jfr_json = _make_jfr_json(
[
_make_execution_sample("com/aerospike/client/command/Buffer", "bytesToInt"),
_make_execution_sample("com/aerospike/client/command/Buffer", "bytesToInt"),
_make_execution_sample("com/aerospike/client/util/Utf8", "encodedLength"),
]
)
with patch("subprocess.run") as mock_run:
mock_run.return_value = subprocess.CompletedProcess(args=[], returncode=0, stdout=jfr_json, stderr="")
profile = JfrProfile(jfr_file, ["com.aerospike"])
assert profile._total_samples == 3
assert len(profile._method_samples) == 2
# Keys should use dots, not slashes
assert "com.aerospike.client.command.Buffer.bytesToInt" in profile._method_samples
assert "com.aerospike.client.util.Utf8.encodedLength" in profile._method_samples
def test_method_info_uses_dot_class_names(self, tmp_path: Path) -> None:
jfr_file = tmp_path / "test.jfr"
jfr_file.write_text("dummy", encoding="utf-8")
jfr_json = _make_jfr_json(
[_make_execution_sample("com/example/MyClass", "myMethod")]
)
with patch("subprocess.run") as mock_run:
mock_run.return_value = subprocess.CompletedProcess(args=[], returncode=0, stdout=jfr_json, stderr="")
profile = JfrProfile(jfr_file, ["com.example"])
info = profile._method_info.get("com.example.MyClass.myMethod")
assert info is not None
assert info["class_name"] == "com.example.MyClass"
assert info["method_name"] == "myMethod"
class TestPackageFiltering:
def test_filters_by_package_prefix(self, tmp_path: Path) -> None:
jfr_file = tmp_path / "test.jfr"
jfr_file.write_text("dummy", encoding="utf-8")
jfr_json = _make_jfr_json(
[
_make_execution_sample("com/aerospike/client/Value", "get"),
_make_execution_sample("java/util/HashMap", "put"),
_make_execution_sample("com/aerospike/benchmarks/Main", "main"),
]
)
with patch("subprocess.run") as mock_run:
mock_run.return_value = subprocess.CompletedProcess(args=[], returncode=0, stdout=jfr_json, stderr="")
profile = JfrProfile(jfr_file, ["com.aerospike"])
# Only com.aerospike classes should be in samples
assert len(profile._method_samples) == 2
assert "com.aerospike.client.Value.get" in profile._method_samples
assert "com.aerospike.benchmarks.Main.main" in profile._method_samples
assert "java.util.HashMap.put" not in profile._method_samples
def test_empty_packages_includes_all(self, tmp_path: Path) -> None:
jfr_file = tmp_path / "test.jfr"
jfr_file.write_text("dummy", encoding="utf-8")
jfr_json = _make_jfr_json(
[
_make_execution_sample("com/example/Foo", "bar"),
_make_execution_sample("java/lang/String", "length"),
]
)
with patch("subprocess.run") as mock_run:
mock_run.return_value = subprocess.CompletedProcess(args=[], returncode=0, stdout=jfr_json, stderr="")
profile = JfrProfile(jfr_file, [])
assert len(profile._method_samples) == 2
class TestAddressableTime:
def test_addressable_time_proportional_to_samples(self, tmp_path: Path) -> None:
jfr_file = tmp_path / "test.jfr"
jfr_file.write_text("dummy", encoding="utf-8")
# 3 samples for methodA, 1 for methodB, spanning 10 seconds
jfr_json = _make_jfr_json(
[
_make_execution_sample("com/example/Foo", "methodA", "2026-01-01T00:00:00Z"),
_make_execution_sample("com/example/Foo", "methodA", "2026-01-01T00:00:03Z"),
_make_execution_sample("com/example/Foo", "methodA", "2026-01-01T00:00:06Z"),
_make_execution_sample("com/example/Foo", "methodB", "2026-01-01T00:00:10Z"),
]
)
with patch("subprocess.run") as mock_run:
mock_run.return_value = subprocess.CompletedProcess(args=[], returncode=0, stdout=jfr_json, stderr="")
profile = JfrProfile(jfr_file, ["com.example"])
time_a = profile.get_addressable_time_ns("com.example.Foo", "methodA")
time_b = profile.get_addressable_time_ns("com.example.Foo", "methodB")
# methodA has 3x the samples of methodB, so 3x the addressable time
assert time_a > 0
assert time_b > 0
assert time_a == pytest.approx(time_b * 3, rel=0.01)
def test_addressable_time_zero_for_unknown_method(self, tmp_path: Path) -> None:
jfr_file = tmp_path / "test.jfr"
jfr_file.write_text("dummy", encoding="utf-8")
jfr_json = _make_jfr_json(
[_make_execution_sample("com/example/Foo", "bar")]
)
with patch("subprocess.run") as mock_run:
mock_run.return_value = subprocess.CompletedProcess(args=[], returncode=0, stdout=jfr_json, stderr="")
profile = JfrProfile(jfr_file, ["com.example"])
assert profile.get_addressable_time_ns("com.example.Foo", "nonExistent") == 0.0
class TestMethodRanking:
def test_ranking_ordered_by_sample_count(self, tmp_path: Path) -> None:
jfr_file = tmp_path / "test.jfr"
jfr_file.write_text("dummy", encoding="utf-8")
jfr_json = _make_jfr_json(
[
_make_execution_sample("com/example/A", "hot"),
_make_execution_sample("com/example/A", "hot"),
_make_execution_sample("com/example/A", "hot"),
_make_execution_sample("com/example/B", "warm"),
_make_execution_sample("com/example/B", "warm"),
_make_execution_sample("com/example/C", "cold"),
]
)
with patch("subprocess.run") as mock_run:
mock_run.return_value = subprocess.CompletedProcess(args=[], returncode=0, stdout=jfr_json, stderr="")
profile = JfrProfile(jfr_file, ["com.example"])
ranking = profile.get_method_ranking()
assert len(ranking) == 3
assert ranking[0]["method_name"] == "hot"
assert ranking[0]["sample_count"] == 3
assert ranking[1]["method_name"] == "warm"
assert ranking[1]["sample_count"] == 2
assert ranking[2]["method_name"] == "cold"
assert ranking[2]["sample_count"] == 1
def test_empty_ranking_when_no_samples(self, tmp_path: Path) -> None:
jfr_file = tmp_path / "test.jfr"
jfr_file.write_text("dummy", encoding="utf-8")
jfr_json = _make_jfr_json([])
with patch("subprocess.run") as mock_run:
mock_run.return_value = subprocess.CompletedProcess(args=[], returncode=0, stdout=jfr_json, stderr="")
profile = JfrProfile(jfr_file, ["com.example"])
assert profile.get_method_ranking() == []
def test_ranking_uses_dot_class_names(self, tmp_path: Path) -> None:
jfr_file = tmp_path / "test.jfr"
jfr_file.write_text("dummy", encoding="utf-8")
jfr_json = _make_jfr_json(
[_make_execution_sample("com/example/nested/Deep", "method")]
)
with patch("subprocess.run") as mock_run:
mock_run.return_value = subprocess.CompletedProcess(args=[], returncode=0, stdout=jfr_json, stderr="")
profile = JfrProfile(jfr_file, ["com.example"])
ranking = profile.get_method_ranking()
assert len(ranking) == 1
assert ranking[0]["class_name"] == "com.example.nested.Deep"
class TestGracefulTimeout:
"""Test that _run_java_with_graceful_timeout sends SIGTERM before SIGKILL."""
def test_sends_sigterm_on_timeout(self) -> None:
import signal
from codeflash.languages.java.tracer import _run_java_with_graceful_timeout
# Run a sleep command with a 1s timeout — should get SIGTERM'd
import os
env = os.environ.copy()
_run_java_with_graceful_timeout(["sleep", "60"], env, timeout=1, stage_name="test")
# If we get here, the process was killed (didn't hang for 60s)
def test_no_timeout_runs_normally(self) -> None:
import os
from codeflash.languages.java.tracer import _run_java_with_graceful_timeout
env = os.environ.copy()
_run_java_with_graceful_timeout(["echo", "hello"], env, timeout=0, stage_name="test")
# Should complete without error
class TestProjectRootResolution:
"""Test that project_root is correctly set for Java multi-module projects."""
def test_java_project_root_is_build_root_not_module(self, tmp_path: Path, monkeypatch: pytest.MonkeyPatch) -> None:
"""For multi-module Maven, project_root should be the root with <modules>, not a sub-module."""
# Create a multi-module project
(tmp_path / "pom.xml").write_text(
'<project xmlns="http://maven.apache.org/POM/4.0.0"><modules><module>client</module></modules></project>',
encoding="utf-8",
)
client = tmp_path / "client"
client.mkdir()
(client / "pom.xml").write_text("<project/>", encoding="utf-8")
src = client / "src" / "main" / "java"
src.mkdir(parents=True)
test = tmp_path / "src" / "test" / "java"
test.mkdir(parents=True)
monkeypatch.chdir(tmp_path)
from codeflash.code_utils.config_parser import parse_config_file
config, config_path = parse_config_file()
assert config["language"] == "java"
# config_path should be the project root directory
assert config_path == tmp_path
def test_project_root_is_path_not_string(self, tmp_path: Path, monkeypatch: pytest.MonkeyPatch) -> None:
"""project_root from process_pyproject_config should be a Path for Java projects."""
from argparse import Namespace
(tmp_path / "pom.xml").write_text("<project/>", encoding="utf-8")
src = tmp_path / "src" / "main" / "java"
src.mkdir(parents=True)
test = tmp_path / "src" / "test" / "java"
test.mkdir(parents=True)
monkeypatch.chdir(tmp_path)
from codeflash.cli_cmds.cli import process_pyproject_config
# Create a minimal args namespace matching what parse_args produces
args = Namespace(
config_file=None, module_root=None, tests_root=None, benchmarks_root=None,
ignore_paths=None, pytest_cmd=None, formatter_cmds=None, disable_telemetry=None,
disable_imports_sorting=None, git_remote=None, override_fixtures=None,
benchmark=False, verbose=False, version=False, show_config=False, reset_config=False,
)
args = process_pyproject_config(args)
assert hasattr(args, "project_root")
assert isinstance(args.project_root, Path)
assert args.project_root == tmp_path

View file

@ -0,0 +1,255 @@
"""Tests for Java replay test generation — JUnit 4/5 support, overload handling, instrumentation skip."""
from __future__ import annotations
import sqlite3
from pathlib import Path
import pytest
from codeflash.languages.java.replay_test import generate_replay_tests, parse_replay_test_metadata
@pytest.fixture
def trace_db(tmp_path: Path) -> Path:
"""Create a trace database with sample function calls."""
db_path = tmp_path / "trace.db"
conn = sqlite3.connect(str(db_path))
conn.execute(
"CREATE TABLE function_calls("
"type TEXT, function TEXT, classname TEXT, filename TEXT, "
"line_number INTEGER, descriptor TEXT, time_ns INTEGER, args BLOB)"
)
conn.execute("CREATE TABLE metadata(key TEXT PRIMARY KEY, value TEXT)")
conn.execute(
"INSERT INTO function_calls VALUES (?, ?, ?, ?, ?, ?, ?, ?)",
("call", "add", "com.example.Calculator", "Calculator.java", 10, "(II)I", 1000, b"\x00"),
)
conn.execute(
"INSERT INTO function_calls VALUES (?, ?, ?, ?, ?, ?, ?, ?)",
("call", "add", "com.example.Calculator", "Calculator.java", 10, "(II)I", 2000, b"\x00"),
)
conn.execute(
"INSERT INTO function_calls VALUES (?, ?, ?, ?, ?, ?, ?, ?)",
("call", "multiply", "com.example.Calculator", "Calculator.java", 20, "(II)I", 3000, b"\x00"),
)
conn.commit()
conn.close()
return db_path
@pytest.fixture
def trace_db_overloaded(tmp_path: Path) -> Path:
"""Create a trace database with overloaded methods (same name, different descriptors)."""
db_path = tmp_path / "trace_overloaded.db"
conn = sqlite3.connect(str(db_path))
conn.execute(
"CREATE TABLE function_calls("
"type TEXT, function TEXT, classname TEXT, filename TEXT, "
"line_number INTEGER, descriptor TEXT, time_ns INTEGER, args BLOB)"
)
conn.execute("CREATE TABLE metadata(key TEXT PRIMARY KEY, value TEXT)")
# Two overloads of estimateKeySize with different descriptors
for i in range(3):
conn.execute(
"INSERT INTO function_calls VALUES (?, ?, ?, ?, ?, ?, ?, ?)",
("call", "estimateKeySize", "com.example.Command", "Command.java", 10, "(I)I", i * 1000, b"\x00"),
)
for i in range(2):
conn.execute(
"INSERT INTO function_calls VALUES (?, ?, ?, ?, ?, ?, ?, ?)",
(
"call",
"estimateKeySize",
"com.example.Command",
"Command.java",
15,
"(Ljava/lang/String;)I",
(i + 10) * 1000,
b"\x00",
),
)
conn.commit()
conn.close()
return db_path
class TestGenerateReplayTestsJunit5:
def test_generates_junit5_by_default(self, trace_db: Path, tmp_path: Path) -> None:
output_dir = tmp_path / "output"
count = generate_replay_tests(trace_db, output_dir, tmp_path)
assert count == 1
test_file = list(output_dir.glob("*.java"))[0]
content = test_file.read_text(encoding="utf-8")
assert "import org.junit.jupiter.api.Test;" in content
assert "import org.junit.jupiter.api.AfterAll;" in content
assert "@Test void replay_add_0()" in content
def test_junit5_class_is_package_private(self, trace_db: Path, tmp_path: Path) -> None:
output_dir = tmp_path / "output"
generate_replay_tests(trace_db, output_dir, tmp_path)
test_file = list(output_dir.glob("*.java"))[0]
content = test_file.read_text(encoding="utf-8")
assert "class ReplayTest_" in content
assert "public class ReplayTest_" not in content
class TestGenerateReplayTestsJunit4:
def test_generates_junit4_imports(self, trace_db: Path, tmp_path: Path) -> None:
output_dir = tmp_path / "output"
count = generate_replay_tests(trace_db, output_dir, tmp_path, test_framework="junit4")
assert count == 1
test_file = list(output_dir.glob("*.java"))[0]
content = test_file.read_text(encoding="utf-8")
assert "import org.junit.Test;" in content
assert "import org.junit.AfterClass;" in content
assert "org.junit.jupiter" not in content
def test_junit4_methods_are_public(self, trace_db: Path, tmp_path: Path) -> None:
output_dir = tmp_path / "output"
generate_replay_tests(trace_db, output_dir, tmp_path, test_framework="junit4")
test_file = list(output_dir.glob("*.java"))[0]
content = test_file.read_text(encoding="utf-8")
assert "@Test public void replay_add_0()" in content
def test_junit4_class_is_public(self, trace_db: Path, tmp_path: Path) -> None:
output_dir = tmp_path / "output"
generate_replay_tests(trace_db, output_dir, tmp_path, test_framework="junit4")
test_file = list(output_dir.glob("*.java"))[0]
content = test_file.read_text(encoding="utf-8")
assert "public class ReplayTest_" in content
def test_junit4_cleanup_uses_afterclass(self, trace_db: Path, tmp_path: Path) -> None:
output_dir = tmp_path / "output"
generate_replay_tests(trace_db, output_dir, tmp_path, test_framework="junit4")
test_file = list(output_dir.glob("*.java"))[0]
content = test_file.read_text(encoding="utf-8")
assert "@AfterClass" in content
assert "@AfterAll" not in content
class TestOverloadedMethods:
def test_no_duplicate_method_names(self, trace_db_overloaded: Path, tmp_path: Path) -> None:
output_dir = tmp_path / "output"
count = generate_replay_tests(trace_db_overloaded, output_dir, tmp_path)
assert count == 1
test_file = list(output_dir.glob("*.java"))[0]
content = test_file.read_text(encoding="utf-8")
# Should have 5 unique methods (3 from first overload + 2 from second)
assert "replay_estimateKeySize_0" in content
assert "replay_estimateKeySize_1" in content
assert "replay_estimateKeySize_2" in content
assert "replay_estimateKeySize_3" in content
assert "replay_estimateKeySize_4" in content
# Verify no duplicates by counting occurrences
lines = content.splitlines()
method_lines = [l for l in lines if "void replay_estimateKeySize_" in l]
method_names = [l.split("void ")[1].split("(")[0] for l in method_lines]
assert len(method_names) == len(set(method_names)), f"Duplicate methods: {method_names}"
class TestReplayTestInstrumentation:
def test_replay_tests_instrumented_correctly(self, trace_db: Path, tmp_path: Path) -> None:
"""Replay tests with compact @Test lines should be instrumented without orphaned code."""
from codeflash.languages.java.discovery import discover_functions_from_source
output_dir = tmp_path / "output"
generate_replay_tests(trace_db, output_dir, tmp_path)
test_file = list(output_dir.glob("*.java"))[0]
src = "public class Calculator { public int add(int a, int b) { return a + b; } }"
funcs = discover_functions_from_source(src, tmp_path / "Calculator.java")
target = funcs[0]
from codeflash.languages.java.support import JavaSupport
support = JavaSupport()
success, instrumented = support.instrument_existing_test(
test_path=test_file,
call_positions=[],
function_to_optimize=target,
tests_project_root=tmp_path,
mode="behavior",
)
assert success
assert instrumented is not None
assert "__perfinstrumented" in instrumented
# Verify no code outside class body
lines = instrumented.splitlines()
class_closed = False
for line in lines:
if line.strip() == "}" and not line.startswith(" "):
class_closed = True
elif class_closed and line.strip() and not line.strip().startswith("//"):
pytest.fail(f"Orphaned code outside class: {line}")
def test_replay_tests_perf_instrumented(self, trace_db: Path, tmp_path: Path) -> None:
from codeflash.languages.java.discovery import discover_functions_from_source
output_dir = tmp_path / "output"
generate_replay_tests(trace_db, output_dir, tmp_path)
test_file = list(output_dir.glob("*.java"))[0]
src = "public class Calculator { public int add(int a, int b) { return a + b; } }"
funcs = discover_functions_from_source(src, tmp_path / "Calculator.java")
target = funcs[0]
from codeflash.languages.java.support import JavaSupport
support = JavaSupport()
success, instrumented = support.instrument_existing_test(
test_path=test_file,
call_positions=[],
function_to_optimize=target,
tests_project_root=tmp_path,
mode="performance",
)
assert success
assert "__perfonlyinstrumented" in instrumented
def test_regular_tests_still_instrumented(self, tmp_path: Path) -> None:
from codeflash.languages.java.discovery import discover_functions_from_source
src = "public class Calculator { public int add(int a, int b) { return a + b; } }"
funcs = discover_functions_from_source(src, tmp_path / "Calculator.java")
target = funcs[0]
test_file = tmp_path / "CalculatorTest.java"
test_file.write_text(
"""
import org.junit.jupiter.api.Test;
public class CalculatorTest {
@Test
public void testAdd() {
Calculator calc = new Calculator();
calc.add(1, 2);
}
}
""",
encoding="utf-8",
)
from codeflash.languages.java.support import JavaSupport
support = JavaSupport()
success, instrumented = support.instrument_existing_test(
test_path=test_file,
call_positions=[],
function_to_optimize=target,
tests_project_root=tmp_path,
mode="behavior",
)
assert success
assert "CODEFLASH_LOOP_INDEX" in instrumented

View file

@ -512,13 +512,16 @@ public class PreciseWaiterTest {
stddev_runtime = statistics.stdev(runtimes)
coefficient_of_variation = stddev_runtime / mean_runtime
# Target: 10ms (10,000,000 ns), allow <5% coefficient of variation
# (accounts for JIT warmup - first iteration is cold, subsequent are optimized)
# Target: 10ms (10,000,000 ns), allow <15% coefficient of variation.
# The first iteration per test method runs with cold JIT, and shared CI VMs
# (especially Windows) have ~15ms scheduler granularity that adds noise.
# 15% still catches instrumentation bugs (e.g., 0ms or 100ms outliers)
# while the ±5% mean check below validates timing accuracy.
expected_ns = 10_000_000
runtimes_ms = [r / 1_000_000 for r in runtimes]
assert coefficient_of_variation < 0.05, (
f"Timing variance too high: CV={coefficient_of_variation:.2%} (should be <5%). "
assert coefficient_of_variation < 0.15, (
f"Timing variance too high: CV={coefficient_of_variation:.2%} (should be <15%). "
f"Runtimes: {runtimes_ms} ms (mean={mean_runtime / 1_000_000:.3f}ms)"
)
@ -597,13 +600,16 @@ public class PreciseWaiterMultiTest {
stddev_runtime = statistics.stdev(runtimes)
coefficient_of_variation = stddev_runtime / mean_runtime
# Target: 10ms (10,000,000 ns), allow <5% coefficient of variation
# (accounts for JIT warmup - first iteration is cold, subsequent are optimized)
# Target: 10ms (10,000,000 ns), allow <15% coefficient of variation.
# The first iteration per test method runs with cold JIT, and shared CI VMs
# (especially Windows) have ~15ms scheduler granularity that adds noise.
# 15% still catches instrumentation bugs (e.g., 0ms or 100ms outliers)
# while the ±5% mean check below validates timing accuracy.
expected_ns = 10_000_000
runtimes_ms = [r / 1_000_000 for r in runtimes]
assert coefficient_of_variation < 0.05, (
f"Timing variance too high: CV={coefficient_of_variation:.2%} (should be <5%). "
assert coefficient_of_variation < 0.15, (
f"Timing variance too high: CV={coefficient_of_variation:.2%} (should be <15%). "
f"Runtimes: {runtimes_ms} ms (mean={mean_runtime / 1_000_000:.3f}ms)"
)

View file

@ -284,3 +284,80 @@ import { process } from './processor';"""
result = convert_commonjs_to_esm(code)
expected = "import { queue, context, db as dbCore, cache, events } from '@budibase/backend-core';"
assert result == expected
class TestAddJsExtensionsToRelativeImports:
"""Tests for adding .js extensions to relative imports in ESM mode."""
def test_add_js_extension_to_relative_import(self):
"""Test adding .js extension to relative import without extension."""
from codeflash.languages.javascript.module_system import add_js_extensions_to_relative_imports
code = "import TreeNode from '../../injector/topology-tree/tree-node';"
result = add_js_extensions_to_relative_imports(code)
expected = "import TreeNode from '../../injector/topology-tree/tree-node.js';"
assert result == expected
def test_add_js_extension_to_single_dot_import(self):
"""Test adding .js extension to same-directory import."""
from codeflash.languages.javascript.module_system import add_js_extensions_to_relative_imports
code = "import { foo } from './module';"
result = add_js_extensions_to_relative_imports(code)
expected = "import { foo } from './module.js';"
assert result == expected
def test_skip_imports_with_existing_extensions(self):
"""Test that imports with extensions are left unchanged."""
from codeflash.languages.javascript.module_system import add_js_extensions_to_relative_imports
code = "import TreeNode from '../../tree-node.js';"
result = add_js_extensions_to_relative_imports(code)
assert result == code
code2 = "import TreeNode from '../../tree-node.ts';"
result2 = add_js_extensions_to_relative_imports(code2)
assert result2 == code2
def test_skip_node_modules_imports(self):
"""Test that node_modules imports are left unchanged."""
from codeflash.languages.javascript.module_system import add_js_extensions_to_relative_imports
code = "import assert from 'node:assert/strict';"
result = add_js_extensions_to_relative_imports(code)
assert result == code
code2 = "import { describe } from 'mocha';"
result2 = add_js_extensions_to_relative_imports(code2)
assert result2 == code2
def test_multiple_imports(self):
"""Test handling multiple imports in one code block."""
from codeflash.languages.javascript.module_system import add_js_extensions_to_relative_imports
code = """import assert from 'node:assert/strict';
import TreeNode from '../../injector/topology-tree/tree-node';
import { helper } from './helper';"""
result = add_js_extensions_to_relative_imports(code)
expected = """import assert from 'node:assert/strict';
import TreeNode from '../../injector/topology-tree/tree-node.js';
import { helper } from './helper.js';"""
assert result == expected
def test_named_imports(self):
"""Test adding extensions to named imports."""
from codeflash.languages.javascript.module_system import add_js_extensions_to_relative_imports
code = "import { foo, bar } from '../utils/helpers';"
result = add_js_extensions_to_relative_imports(code)
expected = "import { foo, bar } from '../utils/helpers.js';"
assert result == expected
def test_namespace_imports(self):
"""Test adding extensions to namespace imports."""
from codeflash.languages.javascript.module_system import add_js_extensions_to_relative_imports
code = "import * as helpers from '../utils';"
result = add_js_extensions_to_relative_imports(code)
expected = "import * as helpers from '../utils.js';"
assert result == expected

View file

@ -122,7 +122,7 @@ class TestJestRootsConfiguration:
runtime_configs = [f for f in get_created_config_files() if "codeflash.runtime" in f.name]
assert len(runtime_configs) == 1, f"Expected 1 runtime config, got {len(runtime_configs)}"
config_content = runtime_configs[0].read_text(encoding="utf-8")
assert str(external_path) in config_content, "Runtime config should contain external test directory"
assert external_path.as_posix() in config_content, "Runtime config should contain external test directory"
clear_created_config_files()

View file

@ -0,0 +1,155 @@
"""Test for TypeScript Jest config require bug.
Regression test for the issue where _create_runtime_jest_config generates
code that tries to require('./jest.config.ts'), which fails because Node.js
CommonJS cannot load TypeScript files directly.
Bug: https://github.com/codeflash-ai/codeflash/issues/XXX
Affects: 18 out of 38 optimization runs in initial testing
"""
import subprocess
import tempfile
from pathlib import Path
import pytest
class TestTypeScriptJestConfigRequire:
"""Test that runtime config correctly handles TypeScript base configs."""
def test_runtime_config_with_typescript_base_config_loads_without_error(self):
"""Runtime config should NOT try to require .ts files directly.
When base_config_path points to jest.config.ts, the generated runtime
config must not use require('./jest.config.ts') because Node.js cannot
parse TypeScript syntax in CommonJS require().
This test creates a jest.config.ts file and verifies that the generated
runtime config can be successfully loaded by Node.js without syntax errors.
"""
from codeflash.languages.javascript.test_runner import _create_runtime_jest_config
with tempfile.TemporaryDirectory() as tmpdir:
project_path = Path(tmpdir).resolve()
# Create a TypeScript Jest config (realistic content with TS syntax)
ts_config_path = project_path / "jest.config.ts"
ts_config_content = """import { Config } from "jest"
const config: Config = {
testEnvironment: 'node',
testMatch: ['**/*.test.ts'],
moduleFileExtensions: ['ts', 'tsx', 'js', 'jsx', 'json', 'node'],
}
export default config
"""
ts_config_path.write_text(ts_config_content, encoding="utf-8")
# Create runtime config with the TS base config
test_dirs = {str(project_path / "test")}
runtime_config_path = _create_runtime_jest_config(
base_config_path=ts_config_path,
project_root=project_path,
test_dirs=test_dirs
)
assert runtime_config_path is not None, "Runtime config should be created"
assert runtime_config_path.exists(), "Runtime config file should exist"
# Read the generated content
runtime_content = runtime_config_path.read_text(encoding="utf-8")
# CRITICAL CHECK: Should NOT contain require('./jest.config.ts')
# This is the bug we're fixing
assert "require('./jest.config.ts')" not in runtime_content, (
"Runtime config should not try to require .ts files directly"
)
# The config should handle TypeScript configs appropriately:
# - Either omit the extension (let Node resolve to .js)
# - Or use a TypeScript loader (ts-node)
# - Or skip requiring TS configs entirely
# Verify the generated config can be loaded by Node.js without errors
test_script = project_path / "test_load_config.js"
test_script_content = f"""
try {{
const config = require('./{runtime_config_path.name}');
console.log('SUCCESS');
process.exit(0);
}} catch (err) {{
console.error('FAILED:', err.message);
process.exit(1);
}}
"""
test_script.write_text(test_script_content, encoding="utf-8")
result = subprocess.run(
["node", str(test_script)],
capture_output=True,
text=True,
cwd=project_path,
timeout=30,
)
assert result.returncode == 0, (
f"Generated runtime config should load without errors.\n"
f"Config path: {runtime_config_path}\n"
f"Config content:\n{runtime_content}\n"
f"Node output:\n{result.stdout}\n{result.stderr}"
)
assert "SUCCESS" in result.stdout
def test_runtime_config_with_js_base_config_works(self):
"""Verify that .js base configs still work correctly (control test)."""
from codeflash.languages.javascript.test_runner import _create_runtime_jest_config
with tempfile.TemporaryDirectory() as tmpdir:
project_path = Path(tmpdir).resolve()
# Create a JavaScript Jest config
js_config_path = project_path / "jest.config.js"
js_config_content = """module.exports = {
testEnvironment: 'node',
testMatch: ['**/*.test.js'],
}
"""
js_config_path.write_text(js_config_content, encoding="utf-8")
# Create runtime config with the JS base config
test_dirs = {str(project_path / "test")}
runtime_config_path = _create_runtime_jest_config(
base_config_path=js_config_path,
project_root=project_path,
test_dirs=test_dirs
)
assert runtime_config_path is not None
assert runtime_config_path.exists()
# Verify it loads without errors
test_script = project_path / "test_load_config.js"
test_script_content = f"""
try {{
const config = require('./{runtime_config_path.name}');
console.log('SUCCESS');
process.exit(0);
}} catch (err) {{
console.error('FAILED:', err.message);
process.exit(1);
}}
"""
test_script.write_text(test_script_content, encoding="utf-8")
result = subprocess.run(
["node", str(test_script)],
capture_output=True,
text=True,
cwd=project_path,
timeout=30,
)
assert result.returncode == 0, f"JS config should load: {result.stderr}"
assert "SUCCESS" in result.stdout

View file

@ -440,25 +440,19 @@ class TestDiscoverFunctionsParity:
assert js_sync.is_async is False, "JavaScript sync function should have is_async=False"
def test_nested_functions_discovery(self, python_support, js_support):
"""Python skips nested functions; JavaScript discovers them with parent info."""
"""Both Python and JavaScript skip nested functions — only outer is discovered."""
py_file = write_temp_file(NESTED_FUNCTIONS.python, ".py")
js_file = write_temp_file(NESTED_FUNCTIONS.javascript, ".js")
py_funcs = python_support.discover_functions(py_file.read_text(encoding="utf-8"), py_file)
js_funcs = js_support.discover_functions(js_file.read_text(encoding="utf-8"), js_file)
# Python skips nested functions — only outer is discovered
# Both skip nested functions — only outer is discovered
assert len(py_funcs) == 1, f"Python found {len(py_funcs)}, expected 1"
assert py_funcs[0].function_name == "outer"
# JavaScript discovers both
assert len(js_funcs) == 2, f"JavaScript found {len(js_funcs)}, expected 2"
js_names = {f.function_name for f in js_funcs}
assert js_names == {"outer", "inner"}, f"JavaScript found {js_names}"
js_inner = next(f for f in js_funcs if f.function_name == "inner")
assert len(js_inner.parents) >= 1, "JavaScript inner should have parent info"
assert js_inner.parents[0].name == "outer", "JavaScript inner's parent should be outer"
assert len(js_funcs) == 1, f"JavaScript found {len(js_funcs)}, expected 1"
assert js_funcs[0].function_name == "outer"
def test_static_methods_discovery(self, python_support, js_support):
"""Both should discover static methods."""

View file

@ -0,0 +1,85 @@
"""Tests for module_name_from_file_path with co-located test directories."""
import pytest
from pathlib import Path
from codeflash.code_utils.code_utils import module_name_from_file_path
class TestModuleNameFromFilePath:
"""Test module name resolution for various directory structures."""
def test_file_inside_project_root(self, tmp_path: Path) -> None:
"""Test normal case where file is inside project root."""
project_root = tmp_path / "project"
project_root.mkdir()
test_file = project_root / "test" / "test_foo.py"
test_file.parent.mkdir()
test_file.touch()
result = module_name_from_file_path(test_file, project_root)
assert result == "test.test_foo"
def test_file_outside_project_root_without_traverse_up(self, tmp_path: Path) -> None:
"""Test that file outside project root raises ValueError by default."""
project_root = tmp_path / "project" / "test"
project_root.mkdir(parents=True)
# File is in a sibling directory, not under project_root
test_file = tmp_path / "project" / "src" / "__tests__" / "test_foo.py"
test_file.parent.mkdir(parents=True)
test_file.touch()
with pytest.raises(ValueError, match="is not within the project root"):
module_name_from_file_path(test_file, project_root)
def test_file_outside_project_root_with_traverse_up(self, tmp_path: Path) -> None:
"""Test that traverse_up=True handles files outside project root."""
project_root = tmp_path / "project" / "test"
project_root.mkdir(parents=True)
# File is in a sibling directory, not under project_root
test_file = tmp_path / "project" / "src" / "__tests__" / "codeflash-generated" / "test_foo.py"
test_file.parent.mkdir(parents=True)
test_file.touch()
# With traverse_up=True, it should find a common ancestor
result = module_name_from_file_path(test_file, project_root, traverse_up=True)
# Should return a relative path from some ancestor directory
assert "test_foo" in result
assert not result.startswith(".")
def test_colocated_test_directory_structure(self, tmp_path: Path) -> None:
"""Test real-world scenario with co-located __tests__ directory.
This reproduces the bug from trace 7b97ddba-6ecd-42fd-b572-d40658746836:
- Source: /workspace/target/src/gateway/server/ws-connection/connect-policy.ts
- Tests root: /workspace/target/test
- Generated test: /workspace/target/src/gateway/server/__tests__/codeflash-generated/test_xxx.test.ts
Without traverse_up=True, this should fail.
"""
project_root = tmp_path / "target"
project_root.mkdir()
tests_root = project_root / "test"
tests_root.mkdir()
# Source file location
source_file = project_root / "src" / "gateway" / "server" / "ws-connection" / "connect-policy.ts"
source_file.parent.mkdir(parents=True)
source_file.touch()
# Generated test in co-located __tests__ directory
test_file = project_root / "src" / "gateway" / "server" / "__tests__" / "codeflash-generated" / "test_resolveControlUiAuthPolicy.test.ts"
test_file.parent.mkdir(parents=True)
test_file.touch()
# This should fail WITHOUT traverse_up
with pytest.raises(ValueError, match="is not within the project root"):
module_name_from_file_path(test_file, tests_root)
# This should succeed WITH traverse_up
result = module_name_from_file_path(test_file, tests_root, traverse_up=True)
assert "test_resolveControlUiAuthPolicy" in result

View file

@ -0,0 +1,57 @@
"""Test that test_cfg.js_project_root caching bug is demonstrated and bypassed by the fix."""
from pathlib import Path
from unittest.mock import patch
from codeflash.languages.javascript.support import JavaScriptSupport
from codeflash.verification.verification_utils import TestConfig
@patch("codeflash.languages.javascript.optimizer.verify_js_requirements")
def test_js_project_root_cached_in_test_cfg(mock_verify: object, tmp_path: Path) -> None:
"""Demonstrates that test_cfg.js_project_root is set once per setup_test_config call.
This test shows the root cause: test_cfg caches the project root from the first function.
The fix bypasses this cache in FunctionOptimizer.get_js_project_root() instead of
changing how test_cfg stores the value.
"""
mock_verify.return_value = [] # type: ignore[attr-defined]
# Create main project
main_project = (tmp_path / "project").resolve()
main_project.mkdir()
(main_project / "package.json").write_text('{"name": "main"}', encoding="utf-8")
(main_project / "src").mkdir()
(main_project / "test").mkdir()
(main_project / "node_modules").mkdir()
# Create extension with its own package.json
extension_dir = (main_project / "extensions" / "discord").resolve()
extension_dir.mkdir(parents=True)
(extension_dir / "package.json").write_text('{"name": "discord-extension"}', encoding="utf-8")
(extension_dir / "src").mkdir()
(extension_dir / "node_modules").mkdir()
test_cfg = TestConfig(
tests_root=main_project / "test",
project_root_path=main_project,
tests_project_rootdir=main_project / "test",
)
test_cfg.set_language("javascript")
js_support = JavaScriptSupport()
extension_file = (extension_dir / "src" / "accounts.ts").resolve()
extension_file.write_text("export function foo() {}", encoding="utf-8")
success = js_support.setup_test_config(test_cfg, extension_file, current_worktree=None)
assert success, "setup_test_config should succeed"
# After setup for extension file, js_project_root is the extension directory
assert test_cfg.js_project_root == extension_dir
# test_cfg is NOT re-initialized for subsequent functions — js_project_root stays cached
main_file = (main_project / "src" / "commands.ts").resolve()
main_file.write_text("export function bar() {}", encoding="utf-8")
# The cached value is still extension_dir, not main_project — this is the root cause
assert test_cfg.js_project_root == extension_dir

View file

@ -253,14 +253,15 @@ def test_run_and_parse_picklepatch() -> None:
cursor = conn.cursor()
cursor.execute(
"SELECT function_name, class_name, module_name, file_path, benchmark_function_name, benchmark_module_path, benchmark_line_number FROM benchmark_function_timings ORDER BY benchmark_module_path, benchmark_function_name, function_name"
"SELECT DISTINCT function_name, class_name, module_name, file_path, benchmark_function_name, benchmark_module_path, benchmark_line_number FROM benchmark_function_timings ORDER BY benchmark_module_path, benchmark_function_name, function_name"
)
function_calls = cursor.fetchall()
# Assert the length of function calls
assert len(function_calls) == 2, f"Expected 2 function calls, but got {len(function_calls)}"
function_benchmark_timings = codeflash_benchmark_plugin.get_function_benchmark_timings(output_file)
total_benchmark_timings = codeflash_benchmark_plugin.get_benchmark_timings(output_file)
total_benchmark_stats = codeflash_benchmark_plugin.get_benchmark_timings(output_file)
total_benchmark_timings = {k: v.median_ns for k, v in total_benchmark_stats.items()}
function_to_results = validate_and_format_benchmark_table(function_benchmark_timings, total_benchmark_timings)
assert (
"code_to_optimize.bubble_sort_picklepatch_test_unused_socket.bubble_sort_with_unused_socket"
@ -401,7 +402,7 @@ def test_run_and_parse_picklepatch() -> None:
pytest_max_loops=1,
testing_time=1.0,
)
assert len(test_results_unused_socket) == 1
assert len(test_results_unused_socket) >= 1
assert (
test_results_unused_socket.test_results[0].id.test_module_path
== "code_to_optimize.tests.pytest.benchmarks_socket_test.codeflash_replay_tests.test_code_to_optimize_tests_pytest_benchmarks_socket_test_test_socket__replay_test_0"
@ -410,7 +411,7 @@ def test_run_and_parse_picklepatch() -> None:
test_results_unused_socket.test_results[0].id.test_function_name
== "test_code_to_optimize_bubble_sort_picklepatch_test_unused_socket_bubble_sort_with_unused_socket_test_socket_picklepatch"
)
assert test_results_unused_socket.test_results[0].did_pass == True
assert test_results_unused_socket.test_results[0].did_pass is True
# Replace with optimized candidate
fto_unused_socket_path.write_text("""
@ -432,7 +433,7 @@ def bubble_sort_with_unused_socket(data_container):
pytest_max_loops=1,
testing_time=1.0,
)
assert len(optimized_test_results_unused_socket) == 1
assert len(optimized_test_results_unused_socket) >= 1
match, _ = compare_test_results(test_results_unused_socket, optimized_test_results_unused_socket)
assert match
@ -487,7 +488,7 @@ def bubble_sort_with_unused_socket(data_container):
pytest_max_loops=1,
testing_time=1.0,
)
assert len(test_results_used_socket) == 1
assert len(test_results_used_socket) >= 1
assert (
test_results_used_socket.test_results[0].id.test_module_path
== "code_to_optimize.tests.pytest.benchmarks_socket_test.codeflash_replay_tests.test_code_to_optimize_tests_pytest_benchmarks_socket_test_test_socket__replay_test_0"
@ -522,7 +523,7 @@ def bubble_sort_with_used_socket(data_container):
pytest_max_loops=1,
testing_time=1.0,
)
assert len(test_results_used_socket) == 1
assert len(test_results_used_socket) >= 1
assert (
test_results_used_socket.test_results[0].id.test_module_path
== "code_to_optimize.tests.pytest.benchmarks_socket_test.codeflash_replay_tests.test_code_to_optimize_tests_pytest_benchmarks_socket_test_test_socket__replay_test_0"

View file

@ -29,7 +29,7 @@ def test_trace_benchmarks() -> None:
# Get the count of records
# Get all records
cursor.execute(
"SELECT function_name, class_name, module_name, file_path, benchmark_function_name, benchmark_module_path, benchmark_line_number FROM benchmark_function_timings ORDER BY benchmark_module_path, benchmark_function_name, function_name"
"SELECT DISTINCT function_name, class_name, module_name, file_path, benchmark_function_name, benchmark_module_path, benchmark_line_number FROM benchmark_function_timings ORDER BY benchmark_module_path, benchmark_function_name, function_name"
)
function_calls = cursor.fetchall()
@ -220,6 +220,7 @@ def test_code_to_optimize_bubble_sort_codeflash_trace_sorter_test_no_func():
if conn is not None:
conn.close()
output_file.unlink(missing_ok=True)
if replay_tests_dir.exists():
shutil.rmtree(replay_tests_dir)
@ -242,14 +243,15 @@ def test_trace_multithreaded_benchmark() -> None:
# Get the count of records
# Get all records
cursor.execute(
"SELECT function_name, class_name, module_name, file_path, benchmark_function_name, benchmark_module_path, benchmark_line_number FROM benchmark_function_timings ORDER BY benchmark_module_path, benchmark_function_name, function_name"
"SELECT DISTINCT function_name, class_name, module_name, file_path, benchmark_function_name, benchmark_module_path, benchmark_line_number FROM benchmark_function_timings ORDER BY benchmark_module_path, benchmark_function_name, function_name"
)
function_calls = cursor.fetchall()
# Assert the length of function calls
assert len(function_calls) == 10, f"Expected 10 function calls, but got {len(function_calls)}"
assert len(function_calls) == 1, f"Expected 1 function call, but got {len(function_calls)}"
function_benchmark_timings = codeflash_benchmark_plugin.get_function_benchmark_timings(output_file)
total_benchmark_timings = codeflash_benchmark_plugin.get_benchmark_timings(output_file)
total_benchmark_stats = codeflash_benchmark_plugin.get_benchmark_timings(output_file)
total_benchmark_timings = {k: v.median_ns for k, v in total_benchmark_stats.items()}
function_to_results = validate_and_format_benchmark_table(function_benchmark_timings, total_benchmark_timings)
assert "code_to_optimize.bubble_sort_codeflash_trace.sorter" in function_to_results
@ -304,23 +306,24 @@ def test_trace_benchmark_decorator() -> None:
# Get the count of records
# Get all records
cursor.execute(
"SELECT function_name, class_name, module_name, file_path, benchmark_function_name, benchmark_module_path, benchmark_line_number FROM benchmark_function_timings ORDER BY benchmark_module_path, benchmark_function_name, function_name"
"SELECT DISTINCT function_name, class_name, module_name, file_path, benchmark_function_name, benchmark_module_path, benchmark_line_number FROM benchmark_function_timings ORDER BY benchmark_module_path, benchmark_function_name, function_name"
)
function_calls = cursor.fetchall()
# Assert the length of function calls
assert len(function_calls) == 2, f"Expected 2 function calls, but got {len(function_calls)}"
function_benchmark_timings = codeflash_benchmark_plugin.get_function_benchmark_timings(output_file)
total_benchmark_timings = codeflash_benchmark_plugin.get_benchmark_timings(output_file)
total_benchmark_stats = codeflash_benchmark_plugin.get_benchmark_timings(output_file)
total_benchmark_timings = {k: v.median_ns for k, v in total_benchmark_stats.items()}
function_to_results = validate_and_format_benchmark_table(function_benchmark_timings, total_benchmark_timings)
assert "code_to_optimize.bubble_sort_codeflash_trace.sorter" in function_to_results
test_name, total_time, function_time, percent = function_to_results[
"code_to_optimize.bubble_sort_codeflash_trace.sorter"
][0]
assert total_time > 0.0
assert function_time > 0.0
assert percent > 0.0
assert total_time >= 0.0
assert function_time >= 0.0
assert percent >= 0.0
bubble_sort_path = (project_root / "bubble_sort_codeflash_trace.py").as_posix()
# Expected function calls

View file

@ -0,0 +1,91 @@
"""Test that coverage error messages are framework-agnostic."""
import tempfile
from pathlib import Path
from unittest.mock import MagicMock
import pytest
from codeflash.languages.language_enum import Language
from codeflash.models.models import CodeOptimizationContext
from codeflash.verification.coverage_utils import JestCoverageUtils
class TestCoverageUtilsFrameworkAgnostic:
"""Test that error messages don't hardcode 'Jest' when used for Vitest."""
def test_missing_coverage_file_message_is_framework_agnostic(self, caplog):
"""When coverage file is missing, error message should not say 'Jest' specifically.
This class is used for both Jest and Vitest (they use the same Istanbul/v8 format).
Error messages should be generic, not hardcode 'Jest'.
"""
# Set log level to DEBUG to capture all messages
caplog.set_level("DEBUG")
# Create minimal context
context = MagicMock(spec=CodeOptimizationContext)
context.language = Language.JAVASCRIPT
context.target_code = "export function test() {}"
context.helper_functions = []
nonexistent_path = Path("/tmp/nonexistent_coverage_12345.json")
# Load coverage from non-existent file
result = JestCoverageUtils.load_from_jest_json(
coverage_json_path=nonexistent_path,
function_name="testFunc",
code_context=context,
source_code_path=Path("/tmp/test.ts")
)
# Should return empty coverage data
assert result.status.name in ("NOT_FOUND", "EMPTY")
# Error message should NOT hardcode "Jest" - it should be framework-agnostic
# since this util is used for both Jest and Vitest
log_messages = [record.message for record in caplog.records]
# Check that if there's a message about coverage file, it doesn't say "Jest"
coverage_messages = [msg for msg in log_messages if "coverage file not found" in msg.lower()]
if coverage_messages:
# The message should NOT contain "Jest" specifically
# It should say something like "Coverage file not found" or "JavaScript coverage file not found"
for msg in coverage_messages:
assert "Jest" not in msg, (
f"Error message should not hardcode 'Jest' since this util is used for Vitest too. "
f"Got: {msg}"
)
def test_parse_error_message_is_framework_agnostic(self, tmp_path, caplog):
"""When coverage file is malformed, error should not say 'Jest' specifically."""
# Set log level to capture all messages
caplog.set_level("DEBUG")
# Create invalid JSON file
coverage_file = tmp_path / "invalid_coverage.json"
coverage_file.write_text("{invalid json")
context = MagicMock(spec=CodeOptimizationContext)
context.language = Language.JAVASCRIPT
context.target_code = "export function test() {}"
context.helper_functions = []
result = JestCoverageUtils.load_from_jest_json(
coverage_json_path=coverage_file,
function_name="testFunc",
code_context=context,
source_code_path=Path("/tmp/test.ts")
)
# Should return empty coverage
assert result.status.name in ("NOT_FOUND", "EMPTY")
# Check log messages don't hardcode "Jest"
log_messages = [record.message for record in caplog.records]
parse_error_messages = [msg for msg in log_messages if "parse" in msg.lower() and "coverage" in msg.lower()]
for msg in parse_error_messages:
assert "Jest" not in msg, (
f"Parse error message should not hardcode 'Jest'. Got: {msg}"
)

View file

@ -0,0 +1,55 @@
"""Test that verifier.py handles test files outside tests_project_rootdir gracefully.
This tests the fix for the bug where JavaScript/TypeScript test files generated
in __tests__ subdirectories (adjacent to source files) caused ValueError when
verifier.py tried to compute their module path relative to tests_project_rootdir.
Trace ID: 84f5467f-8acf-427f-b468-02cb3342097e
"""
from pathlib import Path
import pytest
from codeflash.code_utils.code_utils import module_name_from_file_path
class TestVerifierPathHandling:
"""Test path handling in verifier.py for test files outside tests_root."""
def test_module_name_from_file_path_raises_valueerror_when_outside_root(self) -> None:
"""Verify that module_name_from_file_path raises ValueError when file is outside root.
This is the current behavior that causes the bug in verifier.py line 37.
Scenario:
- JavaScript support generates test at: /workspace/target/src/gateway/server/__tests__/codeflash-generated/test_foo.test.ts
- tests_project_rootdir is: /workspace/target/test
- Test file is NOT within tests_root, so relative_to() fails
"""
test_path = Path("/workspace/target/src/gateway/server/__tests__/codeflash-generated/test_foo.test.ts")
tests_root = Path("/workspace/target/test")
# This should raise ValueError before the fix
with pytest.raises(ValueError, match="is not within the project root"):
module_name_from_file_path(test_path, tests_root)
def test_module_name_from_file_path_with_fallback_succeeds(self) -> None:
"""Test that adding a fallback (try-except) allows graceful handling.
This is the pattern used in javascript/parse.py:330-333 that should
also be applied to verifier.py:37.
"""
test_path = Path("/workspace/target/src/gateway/server/__tests__/codeflash-generated/test_foo.test.ts")
tests_root = Path("/workspace/target/test")
# Simulate the fix: try-except with fallback to filename
try:
test_module_path = module_name_from_file_path(test_path, tests_root)
except ValueError:
# Fallback: use just the filename (or relative path from parent)
# This is what javascript/parse.py does
test_module_path = test_path.name
# After fallback, we should have a valid path
assert test_module_path == "test_foo.test.ts"

214
uv.lock
View file

@ -466,6 +466,7 @@ dependencies = [
{ name = "libcst" },
{ name = "line-profiler" },
{ name = "lxml" },
{ name = "memray", marker = "sys_platform != 'win32'" },
{ name = "parameterized" },
{ name = "platformdirs", version = "4.4.0", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.10'" },
{ name = "platformdirs", version = "4.9.4", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.10'" },
@ -477,6 +478,7 @@ dependencies = [
{ name = "pytest", version = "9.0.2", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.10'" },
{ name = "pytest-asyncio", version = "1.2.0", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.10'" },
{ name = "pytest-asyncio", version = "1.3.0", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.10'" },
{ name = "pytest-memray", marker = "sys_platform != 'win32'" },
{ name = "pytest-timeout" },
{ name = "requests", version = "2.32.5", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.10'" },
{ name = "requests", version = "2.33.0", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.10'" },
@ -576,6 +578,7 @@ requires-dist = [
{ name = "libcst", specifier = ">=1.0.1" },
{ name = "line-profiler", specifier = ">=4.2.0" },
{ name = "lxml", specifier = ">=5.3.0" },
{ name = "memray", marker = "sys_platform != 'win32'", specifier = ">=1.12" },
{ name = "parameterized", specifier = ">=0.9.0" },
{ name = "platformdirs", specifier = ">=4.3.7" },
{ name = "posthog", specifier = ">=3.0.0" },
@ -583,6 +586,7 @@ requires-dist = [
{ name = "pygls", specifier = ">=2.0.0,<3.0.0" },
{ name = "pytest", specifier = ">=7.0.0" },
{ name = "pytest-asyncio", specifier = ">=0.18.0" },
{ name = "pytest-memray", marker = "sys_platform != 'win32'", specifier = ">=1.7" },
{ name = "pytest-timeout", specifier = ">=2.1.0" },
{ name = "requests", specifier = ">=2.28.0" },
{ name = "rich", specifier = ">=13.8.1" },
@ -2261,6 +2265,45 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/32/9f/228020e1bce6308723b5455e7de054428b9908b340b4c702dd2b3409f016/line_profiler-5.0.2-cp39-cp39-win_amd64.whl", hash = "sha256:2b70a38fe852d7c95eca105ec603a28ca6f0bd3c909f2cca9e7cca2bf19cb77e", size = 480441, upload-time = "2026-02-23T23:31:19.162Z" },
]
[[package]]
name = "linkify-it-py"
version = "2.0.3"
source = { registry = "https://pypi.org/simple" }
resolution-markers = [
"python_full_version >= '3.9.2' and python_full_version < '3.10'",
"python_full_version < '3.9.2'",
]
dependencies = [
{ name = "uc-micro-py", version = "1.0.3", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.10'" },
]
sdist = { url = "https://files.pythonhosted.org/packages/2a/ae/bb56c6828e4797ba5a4821eec7c43b8bf40f69cda4d4f5f8c8a2810ec96a/linkify-it-py-2.0.3.tar.gz", hash = "sha256:68cda27e162e9215c17d786649d1da0021a451bdc436ef9e0fa0ba5234b9b048", size = 27946, upload-time = "2024-02-04T14:48:04.179Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/04/1e/b832de447dee8b582cac175871d2f6c3d5077cc56d5575cadba1fd1cccfa/linkify_it_py-2.0.3-py3-none-any.whl", hash = "sha256:6bcbc417b0ac14323382aef5c5192c0075bf8a9d6b41820a2b66371eac6b6d79", size = 19820, upload-time = "2024-02-04T14:48:02.496Z" },
]
[[package]]
name = "linkify-it-py"
version = "2.1.0"
source = { registry = "https://pypi.org/simple" }
resolution-markers = [
"python_full_version >= '3.14' and sys_platform == 'emscripten'",
"python_full_version >= '3.14' and sys_platform != 'emscripten' and sys_platform != 'win32'",
"python_full_version == '3.13.*' and sys_platform == 'emscripten'",
"python_full_version == '3.13.*' and sys_platform != 'emscripten' and sys_platform != 'win32'",
"python_full_version == '3.12.*' and sys_platform == 'emscripten'",
"python_full_version == '3.12.*' and sys_platform != 'emscripten' and sys_platform != 'win32'",
"python_full_version == '3.11.*' and sys_platform == 'emscripten'",
"python_full_version == '3.11.*' and sys_platform != 'emscripten' and sys_platform != 'win32'",
"python_full_version == '3.10.*'",
]
dependencies = [
{ name = "uc-micro-py", version = "2.0.0", source = { registry = "https://pypi.org/simple" }, marker = "(python_full_version >= '3.10' and sys_platform != 'win32') or (python_full_version == '3.10.*' and sys_platform == 'win32')" },
]
sdist = { url = "https://files.pythonhosted.org/packages/2e/c9/06ea13676ef354f0af6169587ae292d3e2406e212876a413bf9eece4eb23/linkify_it_py-2.1.0.tar.gz", hash = "sha256:43360231720999c10e9328dc3691160e27a718e280673d444c38d7d3aaa3b98b", size = 29158, upload-time = "2026-03-01T07:48:47.683Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/b4/de/88b3be5c31b22333b3ca2f6ff1de4e863d8fe45aaea7485f591970ec1d3e/linkify_it_py-2.1.0-py3-none-any.whl", hash = "sha256:0d252c1594ecba2ecedc444053db5d3a9b7ec1b0dd929c8f1d74dce89f86c05e", size = 19878, upload-time = "2026-03-01T07:48:46.098Z" },
]
[[package]]
name = "llvmlite"
version = "0.43.0"
@ -2515,6 +2558,11 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/42/d7/1ec15b46af6af88f19b8e5ffea08fa375d433c998b8a7639e76935c14f1f/markdown_it_py-3.0.0-py3-none-any.whl", hash = "sha256:355216845c60bd96232cd8d8c40e8f9765cc86f46880e43a8fd22dc1a1a8cab1", size = 87528, upload-time = "2023-06-03T06:41:11.019Z" },
]
[package.optional-dependencies]
linkify = [
{ name = "linkify-it-py", version = "2.0.3", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.10'" },
]
[[package]]
name = "markdown-it-py"
version = "4.0.0"
@ -2542,6 +2590,11 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/94/54/e7d793b573f298e1c9013b8c4dade17d481164aa517d1d7148619c2cedbf/markdown_it_py-4.0.0-py3-none-any.whl", hash = "sha256:87327c59b172c5011896038353a81343b6754500a08cd7a4973bb48c6d578147", size = 87321, upload-time = "2025-08-11T12:57:51.923Z" },
]
[package.optional-dependencies]
linkify = [
{ name = "linkify-it-py", version = "2.1.0", source = { registry = "https://pypi.org/simple" }, marker = "(python_full_version >= '3.10' and sys_platform != 'win32') or (python_full_version == '3.10.*' and sys_platform == 'win32')" },
]
[[package]]
name = "markupsafe"
version = "3.0.3"
@ -2650,6 +2703,45 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/af/33/ee4519fa02ed11a94aef9559552f3b17bb863f2ecfe1a35dc7f548cde231/matplotlib_inline-0.2.1-py3-none-any.whl", hash = "sha256:d56ce5156ba6085e00a9d54fead6ed29a9c47e215cd1bba2e976ef39f5710a76", size = 9516, upload-time = "2025-10-23T09:00:20.675Z" },
]
[[package]]
name = "mdit-py-plugins"
version = "0.4.2"
source = { registry = "https://pypi.org/simple" }
resolution-markers = [
"python_full_version >= '3.9.2' and python_full_version < '3.10'",
"python_full_version < '3.9.2'",
]
dependencies = [
{ name = "markdown-it-py", version = "3.0.0", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.10'" },
]
sdist = { url = "https://files.pythonhosted.org/packages/19/03/a2ecab526543b152300717cf232bb4bb8605b6edb946c845016fa9c9c9fd/mdit_py_plugins-0.4.2.tar.gz", hash = "sha256:5f2cd1fdb606ddf152d37ec30e46101a60512bc0e5fa1a7002c36647b09e26b5", size = 43542, upload-time = "2024-09-09T20:27:49.564Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/a7/f7/7782a043553ee469c1ff49cfa1cdace2d6bf99a1f333cf38676b3ddf30da/mdit_py_plugins-0.4.2-py3-none-any.whl", hash = "sha256:0c673c3f889399a33b95e88d2f0d111b4447bdfea7f237dab2d488f459835636", size = 55316, upload-time = "2024-09-09T20:27:48.397Z" },
]
[[package]]
name = "mdit-py-plugins"
version = "0.5.0"
source = { registry = "https://pypi.org/simple" }
resolution-markers = [
"python_full_version >= '3.14' and sys_platform == 'emscripten'",
"python_full_version >= '3.14' and sys_platform != 'emscripten' and sys_platform != 'win32'",
"python_full_version == '3.13.*' and sys_platform == 'emscripten'",
"python_full_version == '3.13.*' and sys_platform != 'emscripten' and sys_platform != 'win32'",
"python_full_version == '3.12.*' and sys_platform == 'emscripten'",
"python_full_version == '3.12.*' and sys_platform != 'emscripten' and sys_platform != 'win32'",
"python_full_version == '3.11.*' and sys_platform == 'emscripten'",
"python_full_version == '3.11.*' and sys_platform != 'emscripten' and sys_platform != 'win32'",
"python_full_version == '3.10.*'",
]
dependencies = [
{ name = "markdown-it-py", version = "4.0.0", source = { registry = "https://pypi.org/simple" }, marker = "(python_full_version >= '3.10' and sys_platform != 'win32') or (python_full_version == '3.10.*' and sys_platform == 'win32')" },
]
sdist = { url = "https://files.pythonhosted.org/packages/b2/fd/a756d36c0bfba5f6e39a1cdbdbfdd448dc02692467d83816dff4592a1ebc/mdit_py_plugins-0.5.0.tar.gz", hash = "sha256:f4918cb50119f50446560513a8e311d574ff6aaed72606ddae6d35716fe809c6", size = 44655, upload-time = "2025-08-11T07:25:49.083Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/fb/86/dd6e5db36df29e76c7a7699123569a4a18c1623ce68d826ed96c62643cae/mdit_py_plugins-0.5.0-py3-none-any.whl", hash = "sha256:07a08422fc1936a5d26d146759e9155ea466e842f5ab2f7d2266dd084c8dab1f", size = 57205, upload-time = "2025-08-11T07:25:47.597Z" },
]
[[package]]
name = "mdurl"
version = "0.1.2"
@ -2659,6 +2751,61 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/b3/38/89ba8ad64ae25be8de66a6d463314cf1eb366222074cfda9ee839c56a4b4/mdurl-0.1.2-py3-none-any.whl", hash = "sha256:84008a41e51615a49fc9966191ff91509e3c40b939176e643fd50a5c2196b8f8", size = 9979, upload-time = "2022-08-14T12:40:09.779Z" },
]
[[package]]
name = "memray"
version = "1.19.2"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "jinja2", marker = "python_full_version < '3.11' or sys_platform != 'win32'" },
{ name = "rich", marker = "python_full_version < '3.11' or sys_platform != 'win32'" },
{ name = "textual", marker = "python_full_version < '3.11' or sys_platform != 'win32'" },
]
sdist = { url = "https://files.pythonhosted.org/packages/e9/db/56ff21f47be261ab781105b233d1851d3f2fcdd4f08ebf689f6d6fd84f0d/memray-1.19.2.tar.gz", hash = "sha256:680cb90ac4564d140673ac9d8b7a7e07a8405bd1fb8f933da22616f93124ca84", size = 2410256, upload-time = "2026-03-13T15:22:31.825Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/3e/5f/48c6d7c6e4d02883d0c3de98c46c71d20c53038dfdde79614d0e55f9f163/memray-1.19.2-cp310-cp310-macosx_10_14_x86_64.whl", hash = "sha256:50d7130bb0c8609176b3b691c8b67fc92805180166e087549a59e7881ae8cf36", size = 2181142, upload-time = "2026-03-13T15:20:26.87Z" },
{ url = "https://files.pythonhosted.org/packages/1d/85/34d5dc497741bf684cfb5f59d58428b6fd4a034e55cb950339ee8f137f9d/memray-1.19.2-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:3643d601c4c1c413a62fb296598ed05dce1e1c3a58ea5c8659ae98ad36ce3a7a", size = 2162529, upload-time = "2026-03-13T15:20:29.187Z" },
{ url = "https://files.pythonhosted.org/packages/95/5f/ca6ab3cd76de6134cbe29f5a6daa77234f216ae9bd8c963beda226a22653/memray-1.19.2-cp310-cp310-manylinux2014_i686.manylinux_2_17_i686.whl", hash = "sha256:661aca0dbf4c448eef93f2f0bd0852eeefe3de2460e8105c2160c86e308beea5", size = 9707355, upload-time = "2026-03-13T15:20:30.941Z" },
{ url = "https://files.pythonhosted.org/packages/bd/c9/4b79508b2cf646ca3fe3c87bdef80cd26362679274b26dab1f4b725ebba0/memray-1.19.2-cp310-cp310-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:d13f33f1fa76165c5596e73bc45a366d58066be567fb131498cd770fa87f5a02", size = 9938651, upload-time = "2026-03-13T15:20:33.755Z" },
{ url = "https://files.pythonhosted.org/packages/d5/d6/ca9cef1c0aba2245c41aed699a45a748db7b0dd8a9a63484e809b0f8e448/memray-1.19.2-cp310-cp310-manylinux_2_26_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:74291aa9bbf54ff2ac5df2665c792d490c576720dd2cbad89af53528bda5443f", size = 9327619, upload-time = "2026-03-13T15:20:36.179Z" },
{ url = "https://files.pythonhosted.org/packages/ce/66/572f819ff58d0f0fefeeeeaa7206f192107f39027a92fd90af1c1cbff61b/memray-1.19.2-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:716a1b2569e049d0cb769015e5be9138bd97bd157e67920cc9e215e011fbb9cd", size = 12158374, upload-time = "2026-03-13T15:20:39.213Z" },
{ url = "https://files.pythonhosted.org/packages/63/bf/b8f28adbd3e1eeeb88e188053a26164b195ebcf66f8af6b30003a83f5660/memray-1.19.2-cp311-cp311-macosx_10_14_x86_64.whl", hash = "sha256:c8d35a9f5b222165c5aedbfc18b79dc5161a724963a4fca8d1053faa0b571195", size = 2181644, upload-time = "2026-03-13T15:20:41.756Z" },
{ url = "https://files.pythonhosted.org/packages/21/66/0791e5514b475d6300d13ebe87839db1606b2dc2fbe00fecce4da2fb405d/memray-1.19.2-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:3735567011cc22339aee2c59b5fc94d1bdd4a23f9990e02a2c3cccc9c3cf6de4", size = 2164670, upload-time = "2026-03-13T15:20:44.14Z" },
{ url = "https://files.pythonhosted.org/packages/0f/aa/086878e99693b174b0d04d0b267231862fb6a3cfc35cab2920284c2a2e38/memray-1.19.2-cp311-cp311-manylinux2014_i686.manylinux_2_17_i686.whl", hash = "sha256:ab78af759eebcb8d8ecef173042515711d2dcc9600d5dd446d1592b24a89b7d9", size = 9777844, upload-time = "2026-03-13T15:20:46.266Z" },
{ url = "https://files.pythonhosted.org/packages/40/a6/40247667e72b5d8322c5dc2ef30513238b3480be1e482faaaf9cc573ff38/memray-1.19.2-cp311-cp311-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:f3ae7983297d168cdcc2d05cd93a4934b9b6fe0d341a91ac5b71bf45f9cec06c", size = 10021548, upload-time = "2026-03-13T15:20:49.079Z" },
{ url = "https://files.pythonhosted.org/packages/b3/bb/50603e8f7fe950b3f6a6e09a80413a8f25c4a9d360d8b3b027a8841e1fe8/memray-1.19.2-cp311-cp311-manylinux_2_26_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:08a4316d7a92eb415024b46988844ed0fd44b2d02ca00fa4a21f2481c1f803e6", size = 9400168, upload-time = "2026-03-13T15:20:51.801Z" },
{ url = "https://files.pythonhosted.org/packages/e2/89/a21e0b639496ed59d2a733e60869ff2e685c5a78891474a494e09a17dc7c/memray-1.19.2-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:dbdb14fd31e2a031312755dc76146aeff9d0889e82ccffe231f1f20f50526f57", size = 12234413, upload-time = "2026-03-13T15:20:54.454Z" },
{ url = "https://files.pythonhosted.org/packages/13/4e/8685c202ddd76860cd8fc5f7f552115ea6f317e9f5f16219a56f336e351e/memray-1.19.2-cp312-cp312-macosx_10_14_x86_64.whl", hash = "sha256:22d4482f559ffa91a9727693e7e338856bee5e316f922839bf8b96e0f9b8a4de", size = 2183484, upload-time = "2026-03-13T15:20:56.696Z" },
{ url = "https://files.pythonhosted.org/packages/89/79/602f55d5466f1f587cdddf0324f82752bd0319ea814bc7cca2efb8593bc8/memray-1.19.2-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:4fd1476868177ee8d9f7f85e5a085a20cc3c3a8228a23ced72749265885d55ca", size = 2162900, upload-time = "2026-03-13T15:20:58.174Z" },
{ url = "https://files.pythonhosted.org/packages/02/1b/402207971653b9861bbbe449cbed7d82e7bb9b953dd6ac93dd4d78e76fa2/memray-1.19.2-cp312-cp312-manylinux2014_i686.manylinux_2_17_i686.whl", hash = "sha256:23375d50faa199e1c1bc2e89f08691f6812478fddb49a1b82bebe6ef5a56df2c", size = 9731991, upload-time = "2026-03-13T15:21:00.299Z" },
{ url = "https://files.pythonhosted.org/packages/3f/7d/895ce73fcf9ab0a2b675ed49bbc91cbca14bda187e2b4df86ccefeb1c9bc/memray-1.19.2-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:8ef3d8e4fba0b26280b550278a0660554283135cbccc34e2d49ba82a1945eb61", size = 9997104, upload-time = "2026-03-13T15:21:02.959Z" },
{ url = "https://files.pythonhosted.org/packages/a0/b9/586bf51a1321cde736d886ca8ac3d4b1f910e4f3f813d7c8eb22498ee16f/memray-1.19.2-cp312-cp312-manylinux_2_26_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:a4d6cf9597ae5d60f7893a0b7b6b9af9c349121446b3c1e7b9ac1d8b5d45a505", size = 9373508, upload-time = "2026-03-13T15:21:05.945Z" },
{ url = "https://files.pythonhosted.org/packages/5d/f1/7cb51edeeceaaee770d4222e833369fbc927227d27e0a917b5ad6f4b2f85/memray-1.19.2-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:716a0a0e9048d21da98f9107fa030a76138eb694a16a81ad15eace54fddef4cd", size = 12222756, upload-time = "2026-03-13T15:21:08.9Z" },
{ url = "https://files.pythonhosted.org/packages/34/10/cbf57c122988d6e3bd148aa374e91e0e2f156cc7db1ac6397eb6db3946d1/memray-1.19.2-cp313-cp313-macosx_10_14_x86_64.whl", hash = "sha256:13aa87ad34cc88b3f31f7205e0a4543c391032e8600dc0c0cbf22555ff816d97", size = 2182910, upload-time = "2026-03-13T15:21:11.357Z" },
{ url = "https://files.pythonhosted.org/packages/5c/0e/7979dfe7e2b034431e44e3bab86356d9bc2c4f3ed0eb1594cb0ceb38c859/memray-1.19.2-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:d6b249618a3e4fa8e10291445a2b2dfaf6f188e7cc1765966aac8fb52cb22066", size = 2161575, upload-time = "2026-03-13T15:21:13.051Z" },
{ url = "https://files.pythonhosted.org/packages/f9/92/2f0ca3936cdf4c59bc8c59fc8738ce8854ba24fd8519988f2ece0eba10fa/memray-1.19.2-cp313-cp313-manylinux2014_i686.manylinux_2_17_i686.whl", hash = "sha256:34985e5e638ef8d4d54de8173c5e4481c478930f545bd0eb4738a631beb63d04", size = 9732172, upload-time = "2026-03-13T15:21:15.115Z" },
{ url = "https://files.pythonhosted.org/packages/52/23/de78510b4e3a0668b793d8b5dff03f2af20eef97943ca5b3263effff799c/memray-1.19.2-cp313-cp313-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:ee0fcfafd1e8535bdc0d0ed75bcdd48d436a6f62d467df91871366cbb3bbaebc", size = 9999447, upload-time = "2026-03-13T15:21:18.099Z" },
{ url = "https://files.pythonhosted.org/packages/00/0d/b0e50537470f93bddfa2c134177fe9332c20be44a571588866776ff92b82/memray-1.19.2-cp313-cp313-manylinux_2_26_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:846185c393ff0dc6bca55819b1c83b510b77d8d561b7c0c50f4873f69579e35d", size = 9379158, upload-time = "2026-03-13T15:21:21.003Z" },
{ url = "https://files.pythonhosted.org/packages/5c/53/78f6de5c7208821b15cfbbb9da44ab4a5a881a7cc5075f9435a1700320e8/memray-1.19.2-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:8cc31327ed71e9f6ef7e9ed558e764f0e9c3f01da13ad8547734eb65fbeade1d", size = 12226753, upload-time = "2026-03-13T15:21:24.041Z" },
{ url = "https://files.pythonhosted.org/packages/e1/f4/3d8205b9f46657d26d54d1e644f27d09955b737189354a01907d8a08c7e2/memray-1.19.2-cp314-cp314-macosx_10_15_x86_64.whl", hash = "sha256:410377c0eae8d544421f74b919a18e119279fe1a2fa5ff381404b55aeb4c6514", size = 2184823, upload-time = "2026-03-13T15:21:27.176Z" },
{ url = "https://files.pythonhosted.org/packages/fb/07/7a342801317eff410a8267b55cb7514e156ee1f574e690852eb240bbe9fd/memray-1.19.2-cp314-cp314-macosx_11_0_arm64.whl", hash = "sha256:a53dc4032581ed075fcb62a4acc0ced14fb90a8269159d4e53dfac7af269c255", size = 2163669, upload-time = "2026-03-13T15:21:29.123Z" },
{ url = "https://files.pythonhosted.org/packages/d4/00/2c342b1472f9f03018bb88c80760cdfa6979404d63c4300c607fd0562607/memray-1.19.2-cp314-cp314-manylinux2014_i686.manylinux_2_17_i686.whl", hash = "sha256:a7630865fbf3823aa2d1a6f7536f7aec88cf8ccf5b2498aad44adbc733f6bd2e", size = 9732615, upload-time = "2026-03-13T15:21:31.038Z" },
{ url = "https://files.pythonhosted.org/packages/fe/ae/2cf960526c9b1f6d46977fc70e11de29ca6b9eafeeb42d1cec7d3bcb056a/memray-1.19.2-cp314-cp314-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:c23e2b4be22a23cf5cae08854549e3460869a36c5f4bedc739b646ac97da4a60", size = 9979299, upload-time = "2026-03-13T15:21:34.072Z" },
{ url = "https://files.pythonhosted.org/packages/e1/78/73ee3d0ebee3c38fbb2d51766854d2932beec6481063532a6019bf340a2d/memray-1.19.2-cp314-cp314-manylinux_2_26_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:95b6c02ca7f8555b5bee1c54c50cbbcf2033e07ebca95dade2ac3a27bb36b320", size = 9375722, upload-time = "2026-03-13T15:21:36.884Z" },
{ url = "https://files.pythonhosted.org/packages/3b/c6/2f02475e85ccd32fa306736986f1f77f99365066ecdc859f5078148ebc40/memray-1.19.2-cp314-cp314-musllinux_1_2_x86_64.whl", hash = "sha256:907470e2684568eb91a993ae69a08b1430494c8f2f6ef489b4b78519d9dae3d0", size = 12220041, upload-time = "2026-03-13T15:21:40.16Z" },
{ url = "https://files.pythonhosted.org/packages/76/12/01bb32188c011e6d802469e04c1d7c8054eb8300164e2269c830f5b26a8e/memray-1.19.2-cp314-cp314t-macosx_10_15_x86_64.whl", hash = "sha256:124138f35fea36c434256c417f1b8cb32f78769f208530c1e56bf2c2b7654120", size = 2201353, upload-time = "2026-03-13T15:21:42.607Z" },
{ url = "https://files.pythonhosted.org/packages/e5/e0/d9b59f8be00f27440f60b95da5db6515a1c44c481651b8d2fa8f3468fc35/memray-1.19.2-cp314-cp314t-macosx_11_0_arm64.whl", hash = "sha256:240192dc98ff0b3501055521bfd73566d339808b11bd5af10865afe6ae18abef", size = 2180420, upload-time = "2026-03-13T15:21:44.623Z" },
{ url = "https://files.pythonhosted.org/packages/a5/5c/30aca63f4b88dca79ba679675200938652c816edee34c12565d2f17ea936/memray-1.19.2-cp314-cp314t-manylinux2014_i686.manylinux_2_17_i686.whl", hash = "sha256:edb7a3c2a9e97fb409b352f6c316598c7c0c3c22732e73704d25b9eb75ae2f2d", size = 9697953, upload-time = "2026-03-13T15:21:47.088Z" },
{ url = "https://files.pythonhosted.org/packages/9f/02/9e4a68bdd5ebc9079f97bdf287cc0ccc51c18e9edc205de7d41648315809/memray-1.19.2-cp314-cp314t-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:b6a43db4c1466446a905a77944813253231ac0269f758c6c6bc03ceb1821c1b6", size = 9944517, upload-time = "2026-03-13T15:21:50.125Z" },
{ url = "https://files.pythonhosted.org/packages/4a/f0/3adad59ebed6841c2f88b43c9b90cc9c03ff086129a8aef3cff23c92d6ac/memray-1.19.2-cp314-cp314t-manylinux_2_26_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:cf951dae8d27d502fbc549f6784460a70cce05b1e71bf5446d8692a74051f14f", size = 9365528, upload-time = "2026-03-13T15:21:53.141Z" },
{ url = "https://files.pythonhosted.org/packages/45/0e/083e00fe74e576b463e7b00e4214b8962f27bd70c5c77e494c0211a77342/memray-1.19.2-cp314-cp314t-musllinux_1_2_x86_64.whl", hash = "sha256:8033b78232555bb1856b3298bef2898ec8b334d3d465c1822c665206d1fa910a", size = 12143894, upload-time = "2026-03-13T15:21:56.486Z" },
{ url = "https://files.pythonhosted.org/packages/4d/1b/b2e54cbe9a67a63a2f8b0c0d3cbfef0db8592e00ced4d6afb324245910e5/memray-1.19.2-cp39-cp39-macosx_10_14_x86_64.whl", hash = "sha256:f82ee0a0b50a04894dacfbe49db1c259fa8a19efb094514b0100e9916d3b1c55", size = 2183022, upload-time = "2026-03-13T15:22:14.81Z" },
{ url = "https://files.pythonhosted.org/packages/fd/1e/17a3e62bccf2c34cfa2208c28bdab127afd279c8a6d7fbb7c2b835a606db/memray-1.19.2-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:5b1c58a54372707b3977c079ef93e751109f0bfe566adc7bd640971d123d8d11", size = 2163707, upload-time = "2026-03-13T15:22:16.507Z" },
{ url = "https://files.pythonhosted.org/packages/9c/bd/a9bb3d747b138c8bc382389857879941f6c7a83fb3beeebce1c3251ad401/memray-1.19.2-cp39-cp39-manylinux2014_i686.manylinux_2_17_i686.whl", hash = "sha256:fa236140320ef1b8801cd289962fd81a2d7e59484cc3ecdbc851d1b5c321795e", size = 9703623, upload-time = "2026-03-13T15:22:19.551Z" },
{ url = "https://files.pythonhosted.org/packages/a3/70/24006fcab90eb6a21b5b2c45f046746578a817c82cb7ed2987d08dffad9d/memray-1.19.2-cp39-cp39-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:816baeda8e62fddf99c900bdc9e748339dba65df091a7c7ceb0f4f9544c2e7ec", size = 9925887, upload-time = "2026-03-13T15:22:23.297Z" },
{ url = "https://files.pythonhosted.org/packages/41/5e/6ac00a20da0b84c9e41d1e0ebaf27d49907ff7be1cd66b1e2b410d1c9c25/memray-1.19.2-cp39-cp39-manylinux_2_26_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:a1532d5dcf8036ec55e43ab6d6b1ff4e70b11a3902dd1c8396b6d2a24ec69d98", size = 9323522, upload-time = "2026-03-13T15:22:26.144Z" },
{ url = "https://files.pythonhosted.org/packages/2d/e0/74c17f7095e7c476fef3f47a13637fe0d717b58c8e0e5e06a388b7ca3cac/memray-1.19.2-cp39-cp39-musllinux_1_2_x86_64.whl", hash = "sha256:86060df2e8e18cc867335c50bf92deb973d4dff856bdb565e17fc86ca7a6619b", size = 12154107, upload-time = "2026-03-13T15:22:29.341Z" },
]
[[package]]
name = "ml-dtypes"
version = "0.5.4"
@ -4368,6 +4515,20 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/e5/35/f8b19922b6a25bc0880171a2f1a003eaeb93657475193ab516fd87cac9da/pytest_asyncio-1.3.0-py3-none-any.whl", hash = "sha256:611e26147c7f77640e6d0a92a38ed17c3e9848063698d5c93d5aa7aa11cebff5", size = 15075, upload-time = "2025-11-10T16:07:45.537Z" },
]
[[package]]
name = "pytest-memray"
version = "1.8.0"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "memray", marker = "python_full_version < '3.11' or sys_platform != 'win32'" },
{ name = "pytest", version = "8.4.2", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.10'" },
{ name = "pytest", version = "9.0.2", source = { registry = "https://pypi.org/simple" }, marker = "(python_full_version >= '3.10' and sys_platform != 'win32') or (python_full_version == '3.10.*' and sys_platform == 'win32')" },
]
sdist = { url = "https://files.pythonhosted.org/packages/3d/28/f67963efed56d847d028d0bb939f26cdeb32c4de474b3befc9da43bf18f9/pytest_memray-1.8.0.tar.gz", hash = "sha256:c0c706ef81941a7aa7064f2b3b8b5cdc0cea72b5277c6a6a09b113ca9ab30bdb", size = 240608, upload-time = "2025-08-18T17:32:47.329Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/cc/52/b8b8e126c176c5f405b307354e1722025063ea104dbd7d286e8b18a76e9f/pytest_memray-1.8.0-py3-none-any.whl", hash = "sha256:44da9fe0d98541abf4cc76acea6e4a9c525b3c8e604655e5537705f336c9b875", size = 17688, upload-time = "2025-08-18T17:32:45.476Z" },
]
[[package]]
name = "pytest-timeout"
version = "2.4.0"
@ -5338,6 +5499,26 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/33/d1/8bb87d21e9aeb323cc03034f5eaf2c8f69841e40e4853c2627edf8111ed3/termcolor-3.3.0-py3-none-any.whl", hash = "sha256:cf642efadaf0a8ebbbf4bc7a31cec2f9b5f21a9f726f4ccbb08192c9c26f43a5", size = 7734, upload-time = "2025-12-29T12:55:20.718Z" },
]
[[package]]
name = "textual"
version = "8.2.1"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "markdown-it-py", version = "3.0.0", source = { registry = "https://pypi.org/simple" }, extra = ["linkify"], marker = "python_full_version < '3.10'" },
{ name = "markdown-it-py", version = "4.0.0", source = { registry = "https://pypi.org/simple" }, extra = ["linkify"], marker = "(python_full_version >= '3.10' and sys_platform != 'win32') or (python_full_version == '3.10.*' and sys_platform == 'win32')" },
{ name = "mdit-py-plugins", version = "0.4.2", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.10'" },
{ name = "mdit-py-plugins", version = "0.5.0", source = { registry = "https://pypi.org/simple" }, marker = "(python_full_version >= '3.10' and sys_platform != 'win32') or (python_full_version == '3.10.*' and sys_platform == 'win32')" },
{ name = "platformdirs", version = "4.4.0", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.10'" },
{ name = "platformdirs", version = "4.9.4", source = { registry = "https://pypi.org/simple" }, marker = "(python_full_version >= '3.10' and sys_platform != 'win32') or (python_full_version == '3.10.*' and sys_platform == 'win32')" },
{ name = "pygments", marker = "python_full_version < '3.11' or sys_platform != 'win32'" },
{ name = "rich", marker = "python_full_version < '3.11' or sys_platform != 'win32'" },
{ name = "typing-extensions", marker = "python_full_version < '3.11' or sys_platform != 'win32'" },
]
sdist = { url = "https://files.pythonhosted.org/packages/4f/07/766ad19cf2b15cae2d79e0db46a1b783b62316e9ff3e058e7424b2a4398b/textual-8.2.1.tar.gz", hash = "sha256:4176890e9cd5c95dcdd206541b2956b0808e74c8c36381c88db53dcb45237451", size = 1848386, upload-time = "2026-03-29T03:57:32.242Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/25/09/c6f000c2e3702036e593803319af02feee58a662528d0d5728a37e1cf81b/textual-8.2.1-py3-none-any.whl", hash = "sha256:746cbf947a8ca875afc09779ef38cadbc7b9f15ac886a5090f7099fef5ade990", size = 723871, upload-time = "2026-03-29T03:57:34.334Z" },
]
[[package]]
name = "tomli"
version = "2.4.1"
@ -6324,6 +6505,39 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/c7/b0/003792df09decd6849a5e39c28b513c06e84436a54440380862b5aeff25d/tzdata-2025.3-py2.py3-none-any.whl", hash = "sha256:06a47e5700f3081aab02b2e513160914ff0694bce9947d6b76ebd6bf57cfc5d1", size = 348521, upload-time = "2025-12-13T17:45:33.889Z" },
]
[[package]]
name = "uc-micro-py"
version = "1.0.3"
source = { registry = "https://pypi.org/simple" }
resolution-markers = [
"python_full_version >= '3.9.2' and python_full_version < '3.10'",
"python_full_version < '3.9.2'",
]
sdist = { url = "https://files.pythonhosted.org/packages/91/7a/146a99696aee0609e3712f2b44c6274566bc368dfe8375191278045186b8/uc-micro-py-1.0.3.tar.gz", hash = "sha256:d321b92cff673ec58027c04015fcaa8bb1e005478643ff4a500882eaab88c48a", size = 6043, upload-time = "2024-02-09T16:52:01.654Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/37/87/1f677586e8ac487e29672e4b17455758fce261de06a0d086167bb760361a/uc_micro_py-1.0.3-py3-none-any.whl", hash = "sha256:db1dffff340817673d7b466ec86114a9dc0e9d4d9b5ba229d9d60e5c12600cd5", size = 6229, upload-time = "2024-02-09T16:52:00.371Z" },
]
[[package]]
name = "uc-micro-py"
version = "2.0.0"
source = { registry = "https://pypi.org/simple" }
resolution-markers = [
"python_full_version >= '3.14' and sys_platform == 'emscripten'",
"python_full_version >= '3.14' and sys_platform != 'emscripten' and sys_platform != 'win32'",
"python_full_version == '3.13.*' and sys_platform == 'emscripten'",
"python_full_version == '3.13.*' and sys_platform != 'emscripten' and sys_platform != 'win32'",
"python_full_version == '3.12.*' and sys_platform == 'emscripten'",
"python_full_version == '3.12.*' and sys_platform != 'emscripten' and sys_platform != 'win32'",
"python_full_version == '3.11.*' and sys_platform == 'emscripten'",
"python_full_version == '3.11.*' and sys_platform != 'emscripten' and sys_platform != 'win32'",
"python_full_version == '3.10.*'",
]
sdist = { url = "https://files.pythonhosted.org/packages/78/67/9a363818028526e2d4579334460df777115bdec1bb77c08f9db88f6389f2/uc_micro_py-2.0.0.tar.gz", hash = "sha256:c53691e495c8db60e16ffc4861a35469b0ba0821fe409a8a7a0a71864d33a811", size = 6611, upload-time = "2026-03-01T06:31:27.526Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/61/73/d21edf5b204d1467e06500080a50f79d49ef2b997c79123a536d4a17d97c/uc_micro_py-2.0.0-py3-none-any.whl", hash = "sha256:3603a3859af53e5a39bc7677713c78ea6589ff188d70f4fee165db88e22b242c", size = 6383, upload-time = "2026-03-01T06:31:26.257Z" },
]
[[package]]
name = "unidiff"
version = "0.7.5"