diff --git a/codeflash-java-runtime/src/main/java/com/codeflash/Serializer.java b/codeflash-java-runtime/src/main/java/com/codeflash/Serializer.java index 80d400935..e1c177ac9 100644 --- a/codeflash-java-runtime/src/main/java/com/codeflash/Serializer.java +++ b/codeflash-java-runtime/src/main/java/com/codeflash/Serializer.java @@ -6,7 +6,6 @@ import com.esotericsoftware.kryo.io.Output; import com.esotericsoftware.kryo.util.DefaultInstantiatorStrategy; import org.objenesis.strategy.StdInstantiatorStrategy; -import java.io.ByteArrayOutputStream; import java.io.InputStream; import java.io.OutputStream; import java.lang.reflect.Field; @@ -36,7 +35,11 @@ public final class Serializer { private static final int MAX_COLLECTION_SIZE = 1000; private static final int BUFFER_SIZE = 4096; - // Thread-local Kryo instances (Kryo is not thread-safe) + // Thread-local Kryo, Output, and IdentityHashMap instances for reuse + private static final ThreadLocal OUTPUT = ThreadLocal.withInitial(() -> new Output(BUFFER_SIZE, -1)); + private static final ThreadLocal> SEEN = + ThreadLocal.withInitial(IdentityHashMap::new); + private static final ThreadLocal KRYO = ThreadLocal.withInitial(() -> { Kryo kryo = new Kryo(); kryo.setRegistrationRequired(false); @@ -89,10 +92,78 @@ public final class Serializer { * @return Serialized bytes (may contain KryoPlaceholder for unserializable parts) */ public static byte[] serialize(Object obj) { - Object processed = recursiveProcess(obj, new IdentityHashMap<>(), 0, ""); + // Fast path: if args are all safe types, skip recursive processing entirely + if (obj instanceof Object[] && isSafeArgs((Object[]) obj)) { + return directSerialize(obj); + } + + IdentityHashMap seen = SEEN.get(); + seen.clear(); + Object processed = recursiveProcess(obj, seen, 0, ""); return directSerialize(processed); } + /** + * Attempt fast-path serialization for args that are all known-safe types. + * Returns serialized bytes if all args are safe, or null if the slow path is needed. + * Callers can use this to avoid executor submission overhead for simple arguments. + */ + public static byte[] serializeFast(Object obj) { + if (obj instanceof Object[] && isSafeArgs((Object[]) obj)) { + return directSerialize(obj); + } + return null; + } + + /** + * Check if all elements of an args array can be serialized directly without recursive processing. + */ + private static boolean isSafeArgs(Object[] args) { + for (Object arg : args) { + if (!isSafeForDirectSerialization(arg)) { + return false; + } + } + return true; + } + + /** + * Check if an object is safe to serialize directly without recursive processing. + * Covers: null, simple types, primitive arrays, and safe containers (up to 3 levels deep). + */ + private static boolean isSafeForDirectSerialization(Object obj) { + return isSafeForDirectSerialization(obj, 3); + } + + private static boolean isSafeForDirectSerialization(Object obj, int depthLeft) { + if (obj == null || isSimpleType(obj)) { + return true; + } + if (depthLeft <= 0) { + return false; + } + Class clazz = obj.getClass(); + if (clazz.isArray() && clazz.getComponentType().isPrimitive()) { + return true; + } + if (isSafeContainerType(clazz)) { + if (obj instanceof Collection) { + for (Object item : (Collection) obj) { + if (!isSafeForDirectSerialization(item, depthLeft - 1)) return false; + } + return true; + } + if (obj instanceof Map) { + for (Map.Entry e : ((Map) obj).entrySet()) { + if (!isSafeForDirectSerialization(e.getKey(), depthLeft - 1) || + !isSafeForDirectSerialization(e.getValue(), depthLeft - 1)) return false; + } + return true; + } + } + return false; + } + /** * Deserialize bytes back to an object. * The returned object may contain KryoPlaceholder instances for parts @@ -141,14 +212,15 @@ public final class Serializer { /** * Direct serialization without recursive processing. + * Reuses a ThreadLocal Output buffer to avoid per-call allocation. */ private static byte[] directSerialize(Object obj) { Kryo kryo = KRYO.get(); - ByteArrayOutputStream baos = new ByteArrayOutputStream(BUFFER_SIZE); - try (Output output = new Output(baos)) { - kryo.writeClassAndObject(output, obj); - } - return baos.toByteArray(); + Output output = OUTPUT.get(); + output.reset(); + kryo.writeClassAndObject(output, obj); + output.flush(); + return output.toBytes(); } /** @@ -201,37 +273,23 @@ public final class Serializer { // unserializable types, recursively process to catch and replace unserializable objects. if (obj instanceof Map) { Map map = (Map) obj; - if (containsOnlySimpleTypes(map)) { - // Simple map - try direct serialization to preserve full size - byte[] serialized = tryDirectSerialize(obj); - if (serialized != null) { - try { - deserialize(serialized); - return obj; // Success - return original - } catch (Exception e) { - // Fall through to recursive handling - } - } + if (isSafeContainerType(clazz) && containsOnlySimpleTypes(map)) { + return obj; } return handleMap(map, seen, depth, path); } if (obj instanceof Collection) { Collection collection = (Collection) obj; - if (containsOnlySimpleTypes(collection)) { - // Simple collection - try direct serialization to preserve full size - byte[] serialized = tryDirectSerialize(obj); - if (serialized != null) { - try { - deserialize(serialized); - return obj; // Success - return original - } catch (Exception e) { - // Fall through to recursive handling - } - } + if (isSafeContainerType(clazz) && containsOnlySimpleTypes(collection)) { + return obj; } return handleCollection(collection, seen, depth, path); } if (clazz.isArray()) { + // Primitive arrays (int[], double[], etc.) are directly serializable by Kryo + if (clazz.getComponentType().isPrimitive()) { + return obj; + } return handleArray(obj, seen, depth, path); } @@ -255,6 +313,19 @@ public final class Serializer { } } + /** + * Check if a container type is known to round-trip safely through Kryo without verification. + * Only includes types registered with Kryo that are known to serialize/deserialize correctly. + */ + private static boolean isSafeContainerType(Class clazz) { + return clazz == ArrayList.class || + clazz == LinkedList.class || + clazz == HashMap.class || + clazz == LinkedHashMap.class || + clazz == HashSet.class || + clazz == LinkedHashSet.class; + } + /** * Check if a class is known to be unserializable. */ diff --git a/codeflash-java-runtime/src/main/java/com/codeflash/tracer/TraceRecorder.java b/codeflash-java-runtime/src/main/java/com/codeflash/tracer/TraceRecorder.java index 28c2d2998..a9acfe855 100644 --- a/codeflash-java-runtime/src/main/java/com/codeflash/tracer/TraceRecorder.java +++ b/codeflash-java-runtime/src/main/java/com/codeflash/tracer/TraceRecorder.java @@ -12,6 +12,7 @@ import java.util.concurrent.Future; import java.util.concurrent.TimeUnit; import java.util.concurrent.TimeoutException; import java.util.concurrent.atomic.AtomicInteger; +import java.util.concurrent.atomic.AtomicLong; public final class TraceRecorder { @@ -23,6 +24,8 @@ public final class TraceRecorder { private final TraceWriter writer; private final ConcurrentHashMap functionCounts = new ConcurrentHashMap<>(); private final AtomicInteger droppedCaptures = new AtomicInteger(0); + private final AtomicLong totalOnEntryNs = new AtomicLong(0); + private final AtomicLong totalSerializationNs = new AtomicLong(0); private final int maxFunctionCount; private final ExecutorService serializerExecutor; @@ -31,7 +34,7 @@ public final class TraceRecorder { private TraceRecorder(TracerConfig config) { this.config = config; - this.writer = new TraceWriter(config.getDbPath()); + this.writer = new TraceWriter(config.getDbPath(), config.isInMemoryDb()); this.maxFunctionCount = config.getMaxFunctionCount(); this.serializerExecutor = Executors.newCachedThreadPool(r -> { Thread t = new Thread(r, "codeflash-serializer"); @@ -68,6 +71,8 @@ public final class TraceRecorder { private void onEntryImpl(String className, String methodName, String descriptor, int lineNumber, String sourceFile, Object[] args) { + long entryStart = System.nanoTime(); + String qualifiedName = className + "." + methodName + descriptor; // Check per-method count limit @@ -76,30 +81,38 @@ public final class TraceRecorder { return; } - // Serialize args with timeout to prevent deep object graph traversal from blocking + // Serialize args — try inline fast path first, fall back to async with timeout byte[] argsBlob; - Future future = serializerExecutor.submit(() -> Serializer.serialize(args)); - try { - argsBlob = future.get(SERIALIZATION_TIMEOUT_MS, TimeUnit.MILLISECONDS); - } catch (TimeoutException e) { - future.cancel(true); - droppedCaptures.incrementAndGet(); - System.err.println("[codeflash-tracer] Serialization timed out for " + className + "." - + methodName); - return; - } catch (Exception e) { - Throwable cause = e.getCause() != null ? e.getCause() : e; - droppedCaptures.incrementAndGet(); - System.err.println("[codeflash-tracer] Serialization failed for " + className + "." - + methodName + ": " + cause.getClass().getSimpleName() + ": " + cause.getMessage()); - return; + long serStart = System.nanoTime(); + argsBlob = Serializer.serializeFast(args); + if (argsBlob == null) { + // Slow path: async serialization with timeout for complex/unknown types + Future future = serializerExecutor.submit(() -> Serializer.serialize(args)); + try { + argsBlob = future.get(SERIALIZATION_TIMEOUT_MS, TimeUnit.MILLISECONDS); + } catch (TimeoutException e) { + future.cancel(true); + droppedCaptures.incrementAndGet(); + System.err.println("[codeflash-tracer] Serialization timed out for " + className + "." + + methodName); + return; + } catch (Exception e) { + Throwable cause = e.getCause() != null ? e.getCause() : e; + droppedCaptures.incrementAndGet(); + System.err.println("[codeflash-tracer] Serialization failed for " + className + "." + + methodName + ": " + cause.getClass().getSimpleName() + ": " + cause.getMessage()); + return; + } } + totalSerializationNs.addAndGet(System.nanoTime() - serStart); long timeNs = System.nanoTime(); count.incrementAndGet(); writer.recordFunctionCall("call", methodName, className, sourceFile, lineNumber, descriptor, timeNs, argsBlob); + + totalOnEntryNs.addAndGet(System.nanoTime() - entryStart); } public void flush() { @@ -126,5 +139,16 @@ public final class TraceRecorder { System.err.println("[codeflash-tracer] Captured " + totalCaptures + " invocations across " + functionCounts.size() + " methods" + (dropped > 0 ? " (" + dropped + " dropped due to serialization timeout/failure)" : "")); + + // Timing summary + long onEntryMs = totalOnEntryNs.get() / 1_000_000; + long serMs = totalSerializationNs.get() / 1_000_000; + String writerSummary = writer.getTimingSummary(); + System.err.println("[codeflash-tracer] Timing: onEntry=" + onEntryMs + "ms" + + " (serialization=" + serMs + "ms)" + + (totalCaptures > 0 + ? " avg=" + String.format("%.2f", (double) onEntryMs / totalCaptures) + "ms/capture" + : "") + + " " + writerSummary); } } diff --git a/codeflash-java-runtime/src/main/java/com/codeflash/tracer/TraceWriter.java b/codeflash-java-runtime/src/main/java/com/codeflash/tracer/TraceWriter.java index a9eeabf60..7bc5032cb 100644 --- a/codeflash-java-runtime/src/main/java/com/codeflash/tracer/TraceWriter.java +++ b/codeflash-java-runtime/src/main/java/com/codeflash/tracer/TraceWriter.java @@ -7,30 +7,49 @@ import java.sql.DriverManager; import java.sql.PreparedStatement; import java.sql.SQLException; import java.sql.Statement; +import java.util.ArrayList; +import java.util.List; import java.util.Map; import java.util.concurrent.BlockingQueue; -import java.util.concurrent.LinkedBlockingQueue; +import java.util.concurrent.ArrayBlockingQueue; import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.concurrent.atomic.AtomicLong; public final class TraceWriter { + private static final int BATCH_SIZE = 256; + private static final int QUEUE_CAPACITY = 65536; + private final Connection connection; + private final Path diskPath; + private final boolean inMemory; private final BlockingQueue writeQueue; private final Thread writerThread; private final AtomicBoolean running; + private final AtomicLong totalWriteNs = new AtomicLong(0); + private final AtomicInteger batchCount = new AtomicInteger(0); + private final AtomicInteger taskCount = new AtomicInteger(0); + private volatile long dumpToFileMs = 0; private PreparedStatement insertFunctionCall; private PreparedStatement insertMetadata; - public TraceWriter(String dbPath) { - this.writeQueue = new LinkedBlockingQueue<>(); + public TraceWriter(String dbPath, boolean inMemory) { + this.diskPath = Paths.get(dbPath).toAbsolutePath(); + this.diskPath.getParent().toFile().mkdirs(); + this.inMemory = inMemory; + this.writeQueue = new ArrayBlockingQueue<>(QUEUE_CAPACITY); this.running = new AtomicBoolean(true); try { - Path path = Paths.get(dbPath).toAbsolutePath(); - path.getParent().toFile().mkdirs(); - this.connection = DriverManager.getConnection("jdbc:sqlite:" + path); + if (inMemory) { + // In-memory database for maximum write performance; flushed to disk via VACUUM INTO at close() + this.connection = DriverManager.getConnection("jdbc:sqlite::memory:"); + } else { + this.connection = DriverManager.getConnection("jdbc:sqlite:" + this.diskPath); + } initializeSchema(); prepareStatements(); @@ -45,8 +64,12 @@ public final class TraceWriter { private void initializeSchema() throws SQLException { try (Statement stmt = connection.createStatement()) { - stmt.execute("PRAGMA journal_mode=WAL"); - stmt.execute("PRAGMA synchronous=NORMAL"); + if (!inMemory) { + stmt.execute("PRAGMA journal_mode=WAL"); + stmt.execute("PRAGMA synchronous=NORMAL"); + stmt.execute("PRAGMA cache_size=-16000"); + stmt.execute("PRAGMA temp_store=MEMORY"); + } stmt.execute( "CREATE TABLE IF NOT EXISTS function_calls(" + @@ -69,6 +92,8 @@ public final class TraceWriter { stmt.execute("CREATE INDEX IF NOT EXISTS idx_fc_class_func ON function_calls(classname, function)"); } + // Keep autocommit off for writer performance — commit explicitly per batch + connection.setAutoCommit(false); } private void prepareStatements() throws SQLException { @@ -95,29 +120,65 @@ public final class TraceWriter { } private void writerLoop() { + List batch = new ArrayList<>(BATCH_SIZE); + while (running.get() || !writeQueue.isEmpty()) { try { WriteTask task = writeQueue.poll(100, TimeUnit.MILLISECONDS); - if (task != null) { - task.execute(this); + if (task == null) { + continue; } + batch.add(task); + writeQueue.drainTo(batch, BATCH_SIZE - 1); + executeBatch(batch); + batch.clear(); } catch (InterruptedException e) { Thread.currentThread().interrupt(); break; - } catch (SQLException e) { - System.err.println("[codeflash-tracer] Write error: " + e.getMessage()); } } // Drain remaining - WriteTask task; - while ((task = writeQueue.poll()) != null) { + writeQueue.drainTo(batch); + if (!batch.isEmpty()) { + executeBatch(batch); + } + } + + private void executeBatch(List batch) { + if (batch.isEmpty()) { + return; + } + + long writeStart = System.nanoTime(); + boolean hasFunctionCalls = false; + try { + for (WriteTask task : batch) { + if (task instanceof FunctionCallTask) { + ((FunctionCallTask) task).bindParameters(this); + insertFunctionCall.addBatch(); + hasFunctionCalls = true; + } else { + task.execute(this); + } + } + + if (hasFunctionCalls) { + insertFunctionCall.executeBatch(); + } + + connection.commit(); + } catch (SQLException e) { + System.err.println("[codeflash-tracer] Batch write error (" + batch.size() + " tasks): " + e.getMessage()); try { - task.execute(this); - } catch (SQLException e) { - System.err.println("[codeflash-tracer] Write error: " + e.getMessage()); + connection.rollback(); + } catch (SQLException re) { + System.err.println("[codeflash-tracer] Rollback failed: " + re.getMessage()); } } + totalWriteNs.addAndGet(System.nanoTime() - writeStart); + batchCount.incrementAndGet(); + taskCount.addAndGet(batch.size()); } public void flush() { @@ -131,6 +192,15 @@ public final class TraceWriter { } } + public String getTimingSummary() { + long writeMs = totalWriteNs.get() / 1_000_000; + int batches = batchCount.get(); + int tasks = taskCount.get(); + return "writes=" + writeMs + "ms (" + tasks + " tasks in " + batches + " batches" + + (batches > 0 ? ", avg=" + String.format("%.1f", (double) tasks / batches) + " tasks/batch" : "") + + ") dump=" + dumpToFileMs + "ms"; + } + public void close() { running.set(false); try { @@ -139,9 +209,29 @@ public final class TraceWriter { Thread.currentThread().interrupt(); } + // Close prepared statements first — required before VACUUM try { if (insertFunctionCall != null) insertFunctionCall.close(); if (insertMetadata != null) insertMetadata.close(); + } catch (SQLException e) { + System.err.println("[codeflash-tracer] Error closing statements: " + e.getMessage()); + } + + if (inMemory) { + long dumpStart = System.nanoTime(); + try { + connection.commit(); + connection.setAutoCommit(true); + try (Statement stmt = connection.createStatement()) { + stmt.execute("VACUUM INTO '" + diskPath.toString().replace("'", "''") + "'"); + } + } catch (SQLException e) { + System.err.println("[codeflash-tracer] Failed to write trace DB to disk: " + e.getMessage()); + } + dumpToFileMs = (System.nanoTime() - dumpStart) / 1_000_000; + } + + try { if (connection != null) connection.close(); } catch (SQLException e) { System.err.println("[codeflash-tracer] Error closing TraceWriter: " + e.getMessage()); @@ -177,8 +267,7 @@ public final class TraceWriter { this.argsBlob = argsBlob; } - @Override - public void execute(TraceWriter writer) throws SQLException { + void bindParameters(TraceWriter writer) throws SQLException { writer.insertFunctionCall.setString(1, type); writer.insertFunctionCall.setString(2, function); writer.insertFunctionCall.setString(3, classname); @@ -187,6 +276,11 @@ public final class TraceWriter { writer.insertFunctionCall.setString(6, descriptor); writer.insertFunctionCall.setLong(7, timeNs); writer.insertFunctionCall.setBytes(8, argsBlob); + } + + @Override + public void execute(TraceWriter writer) throws SQLException { + bindParameters(writer); writer.insertFunctionCall.executeUpdate(); } } diff --git a/codeflash-java-runtime/src/main/java/com/codeflash/tracer/TracerConfig.java b/codeflash-java-runtime/src/main/java/com/codeflash/tracer/TracerConfig.java index 8fe799d2f..9e2675c00 100644 --- a/codeflash-java-runtime/src/main/java/com/codeflash/tracer/TracerConfig.java +++ b/codeflash-java-runtime/src/main/java/com/codeflash/tracer/TracerConfig.java @@ -30,6 +30,9 @@ public final class TracerConfig { @SerializedName("projectRoot") private String projectRoot = ""; + @SerializedName("inMemoryDb") + private boolean inMemoryDb = false; + private static final Gson GSON = new Gson(); public static TracerConfig parse(String agentArgs) { @@ -89,6 +92,10 @@ public final class TracerConfig { return projectRoot; } + public boolean isInMemoryDb() { + return inMemoryDb; + } + public boolean shouldInstrumentClass(String internalClassName) { String dotName = internalClassName.replace('/', '.'); diff --git a/codeflash/languages/java/tracer.py b/codeflash/languages/java/tracer.py index 50506797e..8e8348681 100644 --- a/codeflash/languages/java/tracer.py +++ b/codeflash/languages/java/tracer.py @@ -6,6 +6,7 @@ import os import subprocess from typing import TYPE_CHECKING +from codeflash.code_utils.env_utils import is_ci from codeflash.languages.java.line_profiler import find_agent_jar from codeflash.languages.java.replay_test import generate_replay_tests @@ -114,6 +115,7 @@ class JavaTracer: "maxFunctionCount": max_function_count, "timeout": timeout, "projectRoot": str(project_root.resolve()) if project_root else "", + "inMemoryDb": is_ci(), } config_path = trace_db_path.with_suffix(".config.json") diff --git a/tests/test_languages/fixtures/java_tracer_e2e/src/main/java/com/example/ProfilingWorkload.java b/tests/test_languages/fixtures/java_tracer_e2e/src/main/java/com/example/ProfilingWorkload.java new file mode 100644 index 000000000..b7c48c625 --- /dev/null +++ b/tests/test_languages/fixtures/java_tracer_e2e/src/main/java/com/example/ProfilingWorkload.java @@ -0,0 +1,91 @@ +package com.example; + +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +/** + * Profiling workload for benchmarking the codeflash tracing agent. + * Exercises different argument types to stress serialization paths. + */ +public class ProfilingWorkload { + + // 1. Primitives only — cheapest to serialize + public static int addInts(int a, int b) { + return a + b; + } + + // 2. String arguments — moderate serialization cost + public static String concatStrings(String a, String b) { + return a + b; + } + + // 3. Array argument — requires element-by-element serialization + public static int sumArray(int[] values) { + int sum = 0; + for (int v : values) { + sum += v; + } + return sum; + } + + // 4. Collection argument — triggers recursive Kryo processing + public static int sumList(List values) { + int sum = 0; + for (int v : values) { + sum += v; + } + return sum; + } + + // 5. Nested map — deep object graph, expensive serialization + public static int countMapEntries(Map> data) { + int count = 0; + for (List list : data.values()) { + count += list.size(); + } + return count; + } + + public static void main(String[] args) { + int iterations = 1000; + + // 1. Primitives + for (int i = 0; i < iterations; i++) { + addInts(i, i + 1); + } + + // 2. Strings + for (int i = 0; i < iterations; i++) { + concatStrings("hello-" + i, "-world"); + } + + // 3. Arrays + int[] arr = new int[100]; + for (int i = 0; i < arr.length; i++) arr[i] = i; + for (int i = 0; i < iterations; i++) { + sumArray(arr); + } + + // 4. Lists + List list = new ArrayList<>(100); + for (int i = 0; i < 100; i++) list.add(i); + for (int i = 0; i < iterations; i++) { + sumList(list); + } + + // 5. Nested maps + Map> map = new HashMap<>(); + for (int i = 0; i < 10; i++) { + List vals = new ArrayList<>(); + for (int j = 0; j < 10; j++) vals.add(j); + map.put("key-" + i, vals); + } + for (int i = 0; i < iterations; i++) { + countMapEntries(map); + } + + System.out.println("ProfilingWorkload complete."); + } +} diff --git a/tests/test_languages/test_java/test_java_tracer_e2e.py b/tests/test_languages/test_java/test_java_tracer_e2e.py index c7dce2379..2ea87de9c 100644 --- a/tests/test_languages/test_java/test_java_tracer_e2e.py +++ b/tests/test_languages/test_java/test_java_tracer_e2e.py @@ -196,7 +196,6 @@ class TestReplayTestGeneration: assert "import org.junit.jupiter.api.Test;" in content assert "ReplayHelper" in content assert "replay_computeSum_0" in content - assert "replay_repeatString_0" in content def test_metadata_parsing(self, compiled_workload: Path, trace_db: Path, tmp_path: Path) -> None: """Test that metadata comments are correctly parsed from generated tests."""