wip java support

This commit is contained in:
misrasaurabh1 2026-01-30 00:37:24 -08:00
parent 351dd7539f
commit 29f266ee63
61 changed files with 13048 additions and 12 deletions

View file

@ -0,0 +1,5 @@
# Codeflash configuration for Java project
[tool.codeflash]
module-root = "src/main/java"
tests-root = "src/test/java"

View file

@ -0,0 +1,122 @@
package com.example;
import java.util.ArrayList;
import java.util.List;
/**
* Collection of algorithms that can be optimized by Codeflash.
*/
public class Algorithms {
/**
* Calculate Fibonacci number using naive recursive approach.
* This has O(2^n) time complexity and should be optimized.
*
* @param n The position in Fibonacci sequence (0-indexed)
* @return The nth Fibonacci number
*/
public long fibonacci(int n) {
if (n <= 1) {
return n;
}
return fibonacci(n - 1) + fibonacci(n - 2);
}
/**
* Find all prime numbers up to n using naive approach.
* This can be optimized with Sieve of Eratosthenes.
*
* @param n Upper bound for finding primes
* @return List of all prime numbers <= n
*/
public List<Integer> findPrimes(int n) {
List<Integer> primes = new ArrayList<>();
for (int i = 2; i <= n; i++) {
if (isPrime(i)) {
primes.add(i);
}
}
return primes;
}
/**
* Check if a number is prime using naive trial division.
*
* @param num Number to check
* @return true if num is prime
*/
private boolean isPrime(int num) {
if (num < 2) return false;
for (int i = 2; i < num; i++) {
if (num % i == 0) {
return false;
}
}
return true;
}
/**
* Find duplicates in an array using O(n^2) nested loops.
* This can be optimized with HashSet to O(n).
*
* @param arr Input array
* @return List of duplicate elements
*/
public List<Integer> findDuplicates(int[] arr) {
List<Integer> duplicates = new ArrayList<>();
for (int i = 0; i < arr.length; i++) {
for (int j = i + 1; j < arr.length; j++) {
if (arr[i] == arr[j] && !duplicates.contains(arr[i])) {
duplicates.add(arr[i]);
}
}
}
return duplicates;
}
/**
* Calculate factorial recursively without tail optimization.
*
* @param n Number to calculate factorial for
* @return n!
*/
public long factorial(int n) {
if (n <= 1) {
return 1;
}
return n * factorial(n - 1);
}
/**
* Concatenate strings in a loop using String concatenation.
* Should be optimized to use StringBuilder.
*
* @param items List of strings to concatenate
* @return Concatenated result
*/
public String concatenateStrings(List<String> items) {
String result = "";
for (String item : items) {
result = result + item + ", ";
}
if (result.length() > 2) {
result = result.substring(0, result.length() - 2);
}
return result;
}
/**
* Calculate sum of squares using a loop.
* This is already efficient but shows a simple example.
*
* @param n Upper bound
* @return Sum of squares from 1 to n
*/
public long sumOfSquares(int n) {
long sum = 0;
for (int i = 1; i <= n; i++) {
sum += (long) i * i;
}
return sum;
}
}

View file

@ -0,0 +1,129 @@
package com.example;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.DisplayName;
import java.util.Arrays;
import java.util.List;
import static org.junit.jupiter.api.Assertions.*;
/**
* Unit tests for Algorithms class.
*/
class AlgorithmsTest {
private Algorithms algorithms;
@BeforeEach
void setUp() {
algorithms = new Algorithms();
}
@Test
@DisplayName("Fibonacci of 0 should return 0")
void testFibonacciZero() {
assertEquals(0, algorithms.fibonacci(0));
}
@Test
@DisplayName("Fibonacci of 1 should return 1")
void testFibonacciOne() {
assertEquals(1, algorithms.fibonacci(1));
}
@Test
@DisplayName("Fibonacci of 10 should return 55")
void testFibonacciTen() {
assertEquals(55, algorithms.fibonacci(10));
}
@Test
@DisplayName("Fibonacci of 20 should return 6765")
void testFibonacciTwenty() {
assertEquals(6765, algorithms.fibonacci(20));
}
@Test
@DisplayName("Find primes up to 10")
void testFindPrimesUpToTen() {
List<Integer> primes = algorithms.findPrimes(10);
assertEquals(Arrays.asList(2, 3, 5, 7), primes);
}
@Test
@DisplayName("Find primes up to 20")
void testFindPrimesUpToTwenty() {
List<Integer> primes = algorithms.findPrimes(20);
assertEquals(Arrays.asList(2, 3, 5, 7, 11, 13, 17, 19), primes);
}
@Test
@DisplayName("Find duplicates in array with duplicates")
void testFindDuplicatesWithDuplicates() {
int[] arr = {1, 2, 3, 2, 4, 3, 5};
List<Integer> duplicates = algorithms.findDuplicates(arr);
assertTrue(duplicates.contains(2));
assertTrue(duplicates.contains(3));
assertEquals(2, duplicates.size());
}
@Test
@DisplayName("Find duplicates in array without duplicates")
void testFindDuplicatesNoDuplicates() {
int[] arr = {1, 2, 3, 4, 5};
List<Integer> duplicates = algorithms.findDuplicates(arr);
assertTrue(duplicates.isEmpty());
}
@Test
@DisplayName("Factorial of 0 should return 1")
void testFactorialZero() {
assertEquals(1, algorithms.factorial(0));
}
@Test
@DisplayName("Factorial of 5 should return 120")
void testFactorialFive() {
assertEquals(120, algorithms.factorial(5));
}
@Test
@DisplayName("Factorial of 10 should return 3628800")
void testFactorialTen() {
assertEquals(3628800, algorithms.factorial(10));
}
@Test
@DisplayName("Concatenate empty list")
void testConcatenateEmptyList() {
assertEquals("", algorithms.concatenateStrings(List.of()));
}
@Test
@DisplayName("Concatenate single item")
void testConcatenateSingleItem() {
assertEquals("hello", algorithms.concatenateStrings(List.of("hello")));
}
@Test
@DisplayName("Concatenate multiple items")
void testConcatenateMultipleItems() {
assertEquals("a, b, c", algorithms.concatenateStrings(Arrays.asList("a", "b", "c")));
}
@Test
@DisplayName("Sum of squares up to 5")
void testSumOfSquaresFive() {
// 1 + 4 + 9 + 16 + 25 = 55
assertEquals(55, algorithms.sumOfSquares(5));
}
@Test
@DisplayName("Sum of squares up to 10")
void testSumOfSquaresTen() {
// 1 + 4 + 9 + 16 + 25 + 36 + 49 + 64 + 81 + 100 = 385
assertEquals(385, algorithms.sumOfSquares(10));
}
}

View file

@ -0,0 +1,42 @@
package com.codeflash;
/**
* Context object for tracking benchmark timing.
*
* Created by {@link CodeFlash#startBenchmark(String)} and passed to
* {@link CodeFlash#endBenchmark(BenchmarkContext)}.
*/
public final class BenchmarkContext {
private final String methodId;
private final long startTime;
/**
* Create a new benchmark context.
*
* @param methodId Method being benchmarked
* @param startTime Start time in nanoseconds
*/
BenchmarkContext(String methodId, long startTime) {
this.methodId = methodId;
this.startTime = startTime;
}
/**
* Get the method ID.
*
* @return Method identifier
*/
public String getMethodId() {
return methodId;
}
/**
* Get the start time.
*
* @return Start time in nanoseconds
*/
public long getStartTime() {
return startTime;
}
}

View file

@ -0,0 +1,160 @@
package com.codeflash;
import java.util.Arrays;
/**
* Result of a benchmark run with statistical analysis.
*
* Provides JMH-style statistics including mean, standard deviation,
* and percentiles (p50, p90, p99).
*/
public final class BenchmarkResult {
private final String methodId;
private final long[] measurements;
private final long mean;
private final long stdDev;
private final long min;
private final long max;
private final long p50;
private final long p90;
private final long p99;
/**
* Create a benchmark result from raw measurements.
*
* @param methodId Method that was benchmarked
* @param measurements Array of timing measurements in nanoseconds
*/
public BenchmarkResult(String methodId, long[] measurements) {
this.methodId = methodId;
this.measurements = measurements.clone();
// Sort for percentile calculations
long[] sorted = measurements.clone();
Arrays.sort(sorted);
this.min = sorted[0];
this.max = sorted[sorted.length - 1];
this.mean = calculateMean(sorted);
this.stdDev = calculateStdDev(sorted, this.mean);
this.p50 = percentile(sorted, 50);
this.p90 = percentile(sorted, 90);
this.p99 = percentile(sorted, 99);
}
private static long calculateMean(long[] values) {
long sum = 0;
for (long v : values) {
sum += v;
}
return sum / values.length;
}
private static long calculateStdDev(long[] values, long mean) {
if (values.length < 2) {
return 0;
}
long sumSquaredDiff = 0;
for (long v : values) {
long diff = v - mean;
sumSquaredDiff += diff * diff;
}
return (long) Math.sqrt(sumSquaredDiff / (values.length - 1));
}
private static long percentile(long[] sorted, int percentile) {
int index = (int) Math.ceil(percentile / 100.0 * sorted.length) - 1;
return sorted[Math.max(0, Math.min(index, sorted.length - 1))];
}
// Getters
public String getMethodId() {
return methodId;
}
public long[] getMeasurements() {
return measurements.clone();
}
public int getIterationCount() {
return measurements.length;
}
public long getMean() {
return mean;
}
public long getStdDev() {
return stdDev;
}
public long getMin() {
return min;
}
public long getMax() {
return max;
}
public long getP50() {
return p50;
}
public long getP90() {
return p90;
}
public long getP99() {
return p99;
}
/**
* Get mean in milliseconds.
*/
public double getMeanMs() {
return mean / 1_000_000.0;
}
/**
* Get standard deviation in milliseconds.
*/
public double getStdDevMs() {
return stdDev / 1_000_000.0;
}
/**
* Calculate coefficient of variation (CV) as percentage.
* CV = (stdDev / mean) * 100
* Lower is better (more stable measurements).
*/
public double getCoefficientOfVariation() {
if (mean == 0) {
return 0;
}
return (stdDev * 100.0) / mean;
}
/**
* Check if measurements are stable (CV < 10%).
*/
public boolean isStable() {
return getCoefficientOfVariation() < 10.0;
}
@Override
public String toString() {
return String.format(
"BenchmarkResult{method='%s', mean=%.3fms, stdDev=%.3fms, p50=%.3fms, p90=%.3fms, p99=%.3fms, cv=%.1f%%, iterations=%d}",
methodId,
getMeanMs(),
getStdDevMs(),
p50 / 1_000_000.0,
p90 / 1_000_000.0,
p99 / 1_000_000.0,
getCoefficientOfVariation(),
measurements.length
);
}
}

View file

@ -0,0 +1,148 @@
package com.codeflash;
/**
* Utility class to prevent dead code elimination by the JIT compiler.
*
* Inspired by JMH's Blackhole class. When the JVM detects that a computed
* value is never used, it may eliminate the computation entirely. By
* "consuming" values through this class, we prevent such optimizations.
*
* Usage:
* <pre>
* int result = expensiveComputation();
* Blackhole.consume(result); // Prevents JIT from eliminating the computation
* </pre>
*
* The implementation uses volatile writes which act as memory barriers,
* preventing the JIT from optimizing away the computation.
*/
public final class Blackhole {
// Volatile fields act as memory barriers, preventing optimization
private static volatile int intSink;
private static volatile long longSink;
private static volatile double doubleSink;
private static volatile Object objectSink;
private Blackhole() {
// Utility class, no instantiation
}
/**
* Consume an int value to prevent dead code elimination.
*
* @param value Value to consume
*/
public static void consume(int value) {
intSink = value;
}
/**
* Consume a long value to prevent dead code elimination.
*
* @param value Value to consume
*/
public static void consume(long value) {
longSink = value;
}
/**
* Consume a double value to prevent dead code elimination.
*
* @param value Value to consume
*/
public static void consume(double value) {
doubleSink = value;
}
/**
* Consume a float value to prevent dead code elimination.
*
* @param value Value to consume
*/
public static void consume(float value) {
doubleSink = value;
}
/**
* Consume a boolean value to prevent dead code elimination.
*
* @param value Value to consume
*/
public static void consume(boolean value) {
intSink = value ? 1 : 0;
}
/**
* Consume a byte value to prevent dead code elimination.
*
* @param value Value to consume
*/
public static void consume(byte value) {
intSink = value;
}
/**
* Consume a short value to prevent dead code elimination.
*
* @param value Value to consume
*/
public static void consume(short value) {
intSink = value;
}
/**
* Consume a char value to prevent dead code elimination.
*
* @param value Value to consume
*/
public static void consume(char value) {
intSink = value;
}
/**
* Consume an Object to prevent dead code elimination.
* Works for any reference type including arrays and collections.
*
* @param value Value to consume
*/
public static void consume(Object value) {
objectSink = value;
}
/**
* Consume an int array to prevent dead code elimination.
*
* @param values Array to consume
*/
public static void consume(int[] values) {
objectSink = values;
if (values != null && values.length > 0) {
intSink = values[0];
}
}
/**
* Consume a long array to prevent dead code elimination.
*
* @param values Array to consume
*/
public static void consume(long[] values) {
objectSink = values;
if (values != null && values.length > 0) {
longSink = values[0];
}
}
/**
* Consume a double array to prevent dead code elimination.
*
* @param values Array to consume
*/
public static void consume(double[] values) {
objectSink = values;
if (values != null && values.length > 0) {
doubleSink = values[0];
}
}
}

View file

@ -0,0 +1,264 @@
package com.codeflash;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.util.concurrent.atomic.AtomicLong;
/**
* Main API for CodeFlash runtime instrumentation.
*
* Provides methods for:
* - Capturing function inputs/outputs for behavior verification
* - Benchmarking with JMH-inspired best practices
* - Preventing dead code elimination
*
* Usage:
* <pre>
* // Behavior capture
* CodeFlash.captureInput("Calculator.add", a, b);
* int result = a + b;
* return CodeFlash.captureOutput("Calculator.add", result);
*
* // Benchmarking
* BenchmarkContext ctx = CodeFlash.startBenchmark("Calculator.add");
* // ... code to benchmark ...
* CodeFlash.endBenchmark(ctx);
* </pre>
*/
public final class CodeFlash {
private static final AtomicLong callIdCounter = new AtomicLong(0);
private static volatile ResultWriter resultWriter;
private static volatile boolean initialized = false;
private static volatile String outputFile;
// Configuration from environment variables
private static final int DEFAULT_WARMUP_ITERATIONS = 10;
private static final int DEFAULT_MEASUREMENT_ITERATIONS = 20;
static {
// Register shutdown hook to flush results
Runtime.getRuntime().addShutdownHook(new Thread(() -> {
if (resultWriter != null) {
resultWriter.close();
}
}));
}
private CodeFlash() {
// Utility class, no instantiation
}
/**
* Initialize CodeFlash with output file path.
* Called automatically if CODEFLASH_OUTPUT_FILE env var is set.
*
* @param outputPath Path to output file (SQLite database)
*/
public static synchronized void initialize(String outputPath) {
if (!initialized || !outputPath.equals(outputFile)) {
outputFile = outputPath;
Path path = Paths.get(outputPath);
resultWriter = new ResultWriter(path);
initialized = true;
}
}
/**
* Get or create the result writer, initializing from environment if needed.
*/
private static ResultWriter getWriter() {
if (!initialized) {
String envPath = System.getenv("CODEFLASH_OUTPUT_FILE");
if (envPath != null && !envPath.isEmpty()) {
initialize(envPath);
} else {
// Default to temp file if no env var
initialize(System.getProperty("java.io.tmpdir") + "/codeflash_results.db");
}
}
return resultWriter;
}
/**
* Capture function input arguments.
*
* @param methodId Unique identifier for the method (e.g., "Calculator.add")
* @param args Input arguments
*/
public static void captureInput(String methodId, Object... args) {
long callId = callIdCounter.incrementAndGet();
String argsJson = Serializer.toJson(args);
getWriter().recordInput(callId, methodId, argsJson, System.nanoTime());
}
/**
* Capture function output and return it (for chaining in return statements).
*
* @param methodId Unique identifier for the method
* @param result The result value
* @param <T> Type of the result
* @return The same result (for chaining)
*/
public static <T> T captureOutput(String methodId, T result) {
long callId = callIdCounter.get(); // Use same callId as input
String resultJson = Serializer.toJson(result);
getWriter().recordOutput(callId, methodId, resultJson, System.nanoTime());
return result;
}
/**
* Capture an exception thrown by the function.
*
* @param methodId Unique identifier for the method
* @param error The exception
*/
public static void captureException(String methodId, Throwable error) {
long callId = callIdCounter.get();
String errorJson = Serializer.exceptionToJson(error);
getWriter().recordError(callId, methodId, errorJson, System.nanoTime());
}
/**
* Start a benchmark context for timing code execution.
* Implements JMH-inspired warmup and measurement phases.
*
* @param methodId Unique identifier for the method being benchmarked
* @return BenchmarkContext to pass to endBenchmark
*/
public static BenchmarkContext startBenchmark(String methodId) {
return new BenchmarkContext(methodId, System.nanoTime());
}
/**
* End a benchmark and record the timing.
*
* @param ctx The benchmark context from startBenchmark
*/
public static void endBenchmark(BenchmarkContext ctx) {
long endTime = System.nanoTime();
long duration = endTime - ctx.getStartTime();
getWriter().recordBenchmark(ctx.getMethodId(), duration, endTime);
}
/**
* Run a benchmark with proper JMH-style warmup and measurement.
*
* @param methodId Unique identifier for the method
* @param runnable Code to benchmark
* @return Benchmark result with statistics
*/
public static BenchmarkResult runBenchmark(String methodId, Runnable runnable) {
int warmupIterations = getWarmupIterations();
int measurementIterations = getMeasurementIterations();
// Warmup phase - results discarded
for (int i = 0; i < warmupIterations; i++) {
runnable.run();
}
// Suggest GC before measurement (hint only, not guaranteed)
System.gc();
// Measurement phase
long[] measurements = new long[measurementIterations];
for (int i = 0; i < measurementIterations; i++) {
long start = System.nanoTime();
runnable.run();
measurements[i] = System.nanoTime() - start;
}
BenchmarkResult result = new BenchmarkResult(methodId, measurements);
getWriter().recordBenchmarkResult(methodId, result);
return result;
}
/**
* Run a benchmark that returns a value (prevents dead code elimination).
*
* @param methodId Unique identifier for the method
* @param supplier Code to benchmark that returns a value
* @param <T> Return type
* @return Benchmark result with statistics
*/
public static <T> BenchmarkResult runBenchmarkWithResult(String methodId, java.util.function.Supplier<T> supplier) {
int warmupIterations = getWarmupIterations();
int measurementIterations = getMeasurementIterations();
// Warmup phase - consume results to prevent dead code elimination
for (int i = 0; i < warmupIterations; i++) {
Blackhole.consume(supplier.get());
}
// Suggest GC before measurement
System.gc();
// Measurement phase
long[] measurements = new long[measurementIterations];
for (int i = 0; i < measurementIterations; i++) {
long start = System.nanoTime();
T result = supplier.get();
measurements[i] = System.nanoTime() - start;
Blackhole.consume(result); // Prevent dead code elimination
}
BenchmarkResult benchmarkResult = new BenchmarkResult(methodId, measurements);
getWriter().recordBenchmarkResult(methodId, benchmarkResult);
return benchmarkResult;
}
/**
* Get warmup iterations from environment or use default.
*/
private static int getWarmupIterations() {
String env = System.getenv("CODEFLASH_WARMUP_ITERATIONS");
if (env != null) {
try {
return Integer.parseInt(env);
} catch (NumberFormatException e) {
// Use default
}
}
return DEFAULT_WARMUP_ITERATIONS;
}
/**
* Get measurement iterations from environment or use default.
*/
private static int getMeasurementIterations() {
String env = System.getenv("CODEFLASH_MEASUREMENT_ITERATIONS");
if (env != null) {
try {
return Integer.parseInt(env);
} catch (NumberFormatException e) {
// Use default
}
}
return DEFAULT_MEASUREMENT_ITERATIONS;
}
/**
* Get the current call ID (for correlation).
*
* @return Current call ID
*/
public static long getCurrentCallId() {
return callIdCounter.get();
}
/**
* Reset the call ID counter (for testing).
*/
public static void resetCallId() {
callIdCounter.set(0);
}
/**
* Force flush all pending writes.
*/
public static void flush() {
if (resultWriter != null) {
resultWriter.flush();
}
}
}

View file

@ -0,0 +1,349 @@
package com.codeflash;
import com.google.gson.Gson;
import com.google.gson.GsonBuilder;
import com.google.gson.JsonArray;
import com.google.gson.JsonElement;
import com.google.gson.JsonObject;
import com.google.gson.JsonParser;
import java.sql.Connection;
import java.sql.DriverManager;
import java.sql.PreparedStatement;
import java.sql.ResultSet;
import java.sql.SQLException;
import java.util.ArrayList;
import java.util.List;
import java.util.Objects;
/**
* Compares test results between original and optimized code.
*
* Used by CodeFlash to verify that optimized code produces the
* same outputs as the original code for the same inputs.
*
* Can be run as a CLI tool:
* java -jar codeflash-runtime.jar original.db candidate.db
*/
public final class Comparator {
private static final Gson GSON = new GsonBuilder()
.serializeNulls()
.setPrettyPrinting()
.create();
// Tolerance for floating point comparison
private static final double EPSILON = 1e-9;
private Comparator() {
// Utility class
}
/**
* Main entry point for CLI usage.
*
* @param args [originalDb, candidateDb]
*/
public static void main(String[] args) {
if (args.length != 2) {
System.err.println("Usage: java -jar codeflash-runtime.jar <original.db> <candidate.db>");
System.exit(1);
}
try {
ComparisonResult result = compare(args[0], args[1]);
System.out.println(GSON.toJson(result));
System.exit(result.isEquivalent() ? 0 : 1);
} catch (Exception e) {
JsonObject error = new JsonObject();
error.addProperty("error", e.getMessage());
System.out.println(GSON.toJson(error));
System.exit(2);
}
}
/**
* Compare two result databases.
*
* @param originalDbPath Path to original results database
* @param candidateDbPath Path to candidate results database
* @return Comparison result with list of differences
*/
public static ComparisonResult compare(String originalDbPath, String candidateDbPath) throws SQLException {
List<Diff> diffs = new ArrayList<>();
try (Connection originalConn = DriverManager.getConnection("jdbc:sqlite:" + originalDbPath);
Connection candidateConn = DriverManager.getConnection("jdbc:sqlite:" + candidateDbPath)) {
// Get all invocations from original
List<Invocation> originalInvocations = getInvocations(originalConn);
List<Invocation> candidateInvocations = getInvocations(candidateConn);
// Create lookup map for candidate invocations
java.util.Map<Long, Invocation> candidateMap = new java.util.HashMap<>();
for (Invocation inv : candidateInvocations) {
candidateMap.put(inv.callId, inv);
}
// Compare each original invocation with candidate
for (Invocation original : originalInvocations) {
Invocation candidate = candidateMap.get(original.callId);
if (candidate == null) {
diffs.add(new Diff(
original.callId,
original.methodId,
DiffType.MISSING_IN_CANDIDATE,
"Invocation not found in candidate",
original.resultJson,
null
));
continue;
}
// Compare results
if (!compareJsonValues(original.resultJson, candidate.resultJson)) {
diffs.add(new Diff(
original.callId,
original.methodId,
DiffType.RETURN_VALUE,
"Return values differ",
original.resultJson,
candidate.resultJson
));
}
// Compare errors
boolean originalHasError = original.errorJson != null && !original.errorJson.isEmpty();
boolean candidateHasError = candidate.errorJson != null && !candidate.errorJson.isEmpty();
if (originalHasError != candidateHasError) {
diffs.add(new Diff(
original.callId,
original.methodId,
DiffType.EXCEPTION,
originalHasError ? "Original threw exception, candidate did not" :
"Candidate threw exception, original did not",
original.errorJson,
candidate.errorJson
));
} else if (originalHasError && !compareExceptions(original.errorJson, candidate.errorJson)) {
diffs.add(new Diff(
original.callId,
original.methodId,
DiffType.EXCEPTION,
"Exception details differ",
original.errorJson,
candidate.errorJson
));
}
// Remove from map to track extra invocations
candidateMap.remove(original.callId);
}
// Check for extra invocations in candidate
for (Invocation extra : candidateMap.values()) {
diffs.add(new Diff(
extra.callId,
extra.methodId,
DiffType.EXTRA_IN_CANDIDATE,
"Extra invocation in candidate",
null,
extra.resultJson
));
}
}
return new ComparisonResult(diffs.isEmpty(), diffs);
}
private static List<Invocation> getInvocations(Connection conn) throws SQLException {
List<Invocation> invocations = new ArrayList<>();
String sql = "SELECT call_id, method_id, args_json, result_json, error_json FROM invocations ORDER BY call_id";
try (PreparedStatement stmt = conn.prepareStatement(sql);
ResultSet rs = stmt.executeQuery()) {
while (rs.next()) {
invocations.add(new Invocation(
rs.getLong("call_id"),
rs.getString("method_id"),
rs.getString("args_json"),
rs.getString("result_json"),
rs.getString("error_json")
));
}
}
return invocations;
}
/**
* Compare two JSON values for equivalence.
*/
private static boolean compareJsonValues(String json1, String json2) {
if (json1 == null && json2 == null) return true;
if (json1 == null || json2 == null) return false;
if (json1.equals(json2)) return true;
try {
JsonElement elem1 = JsonParser.parseString(json1);
JsonElement elem2 = JsonParser.parseString(json2);
return compareJsonElements(elem1, elem2);
} catch (Exception e) {
// If parsing fails, fall back to string comparison
return json1.equals(json2);
}
}
private static boolean compareJsonElements(JsonElement elem1, JsonElement elem2) {
if (elem1 == null && elem2 == null) return true;
if (elem1 == null || elem2 == null) return false;
if (elem1.isJsonNull() && elem2.isJsonNull()) return true;
// Compare primitives
if (elem1.isJsonPrimitive() && elem2.isJsonPrimitive()) {
return comparePrimitives(elem1.getAsJsonPrimitive(), elem2.getAsJsonPrimitive());
}
// Compare arrays
if (elem1.isJsonArray() && elem2.isJsonArray()) {
return compareArrays(elem1.getAsJsonArray(), elem2.getAsJsonArray());
}
// Compare objects
if (elem1.isJsonObject() && elem2.isJsonObject()) {
return compareObjects(elem1.getAsJsonObject(), elem2.getAsJsonObject());
}
return false;
}
private static boolean comparePrimitives(com.google.gson.JsonPrimitive p1, com.google.gson.JsonPrimitive p2) {
// Handle numeric comparison with epsilon
if (p1.isNumber() && p2.isNumber()) {
double d1 = p1.getAsDouble();
double d2 = p2.getAsDouble();
// Handle NaN
if (Double.isNaN(d1) && Double.isNaN(d2)) return true;
// Handle infinity
if (Double.isInfinite(d1) && Double.isInfinite(d2)) {
return (d1 > 0) == (d2 > 0);
}
// Compare with epsilon
return Math.abs(d1 - d2) < EPSILON;
}
return Objects.equals(p1, p2);
}
private static boolean compareArrays(JsonArray arr1, JsonArray arr2) {
if (arr1.size() != arr2.size()) return false;
for (int i = 0; i < arr1.size(); i++) {
if (!compareJsonElements(arr1.get(i), arr2.get(i))) {
return false;
}
}
return true;
}
private static boolean compareObjects(JsonObject obj1, JsonObject obj2) {
// Skip type metadata for comparison
java.util.Set<String> keys1 = new java.util.HashSet<>(obj1.keySet());
java.util.Set<String> keys2 = new java.util.HashSet<>(obj2.keySet());
keys1.remove("__type__");
keys2.remove("__type__");
if (!keys1.equals(keys2)) return false;
for (String key : keys1) {
if (!compareJsonElements(obj1.get(key), obj2.get(key))) {
return false;
}
}
return true;
}
private static boolean compareExceptions(String error1, String error2) {
try {
JsonObject e1 = JsonParser.parseString(error1).getAsJsonObject();
JsonObject e2 = JsonParser.parseString(error2).getAsJsonObject();
// Compare exception type and message
String type1 = e1.has("type") ? e1.get("type").getAsString() : "";
String type2 = e2.has("type") ? e2.get("type").getAsString() : "";
// Types must match
return type1.equals(type2);
} catch (Exception e) {
return error1.equals(error2);
}
}
// Data classes
private static class Invocation {
final long callId;
final String methodId;
final String argsJson;
final String resultJson;
final String errorJson;
Invocation(long callId, String methodId, String argsJson, String resultJson, String errorJson) {
this.callId = callId;
this.methodId = methodId;
this.argsJson = argsJson;
this.resultJson = resultJson;
this.errorJson = errorJson;
}
}
public enum DiffType {
RETURN_VALUE,
EXCEPTION,
MISSING_IN_CANDIDATE,
EXTRA_IN_CANDIDATE
}
public static class Diff {
private final long callId;
private final String methodId;
private final DiffType type;
private final String message;
private final String originalValue;
private final String candidateValue;
public Diff(long callId, String methodId, DiffType type, String message,
String originalValue, String candidateValue) {
this.callId = callId;
this.methodId = methodId;
this.type = type;
this.message = message;
this.originalValue = originalValue;
this.candidateValue = candidateValue;
}
// Getters
public long getCallId() { return callId; }
public String getMethodId() { return methodId; }
public DiffType getType() { return type; }
public String getMessage() { return message; }
public String getOriginalValue() { return originalValue; }
public String getCandidateValue() { return candidateValue; }
}
public static class ComparisonResult {
private final boolean equivalent;
private final List<Diff> diffs;
public ComparisonResult(boolean equivalent, List<Diff> diffs) {
this.equivalent = equivalent;
this.diffs = diffs;
}
public boolean isEquivalent() { return equivalent; }
public List<Diff> getDiffs() { return diffs; }
}
}

View file

@ -0,0 +1,318 @@
package com.codeflash;
import java.nio.file.Path;
import java.sql.Connection;
import java.sql.DriverManager;
import java.sql.PreparedStatement;
import java.sql.SQLException;
import java.sql.Statement;
import java.util.concurrent.BlockingQueue;
import java.util.concurrent.LinkedBlockingQueue;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicBoolean;
/**
* Writes benchmark and behavior capture results to SQLite database.
*
* Uses a background thread for non-blocking writes to minimize
* impact on benchmark measurements.
*
* Database schema:
* - invocations: call_id, method_id, args_json, result_json, error_json, start_time, end_time
* - benchmarks: method_id, duration_ns, timestamp
* - benchmark_results: method_id, mean_ns, stddev_ns, min_ns, max_ns, p50_ns, p90_ns, p99_ns, iterations
*/
public final class ResultWriter {
private final Path dbPath;
private final Connection connection;
private final BlockingQueue<WriteTask> writeQueue;
private final Thread writerThread;
private final AtomicBoolean running;
// Prepared statements for performance
private PreparedStatement insertInvocationInput;
private PreparedStatement updateInvocationOutput;
private PreparedStatement updateInvocationError;
private PreparedStatement insertBenchmark;
private PreparedStatement insertBenchmarkResult;
/**
* Create a new ResultWriter that writes to the specified database file.
*
* @param dbPath Path to SQLite database file (will be created if not exists)
*/
public ResultWriter(Path dbPath) {
this.dbPath = dbPath;
this.writeQueue = new LinkedBlockingQueue<>();
this.running = new AtomicBoolean(true);
try {
// Create connection and initialize schema
this.connection = DriverManager.getConnection("jdbc:sqlite:" + dbPath.toAbsolutePath());
initializeSchema();
prepareStatements();
// Start background writer thread
this.writerThread = new Thread(this::writerLoop, "codeflash-writer");
this.writerThread.setDaemon(true);
this.writerThread.start();
} catch (SQLException e) {
throw new RuntimeException("Failed to initialize ResultWriter: " + e.getMessage(), e);
}
}
private void initializeSchema() throws SQLException {
try (Statement stmt = connection.createStatement()) {
// Invocations table - stores input/output/error for each function call
stmt.execute(
"CREATE TABLE IF NOT EXISTS invocations (" +
"call_id INTEGER PRIMARY KEY, " +
"method_id TEXT NOT NULL, " +
"args_json TEXT, " +
"result_json TEXT, " +
"error_json TEXT, " +
"start_time INTEGER, " +
"end_time INTEGER)"
);
// Benchmarks table - stores individual benchmark timings
stmt.execute(
"CREATE TABLE IF NOT EXISTS benchmarks (" +
"id INTEGER PRIMARY KEY AUTOINCREMENT, " +
"method_id TEXT NOT NULL, " +
"duration_ns INTEGER NOT NULL, " +
"timestamp INTEGER NOT NULL)"
);
// Benchmark results table - stores aggregated statistics
stmt.execute(
"CREATE TABLE IF NOT EXISTS benchmark_results (" +
"method_id TEXT PRIMARY KEY, " +
"mean_ns INTEGER NOT NULL, " +
"stddev_ns INTEGER NOT NULL, " +
"min_ns INTEGER NOT NULL, " +
"max_ns INTEGER NOT NULL, " +
"p50_ns INTEGER NOT NULL, " +
"p90_ns INTEGER NOT NULL, " +
"p99_ns INTEGER NOT NULL, " +
"iterations INTEGER NOT NULL, " +
"coefficient_of_variation REAL NOT NULL)"
);
// Create indexes for faster queries
stmt.execute("CREATE INDEX IF NOT EXISTS idx_invocations_method ON invocations(method_id)");
stmt.execute("CREATE INDEX IF NOT EXISTS idx_benchmarks_method ON benchmarks(method_id)");
}
}
private void prepareStatements() throws SQLException {
insertInvocationInput = connection.prepareStatement(
"INSERT INTO invocations (call_id, method_id, args_json, start_time) VALUES (?, ?, ?, ?)"
);
updateInvocationOutput = connection.prepareStatement(
"UPDATE invocations SET result_json = ?, end_time = ? WHERE call_id = ?"
);
updateInvocationError = connection.prepareStatement(
"UPDATE invocations SET error_json = ?, end_time = ? WHERE call_id = ?"
);
insertBenchmark = connection.prepareStatement(
"INSERT INTO benchmarks (method_id, duration_ns, timestamp) VALUES (?, ?, ?)"
);
insertBenchmarkResult = connection.prepareStatement(
"INSERT OR REPLACE INTO benchmark_results " +
"(method_id, mean_ns, stddev_ns, min_ns, max_ns, p50_ns, p90_ns, p99_ns, iterations, coefficient_of_variation) " +
"VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)"
);
}
/**
* Record function input (beginning of invocation).
*/
public void recordInput(long callId, String methodId, String argsJson, long startTime) {
writeQueue.offer(new WriteTask(WriteType.INPUT, callId, methodId, argsJson, null, null, startTime, 0, null));
}
/**
* Record function output (successful completion).
*/
public void recordOutput(long callId, String methodId, String resultJson, long endTime) {
writeQueue.offer(new WriteTask(WriteType.OUTPUT, callId, methodId, null, resultJson, null, 0, endTime, null));
}
/**
* Record function error (exception thrown).
*/
public void recordError(long callId, String methodId, String errorJson, long endTime) {
writeQueue.offer(new WriteTask(WriteType.ERROR, callId, methodId, null, null, errorJson, 0, endTime, null));
}
/**
* Record a single benchmark timing.
*/
public void recordBenchmark(String methodId, long durationNs, long timestamp) {
writeQueue.offer(new WriteTask(WriteType.BENCHMARK, 0, methodId, null, null, null, durationNs, timestamp, null));
}
/**
* Record aggregated benchmark results.
*/
public void recordBenchmarkResult(String methodId, BenchmarkResult result) {
writeQueue.offer(new WriteTask(WriteType.BENCHMARK_RESULT, 0, methodId, null, null, null, 0, 0, result));
}
/**
* Background writer loop - processes write tasks from queue.
*/
private void writerLoop() {
while (running.get() || !writeQueue.isEmpty()) {
try {
WriteTask task = writeQueue.poll(100, TimeUnit.MILLISECONDS);
if (task != null) {
executeTask(task);
}
} catch (InterruptedException e) {
Thread.currentThread().interrupt();
break;
} catch (SQLException e) {
System.err.println("CodeFlash ResultWriter error: " + e.getMessage());
}
}
// Process remaining tasks
WriteTask task;
while ((task = writeQueue.poll()) != null) {
try {
executeTask(task);
} catch (SQLException e) {
System.err.println("CodeFlash ResultWriter error: " + e.getMessage());
}
}
}
private void executeTask(WriteTask task) throws SQLException {
switch (task.type) {
case INPUT:
insertInvocationInput.setLong(1, task.callId);
insertInvocationInput.setString(2, task.methodId);
insertInvocationInput.setString(3, task.argsJson);
insertInvocationInput.setLong(4, task.startTime);
insertInvocationInput.executeUpdate();
break;
case OUTPUT:
updateInvocationOutput.setString(1, task.resultJson);
updateInvocationOutput.setLong(2, task.endTime);
updateInvocationOutput.setLong(3, task.callId);
updateInvocationOutput.executeUpdate();
break;
case ERROR:
updateInvocationError.setString(1, task.errorJson);
updateInvocationError.setLong(2, task.endTime);
updateInvocationError.setLong(3, task.callId);
updateInvocationError.executeUpdate();
break;
case BENCHMARK:
insertBenchmark.setString(1, task.methodId);
insertBenchmark.setLong(2, task.startTime); // duration stored in startTime field
insertBenchmark.setLong(3, task.endTime); // timestamp stored in endTime field
insertBenchmark.executeUpdate();
break;
case BENCHMARK_RESULT:
BenchmarkResult r = task.benchmarkResult;
insertBenchmarkResult.setString(1, task.methodId);
insertBenchmarkResult.setLong(2, r.getMean());
insertBenchmarkResult.setLong(3, r.getStdDev());
insertBenchmarkResult.setLong(4, r.getMin());
insertBenchmarkResult.setLong(5, r.getMax());
insertBenchmarkResult.setLong(6, r.getP50());
insertBenchmarkResult.setLong(7, r.getP90());
insertBenchmarkResult.setLong(8, r.getP99());
insertBenchmarkResult.setInt(9, r.getIterationCount());
insertBenchmarkResult.setDouble(10, r.getCoefficientOfVariation());
insertBenchmarkResult.executeUpdate();
break;
}
}
/**
* Flush all pending writes synchronously.
*/
public void flush() {
// Wait for queue to drain
while (!writeQueue.isEmpty()) {
try {
Thread.sleep(10);
} catch (InterruptedException e) {
Thread.currentThread().interrupt();
break;
}
}
}
/**
* Close the writer and database connection.
*/
public void close() {
running.set(false);
try {
writerThread.join(5000); // Wait up to 5 seconds
} catch (InterruptedException e) {
Thread.currentThread().interrupt();
}
try {
if (insertInvocationInput != null) insertInvocationInput.close();
if (updateInvocationOutput != null) updateInvocationOutput.close();
if (updateInvocationError != null) updateInvocationError.close();
if (insertBenchmark != null) insertBenchmark.close();
if (insertBenchmarkResult != null) insertBenchmarkResult.close();
if (connection != null) connection.close();
} catch (SQLException e) {
System.err.println("Error closing ResultWriter: " + e.getMessage());
}
}
/**
* Get the database path.
*/
public Path getDbPath() {
return dbPath;
}
// Internal task class for queue
private enum WriteType {
INPUT, OUTPUT, ERROR, BENCHMARK, BENCHMARK_RESULT
}
private static class WriteTask {
final WriteType type;
final long callId;
final String methodId;
final String argsJson;
final String resultJson;
final String errorJson;
final long startTime;
final long endTime;
final BenchmarkResult benchmarkResult;
WriteTask(WriteType type, long callId, String methodId, String argsJson,
String resultJson, String errorJson, long startTime, long endTime,
BenchmarkResult benchmarkResult) {
this.type = type;
this.callId = callId;
this.methodId = methodId;
this.argsJson = argsJson;
this.resultJson = resultJson;
this.errorJson = errorJson;
this.startTime = startTime;
this.endTime = endTime;
this.benchmarkResult = benchmarkResult;
}
}
}

View file

@ -0,0 +1,282 @@
package com.codeflash;
import com.google.gson.Gson;
import com.google.gson.GsonBuilder;
import com.google.gson.JsonArray;
import com.google.gson.JsonElement;
import com.google.gson.JsonNull;
import com.google.gson.JsonObject;
import com.google.gson.JsonPrimitive;
import java.lang.reflect.Field;
import java.lang.reflect.Modifier;
import java.time.LocalDate;
import java.time.LocalDateTime;
import java.time.LocalTime;
import java.util.Collection;
import java.util.Date;
import java.util.IdentityHashMap;
import java.util.Map;
import java.util.Optional;
/**
* Serializer for Java objects to JSON format.
*
* Handles:
* - Primitives and their wrappers
* - Strings
* - Arrays (primitive and object)
* - Collections (List, Set, etc.)
* - Maps
* - Date/Time types
* - Custom objects via reflection
* - Circular references (detected and marked)
*/
public final class Serializer {
private static final Gson GSON = new GsonBuilder()
.serializeNulls()
.create();
private static final int MAX_DEPTH = 10;
private static final int MAX_COLLECTION_SIZE = 1000;
private Serializer() {
// Utility class
}
/**
* Serialize an object to JSON string.
*
* @param obj Object to serialize
* @return JSON string representation
*/
public static String toJson(Object obj) {
try {
JsonElement element = serialize(obj, new IdentityHashMap<>(), 0);
return GSON.toJson(element);
} catch (Exception e) {
// Fallback for serialization errors
JsonObject error = new JsonObject();
error.addProperty("__serialization_error__", e.getMessage());
error.addProperty("__type__", obj != null ? obj.getClass().getName() : "null");
return GSON.toJson(error);
}
}
/**
* Serialize varargs (for capturing multiple arguments).
*
* @param args Arguments to serialize
* @return JSON array string
*/
public static String toJson(Object... args) {
JsonArray array = new JsonArray();
IdentityHashMap<Object, Boolean> seen = new IdentityHashMap<>();
for (Object arg : args) {
array.add(serialize(arg, seen, 0));
}
return GSON.toJson(array);
}
/**
* Serialize an exception to JSON.
*
* @param error Exception to serialize
* @return JSON string with exception details
*/
public static String exceptionToJson(Throwable error) {
JsonObject obj = new JsonObject();
obj.addProperty("__exception__", true);
obj.addProperty("type", error.getClass().getName());
obj.addProperty("message", error.getMessage());
// Capture stack trace
JsonArray stackTrace = new JsonArray();
for (StackTraceElement element : error.getStackTrace()) {
stackTrace.add(element.toString());
}
obj.add("stackTrace", stackTrace);
// Capture cause if present
if (error.getCause() != null) {
obj.addProperty("causeType", error.getCause().getClass().getName());
obj.addProperty("causeMessage", error.getCause().getMessage());
}
return GSON.toJson(obj);
}
private static JsonElement serialize(Object obj, IdentityHashMap<Object, Boolean> seen, int depth) {
if (obj == null) {
return JsonNull.INSTANCE;
}
// Depth limit to prevent infinite recursion
if (depth > MAX_DEPTH) {
JsonObject truncated = new JsonObject();
truncated.addProperty("__truncated__", "max depth exceeded");
return truncated;
}
Class<?> clazz = obj.getClass();
// Primitives and wrappers
if (obj instanceof Boolean) {
return new JsonPrimitive((Boolean) obj);
}
if (obj instanceof Number) {
return new JsonPrimitive((Number) obj);
}
if (obj instanceof Character) {
return new JsonPrimitive(String.valueOf(obj));
}
if (obj instanceof String) {
return new JsonPrimitive((String) obj);
}
// Check for circular reference (only for reference types)
if (seen.containsKey(obj)) {
JsonObject circular = new JsonObject();
circular.addProperty("__circular_ref__", clazz.getName());
return circular;
}
seen.put(obj, Boolean.TRUE);
try {
// Date/Time types
if (obj instanceof Date) {
return new JsonPrimitive(((Date) obj).toInstant().toString());
}
if (obj instanceof LocalDateTime) {
return new JsonPrimitive(obj.toString());
}
if (obj instanceof LocalDate) {
return new JsonPrimitive(obj.toString());
}
if (obj instanceof LocalTime) {
return new JsonPrimitive(obj.toString());
}
// Optional
if (obj instanceof Optional) {
Optional<?> opt = (Optional<?>) obj;
if (opt.isPresent()) {
return serialize(opt.get(), seen, depth + 1);
} else {
return JsonNull.INSTANCE;
}
}
// Arrays
if (clazz.isArray()) {
return serializeArray(obj, seen, depth);
}
// Collections
if (obj instanceof Collection) {
return serializeCollection((Collection<?>) obj, seen, depth);
}
// Maps
if (obj instanceof Map) {
return serializeMap((Map<?, ?>) obj, seen, depth);
}
// Enums
if (clazz.isEnum()) {
return new JsonPrimitive(((Enum<?>) obj).name());
}
// Custom objects - serialize via reflection
return serializeObject(obj, seen, depth);
} finally {
seen.remove(obj);
}
}
private static JsonElement serializeArray(Object array, IdentityHashMap<Object, Boolean> seen, int depth) {
JsonArray jsonArray = new JsonArray();
int length = java.lang.reflect.Array.getLength(array);
int limit = Math.min(length, MAX_COLLECTION_SIZE);
for (int i = 0; i < limit; i++) {
Object element = java.lang.reflect.Array.get(array, i);
jsonArray.add(serialize(element, seen, depth + 1));
}
if (length > limit) {
JsonObject truncated = new JsonObject();
truncated.addProperty("__truncated__", length - limit + " more elements");
jsonArray.add(truncated);
}
return jsonArray;
}
private static JsonElement serializeCollection(Collection<?> collection, IdentityHashMap<Object, Boolean> seen, int depth) {
JsonArray jsonArray = new JsonArray();
int count = 0;
for (Object element : collection) {
if (count >= MAX_COLLECTION_SIZE) {
JsonObject truncated = new JsonObject();
truncated.addProperty("__truncated__", collection.size() - count + " more elements");
jsonArray.add(truncated);
break;
}
jsonArray.add(serialize(element, seen, depth + 1));
count++;
}
return jsonArray;
}
private static JsonElement serializeMap(Map<?, ?> map, IdentityHashMap<Object, Boolean> seen, int depth) {
JsonObject jsonObject = new JsonObject();
int count = 0;
for (Map.Entry<?, ?> entry : map.entrySet()) {
if (count >= MAX_COLLECTION_SIZE) {
jsonObject.addProperty("__truncated__", map.size() - count + " more entries");
break;
}
String key = entry.getKey() != null ? entry.getKey().toString() : "null";
jsonObject.add(key, serialize(entry.getValue(), seen, depth + 1));
count++;
}
return jsonObject;
}
private static JsonElement serializeObject(Object obj, IdentityHashMap<Object, Boolean> seen, int depth) {
JsonObject jsonObject = new JsonObject();
Class<?> clazz = obj.getClass();
// Add type information
jsonObject.addProperty("__type__", clazz.getName());
// Serialize all fields (including inherited)
while (clazz != null && clazz != Object.class) {
for (Field field : clazz.getDeclaredFields()) {
// Skip static and transient fields
if (Modifier.isStatic(field.getModifiers()) ||
Modifier.isTransient(field.getModifiers())) {
continue;
}
try {
field.setAccessible(true);
Object value = field.get(obj);
jsonObject.add(field.getName(), serialize(value, seen, depth + 1));
} catch (IllegalAccessException e) {
jsonObject.addProperty(field.getName(), "__access_denied__");
}
}
clazz = clazz.getSuperclass();
}
return jsonObject;
}
}

View file

@ -0,0 +1,126 @@
package com.codeflash;
import org.junit.jupiter.api.DisplayName;
import org.junit.jupiter.api.Test;
import static org.junit.jupiter.api.Assertions.*;
/**
* Tests for the BenchmarkResult class.
*/
@DisplayName("BenchmarkResult Tests")
class BenchmarkResultTest {
@Test
@DisplayName("should calculate mean correctly")
void testMean() {
long[] measurements = {100, 200, 300, 400, 500};
BenchmarkResult result = new BenchmarkResult("test", measurements);
assertEquals(300, result.getMean());
}
@Test
@DisplayName("should calculate min and max")
void testMinMax() {
long[] measurements = {100, 50, 200, 150, 75};
BenchmarkResult result = new BenchmarkResult("test", measurements);
assertEquals(50, result.getMin());
assertEquals(200, result.getMax());
}
@Test
@DisplayName("should calculate percentiles")
void testPercentiles() {
long[] measurements = new long[100];
for (int i = 0; i < 100; i++) {
measurements[i] = i + 1; // 1 to 100
}
BenchmarkResult result = new BenchmarkResult("test", measurements);
assertEquals(50, result.getP50());
assertEquals(90, result.getP90());
assertEquals(99, result.getP99());
}
@Test
@DisplayName("should calculate standard deviation")
void testStdDev() {
// All same values should have 0 std dev
long[] sameValues = {100, 100, 100, 100, 100};
BenchmarkResult sameResult = new BenchmarkResult("test", sameValues);
assertEquals(0, sameResult.getStdDev());
// Different values should have non-zero std dev
long[] differentValues = {100, 200, 300, 400, 500};
BenchmarkResult diffResult = new BenchmarkResult("test", differentValues);
assertTrue(diffResult.getStdDev() > 0);
}
@Test
@DisplayName("should calculate coefficient of variation")
void testCoefficientOfVariation() {
long[] measurements = {100, 100, 100, 100, 100};
BenchmarkResult result = new BenchmarkResult("test", measurements);
assertEquals(0.0, result.getCoefficientOfVariation(), 0.001);
}
@Test
@DisplayName("should detect stable measurements")
void testIsStable() {
// Low variance - stable
long[] stableMeasurements = {100, 101, 99, 100, 102};
BenchmarkResult stableResult = new BenchmarkResult("test", stableMeasurements);
assertTrue(stableResult.isStable());
// High variance - unstable
long[] unstableMeasurements = {100, 200, 50, 300, 25};
BenchmarkResult unstableResult = new BenchmarkResult("test", unstableMeasurements);
assertFalse(unstableResult.isStable());
}
@Test
@DisplayName("should convert to milliseconds")
void testMillisecondConversion() {
long[] measurements = {1_000_000, 2_000_000, 3_000_000}; // 1ms, 2ms, 3ms
BenchmarkResult result = new BenchmarkResult("test", measurements);
assertEquals(2.0, result.getMeanMs(), 0.001);
}
@Test
@DisplayName("should clone measurements array")
void testMeasurementsCloned() {
long[] original = {100, 200, 300};
BenchmarkResult result = new BenchmarkResult("test", original);
long[] retrieved = result.getMeasurements();
retrieved[0] = 999;
// Original should not be affected
assertEquals(100, result.getMeasurements()[0]);
}
@Test
@DisplayName("should return correct iteration count")
void testIterationCount() {
long[] measurements = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10};
BenchmarkResult result = new BenchmarkResult("test", measurements);
assertEquals(10, result.getIterationCount());
}
@Test
@DisplayName("should have meaningful toString")
void testToString() {
long[] measurements = {1_000_000, 2_000_000};
BenchmarkResult result = new BenchmarkResult("Calculator.add", measurements);
String str = result.toString();
assertTrue(str.contains("Calculator.add"));
assertTrue(str.contains("mean="));
assertTrue(str.contains("ms"));
}
}

View file

@ -0,0 +1,108 @@
package com.codeflash;
import org.junit.jupiter.api.DisplayName;
import org.junit.jupiter.api.Test;
import java.util.Arrays;
import static org.junit.jupiter.api.Assertions.*;
/**
* Tests for the Blackhole class.
*/
@DisplayName("Blackhole Tests")
class BlackholeTest {
@Test
@DisplayName("should consume int without throwing")
void testConsumeInt() {
assertDoesNotThrow(() -> Blackhole.consume(42));
}
@Test
@DisplayName("should consume long without throwing")
void testConsumeLong() {
assertDoesNotThrow(() -> Blackhole.consume(Long.MAX_VALUE));
}
@Test
@DisplayName("should consume double without throwing")
void testConsumeDouble() {
assertDoesNotThrow(() -> Blackhole.consume(3.14159));
}
@Test
@DisplayName("should consume float without throwing")
void testConsumeFloat() {
assertDoesNotThrow(() -> Blackhole.consume(3.14f));
}
@Test
@DisplayName("should consume boolean without throwing")
void testConsumeBoolean() {
assertDoesNotThrow(() -> Blackhole.consume(true));
assertDoesNotThrow(() -> Blackhole.consume(false));
}
@Test
@DisplayName("should consume byte without throwing")
void testConsumeByte() {
assertDoesNotThrow(() -> Blackhole.consume((byte) 127));
}
@Test
@DisplayName("should consume short without throwing")
void testConsumeShort() {
assertDoesNotThrow(() -> Blackhole.consume((short) 32000));
}
@Test
@DisplayName("should consume char without throwing")
void testConsumeChar() {
assertDoesNotThrow(() -> Blackhole.consume('x'));
}
@Test
@DisplayName("should consume Object without throwing")
void testConsumeObject() {
assertDoesNotThrow(() -> Blackhole.consume("hello"));
assertDoesNotThrow(() -> Blackhole.consume(Arrays.asList(1, 2, 3)));
assertDoesNotThrow(() -> Blackhole.consume((Object) null));
}
@Test
@DisplayName("should consume int array without throwing")
void testConsumeIntArray() {
assertDoesNotThrow(() -> Blackhole.consume(new int[]{1, 2, 3}));
assertDoesNotThrow(() -> Blackhole.consume((int[]) null));
assertDoesNotThrow(() -> Blackhole.consume(new int[]{}));
}
@Test
@DisplayName("should consume long array without throwing")
void testConsumeLongArray() {
assertDoesNotThrow(() -> Blackhole.consume(new long[]{1L, 2L, 3L}));
assertDoesNotThrow(() -> Blackhole.consume((long[]) null));
}
@Test
@DisplayName("should consume double array without throwing")
void testConsumeDoubleArray() {
assertDoesNotThrow(() -> Blackhole.consume(new double[]{1.0, 2.0, 3.0}));
assertDoesNotThrow(() -> Blackhole.consume((double[]) null));
}
@Test
@DisplayName("should prevent dead code elimination in loop")
void testPreventDeadCodeInLoop() {
// This test verifies that consuming values allows the loop to run
// without the JIT potentially eliminating it
int sum = 0;
for (int i = 0; i < 1000; i++) {
sum += i;
Blackhole.consume(sum);
}
// The loop should have run - this is more of a smoke test
assertTrue(sum > 0);
}
}

View file

@ -0,0 +1,283 @@
package com.codeflash;
import org.junit.jupiter.api.DisplayName;
import org.junit.jupiter.api.Nested;
import org.junit.jupiter.api.Test;
import java.util.*;
import static org.junit.jupiter.api.Assertions.*;
/**
* Tests for the Serializer class.
*/
@DisplayName("Serializer Tests")
class SerializerTest {
@Nested
@DisplayName("Primitive Types")
class PrimitiveTests {
@Test
@DisplayName("should serialize integers")
void testInteger() {
assertEquals("42", Serializer.toJson(42));
assertEquals("-1", Serializer.toJson(-1));
assertEquals("0", Serializer.toJson(0));
}
@Test
@DisplayName("should serialize longs")
void testLong() {
assertEquals("9223372036854775807", Serializer.toJson(Long.MAX_VALUE));
}
@Test
@DisplayName("should serialize doubles")
void testDouble() {
String json = Serializer.toJson(3.14159);
assertTrue(json.startsWith("3.14"));
}
@Test
@DisplayName("should serialize booleans")
void testBoolean() {
assertEquals("true", Serializer.toJson(true));
assertEquals("false", Serializer.toJson(false));
}
@Test
@DisplayName("should serialize strings")
void testString() {
assertEquals("\"hello\"", Serializer.toJson("hello"));
assertEquals("\"with \\\"quotes\\\"\"", Serializer.toJson("with \"quotes\""));
}
@Test
@DisplayName("should serialize null")
void testNull() {
assertEquals("null", Serializer.toJson((Object) null));
}
@Test
@DisplayName("should serialize characters")
void testCharacter() {
assertEquals("\"a\"", Serializer.toJson('a'));
}
}
@Nested
@DisplayName("Array Types")
class ArrayTests {
@Test
@DisplayName("should serialize int arrays")
void testIntArray() {
int[] arr = {1, 2, 3};
assertEquals("[1,2,3]", Serializer.toJson((Object) arr));
}
@Test
@DisplayName("should serialize String arrays")
void testStringArray() {
String[] arr = {"a", "b", "c"};
assertEquals("[\"a\",\"b\",\"c\"]", Serializer.toJson((Object) arr));
}
@Test
@DisplayName("should serialize empty arrays")
void testEmptyArray() {
int[] arr = {};
assertEquals("[]", Serializer.toJson((Object) arr));
}
}
@Nested
@DisplayName("Collection Types")
class CollectionTests {
@Test
@DisplayName("should serialize Lists")
void testList() {
List<Integer> list = Arrays.asList(1, 2, 3);
assertEquals("[1,2,3]", Serializer.toJson(list));
}
@Test
@DisplayName("should serialize Sets")
void testSet() {
Set<String> set = new LinkedHashSet<>(Arrays.asList("a", "b"));
String json = Serializer.toJson(set);
assertTrue(json.contains("\"a\""));
assertTrue(json.contains("\"b\""));
}
@Test
@DisplayName("should serialize Maps")
void testMap() {
Map<String, Integer> map = new LinkedHashMap<>();
map.put("one", 1);
map.put("two", 2);
String json = Serializer.toJson(map);
assertTrue(json.contains("\"one\":1"));
assertTrue(json.contains("\"two\":2"));
}
@Test
@DisplayName("should handle nested collections")
void testNestedCollections() {
List<List<Integer>> nested = Arrays.asList(
Arrays.asList(1, 2),
Arrays.asList(3, 4)
);
assertEquals("[[1,2],[3,4]]", Serializer.toJson(nested));
}
}
@Nested
@DisplayName("Varargs")
class VarargsTests {
@Test
@DisplayName("should serialize multiple arguments")
void testVarargs() {
String json = Serializer.toJson(1, "hello", true);
assertEquals("[1,\"hello\",true]", json);
}
@Test
@DisplayName("should serialize mixed types")
void testMixedVarargs() {
String json = Serializer.toJson(42, Arrays.asList(1, 2), null);
assertTrue(json.startsWith("[42,"));
assertTrue(json.contains("null"));
}
}
@Nested
@DisplayName("Custom Objects")
class CustomObjectTests {
@Test
@DisplayName("should serialize simple objects")
void testSimpleObject() {
TestPerson person = new TestPerson("John", 30);
String json = Serializer.toJson(person);
assertTrue(json.contains("\"name\":\"John\""));
assertTrue(json.contains("\"age\":30"));
assertTrue(json.contains("\"__type__\""));
}
@Test
@DisplayName("should serialize nested objects")
void testNestedObject() {
TestAddress address = new TestAddress("123 Main St", "NYC");
TestPersonWithAddress person = new TestPersonWithAddress("Jane", address);
String json = Serializer.toJson(person);
assertTrue(json.contains("\"name\":\"Jane\""));
assertTrue(json.contains("\"city\":\"NYC\""));
}
}
@Nested
@DisplayName("Exception Serialization")
class ExceptionTests {
@Test
@DisplayName("should serialize exception with type and message")
void testException() {
Exception e = new IllegalArgumentException("test error");
String json = Serializer.exceptionToJson(e);
assertTrue(json.contains("\"__exception__\":true"));
assertTrue(json.contains("\"type\":\"java.lang.IllegalArgumentException\""));
assertTrue(json.contains("\"message\":\"test error\""));
}
@Test
@DisplayName("should include stack trace")
void testExceptionStackTrace() {
Exception e = new RuntimeException("test");
String json = Serializer.exceptionToJson(e);
assertTrue(json.contains("\"stackTrace\""));
}
@Test
@DisplayName("should include cause")
void testExceptionWithCause() {
Exception cause = new NullPointerException("root cause");
Exception e = new RuntimeException("wrapper", cause);
String json = Serializer.exceptionToJson(e);
assertTrue(json.contains("\"causeType\":\"java.lang.NullPointerException\""));
assertTrue(json.contains("\"causeMessage\":\"root cause\""));
}
}
@Nested
@DisplayName("Edge Cases")
class EdgeCaseTests {
@Test
@DisplayName("should handle Optional with value")
void testOptionalPresent() {
Optional<String> opt = Optional.of("value");
assertEquals("\"value\"", Serializer.toJson(opt));
}
@Test
@DisplayName("should handle Optional empty")
void testOptionalEmpty() {
Optional<String> opt = Optional.empty();
assertEquals("null", Serializer.toJson(opt));
}
@Test
@DisplayName("should handle enums")
void testEnum() {
assertEquals("\"MONDAY\"", Serializer.toJson(java.time.DayOfWeek.MONDAY));
}
@Test
@DisplayName("should handle Date")
void testDate() {
Date date = new Date(0); // Epoch
String json = Serializer.toJson(date);
assertTrue(json.contains("1970"));
}
}
// Test helper classes
static class TestPerson {
private final String name;
private final int age;
TestPerson(String name, int age) {
this.name = name;
this.age = age;
}
}
static class TestAddress {
private final String street;
private final String city;
TestAddress(String street, String city) {
this.street = street;
this.city = city;
}
}
static class TestPersonWithAddress {
private final String name;
private final TestAddress address;
TestPersonWithAddress(String name, TestAddress address) {
this.name = name;
this.address = address;
}
}
}

View file

@ -14,7 +14,7 @@ from codeflash.cli_cmds.console import console, logger
from codeflash.code_utils.env_utils import get_codeflash_api_key
from codeflash.code_utils.git_utils import get_last_commit_author_if_pr_exists, get_repo_owner_and_name
from codeflash.code_utils.time_utils import humanize_runtime
from codeflash.languages import is_javascript, is_python
from codeflash.languages import is_java, is_javascript, is_python
from codeflash.models.ExperimentMetadata import ExperimentMetadata
from codeflash.models.models import (
AIServiceRefinerRequest,
@ -182,6 +182,8 @@ class AiServiceClient:
payload["python_version"] = platform.python_version()
if is_python():
pass # python_version already set
elif is_java():
payload["language_version"] = language_version or "17" # Default Java version
else:
payload["language_version"] = language_version or "ES2022"
# Add module system for JavaScript/TypeScript (esm or commonjs)
@ -785,6 +787,8 @@ class AiServiceClient:
payload["python_version"] = platform.python_version()
if is_python():
pass # python_version already set
elif is_java():
payload["language_version"] = language_version or "17" # Default Java version
else:
payload["language_version"] = language_version or "ES2022"
# Add module system for JavaScript/TypeScript (esm or commonjs)

View file

@ -273,6 +273,20 @@ def process_pyproject_config(args: Namespace) -> Namespace:
def project_root_from_module_root(module_root: Path, pyproject_file_path: Path) -> Path:
if pyproject_file_path.parent == module_root:
return module_root
# For Java projects, find the directory containing pom.xml or build.gradle
# This handles the case where module_root is src/main/java
current = module_root
while current != current.parent:
if (current / "pom.xml").exists():
return current.resolve()
if (current / "build.gradle").exists() or (current / "build.gradle.kts").exists():
return current.resolve()
# Check for config file (pyproject.toml for Python, codeflash.toml for other languages)
if (current / "codeflash.toml").exists():
return current.resolve()
current = current.parent
return module_root.parent.resolve()

View file

@ -35,6 +35,9 @@ from codeflash.cli_cmds.init_javascript import (
get_js_dependency_installation_commands,
init_js_project,
)
# Import Java init module
from codeflash.cli_cmds.init_java import init_java_project
from codeflash.code_utils.code_utils import validate_relative_directory_path
from codeflash.code_utils.compat import LF
from codeflash.code_utils.config_parser import parse_config_file
@ -114,6 +117,10 @@ def init_codeflash() -> None:
# Detect project language
project_language = detect_project_language()
if project_language == ProjectLanguage.JAVA:
init_java_project()
return
if project_language in (ProjectLanguage.JAVASCRIPT, ProjectLanguage.TYPESCRIPT):
init_js_project(project_language)
return
@ -798,7 +805,9 @@ def install_github_actions(override_formatter_check: bool = False) -> None:
# Select the appropriate workflow template based on project language
project_language = detect_project_language_for_workflow(Path.cwd())
if project_language in ("javascript", "typescript"):
if project_language == "java":
workflow_template = "codeflash-optimize-java.yaml"
elif project_language in ("javascript", "typescript"):
workflow_template = "codeflash-optimize-js.yaml"
else:
workflow_template = "codeflash-optimize.yaml"
@ -1210,8 +1219,16 @@ def get_github_action_working_directory(toml_path: Path, git_root: Path) -> str:
def detect_project_language_for_workflow(project_root: Path) -> str:
"""Detect the primary language of the project for workflow generation.
Returns: 'python', 'javascript', or 'typescript'
Returns: 'python', 'javascript', 'typescript', or 'java'
"""
# Check for Java project (Maven or Gradle)
has_pom_xml = (project_root / "pom.xml").exists()
has_build_gradle = (project_root / "build.gradle").exists() or (project_root / "build.gradle.kts").exists()
has_java_src = (project_root / "src" / "main" / "java").is_dir()
if has_pom_xml or has_build_gradle or has_java_src:
return "java"
# Check for TypeScript config
if (project_root / "tsconfig.json").exists():
return "typescript"
@ -1230,6 +1247,7 @@ def detect_project_language_for_workflow(project_root: Path) -> str:
# Both exist - count files to determine primary language
js_count = 0
py_count = 0
java_count = 0
for file in project_root.rglob("*"):
if file.is_file():
suffix = file.suffix.lower()
@ -1237,8 +1255,13 @@ def detect_project_language_for_workflow(project_root: Path) -> str:
js_count += 1
elif suffix == ".py":
py_count += 1
elif suffix == ".java":
java_count += 1
if js_count > py_count:
max_count = max(js_count, py_count, java_count)
if max_count == java_count and java_count > 0:
return "java"
if max_count == js_count and js_count > 0:
return "javascript"
return "python"
@ -1343,9 +1366,9 @@ def generate_dynamic_workflow_content(
# Detect project language
project_language = detect_project_language_for_workflow(Path.cwd())
# For JavaScript/TypeScript projects, use static template customization
# For JavaScript/TypeScript and Java projects, use static template customization
# (AI-generated steps are currently Python-only)
if project_language in ("javascript", "typescript"):
if project_language in ("javascript", "typescript", "java"):
return customize_codeflash_yaml_content(optimize_yml_content, config, git_root, benchmark_mode)
# Python project - try AI-generated steps
@ -1466,6 +1489,10 @@ def customize_codeflash_yaml_content(
# Detect project language
project_language = detect_project_language_for_workflow(Path.cwd())
if project_language == "java":
# Java project
return _customize_java_workflow_content(optimize_yml_content, git_root, benchmark_mode)
if project_language in ("javascript", "typescript"):
# JavaScript/TypeScript project
return _customize_js_workflow_content(optimize_yml_content, git_root, benchmark_mode)
@ -1562,6 +1589,54 @@ def _customize_js_workflow_content(optimize_yml_content: str, git_root: Path, be
return optimize_yml_content.replace("{{ codeflash_command }}", codeflash_cmd)
def _customize_java_workflow_content(optimize_yml_content: str, git_root: Path, benchmark_mode: bool = False) -> str:
"""Customize workflow content for Java projects."""
from codeflash.cli_cmds.init_java import (
JavaBuildTool,
detect_java_build_tool,
get_java_dependency_installation_commands,
)
project_root = Path.cwd()
# Check for pom.xml or build.gradle
has_pom = (project_root / "pom.xml").exists()
has_gradle = (project_root / "build.gradle").exists() or (project_root / "build.gradle.kts").exists()
if not has_pom and not has_gradle:
click.echo(
f"I couldn't find a pom.xml or build.gradle in the current directory.{LF}"
f"Please ensure you're in a Maven or Gradle project directory."
)
apologize_and_exit()
# Determine working directory relative to git root
if project_root == git_root:
working_dir = ""
else:
rel_path = str(project_root.relative_to(git_root))
working_dir = f"""defaults:
run:
working-directory: ./{rel_path}"""
optimize_yml_content = optimize_yml_content.replace("{{ working_directory }}", working_dir)
# Determine build tool
build_tool = detect_java_build_tool(project_root)
# Set build tool cache type for actions/setup-java
if build_tool == JavaBuildTool.GRADLE:
optimize_yml_content = optimize_yml_content.replace("{{ java_build_tool }}", "gradle")
else:
optimize_yml_content = optimize_yml_content.replace("{{ java_build_tool }}", "maven")
# Install dependencies
install_deps_cmd = get_java_dependency_installation_commands(build_tool)
optimize_yml_content = optimize_yml_content.replace("{{ install_dependencies_command }}", install_deps_cmd)
return optimize_yml_content
def get_formatter_cmds(formatter: str) -> list[str]:
if formatter == "black":
return ["black $file"]

View file

@ -34,6 +34,7 @@ class ProjectLanguage(Enum):
PYTHON = auto()
JAVASCRIPT = auto()
TYPESCRIPT = auto()
JAVA = auto()
class JsPackageManager(Enum):
@ -89,6 +90,13 @@ def detect_project_language(project_root: Path | None = None) -> ProjectLanguage
has_setup_py = (root / "setup.py").exists()
has_package_json = (root / "package.json").exists()
has_tsconfig = (root / "tsconfig.json").exists()
has_pom_xml = (root / "pom.xml").exists()
has_build_gradle = (root / "build.gradle").exists() or (root / "build.gradle.kts").exists()
has_java_src = (root / "src" / "main" / "java").is_dir()
# Java project (Maven or Gradle)
if has_pom_xml or has_build_gradle or has_java_src:
return ProjectLanguage.JAVA
# TypeScript project
if has_tsconfig:

View file

@ -30,6 +30,7 @@ from codeflash.languages.base import (
from codeflash.languages.current import (
current_language,
current_language_support,
is_java,
is_javascript,
is_python,
is_typescript,
@ -41,6 +42,10 @@ from codeflash.languages.javascript import JavaScriptSupport, TypeScriptSupport
# Import language support modules to trigger auto-registration
# This ensures all supported languages are available when this package is imported
from codeflash.languages.python import PythonSupport # noqa: F401
# Java language support
# Importing the module triggers registration via @register_language decorator
from codeflash.languages.java.support import JavaSupport # noqa: F401
from codeflash.languages.registry import (
detect_project_language,
get_language_support,
@ -67,6 +72,7 @@ __all__ = [
"get_language_support",
"get_supported_extensions",
"get_supported_languages",
"is_java",
"is_javascript",
"is_python",
"is_typescript",

View file

@ -22,6 +22,7 @@ class Language(str, Enum):
PYTHON = "python"
JAVASCRIPT = "javascript"
TYPESCRIPT = "typescript"
JAVA = "java"
def __str__(self) -> str:
return self.value

View file

@ -106,6 +106,16 @@ def is_typescript() -> bool:
return _current_language == Language.TYPESCRIPT
def is_java() -> bool:
"""Check if the current language is Java.
Returns:
True if the current language is Java.
"""
return _current_language == Language.JAVA
def current_language_support() -> LanguageSupport:
"""Get the LanguageSupport instance for the current language.

View file

@ -0,0 +1,195 @@
"""Java language support for codeflash.
This module provides Java-specific functionality for code analysis,
test execution, and optimization using tree-sitter for parsing and
Maven/Gradle for build operations.
"""
from codeflash.languages.java.build_tools import (
BuildTool,
JavaProjectInfo,
MavenTestResult,
add_codeflash_dependency_to_pom,
compile_maven_project,
detect_build_tool,
find_gradle_executable,
find_maven_executable,
find_source_root,
find_test_root,
get_classpath,
get_project_info,
install_codeflash_runtime,
run_maven_tests,
)
from codeflash.languages.java.comparator import (
compare_invocations_directly,
compare_test_results,
)
from codeflash.languages.java.config import (
JavaProjectConfig,
detect_java_project,
get_test_class_pattern,
get_test_file_pattern,
is_java_project,
)
from codeflash.languages.java.context import (
extract_class_context,
extract_code_context,
extract_function_source,
extract_read_only_context,
find_helper_functions,
)
from codeflash.languages.java.discovery import (
discover_functions,
discover_functions_from_source,
discover_test_methods,
get_class_methods,
get_method_by_name,
)
from codeflash.languages.java.formatter import (
JavaFormatter,
format_java_code,
format_java_file,
normalize_java_code,
)
from codeflash.languages.java.import_resolver import (
JavaImportResolver,
ResolvedImport,
find_helper_files,
resolve_imports_for_file,
)
from codeflash.languages.java.instrumentation import (
create_benchmark_test,
instrument_existing_test,
instrument_for_behavior,
instrument_for_benchmarking,
remove_instrumentation,
)
from codeflash.languages.java.parser import (
JavaAnalyzer,
JavaClassNode,
JavaFieldInfo,
JavaImportInfo,
JavaMethodNode,
get_java_analyzer,
)
from codeflash.languages.java.replacement import (
add_runtime_comments,
insert_method,
remove_method,
remove_test_functions,
replace_function,
replace_method_body,
)
from codeflash.languages.java.support import (
JavaSupport,
get_java_support,
)
from codeflash.languages.java.test_discovery import (
build_test_mapping_for_project,
discover_all_tests,
discover_tests,
find_tests_for_function,
get_test_class_for_source_class,
get_test_file_suffix,
get_test_methods_for_class,
is_test_file,
)
from codeflash.languages.java.test_runner import (
JavaTestRunResult,
get_test_run_command,
parse_surefire_results,
parse_test_results,
run_behavioral_tests,
run_benchmarking_tests,
run_tests,
)
__all__ = [
# Parser
"JavaAnalyzer",
"JavaClassNode",
"JavaFieldInfo",
"JavaImportInfo",
"JavaMethodNode",
"get_java_analyzer",
# Build tools
"BuildTool",
"JavaProjectInfo",
"MavenTestResult",
"add_codeflash_dependency_to_pom",
"compile_maven_project",
"detect_build_tool",
"find_gradle_executable",
"find_maven_executable",
"find_source_root",
"find_test_root",
"get_classpath",
"get_project_info",
"install_codeflash_runtime",
"run_maven_tests",
# Comparator
"compare_invocations_directly",
"compare_test_results",
# Config
"JavaProjectConfig",
"detect_java_project",
"get_test_class_pattern",
"get_test_file_pattern",
"is_java_project",
# Context
"extract_class_context",
"extract_code_context",
"extract_function_source",
"extract_read_only_context",
"find_helper_functions",
# Discovery
"discover_functions",
"discover_functions_from_source",
"discover_test_methods",
"get_class_methods",
"get_method_by_name",
# Formatter
"JavaFormatter",
"format_java_code",
"format_java_file",
"normalize_java_code",
# Import resolver
"JavaImportResolver",
"ResolvedImport",
"find_helper_files",
"resolve_imports_for_file",
# Instrumentation
"create_benchmark_test",
"instrument_existing_test",
"instrument_for_behavior",
"instrument_for_benchmarking",
"remove_instrumentation",
# Replacement
"add_runtime_comments",
"insert_method",
"remove_method",
"remove_test_functions",
"replace_function",
"replace_method_body",
# Support
"JavaSupport",
"get_java_support",
# Test discovery
"build_test_mapping_for_project",
"discover_all_tests",
"discover_tests",
"find_tests_for_function",
"get_test_class_for_source_class",
"get_test_file_suffix",
"get_test_methods_for_class",
"is_test_file",
# Test runner
"JavaTestRunResult",
"get_test_run_command",
"parse_surefire_results",
"parse_test_results",
"run_behavioral_tests",
"run_benchmarking_tests",
"run_tests",
]

View file

@ -0,0 +1,742 @@
"""Java build tool detection and integration.
This module provides functionality to detect and work with Java build tools
(Maven and Gradle), including running tests and managing dependencies.
"""
from __future__ import annotations
import logging
import os
import shutil
import subprocess
import xml.etree.ElementTree as ET
from dataclasses import dataclass
from enum import Enum
from pathlib import Path
from typing import TYPE_CHECKING, Any
if TYPE_CHECKING:
pass
logger = logging.getLogger(__name__)
class BuildTool(Enum):
"""Supported Java build tools."""
MAVEN = "maven"
GRADLE = "gradle"
UNKNOWN = "unknown"
@dataclass
class JavaProjectInfo:
"""Information about a Java project."""
project_root: Path
build_tool: BuildTool
source_roots: list[Path]
test_roots: list[Path]
target_dir: Path # build output directory
group_id: str | None
artifact_id: str | None
version: str | None
java_version: str | None
@dataclass
class MavenTestResult:
"""Result of running Maven tests."""
success: bool
tests_run: int
failures: int
errors: int
skipped: int
surefire_reports_dir: Path | None
stdout: str
stderr: str
returncode: int
def detect_build_tool(project_root: Path) -> BuildTool:
"""Detect which build tool a Java project uses.
Args:
project_root: Root directory of the Java project.
Returns:
The detected BuildTool enum value.
"""
# Check for Maven (pom.xml)
if (project_root / "pom.xml").exists():
return BuildTool.MAVEN
# Check for Gradle (build.gradle or build.gradle.kts)
if (project_root / "build.gradle").exists() or (project_root / "build.gradle.kts").exists():
return BuildTool.GRADLE
# Check in parent directories for multi-module projects
current = project_root
for _ in range(3): # Check up to 3 levels
parent = current.parent
if parent == current:
break
if (parent / "pom.xml").exists():
return BuildTool.MAVEN
if (parent / "build.gradle").exists() or (parent / "build.gradle.kts").exists():
return BuildTool.GRADLE
current = parent
return BuildTool.UNKNOWN
def get_project_info(project_root: Path) -> JavaProjectInfo | None:
"""Get information about a Java project.
Args:
project_root: Root directory of the Java project.
Returns:
JavaProjectInfo if a supported project is found, None otherwise.
"""
build_tool = detect_build_tool(project_root)
if build_tool == BuildTool.MAVEN:
return _get_maven_project_info(project_root)
if build_tool == BuildTool.GRADLE:
return _get_gradle_project_info(project_root)
return None
def _get_maven_project_info(project_root: Path) -> JavaProjectInfo | None:
"""Get project info from Maven pom.xml.
Args:
project_root: Root directory of the Maven project.
Returns:
JavaProjectInfo extracted from pom.xml.
"""
pom_path = project_root / "pom.xml"
if not pom_path.exists():
return None
try:
tree = ET.parse(pom_path)
root = tree.getroot()
# Handle Maven namespace
ns = {"m": "http://maven.apache.org/POM/4.0.0"}
def get_text(xpath: str, default: str | None = None) -> str | None:
# Try with namespace first
elem = root.find(f"m:{xpath}", ns)
if elem is None:
# Try without namespace
elem = root.find(xpath)
return elem.text if elem is not None else default
group_id = get_text("groupId")
artifact_id = get_text("artifactId")
version = get_text("version")
# Get Java version from properties or compiler plugin
java_version = _extract_java_version_from_pom(root, ns)
# Standard Maven directory structure
source_roots = []
test_roots = []
main_src = project_root / "src" / "main" / "java"
if main_src.exists():
source_roots.append(main_src)
test_src = project_root / "src" / "test" / "java"
if test_src.exists():
test_roots.append(test_src)
target_dir = project_root / "target"
return JavaProjectInfo(
project_root=project_root,
build_tool=BuildTool.MAVEN,
source_roots=source_roots,
test_roots=test_roots,
target_dir=target_dir,
group_id=group_id,
artifact_id=artifact_id,
version=version,
java_version=java_version,
)
except ET.ParseError as e:
logger.warning("Failed to parse pom.xml: %s", e)
return None
def _extract_java_version_from_pom(root: ET.Element, ns: dict[str, str]) -> str | None:
"""Extract Java version from Maven pom.xml.
Checks multiple locations:
1. properties/maven.compiler.source
2. properties/java.version
3. build/plugins/plugin[compiler]/configuration/source
Args:
root: Root element of the pom.xml.
ns: XML namespace mapping.
Returns:
Java version string or None.
"""
# Check properties
for prop_name in ("maven.compiler.source", "java.version", "maven.compiler.release"):
for props in [root.find(f"m:properties", ns), root.find("properties")]:
if props is not None:
for prop in [props.find(f"m:{prop_name}", ns), props.find(prop_name)]:
if prop is not None and prop.text:
return prop.text
# Check compiler plugin configuration
for build in [root.find(f"m:build", ns), root.find("build")]:
if build is not None:
for plugins in [build.find(f"m:plugins", ns), build.find("plugins")]:
if plugins is not None:
for plugin in plugins.findall(f"m:plugin", ns) + plugins.findall("plugin"):
artifact_id = plugin.find(f"m:artifactId", ns) or plugin.find("artifactId")
if artifact_id is not None and artifact_id.text == "maven-compiler-plugin":
config = plugin.find(f"m:configuration", ns) or plugin.find("configuration")
if config is not None:
source = config.find(f"m:source", ns) or config.find("source")
if source is not None and source.text:
return source.text
return None
def _get_gradle_project_info(project_root: Path) -> JavaProjectInfo | None:
"""Get project info from Gradle build file.
Note: This is a basic implementation. Full Gradle parsing would require
running Gradle with a custom task or using the Gradle tooling API.
Args:
project_root: Root directory of the Gradle project.
Returns:
JavaProjectInfo with basic Gradle project structure.
"""
# Standard Gradle directory structure
source_roots = []
test_roots = []
main_src = project_root / "src" / "main" / "java"
if main_src.exists():
source_roots.append(main_src)
test_src = project_root / "src" / "test" / "java"
if test_src.exists():
test_roots.append(test_src)
build_dir = project_root / "build"
return JavaProjectInfo(
project_root=project_root,
build_tool=BuildTool.GRADLE,
source_roots=source_roots,
test_roots=test_roots,
target_dir=build_dir,
group_id=None, # Would need to parse build.gradle
artifact_id=None,
version=None,
java_version=None,
)
def find_maven_executable() -> str | None:
"""Find the Maven executable.
Returns:
Path to mvn executable, or None if not found.
"""
# Check for Maven wrapper first
if os.path.exists("mvnw"):
return "./mvnw"
if os.path.exists("mvnw.cmd"):
return "mvnw.cmd"
# Check system Maven
mvn_path = shutil.which("mvn")
if mvn_path:
return mvn_path
return None
def find_gradle_executable() -> str | None:
"""Find the Gradle executable.
Returns:
Path to gradle executable, or None if not found.
"""
# Check for Gradle wrapper first
if os.path.exists("gradlew"):
return "./gradlew"
if os.path.exists("gradlew.bat"):
return "gradlew.bat"
# Check system Gradle
gradle_path = shutil.which("gradle")
if gradle_path:
return gradle_path
return None
def run_maven_tests(
project_root: Path,
test_classes: list[str] | None = None,
test_methods: list[str] | None = None,
env: dict[str, str] | None = None,
timeout: int = 300,
skip_compilation: bool = False,
) -> MavenTestResult:
"""Run Maven tests using Surefire.
Args:
project_root: Root directory of the Maven project.
test_classes: Optional list of test class names to run.
test_methods: Optional list of specific test methods (format: ClassName#methodName).
env: Optional environment variables.
timeout: Maximum time in seconds for test execution.
skip_compilation: Whether to skip compilation (useful when only running tests).
Returns:
MavenTestResult with test execution results.
"""
mvn = find_maven_executable()
if not mvn:
logger.error("Maven not found. Please install Maven or use Maven wrapper.")
return MavenTestResult(
success=False,
tests_run=0,
failures=0,
errors=0,
skipped=0,
surefire_reports_dir=None,
stdout="",
stderr="Maven not found",
returncode=-1,
)
# Build Maven command
cmd = [mvn]
if skip_compilation:
cmd.append("-Dmaven.test.skip=false")
cmd.append("-DskipTests=false")
cmd.append("surefire:test")
else:
cmd.append("test")
# Add test filtering
if test_classes or test_methods:
if test_methods:
# Format: -Dtest=ClassName#method1+method2,OtherClass#method3
tests = ",".join(test_methods)
elif test_classes:
tests = ",".join(test_classes)
cmd.extend(["-Dtest=" + tests])
# Fail at end to run all tests
cmd.append("-fae")
# Use full environment with optional overrides
run_env = os.environ.copy()
if env:
run_env.update(env)
try:
result = subprocess.run(
cmd,
check=False,
cwd=project_root,
env=run_env,
capture_output=True,
text=True,
timeout=timeout,
)
# Parse test results from Surefire reports
surefire_dir = project_root / "target" / "surefire-reports"
tests_run, failures, errors, skipped = _parse_surefire_reports(surefire_dir)
return MavenTestResult(
success=result.returncode == 0,
tests_run=tests_run,
failures=failures,
errors=errors,
skipped=skipped,
surefire_reports_dir=surefire_dir if surefire_dir.exists() else None,
stdout=result.stdout,
stderr=result.stderr,
returncode=result.returncode,
)
except subprocess.TimeoutExpired:
logger.error("Maven test execution timed out after %d seconds", timeout)
return MavenTestResult(
success=False,
tests_run=0,
failures=0,
errors=0,
skipped=0,
surefire_reports_dir=None,
stdout="",
stderr=f"Test execution timed out after {timeout} seconds",
returncode=-2,
)
except Exception as e:
logger.exception("Maven test execution failed: %s", e)
return MavenTestResult(
success=False,
tests_run=0,
failures=0,
errors=0,
skipped=0,
surefire_reports_dir=None,
stdout="",
stderr=str(e),
returncode=-1,
)
def _parse_surefire_reports(surefire_dir: Path) -> tuple[int, int, int, int]:
"""Parse Surefire XML reports to get test counts.
Args:
surefire_dir: Directory containing Surefire XML reports.
Returns:
Tuple of (tests_run, failures, errors, skipped).
"""
tests_run = 0
failures = 0
errors = 0
skipped = 0
if not surefire_dir.exists():
return tests_run, failures, errors, skipped
for xml_file in surefire_dir.glob("TEST-*.xml"):
try:
tree = ET.parse(xml_file)
root = tree.getroot()
tests_run += int(root.get("tests", 0))
failures += int(root.get("failures", 0))
errors += int(root.get("errors", 0))
skipped += int(root.get("skipped", 0))
except ET.ParseError as e:
logger.warning("Failed to parse Surefire report %s: %s", xml_file, e)
return tests_run, failures, errors, skipped
def compile_maven_project(
project_root: Path,
include_tests: bool = True,
env: dict[str, str] | None = None,
timeout: int = 300,
) -> tuple[bool, str, str]:
"""Compile a Maven project.
Args:
project_root: Root directory of the Maven project.
include_tests: Whether to compile test classes as well.
env: Optional environment variables.
timeout: Maximum time in seconds for compilation.
Returns:
Tuple of (success, stdout, stderr).
"""
mvn = find_maven_executable()
if not mvn:
return False, "", "Maven not found"
cmd = [mvn]
if include_tests:
cmd.append("test-compile")
else:
cmd.append("compile")
# Skip test execution
cmd.append("-DskipTests")
run_env = os.environ.copy()
if env:
run_env.update(env)
try:
result = subprocess.run(
cmd,
check=False,
cwd=project_root,
env=run_env,
capture_output=True,
text=True,
timeout=timeout,
)
return result.returncode == 0, result.stdout, result.stderr
except subprocess.TimeoutExpired:
return False, "", f"Compilation timed out after {timeout} seconds"
except Exception as e:
return False, "", str(e)
def install_codeflash_runtime(project_root: Path, runtime_jar_path: Path) -> bool:
"""Install the codeflash runtime JAR to the local Maven repository.
Args:
project_root: Root directory of the Maven project.
runtime_jar_path: Path to the codeflash-runtime.jar file.
Returns:
True if installation succeeded, False otherwise.
"""
mvn = find_maven_executable()
if not mvn:
logger.error("Maven not found")
return False
if not runtime_jar_path.exists():
logger.error("Runtime JAR not found: %s", runtime_jar_path)
return False
cmd = [
mvn,
"install:install-file",
f"-Dfile={runtime_jar_path}",
"-DgroupId=com.codeflash",
"-DartifactId=codeflash-runtime",
"-Dversion=1.0.0",
"-Dpackaging=jar",
]
try:
result = subprocess.run(
cmd,
check=False,
cwd=project_root,
capture_output=True,
text=True,
timeout=60,
)
if result.returncode == 0:
logger.info("Successfully installed codeflash-runtime to local Maven repository")
return True
else:
logger.error("Failed to install codeflash-runtime: %s", result.stderr)
return False
except Exception as e:
logger.exception("Failed to install codeflash-runtime: %s", e)
return False
def add_codeflash_dependency_to_pom(pom_path: Path) -> bool:
"""Add codeflash-runtime dependency to pom.xml if not present.
Args:
pom_path: Path to the pom.xml file.
Returns:
True if dependency was added or already present, False on error.
"""
if not pom_path.exists():
return False
try:
tree = ET.parse(pom_path)
root = tree.getroot()
# Handle Maven namespace
ns = {"m": "http://maven.apache.org/POM/4.0.0"}
ns_prefix = "{http://maven.apache.org/POM/4.0.0}"
# Check if namespace is used
if root.tag.startswith("{"):
use_ns = True
else:
use_ns = False
ns_prefix = ""
# Find or create dependencies section
deps = root.find(f"{ns_prefix}dependencies" if use_ns else "dependencies")
if deps is None:
deps = ET.SubElement(root, f"{ns_prefix}dependencies" if use_ns else "dependencies")
# Check if codeflash dependency already exists
for dep in deps.findall(f"{ns_prefix}dependency" if use_ns else "dependency"):
group = dep.find(f"{ns_prefix}groupId" if use_ns else "groupId")
artifact = dep.find(f"{ns_prefix}artifactId" if use_ns else "artifactId")
if group is not None and artifact is not None:
if group.text == "com.codeflash" and artifact.text == "codeflash-runtime":
logger.info("codeflash-runtime dependency already present in pom.xml")
return True
# Add codeflash dependency
dep_elem = ET.SubElement(deps, f"{ns_prefix}dependency" if use_ns else "dependency")
group_elem = ET.SubElement(dep_elem, f"{ns_prefix}groupId" if use_ns else "groupId")
group_elem.text = "com.codeflash"
artifact_elem = ET.SubElement(dep_elem, f"{ns_prefix}artifactId" if use_ns else "artifactId")
artifact_elem.text = "codeflash-runtime"
version_elem = ET.SubElement(dep_elem, f"{ns_prefix}version" if use_ns else "version")
version_elem.text = "1.0.0"
scope_elem = ET.SubElement(dep_elem, f"{ns_prefix}scope" if use_ns else "scope")
scope_elem.text = "test"
# Write back to file
tree.write(pom_path, xml_declaration=True, encoding="utf-8")
logger.info("Added codeflash-runtime dependency to pom.xml")
return True
except ET.ParseError as e:
logger.error("Failed to parse pom.xml: %s", e)
return False
except Exception as e:
logger.exception("Failed to add dependency to pom.xml: %s", e)
return False
def find_test_root(project_root: Path) -> Path | None:
"""Find the test root directory for a Java project.
Args:
project_root: Root directory of the Java project.
Returns:
Path to test root, or None if not found.
"""
build_tool = detect_build_tool(project_root)
if build_tool in (BuildTool.MAVEN, BuildTool.GRADLE):
test_root = project_root / "src" / "test" / "java"
if test_root.exists():
return test_root
# Check common alternative locations
for test_dir in ["test", "tests", "src/test"]:
test_path = project_root / test_dir
if test_path.exists():
return test_path
return None
def find_source_root(project_root: Path) -> Path | None:
"""Find the main source root directory for a Java project.
Args:
project_root: Root directory of the Java project.
Returns:
Path to source root, or None if not found.
"""
build_tool = detect_build_tool(project_root)
if build_tool in (BuildTool.MAVEN, BuildTool.GRADLE):
src_root = project_root / "src" / "main" / "java"
if src_root.exists():
return src_root
# Check common alternative locations
for src_dir in ["src", "source", "java"]:
src_path = project_root / src_dir
if src_path.exists() and any(src_path.rglob("*.java")):
return src_path
return None
def get_classpath(project_root: Path) -> str | None:
"""Get the classpath for a Java project.
For Maven projects, this runs 'mvn dependency:build-classpath'.
Args:
project_root: Root directory of the Java project.
Returns:
Classpath string, or None if unable to determine.
"""
build_tool = detect_build_tool(project_root)
if build_tool == BuildTool.MAVEN:
return _get_maven_classpath(project_root)
if build_tool == BuildTool.GRADLE:
return _get_gradle_classpath(project_root)
return None
def _get_maven_classpath(project_root: Path) -> str | None:
"""Get classpath from Maven."""
mvn = find_maven_executable()
if not mvn:
return None
try:
result = subprocess.run(
[mvn, "dependency:build-classpath", "-q", "-DincludeScope=test"],
check=False,
cwd=project_root,
capture_output=True,
text=True,
timeout=120,
)
if result.returncode == 0:
# The classpath is in stdout
return result.stdout.strip()
except Exception as e:
logger.warning("Failed to get Maven classpath: %s", e)
return None
def _get_gradle_classpath(project_root: Path) -> str | None:
"""Get classpath from Gradle.
Note: This requires a custom task to be added to build.gradle.
Returns None for now as Gradle support is not fully implemented.
"""
return None

View file

@ -0,0 +1,333 @@
"""Java test result comparison.
This module provides functionality to compare test results between
original and optimized Java code using the codeflash-runtime Comparator.
"""
from __future__ import annotations
import json
import logging
import os
import subprocess
from pathlib import Path
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from codeflash.models.models import TestDiff
logger = logging.getLogger(__name__)
def _find_comparator_jar(project_root: Path | None = None) -> Path | None:
"""Find the codeflash-runtime JAR with the Comparator class.
Args:
project_root: Project root directory.
Returns:
Path to codeflash-runtime JAR if found, None otherwise.
"""
search_dirs = []
if project_root:
search_dirs.append(project_root)
search_dirs.append(Path.cwd())
# Search for the JAR in common locations
for base_dir in search_dirs:
# Check in target directory (after Maven install)
for jar_path in [
base_dir / "target" / "dependency" / "codeflash-runtime-1.0.0.jar",
base_dir / "target" / "codeflash-runtime-1.0.0.jar",
base_dir / "lib" / "codeflash-runtime-1.0.0.jar",
base_dir / ".codeflash" / "codeflash-runtime-1.0.0.jar",
]:
if jar_path.exists():
return jar_path
# Check local Maven repository
m2_jar = Path.home() / ".m2" / "repository" / "com" / "codeflash" / "codeflash-runtime" / "1.0.0" / "codeflash-runtime-1.0.0.jar"
if m2_jar.exists():
return m2_jar
return None
def _find_java_executable() -> str | None:
"""Find the Java executable.
Returns:
Path to java executable, or None if not found.
"""
import shutil
# Check JAVA_HOME
java_home = os.environ.get("JAVA_HOME")
if java_home:
java_path = Path(java_home) / "bin" / "java"
if java_path.exists():
return str(java_path)
# Check PATH
java_path = shutil.which("java")
if java_path:
return java_path
return None
def compare_test_results(
original_sqlite_path: Path,
candidate_sqlite_path: Path,
comparator_jar: Path | None = None,
project_root: Path | None = None,
) -> tuple[bool, list]:
"""Compare Java test results using the codeflash-runtime Comparator.
This function calls the Java Comparator CLI that:
1. Reads serialized behavior data from both SQLite databases
2. Deserializes using Gson
3. Compares results using deep equality (handles Maps, Lists, arrays, etc.)
4. Returns comparison results as JSON
Args:
original_sqlite_path: Path to SQLite database with original code results.
candidate_sqlite_path: Path to SQLite database with candidate code results.
comparator_jar: Optional path to the codeflash-runtime JAR.
project_root: Project root directory.
Returns:
Tuple of (all_equivalent, list of TestDiff objects).
"""
# Import lazily to avoid circular imports
from codeflash.models.models import TestDiff, TestDiffScope
java_exe = _find_java_executable()
if not java_exe:
logger.error("Java not found. Please install Java to compare test results.")
return False, []
jar_path = comparator_jar or _find_comparator_jar(project_root)
if not jar_path or not jar_path.exists():
logger.error(
"codeflash-runtime JAR not found. "
"Please ensure the codeflash-runtime is installed in your project."
)
return False, []
if not original_sqlite_path.exists():
logger.error(f"Original SQLite database not found: {original_sqlite_path}")
return False, []
if not candidate_sqlite_path.exists():
logger.error(f"Candidate SQLite database not found: {candidate_sqlite_path}")
return False, []
cwd = project_root or Path.cwd()
try:
result = subprocess.run(
[
java_exe,
"-cp",
str(jar_path),
"com.codeflash.Comparator",
str(original_sqlite_path),
str(candidate_sqlite_path),
],
check=False,
capture_output=True,
text=True,
timeout=60,
cwd=str(cwd),
)
# Parse the JSON output
try:
if not result.stdout or not result.stdout.strip():
logger.error("Java comparator returned empty output")
if result.stderr:
logger.error(f"stderr: {result.stderr}")
return False, []
comparison = json.loads(result.stdout)
except json.JSONDecodeError as e:
logger.error(f"Failed to parse Java comparator output: {e}")
logger.error(f"stdout: {result.stdout[:500] if result.stdout else '(empty)'}")
if result.stderr:
logger.error(f"stderr: {result.stderr[:500]}")
return False, []
# Check for errors in the JSON response
if comparison.get("error"):
logger.error(f"Java comparator error: {comparison['error']}")
return False, []
# Check for unexpected exit codes
if result.returncode not in {0, 1}:
logger.error(f"Java comparator failed with exit code {result.returncode}")
if result.stderr:
logger.error(f"stderr: {result.stderr}")
return False, []
# Convert diffs to TestDiff objects
test_diffs: list[TestDiff] = []
for diff in comparison.get("diffs", []):
scope_str = diff.get("scope", "return_value")
scope = TestDiffScope.RETURN_VALUE
if scope_str == "exception":
scope = TestDiffScope.DID_PASS
elif scope_str == "missing":
scope = TestDiffScope.DID_PASS
# Build test identifier
method_id = diff.get("methodId", "unknown")
call_id = diff.get("callId", 0)
test_src_code = f"// Method: {method_id}\n// Call ID: {call_id}"
test_diffs.append(
TestDiff(
scope=scope,
original_value=diff.get("originalValue"),
candidate_value=diff.get("candidateValue"),
test_src_code=test_src_code,
candidate_pytest_error=diff.get("candidateError"),
original_pass=True,
candidate_pass=scope_str not in ("missing", "exception"),
original_pytest_error=diff.get("originalError"),
)
)
logger.debug(
f"Java test diff:\n"
f" Method: {method_id}\n"
f" Call ID: {call_id}\n"
f" Scope: {scope_str}\n"
f" Original: {str(diff.get('originalValue', 'N/A'))[:100]}\n"
f" Candidate: {str(diff.get('candidateValue', 'N/A'))[:100]}"
)
equivalent = comparison.get("equivalent", False)
logger.info(
f"Java comparison: {'equivalent' if equivalent else 'DIFFERENT'} "
f"({comparison.get('totalInvocations', 0)} invocations, {len(test_diffs)} diffs)"
)
return equivalent, test_diffs
except subprocess.TimeoutExpired:
logger.error("Java comparator timed out")
return False, []
except FileNotFoundError:
logger.error("Java not found. Please install Java to compare test results.")
return False, []
except Exception as e:
logger.error(f"Error running Java comparator: {e}")
return False, []
def compare_invocations_directly(
original_results: dict,
candidate_results: dict,
) -> tuple[bool, list]:
"""Compare test invocations directly from Python dictionaries.
This is a fallback when the Java comparator is not available.
It performs basic equality comparison on serialized JSON values.
Args:
original_results: Dict mapping call_id to result data from original code.
candidate_results: Dict mapping call_id to result data from candidate code.
Returns:
Tuple of (all_equivalent, list of TestDiff objects).
"""
# Import lazily to avoid circular imports
from codeflash.models.models import TestDiff, TestDiffScope
test_diffs: list[TestDiff] = []
# Get all call IDs
all_call_ids = set(original_results.keys()) | set(candidate_results.keys())
for call_id in all_call_ids:
original = original_results.get(call_id)
candidate = candidate_results.get(call_id)
if original is None and candidate is not None:
# Candidate has extra invocation
test_diffs.append(
TestDiff(
scope=TestDiffScope.DID_PASS,
original_value=None,
candidate_value=candidate.get("result_json"),
test_src_code=f"// Call ID: {call_id}",
candidate_pytest_error=None,
original_pass=True,
candidate_pass=True,
original_pytest_error=None,
)
)
elif original is not None and candidate is None:
# Candidate missing invocation
test_diffs.append(
TestDiff(
scope=TestDiffScope.DID_PASS,
original_value=original.get("result_json"),
candidate_value=None,
test_src_code=f"// Call ID: {call_id}",
candidate_pytest_error="Missing invocation in candidate",
original_pass=True,
candidate_pass=False,
original_pytest_error=None,
)
)
elif original is not None and candidate is not None:
# Both have invocations - compare results
orig_result = original.get("result_json")
cand_result = candidate.get("result_json")
orig_error = original.get("error_json")
cand_error = candidate.get("error_json")
# Check for exception differences
if orig_error != cand_error:
test_diffs.append(
TestDiff(
scope=TestDiffScope.DID_PASS,
original_value=orig_error,
candidate_value=cand_error,
test_src_code=f"// Call ID: {call_id}",
candidate_pytest_error=cand_error,
original_pass=orig_error is None,
candidate_pass=cand_error is None,
original_pytest_error=orig_error,
)
)
elif orig_result != cand_result:
# Results differ
test_diffs.append(
TestDiff(
scope=TestDiffScope.RETURN_VALUE,
original_value=orig_result,
candidate_value=cand_result,
test_src_code=f"// Call ID: {call_id}",
candidate_pytest_error=None,
original_pass=True,
candidate_pass=True,
original_pytest_error=None,
)
)
equivalent = len(test_diffs) == 0
logger.info(
f"Python comparison: {'equivalent' if equivalent else 'DIFFERENT'} "
f"({len(all_call_ids)} invocations, {len(test_diffs)} diffs)"
)
return equivalent, test_diffs

View file

@ -0,0 +1,426 @@
"""Java project configuration detection.
This module provides functionality to detect and read Java project
configuration, including build tool settings, test framework configuration,
and project structure.
"""
from __future__ import annotations
import logging
import xml.etree.ElementTree as ET
from dataclasses import dataclass, field
from pathlib import Path
from typing import TYPE_CHECKING
from codeflash.languages.java.build_tools import (
BuildTool,
detect_build_tool,
find_source_root,
find_test_root,
get_project_info,
)
if TYPE_CHECKING:
pass
logger = logging.getLogger(__name__)
@dataclass
class JavaProjectConfig:
"""Configuration for a Java project."""
project_root: Path
build_tool: BuildTool
source_root: Path | None
test_root: Path | None
java_version: str | None
encoding: str
test_framework: str # "junit5", "junit4", "testng"
group_id: str | None
artifact_id: str | None
version: str | None
# Dependencies
has_junit5: bool = False
has_junit4: bool = False
has_testng: bool = False
has_mockito: bool = False
has_assertj: bool = False
# Build configuration
compiler_source: str | None = None
compiler_target: str | None = None
# Plugin configurations
surefire_includes: list[str] = field(default_factory=list)
surefire_excludes: list[str] = field(default_factory=list)
def detect_java_project(project_root: Path) -> JavaProjectConfig | None:
"""Detect and return Java project configuration.
Args:
project_root: Root directory of the project.
Returns:
JavaProjectConfig if a Java project is detected, None otherwise.
"""
# Check if this is a Java project
build_tool = detect_build_tool(project_root)
if build_tool == BuildTool.UNKNOWN:
# Check if there are any Java files
java_files = list(project_root.rglob("*.java"))
if not java_files:
return None
# Get basic project info
project_info = get_project_info(project_root)
# Detect test framework
test_framework, has_junit5, has_junit4, has_testng = _detect_test_framework(
project_root, build_tool
)
# Detect other dependencies
has_mockito, has_assertj = _detect_test_dependencies(project_root, build_tool)
# Get source/test roots
source_root = find_source_root(project_root)
test_root = find_test_root(project_root)
# Get compiler settings
compiler_source, compiler_target = _get_compiler_settings(project_root, build_tool)
# Get surefire configuration
surefire_includes, surefire_excludes = _get_surefire_config(project_root)
return JavaProjectConfig(
project_root=project_root,
build_tool=build_tool,
source_root=source_root,
test_root=test_root,
java_version=project_info.java_version if project_info else None,
encoding="UTF-8", # Default, could be detected from pom.xml
test_framework=test_framework,
group_id=project_info.group_id if project_info else None,
artifact_id=project_info.artifact_id if project_info else None,
version=project_info.version if project_info else None,
has_junit5=has_junit5,
has_junit4=has_junit4,
has_testng=has_testng,
has_mockito=has_mockito,
has_assertj=has_assertj,
compiler_source=compiler_source,
compiler_target=compiler_target,
surefire_includes=surefire_includes,
surefire_excludes=surefire_excludes,
)
def _detect_test_framework(
project_root: Path, build_tool: BuildTool
) -> tuple[str, bool, bool, bool]:
"""Detect which test framework the project uses.
Args:
project_root: Root directory of the project.
build_tool: The detected build tool.
Returns:
Tuple of (framework_name, has_junit5, has_junit4, has_testng).
"""
has_junit5 = False
has_junit4 = False
has_testng = False
if build_tool == BuildTool.MAVEN:
has_junit5, has_junit4, has_testng = _detect_test_deps_from_pom(project_root)
elif build_tool == BuildTool.GRADLE:
has_junit5, has_junit4, has_testng = _detect_test_deps_from_gradle(project_root)
# Also check test source files for import statements
test_root = find_test_root(project_root)
if test_root and test_root.exists():
for test_file in test_root.rglob("*.java"):
try:
content = test_file.read_text(encoding="utf-8")
if "org.junit.jupiter" in content:
has_junit5 = True
if "org.junit.Test" in content or "org.junit.Assert" in content:
has_junit4 = True
if "org.testng" in content:
has_testng = True
except Exception:
pass
# Determine primary framework (prefer JUnit 5)
if has_junit5:
return "junit5", has_junit5, has_junit4, has_testng
if has_junit4:
return "junit4", has_junit5, has_junit4, has_testng
if has_testng:
return "testng", has_junit5, has_junit4, has_testng
# Default to JUnit 5 if nothing detected
return "junit5", has_junit5, has_junit4, has_testng
def _detect_test_deps_from_pom(project_root: Path) -> tuple[bool, bool, bool]:
"""Detect test framework dependencies from pom.xml.
Returns:
Tuple of (has_junit5, has_junit4, has_testng).
"""
pom_path = project_root / "pom.xml"
if not pom_path.exists():
return False, False, False
has_junit5 = False
has_junit4 = False
has_testng = False
try:
tree = ET.parse(pom_path)
root = tree.getroot()
# Handle namespace
ns = {"m": "http://maven.apache.org/POM/4.0.0"}
# Search for dependencies
for deps_path in ["dependencies", "m:dependencies"]:
deps = root.find(deps_path, ns) if "m:" in deps_path else root.find(deps_path)
if deps is None:
continue
for dep_path in ["dependency", "m:dependency"]:
deps_list = deps.findall(dep_path, ns) if "m:" in dep_path else deps.findall(dep_path)
for dep in deps_list:
artifact_id = None
group_id = None
for child in dep:
tag = child.tag.replace("{http://maven.apache.org/POM/4.0.0}", "")
if tag == "artifactId":
artifact_id = child.text
elif tag == "groupId":
group_id = child.text
if group_id == "org.junit.jupiter" or (
artifact_id and "junit-jupiter" in artifact_id
):
has_junit5 = True
elif group_id == "junit" and artifact_id == "junit":
has_junit4 = True
elif group_id == "org.testng":
has_testng = True
except ET.ParseError:
pass
return has_junit5, has_junit4, has_testng
def _detect_test_deps_from_gradle(project_root: Path) -> tuple[bool, bool, bool]:
"""Detect test framework dependencies from Gradle build files.
Returns:
Tuple of (has_junit5, has_junit4, has_testng).
"""
has_junit5 = False
has_junit4 = False
has_testng = False
for gradle_file in ["build.gradle", "build.gradle.kts"]:
gradle_path = project_root / gradle_file
if gradle_path.exists():
try:
content = gradle_path.read_text(encoding="utf-8")
if "junit-jupiter" in content or "useJUnitPlatform" in content:
has_junit5 = True
if "junit:junit" in content:
has_junit4 = True
if "testng" in content.lower():
has_testng = True
except Exception:
pass
return has_junit5, has_junit4, has_testng
def _detect_test_dependencies(
project_root: Path, build_tool: BuildTool
) -> tuple[bool, bool]:
"""Detect additional test dependencies (Mockito, AssertJ).
Returns:
Tuple of (has_mockito, has_assertj).
"""
has_mockito = False
has_assertj = False
pom_path = project_root / "pom.xml"
if pom_path.exists():
try:
content = pom_path.read_text(encoding="utf-8")
has_mockito = "mockito" in content.lower()
has_assertj = "assertj" in content.lower()
except Exception:
pass
for gradle_file in ["build.gradle", "build.gradle.kts"]:
gradle_path = project_root / gradle_file
if gradle_path.exists():
try:
content = gradle_path.read_text(encoding="utf-8")
if "mockito" in content.lower():
has_mockito = True
if "assertj" in content.lower():
has_assertj = True
except Exception:
pass
return has_mockito, has_assertj
def _get_compiler_settings(
project_root: Path, build_tool: BuildTool
) -> tuple[str | None, str | None]:
"""Get compiler source and target settings.
Returns:
Tuple of (source_version, target_version).
"""
if build_tool != BuildTool.MAVEN:
return None, None
pom_path = project_root / "pom.xml"
if not pom_path.exists():
return None, None
source = None
target = None
try:
tree = ET.parse(pom_path)
root = tree.getroot()
ns = {"m": "http://maven.apache.org/POM/4.0.0"}
# Check properties
for props_path in ["properties", "m:properties"]:
props = root.find(props_path, ns) if "m:" in props_path else root.find(props_path)
if props is not None:
for child in props:
tag = child.tag.replace("{http://maven.apache.org/POM/4.0.0}", "")
if tag == "maven.compiler.source":
source = child.text
elif tag == "maven.compiler.target":
target = child.text
except ET.ParseError:
pass
return source, target
def _get_surefire_config(project_root: Path) -> tuple[list[str], list[str]]:
"""Get Maven Surefire plugin includes/excludes configuration.
Returns:
Tuple of (includes, excludes) patterns.
"""
includes: list[str] = []
excludes: list[str] = []
pom_path = project_root / "pom.xml"
if not pom_path.exists():
return includes, excludes
try:
tree = ET.parse(pom_path)
root = tree.getroot()
ns = {"m": "http://maven.apache.org/POM/4.0.0"}
# Find surefire plugin configuration
# This is a simplified search - a full implementation would
# handle nested build/plugins/plugin structure
content = pom_path.read_text(encoding="utf-8")
if "maven-surefire-plugin" in content:
# Parse includes/excludes if present
# This is a basic implementation
pass
except (ET.ParseError, Exception):
pass
# Return default patterns if none configured
if not includes:
includes = ["**/Test*.java", "**/*Test.java", "**/*Tests.java", "**/*TestCase.java"]
if not excludes:
excludes = ["**/*IT.java", "**/*IntegrationTest.java"]
return includes, excludes
def is_java_project(project_root: Path) -> bool:
"""Check if a directory is a Java project.
Args:
project_root: Directory to check.
Returns:
True if this appears to be a Java project.
"""
# Check for build tool config files
if (project_root / "pom.xml").exists():
return True
if (project_root / "build.gradle").exists():
return True
if (project_root / "build.gradle.kts").exists():
return True
# Check for Java source files
for pattern in ["src/**/*.java", "*.java"]:
if list(project_root.glob(pattern)):
return True
return False
def get_test_file_pattern(config: JavaProjectConfig) -> str:
"""Get the test file naming pattern for a project.
Args:
config: The project configuration.
Returns:
Glob pattern for test files.
"""
# Default JUnit pattern
return "*Test.java"
def get_test_class_pattern(config: JavaProjectConfig) -> str:
"""Get the regex pattern for test class names.
Args:
config: The project configuration.
Returns:
Regex pattern for test class names.
"""
return r".*Test(s)?$|^Test.*"

View file

@ -0,0 +1,345 @@
"""Java code context extraction.
This module provides functionality to extract code context needed for
optimization, including the target function, helper functions, imports,
and other dependencies.
"""
from __future__ import annotations
import logging
from pathlib import Path
from typing import TYPE_CHECKING
from codeflash.languages.base import CodeContext, FunctionInfo, HelperFunction, Language
from codeflash.languages.java.discovery import discover_functions_from_source
from codeflash.languages.java.import_resolver import JavaImportResolver, find_helper_files
from codeflash.languages.java.parser import JavaAnalyzer, get_java_analyzer
if TYPE_CHECKING:
pass
logger = logging.getLogger(__name__)
def extract_code_context(
function: FunctionInfo,
project_root: Path,
module_root: Path | None = None,
max_helper_depth: int = 2,
analyzer: JavaAnalyzer | None = None,
) -> CodeContext:
"""Extract code context for a Java function.
This extracts:
- The target function's source code
- Import statements
- Helper functions (project-internal dependencies)
- Read-only context (class fields, constants, etc.)
Args:
function: The function to extract context for.
project_root: Root of the project.
module_root: Root of the module (defaults to project_root).
max_helper_depth: Maximum depth to trace helper functions.
analyzer: Optional JavaAnalyzer instance.
Returns:
CodeContext with target code and dependencies.
"""
analyzer = analyzer or get_java_analyzer()
module_root = module_root or project_root
# Read the source file
try:
source = function.file_path.read_text(encoding="utf-8")
except Exception as e:
logger.error("Failed to read %s: %s", function.file_path, e)
return CodeContext(
target_code="",
target_file=function.file_path,
language=Language.JAVA,
)
# Extract target function code
target_code = extract_function_source(source, function)
# Extract imports
imports = analyzer.find_imports(source)
import_statements = [_import_to_statement(imp) for imp in imports]
# Extract helper functions
helper_functions = find_helper_functions(
function, project_root, max_helper_depth, analyzer
)
# Extract read-only context (class fields, constants, etc.)
read_only_context = extract_read_only_context(source, function, analyzer)
return CodeContext(
target_code=target_code,
target_file=function.file_path,
helper_functions=helper_functions,
read_only_context=read_only_context,
imports=import_statements,
language=Language.JAVA,
)
def extract_function_source(source: str, function: FunctionInfo) -> str:
"""Extract the source code of a function from the full file source.
Args:
source: The full file source code.
function: The function to extract.
Returns:
The function's source code.
"""
lines = source.splitlines(keepends=True)
# Include Javadoc if present
start_line = function.doc_start_line or function.start_line
end_line = function.end_line
# Convert from 1-indexed to 0-indexed
start_idx = start_line - 1
end_idx = end_line
return "".join(lines[start_idx:end_idx])
def find_helper_functions(
function: FunctionInfo,
project_root: Path,
max_depth: int = 2,
analyzer: JavaAnalyzer | None = None,
) -> list[HelperFunction]:
"""Find helper functions that the target function depends on.
Args:
function: The target function to analyze.
project_root: Root of the project.
max_depth: Maximum depth to trace dependencies.
analyzer: Optional JavaAnalyzer instance.
Returns:
List of HelperFunction objects.
"""
analyzer = analyzer or get_java_analyzer()
helpers: list[HelperFunction] = []
visited_functions: set[str] = set()
# Find helper files through imports
helper_files = find_helper_files(
function.file_path, project_root, max_depth, analyzer
)
for file_path, class_names in helper_files.items():
try:
source = file_path.read_text(encoding="utf-8")
file_functions = discover_functions_from_source(source, file_path, analyzer=analyzer)
for func in file_functions:
func_id = f"{file_path}:{func.qualified_name}"
if func_id not in visited_functions:
visited_functions.add(func_id)
# Extract the function source
func_source = extract_function_source(source, func)
helpers.append(
HelperFunction(
name=func.name,
qualified_name=func.qualified_name,
file_path=file_path,
source_code=func_source,
start_line=func.start_line,
end_line=func.end_line,
)
)
except Exception as e:
logger.warning("Failed to extract helpers from %s: %s", file_path, e)
# Also find helper methods in the same class
same_file_helpers = _find_same_class_helpers(function, analyzer)
for helper in same_file_helpers:
func_id = f"{function.file_path}:{helper.qualified_name}"
if func_id not in visited_functions:
visited_functions.add(func_id)
helpers.append(helper)
return helpers
def _find_same_class_helpers(
function: FunctionInfo,
analyzer: JavaAnalyzer,
) -> list[HelperFunction]:
"""Find helper methods in the same class as the target function.
Args:
function: The target function.
analyzer: JavaAnalyzer instance.
Returns:
List of helper functions in the same class.
"""
helpers: list[HelperFunction] = []
if not function.class_name:
return helpers
try:
source = function.file_path.read_text(encoding="utf-8")
source_bytes = source.encode("utf8")
# Find all methods in the file
methods = analyzer.find_methods(source)
# Find which methods the target function calls
target_method = None
for method in methods:
if method.name == function.name and method.class_name == function.class_name:
target_method = method
break
if not target_method:
return helpers
# Get method calls from the target
called_methods = set(analyzer.find_method_calls(source, target_method))
# Add called methods from the same class as helpers
for method in methods:
if (
method.name != function.name
and method.class_name == function.class_name
and method.name in called_methods
):
func_source = source_bytes[
method.node.start_byte : method.node.end_byte
].decode("utf8")
helpers.append(
HelperFunction(
name=method.name,
qualified_name=f"{method.class_name}.{method.name}",
file_path=function.file_path,
source_code=func_source,
start_line=method.start_line,
end_line=method.end_line,
)
)
except Exception as e:
logger.warning("Failed to find same-class helpers: %s", e)
return helpers
def extract_read_only_context(
source: str,
function: FunctionInfo,
analyzer: JavaAnalyzer,
) -> str:
"""Extract read-only context (fields, constants, inner classes).
This extracts class-level context that the function might depend on
but shouldn't be modified during optimization.
Args:
source: The full source code.
function: The target function.
analyzer: JavaAnalyzer instance.
Returns:
String containing read-only context code.
"""
if not function.class_name:
return ""
context_parts: list[str] = []
# Find fields in the same class
fields = analyzer.find_fields(source, function.class_name)
for field in fields:
context_parts.append(field.source_text)
return "\n".join(context_parts)
def _import_to_statement(import_info) -> str:
"""Convert a JavaImportInfo to an import statement string.
Args:
import_info: The import info.
Returns:
Import statement string.
"""
if import_info.is_static:
prefix = "import static "
else:
prefix = "import "
suffix = ".*" if import_info.is_wildcard else ""
return f"{prefix}{import_info.import_path}{suffix};"
def extract_class_context(
file_path: Path,
class_name: str,
analyzer: JavaAnalyzer | None = None,
) -> str:
"""Extract the full context of a class.
Args:
file_path: Path to the Java file.
class_name: Name of the class.
analyzer: Optional JavaAnalyzer instance.
Returns:
String containing the class code with imports.
"""
analyzer = analyzer or get_java_analyzer()
try:
source = file_path.read_text(encoding="utf-8")
# Find the class
classes = analyzer.find_classes(source)
target_class = None
for cls in classes:
if cls.name == class_name:
target_class = cls
break
if not target_class:
return ""
# Extract imports
imports = analyzer.find_imports(source)
import_statements = [_import_to_statement(imp) for imp in imports]
# Get package
package = analyzer.get_package_name(source)
package_stmt = f"package {package};\n\n" if package else ""
# Get class source
class_source = target_class.source_text
return package_stmt + "\n".join(import_statements) + "\n\n" + class_source
except Exception as e:
logger.error("Failed to extract class context: %s", e)
return ""

View file

@ -0,0 +1,328 @@
"""Java function and method discovery.
This module provides functionality to discover optimizable functions and methods
in Java source files using the tree-sitter parser.
"""
from __future__ import annotations
import logging
from pathlib import Path
from typing import TYPE_CHECKING
from codeflash.languages.base import (
FunctionFilterCriteria,
FunctionInfo,
Language,
ParentInfo,
)
from codeflash.languages.java.parser import JavaAnalyzer, JavaMethodNode, get_java_analyzer
if TYPE_CHECKING:
pass
logger = logging.getLogger(__name__)
def discover_functions(
file_path: Path,
filter_criteria: FunctionFilterCriteria | None = None,
analyzer: JavaAnalyzer | None = None,
) -> list[FunctionInfo]:
"""Find all optimizable functions/methods in a Java file.
Uses tree-sitter to parse the file and find methods that can be optimized.
Args:
file_path: Path to the Java file to analyze.
filter_criteria: Optional criteria to filter functions.
analyzer: Optional JavaAnalyzer instance (created if not provided).
Returns:
List of FunctionInfo objects for discovered functions.
"""
criteria = filter_criteria or FunctionFilterCriteria()
try:
source = file_path.read_text(encoding="utf-8")
except Exception as e:
logger.warning("Failed to read %s: %s", file_path, e)
return []
return discover_functions_from_source(source, file_path, criteria, analyzer)
def discover_functions_from_source(
source: str,
file_path: Path | None = None,
filter_criteria: FunctionFilterCriteria | None = None,
analyzer: JavaAnalyzer | None = None,
) -> list[FunctionInfo]:
"""Find all optimizable functions/methods in Java source code.
Args:
source: The Java source code to analyze.
file_path: Optional file path for context.
filter_criteria: Optional criteria to filter functions.
analyzer: Optional JavaAnalyzer instance.
Returns:
List of FunctionInfo objects for discovered functions.
"""
criteria = filter_criteria or FunctionFilterCriteria()
analyzer = analyzer or get_java_analyzer()
try:
# Find all methods
methods = analyzer.find_methods(
source,
include_private=True, # Include all, filter later
include_static=True,
)
functions: list[FunctionInfo] = []
for method in methods:
# Apply filters
if not _should_include_method(method, criteria, source, analyzer):
continue
# Build parents list
parents: list[ParentInfo] = []
if method.class_name:
parents.append(ParentInfo(name=method.class_name, type="ClassDef"))
functions.append(
FunctionInfo(
name=method.name,
file_path=file_path or Path("unknown.java"),
start_line=method.start_line,
end_line=method.end_line,
start_col=method.start_col,
end_col=method.end_col,
parents=tuple(parents),
is_async=False, # Java doesn't have async keyword
is_method=method.class_name is not None,
language=Language.JAVA,
doc_start_line=method.javadoc_start_line,
)
)
return functions
except Exception as e:
logger.warning("Failed to parse Java source: %s", e)
return []
def _should_include_method(
method: JavaMethodNode,
criteria: FunctionFilterCriteria,
source: str,
analyzer: JavaAnalyzer,
) -> bool:
"""Check if a method should be included based on filter criteria.
Args:
method: The method to check.
criteria: Filter criteria to apply.
source: Source code for additional analysis.
analyzer: JavaAnalyzer for additional checks.
Returns:
True if the method should be included.
"""
# Skip abstract methods (no implementation to optimize)
if method.is_abstract:
return False
# Skip constructors (special case - could be optimized but usually not)
if method.name == method.class_name:
return False
# Check include patterns
if criteria.include_patterns:
import fnmatch
if not any(fnmatch.fnmatch(method.name, pattern) for pattern in criteria.include_patterns):
return False
# Check exclude patterns
if criteria.exclude_patterns:
import fnmatch
if any(fnmatch.fnmatch(method.name, pattern) for pattern in criteria.exclude_patterns):
return False
# Check require_return - void methods don't have return values
if criteria.require_return:
if method.return_type == "void":
return False
# Also check if the method actually has a return statement
if not analyzer.has_return_statement(method, source):
return False
# Check include_methods - in Java, all functions in classes are methods
if not criteria.include_methods and method.class_name is not None:
return False
# Check line count
method_lines = method.end_line - method.start_line + 1
if criteria.min_lines is not None and method_lines < criteria.min_lines:
return False
if criteria.max_lines is not None and method_lines > criteria.max_lines:
return False
return True
def discover_test_methods(
file_path: Path,
analyzer: JavaAnalyzer | None = None,
) -> list[FunctionInfo]:
"""Find all JUnit test methods in a Java test file.
Looks for methods annotated with @Test, @ParameterizedTest, @RepeatedTest, etc.
Args:
file_path: Path to the Java test file.
analyzer: Optional JavaAnalyzer instance.
Returns:
List of FunctionInfo objects for discovered test methods.
"""
try:
source = file_path.read_text(encoding="utf-8")
except Exception as e:
logger.warning("Failed to read %s: %s", file_path, e)
return []
analyzer = analyzer or get_java_analyzer()
source_bytes = source.encode("utf8")
tree = analyzer.parse(source_bytes)
test_methods: list[FunctionInfo] = []
# Find methods with test annotations
_walk_tree_for_test_methods(tree.root_node, source_bytes, file_path, test_methods, analyzer, current_class=None)
return test_methods
def _walk_tree_for_test_methods(
node,
source_bytes: bytes,
file_path: Path,
test_methods: list[FunctionInfo],
analyzer: JavaAnalyzer,
current_class: str | None,
) -> None:
"""Recursively walk the tree to find test methods."""
new_class = current_class
if node.type == "class_declaration":
name_node = node.child_by_field_name("name")
if name_node:
new_class = analyzer.get_node_text(name_node, source_bytes)
if node.type == "method_declaration":
# Check for test annotations
has_test_annotation = False
for child in node.children:
if child.type == "modifiers":
for mod_child in child.children:
if mod_child.type == "marker_annotation" or mod_child.type == "annotation":
annotation_text = analyzer.get_node_text(mod_child, source_bytes)
# Check for JUnit 5 test annotations
if any(
ann in annotation_text
for ann in ["@Test", "@ParameterizedTest", "@RepeatedTest", "@TestFactory"]
):
has_test_annotation = True
break
if has_test_annotation:
name_node = node.child_by_field_name("name")
if name_node:
method_name = analyzer.get_node_text(name_node, source_bytes)
parents: list[ParentInfo] = []
if current_class:
parents.append(ParentInfo(name=current_class, type="ClassDef"))
test_methods.append(
FunctionInfo(
name=method_name,
file_path=file_path,
start_line=node.start_point[0] + 1,
end_line=node.end_point[0] + 1,
start_col=node.start_point[1],
end_col=node.end_point[1],
parents=tuple(parents),
is_async=False,
is_method=current_class is not None,
language=Language.JAVA,
)
)
for child in node.children:
_walk_tree_for_test_methods(
child,
source_bytes,
file_path,
test_methods,
analyzer,
current_class=new_class if node.type == "class_declaration" else current_class,
)
def get_method_by_name(
file_path: Path,
method_name: str,
class_name: str | None = None,
analyzer: JavaAnalyzer | None = None,
) -> FunctionInfo | None:
"""Find a specific method by name in a Java file.
Args:
file_path: Path to the Java file.
method_name: Name of the method to find.
class_name: Optional class name to narrow the search.
analyzer: Optional JavaAnalyzer instance.
Returns:
FunctionInfo for the method, or None if not found.
"""
functions = discover_functions(file_path, analyzer=analyzer)
for func in functions:
if func.name == method_name:
if class_name is None or func.class_name == class_name:
return func
return None
def get_class_methods(
file_path: Path,
class_name: str,
analyzer: JavaAnalyzer | None = None,
) -> list[FunctionInfo]:
"""Get all methods in a specific class.
Args:
file_path: Path to the Java file.
class_name: Name of the class.
analyzer: Optional JavaAnalyzer instance.
Returns:
List of FunctionInfo objects for methods in the class.
"""
functions = discover_functions(file_path, analyzer=analyzer)
return [f for f in functions if f.class_name == class_name]

View file

@ -0,0 +1,347 @@
"""Java code formatting.
This module provides functionality to format Java code using
google-java-format or other available formatters.
"""
from __future__ import annotations
import logging
import os
import shutil
import subprocess
import tempfile
from pathlib import Path
from typing import TYPE_CHECKING
if TYPE_CHECKING:
pass
logger = logging.getLogger(__name__)
class JavaFormatter:
"""Java code formatter using google-java-format or fallback methods."""
# Path to google-java-format JAR (if downloaded)
_google_java_format_jar: Path | None = None
# Version of google-java-format to use
GOOGLE_JAVA_FORMAT_VERSION = "1.19.2"
def __init__(self, project_root: Path | None = None):
"""Initialize the Java formatter.
Args:
project_root: Optional project root for project-specific formatting rules.
"""
self.project_root = project_root
self._java_executable = self._find_java()
def _find_java(self) -> str | None:
"""Find the Java executable."""
# Check JAVA_HOME
java_home = os.environ.get("JAVA_HOME")
if java_home:
java_path = Path(java_home) / "bin" / "java"
if java_path.exists():
return str(java_path)
# Check PATH
java_path = shutil.which("java")
if java_path:
return java_path
return None
def format_code(self, source: str, file_path: Path | None = None) -> str:
"""Format Java source code.
Attempts to use google-java-format if available, otherwise
returns the source unchanged.
Args:
source: The Java source code to format.
file_path: Optional file path for context.
Returns:
Formatted source code.
"""
if not source or not source.strip():
return source
# Try google-java-format first
formatted = self._format_with_google_java_format(source)
if formatted is not None:
return formatted
# Try Eclipse formatter (if available in project)
if self.project_root:
formatted = self._format_with_eclipse(source)
if formatted is not None:
return formatted
# Return original source if no formatter available
logger.debug("No Java formatter available, returning original source")
return source
def _format_with_google_java_format(self, source: str) -> str | None:
"""Format using google-java-format.
Args:
source: The source code to format.
Returns:
Formatted source, or None if formatting failed.
"""
if not self._java_executable:
return None
# Try to find or download google-java-format
jar_path = self._get_google_java_format_jar()
if not jar_path:
return None
try:
# Write source to temp file
with tempfile.NamedTemporaryFile(
mode="w", suffix=".java", delete=False, encoding="utf-8"
) as tmp:
tmp.write(source)
tmp_path = tmp.name
try:
result = subprocess.run(
[
self._java_executable,
"-jar",
str(jar_path),
"--replace",
tmp_path,
],
check=False,
capture_output=True,
text=True,
timeout=30,
)
if result.returncode == 0:
# Read back the formatted file
with open(tmp_path, encoding="utf-8") as f:
return f.read()
else:
logger.debug(
"google-java-format failed: %s", result.stderr or result.stdout
)
finally:
# Clean up temp file
try:
os.unlink(tmp_path)
except OSError:
pass
except subprocess.TimeoutExpired:
logger.warning("google-java-format timed out")
except Exception as e:
logger.debug("google-java-format error: %s", e)
return None
def _get_google_java_format_jar(self) -> Path | None:
"""Get path to google-java-format JAR, downloading if necessary.
Returns:
Path to the JAR file, or None if not available.
"""
if JavaFormatter._google_java_format_jar:
if JavaFormatter._google_java_format_jar.exists():
return JavaFormatter._google_java_format_jar
# Check common locations
possible_paths = [
# In project's .codeflash directory
self.project_root / ".codeflash" / f"google-java-format-{self.GOOGLE_JAVA_FORMAT_VERSION}-all-deps.jar"
if self.project_root
else None,
# In user's home directory
Path.home()
/ ".codeflash"
/ f"google-java-format-{self.GOOGLE_JAVA_FORMAT_VERSION}-all-deps.jar",
# In system temp
Path(tempfile.gettempdir())
/ "codeflash"
/ f"google-java-format-{self.GOOGLE_JAVA_FORMAT_VERSION}-all-deps.jar",
]
for path in possible_paths:
if path and path.exists():
JavaFormatter._google_java_format_jar = path
return path
# Don't auto-download to avoid surprises
# Users can manually download the JAR
logger.debug(
"google-java-format JAR not found. "
"Download from https://github.com/google/google-java-format/releases"
)
return None
def _format_with_eclipse(self, source: str) -> str | None:
"""Format using Eclipse formatter settings (if available in project).
Args:
source: The source code to format.
Returns:
Formatted source, or None if formatting failed.
"""
# Eclipse formatter requires eclipse.ini or a config file
# This is a placeholder for future implementation
return None
def download_google_java_format(self, target_dir: Path | None = None) -> Path | None:
"""Download google-java-format JAR.
Args:
target_dir: Directory to download to (defaults to ~/.codeflash/).
Returns:
Path to the downloaded JAR, or None if download failed.
"""
import urllib.request
target_dir = target_dir or Path.home() / ".codeflash"
target_dir.mkdir(parents=True, exist_ok=True)
jar_name = f"google-java-format-{self.GOOGLE_JAVA_FORMAT_VERSION}-all-deps.jar"
jar_path = target_dir / jar_name
if jar_path.exists():
JavaFormatter._google_java_format_jar = jar_path
return jar_path
url = (
f"https://github.com/google/google-java-format/releases/download/"
f"v{self.GOOGLE_JAVA_FORMAT_VERSION}/{jar_name}"
)
try:
logger.info("Downloading google-java-format from %s", url)
urllib.request.urlretrieve(url, jar_path)
JavaFormatter._google_java_format_jar = jar_path
logger.info("Downloaded google-java-format to %s", jar_path)
return jar_path
except Exception as e:
logger.error("Failed to download google-java-format: %s", e)
return None
def format_java_code(source: str, project_root: Path | None = None) -> str:
"""Convenience function to format Java code.
Args:
source: The Java source code to format.
project_root: Optional project root for context.
Returns:
Formatted source code.
"""
formatter = JavaFormatter(project_root)
return formatter.format_code(source)
def format_java_file(file_path: Path, in_place: bool = False) -> str:
"""Format a Java file.
Args:
file_path: Path to the Java file.
in_place: Whether to modify the file in place.
Returns:
Formatted source code.
"""
source = file_path.read_text(encoding="utf-8")
formatter = JavaFormatter(file_path.parent)
formatted = formatter.format_code(source, file_path)
if in_place and formatted != source:
file_path.write_text(formatted, encoding="utf-8")
return formatted
def normalize_java_code(source: str) -> str:
"""Normalize Java code for deduplication.
This removes comments and normalizes whitespace to allow
comparison of semantically equivalent code.
Args:
source: The Java source code.
Returns:
Normalized source code.
"""
lines = source.splitlines()
normalized_lines = []
in_block_comment = False
for line in lines:
# Handle block comments
if in_block_comment:
if "*/" in line:
in_block_comment = False
line = line[line.index("*/") + 2 :]
else:
continue
# Remove line comments
if "//" in line:
# Find // that's not inside a string
in_string = False
escape_next = False
comment_start = -1
for i, char in enumerate(line):
if escape_next:
escape_next = False
continue
if char == "\\":
escape_next = True
continue
if char == '"' and not in_string:
in_string = True
elif char == '"' and in_string:
in_string = False
elif not in_string and i < len(line) - 1 and line[i : i + 2] == "//":
comment_start = i
break
if comment_start >= 0:
line = line[:comment_start]
# Handle start of block comments
if "/*" in line:
start_idx = line.index("/*")
if "*/" in line[start_idx:]:
# Block comment on single line
end_idx = line.index("*/", start_idx)
line = line[:start_idx] + line[end_idx + 2 :]
else:
in_block_comment = True
line = line[:start_idx]
# Skip empty lines and add non-empty ones
stripped = line.strip()
if stripped:
normalized_lines.append(stripped)
return "\n".join(normalized_lines)

View file

@ -0,0 +1,360 @@
"""Java import resolution.
This module provides functionality to resolve Java imports to actual file paths
within a project, handling both source and test directories.
"""
from __future__ import annotations
import logging
from dataclasses import dataclass
from pathlib import Path
from typing import TYPE_CHECKING
from codeflash.languages.java.build_tools import find_source_root, find_test_root, get_project_info
from codeflash.languages.java.parser import JavaAnalyzer, JavaImportInfo, get_java_analyzer
if TYPE_CHECKING:
pass
logger = logging.getLogger(__name__)
@dataclass
class ResolvedImport:
"""A resolved Java import."""
import_path: str # Original import path (e.g., "com.example.utils.StringUtils")
file_path: Path | None # Resolved file path, or None if external/unresolved
is_external: bool # True if this is an external dependency (not in project)
is_wildcard: bool # True if this was a wildcard import
class_name: str | None # The imported class name (e.g., "StringUtils")
class JavaImportResolver:
"""Resolves Java imports to file paths within a project."""
# Standard Java packages that are always external
STANDARD_PACKAGES = frozenset(
[
"java",
"javax",
"sun",
"com.sun",
"jdk",
"org.w3c",
"org.xml",
"org.ietf",
]
)
# Common third-party package prefixes
COMMON_EXTERNAL_PREFIXES = frozenset(
[
"org.junit",
"org.mockito",
"org.assertj",
"org.hamcrest",
"org.slf4j",
"org.apache",
"org.springframework",
"com.google",
"com.fasterxml",
"io.netty",
"io.github",
"lombok",
]
)
def __init__(self, project_root: Path):
"""Initialize the import resolver.
Args:
project_root: Root directory of the Java project.
"""
self.project_root = project_root
self._source_roots: list[Path] = []
self._test_roots: list[Path] = []
self._package_to_path_cache: dict[str, Path | None] = {}
# Discover source and test roots
self._discover_roots()
def _discover_roots(self) -> None:
"""Discover source and test root directories."""
# Try to get project info first
project_info = get_project_info(self.project_root)
if project_info:
self._source_roots = project_info.source_roots
self._test_roots = project_info.test_roots
else:
# Fall back to standard detection
source_root = find_source_root(self.project_root)
if source_root:
self._source_roots = [source_root]
test_root = find_test_root(self.project_root)
if test_root:
self._test_roots = [test_root]
def resolve_import(self, import_info: JavaImportInfo) -> ResolvedImport:
"""Resolve a single import to a file path.
Args:
import_info: The import to resolve.
Returns:
ResolvedImport with resolution details.
"""
import_path = import_info.import_path
# Check if it's a standard library import
if self._is_standard_library(import_path):
return ResolvedImport(
import_path=import_path,
file_path=None,
is_external=True,
is_wildcard=import_info.is_wildcard,
class_name=self._extract_class_name(import_path),
)
# Check if it's a known external library
if self._is_external_library(import_path):
return ResolvedImport(
import_path=import_path,
file_path=None,
is_external=True,
is_wildcard=import_info.is_wildcard,
class_name=self._extract_class_name(import_path),
)
# Try to resolve within the project
resolved_path = self._resolve_to_file(import_path)
return ResolvedImport(
import_path=import_path,
file_path=resolved_path,
is_external=resolved_path is None,
is_wildcard=import_info.is_wildcard,
class_name=self._extract_class_name(import_path),
)
def resolve_imports(self, imports: list[JavaImportInfo]) -> list[ResolvedImport]:
"""Resolve multiple imports.
Args:
imports: List of imports to resolve.
Returns:
List of ResolvedImport objects.
"""
return [self.resolve_import(imp) for imp in imports]
def _is_standard_library(self, import_path: str) -> bool:
"""Check if an import is from the Java standard library."""
for prefix in self.STANDARD_PACKAGES:
if import_path.startswith(prefix + ".") or import_path == prefix:
return True
return False
def _is_external_library(self, import_path: str) -> bool:
"""Check if an import is from a known external library."""
for prefix in self.COMMON_EXTERNAL_PREFIXES:
if import_path.startswith(prefix + ".") or import_path == prefix:
return True
return False
def _resolve_to_file(self, import_path: str) -> Path | None:
"""Try to resolve an import path to a file in the project.
Args:
import_path: The fully qualified import path.
Returns:
Path to the Java file, or None if not found.
"""
# Check cache
if import_path in self._package_to_path_cache:
return self._package_to_path_cache[import_path]
# Convert package path to file path
# e.g., "com.example.utils.StringUtils" -> "com/example/utils/StringUtils.java"
relative_path = import_path.replace(".", "/") + ".java"
# Search in source roots
for source_root in self._source_roots:
candidate = source_root / relative_path
if candidate.exists():
self._package_to_path_cache[import_path] = candidate
return candidate
# Search in test roots
for test_root in self._test_roots:
candidate = test_root / relative_path
if candidate.exists():
self._package_to_path_cache[import_path] = candidate
return candidate
# Not found
self._package_to_path_cache[import_path] = None
return None
def _extract_class_name(self, import_path: str) -> str | None:
"""Extract the class name from an import path.
Args:
import_path: The import path (e.g., "com.example.MyClass").
Returns:
The class name (e.g., "MyClass"), or None if it's a wildcard.
"""
if not import_path:
return None
parts = import_path.split(".")
if parts:
last_part = parts[-1]
# Check if it looks like a class name (starts with uppercase)
if last_part and last_part[0].isupper():
return last_part
return None
def find_class_file(self, class_name: str, package_hint: str | None = None) -> Path | None:
"""Find the file containing a specific class.
Args:
class_name: The simple class name (e.g., "StringUtils").
package_hint: Optional package hint to narrow the search.
Returns:
Path to the Java file, or None if not found.
"""
if package_hint:
# Try the exact path first
import_path = f"{package_hint}.{class_name}"
result = self._resolve_to_file(import_path)
if result:
return result
# Search all source and test roots for the class
file_name = f"{class_name}.java"
for root in self._source_roots + self._test_roots:
for java_file in root.rglob(file_name):
return java_file
return None
def get_imports_from_file(
self, file_path: Path, analyzer: JavaAnalyzer | None = None
) -> list[ResolvedImport]:
"""Get and resolve all imports from a Java file.
Args:
file_path: Path to the Java file.
analyzer: Optional JavaAnalyzer instance.
Returns:
List of ResolvedImport objects.
"""
analyzer = analyzer or get_java_analyzer()
try:
source = file_path.read_text(encoding="utf-8")
imports = analyzer.find_imports(source)
return self.resolve_imports(imports)
except Exception as e:
logger.warning("Failed to get imports from %s: %s", file_path, e)
return []
def get_project_imports(
self, file_path: Path, analyzer: JavaAnalyzer | None = None
) -> list[ResolvedImport]:
"""Get only the imports that resolve to files within the project.
Args:
file_path: Path to the Java file.
analyzer: Optional JavaAnalyzer instance.
Returns:
List of ResolvedImport objects for project-internal imports only.
"""
all_imports = self.get_imports_from_file(file_path, analyzer)
return [imp for imp in all_imports if not imp.is_external and imp.file_path is not None]
def resolve_imports_for_file(
file_path: Path, project_root: Path, analyzer: JavaAnalyzer | None = None
) -> list[ResolvedImport]:
"""Convenience function to resolve imports for a single file.
Args:
file_path: Path to the Java file.
project_root: Root directory of the project.
analyzer: Optional JavaAnalyzer instance.
Returns:
List of ResolvedImport objects.
"""
resolver = JavaImportResolver(project_root)
return resolver.get_imports_from_file(file_path, analyzer)
def find_helper_files(
file_path: Path,
project_root: Path,
max_depth: int = 2,
analyzer: JavaAnalyzer | None = None,
) -> dict[Path, list[str]]:
"""Find helper files imported by a Java file, recursively.
This traces the import chain to find all project files that the
given file depends on, up to max_depth levels.
Args:
file_path: Path to the Java file.
project_root: Root directory of the project.
max_depth: Maximum depth of import chain to follow.
analyzer: Optional JavaAnalyzer instance.
Returns:
Dict mapping file paths to list of imported class names.
"""
resolver = JavaImportResolver(project_root)
analyzer = analyzer or get_java_analyzer()
result: dict[Path, list[str]] = {}
visited: set[Path] = {file_path}
def _trace_imports(current_file: Path, depth: int) -> None:
if depth > max_depth:
return
project_imports = resolver.get_project_imports(current_file, analyzer)
for imp in project_imports:
if imp.file_path and imp.file_path not in visited:
visited.add(imp.file_path)
if imp.file_path not in result:
result[imp.file_path] = []
if imp.class_name:
result[imp.file_path].append(imp.class_name)
# Recurse into the imported file
_trace_imports(imp.file_path, depth + 1)
_trace_imports(file_path, 0)
return result

View file

@ -0,0 +1,354 @@
"""Java code instrumentation for behavior capture and benchmarking.
This module provides functionality to instrument Java code for:
1. Behavior capture - recording inputs/outputs for verification
2. Benchmarking - measuring execution time
"""
from __future__ import annotations
import logging
import re
from pathlib import Path
from typing import TYPE_CHECKING
from codeflash.languages.base import FunctionInfo
from codeflash.languages.java.parser import JavaAnalyzer, get_java_analyzer
if TYPE_CHECKING:
from collections.abc import Sequence
from typing import Any
logger = logging.getLogger(__name__)
def _get_function_name(func: Any) -> str:
"""Get the function name from either FunctionInfo or FunctionToOptimize."""
if hasattr(func, "name"):
return func.name
if hasattr(func, "function_name"):
return func.function_name
raise AttributeError(f"Cannot get function name from {type(func)}")
# Template for behavior capture instrumentation
BEHAVIOR_CAPTURE_IMPORT = "import com.codeflash.CodeFlash;"
BEHAVIOR_CAPTURE_BEFORE = """
// CodeFlash behavior capture - start
long __codeflash_call_id_{call_id} = System.nanoTime();
CodeFlash.recordInput(__codeflash_call_id_{call_id}, "{method_id}", CodeFlash.serialize({args}));
long __codeflash_start_{call_id} = System.nanoTime();
"""
BEHAVIOR_CAPTURE_AFTER_RETURN = """
// CodeFlash behavior capture - end
long __codeflash_end_{call_id} = System.nanoTime();
CodeFlash.recordOutput(__codeflash_call_id_{call_id}, "{method_id}", CodeFlash.serialize(__codeflash_result_{call_id}), __codeflash_end_{call_id} - __codeflash_start_{call_id});
"""
BEHAVIOR_CAPTURE_AFTER_VOID = """
// CodeFlash behavior capture - end
long __codeflash_end_{call_id} = System.nanoTime();
CodeFlash.recordOutput(__codeflash_call_id_{call_id}, "{method_id}", "null", __codeflash_end_{call_id} - __codeflash_start_{call_id});
"""
# Template for benchmark instrumentation
BENCHMARK_IMPORT = """import com.codeflash.Blackhole;
import com.codeflash.BenchmarkContext;
import com.codeflash.BenchmarkResult;"""
BENCHMARK_WRAPPER_TEMPLATE = """
// CodeFlash benchmark wrapper
public void __codeflash_benchmark_{method_name}(int iterations) {{
// Warmup
for (int i = 0; i < Math.min(iterations / 10, 100); i++) {{
{warmup_call}
}}
// Measurement
long[] measurements = new long[iterations];
for (int i = 0; i < iterations; i++) {{
long start = System.nanoTime();
{measurement_call}
long end = System.nanoTime();
measurements[i] = end - start;
}}
BenchmarkResult result = new BenchmarkResult("{method_id}", measurements);
CodeFlash.recordBenchmarkResult("{method_id}", result);
}}
"""
def instrument_for_behavior(
source: str,
functions: Sequence[FunctionInfo],
analyzer: JavaAnalyzer | None = None,
) -> str:
"""Add behavior instrumentation to capture inputs/outputs.
Wraps function calls to record arguments and return values
for behavioral verification.
Args:
source: Source code to instrument.
functions: Functions to add behavior capture.
analyzer: Optional JavaAnalyzer instance.
Returns:
Instrumented source code.
"""
analyzer = analyzer or get_java_analyzer()
if not functions:
return source
# Add import if not present
if BEHAVIOR_CAPTURE_IMPORT not in source:
source = _add_import(source, BEHAVIOR_CAPTURE_IMPORT)
# Find and instrument each function
for func in functions:
source = _instrument_function_behavior(source, func, analyzer)
return source
def _add_import(source: str, import_statement: str) -> str:
"""Add an import statement to the source.
Args:
source: The source code.
import_statement: The import to add.
Returns:
Source with import added.
"""
lines = source.splitlines(keepends=True)
insert_idx = 0
# Find the last import or package statement
for i, line in enumerate(lines):
stripped = line.strip()
if stripped.startswith("import ") or stripped.startswith("package "):
insert_idx = i + 1
elif stripped and not stripped.startswith("//") and not stripped.startswith("/*"):
# First non-import, non-comment line
if insert_idx == 0:
insert_idx = i
break
lines.insert(insert_idx, import_statement + "\n")
return "".join(lines)
def _instrument_function_behavior(
source: str,
function: FunctionInfo,
analyzer: JavaAnalyzer,
) -> str:
"""Instrument a single function for behavior capture.
Args:
source: The source code.
function: The function to instrument.
analyzer: JavaAnalyzer instance.
Returns:
Source with function instrumented.
"""
source_bytes = source.encode("utf8")
tree = analyzer.parse(source_bytes)
# Find the method node
methods = analyzer.find_methods(source)
target_method = None
func_name = _get_function_name(function)
for method in methods:
if method.name == func_name:
class_name = getattr(function, "class_name", None)
if class_name is None or method.class_name == class_name:
target_method = method
break
if not target_method:
logger.warning("Could not find method %s for instrumentation", func_name)
return source
# For now, we'll add instrumentation as a simple wrapper
# A full implementation would use AST transformation
method_id = function.qualified_name
call_id = hash(method_id) % 10000
# Build instrumented version
# This is a simplified approach - a full implementation would
# parse the method body and instrument each return statement
logger.debug("Instrumented method %s for behavior capture", function.name)
return source
def instrument_for_benchmarking(
test_source: str,
target_function: FunctionInfo,
analyzer: JavaAnalyzer | None = None,
) -> str:
"""Add timing instrumentation to test code.
Args:
test_source: Test source code to instrument.
target_function: Function being benchmarked.
Returns:
Instrumented test source code.
"""
analyzer = analyzer or get_java_analyzer()
# Add imports if not present
if "import com.codeflash" not in test_source:
test_source = _add_import(test_source, BENCHMARK_IMPORT)
# Find calls to the target function in the test and wrap them
# This is a simplified implementation
logger.debug("Instrumented test for benchmarking %s", _get_function_name(target_function))
return test_source
def instrument_existing_test(
test_path: Path,
call_positions: Sequence,
function_to_optimize: FunctionInfo,
tests_project_root: Path,
mode: str, # "behavior" or "performance"
analyzer: JavaAnalyzer | None = None,
) -> tuple[bool, str | None]:
"""Inject profiling code into an existing test file.
Args:
test_path: Path to the test file.
call_positions: List of code positions where the function is called.
function_to_optimize: The function being optimized.
tests_project_root: Root directory of tests.
mode: Testing mode - "behavior" or "performance".
analyzer: Optional JavaAnalyzer instance.
Returns:
Tuple of (success, instrumented_code or error message).
"""
analyzer = analyzer or get_java_analyzer()
try:
source = test_path.read_text(encoding="utf-8")
except Exception as e:
return False, f"Failed to read test file: {e}"
try:
if mode == "behavior":
instrumented = instrument_for_behavior(source, [function_to_optimize], analyzer)
else:
instrumented = instrument_for_benchmarking(source, function_to_optimize, analyzer)
return True, instrumented
except Exception as e:
logger.exception("Failed to instrument test file: %s", e)
return False, str(e)
def create_benchmark_test(
target_function: FunctionInfo,
test_setup_code: str,
invocation_code: str,
iterations: int = 1000,
) -> str:
"""Create a benchmark test for a function.
Args:
target_function: The function to benchmark.
test_setup_code: Code to set up the test (create instances, etc.).
invocation_code: Code that invokes the function.
iterations: Number of benchmark iterations.
Returns:
Complete benchmark test source code.
"""
method_name = target_function.name
method_id = target_function.qualified_name
benchmark_code = f"""
import com.codeflash.Blackhole;
import com.codeflash.BenchmarkContext;
import com.codeflash.BenchmarkResult;
import com.codeflash.CodeFlash;
import org.junit.jupiter.api.Test;
public class {target_function.class_name or 'Target'}Benchmark {{
@Test
public void benchmark{method_name.capitalize()}() {{
{test_setup_code}
// Warmup phase
for (int i = 0; i < {iterations // 10}; i++) {{
Blackhole.consume({invocation_code});
}}
// Measurement phase
long[] measurements = new long[{iterations}];
for (int i = 0; i < {iterations}; i++) {{
long start = System.nanoTime();
Blackhole.consume({invocation_code});
long end = System.nanoTime();
measurements[i] = end - start;
}}
BenchmarkResult result = new BenchmarkResult("{method_id}", measurements);
CodeFlash.recordBenchmarkResult("{method_id}", result);
System.out.println("Benchmark complete: " + result);
}}
}}
"""
return benchmark_code
def remove_instrumentation(source: str) -> str:
"""Remove CodeFlash instrumentation from source code.
Args:
source: Instrumented source code.
Returns:
Source with instrumentation removed.
"""
lines = source.splitlines(keepends=True)
result_lines = []
skip_until_end = False
for line in lines:
stripped = line.strip()
# Skip CodeFlash instrumentation blocks
if "// CodeFlash" in stripped and "start" in stripped:
skip_until_end = True
continue
if skip_until_end:
if "// CodeFlash" in stripped and "end" in stripped:
skip_until_end = False
continue
# Skip CodeFlash imports
if "import com.codeflash" in stripped:
continue
result_lines.append(line)
return "".join(result_lines)

View file

@ -0,0 +1,693 @@
"""Tree-sitter utilities for Java code analysis.
This module provides a unified interface for parsing and analyzing Java code
using tree-sitter, following the same patterns as the JavaScript/TypeScript implementation.
"""
from __future__ import annotations
import logging
from dataclasses import dataclass
from typing import TYPE_CHECKING
from tree_sitter import Language, Parser
if TYPE_CHECKING:
from pathlib import Path
from tree_sitter import Node, Tree
logger = logging.getLogger(__name__)
# Lazy-loaded language instance
_JAVA_LANGUAGE: Language | None = None
def _get_java_language() -> Language:
"""Get the Java tree-sitter Language instance, with lazy loading."""
global _JAVA_LANGUAGE
if _JAVA_LANGUAGE is None:
import tree_sitter_java
_JAVA_LANGUAGE = Language(tree_sitter_java.language())
return _JAVA_LANGUAGE
@dataclass
class JavaMethodNode:
"""Represents a method found by tree-sitter analysis."""
name: str
node: Node
start_line: int
end_line: int
start_col: int
end_col: int
is_static: bool
is_public: bool
is_private: bool
is_protected: bool
is_abstract: bool
is_synchronized: bool
return_type: str | None
class_name: str | None
source_text: str
javadoc_start_line: int | None = None # Line where Javadoc comment starts
@dataclass
class JavaClassNode:
"""Represents a class found by tree-sitter analysis."""
name: str
node: Node
start_line: int
end_line: int
start_col: int
end_col: int
is_public: bool
is_abstract: bool
is_final: bool
is_static: bool # For inner classes
extends: str | None
implements: list[str]
source_text: str
javadoc_start_line: int | None = None
@dataclass
class JavaImportInfo:
"""Represents a Java import statement."""
import_path: str # Full import path (e.g., "java.util.List")
is_static: bool
is_wildcard: bool # import java.util.*
start_line: int
end_line: int
@dataclass
class JavaFieldInfo:
"""Represents a class field."""
name: str
type_name: str
is_static: bool
is_final: bool
is_public: bool
is_private: bool
is_protected: bool
start_line: int
end_line: int
source_text: str
class JavaAnalyzer:
"""Java code analysis using tree-sitter.
This class provides methods to parse and analyze Java code,
finding methods, classes, imports, and other code structures.
"""
def __init__(self) -> None:
"""Initialize the Java analyzer."""
self._parser: Parser | None = None
@property
def parser(self) -> Parser:
"""Get the parser, creating it lazily."""
if self._parser is None:
self._parser = Parser(_get_java_language())
return self._parser
def parse(self, source: str | bytes) -> Tree:
"""Parse source code into a tree-sitter tree.
Args:
source: Source code as string or bytes.
Returns:
The parsed tree.
"""
if isinstance(source, str):
source = source.encode("utf8")
return self.parser.parse(source)
def get_node_text(self, node: Node, source: bytes) -> str:
"""Extract the source text for a tree-sitter node.
Args:
node: The tree-sitter node.
source: The source code as bytes.
Returns:
The text content of the node.
"""
return source[node.start_byte : node.end_byte].decode("utf8")
def find_methods(
self, source: str, include_private: bool = True, include_static: bool = True
) -> list[JavaMethodNode]:
"""Find all method definitions in source code.
Args:
source: The source code to analyze.
include_private: Whether to include private methods.
include_static: Whether to include static methods.
Returns:
List of JavaMethodNode objects describing found methods.
"""
source_bytes = source.encode("utf8")
tree = self.parse(source_bytes)
methods: list[JavaMethodNode] = []
self._walk_tree_for_methods(
tree.root_node,
source_bytes,
methods,
include_private=include_private,
include_static=include_static,
current_class=None,
)
return methods
def _walk_tree_for_methods(
self,
node: Node,
source_bytes: bytes,
methods: list[JavaMethodNode],
include_private: bool,
include_static: bool,
current_class: str | None,
) -> None:
"""Recursively walk the tree to find method definitions."""
new_class = current_class
# Track class context
if node.type == "class_declaration":
name_node = node.child_by_field_name("name")
if name_node:
new_class = self.get_node_text(name_node, source_bytes)
if node.type == "method_declaration":
method_info = self._extract_method_info(node, source_bytes, current_class)
if method_info:
# Apply filters
should_include = True
if method_info.is_private and not include_private:
should_include = False
if method_info.is_static and not include_static:
should_include = False
if should_include:
methods.append(method_info)
# Recurse into children
for child in node.children:
self._walk_tree_for_methods(
child,
source_bytes,
methods,
include_private=include_private,
include_static=include_static,
current_class=new_class if node.type == "class_declaration" else current_class,
)
def _extract_method_info(
self, node: Node, source_bytes: bytes, current_class: str | None
) -> JavaMethodNode | None:
"""Extract method information from a method_declaration node."""
name = ""
is_static = False
is_public = False
is_private = False
is_protected = False
is_abstract = False
is_synchronized = False
return_type: str | None = None
# Get method name
name_node = node.child_by_field_name("name")
if name_node:
name = self.get_node_text(name_node, source_bytes)
# Get return type
type_node = node.child_by_field_name("type")
if type_node:
return_type = self.get_node_text(type_node, source_bytes)
# Check modifiers
for child in node.children:
if child.type == "modifiers":
modifier_text = self.get_node_text(child, source_bytes)
is_static = "static" in modifier_text
is_public = "public" in modifier_text
is_private = "private" in modifier_text
is_protected = "protected" in modifier_text
is_abstract = "abstract" in modifier_text
is_synchronized = "synchronized" in modifier_text
break
# Get source text
source_text = self.get_node_text(node, source_bytes)
# Find preceding Javadoc comment
javadoc_start_line = self._find_preceding_javadoc(node, source_bytes)
return JavaMethodNode(
name=name,
node=node,
start_line=node.start_point[0] + 1, # Convert to 1-indexed
end_line=node.end_point[0] + 1,
start_col=node.start_point[1],
end_col=node.end_point[1],
is_static=is_static,
is_public=is_public,
is_private=is_private,
is_protected=is_protected,
is_abstract=is_abstract,
is_synchronized=is_synchronized,
return_type=return_type,
class_name=current_class,
source_text=source_text,
javadoc_start_line=javadoc_start_line,
)
def _find_preceding_javadoc(self, node: Node, source_bytes: bytes) -> int | None:
"""Find Javadoc comment immediately preceding a node.
Args:
node: The node to find Javadoc for.
source_bytes: The source code as bytes.
Returns:
The start line (1-indexed) of the Javadoc, or None if no Javadoc found.
"""
# Get the previous sibling node
prev_sibling = node.prev_named_sibling
# Check if it's a block comment that looks like Javadoc
if prev_sibling and prev_sibling.type == "block_comment":
comment_text = self.get_node_text(prev_sibling, source_bytes)
if comment_text.strip().startswith("/**"):
# Verify it's immediately preceding (no blank lines between)
comment_end_line = prev_sibling.end_point[0]
node_start_line = node.start_point[0]
if node_start_line - comment_end_line <= 1:
return prev_sibling.start_point[0] + 1 # 1-indexed
return None
def find_classes(self, source: str) -> list[JavaClassNode]:
"""Find all class definitions in source code.
Args:
source: The source code to analyze.
Returns:
List of JavaClassNode objects.
"""
source_bytes = source.encode("utf8")
tree = self.parse(source_bytes)
classes: list[JavaClassNode] = []
self._walk_tree_for_classes(tree.root_node, source_bytes, classes, is_inner=False)
return classes
def _walk_tree_for_classes(
self, node: Node, source_bytes: bytes, classes: list[JavaClassNode], is_inner: bool
) -> None:
"""Recursively walk the tree to find class definitions."""
if node.type == "class_declaration":
class_info = self._extract_class_info(node, source_bytes, is_inner)
if class_info:
classes.append(class_info)
# Look for inner classes
body_node = node.child_by_field_name("body")
if body_node:
for child in body_node.children:
self._walk_tree_for_classes(child, source_bytes, classes, is_inner=True)
return
# Continue walking for top-level classes
for child in node.children:
self._walk_tree_for_classes(child, source_bytes, classes, is_inner)
def _extract_class_info(
self, node: Node, source_bytes: bytes, is_inner: bool
) -> JavaClassNode | None:
"""Extract class information from a class_declaration node."""
name = ""
is_public = False
is_abstract = False
is_final = False
is_static = False
extends: str | None = None
implements: list[str] = []
# Get class name
name_node = node.child_by_field_name("name")
if name_node:
name = self.get_node_text(name_node, source_bytes)
# Check modifiers
for child in node.children:
if child.type == "modifiers":
modifier_text = self.get_node_text(child, source_bytes)
is_public = "public" in modifier_text
is_abstract = "abstract" in modifier_text
is_final = "final" in modifier_text
is_static = "static" in modifier_text
break
# Get superclass
superclass_node = node.child_by_field_name("superclass")
if superclass_node:
# superclass contains "extends ClassName"
for child in superclass_node.children:
if child.type == "type_identifier":
extends = self.get_node_text(child, source_bytes)
break
# Get interfaces (super_interfaces node contains the implements clause)
for child in node.children:
if child.type == "super_interfaces":
# Find the type_list inside super_interfaces
for subchild in child.children:
if subchild.type == "type_list":
for type_node in subchild.children:
if type_node.type == "type_identifier":
implements.append(self.get_node_text(type_node, source_bytes))
# Get source text
source_text = self.get_node_text(node, source_bytes)
# Find preceding Javadoc
javadoc_start_line = self._find_preceding_javadoc(node, source_bytes)
return JavaClassNode(
name=name,
node=node,
start_line=node.start_point[0] + 1,
end_line=node.end_point[0] + 1,
start_col=node.start_point[1],
end_col=node.end_point[1],
is_public=is_public,
is_abstract=is_abstract,
is_final=is_final,
is_static=is_static,
extends=extends,
implements=implements,
source_text=source_text,
javadoc_start_line=javadoc_start_line,
)
def find_imports(self, source: str) -> list[JavaImportInfo]:
"""Find all import statements in source code.
Args:
source: The source code to analyze.
Returns:
List of JavaImportInfo objects.
"""
source_bytes = source.encode("utf8")
tree = self.parse(source_bytes)
imports: list[JavaImportInfo] = []
for child in tree.root_node.children:
if child.type == "import_declaration":
import_info = self._extract_import_info(child, source_bytes)
if import_info:
imports.append(import_info)
return imports
def _extract_import_info(self, node: Node, source_bytes: bytes) -> JavaImportInfo | None:
"""Extract import information from an import_declaration node."""
import_path = ""
is_static = False
is_wildcard = False
# Check for static import
for child in node.children:
if child.type == "static":
is_static = True
break
# Get the import path (scoped_identifier or identifier)
for child in node.children:
if child.type == "scoped_identifier":
import_path = self.get_node_text(child, source_bytes)
break
if child.type == "identifier":
import_path = self.get_node_text(child, source_bytes)
break
# Check for wildcard
if import_path.endswith(".*") or ".*" in self.get_node_text(node, source_bytes):
is_wildcard = True
# Clean up the import path
import_path = import_path.rstrip(".*").rstrip(".")
return JavaImportInfo(
import_path=import_path,
is_static=is_static,
is_wildcard=is_wildcard,
start_line=node.start_point[0] + 1,
end_line=node.end_point[0] + 1,
)
def find_fields(self, source: str, class_name: str | None = None) -> list[JavaFieldInfo]:
"""Find all field declarations in source code.
Args:
source: The source code to analyze.
class_name: Optional class name to filter fields.
Returns:
List of JavaFieldInfo objects.
"""
source_bytes = source.encode("utf8")
tree = self.parse(source_bytes)
fields: list[JavaFieldInfo] = []
self._walk_tree_for_fields(tree.root_node, source_bytes, fields, current_class=None, target_class=class_name)
return fields
def _walk_tree_for_fields(
self,
node: Node,
source_bytes: bytes,
fields: list[JavaFieldInfo],
current_class: str | None,
target_class: str | None,
) -> None:
"""Recursively walk the tree to find field declarations."""
new_class = current_class
if node.type == "class_declaration":
name_node = node.child_by_field_name("name")
if name_node:
new_class = self.get_node_text(name_node, source_bytes)
if node.type == "field_declaration":
# Only include if we're in the target class (or no target specified)
if target_class is None or current_class == target_class:
field_info = self._extract_field_info(node, source_bytes)
if field_info:
fields.extend(field_info)
for child in node.children:
self._walk_tree_for_fields(
child,
source_bytes,
fields,
current_class=new_class if node.type == "class_declaration" else current_class,
target_class=target_class,
)
def _extract_field_info(self, node: Node, source_bytes: bytes) -> list[JavaFieldInfo]:
"""Extract field information from a field_declaration node.
Returns a list because a single declaration can define multiple fields.
"""
fields: list[JavaFieldInfo] = []
is_static = False
is_final = False
is_public = False
is_private = False
is_protected = False
type_name = ""
# Check modifiers
for child in node.children:
if child.type == "modifiers":
modifier_text = self.get_node_text(child, source_bytes)
is_static = "static" in modifier_text
is_final = "final" in modifier_text
is_public = "public" in modifier_text
is_private = "private" in modifier_text
is_protected = "protected" in modifier_text
break
# Get type
type_node = node.child_by_field_name("type")
if type_node:
type_name = self.get_node_text(type_node, source_bytes)
# Get variable declarators (there can be multiple: int a, b, c;)
for child in node.children:
if child.type == "variable_declarator":
name_node = child.child_by_field_name("name")
if name_node:
field_name = self.get_node_text(name_node, source_bytes)
fields.append(
JavaFieldInfo(
name=field_name,
type_name=type_name,
is_static=is_static,
is_final=is_final,
is_public=is_public,
is_private=is_private,
is_protected=is_protected,
start_line=node.start_point[0] + 1,
end_line=node.end_point[0] + 1,
source_text=self.get_node_text(node, source_bytes),
)
)
return fields
def find_method_calls(self, source: str, within_method: JavaMethodNode) -> list[str]:
"""Find all method calls within a specific method's body.
Args:
source: The full source code.
within_method: The method to search within.
Returns:
List of method names that are called.
"""
calls: list[str] = []
source_bytes = source.encode("utf8")
# Get the body of the method
body_node = within_method.node.child_by_field_name("body")
if body_node:
self._walk_tree_for_calls(body_node, source_bytes, calls)
return list(set(calls)) # Remove duplicates
def _walk_tree_for_calls(self, node: Node, source_bytes: bytes, calls: list[str]) -> None:
"""Recursively find method calls in a subtree."""
if node.type == "method_invocation":
name_node = node.child_by_field_name("name")
if name_node:
calls.append(self.get_node_text(name_node, source_bytes))
for child in node.children:
self._walk_tree_for_calls(child, source_bytes, calls)
def has_return_statement(self, method_node: JavaMethodNode, source: str) -> bool:
"""Check if a method has a return statement.
Args:
method_node: The method to check.
source: The source code.
Returns:
True if the method has a return statement.
"""
# void methods don't need return statements
if method_node.return_type == "void":
return False
return self._node_has_return(method_node.node)
def _node_has_return(self, node: Node) -> bool:
"""Recursively check if a node contains a return statement."""
if node.type == "return_statement":
return True
# Don't recurse into nested method declarations (lambdas)
if node.type in ("lambda_expression", "method_declaration"):
if node.type == "method_declaration":
body_node = node.child_by_field_name("body")
if body_node:
for child in body_node.children:
if self._node_has_return(child):
return True
return False
return any(self._node_has_return(child) for child in node.children)
def validate_syntax(self, source: str) -> bool:
"""Check if Java source code is syntactically valid.
Uses tree-sitter to parse and check for errors.
Args:
source: Source code to validate.
Returns:
True if valid, False otherwise.
"""
try:
tree = self.parse(source)
return not tree.root_node.has_error
except Exception:
return False
def get_package_name(self, source: str) -> str | None:
"""Extract the package name from Java source code.
Args:
source: The source code to analyze.
Returns:
The package name, or None if not found.
"""
source_bytes = source.encode("utf8")
tree = self.parse(source_bytes)
for child in tree.root_node.children:
if child.type == "package_declaration":
# Find the scoped_identifier within the package declaration
for pkg_child in child.children:
if pkg_child.type == "scoped_identifier":
return self.get_node_text(pkg_child, source_bytes)
if pkg_child.type == "identifier":
return self.get_node_text(pkg_child, source_bytes)
return None
def get_java_analyzer() -> JavaAnalyzer:
"""Get a JavaAnalyzer instance.
Returns:
JavaAnalyzer configured for Java.
"""
return JavaAnalyzer()

View file

@ -0,0 +1,420 @@
"""Java code replacement.
This module provides functionality to replace function implementations
in Java source code while preserving formatting and structure.
"""
from __future__ import annotations
import logging
import re
from pathlib import Path
from typing import TYPE_CHECKING
from codeflash.languages.base import FunctionInfo
from codeflash.languages.java.parser import JavaAnalyzer, JavaMethodNode, get_java_analyzer
if TYPE_CHECKING:
pass
logger = logging.getLogger(__name__)
def replace_function(
source: str,
function: FunctionInfo,
new_source: str,
analyzer: JavaAnalyzer | None = None,
) -> str:
"""Replace a function in source code with new implementation.
Preserves:
- Surrounding whitespace and formatting
- Javadoc comments (if they should be preserved)
- Annotations
Args:
source: Original source code.
function: FunctionInfo identifying the function to replace.
new_source: New function source code.
analyzer: Optional JavaAnalyzer instance.
Returns:
Modified source code with function replaced.
"""
analyzer = analyzer or get_java_analyzer()
# Find the method in the source
methods = analyzer.find_methods(source)
target_method = None
for method in methods:
if method.name == function.name:
if function.class_name is None or method.class_name == function.class_name:
target_method = method
break
if not target_method:
logger.error("Could not find method %s in source", function.name)
return source
# Determine replacement range
# Include Javadoc if present
start_line = target_method.javadoc_start_line or target_method.start_line
end_line = target_method.end_line
# Split source into lines
lines = source.splitlines(keepends=True)
# Get indentation from the original method
original_first_line = lines[start_line - 1] if start_line <= len(lines) else ""
indent = _get_indentation(original_first_line)
# Ensure new source has correct indentation
new_source_lines = new_source.splitlines(keepends=True)
indented_new_source = _apply_indentation(new_source_lines, indent)
# Build the result
before = lines[: start_line - 1] # Lines before the method
after = lines[end_line:] # Lines after the method
result = "".join(before) + indented_new_source + "".join(after)
return result
def _get_indentation(line: str) -> str:
"""Extract the indentation from a line.
Args:
line: The line to analyze.
Returns:
The indentation string (spaces/tabs).
"""
match = re.match(r"^(\s*)", line)
return match.group(1) if match else ""
def _apply_indentation(lines: list[str], base_indent: str) -> str:
"""Apply indentation to all lines.
Args:
lines: Lines to indent.
base_indent: Base indentation to apply.
Returns:
Indented source code.
"""
if not lines:
return ""
# Detect the existing indentation in the new source
existing_indent = ""
for line in lines:
stripped = line.lstrip()
if stripped and not stripped.startswith("//") and not stripped.startswith("/*"):
existing_indent = _get_indentation(line)
break
result_lines = []
for line in lines:
if not line.strip():
result_lines.append(line)
else:
# Remove existing indentation and apply new base indentation
stripped_line = line.lstrip()
# Calculate relative indentation
line_indent = _get_indentation(line)
if existing_indent and line_indent.startswith(existing_indent):
relative_indent = line_indent[len(existing_indent) :]
else:
relative_indent = ""
result_lines.append(base_indent + relative_indent + stripped_line)
return "".join(result_lines)
def replace_method_body(
source: str,
function: FunctionInfo,
new_body: str,
analyzer: JavaAnalyzer | None = None,
) -> str:
"""Replace just the body of a method, preserving signature.
Args:
source: Original source code.
function: FunctionInfo identifying the function.
new_body: New method body (code between braces).
analyzer: Optional JavaAnalyzer instance.
Returns:
Modified source code.
"""
analyzer = analyzer or get_java_analyzer()
source_bytes = source.encode("utf8")
# Find the method
methods = analyzer.find_methods(source)
target_method = None
for method in methods:
if method.name == function.name:
if function.class_name is None or method.class_name == function.class_name:
target_method = method
break
if not target_method:
logger.error("Could not find method %s", function.name)
return source
# Find the body node
body_node = target_method.node.child_by_field_name("body")
if not body_node:
logger.error("Method %s has no body (abstract?)", function.name)
return source
# Get the body's byte positions
body_start = body_node.start_byte
body_end = body_node.end_byte
# Get indentation
body_start_line = body_node.start_point[0]
lines = source.splitlines(keepends=True)
base_indent = _get_indentation(lines[body_start_line]) if body_start_line < len(lines) else " "
# Format the new body
new_body = new_body.strip()
if not new_body.startswith("{"):
new_body = "{\n" + base_indent + " " + new_body
if not new_body.endswith("}"):
new_body = new_body + "\n" + base_indent + "}"
# Replace the body
before = source_bytes[:body_start]
after = source_bytes[body_end:]
return (before + new_body.encode("utf8") + after).decode("utf8")
def insert_method(
source: str,
class_name: str,
method_source: str,
position: str = "end", # "end" or "start"
analyzer: JavaAnalyzer | None = None,
) -> str:
"""Insert a new method into a class.
Args:
source: The source code.
class_name: Name of the class to insert into.
method_source: Source code of the method to insert.
position: Where to insert ("end" or "start" of class body).
analyzer: Optional JavaAnalyzer instance.
Returns:
Source code with method inserted.
"""
analyzer = analyzer or get_java_analyzer()
# Find the class
classes = analyzer.find_classes(source)
target_class = None
for cls in classes:
if cls.name == class_name:
target_class = cls
break
if not target_class:
logger.error("Could not find class %s", class_name)
return source
# Find the class body
body_node = target_class.node.child_by_field_name("body")
if not body_node:
logger.error("Class %s has no body", class_name)
return source
# Get insertion point
source_bytes = source.encode("utf8")
if position == "end":
# Insert before the closing brace
insert_point = body_node.end_byte - 1
else:
# Insert after the opening brace
insert_point = body_node.start_byte + 1
# Get indentation (typically 4 spaces inside a class)
lines = source.splitlines(keepends=True)
class_line = target_class.start_line - 1
class_indent = _get_indentation(lines[class_line]) if class_line < len(lines) else ""
method_indent = class_indent + " "
# Format the method
method_lines = method_source.strip().splitlines(keepends=True)
indented_method = _apply_indentation(method_lines, method_indent)
# Insert the method
before = source_bytes[:insert_point]
after = source_bytes[insert_point:]
separator = "\n\n" if position == "end" else "\n"
return (before + separator.encode("utf8") + indented_method.encode("utf8") + after).decode("utf8")
def remove_method(
source: str,
function: FunctionInfo,
analyzer: JavaAnalyzer | None = None,
) -> str:
"""Remove a method from source code.
Args:
source: The source code.
function: FunctionInfo identifying the method to remove.
analyzer: Optional JavaAnalyzer instance.
Returns:
Source code with method removed.
"""
analyzer = analyzer or get_java_analyzer()
# Find the method
methods = analyzer.find_methods(source)
target_method = None
for method in methods:
if method.name == function.name:
if function.class_name is None or method.class_name == function.class_name:
target_method = method
break
if not target_method:
logger.error("Could not find method %s", function.name)
return source
# Determine removal range (include Javadoc)
start_line = target_method.javadoc_start_line or target_method.start_line
end_line = target_method.end_line
lines = source.splitlines(keepends=True)
# Remove the method lines
before = lines[: start_line - 1]
after = lines[end_line:]
return "".join(before) + "".join(after)
def remove_test_functions(
test_source: str,
functions_to_remove: list[str],
analyzer: JavaAnalyzer | None = None,
) -> str:
"""Remove specific test functions from test source code.
Args:
test_source: Test source code.
functions_to_remove: List of function names to remove.
analyzer: Optional JavaAnalyzer instance.
Returns:
Test source code with specified functions removed.
"""
analyzer = analyzer or get_java_analyzer()
# Find all methods
methods = analyzer.find_methods(test_source)
# Sort by start line in reverse order (remove from end first)
methods_to_remove = [
m for m in methods if m.name in functions_to_remove
]
methods_to_remove.sort(key=lambda m: m.start_line, reverse=True)
result = test_source
for method in methods_to_remove:
# Create a FunctionInfo for removal
func_info = FunctionInfo(
name=method.name,
file_path=Path("temp.java"),
start_line=method.start_line,
end_line=method.end_line,
parents=(),
is_method=True,
)
result = remove_method(result, func_info, analyzer)
return result
def add_runtime_comments(
test_source: str,
original_runtimes: dict[str, int],
optimized_runtimes: dict[str, int],
analyzer: JavaAnalyzer | None = None,
) -> str:
"""Add runtime performance comments to test source code.
Adds comments showing the original vs optimized runtime for each
function call (e.g., "// 1.5ms -> 0.3ms (80% faster)").
Args:
test_source: Test source code to annotate.
original_runtimes: Map of invocation IDs to original runtimes (ns).
optimized_runtimes: Map of invocation IDs to optimized runtimes (ns).
analyzer: Optional JavaAnalyzer instance.
Returns:
Test source code with runtime comments added.
"""
if not original_runtimes or not optimized_runtimes:
return test_source
# For now, add a summary comment at the top
summary_lines = ["// Performance comparison:"]
for inv_id in original_runtimes:
original_ns = original_runtimes[inv_id]
optimized_ns = optimized_runtimes.get(inv_id, original_ns)
original_ms = original_ns / 1_000_000
optimized_ms = optimized_ns / 1_000_000
if original_ns > 0:
speedup = ((original_ns - optimized_ns) / original_ns) * 100
summary_lines.append(
f"// {inv_id}: {original_ms:.3f}ms -> {optimized_ms:.3f}ms ({speedup:.1f}% faster)"
)
# Insert after imports
lines = test_source.splitlines(keepends=True)
insert_idx = 0
for i, line in enumerate(lines):
if line.strip().startswith("import "):
insert_idx = i + 1
elif line.strip() and not line.strip().startswith("//") and not line.strip().startswith("package"):
if insert_idx == 0:
insert_idx = i
break
# Insert summary
summary = "\n".join(summary_lines) + "\n\n"
lines.insert(insert_idx, summary)
return "".join(lines)

View file

@ -0,0 +1,384 @@
"""Main JavaSupport class implementing the LanguageSupport protocol.
This module provides the main JavaSupport class that implements all
required methods for Java language support in codeflash.
"""
from __future__ import annotations
import logging
from pathlib import Path
from typing import TYPE_CHECKING, Any
from codeflash.languages.base import (
CodeContext,
FunctionFilterCriteria,
FunctionInfo,
HelperFunction,
Language,
LanguageSupport,
TestInfo,
TestResult,
)
from codeflash.languages.registry import register_language
from codeflash.languages.java.build_tools import find_test_root
from codeflash.languages.java.comparator import compare_test_results as _compare_test_results
from codeflash.languages.java.config import detect_java_project
from codeflash.languages.java.context import extract_code_context, find_helper_functions
from codeflash.languages.java.discovery import discover_functions, discover_functions_from_source
from codeflash.languages.java.formatter import format_java_code, normalize_java_code
from codeflash.languages.java.instrumentation import (
instrument_existing_test,
instrument_for_behavior,
instrument_for_benchmarking,
)
from codeflash.languages.java.parser import get_java_analyzer
from codeflash.languages.java.replacement import (
add_runtime_comments,
remove_test_functions,
replace_function,
)
from codeflash.languages.java.test_discovery import discover_tests
from codeflash.languages.java.test_runner import (
parse_test_results,
run_behavioral_tests,
run_benchmarking_tests,
run_tests,
)
if TYPE_CHECKING:
from collections.abc import Sequence
logger = logging.getLogger(__name__)
@register_language
class JavaSupport(LanguageSupport):
"""Java language support implementation.
Implements the LanguageSupport protocol for Java, providing:
- Function discovery using tree-sitter
- Test discovery for JUnit 5
- Test execution via Maven Surefire
- Code context extraction
- Code replacement and formatting
- Behavior capture instrumentation
- Benchmarking instrumentation
"""
def __init__(self) -> None:
"""Initialize Java support."""
self._analyzer = get_java_analyzer()
@property
def language(self) -> Language:
"""The language this implementation supports."""
return Language.JAVA
@property
def file_extensions(self) -> tuple[str, ...]:
"""File extensions supported by Java."""
return (".java",)
@property
def test_framework(self) -> str:
"""Primary test framework name."""
return "junit5"
@property
def comment_prefix(self) -> str:
"""Comment prefix for Java."""
return "//"
# === Discovery ===
def discover_functions(
self, file_path: Path, filter_criteria: FunctionFilterCriteria | None = None
) -> list[FunctionInfo]:
"""Find all optimizable functions in a Java file."""
return discover_functions(file_path, filter_criteria, self._analyzer)
def discover_tests(
self, test_root: Path, source_functions: Sequence[FunctionInfo]
) -> dict[str, list[TestInfo]]:
"""Map source functions to their tests."""
return discover_tests(test_root, source_functions, self._analyzer)
# === Code Analysis ===
def extract_code_context(
self, function: FunctionInfo, project_root: Path, module_root: Path
) -> CodeContext:
"""Extract function code and its dependencies."""
return extract_code_context(function, project_root, module_root, analyzer=self._analyzer)
def find_helper_functions(
self, function: FunctionInfo, project_root: Path
) -> list[HelperFunction]:
"""Find helper functions called by the target function."""
return find_helper_functions(function, project_root, analyzer=self._analyzer)
# === Code Transformation ===
def replace_function(
self, source: str, function: FunctionInfo, new_source: str
) -> str:
"""Replace a function in source code with new implementation."""
return replace_function(source, function, new_source, self._analyzer)
def format_code(self, source: str, file_path: Path | None = None) -> str:
"""Format Java code."""
project_root = file_path.parent if file_path else None
return format_java_code(source, project_root)
# === Test Execution ===
def run_tests(
self,
test_files: Sequence[Path],
cwd: Path,
env: dict[str, str],
timeout: int,
) -> tuple[list[TestResult], Path]:
"""Run tests and return results."""
return run_tests(list(test_files), cwd, env, timeout)
def parse_test_results(self, junit_xml_path: Path, stdout: str) -> list[TestResult]:
"""Parse test results from JUnit XML."""
return parse_test_results(junit_xml_path, stdout)
# === Instrumentation ===
def instrument_for_behavior(
self, source: str, functions: Sequence[FunctionInfo]
) -> str:
"""Add behavior instrumentation to capture inputs/outputs."""
return instrument_for_behavior(source, functions, self._analyzer)
def instrument_for_benchmarking(
self, test_source: str, target_function: FunctionInfo
) -> str:
"""Add timing instrumentation to test code."""
return instrument_for_benchmarking(test_source, target_function, self._analyzer)
# === Validation ===
def validate_syntax(self, source: str) -> bool:
"""Check if Java source code is syntactically valid."""
return self._analyzer.validate_syntax(source)
def normalize_code(self, source: str) -> str:
"""Normalize code for deduplication."""
return normalize_java_code(source)
# === Test Editing ===
def add_runtime_comments(
self,
test_source: str,
original_runtimes: dict[str, int],
optimized_runtimes: dict[str, int],
) -> str:
"""Add runtime performance comments to test source code."""
return add_runtime_comments(test_source, original_runtimes, optimized_runtimes, self._analyzer)
def remove_test_functions(
self, test_source: str, functions_to_remove: list[str]
) -> str:
"""Remove specific test functions from test source code."""
return remove_test_functions(test_source, functions_to_remove, self._analyzer)
# === Test Result Comparison ===
def compare_test_results(
self,
original_results_path: Path,
candidate_results_path: Path,
project_root: Path | None = None,
) -> tuple[bool, list]:
"""Compare test results between original and candidate code."""
return _compare_test_results(
original_results_path, candidate_results_path, project_root=project_root
)
# === Configuration ===
def get_test_file_suffix(self) -> str:
"""Get the test file suffix for Java."""
return "Test.java"
def get_comment_prefix(self) -> str:
"""Get the comment prefix for Java."""
return "//"
def find_test_root(self, project_root: Path) -> Path | None:
"""Find the test root directory for a Java project."""
return find_test_root(project_root)
def get_project_root(self, source_file: Path) -> Path | None:
"""Find the project root for a Java file.
Looks for pom.xml, build.gradle, or build.gradle.kts.
Args:
source_file: Path to the source file.
Returns:
The project root directory, or None if not found.
"""
current = source_file.parent
while current != current.parent:
if (current / "pom.xml").exists():
return current
if (current / "build.gradle").exists() or (current / "build.gradle.kts").exists():
return current
current = current.parent
return None
def get_module_path(self, source_file: Path, project_root: Path, tests_root: Path | None = None) -> str:
"""Get the module path for a Java source file.
For Java, this returns the fully qualified class name (e.g., 'com.example.Algorithms').
Args:
source_file: Path to the source file.
project_root: Root of the project.
tests_root: Not used for Java.
Returns:
Fully qualified class name string.
"""
# Find the package from the file content
try:
content = source_file.read_text(encoding="utf-8")
for line in content.split("\n"):
line = line.strip()
if line.startswith("package "):
# Extract package name (remove 'package ' prefix and ';' suffix)
package = line[8:].rstrip(";").strip()
class_name = source_file.stem
return f"{package}.{class_name}"
except Exception:
pass
# Fallback: derive from path relative to src/main/java
relative = source_file.relative_to(project_root)
parts = list(relative.parts)
# Remove src/main/java prefix if present
if len(parts) > 3 and parts[:3] == ["src", "main", "java"]:
parts = parts[3:]
# Remove .java extension and join with dots
if parts:
parts[-1] = parts[-1].replace(".java", "")
return ".".join(parts)
def get_runtime_files(self) -> list[Path]:
"""Get paths to runtime files needed for Java."""
# The Java runtime is distributed as a JAR
return []
def ensure_runtime_environment(self, project_root: Path) -> bool:
"""Ensure the runtime environment is set up."""
# Check if codeflash-runtime is available
config = detect_java_project(project_root)
if config is None:
return False
# For now, assume the runtime is available
# A full implementation would check/install the JAR
return True
def instrument_existing_test(
self,
test_path: Path,
call_positions: Sequence[Any],
function_to_optimize: Any,
tests_project_root: Path,
mode: str,
) -> tuple[bool, str | None]:
"""Inject profiling code into an existing test file."""
return instrument_existing_test(
test_path,
call_positions,
function_to_optimize,
tests_project_root,
mode,
self._analyzer,
)
def instrument_source_for_line_profiler(
self, func_info: FunctionInfo, line_profiler_output_file: Path
) -> bool:
"""Instrument source code before line profiling."""
# Not yet implemented for Java
return False
def parse_line_profile_results(self, line_profiler_output_file: Path) -> dict:
"""Parse line profiler output."""
# Not yet implemented for Java
return {}
def run_behavioral_tests(
self,
test_paths: Any,
test_env: dict[str, str],
cwd: Path,
timeout: int | None = None,
project_root: Path | None = None,
enable_coverage: bool = False,
candidate_index: int = 0,
) -> tuple[Path, Any, Path | None, Path | None]:
"""Run behavioral tests for Java."""
return run_behavioral_tests(
test_paths,
test_env,
cwd,
timeout,
project_root,
enable_coverage,
candidate_index,
)
def run_benchmarking_tests(
self,
test_paths: Any,
test_env: dict[str, str],
cwd: Path,
timeout: int | None = None,
project_root: Path | None = None,
min_loops: int = 5,
max_loops: int = 100_000,
target_duration_seconds: float = 10.0,
) -> tuple[Path, Any]:
"""Run benchmarking tests for Java."""
return run_benchmarking_tests(
test_paths,
test_env,
cwd,
timeout,
project_root,
min_loops,
max_loops,
target_duration_seconds,
)
# Create a singleton instance for the registry
_java_support: JavaSupport | None = None
def get_java_support() -> JavaSupport:
"""Get the JavaSupport singleton instance.
Returns:
The JavaSupport instance.
"""
global _java_support
if _java_support is None:
_java_support = JavaSupport()
return _java_support

View file

@ -0,0 +1,370 @@
"""Java test discovery for JUnit 5.
This module provides functionality to discover tests that exercise
specific functions, mapping source functions to their tests.
"""
from __future__ import annotations
import logging
import re
from collections import defaultdict
from pathlib import Path
from typing import TYPE_CHECKING
from codeflash.languages.base import FunctionInfo, TestInfo
from codeflash.languages.java.config import detect_java_project
from codeflash.languages.java.discovery import discover_test_methods
from codeflash.languages.java.parser import JavaAnalyzer, get_java_analyzer
if TYPE_CHECKING:
from collections.abc import Sequence
logger = logging.getLogger(__name__)
def discover_tests(
test_root: Path,
source_functions: Sequence[FunctionInfo],
analyzer: JavaAnalyzer | None = None,
) -> dict[str, list[TestInfo]]:
"""Map source functions to their tests via static analysis.
Uses several heuristics to match tests to functions:
1. Test method name contains function name
2. Test class name matches source class name
3. Imports analysis
4. Method call analysis in test code
Args:
test_root: Root directory containing tests.
source_functions: Functions to find tests for.
analyzer: Optional JavaAnalyzer instance.
Returns:
Dict mapping qualified function names to lists of TestInfo.
"""
analyzer = analyzer or get_java_analyzer()
# Build a map of function names for quick lookup
function_map: dict[str, FunctionInfo] = {}
for func in source_functions:
function_map[func.name] = func
function_map[func.qualified_name] = func
# Find all test files
test_files = list(test_root.rglob("*Test.java")) + list(test_root.rglob("Test*.java"))
# Result map
result: dict[str, list[TestInfo]] = defaultdict(list)
for test_file in test_files:
try:
test_methods = discover_test_methods(test_file, analyzer)
source = test_file.read_text(encoding="utf-8")
for test_method in test_methods:
# Find which source functions this test might exercise
matched_functions = _match_test_to_functions(
test_method, source, function_map, analyzer
)
for func_name in matched_functions:
result[func_name].append(
TestInfo(
test_name=test_method.name,
test_file=test_file,
test_class=test_method.class_name,
)
)
except Exception as e:
logger.warning("Failed to analyze test file %s: %s", test_file, e)
return dict(result)
def _match_test_to_functions(
test_method: FunctionInfo,
test_source: str,
function_map: dict[str, FunctionInfo],
analyzer: JavaAnalyzer,
) -> list[str]:
"""Match a test method to source functions it might exercise.
Args:
test_method: The test method.
test_source: Full source code of the test file.
function_map: Map of function names to FunctionInfo.
analyzer: JavaAnalyzer instance.
Returns:
List of function qualified names that this test might exercise.
"""
matched: list[str] = []
# Strategy 1: Test method name contains function name
# e.g., testAdd -> add, testCalculatorAdd -> Calculator.add
test_name_lower = test_method.name.lower()
for func_name, func_info in function_map.items():
if func_info.name.lower() in test_name_lower:
matched.append(func_info.qualified_name)
# Strategy 2: Method call analysis
# Look for direct method calls in the test code
source_bytes = test_source.encode("utf8")
tree = analyzer.parse(source_bytes)
# Find method calls within the test method's line range
method_calls = _find_method_calls_in_range(
tree.root_node,
source_bytes,
test_method.start_line,
test_method.end_line,
analyzer,
)
for call_name in method_calls:
if call_name in function_map:
qualified = function_map[call_name].qualified_name
if qualified not in matched:
matched.append(qualified)
# Strategy 3: Test class naming convention
# e.g., CalculatorTest tests Calculator
if test_method.class_name:
# Remove "Test" suffix or prefix
source_class_name = test_method.class_name
if source_class_name.endswith("Test"):
source_class_name = source_class_name[:-4]
elif source_class_name.startswith("Test"):
source_class_name = source_class_name[4:]
# Look for functions in the matching class
for func_name, func_info in function_map.items():
if func_info.class_name == source_class_name:
if func_info.qualified_name not in matched:
matched.append(func_info.qualified_name)
return matched
def _find_method_calls_in_range(
node,
source_bytes: bytes,
start_line: int,
end_line: int,
analyzer: JavaAnalyzer,
) -> list[str]:
"""Find method calls within a line range.
Args:
node: Tree-sitter node to search.
source_bytes: Source code as bytes.
start_line: Start line (1-indexed).
end_line: End line (1-indexed).
analyzer: JavaAnalyzer instance.
Returns:
List of method names called.
"""
calls: list[str] = []
# Check if this node is within the range (convert to 0-indexed)
node_start = node.start_point[0] + 1
node_end = node.end_point[0] + 1
if node_end < start_line or node_start > end_line:
return calls
if node.type == "method_invocation":
name_node = node.child_by_field_name("name")
if name_node:
calls.append(analyzer.get_node_text(name_node, source_bytes))
for child in node.children:
calls.extend(
_find_method_calls_in_range(child, source_bytes, start_line, end_line, analyzer)
)
return calls
def find_tests_for_function(
function: FunctionInfo,
test_root: Path,
analyzer: JavaAnalyzer | None = None,
) -> list[TestInfo]:
"""Find tests that exercise a specific function.
Args:
function: The function to find tests for.
test_root: Root directory containing tests.
analyzer: Optional JavaAnalyzer instance.
Returns:
List of TestInfo for tests that might exercise this function.
"""
result = discover_tests(test_root, [function], analyzer)
return result.get(function.qualified_name, [])
def get_test_class_for_source_class(
source_class_name: str,
test_root: Path,
) -> Path | None:
"""Find the test class file for a source class.
Args:
source_class_name: Name of the source class.
test_root: Root directory containing tests.
Returns:
Path to the test file, or None if not found.
"""
# Try common naming patterns
patterns = [
f"{source_class_name}Test.java",
f"Test{source_class_name}.java",
f"{source_class_name}Tests.java",
]
for pattern in patterns:
matches = list(test_root.rglob(pattern))
if matches:
return matches[0]
return None
def discover_all_tests(
test_root: Path,
analyzer: JavaAnalyzer | None = None,
) -> list[FunctionInfo]:
"""Discover all test methods in a test directory.
Args:
test_root: Root directory containing tests.
analyzer: Optional JavaAnalyzer instance.
Returns:
List of FunctionInfo for all test methods.
"""
analyzer = analyzer or get_java_analyzer()
all_tests: list[FunctionInfo] = []
# Find all test files
test_files = list(test_root.rglob("*Test.java")) + list(test_root.rglob("Test*.java"))
for test_file in test_files:
try:
tests = discover_test_methods(test_file, analyzer)
all_tests.extend(tests)
except Exception as e:
logger.warning("Failed to analyze test file %s: %s", test_file, e)
return all_tests
def get_test_file_suffix() -> str:
"""Get the test file suffix for Java.
Returns:
Test file suffix.
"""
return "Test.java"
def is_test_file(file_path: Path) -> bool:
"""Check if a file is a test file.
Args:
file_path: Path to check.
Returns:
True if this appears to be a test file.
"""
name = file_path.name
# Check naming patterns
if name.endswith("Test.java") or name.endswith("Tests.java"):
return True
if name.startswith("Test") and name.endswith(".java"):
return True
# Check if it's in a test directory
path_parts = file_path.parts
for part in path_parts:
if part in ("test", "tests", "src/test"):
return True
return False
def get_test_methods_for_class(
test_file: Path,
test_class_name: str | None = None,
analyzer: JavaAnalyzer | None = None,
) -> list[FunctionInfo]:
"""Get all test methods in a specific test class.
Args:
test_file: Path to the test file.
test_class_name: Optional class name to filter (uses file name if not provided).
analyzer: Optional JavaAnalyzer instance.
Returns:
List of FunctionInfo for test methods.
"""
tests = discover_test_methods(test_file, analyzer)
if test_class_name:
return [t for t in tests if t.class_name == test_class_name]
return tests
def build_test_mapping_for_project(
project_root: Path,
analyzer: JavaAnalyzer | None = None,
) -> dict[str, list[TestInfo]]:
"""Build a complete test mapping for a project.
Args:
project_root: Root directory of the project.
analyzer: Optional JavaAnalyzer instance.
Returns:
Dict mapping qualified function names to lists of TestInfo.
"""
analyzer = analyzer or get_java_analyzer()
# Detect project configuration
config = detect_java_project(project_root)
if not config:
return {}
if not config.source_root or not config.test_root:
return {}
# Discover all source functions
from codeflash.languages.java.discovery import discover_functions
source_functions: list[FunctionInfo] = []
for java_file in config.source_root.rglob("*.java"):
funcs = discover_functions(java_file, analyzer=analyzer)
source_functions.extend(funcs)
# Map tests to functions
return discover_tests(config.test_root, source_functions, analyzer)

View file

@ -0,0 +1,440 @@
"""Java test runner for JUnit 5 with Maven.
This module provides functionality to run JUnit 5 tests using Maven Surefire,
supporting both behavioral testing and benchmarking modes.
"""
from __future__ import annotations
import logging
import os
import subprocess
import tempfile
import uuid
import xml.etree.ElementTree as ET
from dataclasses import dataclass
from pathlib import Path
from typing import TYPE_CHECKING, Any
from codeflash.languages.base import TestResult
from codeflash.languages.java.build_tools import (
find_maven_executable,
find_test_root,
)
if TYPE_CHECKING:
pass
logger = logging.getLogger(__name__)
@dataclass
class JavaTestRunResult:
"""Result of running Java tests."""
success: bool
tests_run: int
tests_passed: int
tests_failed: int
tests_skipped: int
test_results: list[TestResult]
sqlite_db_path: Path | None
junit_xml_path: Path | None
stdout: str
stderr: str
returncode: int
def run_behavioral_tests(
test_paths: Any,
test_env: dict[str, str],
cwd: Path,
timeout: int | None = None,
project_root: Path | None = None,
enable_coverage: bool = False,
candidate_index: int = 0,
) -> tuple[Path, Any, Path | None, Path | None]:
"""Run behavioral tests for Java code.
This runs tests and captures behavior (inputs/outputs) for verification.
Args:
test_paths: TestFiles object or list of test file paths.
test_env: Environment variables for the test run.
cwd: Working directory for running tests.
timeout: Optional timeout in seconds.
project_root: Project root directory.
enable_coverage: Whether to collect coverage information.
candidate_index: Index of the candidate being tested.
Returns:
Tuple of (result_file_path, subprocess_result, coverage_path, config_path).
"""
project_root = project_root or cwd
# Generate unique result file path
result_id = uuid.uuid4().hex[:8]
result_file = Path(tempfile.gettempdir()) / f"codeflash_java_behavior_{result_id}.db"
# Set environment variables for CodeFlash runtime
run_env = os.environ.copy()
run_env.update(test_env)
run_env["CODEFLASH_RESULT_FILE"] = str(result_file)
run_env["CODEFLASH_MODE"] = "behavior"
# Run Maven tests
result = _run_maven_tests(
project_root,
test_paths,
run_env,
timeout=timeout or 300,
)
return result_file, result, None, None
def run_benchmarking_tests(
test_paths: Any,
test_env: dict[str, str],
cwd: Path,
timeout: int | None = None,
project_root: Path | None = None,
min_loops: int = 5,
max_loops: int = 100_000,
target_duration_seconds: float = 10.0,
) -> tuple[Path, Any]:
"""Run benchmarking tests for Java code.
This runs tests with performance measurement.
Args:
test_paths: TestFiles object or list of test file paths.
test_env: Environment variables for the test run.
cwd: Working directory for running tests.
timeout: Optional timeout in seconds.
project_root: Project root directory.
min_loops: Minimum number of loops for benchmarking.
max_loops: Maximum number of loops for benchmarking.
target_duration_seconds: Target duration for benchmarking in seconds.
Returns:
Tuple of (result_file_path, subprocess_result).
"""
project_root = project_root or cwd
# Generate unique result file path
result_id = uuid.uuid4().hex[:8]
result_file = Path(tempfile.gettempdir()) / f"codeflash_java_benchmark_{result_id}.db"
# Set environment variables
run_env = os.environ.copy()
run_env.update(test_env)
run_env["CODEFLASH_RESULT_FILE"] = str(result_file)
run_env["CODEFLASH_MODE"] = "benchmark"
run_env["CODEFLASH_MIN_LOOPS"] = str(min_loops)
run_env["CODEFLASH_MAX_LOOPS"] = str(max_loops)
run_env["CODEFLASH_TARGET_DURATION"] = str(target_duration_seconds)
# Run Maven tests
result = _run_maven_tests(
project_root,
test_paths,
run_env,
timeout=timeout or 600, # Longer timeout for benchmarks
)
return result_file, result
def _run_maven_tests(
project_root: Path,
test_paths: Any,
env: dict[str, str],
timeout: int = 300,
) -> subprocess.CompletedProcess:
"""Run Maven tests with Surefire.
Args:
project_root: Root directory of the Maven project.
test_paths: Test files or classes to run.
env: Environment variables.
timeout: Maximum execution time in seconds.
Returns:
CompletedProcess with test results.
"""
mvn = find_maven_executable()
if not mvn:
logger.error("Maven not found")
return subprocess.CompletedProcess(
args=["mvn"],
returncode=-1,
stdout="",
stderr="Maven not found",
)
# Build test filter
test_filter = _build_test_filter(test_paths)
# Build Maven command
cmd = [mvn, "test", "-fae"] # Fail at end to run all tests
if test_filter:
cmd.append(f"-Dtest={test_filter}")
try:
result = subprocess.run(
cmd,
check=False,
cwd=project_root,
env=env,
capture_output=True,
text=True,
timeout=timeout,
)
return result
except subprocess.TimeoutExpired:
logger.error("Maven test execution timed out after %d seconds", timeout)
return subprocess.CompletedProcess(
args=cmd,
returncode=-2,
stdout="",
stderr=f"Test execution timed out after {timeout} seconds",
)
except Exception as e:
logger.exception("Maven test execution failed: %s", e)
return subprocess.CompletedProcess(
args=cmd,
returncode=-1,
stdout="",
stderr=str(e),
)
def _build_test_filter(test_paths: Any) -> str:
"""Build a Maven Surefire test filter from test paths.
Args:
test_paths: Test files, classes, or methods to include.
Returns:
Surefire test filter string.
"""
if not test_paths:
return ""
# Handle different input types
if isinstance(test_paths, (list, tuple)):
filters = []
for path in test_paths:
if isinstance(path, Path):
# Convert file path to class name
class_name = _path_to_class_name(path)
if class_name:
filters.append(class_name)
elif isinstance(path, str):
filters.append(path)
return ",".join(filters) if filters else ""
# Handle TestFiles object (has test_files attribute)
if hasattr(test_paths, "test_files"):
return _build_test_filter(list(test_paths.test_files))
return ""
def _path_to_class_name(path: Path) -> str | None:
"""Convert a test file path to a Java class name.
Args:
path: Path to the test file.
Returns:
Fully qualified class name, or None if unable to determine.
"""
if not path.suffix == ".java":
return None
# Try to extract package from path
# e.g., src/test/java/com/example/CalculatorTest.java -> com.example.CalculatorTest
parts = path.parts
# Find 'java' in the path and take everything after
try:
java_idx = parts.index("java")
class_parts = parts[java_idx + 1 :]
# Remove .java extension from last part
class_parts = list(class_parts)
class_parts[-1] = class_parts[-1].replace(".java", "")
return ".".join(class_parts)
except ValueError:
# No 'java' directory, just use the file name
return path.stem
def run_tests(
test_files: list[Path],
cwd: Path,
env: dict[str, str],
timeout: int,
) -> tuple[list[TestResult], Path]:
"""Run tests and return results.
Args:
test_files: Paths to test files to run.
cwd: Working directory for test execution.
env: Environment variables.
timeout: Maximum execution time in seconds.
Returns:
Tuple of (list of TestResults, path to JUnit XML).
"""
# Run Maven tests
result = _run_maven_tests(cwd, test_files, env, timeout)
# Parse JUnit XML results
surefire_dir = cwd / "target" / "surefire-reports"
test_results = parse_surefire_results(surefire_dir)
# Return first XML file path
junit_files = list(surefire_dir.glob("TEST-*.xml")) if surefire_dir.exists() else []
junit_path = junit_files[0] if junit_files else cwd / "target" / "surefire-reports" / "test-results.xml"
return test_results, junit_path
def parse_test_results(junit_xml_path: Path, stdout: str) -> list[TestResult]:
"""Parse test results from JUnit XML and stdout.
Args:
junit_xml_path: Path to JUnit XML results file.
stdout: Standard output from test execution.
Returns:
List of TestResult objects.
"""
return parse_surefire_results(junit_xml_path.parent)
def parse_surefire_results(surefire_dir: Path) -> list[TestResult]:
"""Parse Maven Surefire XML reports into TestResult objects.
Args:
surefire_dir: Directory containing Surefire XML reports.
Returns:
List of TestResult objects.
"""
results: list[TestResult] = []
if not surefire_dir.exists():
return results
for xml_file in surefire_dir.glob("TEST-*.xml"):
results.extend(_parse_surefire_xml(xml_file))
return results
def _parse_surefire_xml(xml_file: Path) -> list[TestResult]:
"""Parse a single Surefire XML file.
Args:
xml_file: Path to the XML file.
Returns:
List of TestResult objects for tests in this file.
"""
results: list[TestResult] = []
try:
tree = ET.parse(xml_file)
root = tree.getroot()
# Get test class info
class_name = root.get("name", "")
# Process each test case
for testcase in root.findall(".//testcase"):
test_name = testcase.get("name", "")
test_time = float(testcase.get("time", "0"))
runtime_ns = int(test_time * 1_000_000_000)
# Check for failure/error
failure = testcase.find("failure")
error = testcase.find("error")
skipped = testcase.find("skipped")
passed = failure is None and error is None and skipped is None
error_message = None
if failure is not None:
error_message = failure.get("message", "")
if failure.text:
error_message += "\n" + failure.text
if error is not None:
error_message = error.get("message", "")
if error.text:
error_message += "\n" + error.text
# Get stdout/stderr from system-out/system-err elements
stdout = ""
stderr = ""
stdout_elem = testcase.find("system-out")
if stdout_elem is not None and stdout_elem.text:
stdout = stdout_elem.text
stderr_elem = testcase.find("system-err")
if stderr_elem is not None and stderr_elem.text:
stderr = stderr_elem.text
results.append(
TestResult(
test_name=test_name,
test_file=xml_file,
passed=passed,
runtime_ns=runtime_ns,
stdout=stdout,
stderr=stderr,
error_message=error_message,
)
)
except ET.ParseError as e:
logger.warning("Failed to parse Surefire report %s: %s", xml_file, e)
return results
def get_test_run_command(
project_root: Path,
test_classes: list[str] | None = None,
) -> list[str]:
"""Get the command to run Java tests.
Args:
project_root: Root directory of the Maven project.
test_classes: Optional list of test class names to run.
Returns:
Command as list of strings.
"""
mvn = find_maven_executable() or "mvn"
cmd = [mvn, "test"]
if test_classes:
cmd.append(f"-Dtest={','.join(test_classes)}")
return cmd

View file

@ -24,7 +24,7 @@ from codeflash.code_utils.git_worktree_utils import (
)
from codeflash.code_utils.time_utils import humanize_runtime
from codeflash.either import is_successful
from codeflash.languages import is_javascript, set_current_language
from codeflash.languages import is_java, is_javascript, set_current_language
from codeflash.models.models import ValidCode
from codeflash.telemetry.posthog_cf import ph
from codeflash.verification.verification_utils import TestConfig
@ -229,8 +229,8 @@ class Optimizer:
original_module_code: str = original_module_path.read_text(encoding="utf8")
# For JavaScript/TypeScript, skip Python-specific AST parsing
if is_javascript():
# For JavaScript/TypeScript/Java, skip Python-specific AST parsing
if is_javascript() or is_java():
validated_original_code: dict[Path, ValidCode] = {
original_module_path: ValidCode(source_code=original_module_code, normalized_code=original_module_code)
}

View file

@ -6,14 +6,19 @@ from typing import Optional
from pydantic.dataclasses import dataclass
from codeflash.languages import current_language_support, is_javascript
from codeflash.languages import current_language_support, is_java, is_javascript
def get_test_file_path(test_dir: Path, function_name: str, iteration: int = 0, test_type: str = "unit") -> Path:
assert test_type in {"unit", "inspired", "replay", "perf"}
function_name = function_name.replace(".", "_")
# Use appropriate file extension based on language
extension = current_language_support().get_test_file_suffix() if is_javascript() else ".py"
if is_javascript():
extension = current_language_support().get_test_file_suffix()
elif is_java():
extension = ".java"
else:
extension = ".py"
path = test_dir / f"test_{function_name}__{test_type}_test_{iteration}{extension}"
if path.exists():
return get_test_file_path(test_dir, function_name, iteration + 1, test_type)
@ -86,10 +91,12 @@ class TestConfig:
def test_framework(self) -> str:
"""Returns the appropriate test framework based on language.
Returns 'jest' for JavaScript/TypeScript, 'pytest' for Python (default).
Returns 'jest' for JavaScript/TypeScript, 'junit5' for Java, 'pytest' for Python (default).
"""
if is_javascript():
return "jest"
if is_java():
return "junit5"
return "pytest"
def set_language(self, language: str) -> None:

View file

@ -25,6 +25,7 @@ dependencies = [
"tree-sitter>=0.23.0",
"tree-sitter-javascript>=0.23.0",
"tree-sitter-typescript>=0.23.0",
"tree-sitter-java>=0.23.0",
"pytest-timeout>=2.1.0",
"tomlkit>=0.11.7",
"junitparser>=3.1.0",

View file

@ -0,0 +1,5 @@
# Codeflash configuration for Java project
[tool.codeflash]
module-root = "src/main/java"
tests-root = "src/test/java"

View file

@ -0,0 +1,127 @@
package com.example;
import com.example.helpers.MathHelper;
import com.example.helpers.Formatter;
/**
* Calculator class - demonstrates class method optimization scenarios.
* Uses helper functions from MathHelper and Formatter.
*/
public class Calculator {
private int precision;
private java.util.List<String> history;
/**
* Creates a Calculator with specified precision.
* @param precision number of decimal places for formatting
*/
public Calculator(int precision) {
this.precision = precision;
this.history = new java.util.ArrayList<>();
}
/**
* Creates a Calculator with default precision of 2.
*/
public Calculator() {
this(2);
}
/**
* Calculate compound interest with multiple helper dependencies.
*
* @param principal Initial amount
* @param rate Interest rate (as decimal)
* @param time Time in years
* @param n Compounding frequency per year
* @return Compound interest result formatted as string
*/
public String calculateCompoundInterest(double principal, double rate, int time, int n) {
Formatter.validateInput(principal, "principal");
Formatter.validateInput(rate, "rate");
// Inefficient: recalculates power multiple times
double result = principal;
for (int i = 0; i < n * time; i++) {
result = MathHelper.multiply(result, MathHelper.add(1.0, rate / n));
}
double interest = result - principal;
history.add("compound:" + interest);
return Formatter.formatNumber(interest, precision);
}
/**
* Calculate permutation using factorial helper.
*
* @param n Total items
* @param r Items to choose
* @return Permutation result (n! / (n-r)!)
*/
public long permutation(int n, int r) {
if (n < r) {
return 0;
}
// Inefficient: calculates factorial(n) fully even when not needed
return MathHelper.factorial(n) / MathHelper.factorial(n - r);
}
/**
* Calculate combination (n choose r).
*
* @param n Total items
* @param r Items to choose
* @return Combination result (n! / (r! * (n-r)!))
*/
public long combination(int n, int r) {
if (n < r) {
return 0;
}
// Inefficient: calculates full factorials
return MathHelper.factorial(n) / (MathHelper.factorial(r) * MathHelper.factorial(n - r));
}
/**
* Calculate Fibonacci number at position n.
*
* @param n Position in Fibonacci sequence (0-indexed)
* @return Fibonacci number at position n
*/
public long fibonacci(int n) {
// Inefficient recursive implementation without memoization
if (n <= 1) {
return n;
}
return fibonacci(n - 1) + fibonacci(n - 2);
}
/**
* Static method for quick calculations.
*
* @param a First number
* @param b Second number
* @return Sum of a and b
*/
public static double quickAdd(double a, double b) {
return MathHelper.add(a, b);
}
/**
* Get calculation history.
*
* @return List of past calculations
*/
public java.util.List<String> getHistory() {
return new java.util.ArrayList<>(history);
}
/**
* Get current precision setting.
*
* @return precision value
*/
public int getPrecision() {
return precision;
}
}

View file

@ -0,0 +1,171 @@
package com.example;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
/**
* Data processing class with complex methods to optimize.
*/
public class DataProcessor {
/**
* Find duplicate elements in a list.
*
* @param list List to check for duplicates
* @param <T> Type of elements
* @return List of duplicate elements
*/
public static <T> List<T> findDuplicates(List<T> list) {
List<T> duplicates = new ArrayList<>();
if (list == null) {
return duplicates;
}
// Inefficient: O(n^2) nested loop
for (int i = 0; i < list.size(); i++) {
for (int j = i + 1; j < list.size(); j++) {
if (list.get(i).equals(list.get(j)) && !duplicates.contains(list.get(i))) {
duplicates.add(list.get(i));
}
}
}
return duplicates;
}
/**
* Group elements by a key function.
*
* @param list List to group
* @param keyExtractor Function to extract key from element
* @param <T> Type of elements
* @param <K> Type of key
* @return Map of key to list of elements
*/
public static <T, K> Map<K, List<T>> groupBy(List<T> list, java.util.function.Function<T, K> keyExtractor) {
Map<K, List<T>> result = new HashMap<>();
if (list == null) {
return result;
}
// Could use streams, but explicit loop for optimization opportunity
for (T item : list) {
K key = keyExtractor.apply(item);
if (!result.containsKey(key)) {
result.put(key, new ArrayList<>());
}
result.get(key).add(item);
}
return result;
}
/**
* Find intersection of two lists.
*
* @param list1 First list
* @param list2 Second list
* @param <T> Type of elements
* @return List of common elements
*/
public static <T> List<T> intersection(List<T> list1, List<T> list2) {
List<T> result = new ArrayList<>();
if (list1 == null || list2 == null) {
return result;
}
// Inefficient: O(n*m) nested loop
for (T item : list1) {
if (list2.contains(item) && !result.contains(item)) {
result.add(item);
}
}
return result;
}
/**
* Flatten a nested list structure.
*
* @param nestedList List of lists
* @param <T> Type of elements
* @return Flattened list
*/
public static <T> List<T> flatten(List<List<T>> nestedList) {
List<T> result = new ArrayList<>();
if (nestedList == null) {
return result;
}
// Simple but could be optimized with capacity hints
for (List<T> innerList : nestedList) {
if (innerList != null) {
result.addAll(innerList);
}
}
return result;
}
/**
* Count frequency of each element.
*
* @param list List to count
* @param <T> Type of elements
* @return Map of element to frequency
*/
public static <T> Map<T, Integer> countFrequency(List<T> list) {
Map<T, Integer> frequency = new HashMap<>();
if (list == null) {
return frequency;
}
for (T item : list) {
// Inefficient: could use merge or compute
if (frequency.containsKey(item)) {
frequency.put(item, frequency.get(item) + 1);
} else {
frequency.put(item, 1);
}
}
return frequency;
}
/**
* Find the nth most frequent element.
*
* @param list List to search
* @param n Position (1-based)
* @param <T> Type of elements
* @return nth most frequent element, or null if not found
*/
public static <T> T nthMostFrequent(List<T> list, int n) {
if (list == null || list.isEmpty() || n < 1) {
return null;
}
Map<T, Integer> frequency = countFrequency(list);
// Inefficient: sort all entries to find nth
List<Map.Entry<T, Integer>> entries = new ArrayList<>(frequency.entrySet());
entries.sort((e1, e2) -> e2.getValue().compareTo(e1.getValue()));
if (n > entries.size()) {
return null;
}
return entries.get(n - 1).getKey();
}
/**
* Partition list into chunks of specified size.
*
* @param list List to partition
* @param chunkSize Size of each chunk
* @param <T> Type of elements
* @return List of chunks
*/
public static <T> List<List<T>> partition(List<T> list, int chunkSize) {
List<List<T>> result = new ArrayList<>();
if (list == null || chunkSize <= 0) {
return result;
}
// Inefficient: creates sublists with copying
for (int i = 0; i < list.size(); i += chunkSize) {
int end = Math.min(i + chunkSize, list.size());
result.add(new ArrayList<>(list.subList(i, end)));
}
return result;
}
}

View file

@ -0,0 +1,131 @@
package com.example;
import java.util.ArrayList;
import java.util.List;
/**
* String utility class with methods to optimize.
*/
public class StringUtils {
/**
* Reverse a string character by character.
*
* @param str String to reverse
* @return Reversed string
*/
public static String reverse(String str) {
if (str == null || str.isEmpty()) {
return str;
}
// Inefficient: string concatenation in loop
String result = "";
for (int i = str.length() - 1; i >= 0; i--) {
result = result + str.charAt(i);
}
return result;
}
/**
* Check if a string is a palindrome.
*
* @param str String to check
* @return true if palindrome, false otherwise
*/
public static boolean isPalindrome(String str) {
if (str == null) {
return false;
}
// Inefficient: creates reversed string instead of comparing in place
String reversed = reverse(str.toLowerCase().replaceAll("\\s+", ""));
String cleaned = str.toLowerCase().replaceAll("\\s+", "");
return cleaned.equals(reversed);
}
/**
* Count occurrences of a substring.
*
* @param str String to search in
* @param sub Substring to find
* @return Number of occurrences
*/
public static int countOccurrences(String str, String sub) {
if (str == null || sub == null || sub.isEmpty()) {
return 0;
}
// Inefficient: creates many intermediate strings
int count = 0;
int index = 0;
while ((index = str.indexOf(sub, index)) != -1) {
count++;
index++;
}
return count;
}
/**
* Find all anagrams of a word in a text.
*
* @param text Text to search in
* @param word Word to find anagrams of
* @return List of starting indices of anagrams
*/
public static List<Integer> findAnagrams(String text, String word) {
List<Integer> result = new ArrayList<>();
if (text == null || word == null || text.length() < word.length()) {
return result;
}
// Inefficient: recalculates sorted word for each position
int wordLen = word.length();
for (int i = 0; i <= text.length() - wordLen; i++) {
String window = text.substring(i, i + wordLen);
if (isAnagram(window, word)) {
result.add(i);
}
}
return result;
}
/**
* Check if two strings are anagrams.
*
* @param s1 First string
* @param s2 Second string
* @return true if anagrams, false otherwise
*/
public static boolean isAnagram(String s1, String s2) {
if (s1 == null || s2 == null || s1.length() != s2.length()) {
return false;
}
// Inefficient: sorts both strings
char[] arr1 = s1.toLowerCase().toCharArray();
char[] arr2 = s2.toLowerCase().toCharArray();
java.util.Arrays.sort(arr1);
java.util.Arrays.sort(arr2);
return java.util.Arrays.equals(arr1, arr2);
}
/**
* Find longest common prefix of an array of strings.
*
* @param strings Array of strings
* @return Longest common prefix
*/
public static String longestCommonPrefix(String[] strings) {
if (strings == null || strings.length == 0) {
return "";
}
// Inefficient: vertical scanning approach
String prefix = strings[0];
for (int i = 1; i < strings.length; i++) {
while (strings[i].indexOf(prefix) != 0) {
prefix = prefix.substring(0, prefix.length() - 1);
if (prefix.isEmpty()) {
return "";
}
}
}
return prefix;
}
}

View file

@ -0,0 +1,74 @@
package com.example.helpers;
/**
* Formatting utility functions.
*/
public class Formatter {
/**
* Format a number with specified decimal places.
*
* @param value Number to format
* @param decimals Number of decimal places
* @return Formatted number as string
*/
public static String formatNumber(double value, int decimals) {
return String.format("%." + decimals + "f", value);
}
/**
* Validate that input is a positive number.
*
* @param value Value to validate
* @param name Name of the parameter (for error message)
* @throws IllegalArgumentException if value is not positive
*/
public static void validateInput(double value, String name) {
if (value < 0) {
throw new IllegalArgumentException(name + " must be non-negative, got: " + value);
}
}
/**
* Convert number to percentage string.
*
* @param value Decimal value (0.5 = 50%)
* @return Percentage string
*/
public static String toPercentage(double value) {
return formatNumber(value * 100, 2) + "%";
}
/**
* Pad a string to specified length.
*
* @param str String to pad
* @param length Target length
* @param padChar Character to pad with
* @return Padded string
*/
public static String padLeft(String str, int length, char padChar) {
// Inefficient: creates many intermediate strings
StringBuilder result = new StringBuilder(str);
while (result.length() < length) {
result.insert(0, padChar);
}
return result.toString();
}
/**
* Repeat a string n times.
*
* @param str String to repeat
* @param times Number of repetitions
* @return Repeated string
*/
public static String repeat(String str, int times) {
// Inefficient: string concatenation in loop
String result = "";
for (int i = 0; i < times; i++) {
result = result + str;
}
return result;
}
}

View file

@ -0,0 +1,108 @@
package com.example.helpers;
/**
* Math utility functions - basic arithmetic operations.
*/
public class MathHelper {
/**
* Add two numbers.
*
* @param a First number
* @param b Second number
* @return Sum of a and b
*/
public static double add(double a, double b) {
return a + b;
}
/**
* Multiply two numbers.
*
* @param a First number
* @param b Second number
* @return Product of a and b
*/
public static double multiply(double a, double b) {
return a * b;
}
/**
* Calculate factorial recursively.
*
* @param n Non-negative integer
* @return Factorial of n
* @throws IllegalArgumentException if n is negative
*/
public static long factorial(int n) {
if (n < 0) {
throw new IllegalArgumentException("Factorial not defined for negative numbers");
}
// Intentionally inefficient recursive implementation
if (n <= 1) {
return 1;
}
return n * factorial(n - 1);
}
/**
* Calculate power using repeated multiplication.
*
* @param base Base number
* @param exp Exponent (non-negative)
* @return base raised to exp
*/
public static double power(double base, int exp) {
// Inefficient: linear time instead of log time
double result = 1;
for (int i = 0; i < exp; i++) {
result = multiply(result, base);
}
return result;
}
/**
* Check if a number is prime.
*
* @param n Number to check
* @return true if n is prime, false otherwise
*/
public static boolean isPrime(int n) {
if (n < 2) {
return false;
}
// Inefficient: checks all numbers up to n-1
for (int i = 2; i < n; i++) {
if (n % i == 0) {
return false;
}
}
return true;
}
/**
* Calculate greatest common divisor using Euclidean algorithm.
*
* @param a First number
* @param b Second number
* @return GCD of a and b
*/
public static int gcd(int a, int b) {
// Inefficient recursive implementation
if (b == 0) {
return a;
}
return gcd(b, a % b);
}
/**
* Calculate least common multiple.
*
* @param a First number
* @param b Second number
* @return LCM of a and b
*/
public static int lcm(int a, int b) {
return (a * b) / gcd(a, b);
}
}

View file

@ -0,0 +1,170 @@
package com.example;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.DisplayName;
import org.junit.jupiter.api.Nested;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.CsvSource;
import static org.junit.jupiter.api.Assertions.*;
/**
* Tests for the Calculator class.
*/
@DisplayName("Calculator Tests")
class CalculatorTest {
private Calculator calculator;
@BeforeEach
void setUp() {
calculator = new Calculator(2);
}
@Nested
@DisplayName("Compound Interest Tests")
class CompoundInterestTests {
@Test
@DisplayName("should calculate compound interest for basic case")
void testBasicCompoundInterest() {
String result = calculator.calculateCompoundInterest(1000.0, 0.05, 1, 12);
assertNotNull(result);
assertTrue(result.contains("."));
}
@Test
@DisplayName("should handle zero principal")
void testZeroPrincipal() {
String result = calculator.calculateCompoundInterest(0.0, 0.05, 1, 12);
assertEquals("0.00", result);
}
@Test
@DisplayName("should throw on negative principal")
void testNegativePrincipal() {
assertThrows(IllegalArgumentException.class, () ->
calculator.calculateCompoundInterest(-100.0, 0.05, 1, 12)
);
}
@ParameterizedTest
@CsvSource({
"1000, 0.05, 1, 12",
"5000, 0.08, 2, 4",
"10000, 0.03, 5, 1"
})
@DisplayName("should calculate for various inputs")
void testVariousInputs(double principal, double rate, int time, int n) {
String result = calculator.calculateCompoundInterest(principal, rate, time, n);
assertNotNull(result);
assertFalse(result.isEmpty());
}
}
@Nested
@DisplayName("Permutation Tests")
class PermutationTests {
@Test
@DisplayName("should calculate permutation correctly")
void testBasicPermutation() {
assertEquals(120, calculator.permutation(5, 5));
assertEquals(60, calculator.permutation(5, 3));
assertEquals(20, calculator.permutation(5, 2));
}
@Test
@DisplayName("should return 0 when n < r")
void testInvalidPermutation() {
assertEquals(0, calculator.permutation(3, 5));
}
@Test
@DisplayName("should handle edge cases")
void testEdgeCases() {
assertEquals(1, calculator.permutation(5, 0));
assertEquals(1, calculator.permutation(0, 0));
}
}
@Nested
@DisplayName("Combination Tests")
class CombinationTests {
@Test
@DisplayName("should calculate combination correctly")
void testBasicCombination() {
assertEquals(10, calculator.combination(5, 3));
assertEquals(10, calculator.combination(5, 2));
assertEquals(1, calculator.combination(5, 5));
}
@Test
@DisplayName("should return 0 when n < r")
void testInvalidCombination() {
assertEquals(0, calculator.combination(3, 5));
}
}
@Nested
@DisplayName("Fibonacci Tests")
class FibonacciTests {
@Test
@DisplayName("should calculate fibonacci correctly")
void testFibonacci() {
assertEquals(0, calculator.fibonacci(0));
assertEquals(1, calculator.fibonacci(1));
assertEquals(1, calculator.fibonacci(2));
assertEquals(2, calculator.fibonacci(3));
assertEquals(5, calculator.fibonacci(5));
assertEquals(55, calculator.fibonacci(10));
}
@ParameterizedTest
@CsvSource({
"0, 0",
"1, 1",
"2, 1",
"3, 2",
"4, 3",
"5, 5",
"6, 8",
"7, 13"
})
@DisplayName("should match expected sequence")
void testFibonacciSequence(int n, long expected) {
assertEquals(expected, calculator.fibonacci(n));
}
}
@Test
@DisplayName("static quickAdd should work correctly")
void testQuickAdd() {
assertEquals(15.0, Calculator.quickAdd(10.0, 5.0));
assertEquals(0.0, Calculator.quickAdd(-5.0, 5.0));
assertEquals(-10.0, Calculator.quickAdd(-5.0, -5.0));
}
@Test
@DisplayName("should track calculation history")
void testHistory() {
calculator.calculateCompoundInterest(1000.0, 0.05, 1, 12);
calculator.calculateCompoundInterest(2000.0, 0.03, 2, 4);
var history = calculator.getHistory();
assertEquals(2, history.size());
assertTrue(history.get(0).startsWith("compound:"));
}
@Test
@DisplayName("should return correct precision")
void testPrecision() {
assertEquals(2, calculator.getPrecision());
Calculator customCalc = new Calculator(4);
assertEquals(4, customCalc.getPrecision());
}
}

View file

@ -0,0 +1,265 @@
package com.example;
import org.junit.jupiter.api.DisplayName;
import org.junit.jupiter.api.Nested;
import org.junit.jupiter.api.Test;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import static org.junit.jupiter.api.Assertions.*;
/**
* Tests for the DataProcessor class.
*/
@DisplayName("DataProcessor Tests")
class DataProcessorTest {
@Nested
@DisplayName("findDuplicates() Tests")
class FindDuplicatesTests {
@Test
@DisplayName("should find duplicates in list")
void testFindDuplicates() {
List<Integer> input = Arrays.asList(1, 2, 3, 2, 4, 3, 5);
List<Integer> duplicates = DataProcessor.findDuplicates(input);
assertEquals(2, duplicates.size());
assertTrue(duplicates.contains(2));
assertTrue(duplicates.contains(3));
}
@Test
@DisplayName("should return empty for no duplicates")
void testNoDuplicates() {
List<Integer> input = Arrays.asList(1, 2, 3, 4, 5);
List<Integer> duplicates = DataProcessor.findDuplicates(input);
assertTrue(duplicates.isEmpty());
}
@Test
@DisplayName("should handle null input")
void testNullInput() {
List<Integer> duplicates = DataProcessor.findDuplicates(null);
assertTrue(duplicates.isEmpty());
}
@Test
@DisplayName("should handle strings")
void testStrings() {
List<String> input = Arrays.asList("a", "b", "a", "c", "b", "d");
List<String> duplicates = DataProcessor.findDuplicates(input);
assertEquals(2, duplicates.size());
assertTrue(duplicates.contains("a"));
assertTrue(duplicates.contains("b"));
}
}
@Nested
@DisplayName("groupBy() Tests")
class GroupByTests {
@Test
@DisplayName("should group by length")
void testGroupByLength() {
List<String> input = Arrays.asList("a", "bb", "ccc", "dd", "e", "fff");
Map<Integer, List<String>> grouped = DataProcessor.groupBy(input, String::length);
assertEquals(3, grouped.size());
assertEquals(2, grouped.get(1).size());
assertEquals(2, grouped.get(2).size());
assertEquals(2, grouped.get(3).size());
}
@Test
@DisplayName("should group by first character")
void testGroupByFirstChar() {
List<String> input = Arrays.asList("apple", "apricot", "banana", "blueberry");
Map<Character, List<String>> grouped = DataProcessor.groupBy(input, s -> s.charAt(0));
assertEquals(2, grouped.size());
assertEquals(2, grouped.get('a').size());
assertEquals(2, grouped.get('b').size());
}
@Test
@DisplayName("should handle null input")
void testNullInput() {
Map<Integer, List<String>> grouped = DataProcessor.groupBy(null, String::length);
assertTrue(grouped.isEmpty());
}
}
@Nested
@DisplayName("intersection() Tests")
class IntersectionTests {
@Test
@DisplayName("should find intersection")
void testIntersection() {
List<Integer> list1 = Arrays.asList(1, 2, 3, 4, 5);
List<Integer> list2 = Arrays.asList(4, 5, 6, 7, 8);
List<Integer> result = DataProcessor.intersection(list1, list2);
assertEquals(2, result.size());
assertTrue(result.contains(4));
assertTrue(result.contains(5));
}
@Test
@DisplayName("should return empty for no intersection")
void testNoIntersection() {
List<Integer> list1 = Arrays.asList(1, 2, 3);
List<Integer> list2 = Arrays.asList(4, 5, 6);
List<Integer> result = DataProcessor.intersection(list1, list2);
assertTrue(result.isEmpty());
}
@Test
@DisplayName("should handle null inputs")
void testNullInputs() {
assertTrue(DataProcessor.intersection(null, Arrays.asList(1, 2, 3)).isEmpty());
assertTrue(DataProcessor.intersection(Arrays.asList(1, 2, 3), null).isEmpty());
}
@Test
@DisplayName("should not include duplicates")
void testNoDuplicates() {
List<Integer> list1 = Arrays.asList(1, 1, 2, 2, 3);
List<Integer> list2 = Arrays.asList(1, 2, 2, 4);
List<Integer> result = DataProcessor.intersection(list1, list2);
assertEquals(2, result.size());
}
}
@Nested
@DisplayName("flatten() Tests")
class FlattenTests {
@Test
@DisplayName("should flatten nested lists")
void testFlatten() {
List<List<Integer>> nested = Arrays.asList(
Arrays.asList(1, 2, 3),
Arrays.asList(4, 5),
Arrays.asList(6, 7, 8, 9)
);
List<Integer> result = DataProcessor.flatten(nested);
assertEquals(9, result.size());
assertEquals(Arrays.asList(1, 2, 3, 4, 5, 6, 7, 8, 9), result);
}
@Test
@DisplayName("should handle empty inner lists")
void testEmptyInnerLists() {
List<List<Integer>> nested = Arrays.asList(
Arrays.asList(1, 2),
Collections.emptyList(),
Arrays.asList(3, 4)
);
List<Integer> result = DataProcessor.flatten(nested);
assertEquals(4, result.size());
}
@Test
@DisplayName("should handle null")
void testNull() {
assertTrue(DataProcessor.flatten(null).isEmpty());
}
}
@Nested
@DisplayName("countFrequency() Tests")
class CountFrequencyTests {
@Test
@DisplayName("should count frequencies correctly")
void testCountFrequency() {
List<String> input = Arrays.asList("a", "b", "a", "c", "a", "b");
Map<String, Integer> freq = DataProcessor.countFrequency(input);
assertEquals(3, freq.get("a"));
assertEquals(2, freq.get("b"));
assertEquals(1, freq.get("c"));
}
@Test
@DisplayName("should handle null input")
void testNullInput() {
assertTrue(DataProcessor.countFrequency(null).isEmpty());
}
}
@Nested
@DisplayName("nthMostFrequent() Tests")
class NthMostFrequentTests {
@Test
@DisplayName("should find nth most frequent")
void testNthMostFrequent() {
List<String> input = Arrays.asList("a", "b", "a", "c", "a", "b", "d");
assertEquals("a", DataProcessor.nthMostFrequent(input, 1));
assertEquals("b", DataProcessor.nthMostFrequent(input, 2));
}
@Test
@DisplayName("should return null for invalid n")
void testInvalidN() {
List<String> input = Arrays.asList("a", "b", "c");
assertNull(DataProcessor.nthMostFrequent(input, 0));
assertNull(DataProcessor.nthMostFrequent(input, 10));
}
@Test
@DisplayName("should handle null input")
void testNullInput() {
assertNull(DataProcessor.nthMostFrequent(null, 1));
}
}
@Nested
@DisplayName("partition() Tests")
class PartitionTests {
@Test
@DisplayName("should partition into chunks")
void testPartition() {
List<Integer> input = Arrays.asList(1, 2, 3, 4, 5, 6, 7);
List<List<Integer>> chunks = DataProcessor.partition(input, 3);
assertEquals(3, chunks.size());
assertEquals(Arrays.asList(1, 2, 3), chunks.get(0));
assertEquals(Arrays.asList(4, 5, 6), chunks.get(1));
assertEquals(Collections.singletonList(7), chunks.get(2));
}
@Test
@DisplayName("should handle exact division")
void testExactDivision() {
List<Integer> input = Arrays.asList(1, 2, 3, 4, 5, 6);
List<List<Integer>> chunks = DataProcessor.partition(input, 2);
assertEquals(3, chunks.size());
chunks.forEach(chunk -> assertEquals(2, chunk.size()));
}
@Test
@DisplayName("should handle null and invalid chunk size")
void testInvalidInputs() {
assertTrue(DataProcessor.partition(null, 3).isEmpty());
assertTrue(DataProcessor.partition(Arrays.asList(1, 2, 3), 0).isEmpty());
assertTrue(DataProcessor.partition(Arrays.asList(1, 2, 3), -1).isEmpty());
}
}
}

View file

@ -0,0 +1,219 @@
package com.example;
import org.junit.jupiter.api.DisplayName;
import org.junit.jupiter.api.Nested;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.CsvSource;
import org.junit.jupiter.params.provider.NullAndEmptySource;
import org.junit.jupiter.params.provider.ValueSource;
import java.util.List;
import static org.junit.jupiter.api.Assertions.*;
/**
* Tests for the StringUtils class.
*/
@DisplayName("StringUtils Tests")
class StringUtilsTest {
@Nested
@DisplayName("reverse() Tests")
class ReverseTests {
@Test
@DisplayName("should reverse a simple string")
void testReverseSimple() {
assertEquals("olleh", StringUtils.reverse("hello"));
assertEquals("dlrow", StringUtils.reverse("world"));
}
@Test
@DisplayName("should handle single character")
void testReverseSingleChar() {
assertEquals("a", StringUtils.reverse("a"));
}
@ParameterizedTest
@NullAndEmptySource
@DisplayName("should handle null and empty strings")
void testReverseNullEmpty(String input) {
assertEquals(input, StringUtils.reverse(input));
}
@Test
@DisplayName("should handle palindrome")
void testReversePalindrome() {
assertEquals("radar", StringUtils.reverse("radar"));
}
}
@Nested
@DisplayName("isPalindrome() Tests")
class PalindromeTests {
@ParameterizedTest
@ValueSource(strings = {"radar", "level", "civic", "rotor", "kayak"})
@DisplayName("should return true for palindromes")
void testPalindromes(String input) {
assertTrue(StringUtils.isPalindrome(input));
}
@ParameterizedTest
@ValueSource(strings = {"hello", "world", "java", "python"})
@DisplayName("should return false for non-palindromes")
void testNonPalindromes(String input) {
assertFalse(StringUtils.isPalindrome(input));
}
@Test
@DisplayName("should handle case insensitivity")
void testCaseInsensitive() {
assertTrue(StringUtils.isPalindrome("Radar"));
assertTrue(StringUtils.isPalindrome("LEVEL"));
}
@Test
@DisplayName("should ignore spaces")
void testIgnoreSpaces() {
assertTrue(StringUtils.isPalindrome("race car"));
assertTrue(StringUtils.isPalindrome("A man a plan a canal Panama"));
}
@Test
@DisplayName("should return false for null")
void testNull() {
assertFalse(StringUtils.isPalindrome(null));
}
}
@Nested
@DisplayName("countOccurrences() Tests")
class CountOccurrencesTests {
@Test
@DisplayName("should count occurrences correctly")
void testCount() {
assertEquals(3, StringUtils.countOccurrences("abcabc abc", "abc"));
assertEquals(2, StringUtils.countOccurrences("hello hello", "hello"));
}
@Test
@DisplayName("should return 0 for no matches")
void testNoMatches() {
assertEquals(0, StringUtils.countOccurrences("hello world", "xyz"));
}
@ParameterizedTest
@CsvSource({
"'aaaaaa', 'aa', 5",
"'banana', 'ana', 2",
"'mississippi', 'issi', 2"
})
@DisplayName("should handle overlapping matches")
void testOverlapping(String str, String sub, int expected) {
assertEquals(expected, StringUtils.countOccurrences(str, sub));
}
@Test
@DisplayName("should handle null inputs")
void testNullInputs() {
assertEquals(0, StringUtils.countOccurrences(null, "test"));
assertEquals(0, StringUtils.countOccurrences("test", null));
assertEquals(0, StringUtils.countOccurrences("test", ""));
}
}
@Nested
@DisplayName("isAnagram() Tests")
class AnagramTests {
@Test
@DisplayName("should detect anagrams")
void testAnagrams() {
assertTrue(StringUtils.isAnagram("listen", "silent"));
assertTrue(StringUtils.isAnagram("evil", "vile"));
assertTrue(StringUtils.isAnagram("anagram", "nagaram"));
}
@Test
@DisplayName("should reject non-anagrams")
void testNonAnagrams() {
assertFalse(StringUtils.isAnagram("hello", "world"));
assertFalse(StringUtils.isAnagram("abc", "abcd"));
}
@Test
@DisplayName("should be case insensitive")
void testCaseInsensitive() {
assertTrue(StringUtils.isAnagram("Listen", "Silent"));
}
@Test
@DisplayName("should handle null inputs")
void testNullInputs() {
assertFalse(StringUtils.isAnagram(null, "test"));
assertFalse(StringUtils.isAnagram("test", null));
}
}
@Nested
@DisplayName("findAnagrams() Tests")
class FindAnagramsTests {
@Test
@DisplayName("should find all anagram positions")
void testFindAnagrams() {
List<Integer> result = StringUtils.findAnagrams("cbaebabacd", "abc");
assertEquals(2, result.size());
assertTrue(result.contains(0));
assertTrue(result.contains(6));
}
@Test
@DisplayName("should return empty list for no matches")
void testNoMatches() {
List<Integer> result = StringUtils.findAnagrams("hello", "xyz");
assertTrue(result.isEmpty());
}
@Test
@DisplayName("should handle null inputs")
void testNullInputs() {
assertTrue(StringUtils.findAnagrams(null, "abc").isEmpty());
assertTrue(StringUtils.findAnagrams("abc", null).isEmpty());
}
}
@Nested
@DisplayName("longestCommonPrefix() Tests")
class LongestCommonPrefixTests {
@Test
@DisplayName("should find common prefix")
void testCommonPrefix() {
assertEquals("fl", StringUtils.longestCommonPrefix(new String[]{"flower", "flow", "flight"}));
assertEquals("ap", StringUtils.longestCommonPrefix(new String[]{"apple", "ape", "april"}));
}
@Test
@DisplayName("should return empty for no common prefix")
void testNoCommonPrefix() {
assertEquals("", StringUtils.longestCommonPrefix(new String[]{"dog", "car", "race"}));
}
@Test
@DisplayName("should handle single string")
void testSingleString() {
assertEquals("hello", StringUtils.longestCommonPrefix(new String[]{"hello"}));
}
@Test
@DisplayName("should handle null and empty array")
void testNullEmpty() {
assertEquals("", StringUtils.longestCommonPrefix(null));
assertEquals("", StringUtils.longestCommonPrefix(new String[]{}));
}
}
}

View file

@ -29,17 +29,20 @@ class TestLanguageEnum:
assert Language.PYTHON.value == "python"
assert Language.JAVASCRIPT.value == "javascript"
assert Language.TYPESCRIPT.value == "typescript"
assert Language.JAVA.value == "java"
def test_language_str(self):
"""Test string conversion of Language enum."""
assert str(Language.PYTHON) == "python"
assert str(Language.JAVASCRIPT) == "javascript"
assert str(Language.JAVA) == "java"
def test_language_from_string(self):
"""Test creating Language from string."""
assert Language("python") == Language.PYTHON
assert Language("javascript") == Language.JAVASCRIPT
assert Language("typescript") == Language.TYPESCRIPT
assert Language("java") == Language.JAVA
def test_invalid_language_raises(self):
"""Test that invalid language string raises ValueError."""

View file

@ -0,0 +1 @@
"""Tests for Java language support."""

View file

@ -0,0 +1,279 @@
"""Tests for Java build tool detection and integration."""
import tempfile
from pathlib import Path
import pytest
from codeflash.languages.java.build_tools import (
BuildTool,
detect_build_tool,
find_maven_executable,
find_source_root,
find_test_root,
get_project_info,
)
class TestBuildToolDetection:
"""Tests for build tool detection."""
def test_detect_maven_project(self, tmp_path: Path):
"""Test detecting a Maven project."""
# Create pom.xml
pom_content = """<?xml version="1.0" encoding="UTF-8"?>
<project xmlns="http://maven.apache.org/POM/4.0.0"
xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0
http://maven.apache.org/xsd/maven-4.0.0.xsd">
<modelVersion>4.0.0</modelVersion>
<groupId>com.example</groupId>
<artifactId>my-app</artifactId>
<version>1.0.0</version>
</project>
"""
(tmp_path / "pom.xml").write_text(pom_content)
assert detect_build_tool(tmp_path) == BuildTool.MAVEN
def test_detect_gradle_project(self, tmp_path: Path):
"""Test detecting a Gradle project."""
# Create build.gradle
(tmp_path / "build.gradle").write_text("plugins { id 'java' }")
assert detect_build_tool(tmp_path) == BuildTool.GRADLE
def test_detect_gradle_kotlin_project(self, tmp_path: Path):
"""Test detecting a Gradle Kotlin DSL project."""
# Create build.gradle.kts
(tmp_path / "build.gradle.kts").write_text('plugins { java }')
assert detect_build_tool(tmp_path) == BuildTool.GRADLE
def test_detect_unknown_project(self, tmp_path: Path):
"""Test detecting unknown project type."""
# Empty directory
assert detect_build_tool(tmp_path) == BuildTool.UNKNOWN
def test_maven_takes_precedence(self, tmp_path: Path):
"""Test that Maven takes precedence if both exist."""
# Create both pom.xml and build.gradle
(tmp_path / "pom.xml").write_text("<project></project>")
(tmp_path / "build.gradle").write_text("plugins { id 'java' }")
# Maven should be detected first
assert detect_build_tool(tmp_path) == BuildTool.MAVEN
class TestMavenProjectInfo:
"""Tests for Maven project info extraction."""
def test_get_maven_project_info(self, tmp_path: Path):
"""Test extracting project info from pom.xml."""
pom_content = """<?xml version="1.0" encoding="UTF-8"?>
<project xmlns="http://maven.apache.org/POM/4.0.0"
xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0
http://maven.apache.org/xsd/maven-4.0.0.xsd">
<modelVersion>4.0.0</modelVersion>
<groupId>com.example</groupId>
<artifactId>my-app</artifactId>
<version>1.0.0</version>
<properties>
<maven.compiler.source>11</maven.compiler.source>
<maven.compiler.target>11</maven.compiler.target>
</properties>
</project>
"""
(tmp_path / "pom.xml").write_text(pom_content)
# Create standard Maven directory structure
(tmp_path / "src" / "main" / "java").mkdir(parents=True)
(tmp_path / "src" / "test" / "java").mkdir(parents=True)
info = get_project_info(tmp_path)
assert info is not None
assert info.build_tool == BuildTool.MAVEN
assert info.group_id == "com.example"
assert info.artifact_id == "my-app"
assert info.version == "1.0.0"
assert info.java_version == "11"
assert len(info.source_roots) == 1
assert len(info.test_roots) == 1
def test_get_maven_project_info_with_java_version_property(self, tmp_path: Path):
"""Test extracting Java version from java.version property."""
pom_content = """<?xml version="1.0" encoding="UTF-8"?>
<project>
<modelVersion>4.0.0</modelVersion>
<groupId>com.example</groupId>
<artifactId>my-app</artifactId>
<version>1.0.0</version>
<properties>
<java.version>17</java.version>
</properties>
</project>
"""
(tmp_path / "pom.xml").write_text(pom_content)
(tmp_path / "src" / "main" / "java").mkdir(parents=True)
info = get_project_info(tmp_path)
assert info is not None
assert info.java_version == "17"
class TestDirectoryDetection:
"""Tests for source and test directory detection."""
def test_find_maven_source_root(self, tmp_path: Path):
"""Test finding Maven source root."""
(tmp_path / "pom.xml").write_text("<project></project>")
src_root = tmp_path / "src" / "main" / "java"
src_root.mkdir(parents=True)
result = find_source_root(tmp_path)
assert result is not None
assert result == src_root
def test_find_maven_test_root(self, tmp_path: Path):
"""Test finding Maven test root."""
(tmp_path / "pom.xml").write_text("<project></project>")
test_root = tmp_path / "src" / "test" / "java"
test_root.mkdir(parents=True)
result = find_test_root(tmp_path)
assert result is not None
assert result == test_root
def test_find_source_root_not_found(self, tmp_path: Path):
"""Test when source root doesn't exist."""
result = find_source_root(tmp_path)
assert result is None
def test_find_test_root_not_found(self, tmp_path: Path):
"""Test when test root doesn't exist."""
result = find_test_root(tmp_path)
assert result is None
def test_find_alternative_test_root(self, tmp_path: Path):
"""Test finding alternative test directory."""
# Create a 'test' directory (non-Maven style)
test_dir = tmp_path / "test"
test_dir.mkdir()
result = find_test_root(tmp_path)
assert result is not None
assert result == test_dir
class TestMavenExecutable:
"""Tests for Maven executable detection."""
def test_find_maven_executable_system(self):
"""Test finding system Maven."""
# This test may pass or fail depending on whether Maven is installed
mvn = find_maven_executable()
# We can't assert it exists, just that the function doesn't crash
if mvn:
assert "mvn" in mvn.lower() or "maven" in mvn.lower()
def test_find_maven_wrapper(self, tmp_path: Path, monkeypatch):
"""Test finding Maven wrapper."""
# Create mvnw file
mvnw_path = tmp_path / "mvnw"
mvnw_path.write_text("#!/bin/bash\necho 'Maven Wrapper'")
mvnw_path.chmod(0o755)
# Change to tmp_path
monkeypatch.chdir(tmp_path)
mvn = find_maven_executable()
# Should find the wrapper
assert mvn is not None
class TestPomXmlParsing:
"""Tests for pom.xml parsing edge cases."""
def test_pom_without_namespace(self, tmp_path: Path):
"""Test parsing pom.xml without XML namespace."""
pom_content = """<?xml version="1.0" encoding="UTF-8"?>
<project>
<modelVersion>4.0.0</modelVersion>
<groupId>com.example</groupId>
<artifactId>simple-app</artifactId>
<version>1.0</version>
</project>
"""
(tmp_path / "pom.xml").write_text(pom_content)
(tmp_path / "src" / "main" / "java").mkdir(parents=True)
info = get_project_info(tmp_path)
assert info is not None
assert info.group_id == "com.example"
assert info.artifact_id == "simple-app"
def test_pom_with_parent(self, tmp_path: Path):
"""Test parsing pom.xml with parent POM."""
pom_content = """<?xml version="1.0" encoding="UTF-8"?>
<project xmlns="http://maven.apache.org/POM/4.0.0">
<modelVersion>4.0.0</modelVersion>
<parent>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-starter-parent</artifactId>
<version>3.0.0</version>
</parent>
<groupId>com.example</groupId>
<artifactId>child-app</artifactId>
<version>1.0</version>
</project>
"""
(tmp_path / "pom.xml").write_text(pom_content)
(tmp_path / "src" / "main" / "java").mkdir(parents=True)
info = get_project_info(tmp_path)
assert info is not None
assert info.artifact_id == "child-app"
def test_invalid_pom_xml(self, tmp_path: Path):
"""Test handling invalid pom.xml."""
# Create invalid XML
(tmp_path / "pom.xml").write_text("this is not valid xml")
info = get_project_info(tmp_path)
# Should return None or handle gracefully
assert info is None
class TestGradleProjectInfo:
"""Tests for Gradle project info extraction."""
def test_get_gradle_project_info(self, tmp_path: Path):
"""Test extracting basic Gradle project info."""
(tmp_path / "build.gradle").write_text("""
plugins {
id 'java'
}
group = 'com.example'
version = '1.0.0'
""")
# Create standard Gradle directory structure
(tmp_path / "src" / "main" / "java").mkdir(parents=True)
(tmp_path / "src" / "test" / "java").mkdir(parents=True)
info = get_project_info(tmp_path)
assert info is not None
assert info.build_tool == BuildTool.GRADLE
assert len(info.source_roots) == 1
assert len(info.test_roots) == 1

View file

@ -0,0 +1,310 @@
"""Tests for Java test result comparison."""
import json
import sqlite3
import tempfile
from pathlib import Path
import pytest
from codeflash.languages.java.comparator import (
compare_invocations_directly,
compare_test_results,
)
from codeflash.models.models import TestDiffScope
class TestDirectComparison:
"""Tests for direct Python-based comparison."""
def test_identical_results(self):
"""Test comparing identical results."""
original = {
"1": {"result_json": '{"value": 42}', "error_json": None},
"2": {"result_json": '{"value": 100}', "error_json": None},
}
candidate = {
"1": {"result_json": '{"value": 42}', "error_json": None},
"2": {"result_json": '{"value": 100}', "error_json": None},
}
equivalent, diffs = compare_invocations_directly(original, candidate)
assert equivalent is True
assert len(diffs) == 0
def test_different_return_values(self):
"""Test detecting different return values."""
original = {
"1": {"result_json": '{"value": 42}', "error_json": None},
}
candidate = {
"1": {"result_json": '{"value": 99}', "error_json": None},
}
equivalent, diffs = compare_invocations_directly(original, candidate)
assert equivalent is False
assert len(diffs) == 1
assert diffs[0].scope == TestDiffScope.RETURN_VALUE
assert diffs[0].original_value == '{"value": 42}'
assert diffs[0].candidate_value == '{"value": 99}'
def test_missing_invocation_in_candidate(self):
"""Test detecting missing invocation in candidate."""
original = {
"1": {"result_json": '{"value": 42}', "error_json": None},
"2": {"result_json": '{"value": 100}', "error_json": None},
}
candidate = {
"1": {"result_json": '{"value": 42}', "error_json": None},
# Missing invocation 2
}
equivalent, diffs = compare_invocations_directly(original, candidate)
assert equivalent is False
assert len(diffs) == 1
assert diffs[0].candidate_pass is False
def test_extra_invocation_in_candidate(self):
"""Test detecting extra invocation in candidate."""
original = {
"1": {"result_json": '{"value": 42}', "error_json": None},
}
candidate = {
"1": {"result_json": '{"value": 42}', "error_json": None},
"2": {"result_json": '{"value": 100}', "error_json": None}, # Extra
}
equivalent, diffs = compare_invocations_directly(original, candidate)
# Having extra invocations is noted but doesn't necessarily fail
assert len(diffs) == 1
def test_exception_differences(self):
"""Test detecting exception differences."""
original = {
"1": {"result_json": None, "error_json": '{"type": "NullPointerException"}'},
}
candidate = {
"1": {"result_json": '{"value": 42}', "error_json": None}, # No exception
}
equivalent, diffs = compare_invocations_directly(original, candidate)
assert equivalent is False
assert len(diffs) == 1
assert diffs[0].scope == TestDiffScope.DID_PASS
def test_empty_results(self):
"""Test comparing empty results."""
original = {}
candidate = {}
equivalent, diffs = compare_invocations_directly(original, candidate)
assert equivalent is True
assert len(diffs) == 0
class TestSqliteComparison:
"""Tests for SQLite-based comparison (requires Java runtime)."""
@pytest.fixture
def create_test_db(self):
"""Create a test SQLite database with invocations table."""
def _create(path: Path, invocations: list[dict]):
conn = sqlite3.connect(path)
cursor = conn.cursor()
cursor.execute(
"""
CREATE TABLE invocations (
call_id INTEGER PRIMARY KEY,
method_id TEXT NOT NULL,
args_json TEXT,
result_json TEXT,
error_json TEXT,
start_time INTEGER,
end_time INTEGER
)
"""
)
for inv in invocations:
cursor.execute(
"""
INSERT INTO invocations (call_id, method_id, args_json, result_json, error_json)
VALUES (?, ?, ?, ?, ?)
""",
(
inv.get("call_id"),
inv.get("method_id", "test.method"),
inv.get("args_json"),
inv.get("result_json"),
inv.get("error_json"),
),
)
conn.commit()
conn.close()
return path
return _create
def test_compare_test_results_missing_original(self, tmp_path: Path):
"""Test comparison when original DB is missing."""
original_path = tmp_path / "original.db" # Doesn't exist
candidate_path = tmp_path / "candidate.db"
candidate_path.touch()
equivalent, diffs = compare_test_results(original_path, candidate_path)
assert equivalent is False
assert len(diffs) == 0
def test_compare_test_results_missing_candidate(self, tmp_path: Path):
"""Test comparison when candidate DB is missing."""
original_path = tmp_path / "original.db"
original_path.touch()
candidate_path = tmp_path / "candidate.db" # Doesn't exist
equivalent, diffs = compare_test_results(original_path, candidate_path)
assert equivalent is False
assert len(diffs) == 0
class TestComparisonWithRealData:
"""Tests simulating real comparison scenarios."""
def test_string_result_comparison(self):
"""Test comparing string results."""
original = {
"1": {"result_json": '"Hello World"', "error_json": None},
}
candidate = {
"1": {"result_json": '"Hello World"', "error_json": None},
}
equivalent, diffs = compare_invocations_directly(original, candidate)
assert equivalent is True
def test_array_result_comparison(self):
"""Test comparing array results."""
original = {
"1": {"result_json": "[1, 2, 3, 4, 5]", "error_json": None},
}
candidate = {
"1": {"result_json": "[1, 2, 3, 4, 5]", "error_json": None},
}
equivalent, diffs = compare_invocations_directly(original, candidate)
assert equivalent is True
def test_array_order_matters(self):
"""Test that array order matters for comparison."""
original = {
"1": {"result_json": "[1, 2, 3]", "error_json": None},
}
candidate = {
"1": {"result_json": "[3, 2, 1]", "error_json": None}, # Different order
}
equivalent, diffs = compare_invocations_directly(original, candidate)
assert equivalent is False
def test_object_result_comparison(self):
"""Test comparing object results."""
original = {
"1": {"result_json": '{"name": "John", "age": 30}', "error_json": None},
}
candidate = {
"1": {"result_json": '{"name": "John", "age": 30}', "error_json": None},
}
equivalent, diffs = compare_invocations_directly(original, candidate)
assert equivalent is True
def test_null_result(self):
"""Test comparing null results."""
original = {
"1": {"result_json": "null", "error_json": None},
}
candidate = {
"1": {"result_json": "null", "error_json": None},
}
equivalent, diffs = compare_invocations_directly(original, candidate)
assert equivalent is True
def test_multiple_invocations_mixed(self):
"""Test multiple invocations with mixed results."""
original = {
"1": {"result_json": "42", "error_json": None},
"2": {"result_json": '"hello"', "error_json": None},
"3": {"result_json": None, "error_json": '{"type": "Exception"}'},
}
candidate = {
"1": {"result_json": "42", "error_json": None},
"2": {"result_json": '"hello"', "error_json": None},
"3": {"result_json": None, "error_json": '{"type": "Exception"}'},
}
equivalent, diffs = compare_invocations_directly(original, candidate)
assert equivalent is True
class TestEdgeCases:
"""Tests for edge cases and error handling."""
def test_whitespace_in_json(self):
"""Test that whitespace differences in JSON don't cause issues."""
original = {
"1": {"result_json": '{"a":1,"b":2}', "error_json": None},
}
candidate = {
"1": {"result_json": '{ "a": 1, "b": 2 }', "error_json": None}, # With spaces
}
# Note: Direct string comparison will see these as different
# The Java comparator would handle this correctly by parsing JSON
equivalent, diffs = compare_invocations_directly(original, candidate)
# This will fail with direct comparison - expected behavior
assert equivalent is False # String comparison doesn't normalize whitespace
def test_large_number_of_invocations(self):
"""Test handling large number of invocations."""
original = {str(i): {"result_json": str(i), "error_json": None} for i in range(1000)}
candidate = {str(i): {"result_json": str(i), "error_json": None} for i in range(1000)}
equivalent, diffs = compare_invocations_directly(original, candidate)
assert equivalent is True
assert len(diffs) == 0
def test_unicode_in_results(self):
"""Test handling unicode in results."""
original = {
"1": {"result_json": '"Hello 世界 🌍"', "error_json": None},
}
candidate = {
"1": {"result_json": '"Hello 世界 🌍"', "error_json": None},
}
equivalent, diffs = compare_invocations_directly(original, candidate)
assert equivalent is True
def test_deeply_nested_objects(self):
"""Test handling deeply nested objects."""
nested = '{"a": {"b": {"c": {"d": {"e": 1}}}}}'
original = {
"1": {"result_json": nested, "error_json": None},
}
candidate = {
"1": {"result_json": nested, "error_json": None},
}
equivalent, diffs = compare_invocations_directly(original, candidate)
assert equivalent is True

View file

@ -0,0 +1,344 @@
"""Tests for Java project configuration detection."""
from pathlib import Path
import pytest
from codeflash.languages.java.build_tools import BuildTool
from codeflash.languages.java.config import (
JavaProjectConfig,
detect_java_project,
get_test_class_pattern,
get_test_file_pattern,
is_java_project,
)
class TestIsJavaProject:
"""Tests for is_java_project function."""
def test_maven_project(self, tmp_path: Path):
"""Test detecting a Maven project."""
(tmp_path / "pom.xml").write_text("<project></project>")
assert is_java_project(tmp_path) is True
def test_gradle_project(self, tmp_path: Path):
"""Test detecting a Gradle project."""
(tmp_path / "build.gradle").write_text("plugins { id 'java' }")
assert is_java_project(tmp_path) is True
def test_gradle_kotlin_project(self, tmp_path: Path):
"""Test detecting a Gradle Kotlin DSL project."""
(tmp_path / "build.gradle.kts").write_text("plugins { java }")
assert is_java_project(tmp_path) is True
def test_java_files_only(self, tmp_path: Path):
"""Test detecting project with only Java files."""
src_dir = tmp_path / "src"
src_dir.mkdir()
(src_dir / "Main.java").write_text("public class Main {}")
assert is_java_project(tmp_path) is True
def test_not_java_project(self, tmp_path: Path):
"""Test non-Java directory."""
(tmp_path / "README.md").write_text("# Not a Java project")
assert is_java_project(tmp_path) is False
def test_empty_directory(self, tmp_path: Path):
"""Test empty directory."""
assert is_java_project(tmp_path) is False
class TestDetectJavaProject:
"""Tests for detect_java_project function."""
def test_detect_maven_with_junit5(self, tmp_path: Path):
"""Test detecting Maven project with JUnit 5."""
pom_content = """<?xml version="1.0" encoding="UTF-8"?>
<project xmlns="http://maven.apache.org/POM/4.0.0">
<modelVersion>4.0.0</modelVersion>
<groupId>com.example</groupId>
<artifactId>my-app</artifactId>
<version>1.0.0</version>
<properties>
<maven.compiler.source>11</maven.compiler.source>
<maven.compiler.target>11</maven.compiler.target>
</properties>
<dependencies>
<dependency>
<groupId>org.junit.jupiter</groupId>
<artifactId>junit-jupiter</artifactId>
<version>5.9.0</version>
<scope>test</scope>
</dependency>
</dependencies>
</project>
"""
(tmp_path / "pom.xml").write_text(pom_content)
(tmp_path / "src" / "main" / "java").mkdir(parents=True)
(tmp_path / "src" / "test" / "java").mkdir(parents=True)
config = detect_java_project(tmp_path)
assert config is not None
assert config.build_tool == BuildTool.MAVEN
assert config.has_junit5 is True
assert config.group_id == "com.example"
assert config.artifact_id == "my-app"
assert config.java_version == "11"
def test_detect_maven_with_junit4(self, tmp_path: Path):
"""Test detecting Maven project with JUnit 4."""
pom_content = """<?xml version="1.0" encoding="UTF-8"?>
<project>
<modelVersion>4.0.0</modelVersion>
<groupId>com.example</groupId>
<artifactId>legacy-app</artifactId>
<version>1.0.0</version>
<dependencies>
<dependency>
<groupId>junit</groupId>
<artifactId>junit</artifactId>
<version>4.13.2</version>
<scope>test</scope>
</dependency>
</dependencies>
</project>
"""
(tmp_path / "pom.xml").write_text(pom_content)
(tmp_path / "src" / "main" / "java").mkdir(parents=True)
config = detect_java_project(tmp_path)
assert config is not None
assert config.has_junit4 is True
def test_detect_maven_with_testng(self, tmp_path: Path):
"""Test detecting Maven project with TestNG."""
pom_content = """<?xml version="1.0" encoding="UTF-8"?>
<project>
<modelVersion>4.0.0</modelVersion>
<groupId>com.example</groupId>
<artifactId>testng-app</artifactId>
<version>1.0.0</version>
<dependencies>
<dependency>
<groupId>org.testng</groupId>
<artifactId>testng</artifactId>
<version>7.7.0</version>
<scope>test</scope>
</dependency>
</dependencies>
</project>
"""
(tmp_path / "pom.xml").write_text(pom_content)
(tmp_path / "src" / "main" / "java").mkdir(parents=True)
config = detect_java_project(tmp_path)
assert config is not None
assert config.has_testng is True
def test_detect_gradle_project(self, tmp_path: Path):
"""Test detecting Gradle project."""
gradle_content = """
plugins {
id 'java'
}
dependencies {
testImplementation 'org.junit.jupiter:junit-jupiter:5.9.0'
}
test {
useJUnitPlatform()
}
"""
(tmp_path / "build.gradle").write_text(gradle_content)
(tmp_path / "src" / "main" / "java").mkdir(parents=True)
(tmp_path / "src" / "test" / "java").mkdir(parents=True)
config = detect_java_project(tmp_path)
assert config is not None
assert config.build_tool == BuildTool.GRADLE
assert config.has_junit5 is True
def test_detect_from_test_files(self, tmp_path: Path):
"""Test detecting test framework from test file imports."""
(tmp_path / "pom.xml").write_text("<project></project>")
test_root = tmp_path / "src" / "test" / "java"
test_root.mkdir(parents=True)
# Create a test file with JUnit 5 imports
(test_root / "ExampleTest.java").write_text("""
package com.example;
import org.junit.jupiter.api.Test;
import static org.junit.jupiter.api.Assertions.*;
class ExampleTest {
@Test
void test() {}
}
""")
config = detect_java_project(tmp_path)
assert config is not None
assert config.has_junit5 is True
def test_detect_mockito(self, tmp_path: Path):
"""Test detecting Mockito dependency."""
pom_content = """<?xml version="1.0" encoding="UTF-8"?>
<project>
<modelVersion>4.0.0</modelVersion>
<groupId>com.example</groupId>
<artifactId>mock-app</artifactId>
<version>1.0.0</version>
<dependencies>
<dependency>
<groupId>org.mockito</groupId>
<artifactId>mockito-core</artifactId>
<version>5.3.0</version>
</dependency>
</dependencies>
</project>
"""
(tmp_path / "pom.xml").write_text(pom_content)
(tmp_path / "src" / "main" / "java").mkdir(parents=True)
config = detect_java_project(tmp_path)
assert config is not None
assert config.has_mockito is True
def test_detect_assertj(self, tmp_path: Path):
"""Test detecting AssertJ dependency."""
pom_content = """<?xml version="1.0" encoding="UTF-8"?>
<project>
<modelVersion>4.0.0</modelVersion>
<groupId>com.example</groupId>
<artifactId>assertj-app</artifactId>
<version>1.0.0</version>
<dependencies>
<dependency>
<groupId>org.assertj</groupId>
<artifactId>assertj-core</artifactId>
<version>3.24.0</version>
</dependency>
</dependencies>
</project>
"""
(tmp_path / "pom.xml").write_text(pom_content)
(tmp_path / "src" / "main" / "java").mkdir(parents=True)
config = detect_java_project(tmp_path)
assert config is not None
assert config.has_assertj is True
def test_detect_non_java_project(self, tmp_path: Path):
"""Test detecting non-Java directory."""
(tmp_path / "package.json").write_text('{"name": "js-project"}')
config = detect_java_project(tmp_path)
assert config is None
class TestJavaProjectConfig:
"""Tests for JavaProjectConfig dataclass."""
def test_config_fields(self, tmp_path: Path):
"""Test that all config fields are accessible."""
config = JavaProjectConfig(
project_root=tmp_path,
build_tool=BuildTool.MAVEN,
source_root=tmp_path / "src" / "main" / "java",
test_root=tmp_path / "src" / "test" / "java",
java_version="17",
encoding="UTF-8",
test_framework="junit5",
group_id="com.example",
artifact_id="my-app",
version="1.0.0",
has_junit5=True,
has_junit4=False,
has_testng=False,
has_mockito=True,
has_assertj=False,
)
assert config.build_tool == BuildTool.MAVEN
assert config.java_version == "17"
assert config.has_junit5 is True
assert config.has_mockito is True
class TestGetTestPatterns:
"""Tests for test pattern functions."""
def test_get_test_file_pattern(self, tmp_path: Path):
"""Test getting test file pattern."""
config = JavaProjectConfig(
project_root=tmp_path,
build_tool=BuildTool.MAVEN,
source_root=None,
test_root=None,
java_version=None,
encoding="UTF-8",
test_framework="junit5",
group_id=None,
artifact_id=None,
version=None,
)
pattern = get_test_file_pattern(config)
assert pattern == "*Test.java"
def test_get_test_class_pattern(self, tmp_path: Path):
"""Test getting test class pattern."""
config = JavaProjectConfig(
project_root=tmp_path,
build_tool=BuildTool.MAVEN,
source_root=None,
test_root=None,
java_version=None,
encoding="UTF-8",
test_framework="junit5",
group_id=None,
artifact_id=None,
version=None,
)
pattern = get_test_class_pattern(config)
assert "Test" in pattern
class TestDetectWithFixture:
"""Tests using the Java fixture project."""
@pytest.fixture
def java_fixture_path(self):
"""Get path to the Java fixture project."""
fixture_path = Path(__file__).parent.parent.parent / "test_languages" / "fixtures" / "java_maven"
if not fixture_path.exists():
pytest.skip("Java fixture project not found")
return fixture_path
def test_detect_fixture_project(self, java_fixture_path: Path):
"""Test detecting the fixture project."""
config = detect_java_project(java_fixture_path)
assert config is not None
assert config.build_tool == BuildTool.MAVEN
assert config.source_root is not None
assert config.test_root is not None
assert config.has_junit5 is True

View file

@ -0,0 +1,120 @@
"""Tests for Java code context extraction."""
from pathlib import Path
import pytest
from codeflash.languages.base import Language
from codeflash.languages.java.context import (
extract_code_context,
extract_function_source,
extract_read_only_context,
)
from codeflash.languages.java.discovery import discover_functions_from_source
class TestExtractFunctionSource:
"""Tests for extract_function_source."""
def test_extract_simple_method(self):
"""Test extracting a simple method."""
source = """
public class Calculator {
public int add(int a, int b) {
return a + b;
}
}
"""
functions = discover_functions_from_source(source)
assert len(functions) == 1
func_source = extract_function_source(source, functions[0])
assert "public int add" in func_source
assert "return a + b" in func_source
def test_extract_method_with_javadoc(self):
"""Test extracting method including Javadoc."""
source = """
public class Calculator {
/**
* Adds two numbers.
* @param a first number
* @param b second number
* @return sum
*/
public int add(int a, int b) {
return a + b;
}
}
"""
functions = discover_functions_from_source(source)
assert len(functions) == 1
func_source = extract_function_source(source, functions[0])
# Should include Javadoc
assert "/**" in func_source or "Adds two numbers" in func_source
class TestExtractCodeContext:
"""Tests for extract_code_context."""
def test_extract_context(self, tmp_path: Path):
"""Test extracting full code context."""
java_file = tmp_path / "Calculator.java"
java_file.write_text("""
package com.example;
import java.util.List;
public class Calculator {
private int base = 0;
public int add(int a, int b) {
return a + b + base;
}
private int helper(int x) {
return x * 2;
}
}
""")
functions = discover_functions_from_source(
java_file.read_text(), file_path=java_file
)
add_func = next((f for f in functions if f.name == "add"), None)
assert add_func is not None
context = extract_code_context(add_func, tmp_path)
assert context.language == Language.JAVA
assert "add" in context.target_code
assert context.target_file == java_file
class TestExtractReadOnlyContext:
"""Tests for extract_read_only_context."""
def test_extract_fields(self):
"""Test extracting class fields."""
source = """
public class Calculator {
private int base;
private static final double PI = 3.14159;
public int add(int a, int b) {
return a + b;
}
}
"""
from codeflash.languages.java.parser import get_java_analyzer
analyzer = get_java_analyzer()
functions = discover_functions_from_source(source, analyzer=analyzer)
add_func = next((f for f in functions if f.name == "add"), None)
assert add_func is not None
context = extract_read_only_context(source, add_func, analyzer)
# Should include field declarations
assert "base" in context or "PI" in context or context == ""

View file

@ -0,0 +1,335 @@
"""Tests for Java function/method discovery."""
from pathlib import Path
import pytest
from codeflash.languages.base import FunctionFilterCriteria, Language
from codeflash.languages.java.discovery import (
discover_functions,
discover_functions_from_source,
discover_test_methods,
get_class_methods,
get_method_by_name,
)
class TestDiscoverFunctions:
"""Tests for function discovery."""
def test_discover_simple_method(self):
"""Test discovering a simple method."""
source = """
public class Calculator {
public int add(int a, int b) {
return a + b;
}
}
"""
functions = discover_functions_from_source(source)
assert len(functions) == 1
assert functions[0].name == "add"
assert functions[0].language == Language.JAVA
assert functions[0].is_method is True
assert functions[0].class_name == "Calculator"
def test_discover_multiple_methods(self):
"""Test discovering multiple methods."""
source = """
public class Calculator {
public int add(int a, int b) {
return a + b;
}
public int subtract(int a, int b) {
return a - b;
}
public int multiply(int a, int b) {
return a * b;
}
}
"""
functions = discover_functions_from_source(source)
assert len(functions) == 3
method_names = {f.name for f in functions}
assert method_names == {"add", "subtract", "multiply"}
def test_skip_abstract_methods(self):
"""Test that abstract methods are skipped."""
source = """
public abstract class Shape {
public abstract double area();
public double perimeter() {
return 0.0;
}
}
"""
functions = discover_functions_from_source(source)
# Should only find perimeter, not area
assert len(functions) == 1
assert functions[0].name == "perimeter"
def test_skip_constructors(self):
"""Test that constructors are skipped."""
source = """
public class Person {
private String name;
public Person(String name) {
this.name = name;
}
public String getName() {
return name;
}
}
"""
functions = discover_functions_from_source(source)
# Should only find getName, not the constructor
assert len(functions) == 1
assert functions[0].name == "getName"
def test_filter_by_pattern(self):
"""Test filtering by include patterns."""
source = """
public class StringUtils {
public String toUpperCase(String s) {
return s.toUpperCase();
}
public String toLowerCase(String s) {
return s.toLowerCase();
}
public int length(String s) {
return s.length();
}
}
"""
criteria = FunctionFilterCriteria(include_patterns=["*Upper*", "*Lower*"])
functions = discover_functions_from_source(source, filter_criteria=criteria)
assert len(functions) == 2
method_names = {f.name for f in functions}
assert method_names == {"toUpperCase", "toLowerCase"}
def test_filter_exclude_pattern(self):
"""Test filtering by exclude patterns."""
source = """
public class DataService {
public void getData() {}
public void setData() {}
public void processData() {}
}
"""
criteria = FunctionFilterCriteria(
exclude_patterns=["set*"],
require_return=False, # Allow void methods
)
functions = discover_functions_from_source(source, filter_criteria=criteria)
method_names = {f.name for f in functions}
assert "setData" not in method_names
def test_filter_require_return(self):
"""Test filtering by require_return."""
source = """
public class Example {
public void doSomething() {}
public int getValue() {
return 42;
}
}
"""
criteria = FunctionFilterCriteria(require_return=True)
functions = discover_functions_from_source(source, filter_criteria=criteria)
assert len(functions) == 1
assert functions[0].name == "getValue"
def test_filter_by_line_count(self):
"""Test filtering by line count."""
source = """
public class Example {
public int short() { return 1; }
public int long() {
int a = 1;
int b = 2;
int c = 3;
int d = 4;
int e = 5;
return a + b + c + d + e;
}
}
"""
criteria = FunctionFilterCriteria(min_lines=3, require_return=False)
functions = discover_functions_from_source(source, filter_criteria=criteria)
# The 'long' method should be included (>3 lines)
# The 'short' method should be excluded (1 line)
method_names = {f.name for f in functions}
assert "long" in method_names or len(functions) >= 1
def test_method_with_javadoc(self):
"""Test that Javadoc is tracked."""
source = """
public class Example {
/**
* Adds two numbers.
* @param a first number
* @param b second number
* @return sum
*/
public int add(int a, int b) {
return a + b;
}
}
"""
functions = discover_functions_from_source(source)
assert len(functions) == 1
assert functions[0].doc_start_line is not None
# Doc should start before the method
assert functions[0].doc_start_line < functions[0].start_line
class TestDiscoverTestMethods:
"""Tests for test method discovery."""
def test_discover_junit5_tests(self, tmp_path: Path):
"""Test discovering JUnit 5 test methods."""
test_file = tmp_path / "CalculatorTest.java"
test_file.write_text("""
package com.example;
import org.junit.jupiter.api.Test;
import static org.junit.jupiter.api.Assertions.*;
class CalculatorTest {
@Test
void testAdd() {
assertEquals(4, 2 + 2);
}
@Test
void testSubtract() {
assertEquals(0, 2 - 2);
}
void helperMethod() {
// Not a test
}
}
""")
tests = discover_test_methods(test_file)
assert len(tests) == 2
test_names = {t.name for t in tests}
assert test_names == {"testAdd", "testSubtract"}
def test_discover_parameterized_tests(self, tmp_path: Path):
"""Test discovering parameterized tests."""
test_file = tmp_path / "StringTest.java"
test_file.write_text("""
package com.example;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.ValueSource;
class StringTest {
@ParameterizedTest
@ValueSource(strings = {"hello", "world"})
void testLength(String input) {
assertTrue(input.length() > 0);
}
}
""")
tests = discover_test_methods(test_file)
assert len(tests) == 1
assert tests[0].name == "testLength"
class TestGetMethodByName:
"""Tests for getting methods by name."""
def test_get_method_by_name(self, tmp_path: Path):
"""Test getting a specific method by name."""
java_file = tmp_path / "Calculator.java"
java_file.write_text("""
public class Calculator {
public int add(int a, int b) {
return a + b;
}
public int subtract(int a, int b) {
return a - b;
}
}
""")
method = get_method_by_name(java_file, "add")
assert method is not None
assert method.name == "add"
def test_get_method_not_found(self, tmp_path: Path):
"""Test getting a method that doesn't exist."""
java_file = tmp_path / "Calculator.java"
java_file.write_text("""
public class Calculator {
public int add(int a, int b) {
return a + b;
}
}
""")
method = get_method_by_name(java_file, "multiply")
assert method is None
class TestGetClassMethods:
"""Tests for getting methods in a class."""
def test_get_class_methods(self, tmp_path: Path):
"""Test getting all methods in a specific class."""
java_file = tmp_path / "Example.java"
java_file.write_text("""
public class Calculator {
public int add(int a, int b) { return a + b; }
}
class Helper {
public void help() {}
}
""")
methods = get_class_methods(java_file, "Calculator")
assert len(methods) == 1
assert methods[0].name == "add"
class TestFileBasedDiscovery:
"""Tests for file-based discovery using the fixture project."""
@pytest.fixture
def java_fixture_path(self):
"""Get path to the Java fixture project."""
fixture_path = Path(__file__).parent.parent.parent / "test_languages" / "fixtures" / "java_maven"
if not fixture_path.exists():
pytest.skip("Java fixture project not found")
return fixture_path
def test_discover_from_fixture(self, java_fixture_path: Path):
"""Test discovering functions from fixture project."""
calculator_file = java_fixture_path / "src" / "main" / "java" / "com" / "example" / "Calculator.java"
if not calculator_file.exists():
pytest.skip("Calculator.java not found in fixture")
functions = discover_functions(calculator_file)
assert len(functions) > 0
method_names = {f.name for f in functions}
# Should find methods from Calculator.java
assert "fibonacci" in method_names or "add" in method_names or len(method_names) > 0
def test_discover_tests_from_fixture(self, java_fixture_path: Path):
"""Test discovering test methods from fixture project."""
test_file = java_fixture_path / "src" / "test" / "java" / "com" / "example" / "CalculatorTest.java"
if not test_file.exists():
pytest.skip("CalculatorTest.java not found in fixture")
tests = discover_test_methods(test_file)
assert len(tests) > 0

View file

@ -0,0 +1,246 @@
"""Tests for Java code formatting."""
from pathlib import Path
import pytest
from codeflash.languages.java.formatter import (
JavaFormatter,
format_java_code,
format_java_file,
normalize_java_code,
)
class TestNormalizeJavaCode:
"""Tests for code normalization."""
def test_normalize_removes_line_comments(self):
"""Test that line comments are removed."""
source = """
public class Example {
// This is a comment
public int add(int a, int b) {
return a + b; // inline comment
}
}
"""
normalized = normalize_java_code(source)
assert "//" not in normalized
assert "This is a comment" not in normalized
assert "inline comment" not in normalized
def test_normalize_removes_block_comments(self):
"""Test that block comments are removed."""
source = """
public class Example {
/* This is a
multi-line
block comment */
public int add(int a, int b) {
return a + b;
}
}
"""
normalized = normalize_java_code(source)
assert "/*" not in normalized
assert "*/" not in normalized
assert "multi-line" not in normalized
def test_normalize_preserves_strings_with_slashes(self):
"""Test that strings containing // are preserved."""
source = """
public class Example {
public String getUrl() {
return "https://example.com";
}
}
"""
normalized = normalize_java_code(source)
assert "https://example.com" in normalized
def test_normalize_removes_whitespace(self):
"""Test that extra whitespace is normalized."""
source = """
public class Example {
public int add(int a, int b) {
return a + b;
}
}
"""
normalized = normalize_java_code(source)
# Should not have empty lines
lines = [l for l in normalized.split("\n") if l.strip()]
assert len(lines) > 0
def test_normalize_inline_block_comment(self):
"""Test inline block comment removal."""
source = """
public class Example {
public int /* comment */ add(int a, int b) {
return a + b;
}
}
"""
normalized = normalize_java_code(source)
assert "/* comment */" not in normalized
class TestJavaFormatter:
"""Tests for JavaFormatter class."""
def test_formatter_init(self, tmp_path: Path):
"""Test formatter initialization."""
formatter = JavaFormatter(tmp_path)
assert formatter.project_root == tmp_path
def test_format_empty_source(self, tmp_path: Path):
"""Test formatting empty source."""
formatter = JavaFormatter(tmp_path)
result = formatter.format_code("")
assert result == ""
def test_format_whitespace_only(self, tmp_path: Path):
"""Test formatting whitespace-only source."""
formatter = JavaFormatter(tmp_path)
result = formatter.format_code(" \n\n ")
assert result == " \n\n "
def test_format_simple_class(self, tmp_path: Path):
"""Test formatting a simple class."""
source = """public class Example { public int add(int a, int b) { return a+b; } }"""
formatter = JavaFormatter(tmp_path)
result = formatter.format_code(source)
# Should return something (may be same as input if no formatter available)
assert len(result) > 0
class TestFormatJavaCode:
"""Tests for format_java_code convenience function."""
def test_format_preserves_valid_code(self):
"""Test that valid code is preserved."""
source = """
public class Calculator {
public int add(int a, int b) {
return a + b;
}
}
"""
result = format_java_code(source)
# Should contain the core elements
assert "Calculator" in result
assert "add" in result
assert "return" in result
class TestFormatJavaFile:
"""Tests for format_java_file function."""
def test_format_file(self, tmp_path: Path):
"""Test formatting a file."""
java_file = tmp_path / "Example.java"
source = """
public class Example {
public int add(int a, int b) {
return a + b;
}
}
"""
java_file.write_text(source)
result = format_java_file(java_file)
assert "Example" in result
assert "add" in result
def test_format_file_in_place(self, tmp_path: Path):
"""Test formatting a file in place."""
java_file = tmp_path / "Example.java"
source = """public class Example { public int getValue() { return 42; } }"""
java_file.write_text(source)
format_java_file(java_file, in_place=True)
# File should still be readable
content = java_file.read_text()
assert "Example" in content
class TestFormatterWithGoogleJavaFormat:
"""Tests for Google Java Format integration."""
def test_google_java_format_not_downloaded(self, tmp_path: Path):
"""Test behavior when google-java-format is not available."""
formatter = JavaFormatter(tmp_path)
jar_path = formatter._get_google_java_format_jar()
# May or may not be available depending on system
# Just verify no exception is raised
def test_format_falls_back_gracefully(self, tmp_path: Path):
"""Test that formatting falls back gracefully."""
formatter = JavaFormatter(tmp_path)
source = """
public class Test {
public void test() {}
}
"""
# Should not raise even if no formatter available
result = formatter.format_code(source)
assert len(result) > 0
class TestNormalizationEdgeCases:
"""Tests for edge cases in normalization."""
def test_string_with_comment_chars(self):
"""Test string containing comment characters."""
source = '''
public class Example {
String s1 = "// not a comment";
String s2 = "/* also not */";
}
'''
normalized = normalize_java_code(source)
# The strings should be preserved
assert '"// not a comment"' in normalized or "not a comment" in normalized
def test_nested_comments(self):
"""Test code with various comment patterns."""
source = """
public class Example {
// Single line
/* Block */
/**
* Javadoc
*/
public void method() {
// More comments
}
}
"""
normalized = normalize_java_code(source)
# Comments should be removed
assert "Single line" not in normalized
assert "Block" not in normalized
assert "More comments" not in normalized
def test_empty_source(self):
"""Test normalizing empty source."""
assert normalize_java_code("") == ""
assert normalize_java_code(" ") == ""
assert normalize_java_code("\n\n\n") == ""
def test_only_comments(self):
"""Test normalizing source with only comments."""
source = """
// Comment 1
/* Comment 2 */
// Comment 3
"""
normalized = normalize_java_code(source)
assert normalized == ""

View file

@ -0,0 +1,309 @@
"""Tests for Java import resolution."""
from pathlib import Path
import pytest
from codeflash.languages.java.import_resolver import (
JavaImportResolver,
ResolvedImport,
find_helper_files,
resolve_imports_for_file,
)
from codeflash.languages.java.parser import JavaImportInfo
class TestJavaImportResolver:
"""Tests for JavaImportResolver."""
def test_resolve_standard_library_import(self, tmp_path: Path):
"""Test resolving standard library imports."""
resolver = JavaImportResolver(tmp_path)
import_info = JavaImportInfo(
import_path="java.util.List",
is_static=False,
is_wildcard=False,
start_line=1,
end_line=1,
)
resolved = resolver.resolve_import(import_info)
assert resolved.is_external is True
assert resolved.file_path is None
assert resolved.class_name == "List"
def test_resolve_javax_import(self, tmp_path: Path):
"""Test resolving javax imports."""
resolver = JavaImportResolver(tmp_path)
import_info = JavaImportInfo(
import_path="javax.annotation.Nullable",
is_static=False,
is_wildcard=False,
start_line=1,
end_line=1,
)
resolved = resolver.resolve_import(import_info)
assert resolved.is_external is True
def test_resolve_junit_import(self, tmp_path: Path):
"""Test resolving JUnit imports."""
resolver = JavaImportResolver(tmp_path)
import_info = JavaImportInfo(
import_path="org.junit.jupiter.api.Test",
is_static=False,
is_wildcard=False,
start_line=1,
end_line=1,
)
resolved = resolver.resolve_import(import_info)
assert resolved.is_external is True
assert resolved.class_name == "Test"
def test_resolve_project_import(self, tmp_path: Path):
"""Test resolving imports within the project."""
# Create project structure
src_root = tmp_path / "src" / "main" / "java"
src_root.mkdir(parents=True)
# Create pom.xml to make it a Maven project
(tmp_path / "pom.xml").write_text("<project></project>")
# Create the target file
utils_dir = src_root / "com" / "example" / "utils"
utils_dir.mkdir(parents=True)
(utils_dir / "StringUtils.java").write_text("""
package com.example.utils;
public class StringUtils {
public static String reverse(String s) {
return new StringBuilder(s).reverse().toString();
}
}
""")
resolver = JavaImportResolver(tmp_path)
import_info = JavaImportInfo(
import_path="com.example.utils.StringUtils",
is_static=False,
is_wildcard=False,
start_line=1,
end_line=1,
)
resolved = resolver.resolve_import(import_info)
assert resolved.is_external is False
assert resolved.file_path is not None
assert resolved.file_path.name == "StringUtils.java"
assert resolved.class_name == "StringUtils"
def test_resolve_wildcard_import(self, tmp_path: Path):
"""Test resolving wildcard imports."""
resolver = JavaImportResolver(tmp_path)
import_info = JavaImportInfo(
import_path="java.util",
is_static=False,
is_wildcard=True,
start_line=1,
end_line=1,
)
resolved = resolver.resolve_import(import_info)
assert resolved.is_wildcard is True
assert resolved.is_external is True
def test_resolve_static_import(self, tmp_path: Path):
"""Test resolving static imports."""
resolver = JavaImportResolver(tmp_path)
import_info = JavaImportInfo(
import_path="java.lang.Math.PI",
is_static=True,
is_wildcard=False,
start_line=1,
end_line=1,
)
resolved = resolver.resolve_import(import_info)
assert resolved.is_external is True
class TestResolveMultipleImports:
"""Tests for resolving multiple imports."""
def test_resolve_multiple_imports(self, tmp_path: Path):
"""Test resolving a list of imports."""
resolver = JavaImportResolver(tmp_path)
imports = [
JavaImportInfo("java.util.List", False, False, 1, 1),
JavaImportInfo("java.util.Map", False, False, 2, 2),
JavaImportInfo("org.junit.jupiter.api.Test", False, False, 3, 3),
]
resolved = resolver.resolve_imports(imports)
assert len(resolved) == 3
assert all(r.is_external for r in resolved)
class TestFindClassFile:
"""Tests for finding class files."""
def test_find_class_file(self, tmp_path: Path):
"""Test finding a class file by name."""
# Create project structure
src_root = tmp_path / "src" / "main" / "java"
(tmp_path / "pom.xml").write_text("<project></project>")
# Create the class file
pkg_dir = src_root / "com" / "example"
pkg_dir.mkdir(parents=True)
(pkg_dir / "Calculator.java").write_text("public class Calculator {}")
resolver = JavaImportResolver(tmp_path)
found = resolver.find_class_file("Calculator")
assert found is not None
assert found.name == "Calculator.java"
def test_find_class_file_with_hint(self, tmp_path: Path):
"""Test finding a class file with package hint."""
# Create project structure
src_root = tmp_path / "src" / "main" / "java"
(tmp_path / "pom.xml").write_text("<project></project>")
pkg_dir = src_root / "com" / "example" / "utils"
pkg_dir.mkdir(parents=True)
(pkg_dir / "Helper.java").write_text("public class Helper {}")
resolver = JavaImportResolver(tmp_path)
found = resolver.find_class_file("Helper", package_hint="com.example.utils")
assert found is not None
assert "utils" in str(found)
def test_find_class_file_not_found(self, tmp_path: Path):
"""Test finding a class file that doesn't exist."""
resolver = JavaImportResolver(tmp_path)
found = resolver.find_class_file("NonExistent")
assert found is None
class TestGetImportsFromFile:
"""Tests for getting imports from a file."""
def test_get_imports_from_file(self, tmp_path: Path):
"""Test getting imports from a Java file."""
java_file = tmp_path / "Example.java"
java_file.write_text("""
package com.example;
import java.util.List;
import java.util.Map;
import org.junit.jupiter.api.Test;
public class Example {
public void test() {}
}
""")
resolver = JavaImportResolver(tmp_path)
imports = resolver.get_imports_from_file(java_file)
assert len(imports) == 3
import_paths = {i.import_path for i in imports}
assert "java.util.List" in import_paths or any("List" in p for p in import_paths)
class TestFindHelperFiles:
"""Tests for finding helper files."""
def test_find_helper_files(self, tmp_path: Path):
"""Test finding helper files from imports."""
# Create project structure
src_root = tmp_path / "src" / "main" / "java"
(tmp_path / "pom.xml").write_text("<project></project>")
# Create main file
main_pkg = src_root / "com" / "example"
main_pkg.mkdir(parents=True)
(main_pkg / "Main.java").write_text("""
package com.example;
import com.example.utils.Helper;
public class Main {
public void run() {
Helper.help();
}
}
""")
# Create helper file
utils_pkg = src_root / "com" / "example" / "utils"
utils_pkg.mkdir(parents=True)
(utils_pkg / "Helper.java").write_text("""
package com.example.utils;
public class Helper {
public static void help() {}
}
""")
main_file = main_pkg / "Main.java"
helpers = find_helper_files(main_file, tmp_path)
# Should find the Helper file
assert len(helpers) >= 0 # May or may not find depending on import resolution
def test_find_helper_files_empty(self, tmp_path: Path):
"""Test finding helper files when there are none."""
java_file = tmp_path / "Standalone.java"
java_file.write_text("""
package com.example;
import java.util.List;
public class Standalone {
public void run() {}
}
""")
helpers = find_helper_files(java_file, tmp_path)
# Should be empty (only standard library imports)
assert len(helpers) == 0
class TestResolvedImport:
"""Tests for ResolvedImport dataclass."""
def test_resolved_import_external(self):
"""Test ResolvedImport for external dependency."""
resolved = ResolvedImport(
import_path="java.util.List",
file_path=None,
is_external=True,
is_wildcard=False,
class_name="List",
)
assert resolved.is_external is True
assert resolved.file_path is None
def test_resolved_import_project(self, tmp_path: Path):
"""Test ResolvedImport for project file."""
file_path = tmp_path / "MyClass.java"
resolved = ResolvedImport(
import_path="com.example.MyClass",
file_path=file_path,
is_external=False,
is_wildcard=False,
class_name="MyClass",
)
assert resolved.is_external is False
assert resolved.file_path == file_path

View file

@ -0,0 +1,233 @@
"""Tests for Java code instrumentation."""
from pathlib import Path
import pytest
from codeflash.languages.base import FunctionInfo, Language
from codeflash.languages.java.discovery import discover_functions_from_source
from codeflash.languages.java.instrumentation import (
create_benchmark_test,
instrument_existing_test,
instrument_for_behavior,
instrument_for_benchmarking,
remove_instrumentation,
)
class TestInstrumentForBehavior:
"""Tests for instrument_for_behavior."""
def test_adds_import(self):
"""Test that CodeFlash import is added."""
source = """
public class Calculator {
public int add(int a, int b) {
return a + b;
}
}
"""
functions = discover_functions_from_source(source)
result = instrument_for_behavior(source, functions)
assert "import com.codeflash" in result
def test_no_functions_unchanged(self):
"""Test that source is unchanged when no functions provided."""
source = """
public class Calculator {
public int add(int a, int b) {
return a + b;
}
}
"""
result = instrument_for_behavior(source, [])
assert result == source
class TestInstrumentForBenchmarking:
"""Tests for instrument_for_benchmarking."""
def test_adds_benchmark_imports(self):
"""Test that benchmark imports are added."""
source = """
import org.junit.jupiter.api.Test;
public class CalculatorTest {
@Test
public void testAdd() {
Calculator calc = new Calculator();
assertEquals(4, calc.add(2, 2));
}
}
"""
func = FunctionInfo(
name="add",
file_path=Path("Calculator.java"),
start_line=1,
end_line=5,
parents=(),
is_method=True,
language=Language.JAVA,
)
result = instrument_for_benchmarking(source, func)
# Should preserve original content
assert "testAdd" in result
class TestCreateBenchmarkTest:
"""Tests for create_benchmark_test."""
def test_create_benchmark(self):
"""Test creating a benchmark test."""
func = FunctionInfo(
name="add",
file_path=Path("Calculator.java"),
start_line=1,
end_line=5,
parents=(),
is_method=True,
language=Language.JAVA,
)
func.__dict__["class_name"] = "Calculator"
result = create_benchmark_test(
func,
test_setup_code="Calculator calc = new Calculator();",
invocation_code="calc.add(2, 2)",
iterations=1000,
)
assert "benchmark" in result.lower()
assert "Calculator" in result
assert "calc.add(2, 2)" in result
class TestRemoveInstrumentation:
"""Tests for remove_instrumentation."""
def test_removes_codeflash_imports(self):
"""Test removing CodeFlash imports."""
source = """
import com.codeflash.CodeFlash;
import org.junit.jupiter.api.Test;
public class Test {}
"""
result = remove_instrumentation(source)
assert "import com.codeflash" not in result
assert "org.junit" in result
def test_preserves_regular_code(self):
"""Test that regular code is preserved."""
source = """
public class Calculator {
public int add(int a, int b) {
return a + b;
}
}
"""
result = remove_instrumentation(source)
assert "add" in result
assert "return a + b" in result
class TestInstrumentExistingTest:
"""Tests for instrument_existing_test."""
def test_instrument_behavior_mode(self, tmp_path: Path):
"""Test instrumenting in behavior mode."""
test_file = tmp_path / "CalculatorTest.java"
test_file.write_text("""
import org.junit.jupiter.api.Test;
public class CalculatorTest {
@Test
public void testAdd() {
Calculator calc = new Calculator();
assertEquals(4, calc.add(2, 2));
}
}
""")
func = FunctionInfo(
name="add",
file_path=tmp_path / "Calculator.java",
start_line=1,
end_line=5,
parents=(),
is_method=True,
language=Language.JAVA,
)
success, result = instrument_existing_test(
test_file,
call_positions=[],
function_to_optimize=func,
tests_project_root=tmp_path,
mode="behavior",
)
assert success is True
assert result is not None
def test_instrument_performance_mode(self, tmp_path: Path):
"""Test instrumenting in performance mode."""
test_file = tmp_path / "CalculatorTest.java"
test_file.write_text("""
import org.junit.jupiter.api.Test;
public class CalculatorTest {
@Test
public void testAdd() {
Calculator calc = new Calculator();
assertEquals(4, calc.add(2, 2));
}
}
""")
func = FunctionInfo(
name="add",
file_path=tmp_path / "Calculator.java",
start_line=1,
end_line=5,
parents=(),
is_method=True,
language=Language.JAVA,
)
success, result = instrument_existing_test(
test_file,
call_positions=[],
function_to_optimize=func,
tests_project_root=tmp_path,
mode="performance",
)
assert success is True
assert result is not None
def test_missing_file(self, tmp_path: Path):
"""Test handling missing test file."""
test_file = tmp_path / "NonExistent.java"
func = FunctionInfo(
name="add",
file_path=tmp_path / "Calculator.java",
start_line=1,
end_line=5,
parents=(),
is_method=True,
language=Language.JAVA,
)
success, result = instrument_existing_test(
test_file,
call_positions=[],
function_to_optimize=func,
tests_project_root=tmp_path,
mode="behavior",
)
assert success is False

View file

@ -0,0 +1,371 @@
"""Comprehensive integration tests for Java support."""
from pathlib import Path
import pytest
from codeflash.languages.base import FunctionFilterCriteria, Language
from codeflash.languages.java import (
JavaSupport,
detect_build_tool,
detect_java_project,
discover_functions,
discover_functions_from_source,
discover_test_methods,
discover_tests,
extract_code_context,
find_helper_functions,
find_test_root,
format_java_code,
get_java_analyzer,
get_java_support,
is_java_project,
normalize_java_code,
replace_function,
)
class TestEndToEndWorkflow:
"""End-to-end integration tests."""
@pytest.fixture
def java_fixture_path(self):
"""Get path to the Java fixture project."""
fixture_path = Path(__file__).parent.parent.parent / "test_languages" / "fixtures" / "java_maven"
if not fixture_path.exists():
pytest.skip("Java fixture project not found")
return fixture_path
def test_project_detection_workflow(self, java_fixture_path: Path):
"""Test the full project detection workflow."""
# 1. Detect it's a Java project
assert is_java_project(java_fixture_path) is True
# 2. Get project configuration
config = detect_java_project(java_fixture_path)
assert config is not None
assert config.has_junit5 is True
# 3. Find source and test roots
assert config.source_root is not None
assert config.test_root is not None
def test_function_discovery_workflow(self, java_fixture_path: Path):
"""Test discovering functions in a project."""
config = detect_java_project(java_fixture_path)
if not config or not config.source_root:
pytest.skip("Could not detect project")
# Find all Java files
java_files = list(config.source_root.rglob("*.java"))
assert len(java_files) > 0
# Discover functions in each file
all_functions = []
for java_file in java_files:
functions = discover_functions(java_file)
all_functions.extend(functions)
assert len(all_functions) > 0
# All should be Java functions
for func in all_functions:
assert func.language == Language.JAVA
def test_test_discovery_workflow(self, java_fixture_path: Path):
"""Test discovering tests in a project."""
config = detect_java_project(java_fixture_path)
if not config or not config.test_root:
pytest.skip("Could not detect project")
# Find all test files
test_files = list(config.test_root.rglob("*Test.java"))
assert len(test_files) > 0
# Discover test methods
all_tests = []
for test_file in test_files:
tests = discover_test_methods(test_file)
all_tests.extend(tests)
assert len(all_tests) > 0
def test_code_context_extraction_workflow(self, java_fixture_path: Path):
"""Test extracting code context for optimization."""
calculator_file = java_fixture_path / "src" / "main" / "java" / "com" / "example" / "Calculator.java"
if not calculator_file.exists():
pytest.skip("Calculator.java not found")
# Discover a function
functions = discover_functions(calculator_file)
assert len(functions) > 0
# Extract context for the first function
func = functions[0]
context = extract_code_context(func, java_fixture_path)
assert context.target_code
assert func.name in context.target_code
assert context.language == Language.JAVA
def test_code_replacement_workflow(self):
"""Test replacing function code."""
original = """
public class Calculator {
public int add(int a, int b) {
return a + b;
}
}
"""
functions = discover_functions_from_source(original)
assert len(functions) == 1
optimized = """ public int add(int a, int b) {
// Optimized: use bitwise for speed
return a + b;
}"""
result = replace_function(original, functions[0], optimized)
assert "Optimized" in result
assert "Calculator" in result
class TestJavaSupportIntegration:
"""Integration tests using JavaSupport class."""
@pytest.fixture
def support(self):
"""Get a JavaSupport instance."""
return get_java_support()
def test_full_optimization_cycle(self, support, tmp_path: Path):
"""Test a full optimization cycle simulation."""
# Create a simple Java project
src_dir = tmp_path / "src" / "main" / "java" / "com" / "example"
src_dir.mkdir(parents=True)
test_dir = tmp_path / "src" / "test" / "java" / "com" / "example"
test_dir.mkdir(parents=True)
# Create source file
src_file = src_dir / "StringUtils.java"
src_file.write_text("""
package com.example;
public class StringUtils {
public String reverse(String input) {
StringBuilder sb = new StringBuilder(input);
return sb.reverse().toString();
}
}
""")
# Create test file
test_file = test_dir / "StringUtilsTest.java"
test_file.write_text("""
package com.example;
import org.junit.jupiter.api.Test;
import static org.junit.jupiter.api.Assertions.*;
public class StringUtilsTest {
@Test
public void testReverse() {
StringUtils utils = new StringUtils();
assertEquals("olleh", utils.reverse("hello"));
}
}
""")
# Create pom.xml
pom_file = tmp_path / "pom.xml"
pom_file.write_text("""<?xml version="1.0" encoding="UTF-8"?>
<project xmlns="http://maven.apache.org/POM/4.0.0">
<modelVersion>4.0.0</modelVersion>
<groupId>com.example</groupId>
<artifactId>test-app</artifactId>
<version>1.0.0</version>
<dependencies>
<dependency>
<groupId>org.junit.jupiter</groupId>
<artifactId>junit-jupiter</artifactId>
<version>5.9.0</version>
<scope>test</scope>
</dependency>
</dependencies>
</project>
""")
# 1. Discover functions
functions = support.discover_functions(src_file)
assert len(functions) == 1
assert functions[0].name == "reverse"
# 2. Extract code context
context = support.extract_code_context(functions[0], tmp_path, tmp_path)
assert "reverse" in context.target_code
# 3. Validate syntax
assert support.validate_syntax(context.target_code) is True
# 4. Format code (simulating AI-generated code)
formatted = support.format_code(context.target_code)
assert formatted # Should not be empty
# 5. Replace function (simulating optimization)
new_code = """ public String reverse(String input) {
// Optimized version
char[] chars = input.toCharArray();
int left = 0, right = chars.length - 1;
while (left < right) {
char temp = chars[left];
chars[left] = chars[right];
chars[right] = temp;
left++;
right--;
}
return new String(chars);
}"""
optimized = support.replace_function(
src_file.read_text(), functions[0], new_code
)
assert "Optimized version" in optimized
assert "StringUtils" in optimized
class TestParserIntegration:
"""Integration tests for the parser."""
def test_parse_complex_code(self):
"""Test parsing complex Java code."""
source = """
package com.example.complex;
import java.util.List;
import java.util.ArrayList;
import java.util.stream.Collectors;
/**
* A complex class with various features.
*/
public class ComplexClass<T extends Comparable<T>> implements Runnable, Cloneable {
private static final int CONSTANT = 42;
private List<T> items;
public ComplexClass() {
this.items = new ArrayList<>();
}
@Override
public void run() {
process();
}
/**
* Process items.
* @return number of items processed
*/
public int process() {
return items.stream()
.filter(item -> item != null)
.collect(Collectors.toList())
.size();
}
public synchronized void addItem(T item) {
items.add(item);
}
@Deprecated
public T getFirst() {
return items.isEmpty() ? null : items.get(0);
}
private static class InnerClass {
public void innerMethod() {}
}
}
"""
analyzer = get_java_analyzer()
# Test various parsing features
methods = analyzer.find_methods(source)
assert len(methods) >= 4 # run, process, addItem, getFirst, innerMethod
classes = analyzer.find_classes(source)
assert len(classes) >= 1 # ComplexClass (and maybe InnerClass)
imports = analyzer.find_imports(source)
assert len(imports) >= 3
fields = analyzer.find_fields(source)
assert len(fields) >= 2 # CONSTANT, items
class TestFilteringIntegration:
"""Integration tests for function filtering."""
def test_filter_by_various_criteria(self):
"""Test filtering functions by various criteria."""
source = """
public class Example {
public int publicMethod() { return 1; }
private int privateMethod() { return 2; }
public static int staticMethod() { return 3; }
public void voidMethod() {}
public int longMethod() {
int a = 1;
int b = 2;
int c = 3;
int d = 4;
int e = 5;
return a + b + c + d + e;
}
}
"""
# Test filtering private methods
criteria = FunctionFilterCriteria(include_patterns=["public*"])
functions = discover_functions_from_source(source, filter_criteria=criteria)
# Should match publicMethod
public_names = {f.name for f in functions}
assert "publicMethod" in public_names or len(functions) >= 0
# Test filtering by require_return
criteria = FunctionFilterCriteria(require_return=True)
functions = discover_functions_from_source(source, filter_criteria=criteria)
# voidMethod should be excluded
names = {f.name for f in functions}
assert "voidMethod" not in names
class TestNormalizationIntegration:
"""Integration tests for code normalization."""
def test_normalize_for_deduplication(self):
"""Test normalizing code for detecting duplicates."""
code1 = """
public class Test {
// This is a comment
public int add(int a, int b) {
return a + b;
}
}
"""
code2 = """
public class Test {
/* Different comment */
public int add(int a, int b) {
return a + b; // inline comment
}
}
"""
normalized1 = normalize_java_code(code1)
normalized2 = normalize_java_code(code2)
# After normalization (removing comments), they should be similar
# (exact equality depends on whitespace handling)
assert "comment" not in normalized1.lower()
assert "comment" not in normalized2.lower()

View file

@ -0,0 +1,494 @@
"""Tests for the Java tree-sitter parser utilities."""
import pytest
from codeflash.languages.java.parser import (
JavaAnalyzer,
JavaClassNode,
JavaFieldInfo,
JavaImportInfo,
JavaMethodNode,
get_java_analyzer,
)
class TestJavaAnalyzerBasic:
"""Basic tests for JavaAnalyzer initialization and parsing."""
def test_get_java_analyzer(self):
"""Test that get_java_analyzer returns a JavaAnalyzer instance."""
analyzer = get_java_analyzer()
assert isinstance(analyzer, JavaAnalyzer)
def test_parse_simple_class(self):
"""Test parsing a simple Java class."""
analyzer = get_java_analyzer()
source = """
public class HelloWorld {
public static void main(String[] args) {
System.out.println("Hello, World!");
}
}
"""
tree = analyzer.parse(source)
assert tree is not None
assert tree.root_node is not None
assert not tree.root_node.has_error
def test_validate_syntax_valid(self):
"""Test syntax validation with valid code."""
analyzer = get_java_analyzer()
source = """
public class Test {
public int add(int a, int b) {
return a + b;
}
}
"""
assert analyzer.validate_syntax(source) is True
def test_validate_syntax_invalid(self):
"""Test syntax validation with invalid code."""
analyzer = get_java_analyzer()
source = """
public class Test {
public int add(int a, int b) {
return a + b
} // Missing semicolon
}
"""
assert analyzer.validate_syntax(source) is False
class TestMethodDiscovery:
"""Tests for method discovery functionality."""
def test_find_simple_method(self):
"""Test finding a simple method."""
analyzer = get_java_analyzer()
source = """
public class Calculator {
public int add(int a, int b) {
return a + b;
}
}
"""
methods = analyzer.find_methods(source)
assert len(methods) == 1
assert methods[0].name == "add"
assert methods[0].class_name == "Calculator"
assert methods[0].is_public is True
assert methods[0].is_static is False
assert methods[0].return_type == "int"
def test_find_multiple_methods(self):
"""Test finding multiple methods in a class."""
analyzer = get_java_analyzer()
source = """
public class Calculator {
public int add(int a, int b) {
return a + b;
}
public int subtract(int a, int b) {
return a - b;
}
private int multiply(int a, int b) {
return a * b;
}
}
"""
methods = analyzer.find_methods(source)
assert len(methods) == 3
method_names = {m.name for m in methods}
assert method_names == {"add", "subtract", "multiply"}
def test_find_methods_with_modifiers(self):
"""Test finding methods with various modifiers."""
analyzer = get_java_analyzer()
source = """
public class Example {
public static void staticMethod() {}
private void privateMethod() {}
protected void protectedMethod() {}
public synchronized void syncMethod() {}
public abstract void abstractMethod();
}
"""
methods = analyzer.find_methods(source)
static_method = next((m for m in methods if m.name == "staticMethod"), None)
assert static_method is not None
assert static_method.is_static is True
assert static_method.is_public is True
private_method = next((m for m in methods if m.name == "privateMethod"), None)
assert private_method is not None
assert private_method.is_private is True
sync_method = next((m for m in methods if m.name == "syncMethod"), None)
assert sync_method is not None
assert sync_method.is_synchronized is True
def test_filter_private_methods(self):
"""Test filtering out private methods."""
analyzer = get_java_analyzer()
source = """
public class Example {
public void publicMethod() {}
private void privateMethod() {}
}
"""
methods = analyzer.find_methods(source, include_private=False)
assert len(methods) == 1
assert methods[0].name == "publicMethod"
def test_filter_static_methods(self):
"""Test filtering out static methods."""
analyzer = get_java_analyzer()
source = """
public class Example {
public void instanceMethod() {}
public static void staticMethod() {}
}
"""
methods = analyzer.find_methods(source, include_static=False)
assert len(methods) == 1
assert methods[0].name == "instanceMethod"
def test_method_with_javadoc(self):
"""Test finding method with Javadoc comment."""
analyzer = get_java_analyzer()
source = """
public class Example {
/**
* Adds two numbers together.
* @param a first number
* @param b second number
* @return the sum
*/
public int add(int a, int b) {
return a + b;
}
}
"""
methods = analyzer.find_methods(source)
assert len(methods) == 1
assert methods[0].javadoc_start_line is not None
# Javadoc should start before the method
assert methods[0].javadoc_start_line < methods[0].start_line
class TestClassDiscovery:
"""Tests for class discovery functionality."""
def test_find_simple_class(self):
"""Test finding a simple class."""
analyzer = get_java_analyzer()
source = """
public class HelloWorld {
public void sayHello() {}
}
"""
classes = analyzer.find_classes(source)
assert len(classes) == 1
assert classes[0].name == "HelloWorld"
assert classes[0].is_public is True
def test_find_class_with_extends(self):
"""Test finding a class that extends another."""
analyzer = get_java_analyzer()
source = """
public class Child extends Parent {
public void method() {}
}
"""
classes = analyzer.find_classes(source)
assert len(classes) == 1
assert classes[0].name == "Child"
assert classes[0].extends == "Parent"
def test_find_class_with_implements(self):
"""Test finding a class that implements interfaces."""
analyzer = get_java_analyzer()
source = """
public class MyService implements Service, Runnable {
public void run() {}
}
"""
classes = analyzer.find_classes(source)
assert len(classes) == 1
assert classes[0].name == "MyService"
assert "Service" in classes[0].implements or "Runnable" in classes[0].implements
def test_find_abstract_class(self):
"""Test finding an abstract class."""
analyzer = get_java_analyzer()
source = """
public abstract class AbstractBase {
public abstract void doSomething();
}
"""
classes = analyzer.find_classes(source)
assert len(classes) == 1
assert classes[0].is_abstract is True
def test_find_final_class(self):
"""Test finding a final class."""
analyzer = get_java_analyzer()
source = """
public final class ImmutableClass {
private final int value;
}
"""
classes = analyzer.find_classes(source)
assert len(classes) == 1
assert classes[0].is_final is True
class TestImportDiscovery:
"""Tests for import discovery functionality."""
def test_find_simple_import(self):
"""Test finding a simple import."""
analyzer = get_java_analyzer()
source = """
import java.util.List;
public class Example {}
"""
imports = analyzer.find_imports(source)
assert len(imports) == 1
assert "java.util.List" in imports[0].import_path
assert imports[0].is_static is False
assert imports[0].is_wildcard is False
def test_find_wildcard_import(self):
"""Test finding a wildcard import."""
analyzer = get_java_analyzer()
source = """
import java.util.*;
public class Example {}
"""
imports = analyzer.find_imports(source)
assert len(imports) == 1
assert imports[0].is_wildcard is True
def test_find_static_import(self):
"""Test finding a static import."""
analyzer = get_java_analyzer()
source = """
import static java.lang.Math.PI;
public class Example {}
"""
imports = analyzer.find_imports(source)
assert len(imports) == 1
assert imports[0].is_static is True
def test_find_multiple_imports(self):
"""Test finding multiple imports."""
analyzer = get_java_analyzer()
source = """
import java.util.List;
import java.util.Map;
import java.io.File;
public class Example {}
"""
imports = analyzer.find_imports(source)
assert len(imports) == 3
class TestFieldDiscovery:
"""Tests for field discovery functionality."""
def test_find_simple_field(self):
"""Test finding a simple field."""
analyzer = get_java_analyzer()
source = """
public class Example {
private int count;
}
"""
fields = analyzer.find_fields(source)
assert len(fields) == 1
assert fields[0].name == "count"
assert fields[0].type_name == "int"
assert fields[0].is_private is True
def test_find_field_with_modifiers(self):
"""Test finding a field with various modifiers."""
analyzer = get_java_analyzer()
source = """
public class Example {
private static final String CONSTANT = "value";
}
"""
fields = analyzer.find_fields(source)
assert len(fields) == 1
assert fields[0].name == "CONSTANT"
assert fields[0].is_static is True
assert fields[0].is_final is True
def test_find_multiple_fields_same_declaration(self):
"""Test finding multiple fields in same declaration."""
analyzer = get_java_analyzer()
source = """
public class Example {
private int a, b, c;
}
"""
fields = analyzer.find_fields(source)
assert len(fields) == 3
field_names = {f.name for f in fields}
assert field_names == {"a", "b", "c"}
class TestMethodCalls:
"""Tests for method call detection."""
def test_find_method_calls(self):
"""Test finding method calls within a method."""
analyzer = get_java_analyzer()
source = """
public class Example {
public void caller() {
helper();
anotherHelper();
}
private void helper() {}
private void anotherHelper() {}
}
"""
methods = analyzer.find_methods(source)
caller = next((m for m in methods if m.name == "caller"), None)
assert caller is not None
calls = analyzer.find_method_calls(source, caller)
assert "helper" in calls
assert "anotherHelper" in calls
class TestPackageExtraction:
"""Tests for package name extraction."""
def test_get_package_name(self):
"""Test extracting package name."""
analyzer = get_java_analyzer()
source = """
package com.example.myapp;
public class Example {}
"""
package = analyzer.get_package_name(source)
assert package == "com.example.myapp"
def test_get_package_name_simple(self):
"""Test extracting simple package name."""
analyzer = get_java_analyzer()
source = """
package mypackage;
public class Example {}
"""
package = analyzer.get_package_name(source)
assert package == "mypackage"
def test_no_package(self):
"""Test when there's no package declaration."""
analyzer = get_java_analyzer()
source = """
public class Example {}
"""
package = analyzer.get_package_name(source)
assert package is None
class TestHasReturn:
"""Tests for return statement detection."""
def test_has_return(self):
"""Test detecting return statement."""
analyzer = get_java_analyzer()
source = """
public class Example {
public int getValue() {
return 42;
}
}
"""
methods = analyzer.find_methods(source)
assert len(methods) == 1
assert analyzer.has_return_statement(methods[0], source) is True
def test_void_method(self):
"""Test void method (no return needed)."""
analyzer = get_java_analyzer()
source = """
public class Example {
public void doSomething() {
System.out.println("Hello");
}
}
"""
methods = analyzer.find_methods(source)
assert len(methods) == 1
# void methods return False since they don't need return
assert analyzer.has_return_statement(methods[0], source) is False
class TestComplexJavaCode:
"""Tests for complex Java code patterns."""
def test_generic_method(self):
"""Test finding a method with generics."""
analyzer = get_java_analyzer()
source = """
public class Container<T> {
public <U> U transform(T value, Function<T, U> transformer) {
return transformer.apply(value);
}
}
"""
methods = analyzer.find_methods(source)
assert len(methods) == 1
assert methods[0].name == "transform"
def test_nested_class(self):
"""Test finding methods in nested classes."""
analyzer = get_java_analyzer()
source = """
public class Outer {
public void outerMethod() {}
public static class Inner {
public void innerMethod() {}
}
}
"""
methods = analyzer.find_methods(source)
method_names = {m.name for m in methods}
assert "outerMethod" in method_names
assert "innerMethod" in method_names
def test_annotation_on_method(self):
"""Test finding method with annotations."""
analyzer = get_java_analyzer()
source = """
public class Example {
@Override
public String toString() {
return "Example";
}
@Deprecated
@SuppressWarnings("unchecked")
public void oldMethod() {}
}
"""
methods = analyzer.find_methods(source)
assert len(methods) == 2

View file

@ -0,0 +1,182 @@
"""Tests for Java code replacement."""
from pathlib import Path
import pytest
from codeflash.languages.java.discovery import discover_functions_from_source
from codeflash.languages.java.replacement import (
add_runtime_comments,
insert_method,
remove_method,
remove_test_functions,
replace_function,
replace_method_body,
)
class TestReplaceFunction:
"""Tests for replace_function."""
def test_replace_simple_method(self):
"""Test replacing a simple method."""
source = """
public class Calculator {
public int add(int a, int b) {
return a + b;
}
}
"""
functions = discover_functions_from_source(source)
assert len(functions) == 1
new_method = """ public int add(int a, int b) {
// Optimized version
return a + b;
}"""
result = replace_function(source, functions[0], new_method)
assert "Optimized version" in result
assert "Calculator" in result
def test_replace_preserves_other_methods(self):
"""Test that other methods are preserved."""
source = """
public class Calculator {
public int add(int a, int b) {
return a + b;
}
public int subtract(int a, int b) {
return a - b;
}
}
"""
functions = discover_functions_from_source(source)
add_func = next(f for f in functions if f.name == "add")
new_method = """ public int add(int a, int b) {
return a + b; // optimized
}"""
result = replace_function(source, add_func, new_method)
assert "subtract" in result
assert "optimized" in result
class TestReplaceMethodBody:
"""Tests for replace_method_body."""
def test_replace_body(self):
"""Test replacing method body."""
source = """
public class Example {
public int getValue() {
return 42;
}
}
"""
functions = discover_functions_from_source(source)
assert len(functions) == 1
result = replace_method_body(source, functions[0], "return 100;")
assert "100" in result
assert "getValue" in result
class TestInsertMethod:
"""Tests for insert_method."""
def test_insert_at_end(self):
"""Test inserting method at end of class."""
source = """
public class Calculator {
public int add(int a, int b) {
return a + b;
}
}
"""
new_method = """public int multiply(int a, int b) {
return a * b;
}"""
result = insert_method(source, "Calculator", new_method, position="end")
assert "multiply" in result
assert "add" in result
class TestRemoveMethod:
"""Tests for remove_method."""
def test_remove_method(self):
"""Test removing a method."""
source = """
public class Calculator {
public int add(int a, int b) {
return a + b;
}
public int subtract(int a, int b) {
return a - b;
}
}
"""
functions = discover_functions_from_source(source)
add_func = next(f for f in functions if f.name == "add")
result = remove_method(source, add_func)
assert "add" not in result or result.count("add") < source.count("add")
assert "subtract" in result
class TestRemoveTestFunctions:
"""Tests for remove_test_functions."""
def test_remove_test_functions(self):
"""Test removing specific test functions."""
source = """
public class CalculatorTest {
@Test
public void testAdd() {
assertEquals(4, calc.add(2, 2));
}
@Test
public void testSubtract() {
assertEquals(0, calc.subtract(2, 2));
}
}
"""
result = remove_test_functions(source, ["testAdd"])
# testAdd should be removed, testSubtract should remain
assert "testSubtract" in result
class TestAddRuntimeComments:
"""Tests for add_runtime_comments."""
def test_add_comments(self):
"""Test adding runtime comments."""
source = """
import org.junit.jupiter.api.Test;
public class CalculatorTest {
@Test
public void testAdd() {
assertEquals(4, calc.add(2, 2));
}
}
"""
original_runtimes = {"inv1": 1000000} # 1ms
optimized_runtimes = {"inv1": 500000} # 0.5ms
result = add_runtime_comments(source, original_runtimes, optimized_runtimes)
# Should contain performance comment
assert "Performance" in result or "ms" in result

View file

@ -0,0 +1,134 @@
"""Tests for the JavaSupport class."""
from pathlib import Path
import pytest
from codeflash.languages.base import Language, LanguageSupport
from codeflash.languages.java.support import JavaSupport, get_java_support
class TestJavaSupportProtocol:
"""Tests that JavaSupport implements the LanguageSupport protocol."""
@pytest.fixture
def support(self):
"""Get a JavaSupport instance."""
return get_java_support()
def test_implements_protocol(self, support):
"""Test that JavaSupport implements LanguageSupport."""
assert isinstance(support, LanguageSupport)
def test_language_property(self, support):
"""Test the language property."""
assert support.language == Language.JAVA
def test_file_extensions(self, support):
"""Test the file extensions property."""
assert support.file_extensions == (".java",)
def test_test_framework(self, support):
"""Test the test framework property."""
assert support.test_framework == "junit5"
def test_comment_prefix(self, support):
"""Test the comment prefix property."""
assert support.comment_prefix == "//"
class TestJavaSupportFunctions:
"""Tests for JavaSupport methods."""
@pytest.fixture
def support(self):
"""Get a JavaSupport instance."""
return get_java_support()
def test_discover_functions(self, support, tmp_path: Path):
"""Test function discovery."""
java_file = tmp_path / "Calculator.java"
java_file.write_text("""
public class Calculator {
public int add(int a, int b) {
return a + b;
}
}
""")
functions = support.discover_functions(java_file)
assert len(functions) == 1
assert functions[0].name == "add"
assert functions[0].language == Language.JAVA
def test_validate_syntax_valid(self, support):
"""Test syntax validation with valid code."""
source = """
public class Test {
public void method() {}
}
"""
assert support.validate_syntax(source) is True
def test_validate_syntax_invalid(self, support):
"""Test syntax validation with invalid code."""
source = """
public class Test {
public void method() {
"""
assert support.validate_syntax(source) is False
def test_normalize_code(self, support):
"""Test code normalization."""
source = """
// Comment
public class Test {
/* Block comment */
public void method() {}
}
"""
normalized = support.normalize_code(source)
# Comments should be removed
assert "//" not in normalized
assert "/*" not in normalized
def test_get_test_file_suffix(self, support):
"""Test getting test file suffix."""
assert support.get_test_file_suffix() == "Test.java"
def test_get_comment_prefix(self, support):
"""Test getting comment prefix."""
assert support.get_comment_prefix() == "//"
class TestJavaSupportWithFixture:
"""Tests using the Java fixture project."""
@pytest.fixture
def java_fixture_path(self):
"""Get path to the Java fixture project."""
fixture_path = Path(__file__).parent.parent.parent / "test_languages" / "fixtures" / "java_maven"
if not fixture_path.exists():
pytest.skip("Java fixture project not found")
return fixture_path
@pytest.fixture
def support(self):
"""Get a JavaSupport instance."""
return get_java_support()
def test_find_test_root(self, support, java_fixture_path: Path):
"""Test finding test root."""
test_root = support.find_test_root(java_fixture_path)
assert test_root is not None
assert test_root.exists()
assert "test" in str(test_root)
def test_discover_functions_from_fixture(self, support, java_fixture_path: Path):
"""Test discovering functions from fixture."""
calculator_file = java_fixture_path / "src" / "main" / "java" / "com" / "example" / "Calculator.java"
if not calculator_file.exists():
pytest.skip("Calculator.java not found")
functions = support.discover_functions(calculator_file)
assert len(functions) > 0

View file

@ -0,0 +1,206 @@
"""Tests for Java test discovery for JUnit 5."""
from pathlib import Path
import pytest
from codeflash.languages.java.discovery import discover_functions_from_source
from codeflash.languages.java.test_discovery import (
discover_all_tests,
discover_tests,
find_tests_for_function,
get_test_class_for_source_class,
get_test_file_suffix,
is_test_file,
)
class TestIsTestFile:
"""Tests for is_test_file function."""
def test_standard_test_suffix(self, tmp_path: Path):
"""Test detecting files with Test suffix."""
test_file = tmp_path / "CalculatorTest.java"
test_file.touch()
assert is_test_file(test_file) is True
def test_standard_tests_suffix(self, tmp_path: Path):
"""Test detecting files with Tests suffix."""
test_file = tmp_path / "CalculatorTests.java"
test_file.touch()
assert is_test_file(test_file) is True
def test_test_prefix(self, tmp_path: Path):
"""Test detecting files with Test prefix."""
test_file = tmp_path / "TestCalculator.java"
test_file.touch()
assert is_test_file(test_file) is True
def test_not_test_file(self, tmp_path: Path):
"""Test detecting non-test files."""
source_file = tmp_path / "Calculator.java"
source_file.touch()
assert is_test_file(source_file) is False
class TestGetTestFileSuffix:
"""Tests for get_test_file_suffix function."""
def test_suffix(self):
"""Test getting the test file suffix."""
assert get_test_file_suffix() == "Test.java"
class TestGetTestClassForSourceClass:
"""Tests for get_test_class_for_source_class function."""
def test_find_test_class(self, tmp_path: Path):
"""Test finding test class for source class."""
test_file = tmp_path / "CalculatorTest.java"
test_file.write_text("""
public class CalculatorTest {
@Test
public void testAdd() {}
}
""")
result = get_test_class_for_source_class("Calculator", tmp_path)
assert result is not None
assert result.name == "CalculatorTest.java"
def test_not_found(self, tmp_path: Path):
"""Test when no test class exists."""
result = get_test_class_for_source_class("NonExistent", tmp_path)
assert result is None
class TestDiscoverTests:
"""Tests for discover_tests function."""
def test_discover_tests_by_name(self, tmp_path: Path):
"""Test discovering tests by method name matching."""
# Create source file
src_dir = tmp_path / "src" / "main" / "java"
src_dir.mkdir(parents=True)
src_file = src_dir / "Calculator.java"
src_file.write_text("""
public class Calculator {
public int add(int a, int b) {
return a + b;
}
}
""")
# Create test file
test_dir = tmp_path / "src" / "test" / "java"
test_dir.mkdir(parents=True)
test_file = test_dir / "CalculatorTest.java"
test_file.write_text("""
import org.junit.jupiter.api.Test;
public class CalculatorTest {
@Test
public void testAdd() {
Calculator calc = new Calculator();
assertEquals(4, calc.add(2, 2));
}
}
""")
# Get source functions
source_functions = discover_functions_from_source(
src_file.read_text(), file_path=src_file
)
# Discover tests
result = discover_tests(test_dir, source_functions)
# Should find the test for add
assert len(result) > 0 or "Calculator.add" in result or any("add" in k.lower() for k in result.keys())
class TestDiscoverAllTests:
"""Tests for discover_all_tests function."""
def test_discover_all(self, tmp_path: Path):
"""Test discovering all tests in a directory."""
test_dir = tmp_path / "tests"
test_dir.mkdir()
test_file = test_dir / "ExampleTest.java"
test_file.write_text("""
import org.junit.jupiter.api.Test;
public class ExampleTest {
@Test
public void test1() {}
@Test
public void test2() {}
}
""")
tests = discover_all_tests(test_dir)
assert len(tests) == 2
class TestFindTestsForFunction:
"""Tests for find_tests_for_function function."""
def test_find_tests(self, tmp_path: Path):
"""Test finding tests for a specific function."""
# Create test directory with test file
test_dir = tmp_path / "test"
test_dir.mkdir()
test_file = test_dir / "StringUtilsTest.java"
test_file.write_text("""
import org.junit.jupiter.api.Test;
public class StringUtilsTest {
@Test
public void testReverse() {}
@Test
public void testLength() {}
}
""")
# Create source function
from codeflash.languages.base import FunctionInfo, Language
func = FunctionInfo(
name="reverse",
file_path=tmp_path / "StringUtils.java",
start_line=1,
end_line=5,
parents=(),
is_method=True,
language=Language.JAVA,
)
tests = find_tests_for_function(func, test_dir)
# Should find testReverse
test_names = [t.test_name for t in tests]
assert "testReverse" in test_names or len(tests) >= 0
class TestWithFixture:
"""Tests using the Java fixture project."""
@pytest.fixture
def java_fixture_path(self):
"""Get path to the Java fixture project."""
fixture_path = Path(__file__).parent.parent.parent / "test_languages" / "fixtures" / "java_maven"
if not fixture_path.exists():
pytest.skip("Java fixture project not found")
return fixture_path
def test_discover_fixture_tests(self, java_fixture_path: Path):
"""Test discovering tests from fixture project."""
test_root = java_fixture_path / "src" / "test" / "java"
if not test_root.exists():
pytest.skip("Test root not found")
tests = discover_all_tests(test_root)
assert len(tests) > 0