mirror of
https://github.com/codeflash-ai/codeflash.git
synced 2026-05-04 18:25:17 +00:00
wip java support
This commit is contained in:
parent
351dd7539f
commit
29f266ee63
61 changed files with 13048 additions and 12 deletions
5
code_to_optimize/java/codeflash.toml
Normal file
5
code_to_optimize/java/codeflash.toml
Normal file
|
|
@ -0,0 +1,5 @@
|
|||
# Codeflash configuration for Java project
|
||||
|
||||
[tool.codeflash]
|
||||
module-root = "src/main/java"
|
||||
tests-root = "src/test/java"
|
||||
122
code_to_optimize/java/src/main/java/com/example/Algorithms.java
Normal file
122
code_to_optimize/java/src/main/java/com/example/Algorithms.java
Normal file
|
|
@ -0,0 +1,122 @@
|
|||
package com.example;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
|
||||
/**
|
||||
* Collection of algorithms that can be optimized by Codeflash.
|
||||
*/
|
||||
public class Algorithms {
|
||||
|
||||
/**
|
||||
* Calculate Fibonacci number using naive recursive approach.
|
||||
* This has O(2^n) time complexity and should be optimized.
|
||||
*
|
||||
* @param n The position in Fibonacci sequence (0-indexed)
|
||||
* @return The nth Fibonacci number
|
||||
*/
|
||||
public long fibonacci(int n) {
|
||||
if (n <= 1) {
|
||||
return n;
|
||||
}
|
||||
return fibonacci(n - 1) + fibonacci(n - 2);
|
||||
}
|
||||
|
||||
/**
|
||||
* Find all prime numbers up to n using naive approach.
|
||||
* This can be optimized with Sieve of Eratosthenes.
|
||||
*
|
||||
* @param n Upper bound for finding primes
|
||||
* @return List of all prime numbers <= n
|
||||
*/
|
||||
public List<Integer> findPrimes(int n) {
|
||||
List<Integer> primes = new ArrayList<>();
|
||||
for (int i = 2; i <= n; i++) {
|
||||
if (isPrime(i)) {
|
||||
primes.add(i);
|
||||
}
|
||||
}
|
||||
return primes;
|
||||
}
|
||||
|
||||
/**
|
||||
* Check if a number is prime using naive trial division.
|
||||
*
|
||||
* @param num Number to check
|
||||
* @return true if num is prime
|
||||
*/
|
||||
private boolean isPrime(int num) {
|
||||
if (num < 2) return false;
|
||||
for (int i = 2; i < num; i++) {
|
||||
if (num % i == 0) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
/**
|
||||
* Find duplicates in an array using O(n^2) nested loops.
|
||||
* This can be optimized with HashSet to O(n).
|
||||
*
|
||||
* @param arr Input array
|
||||
* @return List of duplicate elements
|
||||
*/
|
||||
public List<Integer> findDuplicates(int[] arr) {
|
||||
List<Integer> duplicates = new ArrayList<>();
|
||||
for (int i = 0; i < arr.length; i++) {
|
||||
for (int j = i + 1; j < arr.length; j++) {
|
||||
if (arr[i] == arr[j] && !duplicates.contains(arr[i])) {
|
||||
duplicates.add(arr[i]);
|
||||
}
|
||||
}
|
||||
}
|
||||
return duplicates;
|
||||
}
|
||||
|
||||
/**
|
||||
* Calculate factorial recursively without tail optimization.
|
||||
*
|
||||
* @param n Number to calculate factorial for
|
||||
* @return n!
|
||||
*/
|
||||
public long factorial(int n) {
|
||||
if (n <= 1) {
|
||||
return 1;
|
||||
}
|
||||
return n * factorial(n - 1);
|
||||
}
|
||||
|
||||
/**
|
||||
* Concatenate strings in a loop using String concatenation.
|
||||
* Should be optimized to use StringBuilder.
|
||||
*
|
||||
* @param items List of strings to concatenate
|
||||
* @return Concatenated result
|
||||
*/
|
||||
public String concatenateStrings(List<String> items) {
|
||||
String result = "";
|
||||
for (String item : items) {
|
||||
result = result + item + ", ";
|
||||
}
|
||||
if (result.length() > 2) {
|
||||
result = result.substring(0, result.length() - 2);
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
/**
|
||||
* Calculate sum of squares using a loop.
|
||||
* This is already efficient but shows a simple example.
|
||||
*
|
||||
* @param n Upper bound
|
||||
* @return Sum of squares from 1 to n
|
||||
*/
|
||||
public long sumOfSquares(int n) {
|
||||
long sum = 0;
|
||||
for (int i = 1; i <= n; i++) {
|
||||
sum += (long) i * i;
|
||||
}
|
||||
return sum;
|
||||
}
|
||||
}
|
||||
|
|
@ -0,0 +1,129 @@
|
|||
package com.example;
|
||||
|
||||
import org.junit.jupiter.api.BeforeEach;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.junit.jupiter.api.DisplayName;
|
||||
|
||||
import java.util.Arrays;
|
||||
import java.util.List;
|
||||
|
||||
import static org.junit.jupiter.api.Assertions.*;
|
||||
|
||||
/**
|
||||
* Unit tests for Algorithms class.
|
||||
*/
|
||||
class AlgorithmsTest {
|
||||
|
||||
private Algorithms algorithms;
|
||||
|
||||
@BeforeEach
|
||||
void setUp() {
|
||||
algorithms = new Algorithms();
|
||||
}
|
||||
|
||||
@Test
|
||||
@DisplayName("Fibonacci of 0 should return 0")
|
||||
void testFibonacciZero() {
|
||||
assertEquals(0, algorithms.fibonacci(0));
|
||||
}
|
||||
|
||||
@Test
|
||||
@DisplayName("Fibonacci of 1 should return 1")
|
||||
void testFibonacciOne() {
|
||||
assertEquals(1, algorithms.fibonacci(1));
|
||||
}
|
||||
|
||||
@Test
|
||||
@DisplayName("Fibonacci of 10 should return 55")
|
||||
void testFibonacciTen() {
|
||||
assertEquals(55, algorithms.fibonacci(10));
|
||||
}
|
||||
|
||||
@Test
|
||||
@DisplayName("Fibonacci of 20 should return 6765")
|
||||
void testFibonacciTwenty() {
|
||||
assertEquals(6765, algorithms.fibonacci(20));
|
||||
}
|
||||
|
||||
@Test
|
||||
@DisplayName("Find primes up to 10")
|
||||
void testFindPrimesUpToTen() {
|
||||
List<Integer> primes = algorithms.findPrimes(10);
|
||||
assertEquals(Arrays.asList(2, 3, 5, 7), primes);
|
||||
}
|
||||
|
||||
@Test
|
||||
@DisplayName("Find primes up to 20")
|
||||
void testFindPrimesUpToTwenty() {
|
||||
List<Integer> primes = algorithms.findPrimes(20);
|
||||
assertEquals(Arrays.asList(2, 3, 5, 7, 11, 13, 17, 19), primes);
|
||||
}
|
||||
|
||||
@Test
|
||||
@DisplayName("Find duplicates in array with duplicates")
|
||||
void testFindDuplicatesWithDuplicates() {
|
||||
int[] arr = {1, 2, 3, 2, 4, 3, 5};
|
||||
List<Integer> duplicates = algorithms.findDuplicates(arr);
|
||||
assertTrue(duplicates.contains(2));
|
||||
assertTrue(duplicates.contains(3));
|
||||
assertEquals(2, duplicates.size());
|
||||
}
|
||||
|
||||
@Test
|
||||
@DisplayName("Find duplicates in array without duplicates")
|
||||
void testFindDuplicatesNoDuplicates() {
|
||||
int[] arr = {1, 2, 3, 4, 5};
|
||||
List<Integer> duplicates = algorithms.findDuplicates(arr);
|
||||
assertTrue(duplicates.isEmpty());
|
||||
}
|
||||
|
||||
@Test
|
||||
@DisplayName("Factorial of 0 should return 1")
|
||||
void testFactorialZero() {
|
||||
assertEquals(1, algorithms.factorial(0));
|
||||
}
|
||||
|
||||
@Test
|
||||
@DisplayName("Factorial of 5 should return 120")
|
||||
void testFactorialFive() {
|
||||
assertEquals(120, algorithms.factorial(5));
|
||||
}
|
||||
|
||||
@Test
|
||||
@DisplayName("Factorial of 10 should return 3628800")
|
||||
void testFactorialTen() {
|
||||
assertEquals(3628800, algorithms.factorial(10));
|
||||
}
|
||||
|
||||
@Test
|
||||
@DisplayName("Concatenate empty list")
|
||||
void testConcatenateEmptyList() {
|
||||
assertEquals("", algorithms.concatenateStrings(List.of()));
|
||||
}
|
||||
|
||||
@Test
|
||||
@DisplayName("Concatenate single item")
|
||||
void testConcatenateSingleItem() {
|
||||
assertEquals("hello", algorithms.concatenateStrings(List.of("hello")));
|
||||
}
|
||||
|
||||
@Test
|
||||
@DisplayName("Concatenate multiple items")
|
||||
void testConcatenateMultipleItems() {
|
||||
assertEquals("a, b, c", algorithms.concatenateStrings(Arrays.asList("a", "b", "c")));
|
||||
}
|
||||
|
||||
@Test
|
||||
@DisplayName("Sum of squares up to 5")
|
||||
void testSumOfSquaresFive() {
|
||||
// 1 + 4 + 9 + 16 + 25 = 55
|
||||
assertEquals(55, algorithms.sumOfSquares(5));
|
||||
}
|
||||
|
||||
@Test
|
||||
@DisplayName("Sum of squares up to 10")
|
||||
void testSumOfSquaresTen() {
|
||||
// 1 + 4 + 9 + 16 + 25 + 36 + 49 + 64 + 81 + 100 = 385
|
||||
assertEquals(385, algorithms.sumOfSquares(10));
|
||||
}
|
||||
}
|
||||
|
|
@ -0,0 +1,42 @@
|
|||
package com.codeflash;
|
||||
|
||||
/**
|
||||
* Context object for tracking benchmark timing.
|
||||
*
|
||||
* Created by {@link CodeFlash#startBenchmark(String)} and passed to
|
||||
* {@link CodeFlash#endBenchmark(BenchmarkContext)}.
|
||||
*/
|
||||
public final class BenchmarkContext {
|
||||
|
||||
private final String methodId;
|
||||
private final long startTime;
|
||||
|
||||
/**
|
||||
* Create a new benchmark context.
|
||||
*
|
||||
* @param methodId Method being benchmarked
|
||||
* @param startTime Start time in nanoseconds
|
||||
*/
|
||||
BenchmarkContext(String methodId, long startTime) {
|
||||
this.methodId = methodId;
|
||||
this.startTime = startTime;
|
||||
}
|
||||
|
||||
/**
|
||||
* Get the method ID.
|
||||
*
|
||||
* @return Method identifier
|
||||
*/
|
||||
public String getMethodId() {
|
||||
return methodId;
|
||||
}
|
||||
|
||||
/**
|
||||
* Get the start time.
|
||||
*
|
||||
* @return Start time in nanoseconds
|
||||
*/
|
||||
public long getStartTime() {
|
||||
return startTime;
|
||||
}
|
||||
}
|
||||
|
|
@ -0,0 +1,160 @@
|
|||
package com.codeflash;
|
||||
|
||||
import java.util.Arrays;
|
||||
|
||||
/**
|
||||
* Result of a benchmark run with statistical analysis.
|
||||
*
|
||||
* Provides JMH-style statistics including mean, standard deviation,
|
||||
* and percentiles (p50, p90, p99).
|
||||
*/
|
||||
public final class BenchmarkResult {
|
||||
|
||||
private final String methodId;
|
||||
private final long[] measurements;
|
||||
private final long mean;
|
||||
private final long stdDev;
|
||||
private final long min;
|
||||
private final long max;
|
||||
private final long p50;
|
||||
private final long p90;
|
||||
private final long p99;
|
||||
|
||||
/**
|
||||
* Create a benchmark result from raw measurements.
|
||||
*
|
||||
* @param methodId Method that was benchmarked
|
||||
* @param measurements Array of timing measurements in nanoseconds
|
||||
*/
|
||||
public BenchmarkResult(String methodId, long[] measurements) {
|
||||
this.methodId = methodId;
|
||||
this.measurements = measurements.clone();
|
||||
|
||||
// Sort for percentile calculations
|
||||
long[] sorted = measurements.clone();
|
||||
Arrays.sort(sorted);
|
||||
|
||||
this.min = sorted[0];
|
||||
this.max = sorted[sorted.length - 1];
|
||||
this.mean = calculateMean(sorted);
|
||||
this.stdDev = calculateStdDev(sorted, this.mean);
|
||||
this.p50 = percentile(sorted, 50);
|
||||
this.p90 = percentile(sorted, 90);
|
||||
this.p99 = percentile(sorted, 99);
|
||||
}
|
||||
|
||||
private static long calculateMean(long[] values) {
|
||||
long sum = 0;
|
||||
for (long v : values) {
|
||||
sum += v;
|
||||
}
|
||||
return sum / values.length;
|
||||
}
|
||||
|
||||
private static long calculateStdDev(long[] values, long mean) {
|
||||
if (values.length < 2) {
|
||||
return 0;
|
||||
}
|
||||
long sumSquaredDiff = 0;
|
||||
for (long v : values) {
|
||||
long diff = v - mean;
|
||||
sumSquaredDiff += diff * diff;
|
||||
}
|
||||
return (long) Math.sqrt(sumSquaredDiff / (values.length - 1));
|
||||
}
|
||||
|
||||
private static long percentile(long[] sorted, int percentile) {
|
||||
int index = (int) Math.ceil(percentile / 100.0 * sorted.length) - 1;
|
||||
return sorted[Math.max(0, Math.min(index, sorted.length - 1))];
|
||||
}
|
||||
|
||||
// Getters
|
||||
|
||||
public String getMethodId() {
|
||||
return methodId;
|
||||
}
|
||||
|
||||
public long[] getMeasurements() {
|
||||
return measurements.clone();
|
||||
}
|
||||
|
||||
public int getIterationCount() {
|
||||
return measurements.length;
|
||||
}
|
||||
|
||||
public long getMean() {
|
||||
return mean;
|
||||
}
|
||||
|
||||
public long getStdDev() {
|
||||
return stdDev;
|
||||
}
|
||||
|
||||
public long getMin() {
|
||||
return min;
|
||||
}
|
||||
|
||||
public long getMax() {
|
||||
return max;
|
||||
}
|
||||
|
||||
public long getP50() {
|
||||
return p50;
|
||||
}
|
||||
|
||||
public long getP90() {
|
||||
return p90;
|
||||
}
|
||||
|
||||
public long getP99() {
|
||||
return p99;
|
||||
}
|
||||
|
||||
/**
|
||||
* Get mean in milliseconds.
|
||||
*/
|
||||
public double getMeanMs() {
|
||||
return mean / 1_000_000.0;
|
||||
}
|
||||
|
||||
/**
|
||||
* Get standard deviation in milliseconds.
|
||||
*/
|
||||
public double getStdDevMs() {
|
||||
return stdDev / 1_000_000.0;
|
||||
}
|
||||
|
||||
/**
|
||||
* Calculate coefficient of variation (CV) as percentage.
|
||||
* CV = (stdDev / mean) * 100
|
||||
* Lower is better (more stable measurements).
|
||||
*/
|
||||
public double getCoefficientOfVariation() {
|
||||
if (mean == 0) {
|
||||
return 0;
|
||||
}
|
||||
return (stdDev * 100.0) / mean;
|
||||
}
|
||||
|
||||
/**
|
||||
* Check if measurements are stable (CV < 10%).
|
||||
*/
|
||||
public boolean isStable() {
|
||||
return getCoefficientOfVariation() < 10.0;
|
||||
}
|
||||
|
||||
@Override
|
||||
public String toString() {
|
||||
return String.format(
|
||||
"BenchmarkResult{method='%s', mean=%.3fms, stdDev=%.3fms, p50=%.3fms, p90=%.3fms, p99=%.3fms, cv=%.1f%%, iterations=%d}",
|
||||
methodId,
|
||||
getMeanMs(),
|
||||
getStdDevMs(),
|
||||
p50 / 1_000_000.0,
|
||||
p90 / 1_000_000.0,
|
||||
p99 / 1_000_000.0,
|
||||
getCoefficientOfVariation(),
|
||||
measurements.length
|
||||
);
|
||||
}
|
||||
}
|
||||
|
|
@ -0,0 +1,148 @@
|
|||
package com.codeflash;
|
||||
|
||||
/**
|
||||
* Utility class to prevent dead code elimination by the JIT compiler.
|
||||
*
|
||||
* Inspired by JMH's Blackhole class. When the JVM detects that a computed
|
||||
* value is never used, it may eliminate the computation entirely. By
|
||||
* "consuming" values through this class, we prevent such optimizations.
|
||||
*
|
||||
* Usage:
|
||||
* <pre>
|
||||
* int result = expensiveComputation();
|
||||
* Blackhole.consume(result); // Prevents JIT from eliminating the computation
|
||||
* </pre>
|
||||
*
|
||||
* The implementation uses volatile writes which act as memory barriers,
|
||||
* preventing the JIT from optimizing away the computation.
|
||||
*/
|
||||
public final class Blackhole {
|
||||
|
||||
// Volatile fields act as memory barriers, preventing optimization
|
||||
private static volatile int intSink;
|
||||
private static volatile long longSink;
|
||||
private static volatile double doubleSink;
|
||||
private static volatile Object objectSink;
|
||||
|
||||
private Blackhole() {
|
||||
// Utility class, no instantiation
|
||||
}
|
||||
|
||||
/**
|
||||
* Consume an int value to prevent dead code elimination.
|
||||
*
|
||||
* @param value Value to consume
|
||||
*/
|
||||
public static void consume(int value) {
|
||||
intSink = value;
|
||||
}
|
||||
|
||||
/**
|
||||
* Consume a long value to prevent dead code elimination.
|
||||
*
|
||||
* @param value Value to consume
|
||||
*/
|
||||
public static void consume(long value) {
|
||||
longSink = value;
|
||||
}
|
||||
|
||||
/**
|
||||
* Consume a double value to prevent dead code elimination.
|
||||
*
|
||||
* @param value Value to consume
|
||||
*/
|
||||
public static void consume(double value) {
|
||||
doubleSink = value;
|
||||
}
|
||||
|
||||
/**
|
||||
* Consume a float value to prevent dead code elimination.
|
||||
*
|
||||
* @param value Value to consume
|
||||
*/
|
||||
public static void consume(float value) {
|
||||
doubleSink = value;
|
||||
}
|
||||
|
||||
/**
|
||||
* Consume a boolean value to prevent dead code elimination.
|
||||
*
|
||||
* @param value Value to consume
|
||||
*/
|
||||
public static void consume(boolean value) {
|
||||
intSink = value ? 1 : 0;
|
||||
}
|
||||
|
||||
/**
|
||||
* Consume a byte value to prevent dead code elimination.
|
||||
*
|
||||
* @param value Value to consume
|
||||
*/
|
||||
public static void consume(byte value) {
|
||||
intSink = value;
|
||||
}
|
||||
|
||||
/**
|
||||
* Consume a short value to prevent dead code elimination.
|
||||
*
|
||||
* @param value Value to consume
|
||||
*/
|
||||
public static void consume(short value) {
|
||||
intSink = value;
|
||||
}
|
||||
|
||||
/**
|
||||
* Consume a char value to prevent dead code elimination.
|
||||
*
|
||||
* @param value Value to consume
|
||||
*/
|
||||
public static void consume(char value) {
|
||||
intSink = value;
|
||||
}
|
||||
|
||||
/**
|
||||
* Consume an Object to prevent dead code elimination.
|
||||
* Works for any reference type including arrays and collections.
|
||||
*
|
||||
* @param value Value to consume
|
||||
*/
|
||||
public static void consume(Object value) {
|
||||
objectSink = value;
|
||||
}
|
||||
|
||||
/**
|
||||
* Consume an int array to prevent dead code elimination.
|
||||
*
|
||||
* @param values Array to consume
|
||||
*/
|
||||
public static void consume(int[] values) {
|
||||
objectSink = values;
|
||||
if (values != null && values.length > 0) {
|
||||
intSink = values[0];
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Consume a long array to prevent dead code elimination.
|
||||
*
|
||||
* @param values Array to consume
|
||||
*/
|
||||
public static void consume(long[] values) {
|
||||
objectSink = values;
|
||||
if (values != null && values.length > 0) {
|
||||
longSink = values[0];
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Consume a double array to prevent dead code elimination.
|
||||
*
|
||||
* @param values Array to consume
|
||||
*/
|
||||
public static void consume(double[] values) {
|
||||
objectSink = values;
|
||||
if (values != null && values.length > 0) {
|
||||
doubleSink = values[0];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -0,0 +1,264 @@
|
|||
package com.codeflash;
|
||||
|
||||
import java.nio.file.Path;
|
||||
import java.nio.file.Paths;
|
||||
import java.util.concurrent.atomic.AtomicLong;
|
||||
|
||||
/**
|
||||
* Main API for CodeFlash runtime instrumentation.
|
||||
*
|
||||
* Provides methods for:
|
||||
* - Capturing function inputs/outputs for behavior verification
|
||||
* - Benchmarking with JMH-inspired best practices
|
||||
* - Preventing dead code elimination
|
||||
*
|
||||
* Usage:
|
||||
* <pre>
|
||||
* // Behavior capture
|
||||
* CodeFlash.captureInput("Calculator.add", a, b);
|
||||
* int result = a + b;
|
||||
* return CodeFlash.captureOutput("Calculator.add", result);
|
||||
*
|
||||
* // Benchmarking
|
||||
* BenchmarkContext ctx = CodeFlash.startBenchmark("Calculator.add");
|
||||
* // ... code to benchmark ...
|
||||
* CodeFlash.endBenchmark(ctx);
|
||||
* </pre>
|
||||
*/
|
||||
public final class CodeFlash {
|
||||
|
||||
private static final AtomicLong callIdCounter = new AtomicLong(0);
|
||||
private static volatile ResultWriter resultWriter;
|
||||
private static volatile boolean initialized = false;
|
||||
private static volatile String outputFile;
|
||||
|
||||
// Configuration from environment variables
|
||||
private static final int DEFAULT_WARMUP_ITERATIONS = 10;
|
||||
private static final int DEFAULT_MEASUREMENT_ITERATIONS = 20;
|
||||
|
||||
static {
|
||||
// Register shutdown hook to flush results
|
||||
Runtime.getRuntime().addShutdownHook(new Thread(() -> {
|
||||
if (resultWriter != null) {
|
||||
resultWriter.close();
|
||||
}
|
||||
}));
|
||||
}
|
||||
|
||||
private CodeFlash() {
|
||||
// Utility class, no instantiation
|
||||
}
|
||||
|
||||
/**
|
||||
* Initialize CodeFlash with output file path.
|
||||
* Called automatically if CODEFLASH_OUTPUT_FILE env var is set.
|
||||
*
|
||||
* @param outputPath Path to output file (SQLite database)
|
||||
*/
|
||||
public static synchronized void initialize(String outputPath) {
|
||||
if (!initialized || !outputPath.equals(outputFile)) {
|
||||
outputFile = outputPath;
|
||||
Path path = Paths.get(outputPath);
|
||||
resultWriter = new ResultWriter(path);
|
||||
initialized = true;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Get or create the result writer, initializing from environment if needed.
|
||||
*/
|
||||
private static ResultWriter getWriter() {
|
||||
if (!initialized) {
|
||||
String envPath = System.getenv("CODEFLASH_OUTPUT_FILE");
|
||||
if (envPath != null && !envPath.isEmpty()) {
|
||||
initialize(envPath);
|
||||
} else {
|
||||
// Default to temp file if no env var
|
||||
initialize(System.getProperty("java.io.tmpdir") + "/codeflash_results.db");
|
||||
}
|
||||
}
|
||||
return resultWriter;
|
||||
}
|
||||
|
||||
/**
|
||||
* Capture function input arguments.
|
||||
*
|
||||
* @param methodId Unique identifier for the method (e.g., "Calculator.add")
|
||||
* @param args Input arguments
|
||||
*/
|
||||
public static void captureInput(String methodId, Object... args) {
|
||||
long callId = callIdCounter.incrementAndGet();
|
||||
String argsJson = Serializer.toJson(args);
|
||||
getWriter().recordInput(callId, methodId, argsJson, System.nanoTime());
|
||||
}
|
||||
|
||||
/**
|
||||
* Capture function output and return it (for chaining in return statements).
|
||||
*
|
||||
* @param methodId Unique identifier for the method
|
||||
* @param result The result value
|
||||
* @param <T> Type of the result
|
||||
* @return The same result (for chaining)
|
||||
*/
|
||||
public static <T> T captureOutput(String methodId, T result) {
|
||||
long callId = callIdCounter.get(); // Use same callId as input
|
||||
String resultJson = Serializer.toJson(result);
|
||||
getWriter().recordOutput(callId, methodId, resultJson, System.nanoTime());
|
||||
return result;
|
||||
}
|
||||
|
||||
/**
|
||||
* Capture an exception thrown by the function.
|
||||
*
|
||||
* @param methodId Unique identifier for the method
|
||||
* @param error The exception
|
||||
*/
|
||||
public static void captureException(String methodId, Throwable error) {
|
||||
long callId = callIdCounter.get();
|
||||
String errorJson = Serializer.exceptionToJson(error);
|
||||
getWriter().recordError(callId, methodId, errorJson, System.nanoTime());
|
||||
}
|
||||
|
||||
/**
|
||||
* Start a benchmark context for timing code execution.
|
||||
* Implements JMH-inspired warmup and measurement phases.
|
||||
*
|
||||
* @param methodId Unique identifier for the method being benchmarked
|
||||
* @return BenchmarkContext to pass to endBenchmark
|
||||
*/
|
||||
public static BenchmarkContext startBenchmark(String methodId) {
|
||||
return new BenchmarkContext(methodId, System.nanoTime());
|
||||
}
|
||||
|
||||
/**
|
||||
* End a benchmark and record the timing.
|
||||
*
|
||||
* @param ctx The benchmark context from startBenchmark
|
||||
*/
|
||||
public static void endBenchmark(BenchmarkContext ctx) {
|
||||
long endTime = System.nanoTime();
|
||||
long duration = endTime - ctx.getStartTime();
|
||||
getWriter().recordBenchmark(ctx.getMethodId(), duration, endTime);
|
||||
}
|
||||
|
||||
/**
|
||||
* Run a benchmark with proper JMH-style warmup and measurement.
|
||||
*
|
||||
* @param methodId Unique identifier for the method
|
||||
* @param runnable Code to benchmark
|
||||
* @return Benchmark result with statistics
|
||||
*/
|
||||
public static BenchmarkResult runBenchmark(String methodId, Runnable runnable) {
|
||||
int warmupIterations = getWarmupIterations();
|
||||
int measurementIterations = getMeasurementIterations();
|
||||
|
||||
// Warmup phase - results discarded
|
||||
for (int i = 0; i < warmupIterations; i++) {
|
||||
runnable.run();
|
||||
}
|
||||
|
||||
// Suggest GC before measurement (hint only, not guaranteed)
|
||||
System.gc();
|
||||
|
||||
// Measurement phase
|
||||
long[] measurements = new long[measurementIterations];
|
||||
for (int i = 0; i < measurementIterations; i++) {
|
||||
long start = System.nanoTime();
|
||||
runnable.run();
|
||||
measurements[i] = System.nanoTime() - start;
|
||||
}
|
||||
|
||||
BenchmarkResult result = new BenchmarkResult(methodId, measurements);
|
||||
getWriter().recordBenchmarkResult(methodId, result);
|
||||
return result;
|
||||
}
|
||||
|
||||
/**
|
||||
* Run a benchmark that returns a value (prevents dead code elimination).
|
||||
*
|
||||
* @param methodId Unique identifier for the method
|
||||
* @param supplier Code to benchmark that returns a value
|
||||
* @param <T> Return type
|
||||
* @return Benchmark result with statistics
|
||||
*/
|
||||
public static <T> BenchmarkResult runBenchmarkWithResult(String methodId, java.util.function.Supplier<T> supplier) {
|
||||
int warmupIterations = getWarmupIterations();
|
||||
int measurementIterations = getMeasurementIterations();
|
||||
|
||||
// Warmup phase - consume results to prevent dead code elimination
|
||||
for (int i = 0; i < warmupIterations; i++) {
|
||||
Blackhole.consume(supplier.get());
|
||||
}
|
||||
|
||||
// Suggest GC before measurement
|
||||
System.gc();
|
||||
|
||||
// Measurement phase
|
||||
long[] measurements = new long[measurementIterations];
|
||||
for (int i = 0; i < measurementIterations; i++) {
|
||||
long start = System.nanoTime();
|
||||
T result = supplier.get();
|
||||
measurements[i] = System.nanoTime() - start;
|
||||
Blackhole.consume(result); // Prevent dead code elimination
|
||||
}
|
||||
|
||||
BenchmarkResult benchmarkResult = new BenchmarkResult(methodId, measurements);
|
||||
getWriter().recordBenchmarkResult(methodId, benchmarkResult);
|
||||
return benchmarkResult;
|
||||
}
|
||||
|
||||
/**
|
||||
* Get warmup iterations from environment or use default.
|
||||
*/
|
||||
private static int getWarmupIterations() {
|
||||
String env = System.getenv("CODEFLASH_WARMUP_ITERATIONS");
|
||||
if (env != null) {
|
||||
try {
|
||||
return Integer.parseInt(env);
|
||||
} catch (NumberFormatException e) {
|
||||
// Use default
|
||||
}
|
||||
}
|
||||
return DEFAULT_WARMUP_ITERATIONS;
|
||||
}
|
||||
|
||||
/**
|
||||
* Get measurement iterations from environment or use default.
|
||||
*/
|
||||
private static int getMeasurementIterations() {
|
||||
String env = System.getenv("CODEFLASH_MEASUREMENT_ITERATIONS");
|
||||
if (env != null) {
|
||||
try {
|
||||
return Integer.parseInt(env);
|
||||
} catch (NumberFormatException e) {
|
||||
// Use default
|
||||
}
|
||||
}
|
||||
return DEFAULT_MEASUREMENT_ITERATIONS;
|
||||
}
|
||||
|
||||
/**
|
||||
* Get the current call ID (for correlation).
|
||||
*
|
||||
* @return Current call ID
|
||||
*/
|
||||
public static long getCurrentCallId() {
|
||||
return callIdCounter.get();
|
||||
}
|
||||
|
||||
/**
|
||||
* Reset the call ID counter (for testing).
|
||||
*/
|
||||
public static void resetCallId() {
|
||||
callIdCounter.set(0);
|
||||
}
|
||||
|
||||
/**
|
||||
* Force flush all pending writes.
|
||||
*/
|
||||
public static void flush() {
|
||||
if (resultWriter != null) {
|
||||
resultWriter.flush();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -0,0 +1,349 @@
|
|||
package com.codeflash;
|
||||
|
||||
import com.google.gson.Gson;
|
||||
import com.google.gson.GsonBuilder;
|
||||
import com.google.gson.JsonArray;
|
||||
import com.google.gson.JsonElement;
|
||||
import com.google.gson.JsonObject;
|
||||
import com.google.gson.JsonParser;
|
||||
|
||||
import java.sql.Connection;
|
||||
import java.sql.DriverManager;
|
||||
import java.sql.PreparedStatement;
|
||||
import java.sql.ResultSet;
|
||||
import java.sql.SQLException;
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
import java.util.Objects;
|
||||
|
||||
/**
|
||||
* Compares test results between original and optimized code.
|
||||
*
|
||||
* Used by CodeFlash to verify that optimized code produces the
|
||||
* same outputs as the original code for the same inputs.
|
||||
*
|
||||
* Can be run as a CLI tool:
|
||||
* java -jar codeflash-runtime.jar original.db candidate.db
|
||||
*/
|
||||
public final class Comparator {
|
||||
|
||||
private static final Gson GSON = new GsonBuilder()
|
||||
.serializeNulls()
|
||||
.setPrettyPrinting()
|
||||
.create();
|
||||
|
||||
// Tolerance for floating point comparison
|
||||
private static final double EPSILON = 1e-9;
|
||||
|
||||
private Comparator() {
|
||||
// Utility class
|
||||
}
|
||||
|
||||
/**
|
||||
* Main entry point for CLI usage.
|
||||
*
|
||||
* @param args [originalDb, candidateDb]
|
||||
*/
|
||||
public static void main(String[] args) {
|
||||
if (args.length != 2) {
|
||||
System.err.println("Usage: java -jar codeflash-runtime.jar <original.db> <candidate.db>");
|
||||
System.exit(1);
|
||||
}
|
||||
|
||||
try {
|
||||
ComparisonResult result = compare(args[0], args[1]);
|
||||
System.out.println(GSON.toJson(result));
|
||||
System.exit(result.isEquivalent() ? 0 : 1);
|
||||
} catch (Exception e) {
|
||||
JsonObject error = new JsonObject();
|
||||
error.addProperty("error", e.getMessage());
|
||||
System.out.println(GSON.toJson(error));
|
||||
System.exit(2);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Compare two result databases.
|
||||
*
|
||||
* @param originalDbPath Path to original results database
|
||||
* @param candidateDbPath Path to candidate results database
|
||||
* @return Comparison result with list of differences
|
||||
*/
|
||||
public static ComparisonResult compare(String originalDbPath, String candidateDbPath) throws SQLException {
|
||||
List<Diff> diffs = new ArrayList<>();
|
||||
|
||||
try (Connection originalConn = DriverManager.getConnection("jdbc:sqlite:" + originalDbPath);
|
||||
Connection candidateConn = DriverManager.getConnection("jdbc:sqlite:" + candidateDbPath)) {
|
||||
|
||||
// Get all invocations from original
|
||||
List<Invocation> originalInvocations = getInvocations(originalConn);
|
||||
List<Invocation> candidateInvocations = getInvocations(candidateConn);
|
||||
|
||||
// Create lookup map for candidate invocations
|
||||
java.util.Map<Long, Invocation> candidateMap = new java.util.HashMap<>();
|
||||
for (Invocation inv : candidateInvocations) {
|
||||
candidateMap.put(inv.callId, inv);
|
||||
}
|
||||
|
||||
// Compare each original invocation with candidate
|
||||
for (Invocation original : originalInvocations) {
|
||||
Invocation candidate = candidateMap.get(original.callId);
|
||||
|
||||
if (candidate == null) {
|
||||
diffs.add(new Diff(
|
||||
original.callId,
|
||||
original.methodId,
|
||||
DiffType.MISSING_IN_CANDIDATE,
|
||||
"Invocation not found in candidate",
|
||||
original.resultJson,
|
||||
null
|
||||
));
|
||||
continue;
|
||||
}
|
||||
|
||||
// Compare results
|
||||
if (!compareJsonValues(original.resultJson, candidate.resultJson)) {
|
||||
diffs.add(new Diff(
|
||||
original.callId,
|
||||
original.methodId,
|
||||
DiffType.RETURN_VALUE,
|
||||
"Return values differ",
|
||||
original.resultJson,
|
||||
candidate.resultJson
|
||||
));
|
||||
}
|
||||
|
||||
// Compare errors
|
||||
boolean originalHasError = original.errorJson != null && !original.errorJson.isEmpty();
|
||||
boolean candidateHasError = candidate.errorJson != null && !candidate.errorJson.isEmpty();
|
||||
|
||||
if (originalHasError != candidateHasError) {
|
||||
diffs.add(new Diff(
|
||||
original.callId,
|
||||
original.methodId,
|
||||
DiffType.EXCEPTION,
|
||||
originalHasError ? "Original threw exception, candidate did not" :
|
||||
"Candidate threw exception, original did not",
|
||||
original.errorJson,
|
||||
candidate.errorJson
|
||||
));
|
||||
} else if (originalHasError && !compareExceptions(original.errorJson, candidate.errorJson)) {
|
||||
diffs.add(new Diff(
|
||||
original.callId,
|
||||
original.methodId,
|
||||
DiffType.EXCEPTION,
|
||||
"Exception details differ",
|
||||
original.errorJson,
|
||||
candidate.errorJson
|
||||
));
|
||||
}
|
||||
|
||||
// Remove from map to track extra invocations
|
||||
candidateMap.remove(original.callId);
|
||||
}
|
||||
|
||||
// Check for extra invocations in candidate
|
||||
for (Invocation extra : candidateMap.values()) {
|
||||
diffs.add(new Diff(
|
||||
extra.callId,
|
||||
extra.methodId,
|
||||
DiffType.EXTRA_IN_CANDIDATE,
|
||||
"Extra invocation in candidate",
|
||||
null,
|
||||
extra.resultJson
|
||||
));
|
||||
}
|
||||
}
|
||||
|
||||
return new ComparisonResult(diffs.isEmpty(), diffs);
|
||||
}
|
||||
|
||||
private static List<Invocation> getInvocations(Connection conn) throws SQLException {
|
||||
List<Invocation> invocations = new ArrayList<>();
|
||||
String sql = "SELECT call_id, method_id, args_json, result_json, error_json FROM invocations ORDER BY call_id";
|
||||
|
||||
try (PreparedStatement stmt = conn.prepareStatement(sql);
|
||||
ResultSet rs = stmt.executeQuery()) {
|
||||
|
||||
while (rs.next()) {
|
||||
invocations.add(new Invocation(
|
||||
rs.getLong("call_id"),
|
||||
rs.getString("method_id"),
|
||||
rs.getString("args_json"),
|
||||
rs.getString("result_json"),
|
||||
rs.getString("error_json")
|
||||
));
|
||||
}
|
||||
}
|
||||
|
||||
return invocations;
|
||||
}
|
||||
|
||||
/**
|
||||
* Compare two JSON values for equivalence.
|
||||
*/
|
||||
private static boolean compareJsonValues(String json1, String json2) {
|
||||
if (json1 == null && json2 == null) return true;
|
||||
if (json1 == null || json2 == null) return false;
|
||||
if (json1.equals(json2)) return true;
|
||||
|
||||
try {
|
||||
JsonElement elem1 = JsonParser.parseString(json1);
|
||||
JsonElement elem2 = JsonParser.parseString(json2);
|
||||
return compareJsonElements(elem1, elem2);
|
||||
} catch (Exception e) {
|
||||
// If parsing fails, fall back to string comparison
|
||||
return json1.equals(json2);
|
||||
}
|
||||
}
|
||||
|
||||
private static boolean compareJsonElements(JsonElement elem1, JsonElement elem2) {
|
||||
if (elem1 == null && elem2 == null) return true;
|
||||
if (elem1 == null || elem2 == null) return false;
|
||||
if (elem1.isJsonNull() && elem2.isJsonNull()) return true;
|
||||
|
||||
// Compare primitives
|
||||
if (elem1.isJsonPrimitive() && elem2.isJsonPrimitive()) {
|
||||
return comparePrimitives(elem1.getAsJsonPrimitive(), elem2.getAsJsonPrimitive());
|
||||
}
|
||||
|
||||
// Compare arrays
|
||||
if (elem1.isJsonArray() && elem2.isJsonArray()) {
|
||||
return compareArrays(elem1.getAsJsonArray(), elem2.getAsJsonArray());
|
||||
}
|
||||
|
||||
// Compare objects
|
||||
if (elem1.isJsonObject() && elem2.isJsonObject()) {
|
||||
return compareObjects(elem1.getAsJsonObject(), elem2.getAsJsonObject());
|
||||
}
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
private static boolean comparePrimitives(com.google.gson.JsonPrimitive p1, com.google.gson.JsonPrimitive p2) {
|
||||
// Handle numeric comparison with epsilon
|
||||
if (p1.isNumber() && p2.isNumber()) {
|
||||
double d1 = p1.getAsDouble();
|
||||
double d2 = p2.getAsDouble();
|
||||
// Handle NaN
|
||||
if (Double.isNaN(d1) && Double.isNaN(d2)) return true;
|
||||
// Handle infinity
|
||||
if (Double.isInfinite(d1) && Double.isInfinite(d2)) {
|
||||
return (d1 > 0) == (d2 > 0);
|
||||
}
|
||||
// Compare with epsilon
|
||||
return Math.abs(d1 - d2) < EPSILON;
|
||||
}
|
||||
|
||||
return Objects.equals(p1, p2);
|
||||
}
|
||||
|
||||
private static boolean compareArrays(JsonArray arr1, JsonArray arr2) {
|
||||
if (arr1.size() != arr2.size()) return false;
|
||||
|
||||
for (int i = 0; i < arr1.size(); i++) {
|
||||
if (!compareJsonElements(arr1.get(i), arr2.get(i))) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
private static boolean compareObjects(JsonObject obj1, JsonObject obj2) {
|
||||
// Skip type metadata for comparison
|
||||
java.util.Set<String> keys1 = new java.util.HashSet<>(obj1.keySet());
|
||||
java.util.Set<String> keys2 = new java.util.HashSet<>(obj2.keySet());
|
||||
keys1.remove("__type__");
|
||||
keys2.remove("__type__");
|
||||
|
||||
if (!keys1.equals(keys2)) return false;
|
||||
|
||||
for (String key : keys1) {
|
||||
if (!compareJsonElements(obj1.get(key), obj2.get(key))) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
private static boolean compareExceptions(String error1, String error2) {
|
||||
try {
|
||||
JsonObject e1 = JsonParser.parseString(error1).getAsJsonObject();
|
||||
JsonObject e2 = JsonParser.parseString(error2).getAsJsonObject();
|
||||
|
||||
// Compare exception type and message
|
||||
String type1 = e1.has("type") ? e1.get("type").getAsString() : "";
|
||||
String type2 = e2.has("type") ? e2.get("type").getAsString() : "";
|
||||
|
||||
// Types must match
|
||||
return type1.equals(type2);
|
||||
} catch (Exception e) {
|
||||
return error1.equals(error2);
|
||||
}
|
||||
}
|
||||
|
||||
// Data classes
|
||||
|
||||
private static class Invocation {
|
||||
final long callId;
|
||||
final String methodId;
|
||||
final String argsJson;
|
||||
final String resultJson;
|
||||
final String errorJson;
|
||||
|
||||
Invocation(long callId, String methodId, String argsJson, String resultJson, String errorJson) {
|
||||
this.callId = callId;
|
||||
this.methodId = methodId;
|
||||
this.argsJson = argsJson;
|
||||
this.resultJson = resultJson;
|
||||
this.errorJson = errorJson;
|
||||
}
|
||||
}
|
||||
|
||||
public enum DiffType {
|
||||
RETURN_VALUE,
|
||||
EXCEPTION,
|
||||
MISSING_IN_CANDIDATE,
|
||||
EXTRA_IN_CANDIDATE
|
||||
}
|
||||
|
||||
public static class Diff {
|
||||
private final long callId;
|
||||
private final String methodId;
|
||||
private final DiffType type;
|
||||
private final String message;
|
||||
private final String originalValue;
|
||||
private final String candidateValue;
|
||||
|
||||
public Diff(long callId, String methodId, DiffType type, String message,
|
||||
String originalValue, String candidateValue) {
|
||||
this.callId = callId;
|
||||
this.methodId = methodId;
|
||||
this.type = type;
|
||||
this.message = message;
|
||||
this.originalValue = originalValue;
|
||||
this.candidateValue = candidateValue;
|
||||
}
|
||||
|
||||
// Getters
|
||||
public long getCallId() { return callId; }
|
||||
public String getMethodId() { return methodId; }
|
||||
public DiffType getType() { return type; }
|
||||
public String getMessage() { return message; }
|
||||
public String getOriginalValue() { return originalValue; }
|
||||
public String getCandidateValue() { return candidateValue; }
|
||||
}
|
||||
|
||||
public static class ComparisonResult {
|
||||
private final boolean equivalent;
|
||||
private final List<Diff> diffs;
|
||||
|
||||
public ComparisonResult(boolean equivalent, List<Diff> diffs) {
|
||||
this.equivalent = equivalent;
|
||||
this.diffs = diffs;
|
||||
}
|
||||
|
||||
public boolean isEquivalent() { return equivalent; }
|
||||
public List<Diff> getDiffs() { return diffs; }
|
||||
}
|
||||
}
|
||||
|
|
@ -0,0 +1,318 @@
|
|||
package com.codeflash;
|
||||
|
||||
import java.nio.file.Path;
|
||||
import java.sql.Connection;
|
||||
import java.sql.DriverManager;
|
||||
import java.sql.PreparedStatement;
|
||||
import java.sql.SQLException;
|
||||
import java.sql.Statement;
|
||||
import java.util.concurrent.BlockingQueue;
|
||||
import java.util.concurrent.LinkedBlockingQueue;
|
||||
import java.util.concurrent.TimeUnit;
|
||||
import java.util.concurrent.atomic.AtomicBoolean;
|
||||
|
||||
/**
|
||||
* Writes benchmark and behavior capture results to SQLite database.
|
||||
*
|
||||
* Uses a background thread for non-blocking writes to minimize
|
||||
* impact on benchmark measurements.
|
||||
*
|
||||
* Database schema:
|
||||
* - invocations: call_id, method_id, args_json, result_json, error_json, start_time, end_time
|
||||
* - benchmarks: method_id, duration_ns, timestamp
|
||||
* - benchmark_results: method_id, mean_ns, stddev_ns, min_ns, max_ns, p50_ns, p90_ns, p99_ns, iterations
|
||||
*/
|
||||
public final class ResultWriter {
|
||||
|
||||
private final Path dbPath;
|
||||
private final Connection connection;
|
||||
private final BlockingQueue<WriteTask> writeQueue;
|
||||
private final Thread writerThread;
|
||||
private final AtomicBoolean running;
|
||||
|
||||
// Prepared statements for performance
|
||||
private PreparedStatement insertInvocationInput;
|
||||
private PreparedStatement updateInvocationOutput;
|
||||
private PreparedStatement updateInvocationError;
|
||||
private PreparedStatement insertBenchmark;
|
||||
private PreparedStatement insertBenchmarkResult;
|
||||
|
||||
/**
|
||||
* Create a new ResultWriter that writes to the specified database file.
|
||||
*
|
||||
* @param dbPath Path to SQLite database file (will be created if not exists)
|
||||
*/
|
||||
public ResultWriter(Path dbPath) {
|
||||
this.dbPath = dbPath;
|
||||
this.writeQueue = new LinkedBlockingQueue<>();
|
||||
this.running = new AtomicBoolean(true);
|
||||
|
||||
try {
|
||||
// Create connection and initialize schema
|
||||
this.connection = DriverManager.getConnection("jdbc:sqlite:" + dbPath.toAbsolutePath());
|
||||
initializeSchema();
|
||||
prepareStatements();
|
||||
|
||||
// Start background writer thread
|
||||
this.writerThread = new Thread(this::writerLoop, "codeflash-writer");
|
||||
this.writerThread.setDaemon(true);
|
||||
this.writerThread.start();
|
||||
|
||||
} catch (SQLException e) {
|
||||
throw new RuntimeException("Failed to initialize ResultWriter: " + e.getMessage(), e);
|
||||
}
|
||||
}
|
||||
|
||||
private void initializeSchema() throws SQLException {
|
||||
try (Statement stmt = connection.createStatement()) {
|
||||
// Invocations table - stores input/output/error for each function call
|
||||
stmt.execute(
|
||||
"CREATE TABLE IF NOT EXISTS invocations (" +
|
||||
"call_id INTEGER PRIMARY KEY, " +
|
||||
"method_id TEXT NOT NULL, " +
|
||||
"args_json TEXT, " +
|
||||
"result_json TEXT, " +
|
||||
"error_json TEXT, " +
|
||||
"start_time INTEGER, " +
|
||||
"end_time INTEGER)"
|
||||
);
|
||||
|
||||
// Benchmarks table - stores individual benchmark timings
|
||||
stmt.execute(
|
||||
"CREATE TABLE IF NOT EXISTS benchmarks (" +
|
||||
"id INTEGER PRIMARY KEY AUTOINCREMENT, " +
|
||||
"method_id TEXT NOT NULL, " +
|
||||
"duration_ns INTEGER NOT NULL, " +
|
||||
"timestamp INTEGER NOT NULL)"
|
||||
);
|
||||
|
||||
// Benchmark results table - stores aggregated statistics
|
||||
stmt.execute(
|
||||
"CREATE TABLE IF NOT EXISTS benchmark_results (" +
|
||||
"method_id TEXT PRIMARY KEY, " +
|
||||
"mean_ns INTEGER NOT NULL, " +
|
||||
"stddev_ns INTEGER NOT NULL, " +
|
||||
"min_ns INTEGER NOT NULL, " +
|
||||
"max_ns INTEGER NOT NULL, " +
|
||||
"p50_ns INTEGER NOT NULL, " +
|
||||
"p90_ns INTEGER NOT NULL, " +
|
||||
"p99_ns INTEGER NOT NULL, " +
|
||||
"iterations INTEGER NOT NULL, " +
|
||||
"coefficient_of_variation REAL NOT NULL)"
|
||||
);
|
||||
|
||||
// Create indexes for faster queries
|
||||
stmt.execute("CREATE INDEX IF NOT EXISTS idx_invocations_method ON invocations(method_id)");
|
||||
stmt.execute("CREATE INDEX IF NOT EXISTS idx_benchmarks_method ON benchmarks(method_id)");
|
||||
}
|
||||
}
|
||||
|
||||
private void prepareStatements() throws SQLException {
|
||||
insertInvocationInput = connection.prepareStatement(
|
||||
"INSERT INTO invocations (call_id, method_id, args_json, start_time) VALUES (?, ?, ?, ?)"
|
||||
);
|
||||
updateInvocationOutput = connection.prepareStatement(
|
||||
"UPDATE invocations SET result_json = ?, end_time = ? WHERE call_id = ?"
|
||||
);
|
||||
updateInvocationError = connection.prepareStatement(
|
||||
"UPDATE invocations SET error_json = ?, end_time = ? WHERE call_id = ?"
|
||||
);
|
||||
insertBenchmark = connection.prepareStatement(
|
||||
"INSERT INTO benchmarks (method_id, duration_ns, timestamp) VALUES (?, ?, ?)"
|
||||
);
|
||||
insertBenchmarkResult = connection.prepareStatement(
|
||||
"INSERT OR REPLACE INTO benchmark_results " +
|
||||
"(method_id, mean_ns, stddev_ns, min_ns, max_ns, p50_ns, p90_ns, p99_ns, iterations, coefficient_of_variation) " +
|
||||
"VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)"
|
||||
);
|
||||
}
|
||||
|
||||
/**
|
||||
* Record function input (beginning of invocation).
|
||||
*/
|
||||
public void recordInput(long callId, String methodId, String argsJson, long startTime) {
|
||||
writeQueue.offer(new WriteTask(WriteType.INPUT, callId, methodId, argsJson, null, null, startTime, 0, null));
|
||||
}
|
||||
|
||||
/**
|
||||
* Record function output (successful completion).
|
||||
*/
|
||||
public void recordOutput(long callId, String methodId, String resultJson, long endTime) {
|
||||
writeQueue.offer(new WriteTask(WriteType.OUTPUT, callId, methodId, null, resultJson, null, 0, endTime, null));
|
||||
}
|
||||
|
||||
/**
|
||||
* Record function error (exception thrown).
|
||||
*/
|
||||
public void recordError(long callId, String methodId, String errorJson, long endTime) {
|
||||
writeQueue.offer(new WriteTask(WriteType.ERROR, callId, methodId, null, null, errorJson, 0, endTime, null));
|
||||
}
|
||||
|
||||
/**
|
||||
* Record a single benchmark timing.
|
||||
*/
|
||||
public void recordBenchmark(String methodId, long durationNs, long timestamp) {
|
||||
writeQueue.offer(new WriteTask(WriteType.BENCHMARK, 0, methodId, null, null, null, durationNs, timestamp, null));
|
||||
}
|
||||
|
||||
/**
|
||||
* Record aggregated benchmark results.
|
||||
*/
|
||||
public void recordBenchmarkResult(String methodId, BenchmarkResult result) {
|
||||
writeQueue.offer(new WriteTask(WriteType.BENCHMARK_RESULT, 0, methodId, null, null, null, 0, 0, result));
|
||||
}
|
||||
|
||||
/**
|
||||
* Background writer loop - processes write tasks from queue.
|
||||
*/
|
||||
private void writerLoop() {
|
||||
while (running.get() || !writeQueue.isEmpty()) {
|
||||
try {
|
||||
WriteTask task = writeQueue.poll(100, TimeUnit.MILLISECONDS);
|
||||
if (task != null) {
|
||||
executeTask(task);
|
||||
}
|
||||
} catch (InterruptedException e) {
|
||||
Thread.currentThread().interrupt();
|
||||
break;
|
||||
} catch (SQLException e) {
|
||||
System.err.println("CodeFlash ResultWriter error: " + e.getMessage());
|
||||
}
|
||||
}
|
||||
|
||||
// Process remaining tasks
|
||||
WriteTask task;
|
||||
while ((task = writeQueue.poll()) != null) {
|
||||
try {
|
||||
executeTask(task);
|
||||
} catch (SQLException e) {
|
||||
System.err.println("CodeFlash ResultWriter error: " + e.getMessage());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private void executeTask(WriteTask task) throws SQLException {
|
||||
switch (task.type) {
|
||||
case INPUT:
|
||||
insertInvocationInput.setLong(1, task.callId);
|
||||
insertInvocationInput.setString(2, task.methodId);
|
||||
insertInvocationInput.setString(3, task.argsJson);
|
||||
insertInvocationInput.setLong(4, task.startTime);
|
||||
insertInvocationInput.executeUpdate();
|
||||
break;
|
||||
|
||||
case OUTPUT:
|
||||
updateInvocationOutput.setString(1, task.resultJson);
|
||||
updateInvocationOutput.setLong(2, task.endTime);
|
||||
updateInvocationOutput.setLong(3, task.callId);
|
||||
updateInvocationOutput.executeUpdate();
|
||||
break;
|
||||
|
||||
case ERROR:
|
||||
updateInvocationError.setString(1, task.errorJson);
|
||||
updateInvocationError.setLong(2, task.endTime);
|
||||
updateInvocationError.setLong(3, task.callId);
|
||||
updateInvocationError.executeUpdate();
|
||||
break;
|
||||
|
||||
case BENCHMARK:
|
||||
insertBenchmark.setString(1, task.methodId);
|
||||
insertBenchmark.setLong(2, task.startTime); // duration stored in startTime field
|
||||
insertBenchmark.setLong(3, task.endTime); // timestamp stored in endTime field
|
||||
insertBenchmark.executeUpdate();
|
||||
break;
|
||||
|
||||
case BENCHMARK_RESULT:
|
||||
BenchmarkResult r = task.benchmarkResult;
|
||||
insertBenchmarkResult.setString(1, task.methodId);
|
||||
insertBenchmarkResult.setLong(2, r.getMean());
|
||||
insertBenchmarkResult.setLong(3, r.getStdDev());
|
||||
insertBenchmarkResult.setLong(4, r.getMin());
|
||||
insertBenchmarkResult.setLong(5, r.getMax());
|
||||
insertBenchmarkResult.setLong(6, r.getP50());
|
||||
insertBenchmarkResult.setLong(7, r.getP90());
|
||||
insertBenchmarkResult.setLong(8, r.getP99());
|
||||
insertBenchmarkResult.setInt(9, r.getIterationCount());
|
||||
insertBenchmarkResult.setDouble(10, r.getCoefficientOfVariation());
|
||||
insertBenchmarkResult.executeUpdate();
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Flush all pending writes synchronously.
|
||||
*/
|
||||
public void flush() {
|
||||
// Wait for queue to drain
|
||||
while (!writeQueue.isEmpty()) {
|
||||
try {
|
||||
Thread.sleep(10);
|
||||
} catch (InterruptedException e) {
|
||||
Thread.currentThread().interrupt();
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Close the writer and database connection.
|
||||
*/
|
||||
public void close() {
|
||||
running.set(false);
|
||||
|
||||
try {
|
||||
writerThread.join(5000); // Wait up to 5 seconds
|
||||
} catch (InterruptedException e) {
|
||||
Thread.currentThread().interrupt();
|
||||
}
|
||||
|
||||
try {
|
||||
if (insertInvocationInput != null) insertInvocationInput.close();
|
||||
if (updateInvocationOutput != null) updateInvocationOutput.close();
|
||||
if (updateInvocationError != null) updateInvocationError.close();
|
||||
if (insertBenchmark != null) insertBenchmark.close();
|
||||
if (insertBenchmarkResult != null) insertBenchmarkResult.close();
|
||||
if (connection != null) connection.close();
|
||||
} catch (SQLException e) {
|
||||
System.err.println("Error closing ResultWriter: " + e.getMessage());
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Get the database path.
|
||||
*/
|
||||
public Path getDbPath() {
|
||||
return dbPath;
|
||||
}
|
||||
|
||||
// Internal task class for queue
|
||||
private enum WriteType {
|
||||
INPUT, OUTPUT, ERROR, BENCHMARK, BENCHMARK_RESULT
|
||||
}
|
||||
|
||||
private static class WriteTask {
|
||||
final WriteType type;
|
||||
final long callId;
|
||||
final String methodId;
|
||||
final String argsJson;
|
||||
final String resultJson;
|
||||
final String errorJson;
|
||||
final long startTime;
|
||||
final long endTime;
|
||||
final BenchmarkResult benchmarkResult;
|
||||
|
||||
WriteTask(WriteType type, long callId, String methodId, String argsJson,
|
||||
String resultJson, String errorJson, long startTime, long endTime,
|
||||
BenchmarkResult benchmarkResult) {
|
||||
this.type = type;
|
||||
this.callId = callId;
|
||||
this.methodId = methodId;
|
||||
this.argsJson = argsJson;
|
||||
this.resultJson = resultJson;
|
||||
this.errorJson = errorJson;
|
||||
this.startTime = startTime;
|
||||
this.endTime = endTime;
|
||||
this.benchmarkResult = benchmarkResult;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -0,0 +1,282 @@
|
|||
package com.codeflash;
|
||||
|
||||
import com.google.gson.Gson;
|
||||
import com.google.gson.GsonBuilder;
|
||||
import com.google.gson.JsonArray;
|
||||
import com.google.gson.JsonElement;
|
||||
import com.google.gson.JsonNull;
|
||||
import com.google.gson.JsonObject;
|
||||
import com.google.gson.JsonPrimitive;
|
||||
|
||||
import java.lang.reflect.Field;
|
||||
import java.lang.reflect.Modifier;
|
||||
import java.time.LocalDate;
|
||||
import java.time.LocalDateTime;
|
||||
import java.time.LocalTime;
|
||||
import java.util.Collection;
|
||||
import java.util.Date;
|
||||
import java.util.IdentityHashMap;
|
||||
import java.util.Map;
|
||||
import java.util.Optional;
|
||||
|
||||
/**
|
||||
* Serializer for Java objects to JSON format.
|
||||
*
|
||||
* Handles:
|
||||
* - Primitives and their wrappers
|
||||
* - Strings
|
||||
* - Arrays (primitive and object)
|
||||
* - Collections (List, Set, etc.)
|
||||
* - Maps
|
||||
* - Date/Time types
|
||||
* - Custom objects via reflection
|
||||
* - Circular references (detected and marked)
|
||||
*/
|
||||
public final class Serializer {
|
||||
|
||||
private static final Gson GSON = new GsonBuilder()
|
||||
.serializeNulls()
|
||||
.create();
|
||||
|
||||
private static final int MAX_DEPTH = 10;
|
||||
private static final int MAX_COLLECTION_SIZE = 1000;
|
||||
|
||||
private Serializer() {
|
||||
// Utility class
|
||||
}
|
||||
|
||||
/**
|
||||
* Serialize an object to JSON string.
|
||||
*
|
||||
* @param obj Object to serialize
|
||||
* @return JSON string representation
|
||||
*/
|
||||
public static String toJson(Object obj) {
|
||||
try {
|
||||
JsonElement element = serialize(obj, new IdentityHashMap<>(), 0);
|
||||
return GSON.toJson(element);
|
||||
} catch (Exception e) {
|
||||
// Fallback for serialization errors
|
||||
JsonObject error = new JsonObject();
|
||||
error.addProperty("__serialization_error__", e.getMessage());
|
||||
error.addProperty("__type__", obj != null ? obj.getClass().getName() : "null");
|
||||
return GSON.toJson(error);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Serialize varargs (for capturing multiple arguments).
|
||||
*
|
||||
* @param args Arguments to serialize
|
||||
* @return JSON array string
|
||||
*/
|
||||
public static String toJson(Object... args) {
|
||||
JsonArray array = new JsonArray();
|
||||
IdentityHashMap<Object, Boolean> seen = new IdentityHashMap<>();
|
||||
for (Object arg : args) {
|
||||
array.add(serialize(arg, seen, 0));
|
||||
}
|
||||
return GSON.toJson(array);
|
||||
}
|
||||
|
||||
/**
|
||||
* Serialize an exception to JSON.
|
||||
*
|
||||
* @param error Exception to serialize
|
||||
* @return JSON string with exception details
|
||||
*/
|
||||
public static String exceptionToJson(Throwable error) {
|
||||
JsonObject obj = new JsonObject();
|
||||
obj.addProperty("__exception__", true);
|
||||
obj.addProperty("type", error.getClass().getName());
|
||||
obj.addProperty("message", error.getMessage());
|
||||
|
||||
// Capture stack trace
|
||||
JsonArray stackTrace = new JsonArray();
|
||||
for (StackTraceElement element : error.getStackTrace()) {
|
||||
stackTrace.add(element.toString());
|
||||
}
|
||||
obj.add("stackTrace", stackTrace);
|
||||
|
||||
// Capture cause if present
|
||||
if (error.getCause() != null) {
|
||||
obj.addProperty("causeType", error.getCause().getClass().getName());
|
||||
obj.addProperty("causeMessage", error.getCause().getMessage());
|
||||
}
|
||||
|
||||
return GSON.toJson(obj);
|
||||
}
|
||||
|
||||
private static JsonElement serialize(Object obj, IdentityHashMap<Object, Boolean> seen, int depth) {
|
||||
if (obj == null) {
|
||||
return JsonNull.INSTANCE;
|
||||
}
|
||||
|
||||
// Depth limit to prevent infinite recursion
|
||||
if (depth > MAX_DEPTH) {
|
||||
JsonObject truncated = new JsonObject();
|
||||
truncated.addProperty("__truncated__", "max depth exceeded");
|
||||
return truncated;
|
||||
}
|
||||
|
||||
Class<?> clazz = obj.getClass();
|
||||
|
||||
// Primitives and wrappers
|
||||
if (obj instanceof Boolean) {
|
||||
return new JsonPrimitive((Boolean) obj);
|
||||
}
|
||||
if (obj instanceof Number) {
|
||||
return new JsonPrimitive((Number) obj);
|
||||
}
|
||||
if (obj instanceof Character) {
|
||||
return new JsonPrimitive(String.valueOf(obj));
|
||||
}
|
||||
if (obj instanceof String) {
|
||||
return new JsonPrimitive((String) obj);
|
||||
}
|
||||
|
||||
// Check for circular reference (only for reference types)
|
||||
if (seen.containsKey(obj)) {
|
||||
JsonObject circular = new JsonObject();
|
||||
circular.addProperty("__circular_ref__", clazz.getName());
|
||||
return circular;
|
||||
}
|
||||
seen.put(obj, Boolean.TRUE);
|
||||
|
||||
try {
|
||||
// Date/Time types
|
||||
if (obj instanceof Date) {
|
||||
return new JsonPrimitive(((Date) obj).toInstant().toString());
|
||||
}
|
||||
if (obj instanceof LocalDateTime) {
|
||||
return new JsonPrimitive(obj.toString());
|
||||
}
|
||||
if (obj instanceof LocalDate) {
|
||||
return new JsonPrimitive(obj.toString());
|
||||
}
|
||||
if (obj instanceof LocalTime) {
|
||||
return new JsonPrimitive(obj.toString());
|
||||
}
|
||||
|
||||
// Optional
|
||||
if (obj instanceof Optional) {
|
||||
Optional<?> opt = (Optional<?>) obj;
|
||||
if (opt.isPresent()) {
|
||||
return serialize(opt.get(), seen, depth + 1);
|
||||
} else {
|
||||
return JsonNull.INSTANCE;
|
||||
}
|
||||
}
|
||||
|
||||
// Arrays
|
||||
if (clazz.isArray()) {
|
||||
return serializeArray(obj, seen, depth);
|
||||
}
|
||||
|
||||
// Collections
|
||||
if (obj instanceof Collection) {
|
||||
return serializeCollection((Collection<?>) obj, seen, depth);
|
||||
}
|
||||
|
||||
// Maps
|
||||
if (obj instanceof Map) {
|
||||
return serializeMap((Map<?, ?>) obj, seen, depth);
|
||||
}
|
||||
|
||||
// Enums
|
||||
if (clazz.isEnum()) {
|
||||
return new JsonPrimitive(((Enum<?>) obj).name());
|
||||
}
|
||||
|
||||
// Custom objects - serialize via reflection
|
||||
return serializeObject(obj, seen, depth);
|
||||
|
||||
} finally {
|
||||
seen.remove(obj);
|
||||
}
|
||||
}
|
||||
|
||||
private static JsonElement serializeArray(Object array, IdentityHashMap<Object, Boolean> seen, int depth) {
|
||||
JsonArray jsonArray = new JsonArray();
|
||||
int length = java.lang.reflect.Array.getLength(array);
|
||||
int limit = Math.min(length, MAX_COLLECTION_SIZE);
|
||||
|
||||
for (int i = 0; i < limit; i++) {
|
||||
Object element = java.lang.reflect.Array.get(array, i);
|
||||
jsonArray.add(serialize(element, seen, depth + 1));
|
||||
}
|
||||
|
||||
if (length > limit) {
|
||||
JsonObject truncated = new JsonObject();
|
||||
truncated.addProperty("__truncated__", length - limit + " more elements");
|
||||
jsonArray.add(truncated);
|
||||
}
|
||||
|
||||
return jsonArray;
|
||||
}
|
||||
|
||||
private static JsonElement serializeCollection(Collection<?> collection, IdentityHashMap<Object, Boolean> seen, int depth) {
|
||||
JsonArray jsonArray = new JsonArray();
|
||||
int count = 0;
|
||||
|
||||
for (Object element : collection) {
|
||||
if (count >= MAX_COLLECTION_SIZE) {
|
||||
JsonObject truncated = new JsonObject();
|
||||
truncated.addProperty("__truncated__", collection.size() - count + " more elements");
|
||||
jsonArray.add(truncated);
|
||||
break;
|
||||
}
|
||||
jsonArray.add(serialize(element, seen, depth + 1));
|
||||
count++;
|
||||
}
|
||||
|
||||
return jsonArray;
|
||||
}
|
||||
|
||||
private static JsonElement serializeMap(Map<?, ?> map, IdentityHashMap<Object, Boolean> seen, int depth) {
|
||||
JsonObject jsonObject = new JsonObject();
|
||||
int count = 0;
|
||||
|
||||
for (Map.Entry<?, ?> entry : map.entrySet()) {
|
||||
if (count >= MAX_COLLECTION_SIZE) {
|
||||
jsonObject.addProperty("__truncated__", map.size() - count + " more entries");
|
||||
break;
|
||||
}
|
||||
String key = entry.getKey() != null ? entry.getKey().toString() : "null";
|
||||
jsonObject.add(key, serialize(entry.getValue(), seen, depth + 1));
|
||||
count++;
|
||||
}
|
||||
|
||||
return jsonObject;
|
||||
}
|
||||
|
||||
private static JsonElement serializeObject(Object obj, IdentityHashMap<Object, Boolean> seen, int depth) {
|
||||
JsonObject jsonObject = new JsonObject();
|
||||
Class<?> clazz = obj.getClass();
|
||||
|
||||
// Add type information
|
||||
jsonObject.addProperty("__type__", clazz.getName());
|
||||
|
||||
// Serialize all fields (including inherited)
|
||||
while (clazz != null && clazz != Object.class) {
|
||||
for (Field field : clazz.getDeclaredFields()) {
|
||||
// Skip static and transient fields
|
||||
if (Modifier.isStatic(field.getModifiers()) ||
|
||||
Modifier.isTransient(field.getModifiers())) {
|
||||
continue;
|
||||
}
|
||||
|
||||
try {
|
||||
field.setAccessible(true);
|
||||
Object value = field.get(obj);
|
||||
jsonObject.add(field.getName(), serialize(value, seen, depth + 1));
|
||||
} catch (IllegalAccessException e) {
|
||||
jsonObject.addProperty(field.getName(), "__access_denied__");
|
||||
}
|
||||
}
|
||||
clazz = clazz.getSuperclass();
|
||||
}
|
||||
|
||||
return jsonObject;
|
||||
}
|
||||
}
|
||||
|
|
@ -0,0 +1,126 @@
|
|||
package com.codeflash;
|
||||
|
||||
import org.junit.jupiter.api.DisplayName;
|
||||
import org.junit.jupiter.api.Test;
|
||||
|
||||
import static org.junit.jupiter.api.Assertions.*;
|
||||
|
||||
/**
|
||||
* Tests for the BenchmarkResult class.
|
||||
*/
|
||||
@DisplayName("BenchmarkResult Tests")
|
||||
class BenchmarkResultTest {
|
||||
|
||||
@Test
|
||||
@DisplayName("should calculate mean correctly")
|
||||
void testMean() {
|
||||
long[] measurements = {100, 200, 300, 400, 500};
|
||||
BenchmarkResult result = new BenchmarkResult("test", measurements);
|
||||
|
||||
assertEquals(300, result.getMean());
|
||||
}
|
||||
|
||||
@Test
|
||||
@DisplayName("should calculate min and max")
|
||||
void testMinMax() {
|
||||
long[] measurements = {100, 50, 200, 150, 75};
|
||||
BenchmarkResult result = new BenchmarkResult("test", measurements);
|
||||
|
||||
assertEquals(50, result.getMin());
|
||||
assertEquals(200, result.getMax());
|
||||
}
|
||||
|
||||
@Test
|
||||
@DisplayName("should calculate percentiles")
|
||||
void testPercentiles() {
|
||||
long[] measurements = new long[100];
|
||||
for (int i = 0; i < 100; i++) {
|
||||
measurements[i] = i + 1; // 1 to 100
|
||||
}
|
||||
BenchmarkResult result = new BenchmarkResult("test", measurements);
|
||||
|
||||
assertEquals(50, result.getP50());
|
||||
assertEquals(90, result.getP90());
|
||||
assertEquals(99, result.getP99());
|
||||
}
|
||||
|
||||
@Test
|
||||
@DisplayName("should calculate standard deviation")
|
||||
void testStdDev() {
|
||||
// All same values should have 0 std dev
|
||||
long[] sameValues = {100, 100, 100, 100, 100};
|
||||
BenchmarkResult sameResult = new BenchmarkResult("test", sameValues);
|
||||
assertEquals(0, sameResult.getStdDev());
|
||||
|
||||
// Different values should have non-zero std dev
|
||||
long[] differentValues = {100, 200, 300, 400, 500};
|
||||
BenchmarkResult diffResult = new BenchmarkResult("test", differentValues);
|
||||
assertTrue(diffResult.getStdDev() > 0);
|
||||
}
|
||||
|
||||
@Test
|
||||
@DisplayName("should calculate coefficient of variation")
|
||||
void testCoefficientOfVariation() {
|
||||
long[] measurements = {100, 100, 100, 100, 100};
|
||||
BenchmarkResult result = new BenchmarkResult("test", measurements);
|
||||
|
||||
assertEquals(0.0, result.getCoefficientOfVariation(), 0.001);
|
||||
}
|
||||
|
||||
@Test
|
||||
@DisplayName("should detect stable measurements")
|
||||
void testIsStable() {
|
||||
// Low variance - stable
|
||||
long[] stableMeasurements = {100, 101, 99, 100, 102};
|
||||
BenchmarkResult stableResult = new BenchmarkResult("test", stableMeasurements);
|
||||
assertTrue(stableResult.isStable());
|
||||
|
||||
// High variance - unstable
|
||||
long[] unstableMeasurements = {100, 200, 50, 300, 25};
|
||||
BenchmarkResult unstableResult = new BenchmarkResult("test", unstableMeasurements);
|
||||
assertFalse(unstableResult.isStable());
|
||||
}
|
||||
|
||||
@Test
|
||||
@DisplayName("should convert to milliseconds")
|
||||
void testMillisecondConversion() {
|
||||
long[] measurements = {1_000_000, 2_000_000, 3_000_000}; // 1ms, 2ms, 3ms
|
||||
BenchmarkResult result = new BenchmarkResult("test", measurements);
|
||||
|
||||
assertEquals(2.0, result.getMeanMs(), 0.001);
|
||||
}
|
||||
|
||||
@Test
|
||||
@DisplayName("should clone measurements array")
|
||||
void testMeasurementsCloned() {
|
||||
long[] original = {100, 200, 300};
|
||||
BenchmarkResult result = new BenchmarkResult("test", original);
|
||||
|
||||
long[] retrieved = result.getMeasurements();
|
||||
retrieved[0] = 999;
|
||||
|
||||
// Original should not be affected
|
||||
assertEquals(100, result.getMeasurements()[0]);
|
||||
}
|
||||
|
||||
@Test
|
||||
@DisplayName("should return correct iteration count")
|
||||
void testIterationCount() {
|
||||
long[] measurements = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10};
|
||||
BenchmarkResult result = new BenchmarkResult("test", measurements);
|
||||
|
||||
assertEquals(10, result.getIterationCount());
|
||||
}
|
||||
|
||||
@Test
|
||||
@DisplayName("should have meaningful toString")
|
||||
void testToString() {
|
||||
long[] measurements = {1_000_000, 2_000_000};
|
||||
BenchmarkResult result = new BenchmarkResult("Calculator.add", measurements);
|
||||
|
||||
String str = result.toString();
|
||||
assertTrue(str.contains("Calculator.add"));
|
||||
assertTrue(str.contains("mean="));
|
||||
assertTrue(str.contains("ms"));
|
||||
}
|
||||
}
|
||||
|
|
@ -0,0 +1,108 @@
|
|||
package com.codeflash;
|
||||
|
||||
import org.junit.jupiter.api.DisplayName;
|
||||
import org.junit.jupiter.api.Test;
|
||||
|
||||
import java.util.Arrays;
|
||||
|
||||
import static org.junit.jupiter.api.Assertions.*;
|
||||
|
||||
/**
|
||||
* Tests for the Blackhole class.
|
||||
*/
|
||||
@DisplayName("Blackhole Tests")
|
||||
class BlackholeTest {
|
||||
|
||||
@Test
|
||||
@DisplayName("should consume int without throwing")
|
||||
void testConsumeInt() {
|
||||
assertDoesNotThrow(() -> Blackhole.consume(42));
|
||||
}
|
||||
|
||||
@Test
|
||||
@DisplayName("should consume long without throwing")
|
||||
void testConsumeLong() {
|
||||
assertDoesNotThrow(() -> Blackhole.consume(Long.MAX_VALUE));
|
||||
}
|
||||
|
||||
@Test
|
||||
@DisplayName("should consume double without throwing")
|
||||
void testConsumeDouble() {
|
||||
assertDoesNotThrow(() -> Blackhole.consume(3.14159));
|
||||
}
|
||||
|
||||
@Test
|
||||
@DisplayName("should consume float without throwing")
|
||||
void testConsumeFloat() {
|
||||
assertDoesNotThrow(() -> Blackhole.consume(3.14f));
|
||||
}
|
||||
|
||||
@Test
|
||||
@DisplayName("should consume boolean without throwing")
|
||||
void testConsumeBoolean() {
|
||||
assertDoesNotThrow(() -> Blackhole.consume(true));
|
||||
assertDoesNotThrow(() -> Blackhole.consume(false));
|
||||
}
|
||||
|
||||
@Test
|
||||
@DisplayName("should consume byte without throwing")
|
||||
void testConsumeByte() {
|
||||
assertDoesNotThrow(() -> Blackhole.consume((byte) 127));
|
||||
}
|
||||
|
||||
@Test
|
||||
@DisplayName("should consume short without throwing")
|
||||
void testConsumeShort() {
|
||||
assertDoesNotThrow(() -> Blackhole.consume((short) 32000));
|
||||
}
|
||||
|
||||
@Test
|
||||
@DisplayName("should consume char without throwing")
|
||||
void testConsumeChar() {
|
||||
assertDoesNotThrow(() -> Blackhole.consume('x'));
|
||||
}
|
||||
|
||||
@Test
|
||||
@DisplayName("should consume Object without throwing")
|
||||
void testConsumeObject() {
|
||||
assertDoesNotThrow(() -> Blackhole.consume("hello"));
|
||||
assertDoesNotThrow(() -> Blackhole.consume(Arrays.asList(1, 2, 3)));
|
||||
assertDoesNotThrow(() -> Blackhole.consume((Object) null));
|
||||
}
|
||||
|
||||
@Test
|
||||
@DisplayName("should consume int array without throwing")
|
||||
void testConsumeIntArray() {
|
||||
assertDoesNotThrow(() -> Blackhole.consume(new int[]{1, 2, 3}));
|
||||
assertDoesNotThrow(() -> Blackhole.consume((int[]) null));
|
||||
assertDoesNotThrow(() -> Blackhole.consume(new int[]{}));
|
||||
}
|
||||
|
||||
@Test
|
||||
@DisplayName("should consume long array without throwing")
|
||||
void testConsumeLongArray() {
|
||||
assertDoesNotThrow(() -> Blackhole.consume(new long[]{1L, 2L, 3L}));
|
||||
assertDoesNotThrow(() -> Blackhole.consume((long[]) null));
|
||||
}
|
||||
|
||||
@Test
|
||||
@DisplayName("should consume double array without throwing")
|
||||
void testConsumeDoubleArray() {
|
||||
assertDoesNotThrow(() -> Blackhole.consume(new double[]{1.0, 2.0, 3.0}));
|
||||
assertDoesNotThrow(() -> Blackhole.consume((double[]) null));
|
||||
}
|
||||
|
||||
@Test
|
||||
@DisplayName("should prevent dead code elimination in loop")
|
||||
void testPreventDeadCodeInLoop() {
|
||||
// This test verifies that consuming values allows the loop to run
|
||||
// without the JIT potentially eliminating it
|
||||
int sum = 0;
|
||||
for (int i = 0; i < 1000; i++) {
|
||||
sum += i;
|
||||
Blackhole.consume(sum);
|
||||
}
|
||||
// The loop should have run - this is more of a smoke test
|
||||
assertTrue(sum > 0);
|
||||
}
|
||||
}
|
||||
|
|
@ -0,0 +1,283 @@
|
|||
package com.codeflash;
|
||||
|
||||
import org.junit.jupiter.api.DisplayName;
|
||||
import org.junit.jupiter.api.Nested;
|
||||
import org.junit.jupiter.api.Test;
|
||||
|
||||
import java.util.*;
|
||||
|
||||
import static org.junit.jupiter.api.Assertions.*;
|
||||
|
||||
/**
|
||||
* Tests for the Serializer class.
|
||||
*/
|
||||
@DisplayName("Serializer Tests")
|
||||
class SerializerTest {
|
||||
|
||||
@Nested
|
||||
@DisplayName("Primitive Types")
|
||||
class PrimitiveTests {
|
||||
|
||||
@Test
|
||||
@DisplayName("should serialize integers")
|
||||
void testInteger() {
|
||||
assertEquals("42", Serializer.toJson(42));
|
||||
assertEquals("-1", Serializer.toJson(-1));
|
||||
assertEquals("0", Serializer.toJson(0));
|
||||
}
|
||||
|
||||
@Test
|
||||
@DisplayName("should serialize longs")
|
||||
void testLong() {
|
||||
assertEquals("9223372036854775807", Serializer.toJson(Long.MAX_VALUE));
|
||||
}
|
||||
|
||||
@Test
|
||||
@DisplayName("should serialize doubles")
|
||||
void testDouble() {
|
||||
String json = Serializer.toJson(3.14159);
|
||||
assertTrue(json.startsWith("3.14"));
|
||||
}
|
||||
|
||||
@Test
|
||||
@DisplayName("should serialize booleans")
|
||||
void testBoolean() {
|
||||
assertEquals("true", Serializer.toJson(true));
|
||||
assertEquals("false", Serializer.toJson(false));
|
||||
}
|
||||
|
||||
@Test
|
||||
@DisplayName("should serialize strings")
|
||||
void testString() {
|
||||
assertEquals("\"hello\"", Serializer.toJson("hello"));
|
||||
assertEquals("\"with \\\"quotes\\\"\"", Serializer.toJson("with \"quotes\""));
|
||||
}
|
||||
|
||||
@Test
|
||||
@DisplayName("should serialize null")
|
||||
void testNull() {
|
||||
assertEquals("null", Serializer.toJson((Object) null));
|
||||
}
|
||||
|
||||
@Test
|
||||
@DisplayName("should serialize characters")
|
||||
void testCharacter() {
|
||||
assertEquals("\"a\"", Serializer.toJson('a'));
|
||||
}
|
||||
}
|
||||
|
||||
@Nested
|
||||
@DisplayName("Array Types")
|
||||
class ArrayTests {
|
||||
|
||||
@Test
|
||||
@DisplayName("should serialize int arrays")
|
||||
void testIntArray() {
|
||||
int[] arr = {1, 2, 3};
|
||||
assertEquals("[1,2,3]", Serializer.toJson((Object) arr));
|
||||
}
|
||||
|
||||
@Test
|
||||
@DisplayName("should serialize String arrays")
|
||||
void testStringArray() {
|
||||
String[] arr = {"a", "b", "c"};
|
||||
assertEquals("[\"a\",\"b\",\"c\"]", Serializer.toJson((Object) arr));
|
||||
}
|
||||
|
||||
@Test
|
||||
@DisplayName("should serialize empty arrays")
|
||||
void testEmptyArray() {
|
||||
int[] arr = {};
|
||||
assertEquals("[]", Serializer.toJson((Object) arr));
|
||||
}
|
||||
}
|
||||
|
||||
@Nested
|
||||
@DisplayName("Collection Types")
|
||||
class CollectionTests {
|
||||
|
||||
@Test
|
||||
@DisplayName("should serialize Lists")
|
||||
void testList() {
|
||||
List<Integer> list = Arrays.asList(1, 2, 3);
|
||||
assertEquals("[1,2,3]", Serializer.toJson(list));
|
||||
}
|
||||
|
||||
@Test
|
||||
@DisplayName("should serialize Sets")
|
||||
void testSet() {
|
||||
Set<String> set = new LinkedHashSet<>(Arrays.asList("a", "b"));
|
||||
String json = Serializer.toJson(set);
|
||||
assertTrue(json.contains("\"a\""));
|
||||
assertTrue(json.contains("\"b\""));
|
||||
}
|
||||
|
||||
@Test
|
||||
@DisplayName("should serialize Maps")
|
||||
void testMap() {
|
||||
Map<String, Integer> map = new LinkedHashMap<>();
|
||||
map.put("one", 1);
|
||||
map.put("two", 2);
|
||||
String json = Serializer.toJson(map);
|
||||
assertTrue(json.contains("\"one\":1"));
|
||||
assertTrue(json.contains("\"two\":2"));
|
||||
}
|
||||
|
||||
@Test
|
||||
@DisplayName("should handle nested collections")
|
||||
void testNestedCollections() {
|
||||
List<List<Integer>> nested = Arrays.asList(
|
||||
Arrays.asList(1, 2),
|
||||
Arrays.asList(3, 4)
|
||||
);
|
||||
assertEquals("[[1,2],[3,4]]", Serializer.toJson(nested));
|
||||
}
|
||||
}
|
||||
|
||||
@Nested
|
||||
@DisplayName("Varargs")
|
||||
class VarargsTests {
|
||||
|
||||
@Test
|
||||
@DisplayName("should serialize multiple arguments")
|
||||
void testVarargs() {
|
||||
String json = Serializer.toJson(1, "hello", true);
|
||||
assertEquals("[1,\"hello\",true]", json);
|
||||
}
|
||||
|
||||
@Test
|
||||
@DisplayName("should serialize mixed types")
|
||||
void testMixedVarargs() {
|
||||
String json = Serializer.toJson(42, Arrays.asList(1, 2), null);
|
||||
assertTrue(json.startsWith("[42,"));
|
||||
assertTrue(json.contains("null"));
|
||||
}
|
||||
}
|
||||
|
||||
@Nested
|
||||
@DisplayName("Custom Objects")
|
||||
class CustomObjectTests {
|
||||
|
||||
@Test
|
||||
@DisplayName("should serialize simple objects")
|
||||
void testSimpleObject() {
|
||||
TestPerson person = new TestPerson("John", 30);
|
||||
String json = Serializer.toJson(person);
|
||||
|
||||
assertTrue(json.contains("\"name\":\"John\""));
|
||||
assertTrue(json.contains("\"age\":30"));
|
||||
assertTrue(json.contains("\"__type__\""));
|
||||
}
|
||||
|
||||
@Test
|
||||
@DisplayName("should serialize nested objects")
|
||||
void testNestedObject() {
|
||||
TestAddress address = new TestAddress("123 Main St", "NYC");
|
||||
TestPersonWithAddress person = new TestPersonWithAddress("Jane", address);
|
||||
String json = Serializer.toJson(person);
|
||||
|
||||
assertTrue(json.contains("\"name\":\"Jane\""));
|
||||
assertTrue(json.contains("\"city\":\"NYC\""));
|
||||
}
|
||||
}
|
||||
|
||||
@Nested
|
||||
@DisplayName("Exception Serialization")
|
||||
class ExceptionTests {
|
||||
|
||||
@Test
|
||||
@DisplayName("should serialize exception with type and message")
|
||||
void testException() {
|
||||
Exception e = new IllegalArgumentException("test error");
|
||||
String json = Serializer.exceptionToJson(e);
|
||||
|
||||
assertTrue(json.contains("\"__exception__\":true"));
|
||||
assertTrue(json.contains("\"type\":\"java.lang.IllegalArgumentException\""));
|
||||
assertTrue(json.contains("\"message\":\"test error\""));
|
||||
}
|
||||
|
||||
@Test
|
||||
@DisplayName("should include stack trace")
|
||||
void testExceptionStackTrace() {
|
||||
Exception e = new RuntimeException("test");
|
||||
String json = Serializer.exceptionToJson(e);
|
||||
|
||||
assertTrue(json.contains("\"stackTrace\""));
|
||||
}
|
||||
|
||||
@Test
|
||||
@DisplayName("should include cause")
|
||||
void testExceptionWithCause() {
|
||||
Exception cause = new NullPointerException("root cause");
|
||||
Exception e = new RuntimeException("wrapper", cause);
|
||||
String json = Serializer.exceptionToJson(e);
|
||||
|
||||
assertTrue(json.contains("\"causeType\":\"java.lang.NullPointerException\""));
|
||||
assertTrue(json.contains("\"causeMessage\":\"root cause\""));
|
||||
}
|
||||
}
|
||||
|
||||
@Nested
|
||||
@DisplayName("Edge Cases")
|
||||
class EdgeCaseTests {
|
||||
|
||||
@Test
|
||||
@DisplayName("should handle Optional with value")
|
||||
void testOptionalPresent() {
|
||||
Optional<String> opt = Optional.of("value");
|
||||
assertEquals("\"value\"", Serializer.toJson(opt));
|
||||
}
|
||||
|
||||
@Test
|
||||
@DisplayName("should handle Optional empty")
|
||||
void testOptionalEmpty() {
|
||||
Optional<String> opt = Optional.empty();
|
||||
assertEquals("null", Serializer.toJson(opt));
|
||||
}
|
||||
|
||||
@Test
|
||||
@DisplayName("should handle enums")
|
||||
void testEnum() {
|
||||
assertEquals("\"MONDAY\"", Serializer.toJson(java.time.DayOfWeek.MONDAY));
|
||||
}
|
||||
|
||||
@Test
|
||||
@DisplayName("should handle Date")
|
||||
void testDate() {
|
||||
Date date = new Date(0); // Epoch
|
||||
String json = Serializer.toJson(date);
|
||||
assertTrue(json.contains("1970"));
|
||||
}
|
||||
}
|
||||
|
||||
// Test helper classes
|
||||
static class TestPerson {
|
||||
private final String name;
|
||||
private final int age;
|
||||
|
||||
TestPerson(String name, int age) {
|
||||
this.name = name;
|
||||
this.age = age;
|
||||
}
|
||||
}
|
||||
|
||||
static class TestAddress {
|
||||
private final String street;
|
||||
private final String city;
|
||||
|
||||
TestAddress(String street, String city) {
|
||||
this.street = street;
|
||||
this.city = city;
|
||||
}
|
||||
}
|
||||
|
||||
static class TestPersonWithAddress {
|
||||
private final String name;
|
||||
private final TestAddress address;
|
||||
|
||||
TestPersonWithAddress(String name, TestAddress address) {
|
||||
this.name = name;
|
||||
this.address = address;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -14,7 +14,7 @@ from codeflash.cli_cmds.console import console, logger
|
|||
from codeflash.code_utils.env_utils import get_codeflash_api_key
|
||||
from codeflash.code_utils.git_utils import get_last_commit_author_if_pr_exists, get_repo_owner_and_name
|
||||
from codeflash.code_utils.time_utils import humanize_runtime
|
||||
from codeflash.languages import is_javascript, is_python
|
||||
from codeflash.languages import is_java, is_javascript, is_python
|
||||
from codeflash.models.ExperimentMetadata import ExperimentMetadata
|
||||
from codeflash.models.models import (
|
||||
AIServiceRefinerRequest,
|
||||
|
|
@ -182,6 +182,8 @@ class AiServiceClient:
|
|||
payload["python_version"] = platform.python_version()
|
||||
if is_python():
|
||||
pass # python_version already set
|
||||
elif is_java():
|
||||
payload["language_version"] = language_version or "17" # Default Java version
|
||||
else:
|
||||
payload["language_version"] = language_version or "ES2022"
|
||||
# Add module system for JavaScript/TypeScript (esm or commonjs)
|
||||
|
|
@ -785,6 +787,8 @@ class AiServiceClient:
|
|||
payload["python_version"] = platform.python_version()
|
||||
if is_python():
|
||||
pass # python_version already set
|
||||
elif is_java():
|
||||
payload["language_version"] = language_version or "17" # Default Java version
|
||||
else:
|
||||
payload["language_version"] = language_version or "ES2022"
|
||||
# Add module system for JavaScript/TypeScript (esm or commonjs)
|
||||
|
|
|
|||
|
|
@ -273,6 +273,20 @@ def process_pyproject_config(args: Namespace) -> Namespace:
|
|||
def project_root_from_module_root(module_root: Path, pyproject_file_path: Path) -> Path:
|
||||
if pyproject_file_path.parent == module_root:
|
||||
return module_root
|
||||
|
||||
# For Java projects, find the directory containing pom.xml or build.gradle
|
||||
# This handles the case where module_root is src/main/java
|
||||
current = module_root
|
||||
while current != current.parent:
|
||||
if (current / "pom.xml").exists():
|
||||
return current.resolve()
|
||||
if (current / "build.gradle").exists() or (current / "build.gradle.kts").exists():
|
||||
return current.resolve()
|
||||
# Check for config file (pyproject.toml for Python, codeflash.toml for other languages)
|
||||
if (current / "codeflash.toml").exists():
|
||||
return current.resolve()
|
||||
current = current.parent
|
||||
|
||||
return module_root.parent.resolve()
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -35,6 +35,9 @@ from codeflash.cli_cmds.init_javascript import (
|
|||
get_js_dependency_installation_commands,
|
||||
init_js_project,
|
||||
)
|
||||
|
||||
# Import Java init module
|
||||
from codeflash.cli_cmds.init_java import init_java_project
|
||||
from codeflash.code_utils.code_utils import validate_relative_directory_path
|
||||
from codeflash.code_utils.compat import LF
|
||||
from codeflash.code_utils.config_parser import parse_config_file
|
||||
|
|
@ -114,6 +117,10 @@ def init_codeflash() -> None:
|
|||
# Detect project language
|
||||
project_language = detect_project_language()
|
||||
|
||||
if project_language == ProjectLanguage.JAVA:
|
||||
init_java_project()
|
||||
return
|
||||
|
||||
if project_language in (ProjectLanguage.JAVASCRIPT, ProjectLanguage.TYPESCRIPT):
|
||||
init_js_project(project_language)
|
||||
return
|
||||
|
|
@ -798,7 +805,9 @@ def install_github_actions(override_formatter_check: bool = False) -> None:
|
|||
|
||||
# Select the appropriate workflow template based on project language
|
||||
project_language = detect_project_language_for_workflow(Path.cwd())
|
||||
if project_language in ("javascript", "typescript"):
|
||||
if project_language == "java":
|
||||
workflow_template = "codeflash-optimize-java.yaml"
|
||||
elif project_language in ("javascript", "typescript"):
|
||||
workflow_template = "codeflash-optimize-js.yaml"
|
||||
else:
|
||||
workflow_template = "codeflash-optimize.yaml"
|
||||
|
|
@ -1210,8 +1219,16 @@ def get_github_action_working_directory(toml_path: Path, git_root: Path) -> str:
|
|||
def detect_project_language_for_workflow(project_root: Path) -> str:
|
||||
"""Detect the primary language of the project for workflow generation.
|
||||
|
||||
Returns: 'python', 'javascript', or 'typescript'
|
||||
Returns: 'python', 'javascript', 'typescript', or 'java'
|
||||
"""
|
||||
# Check for Java project (Maven or Gradle)
|
||||
has_pom_xml = (project_root / "pom.xml").exists()
|
||||
has_build_gradle = (project_root / "build.gradle").exists() or (project_root / "build.gradle.kts").exists()
|
||||
has_java_src = (project_root / "src" / "main" / "java").is_dir()
|
||||
|
||||
if has_pom_xml or has_build_gradle or has_java_src:
|
||||
return "java"
|
||||
|
||||
# Check for TypeScript config
|
||||
if (project_root / "tsconfig.json").exists():
|
||||
return "typescript"
|
||||
|
|
@ -1230,6 +1247,7 @@ def detect_project_language_for_workflow(project_root: Path) -> str:
|
|||
# Both exist - count files to determine primary language
|
||||
js_count = 0
|
||||
py_count = 0
|
||||
java_count = 0
|
||||
for file in project_root.rglob("*"):
|
||||
if file.is_file():
|
||||
suffix = file.suffix.lower()
|
||||
|
|
@ -1237,8 +1255,13 @@ def detect_project_language_for_workflow(project_root: Path) -> str:
|
|||
js_count += 1
|
||||
elif suffix == ".py":
|
||||
py_count += 1
|
||||
elif suffix == ".java":
|
||||
java_count += 1
|
||||
|
||||
if js_count > py_count:
|
||||
max_count = max(js_count, py_count, java_count)
|
||||
if max_count == java_count and java_count > 0:
|
||||
return "java"
|
||||
if max_count == js_count and js_count > 0:
|
||||
return "javascript"
|
||||
return "python"
|
||||
|
||||
|
|
@ -1343,9 +1366,9 @@ def generate_dynamic_workflow_content(
|
|||
# Detect project language
|
||||
project_language = detect_project_language_for_workflow(Path.cwd())
|
||||
|
||||
# For JavaScript/TypeScript projects, use static template customization
|
||||
# For JavaScript/TypeScript and Java projects, use static template customization
|
||||
# (AI-generated steps are currently Python-only)
|
||||
if project_language in ("javascript", "typescript"):
|
||||
if project_language in ("javascript", "typescript", "java"):
|
||||
return customize_codeflash_yaml_content(optimize_yml_content, config, git_root, benchmark_mode)
|
||||
|
||||
# Python project - try AI-generated steps
|
||||
|
|
@ -1466,6 +1489,10 @@ def customize_codeflash_yaml_content(
|
|||
# Detect project language
|
||||
project_language = detect_project_language_for_workflow(Path.cwd())
|
||||
|
||||
if project_language == "java":
|
||||
# Java project
|
||||
return _customize_java_workflow_content(optimize_yml_content, git_root, benchmark_mode)
|
||||
|
||||
if project_language in ("javascript", "typescript"):
|
||||
# JavaScript/TypeScript project
|
||||
return _customize_js_workflow_content(optimize_yml_content, git_root, benchmark_mode)
|
||||
|
|
@ -1562,6 +1589,54 @@ def _customize_js_workflow_content(optimize_yml_content: str, git_root: Path, be
|
|||
return optimize_yml_content.replace("{{ codeflash_command }}", codeflash_cmd)
|
||||
|
||||
|
||||
def _customize_java_workflow_content(optimize_yml_content: str, git_root: Path, benchmark_mode: bool = False) -> str:
|
||||
"""Customize workflow content for Java projects."""
|
||||
from codeflash.cli_cmds.init_java import (
|
||||
JavaBuildTool,
|
||||
detect_java_build_tool,
|
||||
get_java_dependency_installation_commands,
|
||||
)
|
||||
|
||||
project_root = Path.cwd()
|
||||
|
||||
# Check for pom.xml or build.gradle
|
||||
has_pom = (project_root / "pom.xml").exists()
|
||||
has_gradle = (project_root / "build.gradle").exists() or (project_root / "build.gradle.kts").exists()
|
||||
|
||||
if not has_pom and not has_gradle:
|
||||
click.echo(
|
||||
f"I couldn't find a pom.xml or build.gradle in the current directory.{LF}"
|
||||
f"Please ensure you're in a Maven or Gradle project directory."
|
||||
)
|
||||
apologize_and_exit()
|
||||
|
||||
# Determine working directory relative to git root
|
||||
if project_root == git_root:
|
||||
working_dir = ""
|
||||
else:
|
||||
rel_path = str(project_root.relative_to(git_root))
|
||||
working_dir = f"""defaults:
|
||||
run:
|
||||
working-directory: ./{rel_path}"""
|
||||
|
||||
optimize_yml_content = optimize_yml_content.replace("{{ working_directory }}", working_dir)
|
||||
|
||||
# Determine build tool
|
||||
build_tool = detect_java_build_tool(project_root)
|
||||
|
||||
# Set build tool cache type for actions/setup-java
|
||||
if build_tool == JavaBuildTool.GRADLE:
|
||||
optimize_yml_content = optimize_yml_content.replace("{{ java_build_tool }}", "gradle")
|
||||
else:
|
||||
optimize_yml_content = optimize_yml_content.replace("{{ java_build_tool }}", "maven")
|
||||
|
||||
# Install dependencies
|
||||
install_deps_cmd = get_java_dependency_installation_commands(build_tool)
|
||||
optimize_yml_content = optimize_yml_content.replace("{{ install_dependencies_command }}", install_deps_cmd)
|
||||
|
||||
return optimize_yml_content
|
||||
|
||||
|
||||
def get_formatter_cmds(formatter: str) -> list[str]:
|
||||
if formatter == "black":
|
||||
return ["black $file"]
|
||||
|
|
|
|||
|
|
@ -34,6 +34,7 @@ class ProjectLanguage(Enum):
|
|||
PYTHON = auto()
|
||||
JAVASCRIPT = auto()
|
||||
TYPESCRIPT = auto()
|
||||
JAVA = auto()
|
||||
|
||||
|
||||
class JsPackageManager(Enum):
|
||||
|
|
@ -89,6 +90,13 @@ def detect_project_language(project_root: Path | None = None) -> ProjectLanguage
|
|||
has_setup_py = (root / "setup.py").exists()
|
||||
has_package_json = (root / "package.json").exists()
|
||||
has_tsconfig = (root / "tsconfig.json").exists()
|
||||
has_pom_xml = (root / "pom.xml").exists()
|
||||
has_build_gradle = (root / "build.gradle").exists() or (root / "build.gradle.kts").exists()
|
||||
has_java_src = (root / "src" / "main" / "java").is_dir()
|
||||
|
||||
# Java project (Maven or Gradle)
|
||||
if has_pom_xml or has_build_gradle or has_java_src:
|
||||
return ProjectLanguage.JAVA
|
||||
|
||||
# TypeScript project
|
||||
if has_tsconfig:
|
||||
|
|
|
|||
|
|
@ -30,6 +30,7 @@ from codeflash.languages.base import (
|
|||
from codeflash.languages.current import (
|
||||
current_language,
|
||||
current_language_support,
|
||||
is_java,
|
||||
is_javascript,
|
||||
is_python,
|
||||
is_typescript,
|
||||
|
|
@ -41,6 +42,10 @@ from codeflash.languages.javascript import JavaScriptSupport, TypeScriptSupport
|
|||
# Import language support modules to trigger auto-registration
|
||||
# This ensures all supported languages are available when this package is imported
|
||||
from codeflash.languages.python import PythonSupport # noqa: F401
|
||||
|
||||
# Java language support
|
||||
# Importing the module triggers registration via @register_language decorator
|
||||
from codeflash.languages.java.support import JavaSupport # noqa: F401
|
||||
from codeflash.languages.registry import (
|
||||
detect_project_language,
|
||||
get_language_support,
|
||||
|
|
@ -67,6 +72,7 @@ __all__ = [
|
|||
"get_language_support",
|
||||
"get_supported_extensions",
|
||||
"get_supported_languages",
|
||||
"is_java",
|
||||
"is_javascript",
|
||||
"is_python",
|
||||
"is_typescript",
|
||||
|
|
|
|||
|
|
@ -22,6 +22,7 @@ class Language(str, Enum):
|
|||
PYTHON = "python"
|
||||
JAVASCRIPT = "javascript"
|
||||
TYPESCRIPT = "typescript"
|
||||
JAVA = "java"
|
||||
|
||||
def __str__(self) -> str:
|
||||
return self.value
|
||||
|
|
|
|||
|
|
@ -106,6 +106,16 @@ def is_typescript() -> bool:
|
|||
return _current_language == Language.TYPESCRIPT
|
||||
|
||||
|
||||
def is_java() -> bool:
|
||||
"""Check if the current language is Java.
|
||||
|
||||
Returns:
|
||||
True if the current language is Java.
|
||||
|
||||
"""
|
||||
return _current_language == Language.JAVA
|
||||
|
||||
|
||||
def current_language_support() -> LanguageSupport:
|
||||
"""Get the LanguageSupport instance for the current language.
|
||||
|
||||
|
|
|
|||
195
codeflash/languages/java/__init__.py
Normal file
195
codeflash/languages/java/__init__.py
Normal file
|
|
@ -0,0 +1,195 @@
|
|||
"""Java language support for codeflash.
|
||||
|
||||
This module provides Java-specific functionality for code analysis,
|
||||
test execution, and optimization using tree-sitter for parsing and
|
||||
Maven/Gradle for build operations.
|
||||
"""
|
||||
|
||||
from codeflash.languages.java.build_tools import (
|
||||
BuildTool,
|
||||
JavaProjectInfo,
|
||||
MavenTestResult,
|
||||
add_codeflash_dependency_to_pom,
|
||||
compile_maven_project,
|
||||
detect_build_tool,
|
||||
find_gradle_executable,
|
||||
find_maven_executable,
|
||||
find_source_root,
|
||||
find_test_root,
|
||||
get_classpath,
|
||||
get_project_info,
|
||||
install_codeflash_runtime,
|
||||
run_maven_tests,
|
||||
)
|
||||
from codeflash.languages.java.comparator import (
|
||||
compare_invocations_directly,
|
||||
compare_test_results,
|
||||
)
|
||||
from codeflash.languages.java.config import (
|
||||
JavaProjectConfig,
|
||||
detect_java_project,
|
||||
get_test_class_pattern,
|
||||
get_test_file_pattern,
|
||||
is_java_project,
|
||||
)
|
||||
from codeflash.languages.java.context import (
|
||||
extract_class_context,
|
||||
extract_code_context,
|
||||
extract_function_source,
|
||||
extract_read_only_context,
|
||||
find_helper_functions,
|
||||
)
|
||||
from codeflash.languages.java.discovery import (
|
||||
discover_functions,
|
||||
discover_functions_from_source,
|
||||
discover_test_methods,
|
||||
get_class_methods,
|
||||
get_method_by_name,
|
||||
)
|
||||
from codeflash.languages.java.formatter import (
|
||||
JavaFormatter,
|
||||
format_java_code,
|
||||
format_java_file,
|
||||
normalize_java_code,
|
||||
)
|
||||
from codeflash.languages.java.import_resolver import (
|
||||
JavaImportResolver,
|
||||
ResolvedImport,
|
||||
find_helper_files,
|
||||
resolve_imports_for_file,
|
||||
)
|
||||
from codeflash.languages.java.instrumentation import (
|
||||
create_benchmark_test,
|
||||
instrument_existing_test,
|
||||
instrument_for_behavior,
|
||||
instrument_for_benchmarking,
|
||||
remove_instrumentation,
|
||||
)
|
||||
from codeflash.languages.java.parser import (
|
||||
JavaAnalyzer,
|
||||
JavaClassNode,
|
||||
JavaFieldInfo,
|
||||
JavaImportInfo,
|
||||
JavaMethodNode,
|
||||
get_java_analyzer,
|
||||
)
|
||||
from codeflash.languages.java.replacement import (
|
||||
add_runtime_comments,
|
||||
insert_method,
|
||||
remove_method,
|
||||
remove_test_functions,
|
||||
replace_function,
|
||||
replace_method_body,
|
||||
)
|
||||
from codeflash.languages.java.support import (
|
||||
JavaSupport,
|
||||
get_java_support,
|
||||
)
|
||||
from codeflash.languages.java.test_discovery import (
|
||||
build_test_mapping_for_project,
|
||||
discover_all_tests,
|
||||
discover_tests,
|
||||
find_tests_for_function,
|
||||
get_test_class_for_source_class,
|
||||
get_test_file_suffix,
|
||||
get_test_methods_for_class,
|
||||
is_test_file,
|
||||
)
|
||||
from codeflash.languages.java.test_runner import (
|
||||
JavaTestRunResult,
|
||||
get_test_run_command,
|
||||
parse_surefire_results,
|
||||
parse_test_results,
|
||||
run_behavioral_tests,
|
||||
run_benchmarking_tests,
|
||||
run_tests,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
# Parser
|
||||
"JavaAnalyzer",
|
||||
"JavaClassNode",
|
||||
"JavaFieldInfo",
|
||||
"JavaImportInfo",
|
||||
"JavaMethodNode",
|
||||
"get_java_analyzer",
|
||||
# Build tools
|
||||
"BuildTool",
|
||||
"JavaProjectInfo",
|
||||
"MavenTestResult",
|
||||
"add_codeflash_dependency_to_pom",
|
||||
"compile_maven_project",
|
||||
"detect_build_tool",
|
||||
"find_gradle_executable",
|
||||
"find_maven_executable",
|
||||
"find_source_root",
|
||||
"find_test_root",
|
||||
"get_classpath",
|
||||
"get_project_info",
|
||||
"install_codeflash_runtime",
|
||||
"run_maven_tests",
|
||||
# Comparator
|
||||
"compare_invocations_directly",
|
||||
"compare_test_results",
|
||||
# Config
|
||||
"JavaProjectConfig",
|
||||
"detect_java_project",
|
||||
"get_test_class_pattern",
|
||||
"get_test_file_pattern",
|
||||
"is_java_project",
|
||||
# Context
|
||||
"extract_class_context",
|
||||
"extract_code_context",
|
||||
"extract_function_source",
|
||||
"extract_read_only_context",
|
||||
"find_helper_functions",
|
||||
# Discovery
|
||||
"discover_functions",
|
||||
"discover_functions_from_source",
|
||||
"discover_test_methods",
|
||||
"get_class_methods",
|
||||
"get_method_by_name",
|
||||
# Formatter
|
||||
"JavaFormatter",
|
||||
"format_java_code",
|
||||
"format_java_file",
|
||||
"normalize_java_code",
|
||||
# Import resolver
|
||||
"JavaImportResolver",
|
||||
"ResolvedImport",
|
||||
"find_helper_files",
|
||||
"resolve_imports_for_file",
|
||||
# Instrumentation
|
||||
"create_benchmark_test",
|
||||
"instrument_existing_test",
|
||||
"instrument_for_behavior",
|
||||
"instrument_for_benchmarking",
|
||||
"remove_instrumentation",
|
||||
# Replacement
|
||||
"add_runtime_comments",
|
||||
"insert_method",
|
||||
"remove_method",
|
||||
"remove_test_functions",
|
||||
"replace_function",
|
||||
"replace_method_body",
|
||||
# Support
|
||||
"JavaSupport",
|
||||
"get_java_support",
|
||||
# Test discovery
|
||||
"build_test_mapping_for_project",
|
||||
"discover_all_tests",
|
||||
"discover_tests",
|
||||
"find_tests_for_function",
|
||||
"get_test_class_for_source_class",
|
||||
"get_test_file_suffix",
|
||||
"get_test_methods_for_class",
|
||||
"is_test_file",
|
||||
# Test runner
|
||||
"JavaTestRunResult",
|
||||
"get_test_run_command",
|
||||
"parse_surefire_results",
|
||||
"parse_test_results",
|
||||
"run_behavioral_tests",
|
||||
"run_benchmarking_tests",
|
||||
"run_tests",
|
||||
]
|
||||
742
codeflash/languages/java/build_tools.py
Normal file
742
codeflash/languages/java/build_tools.py
Normal file
|
|
@ -0,0 +1,742 @@
|
|||
"""Java build tool detection and integration.
|
||||
|
||||
This module provides functionality to detect and work with Java build tools
|
||||
(Maven and Gradle), including running tests and managing dependencies.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import os
|
||||
import shutil
|
||||
import subprocess
|
||||
import xml.etree.ElementTree as ET
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
if TYPE_CHECKING:
|
||||
pass
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class BuildTool(Enum):
|
||||
"""Supported Java build tools."""
|
||||
|
||||
MAVEN = "maven"
|
||||
GRADLE = "gradle"
|
||||
UNKNOWN = "unknown"
|
||||
|
||||
|
||||
@dataclass
|
||||
class JavaProjectInfo:
|
||||
"""Information about a Java project."""
|
||||
|
||||
project_root: Path
|
||||
build_tool: BuildTool
|
||||
source_roots: list[Path]
|
||||
test_roots: list[Path]
|
||||
target_dir: Path # build output directory
|
||||
group_id: str | None
|
||||
artifact_id: str | None
|
||||
version: str | None
|
||||
java_version: str | None
|
||||
|
||||
|
||||
@dataclass
|
||||
class MavenTestResult:
|
||||
"""Result of running Maven tests."""
|
||||
|
||||
success: bool
|
||||
tests_run: int
|
||||
failures: int
|
||||
errors: int
|
||||
skipped: int
|
||||
surefire_reports_dir: Path | None
|
||||
stdout: str
|
||||
stderr: str
|
||||
returncode: int
|
||||
|
||||
|
||||
def detect_build_tool(project_root: Path) -> BuildTool:
|
||||
"""Detect which build tool a Java project uses.
|
||||
|
||||
Args:
|
||||
project_root: Root directory of the Java project.
|
||||
|
||||
Returns:
|
||||
The detected BuildTool enum value.
|
||||
|
||||
"""
|
||||
# Check for Maven (pom.xml)
|
||||
if (project_root / "pom.xml").exists():
|
||||
return BuildTool.MAVEN
|
||||
|
||||
# Check for Gradle (build.gradle or build.gradle.kts)
|
||||
if (project_root / "build.gradle").exists() or (project_root / "build.gradle.kts").exists():
|
||||
return BuildTool.GRADLE
|
||||
|
||||
# Check in parent directories for multi-module projects
|
||||
current = project_root
|
||||
for _ in range(3): # Check up to 3 levels
|
||||
parent = current.parent
|
||||
if parent == current:
|
||||
break
|
||||
if (parent / "pom.xml").exists():
|
||||
return BuildTool.MAVEN
|
||||
if (parent / "build.gradle").exists() or (parent / "build.gradle.kts").exists():
|
||||
return BuildTool.GRADLE
|
||||
current = parent
|
||||
|
||||
return BuildTool.UNKNOWN
|
||||
|
||||
|
||||
def get_project_info(project_root: Path) -> JavaProjectInfo | None:
|
||||
"""Get information about a Java project.
|
||||
|
||||
Args:
|
||||
project_root: Root directory of the Java project.
|
||||
|
||||
Returns:
|
||||
JavaProjectInfo if a supported project is found, None otherwise.
|
||||
|
||||
"""
|
||||
build_tool = detect_build_tool(project_root)
|
||||
|
||||
if build_tool == BuildTool.MAVEN:
|
||||
return _get_maven_project_info(project_root)
|
||||
if build_tool == BuildTool.GRADLE:
|
||||
return _get_gradle_project_info(project_root)
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def _get_maven_project_info(project_root: Path) -> JavaProjectInfo | None:
|
||||
"""Get project info from Maven pom.xml.
|
||||
|
||||
Args:
|
||||
project_root: Root directory of the Maven project.
|
||||
|
||||
Returns:
|
||||
JavaProjectInfo extracted from pom.xml.
|
||||
|
||||
"""
|
||||
pom_path = project_root / "pom.xml"
|
||||
if not pom_path.exists():
|
||||
return None
|
||||
|
||||
try:
|
||||
tree = ET.parse(pom_path)
|
||||
root = tree.getroot()
|
||||
|
||||
# Handle Maven namespace
|
||||
ns = {"m": "http://maven.apache.org/POM/4.0.0"}
|
||||
|
||||
def get_text(xpath: str, default: str | None = None) -> str | None:
|
||||
# Try with namespace first
|
||||
elem = root.find(f"m:{xpath}", ns)
|
||||
if elem is None:
|
||||
# Try without namespace
|
||||
elem = root.find(xpath)
|
||||
return elem.text if elem is not None else default
|
||||
|
||||
group_id = get_text("groupId")
|
||||
artifact_id = get_text("artifactId")
|
||||
version = get_text("version")
|
||||
|
||||
# Get Java version from properties or compiler plugin
|
||||
java_version = _extract_java_version_from_pom(root, ns)
|
||||
|
||||
# Standard Maven directory structure
|
||||
source_roots = []
|
||||
test_roots = []
|
||||
|
||||
main_src = project_root / "src" / "main" / "java"
|
||||
if main_src.exists():
|
||||
source_roots.append(main_src)
|
||||
|
||||
test_src = project_root / "src" / "test" / "java"
|
||||
if test_src.exists():
|
||||
test_roots.append(test_src)
|
||||
|
||||
target_dir = project_root / "target"
|
||||
|
||||
return JavaProjectInfo(
|
||||
project_root=project_root,
|
||||
build_tool=BuildTool.MAVEN,
|
||||
source_roots=source_roots,
|
||||
test_roots=test_roots,
|
||||
target_dir=target_dir,
|
||||
group_id=group_id,
|
||||
artifact_id=artifact_id,
|
||||
version=version,
|
||||
java_version=java_version,
|
||||
)
|
||||
|
||||
except ET.ParseError as e:
|
||||
logger.warning("Failed to parse pom.xml: %s", e)
|
||||
return None
|
||||
|
||||
|
||||
def _extract_java_version_from_pom(root: ET.Element, ns: dict[str, str]) -> str | None:
|
||||
"""Extract Java version from Maven pom.xml.
|
||||
|
||||
Checks multiple locations:
|
||||
1. properties/maven.compiler.source
|
||||
2. properties/java.version
|
||||
3. build/plugins/plugin[compiler]/configuration/source
|
||||
|
||||
Args:
|
||||
root: Root element of the pom.xml.
|
||||
ns: XML namespace mapping.
|
||||
|
||||
Returns:
|
||||
Java version string or None.
|
||||
|
||||
"""
|
||||
# Check properties
|
||||
for prop_name in ("maven.compiler.source", "java.version", "maven.compiler.release"):
|
||||
for props in [root.find(f"m:properties", ns), root.find("properties")]:
|
||||
if props is not None:
|
||||
for prop in [props.find(f"m:{prop_name}", ns), props.find(prop_name)]:
|
||||
if prop is not None and prop.text:
|
||||
return prop.text
|
||||
|
||||
# Check compiler plugin configuration
|
||||
for build in [root.find(f"m:build", ns), root.find("build")]:
|
||||
if build is not None:
|
||||
for plugins in [build.find(f"m:plugins", ns), build.find("plugins")]:
|
||||
if plugins is not None:
|
||||
for plugin in plugins.findall(f"m:plugin", ns) + plugins.findall("plugin"):
|
||||
artifact_id = plugin.find(f"m:artifactId", ns) or plugin.find("artifactId")
|
||||
if artifact_id is not None and artifact_id.text == "maven-compiler-plugin":
|
||||
config = plugin.find(f"m:configuration", ns) or plugin.find("configuration")
|
||||
if config is not None:
|
||||
source = config.find(f"m:source", ns) or config.find("source")
|
||||
if source is not None and source.text:
|
||||
return source.text
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def _get_gradle_project_info(project_root: Path) -> JavaProjectInfo | None:
|
||||
"""Get project info from Gradle build file.
|
||||
|
||||
Note: This is a basic implementation. Full Gradle parsing would require
|
||||
running Gradle with a custom task or using the Gradle tooling API.
|
||||
|
||||
Args:
|
||||
project_root: Root directory of the Gradle project.
|
||||
|
||||
Returns:
|
||||
JavaProjectInfo with basic Gradle project structure.
|
||||
|
||||
"""
|
||||
# Standard Gradle directory structure
|
||||
source_roots = []
|
||||
test_roots = []
|
||||
|
||||
main_src = project_root / "src" / "main" / "java"
|
||||
if main_src.exists():
|
||||
source_roots.append(main_src)
|
||||
|
||||
test_src = project_root / "src" / "test" / "java"
|
||||
if test_src.exists():
|
||||
test_roots.append(test_src)
|
||||
|
||||
build_dir = project_root / "build"
|
||||
|
||||
return JavaProjectInfo(
|
||||
project_root=project_root,
|
||||
build_tool=BuildTool.GRADLE,
|
||||
source_roots=source_roots,
|
||||
test_roots=test_roots,
|
||||
target_dir=build_dir,
|
||||
group_id=None, # Would need to parse build.gradle
|
||||
artifact_id=None,
|
||||
version=None,
|
||||
java_version=None,
|
||||
)
|
||||
|
||||
|
||||
def find_maven_executable() -> str | None:
|
||||
"""Find the Maven executable.
|
||||
|
||||
Returns:
|
||||
Path to mvn executable, or None if not found.
|
||||
|
||||
"""
|
||||
# Check for Maven wrapper first
|
||||
if os.path.exists("mvnw"):
|
||||
return "./mvnw"
|
||||
if os.path.exists("mvnw.cmd"):
|
||||
return "mvnw.cmd"
|
||||
|
||||
# Check system Maven
|
||||
mvn_path = shutil.which("mvn")
|
||||
if mvn_path:
|
||||
return mvn_path
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def find_gradle_executable() -> str | None:
|
||||
"""Find the Gradle executable.
|
||||
|
||||
Returns:
|
||||
Path to gradle executable, or None if not found.
|
||||
|
||||
"""
|
||||
# Check for Gradle wrapper first
|
||||
if os.path.exists("gradlew"):
|
||||
return "./gradlew"
|
||||
if os.path.exists("gradlew.bat"):
|
||||
return "gradlew.bat"
|
||||
|
||||
# Check system Gradle
|
||||
gradle_path = shutil.which("gradle")
|
||||
if gradle_path:
|
||||
return gradle_path
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def run_maven_tests(
|
||||
project_root: Path,
|
||||
test_classes: list[str] | None = None,
|
||||
test_methods: list[str] | None = None,
|
||||
env: dict[str, str] | None = None,
|
||||
timeout: int = 300,
|
||||
skip_compilation: bool = False,
|
||||
) -> MavenTestResult:
|
||||
"""Run Maven tests using Surefire.
|
||||
|
||||
Args:
|
||||
project_root: Root directory of the Maven project.
|
||||
test_classes: Optional list of test class names to run.
|
||||
test_methods: Optional list of specific test methods (format: ClassName#methodName).
|
||||
env: Optional environment variables.
|
||||
timeout: Maximum time in seconds for test execution.
|
||||
skip_compilation: Whether to skip compilation (useful when only running tests).
|
||||
|
||||
Returns:
|
||||
MavenTestResult with test execution results.
|
||||
|
||||
"""
|
||||
mvn = find_maven_executable()
|
||||
if not mvn:
|
||||
logger.error("Maven not found. Please install Maven or use Maven wrapper.")
|
||||
return MavenTestResult(
|
||||
success=False,
|
||||
tests_run=0,
|
||||
failures=0,
|
||||
errors=0,
|
||||
skipped=0,
|
||||
surefire_reports_dir=None,
|
||||
stdout="",
|
||||
stderr="Maven not found",
|
||||
returncode=-1,
|
||||
)
|
||||
|
||||
# Build Maven command
|
||||
cmd = [mvn]
|
||||
|
||||
if skip_compilation:
|
||||
cmd.append("-Dmaven.test.skip=false")
|
||||
cmd.append("-DskipTests=false")
|
||||
cmd.append("surefire:test")
|
||||
else:
|
||||
cmd.append("test")
|
||||
|
||||
# Add test filtering
|
||||
if test_classes or test_methods:
|
||||
if test_methods:
|
||||
# Format: -Dtest=ClassName#method1+method2,OtherClass#method3
|
||||
tests = ",".join(test_methods)
|
||||
elif test_classes:
|
||||
tests = ",".join(test_classes)
|
||||
cmd.extend(["-Dtest=" + tests])
|
||||
|
||||
# Fail at end to run all tests
|
||||
cmd.append("-fae")
|
||||
|
||||
# Use full environment with optional overrides
|
||||
run_env = os.environ.copy()
|
||||
if env:
|
||||
run_env.update(env)
|
||||
|
||||
try:
|
||||
result = subprocess.run(
|
||||
cmd,
|
||||
check=False,
|
||||
cwd=project_root,
|
||||
env=run_env,
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=timeout,
|
||||
)
|
||||
|
||||
# Parse test results from Surefire reports
|
||||
surefire_dir = project_root / "target" / "surefire-reports"
|
||||
tests_run, failures, errors, skipped = _parse_surefire_reports(surefire_dir)
|
||||
|
||||
return MavenTestResult(
|
||||
success=result.returncode == 0,
|
||||
tests_run=tests_run,
|
||||
failures=failures,
|
||||
errors=errors,
|
||||
skipped=skipped,
|
||||
surefire_reports_dir=surefire_dir if surefire_dir.exists() else None,
|
||||
stdout=result.stdout,
|
||||
stderr=result.stderr,
|
||||
returncode=result.returncode,
|
||||
)
|
||||
|
||||
except subprocess.TimeoutExpired:
|
||||
logger.error("Maven test execution timed out after %d seconds", timeout)
|
||||
return MavenTestResult(
|
||||
success=False,
|
||||
tests_run=0,
|
||||
failures=0,
|
||||
errors=0,
|
||||
skipped=0,
|
||||
surefire_reports_dir=None,
|
||||
stdout="",
|
||||
stderr=f"Test execution timed out after {timeout} seconds",
|
||||
returncode=-2,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.exception("Maven test execution failed: %s", e)
|
||||
return MavenTestResult(
|
||||
success=False,
|
||||
tests_run=0,
|
||||
failures=0,
|
||||
errors=0,
|
||||
skipped=0,
|
||||
surefire_reports_dir=None,
|
||||
stdout="",
|
||||
stderr=str(e),
|
||||
returncode=-1,
|
||||
)
|
||||
|
||||
|
||||
def _parse_surefire_reports(surefire_dir: Path) -> tuple[int, int, int, int]:
|
||||
"""Parse Surefire XML reports to get test counts.
|
||||
|
||||
Args:
|
||||
surefire_dir: Directory containing Surefire XML reports.
|
||||
|
||||
Returns:
|
||||
Tuple of (tests_run, failures, errors, skipped).
|
||||
|
||||
"""
|
||||
tests_run = 0
|
||||
failures = 0
|
||||
errors = 0
|
||||
skipped = 0
|
||||
|
||||
if not surefire_dir.exists():
|
||||
return tests_run, failures, errors, skipped
|
||||
|
||||
for xml_file in surefire_dir.glob("TEST-*.xml"):
|
||||
try:
|
||||
tree = ET.parse(xml_file)
|
||||
root = tree.getroot()
|
||||
|
||||
tests_run += int(root.get("tests", 0))
|
||||
failures += int(root.get("failures", 0))
|
||||
errors += int(root.get("errors", 0))
|
||||
skipped += int(root.get("skipped", 0))
|
||||
|
||||
except ET.ParseError as e:
|
||||
logger.warning("Failed to parse Surefire report %s: %s", xml_file, e)
|
||||
|
||||
return tests_run, failures, errors, skipped
|
||||
|
||||
|
||||
def compile_maven_project(
|
||||
project_root: Path,
|
||||
include_tests: bool = True,
|
||||
env: dict[str, str] | None = None,
|
||||
timeout: int = 300,
|
||||
) -> tuple[bool, str, str]:
|
||||
"""Compile a Maven project.
|
||||
|
||||
Args:
|
||||
project_root: Root directory of the Maven project.
|
||||
include_tests: Whether to compile test classes as well.
|
||||
env: Optional environment variables.
|
||||
timeout: Maximum time in seconds for compilation.
|
||||
|
||||
Returns:
|
||||
Tuple of (success, stdout, stderr).
|
||||
|
||||
"""
|
||||
mvn = find_maven_executable()
|
||||
if not mvn:
|
||||
return False, "", "Maven not found"
|
||||
|
||||
cmd = [mvn]
|
||||
|
||||
if include_tests:
|
||||
cmd.append("test-compile")
|
||||
else:
|
||||
cmd.append("compile")
|
||||
|
||||
# Skip test execution
|
||||
cmd.append("-DskipTests")
|
||||
|
||||
run_env = os.environ.copy()
|
||||
if env:
|
||||
run_env.update(env)
|
||||
|
||||
try:
|
||||
result = subprocess.run(
|
||||
cmd,
|
||||
check=False,
|
||||
cwd=project_root,
|
||||
env=run_env,
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=timeout,
|
||||
)
|
||||
|
||||
return result.returncode == 0, result.stdout, result.stderr
|
||||
|
||||
except subprocess.TimeoutExpired:
|
||||
return False, "", f"Compilation timed out after {timeout} seconds"
|
||||
except Exception as e:
|
||||
return False, "", str(e)
|
||||
|
||||
|
||||
def install_codeflash_runtime(project_root: Path, runtime_jar_path: Path) -> bool:
|
||||
"""Install the codeflash runtime JAR to the local Maven repository.
|
||||
|
||||
Args:
|
||||
project_root: Root directory of the Maven project.
|
||||
runtime_jar_path: Path to the codeflash-runtime.jar file.
|
||||
|
||||
Returns:
|
||||
True if installation succeeded, False otherwise.
|
||||
|
||||
"""
|
||||
mvn = find_maven_executable()
|
||||
if not mvn:
|
||||
logger.error("Maven not found")
|
||||
return False
|
||||
|
||||
if not runtime_jar_path.exists():
|
||||
logger.error("Runtime JAR not found: %s", runtime_jar_path)
|
||||
return False
|
||||
|
||||
cmd = [
|
||||
mvn,
|
||||
"install:install-file",
|
||||
f"-Dfile={runtime_jar_path}",
|
||||
"-DgroupId=com.codeflash",
|
||||
"-DartifactId=codeflash-runtime",
|
||||
"-Dversion=1.0.0",
|
||||
"-Dpackaging=jar",
|
||||
]
|
||||
|
||||
try:
|
||||
result = subprocess.run(
|
||||
cmd,
|
||||
check=False,
|
||||
cwd=project_root,
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=60,
|
||||
)
|
||||
|
||||
if result.returncode == 0:
|
||||
logger.info("Successfully installed codeflash-runtime to local Maven repository")
|
||||
return True
|
||||
else:
|
||||
logger.error("Failed to install codeflash-runtime: %s", result.stderr)
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
logger.exception("Failed to install codeflash-runtime: %s", e)
|
||||
return False
|
||||
|
||||
|
||||
def add_codeflash_dependency_to_pom(pom_path: Path) -> bool:
|
||||
"""Add codeflash-runtime dependency to pom.xml if not present.
|
||||
|
||||
Args:
|
||||
pom_path: Path to the pom.xml file.
|
||||
|
||||
Returns:
|
||||
True if dependency was added or already present, False on error.
|
||||
|
||||
"""
|
||||
if not pom_path.exists():
|
||||
return False
|
||||
|
||||
try:
|
||||
tree = ET.parse(pom_path)
|
||||
root = tree.getroot()
|
||||
|
||||
# Handle Maven namespace
|
||||
ns = {"m": "http://maven.apache.org/POM/4.0.0"}
|
||||
ns_prefix = "{http://maven.apache.org/POM/4.0.0}"
|
||||
|
||||
# Check if namespace is used
|
||||
if root.tag.startswith("{"):
|
||||
use_ns = True
|
||||
else:
|
||||
use_ns = False
|
||||
ns_prefix = ""
|
||||
|
||||
# Find or create dependencies section
|
||||
deps = root.find(f"{ns_prefix}dependencies" if use_ns else "dependencies")
|
||||
if deps is None:
|
||||
deps = ET.SubElement(root, f"{ns_prefix}dependencies" if use_ns else "dependencies")
|
||||
|
||||
# Check if codeflash dependency already exists
|
||||
for dep in deps.findall(f"{ns_prefix}dependency" if use_ns else "dependency"):
|
||||
group = dep.find(f"{ns_prefix}groupId" if use_ns else "groupId")
|
||||
artifact = dep.find(f"{ns_prefix}artifactId" if use_ns else "artifactId")
|
||||
if group is not None and artifact is not None:
|
||||
if group.text == "com.codeflash" and artifact.text == "codeflash-runtime":
|
||||
logger.info("codeflash-runtime dependency already present in pom.xml")
|
||||
return True
|
||||
|
||||
# Add codeflash dependency
|
||||
dep_elem = ET.SubElement(deps, f"{ns_prefix}dependency" if use_ns else "dependency")
|
||||
|
||||
group_elem = ET.SubElement(dep_elem, f"{ns_prefix}groupId" if use_ns else "groupId")
|
||||
group_elem.text = "com.codeflash"
|
||||
|
||||
artifact_elem = ET.SubElement(dep_elem, f"{ns_prefix}artifactId" if use_ns else "artifactId")
|
||||
artifact_elem.text = "codeflash-runtime"
|
||||
|
||||
version_elem = ET.SubElement(dep_elem, f"{ns_prefix}version" if use_ns else "version")
|
||||
version_elem.text = "1.0.0"
|
||||
|
||||
scope_elem = ET.SubElement(dep_elem, f"{ns_prefix}scope" if use_ns else "scope")
|
||||
scope_elem.text = "test"
|
||||
|
||||
# Write back to file
|
||||
tree.write(pom_path, xml_declaration=True, encoding="utf-8")
|
||||
logger.info("Added codeflash-runtime dependency to pom.xml")
|
||||
return True
|
||||
|
||||
except ET.ParseError as e:
|
||||
logger.error("Failed to parse pom.xml: %s", e)
|
||||
return False
|
||||
except Exception as e:
|
||||
logger.exception("Failed to add dependency to pom.xml: %s", e)
|
||||
return False
|
||||
|
||||
|
||||
def find_test_root(project_root: Path) -> Path | None:
|
||||
"""Find the test root directory for a Java project.
|
||||
|
||||
Args:
|
||||
project_root: Root directory of the Java project.
|
||||
|
||||
Returns:
|
||||
Path to test root, or None if not found.
|
||||
|
||||
"""
|
||||
build_tool = detect_build_tool(project_root)
|
||||
|
||||
if build_tool in (BuildTool.MAVEN, BuildTool.GRADLE):
|
||||
test_root = project_root / "src" / "test" / "java"
|
||||
if test_root.exists():
|
||||
return test_root
|
||||
|
||||
# Check common alternative locations
|
||||
for test_dir in ["test", "tests", "src/test"]:
|
||||
test_path = project_root / test_dir
|
||||
if test_path.exists():
|
||||
return test_path
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def find_source_root(project_root: Path) -> Path | None:
|
||||
"""Find the main source root directory for a Java project.
|
||||
|
||||
Args:
|
||||
project_root: Root directory of the Java project.
|
||||
|
||||
Returns:
|
||||
Path to source root, or None if not found.
|
||||
|
||||
"""
|
||||
build_tool = detect_build_tool(project_root)
|
||||
|
||||
if build_tool in (BuildTool.MAVEN, BuildTool.GRADLE):
|
||||
src_root = project_root / "src" / "main" / "java"
|
||||
if src_root.exists():
|
||||
return src_root
|
||||
|
||||
# Check common alternative locations
|
||||
for src_dir in ["src", "source", "java"]:
|
||||
src_path = project_root / src_dir
|
||||
if src_path.exists() and any(src_path.rglob("*.java")):
|
||||
return src_path
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def get_classpath(project_root: Path) -> str | None:
|
||||
"""Get the classpath for a Java project.
|
||||
|
||||
For Maven projects, this runs 'mvn dependency:build-classpath'.
|
||||
|
||||
Args:
|
||||
project_root: Root directory of the Java project.
|
||||
|
||||
Returns:
|
||||
Classpath string, or None if unable to determine.
|
||||
|
||||
"""
|
||||
build_tool = detect_build_tool(project_root)
|
||||
|
||||
if build_tool == BuildTool.MAVEN:
|
||||
return _get_maven_classpath(project_root)
|
||||
if build_tool == BuildTool.GRADLE:
|
||||
return _get_gradle_classpath(project_root)
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def _get_maven_classpath(project_root: Path) -> str | None:
|
||||
"""Get classpath from Maven."""
|
||||
mvn = find_maven_executable()
|
||||
if not mvn:
|
||||
return None
|
||||
|
||||
try:
|
||||
result = subprocess.run(
|
||||
[mvn, "dependency:build-classpath", "-q", "-DincludeScope=test"],
|
||||
check=False,
|
||||
cwd=project_root,
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=120,
|
||||
)
|
||||
|
||||
if result.returncode == 0:
|
||||
# The classpath is in stdout
|
||||
return result.stdout.strip()
|
||||
|
||||
except Exception as e:
|
||||
logger.warning("Failed to get Maven classpath: %s", e)
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def _get_gradle_classpath(project_root: Path) -> str | None:
|
||||
"""Get classpath from Gradle.
|
||||
|
||||
Note: This requires a custom task to be added to build.gradle.
|
||||
Returns None for now as Gradle support is not fully implemented.
|
||||
"""
|
||||
return None
|
||||
333
codeflash/languages/java/comparator.py
Normal file
333
codeflash/languages/java/comparator.py
Normal file
|
|
@ -0,0 +1,333 @@
|
|||
"""Java test result comparison.
|
||||
|
||||
This module provides functionality to compare test results between
|
||||
original and optimized Java code using the codeflash-runtime Comparator.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import subprocess
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from codeflash.models.models import TestDiff
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _find_comparator_jar(project_root: Path | None = None) -> Path | None:
|
||||
"""Find the codeflash-runtime JAR with the Comparator class.
|
||||
|
||||
Args:
|
||||
project_root: Project root directory.
|
||||
|
||||
Returns:
|
||||
Path to codeflash-runtime JAR if found, None otherwise.
|
||||
|
||||
"""
|
||||
search_dirs = []
|
||||
if project_root:
|
||||
search_dirs.append(project_root)
|
||||
search_dirs.append(Path.cwd())
|
||||
|
||||
# Search for the JAR in common locations
|
||||
for base_dir in search_dirs:
|
||||
# Check in target directory (after Maven install)
|
||||
for jar_path in [
|
||||
base_dir / "target" / "dependency" / "codeflash-runtime-1.0.0.jar",
|
||||
base_dir / "target" / "codeflash-runtime-1.0.0.jar",
|
||||
base_dir / "lib" / "codeflash-runtime-1.0.0.jar",
|
||||
base_dir / ".codeflash" / "codeflash-runtime-1.0.0.jar",
|
||||
]:
|
||||
if jar_path.exists():
|
||||
return jar_path
|
||||
|
||||
# Check local Maven repository
|
||||
m2_jar = Path.home() / ".m2" / "repository" / "com" / "codeflash" / "codeflash-runtime" / "1.0.0" / "codeflash-runtime-1.0.0.jar"
|
||||
if m2_jar.exists():
|
||||
return m2_jar
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def _find_java_executable() -> str | None:
|
||||
"""Find the Java executable.
|
||||
|
||||
Returns:
|
||||
Path to java executable, or None if not found.
|
||||
|
||||
"""
|
||||
import shutil
|
||||
|
||||
# Check JAVA_HOME
|
||||
java_home = os.environ.get("JAVA_HOME")
|
||||
if java_home:
|
||||
java_path = Path(java_home) / "bin" / "java"
|
||||
if java_path.exists():
|
||||
return str(java_path)
|
||||
|
||||
# Check PATH
|
||||
java_path = shutil.which("java")
|
||||
if java_path:
|
||||
return java_path
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def compare_test_results(
|
||||
original_sqlite_path: Path,
|
||||
candidate_sqlite_path: Path,
|
||||
comparator_jar: Path | None = None,
|
||||
project_root: Path | None = None,
|
||||
) -> tuple[bool, list]:
|
||||
"""Compare Java test results using the codeflash-runtime Comparator.
|
||||
|
||||
This function calls the Java Comparator CLI that:
|
||||
1. Reads serialized behavior data from both SQLite databases
|
||||
2. Deserializes using Gson
|
||||
3. Compares results using deep equality (handles Maps, Lists, arrays, etc.)
|
||||
4. Returns comparison results as JSON
|
||||
|
||||
Args:
|
||||
original_sqlite_path: Path to SQLite database with original code results.
|
||||
candidate_sqlite_path: Path to SQLite database with candidate code results.
|
||||
comparator_jar: Optional path to the codeflash-runtime JAR.
|
||||
project_root: Project root directory.
|
||||
|
||||
Returns:
|
||||
Tuple of (all_equivalent, list of TestDiff objects).
|
||||
|
||||
"""
|
||||
# Import lazily to avoid circular imports
|
||||
from codeflash.models.models import TestDiff, TestDiffScope
|
||||
|
||||
java_exe = _find_java_executable()
|
||||
if not java_exe:
|
||||
logger.error("Java not found. Please install Java to compare test results.")
|
||||
return False, []
|
||||
|
||||
jar_path = comparator_jar or _find_comparator_jar(project_root)
|
||||
if not jar_path or not jar_path.exists():
|
||||
logger.error(
|
||||
"codeflash-runtime JAR not found. "
|
||||
"Please ensure the codeflash-runtime is installed in your project."
|
||||
)
|
||||
return False, []
|
||||
|
||||
if not original_sqlite_path.exists():
|
||||
logger.error(f"Original SQLite database not found: {original_sqlite_path}")
|
||||
return False, []
|
||||
|
||||
if not candidate_sqlite_path.exists():
|
||||
logger.error(f"Candidate SQLite database not found: {candidate_sqlite_path}")
|
||||
return False, []
|
||||
|
||||
cwd = project_root or Path.cwd()
|
||||
|
||||
try:
|
||||
result = subprocess.run(
|
||||
[
|
||||
java_exe,
|
||||
"-cp",
|
||||
str(jar_path),
|
||||
"com.codeflash.Comparator",
|
||||
str(original_sqlite_path),
|
||||
str(candidate_sqlite_path),
|
||||
],
|
||||
check=False,
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=60,
|
||||
cwd=str(cwd),
|
||||
)
|
||||
|
||||
# Parse the JSON output
|
||||
try:
|
||||
if not result.stdout or not result.stdout.strip():
|
||||
logger.error("Java comparator returned empty output")
|
||||
if result.stderr:
|
||||
logger.error(f"stderr: {result.stderr}")
|
||||
return False, []
|
||||
|
||||
comparison = json.loads(result.stdout)
|
||||
except json.JSONDecodeError as e:
|
||||
logger.error(f"Failed to parse Java comparator output: {e}")
|
||||
logger.error(f"stdout: {result.stdout[:500] if result.stdout else '(empty)'}")
|
||||
if result.stderr:
|
||||
logger.error(f"stderr: {result.stderr[:500]}")
|
||||
return False, []
|
||||
|
||||
# Check for errors in the JSON response
|
||||
if comparison.get("error"):
|
||||
logger.error(f"Java comparator error: {comparison['error']}")
|
||||
return False, []
|
||||
|
||||
# Check for unexpected exit codes
|
||||
if result.returncode not in {0, 1}:
|
||||
logger.error(f"Java comparator failed with exit code {result.returncode}")
|
||||
if result.stderr:
|
||||
logger.error(f"stderr: {result.stderr}")
|
||||
return False, []
|
||||
|
||||
# Convert diffs to TestDiff objects
|
||||
test_diffs: list[TestDiff] = []
|
||||
for diff in comparison.get("diffs", []):
|
||||
scope_str = diff.get("scope", "return_value")
|
||||
scope = TestDiffScope.RETURN_VALUE
|
||||
if scope_str == "exception":
|
||||
scope = TestDiffScope.DID_PASS
|
||||
elif scope_str == "missing":
|
||||
scope = TestDiffScope.DID_PASS
|
||||
|
||||
# Build test identifier
|
||||
method_id = diff.get("methodId", "unknown")
|
||||
call_id = diff.get("callId", 0)
|
||||
test_src_code = f"// Method: {method_id}\n// Call ID: {call_id}"
|
||||
|
||||
test_diffs.append(
|
||||
TestDiff(
|
||||
scope=scope,
|
||||
original_value=diff.get("originalValue"),
|
||||
candidate_value=diff.get("candidateValue"),
|
||||
test_src_code=test_src_code,
|
||||
candidate_pytest_error=diff.get("candidateError"),
|
||||
original_pass=True,
|
||||
candidate_pass=scope_str not in ("missing", "exception"),
|
||||
original_pytest_error=diff.get("originalError"),
|
||||
)
|
||||
)
|
||||
|
||||
logger.debug(
|
||||
f"Java test diff:\n"
|
||||
f" Method: {method_id}\n"
|
||||
f" Call ID: {call_id}\n"
|
||||
f" Scope: {scope_str}\n"
|
||||
f" Original: {str(diff.get('originalValue', 'N/A'))[:100]}\n"
|
||||
f" Candidate: {str(diff.get('candidateValue', 'N/A'))[:100]}"
|
||||
)
|
||||
|
||||
equivalent = comparison.get("equivalent", False)
|
||||
|
||||
logger.info(
|
||||
f"Java comparison: {'equivalent' if equivalent else 'DIFFERENT'} "
|
||||
f"({comparison.get('totalInvocations', 0)} invocations, {len(test_diffs)} diffs)"
|
||||
)
|
||||
|
||||
return equivalent, test_diffs
|
||||
|
||||
except subprocess.TimeoutExpired:
|
||||
logger.error("Java comparator timed out")
|
||||
return False, []
|
||||
except FileNotFoundError:
|
||||
logger.error("Java not found. Please install Java to compare test results.")
|
||||
return False, []
|
||||
except Exception as e:
|
||||
logger.error(f"Error running Java comparator: {e}")
|
||||
return False, []
|
||||
|
||||
|
||||
def compare_invocations_directly(
|
||||
original_results: dict,
|
||||
candidate_results: dict,
|
||||
) -> tuple[bool, list]:
|
||||
"""Compare test invocations directly from Python dictionaries.
|
||||
|
||||
This is a fallback when the Java comparator is not available.
|
||||
It performs basic equality comparison on serialized JSON values.
|
||||
|
||||
Args:
|
||||
original_results: Dict mapping call_id to result data from original code.
|
||||
candidate_results: Dict mapping call_id to result data from candidate code.
|
||||
|
||||
Returns:
|
||||
Tuple of (all_equivalent, list of TestDiff objects).
|
||||
|
||||
"""
|
||||
# Import lazily to avoid circular imports
|
||||
from codeflash.models.models import TestDiff, TestDiffScope
|
||||
|
||||
test_diffs: list[TestDiff] = []
|
||||
|
||||
# Get all call IDs
|
||||
all_call_ids = set(original_results.keys()) | set(candidate_results.keys())
|
||||
|
||||
for call_id in all_call_ids:
|
||||
original = original_results.get(call_id)
|
||||
candidate = candidate_results.get(call_id)
|
||||
|
||||
if original is None and candidate is not None:
|
||||
# Candidate has extra invocation
|
||||
test_diffs.append(
|
||||
TestDiff(
|
||||
scope=TestDiffScope.DID_PASS,
|
||||
original_value=None,
|
||||
candidate_value=candidate.get("result_json"),
|
||||
test_src_code=f"// Call ID: {call_id}",
|
||||
candidate_pytest_error=None,
|
||||
original_pass=True,
|
||||
candidate_pass=True,
|
||||
original_pytest_error=None,
|
||||
)
|
||||
)
|
||||
elif original is not None and candidate is None:
|
||||
# Candidate missing invocation
|
||||
test_diffs.append(
|
||||
TestDiff(
|
||||
scope=TestDiffScope.DID_PASS,
|
||||
original_value=original.get("result_json"),
|
||||
candidate_value=None,
|
||||
test_src_code=f"// Call ID: {call_id}",
|
||||
candidate_pytest_error="Missing invocation in candidate",
|
||||
original_pass=True,
|
||||
candidate_pass=False,
|
||||
original_pytest_error=None,
|
||||
)
|
||||
)
|
||||
elif original is not None and candidate is not None:
|
||||
# Both have invocations - compare results
|
||||
orig_result = original.get("result_json")
|
||||
cand_result = candidate.get("result_json")
|
||||
orig_error = original.get("error_json")
|
||||
cand_error = candidate.get("error_json")
|
||||
|
||||
# Check for exception differences
|
||||
if orig_error != cand_error:
|
||||
test_diffs.append(
|
||||
TestDiff(
|
||||
scope=TestDiffScope.DID_PASS,
|
||||
original_value=orig_error,
|
||||
candidate_value=cand_error,
|
||||
test_src_code=f"// Call ID: {call_id}",
|
||||
candidate_pytest_error=cand_error,
|
||||
original_pass=orig_error is None,
|
||||
candidate_pass=cand_error is None,
|
||||
original_pytest_error=orig_error,
|
||||
)
|
||||
)
|
||||
elif orig_result != cand_result:
|
||||
# Results differ
|
||||
test_diffs.append(
|
||||
TestDiff(
|
||||
scope=TestDiffScope.RETURN_VALUE,
|
||||
original_value=orig_result,
|
||||
candidate_value=cand_result,
|
||||
test_src_code=f"// Call ID: {call_id}",
|
||||
candidate_pytest_error=None,
|
||||
original_pass=True,
|
||||
candidate_pass=True,
|
||||
original_pytest_error=None,
|
||||
)
|
||||
)
|
||||
|
||||
equivalent = len(test_diffs) == 0
|
||||
|
||||
logger.info(
|
||||
f"Python comparison: {'equivalent' if equivalent else 'DIFFERENT'} "
|
||||
f"({len(all_call_ids)} invocations, {len(test_diffs)} diffs)"
|
||||
)
|
||||
|
||||
return equivalent, test_diffs
|
||||
426
codeflash/languages/java/config.py
Normal file
426
codeflash/languages/java/config.py
Normal file
|
|
@ -0,0 +1,426 @@
|
|||
"""Java project configuration detection.
|
||||
|
||||
This module provides functionality to detect and read Java project
|
||||
configuration, including build tool settings, test framework configuration,
|
||||
and project structure.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import xml.etree.ElementTree as ET
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from codeflash.languages.java.build_tools import (
|
||||
BuildTool,
|
||||
detect_build_tool,
|
||||
find_source_root,
|
||||
find_test_root,
|
||||
get_project_info,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
pass
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class JavaProjectConfig:
|
||||
"""Configuration for a Java project."""
|
||||
|
||||
project_root: Path
|
||||
build_tool: BuildTool
|
||||
source_root: Path | None
|
||||
test_root: Path | None
|
||||
java_version: str | None
|
||||
encoding: str
|
||||
test_framework: str # "junit5", "junit4", "testng"
|
||||
group_id: str | None
|
||||
artifact_id: str | None
|
||||
version: str | None
|
||||
|
||||
# Dependencies
|
||||
has_junit5: bool = False
|
||||
has_junit4: bool = False
|
||||
has_testng: bool = False
|
||||
has_mockito: bool = False
|
||||
has_assertj: bool = False
|
||||
|
||||
# Build configuration
|
||||
compiler_source: str | None = None
|
||||
compiler_target: str | None = None
|
||||
|
||||
# Plugin configurations
|
||||
surefire_includes: list[str] = field(default_factory=list)
|
||||
surefire_excludes: list[str] = field(default_factory=list)
|
||||
|
||||
|
||||
def detect_java_project(project_root: Path) -> JavaProjectConfig | None:
|
||||
"""Detect and return Java project configuration.
|
||||
|
||||
Args:
|
||||
project_root: Root directory of the project.
|
||||
|
||||
Returns:
|
||||
JavaProjectConfig if a Java project is detected, None otherwise.
|
||||
|
||||
"""
|
||||
# Check if this is a Java project
|
||||
build_tool = detect_build_tool(project_root)
|
||||
if build_tool == BuildTool.UNKNOWN:
|
||||
# Check if there are any Java files
|
||||
java_files = list(project_root.rglob("*.java"))
|
||||
if not java_files:
|
||||
return None
|
||||
|
||||
# Get basic project info
|
||||
project_info = get_project_info(project_root)
|
||||
|
||||
# Detect test framework
|
||||
test_framework, has_junit5, has_junit4, has_testng = _detect_test_framework(
|
||||
project_root, build_tool
|
||||
)
|
||||
|
||||
# Detect other dependencies
|
||||
has_mockito, has_assertj = _detect_test_dependencies(project_root, build_tool)
|
||||
|
||||
# Get source/test roots
|
||||
source_root = find_source_root(project_root)
|
||||
test_root = find_test_root(project_root)
|
||||
|
||||
# Get compiler settings
|
||||
compiler_source, compiler_target = _get_compiler_settings(project_root, build_tool)
|
||||
|
||||
# Get surefire configuration
|
||||
surefire_includes, surefire_excludes = _get_surefire_config(project_root)
|
||||
|
||||
return JavaProjectConfig(
|
||||
project_root=project_root,
|
||||
build_tool=build_tool,
|
||||
source_root=source_root,
|
||||
test_root=test_root,
|
||||
java_version=project_info.java_version if project_info else None,
|
||||
encoding="UTF-8", # Default, could be detected from pom.xml
|
||||
test_framework=test_framework,
|
||||
group_id=project_info.group_id if project_info else None,
|
||||
artifact_id=project_info.artifact_id if project_info else None,
|
||||
version=project_info.version if project_info else None,
|
||||
has_junit5=has_junit5,
|
||||
has_junit4=has_junit4,
|
||||
has_testng=has_testng,
|
||||
has_mockito=has_mockito,
|
||||
has_assertj=has_assertj,
|
||||
compiler_source=compiler_source,
|
||||
compiler_target=compiler_target,
|
||||
surefire_includes=surefire_includes,
|
||||
surefire_excludes=surefire_excludes,
|
||||
)
|
||||
|
||||
|
||||
def _detect_test_framework(
|
||||
project_root: Path, build_tool: BuildTool
|
||||
) -> tuple[str, bool, bool, bool]:
|
||||
"""Detect which test framework the project uses.
|
||||
|
||||
Args:
|
||||
project_root: Root directory of the project.
|
||||
build_tool: The detected build tool.
|
||||
|
||||
Returns:
|
||||
Tuple of (framework_name, has_junit5, has_junit4, has_testng).
|
||||
|
||||
"""
|
||||
has_junit5 = False
|
||||
has_junit4 = False
|
||||
has_testng = False
|
||||
|
||||
if build_tool == BuildTool.MAVEN:
|
||||
has_junit5, has_junit4, has_testng = _detect_test_deps_from_pom(project_root)
|
||||
elif build_tool == BuildTool.GRADLE:
|
||||
has_junit5, has_junit4, has_testng = _detect_test_deps_from_gradle(project_root)
|
||||
|
||||
# Also check test source files for import statements
|
||||
test_root = find_test_root(project_root)
|
||||
if test_root and test_root.exists():
|
||||
for test_file in test_root.rglob("*.java"):
|
||||
try:
|
||||
content = test_file.read_text(encoding="utf-8")
|
||||
if "org.junit.jupiter" in content:
|
||||
has_junit5 = True
|
||||
if "org.junit.Test" in content or "org.junit.Assert" in content:
|
||||
has_junit4 = True
|
||||
if "org.testng" in content:
|
||||
has_testng = True
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Determine primary framework (prefer JUnit 5)
|
||||
if has_junit5:
|
||||
return "junit5", has_junit5, has_junit4, has_testng
|
||||
if has_junit4:
|
||||
return "junit4", has_junit5, has_junit4, has_testng
|
||||
if has_testng:
|
||||
return "testng", has_junit5, has_junit4, has_testng
|
||||
|
||||
# Default to JUnit 5 if nothing detected
|
||||
return "junit5", has_junit5, has_junit4, has_testng
|
||||
|
||||
|
||||
def _detect_test_deps_from_pom(project_root: Path) -> tuple[bool, bool, bool]:
|
||||
"""Detect test framework dependencies from pom.xml.
|
||||
|
||||
Returns:
|
||||
Tuple of (has_junit5, has_junit4, has_testng).
|
||||
|
||||
"""
|
||||
pom_path = project_root / "pom.xml"
|
||||
if not pom_path.exists():
|
||||
return False, False, False
|
||||
|
||||
has_junit5 = False
|
||||
has_junit4 = False
|
||||
has_testng = False
|
||||
|
||||
try:
|
||||
tree = ET.parse(pom_path)
|
||||
root = tree.getroot()
|
||||
|
||||
# Handle namespace
|
||||
ns = {"m": "http://maven.apache.org/POM/4.0.0"}
|
||||
|
||||
# Search for dependencies
|
||||
for deps_path in ["dependencies", "m:dependencies"]:
|
||||
deps = root.find(deps_path, ns) if "m:" in deps_path else root.find(deps_path)
|
||||
if deps is None:
|
||||
continue
|
||||
|
||||
for dep_path in ["dependency", "m:dependency"]:
|
||||
deps_list = deps.findall(dep_path, ns) if "m:" in dep_path else deps.findall(dep_path)
|
||||
for dep in deps_list:
|
||||
artifact_id = None
|
||||
group_id = None
|
||||
|
||||
for child in dep:
|
||||
tag = child.tag.replace("{http://maven.apache.org/POM/4.0.0}", "")
|
||||
if tag == "artifactId":
|
||||
artifact_id = child.text
|
||||
elif tag == "groupId":
|
||||
group_id = child.text
|
||||
|
||||
if group_id == "org.junit.jupiter" or (
|
||||
artifact_id and "junit-jupiter" in artifact_id
|
||||
):
|
||||
has_junit5 = True
|
||||
elif group_id == "junit" and artifact_id == "junit":
|
||||
has_junit4 = True
|
||||
elif group_id == "org.testng":
|
||||
has_testng = True
|
||||
|
||||
except ET.ParseError:
|
||||
pass
|
||||
|
||||
return has_junit5, has_junit4, has_testng
|
||||
|
||||
|
||||
def _detect_test_deps_from_gradle(project_root: Path) -> tuple[bool, bool, bool]:
|
||||
"""Detect test framework dependencies from Gradle build files.
|
||||
|
||||
Returns:
|
||||
Tuple of (has_junit5, has_junit4, has_testng).
|
||||
|
||||
"""
|
||||
has_junit5 = False
|
||||
has_junit4 = False
|
||||
has_testng = False
|
||||
|
||||
for gradle_file in ["build.gradle", "build.gradle.kts"]:
|
||||
gradle_path = project_root / gradle_file
|
||||
if gradle_path.exists():
|
||||
try:
|
||||
content = gradle_path.read_text(encoding="utf-8")
|
||||
if "junit-jupiter" in content or "useJUnitPlatform" in content:
|
||||
has_junit5 = True
|
||||
if "junit:junit" in content:
|
||||
has_junit4 = True
|
||||
if "testng" in content.lower():
|
||||
has_testng = True
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return has_junit5, has_junit4, has_testng
|
||||
|
||||
|
||||
def _detect_test_dependencies(
|
||||
project_root: Path, build_tool: BuildTool
|
||||
) -> tuple[bool, bool]:
|
||||
"""Detect additional test dependencies (Mockito, AssertJ).
|
||||
|
||||
Returns:
|
||||
Tuple of (has_mockito, has_assertj).
|
||||
|
||||
"""
|
||||
has_mockito = False
|
||||
has_assertj = False
|
||||
|
||||
pom_path = project_root / "pom.xml"
|
||||
if pom_path.exists():
|
||||
try:
|
||||
content = pom_path.read_text(encoding="utf-8")
|
||||
has_mockito = "mockito" in content.lower()
|
||||
has_assertj = "assertj" in content.lower()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
for gradle_file in ["build.gradle", "build.gradle.kts"]:
|
||||
gradle_path = project_root / gradle_file
|
||||
if gradle_path.exists():
|
||||
try:
|
||||
content = gradle_path.read_text(encoding="utf-8")
|
||||
if "mockito" in content.lower():
|
||||
has_mockito = True
|
||||
if "assertj" in content.lower():
|
||||
has_assertj = True
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return has_mockito, has_assertj
|
||||
|
||||
|
||||
def _get_compiler_settings(
|
||||
project_root: Path, build_tool: BuildTool
|
||||
) -> tuple[str | None, str | None]:
|
||||
"""Get compiler source and target settings.
|
||||
|
||||
Returns:
|
||||
Tuple of (source_version, target_version).
|
||||
|
||||
"""
|
||||
if build_tool != BuildTool.MAVEN:
|
||||
return None, None
|
||||
|
||||
pom_path = project_root / "pom.xml"
|
||||
if not pom_path.exists():
|
||||
return None, None
|
||||
|
||||
source = None
|
||||
target = None
|
||||
|
||||
try:
|
||||
tree = ET.parse(pom_path)
|
||||
root = tree.getroot()
|
||||
|
||||
ns = {"m": "http://maven.apache.org/POM/4.0.0"}
|
||||
|
||||
# Check properties
|
||||
for props_path in ["properties", "m:properties"]:
|
||||
props = root.find(props_path, ns) if "m:" in props_path else root.find(props_path)
|
||||
if props is not None:
|
||||
for child in props:
|
||||
tag = child.tag.replace("{http://maven.apache.org/POM/4.0.0}", "")
|
||||
if tag == "maven.compiler.source":
|
||||
source = child.text
|
||||
elif tag == "maven.compiler.target":
|
||||
target = child.text
|
||||
|
||||
except ET.ParseError:
|
||||
pass
|
||||
|
||||
return source, target
|
||||
|
||||
|
||||
def _get_surefire_config(project_root: Path) -> tuple[list[str], list[str]]:
|
||||
"""Get Maven Surefire plugin includes/excludes configuration.
|
||||
|
||||
Returns:
|
||||
Tuple of (includes, excludes) patterns.
|
||||
|
||||
"""
|
||||
includes: list[str] = []
|
||||
excludes: list[str] = []
|
||||
|
||||
pom_path = project_root / "pom.xml"
|
||||
if not pom_path.exists():
|
||||
return includes, excludes
|
||||
|
||||
try:
|
||||
tree = ET.parse(pom_path)
|
||||
root = tree.getroot()
|
||||
|
||||
ns = {"m": "http://maven.apache.org/POM/4.0.0"}
|
||||
|
||||
# Find surefire plugin configuration
|
||||
# This is a simplified search - a full implementation would
|
||||
# handle nested build/plugins/plugin structure
|
||||
|
||||
content = pom_path.read_text(encoding="utf-8")
|
||||
if "maven-surefire-plugin" in content:
|
||||
# Parse includes/excludes if present
|
||||
# This is a basic implementation
|
||||
pass
|
||||
|
||||
except (ET.ParseError, Exception):
|
||||
pass
|
||||
|
||||
# Return default patterns if none configured
|
||||
if not includes:
|
||||
includes = ["**/Test*.java", "**/*Test.java", "**/*Tests.java", "**/*TestCase.java"]
|
||||
if not excludes:
|
||||
excludes = ["**/*IT.java", "**/*IntegrationTest.java"]
|
||||
|
||||
return includes, excludes
|
||||
|
||||
|
||||
def is_java_project(project_root: Path) -> bool:
|
||||
"""Check if a directory is a Java project.
|
||||
|
||||
Args:
|
||||
project_root: Directory to check.
|
||||
|
||||
Returns:
|
||||
True if this appears to be a Java project.
|
||||
|
||||
"""
|
||||
# Check for build tool config files
|
||||
if (project_root / "pom.xml").exists():
|
||||
return True
|
||||
if (project_root / "build.gradle").exists():
|
||||
return True
|
||||
if (project_root / "build.gradle.kts").exists():
|
||||
return True
|
||||
|
||||
# Check for Java source files
|
||||
for pattern in ["src/**/*.java", "*.java"]:
|
||||
if list(project_root.glob(pattern)):
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
|
||||
def get_test_file_pattern(config: JavaProjectConfig) -> str:
|
||||
"""Get the test file naming pattern for a project.
|
||||
|
||||
Args:
|
||||
config: The project configuration.
|
||||
|
||||
Returns:
|
||||
Glob pattern for test files.
|
||||
|
||||
"""
|
||||
# Default JUnit pattern
|
||||
return "*Test.java"
|
||||
|
||||
|
||||
def get_test_class_pattern(config: JavaProjectConfig) -> str:
|
||||
"""Get the regex pattern for test class names.
|
||||
|
||||
Args:
|
||||
config: The project configuration.
|
||||
|
||||
Returns:
|
||||
Regex pattern for test class names.
|
||||
|
||||
"""
|
||||
return r".*Test(s)?$|^Test.*"
|
||||
345
codeflash/languages/java/context.py
Normal file
345
codeflash/languages/java/context.py
Normal file
|
|
@ -0,0 +1,345 @@
|
|||
"""Java code context extraction.
|
||||
|
||||
This module provides functionality to extract code context needed for
|
||||
optimization, including the target function, helper functions, imports,
|
||||
and other dependencies.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from codeflash.languages.base import CodeContext, FunctionInfo, HelperFunction, Language
|
||||
from codeflash.languages.java.discovery import discover_functions_from_source
|
||||
from codeflash.languages.java.import_resolver import JavaImportResolver, find_helper_files
|
||||
from codeflash.languages.java.parser import JavaAnalyzer, get_java_analyzer
|
||||
|
||||
if TYPE_CHECKING:
|
||||
pass
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def extract_code_context(
|
||||
function: FunctionInfo,
|
||||
project_root: Path,
|
||||
module_root: Path | None = None,
|
||||
max_helper_depth: int = 2,
|
||||
analyzer: JavaAnalyzer | None = None,
|
||||
) -> CodeContext:
|
||||
"""Extract code context for a Java function.
|
||||
|
||||
This extracts:
|
||||
- The target function's source code
|
||||
- Import statements
|
||||
- Helper functions (project-internal dependencies)
|
||||
- Read-only context (class fields, constants, etc.)
|
||||
|
||||
Args:
|
||||
function: The function to extract context for.
|
||||
project_root: Root of the project.
|
||||
module_root: Root of the module (defaults to project_root).
|
||||
max_helper_depth: Maximum depth to trace helper functions.
|
||||
analyzer: Optional JavaAnalyzer instance.
|
||||
|
||||
Returns:
|
||||
CodeContext with target code and dependencies.
|
||||
|
||||
"""
|
||||
analyzer = analyzer or get_java_analyzer()
|
||||
module_root = module_root or project_root
|
||||
|
||||
# Read the source file
|
||||
try:
|
||||
source = function.file_path.read_text(encoding="utf-8")
|
||||
except Exception as e:
|
||||
logger.error("Failed to read %s: %s", function.file_path, e)
|
||||
return CodeContext(
|
||||
target_code="",
|
||||
target_file=function.file_path,
|
||||
language=Language.JAVA,
|
||||
)
|
||||
|
||||
# Extract target function code
|
||||
target_code = extract_function_source(source, function)
|
||||
|
||||
# Extract imports
|
||||
imports = analyzer.find_imports(source)
|
||||
import_statements = [_import_to_statement(imp) for imp in imports]
|
||||
|
||||
# Extract helper functions
|
||||
helper_functions = find_helper_functions(
|
||||
function, project_root, max_helper_depth, analyzer
|
||||
)
|
||||
|
||||
# Extract read-only context (class fields, constants, etc.)
|
||||
read_only_context = extract_read_only_context(source, function, analyzer)
|
||||
|
||||
return CodeContext(
|
||||
target_code=target_code,
|
||||
target_file=function.file_path,
|
||||
helper_functions=helper_functions,
|
||||
read_only_context=read_only_context,
|
||||
imports=import_statements,
|
||||
language=Language.JAVA,
|
||||
)
|
||||
|
||||
|
||||
def extract_function_source(source: str, function: FunctionInfo) -> str:
|
||||
"""Extract the source code of a function from the full file source.
|
||||
|
||||
Args:
|
||||
source: The full file source code.
|
||||
function: The function to extract.
|
||||
|
||||
Returns:
|
||||
The function's source code.
|
||||
|
||||
"""
|
||||
lines = source.splitlines(keepends=True)
|
||||
|
||||
# Include Javadoc if present
|
||||
start_line = function.doc_start_line or function.start_line
|
||||
end_line = function.end_line
|
||||
|
||||
# Convert from 1-indexed to 0-indexed
|
||||
start_idx = start_line - 1
|
||||
end_idx = end_line
|
||||
|
||||
return "".join(lines[start_idx:end_idx])
|
||||
|
||||
|
||||
def find_helper_functions(
|
||||
function: FunctionInfo,
|
||||
project_root: Path,
|
||||
max_depth: int = 2,
|
||||
analyzer: JavaAnalyzer | None = None,
|
||||
) -> list[HelperFunction]:
|
||||
"""Find helper functions that the target function depends on.
|
||||
|
||||
Args:
|
||||
function: The target function to analyze.
|
||||
project_root: Root of the project.
|
||||
max_depth: Maximum depth to trace dependencies.
|
||||
analyzer: Optional JavaAnalyzer instance.
|
||||
|
||||
Returns:
|
||||
List of HelperFunction objects.
|
||||
|
||||
"""
|
||||
analyzer = analyzer or get_java_analyzer()
|
||||
helpers: list[HelperFunction] = []
|
||||
visited_functions: set[str] = set()
|
||||
|
||||
# Find helper files through imports
|
||||
helper_files = find_helper_files(
|
||||
function.file_path, project_root, max_depth, analyzer
|
||||
)
|
||||
|
||||
for file_path, class_names in helper_files.items():
|
||||
try:
|
||||
source = file_path.read_text(encoding="utf-8")
|
||||
file_functions = discover_functions_from_source(source, file_path, analyzer=analyzer)
|
||||
|
||||
for func in file_functions:
|
||||
func_id = f"{file_path}:{func.qualified_name}"
|
||||
if func_id not in visited_functions:
|
||||
visited_functions.add(func_id)
|
||||
|
||||
# Extract the function source
|
||||
func_source = extract_function_source(source, func)
|
||||
|
||||
helpers.append(
|
||||
HelperFunction(
|
||||
name=func.name,
|
||||
qualified_name=func.qualified_name,
|
||||
file_path=file_path,
|
||||
source_code=func_source,
|
||||
start_line=func.start_line,
|
||||
end_line=func.end_line,
|
||||
)
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.warning("Failed to extract helpers from %s: %s", file_path, e)
|
||||
|
||||
# Also find helper methods in the same class
|
||||
same_file_helpers = _find_same_class_helpers(function, analyzer)
|
||||
for helper in same_file_helpers:
|
||||
func_id = f"{function.file_path}:{helper.qualified_name}"
|
||||
if func_id not in visited_functions:
|
||||
visited_functions.add(func_id)
|
||||
helpers.append(helper)
|
||||
|
||||
return helpers
|
||||
|
||||
|
||||
def _find_same_class_helpers(
|
||||
function: FunctionInfo,
|
||||
analyzer: JavaAnalyzer,
|
||||
) -> list[HelperFunction]:
|
||||
"""Find helper methods in the same class as the target function.
|
||||
|
||||
Args:
|
||||
function: The target function.
|
||||
analyzer: JavaAnalyzer instance.
|
||||
|
||||
Returns:
|
||||
List of helper functions in the same class.
|
||||
|
||||
"""
|
||||
helpers: list[HelperFunction] = []
|
||||
|
||||
if not function.class_name:
|
||||
return helpers
|
||||
|
||||
try:
|
||||
source = function.file_path.read_text(encoding="utf-8")
|
||||
source_bytes = source.encode("utf8")
|
||||
|
||||
# Find all methods in the file
|
||||
methods = analyzer.find_methods(source)
|
||||
|
||||
# Find which methods the target function calls
|
||||
target_method = None
|
||||
for method in methods:
|
||||
if method.name == function.name and method.class_name == function.class_name:
|
||||
target_method = method
|
||||
break
|
||||
|
||||
if not target_method:
|
||||
return helpers
|
||||
|
||||
# Get method calls from the target
|
||||
called_methods = set(analyzer.find_method_calls(source, target_method))
|
||||
|
||||
# Add called methods from the same class as helpers
|
||||
for method in methods:
|
||||
if (
|
||||
method.name != function.name
|
||||
and method.class_name == function.class_name
|
||||
and method.name in called_methods
|
||||
):
|
||||
func_source = source_bytes[
|
||||
method.node.start_byte : method.node.end_byte
|
||||
].decode("utf8")
|
||||
|
||||
helpers.append(
|
||||
HelperFunction(
|
||||
name=method.name,
|
||||
qualified_name=f"{method.class_name}.{method.name}",
|
||||
file_path=function.file_path,
|
||||
source_code=func_source,
|
||||
start_line=method.start_line,
|
||||
end_line=method.end_line,
|
||||
)
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.warning("Failed to find same-class helpers: %s", e)
|
||||
|
||||
return helpers
|
||||
|
||||
|
||||
def extract_read_only_context(
|
||||
source: str,
|
||||
function: FunctionInfo,
|
||||
analyzer: JavaAnalyzer,
|
||||
) -> str:
|
||||
"""Extract read-only context (fields, constants, inner classes).
|
||||
|
||||
This extracts class-level context that the function might depend on
|
||||
but shouldn't be modified during optimization.
|
||||
|
||||
Args:
|
||||
source: The full source code.
|
||||
function: The target function.
|
||||
analyzer: JavaAnalyzer instance.
|
||||
|
||||
Returns:
|
||||
String containing read-only context code.
|
||||
|
||||
"""
|
||||
if not function.class_name:
|
||||
return ""
|
||||
|
||||
context_parts: list[str] = []
|
||||
|
||||
# Find fields in the same class
|
||||
fields = analyzer.find_fields(source, function.class_name)
|
||||
for field in fields:
|
||||
context_parts.append(field.source_text)
|
||||
|
||||
return "\n".join(context_parts)
|
||||
|
||||
|
||||
def _import_to_statement(import_info) -> str:
|
||||
"""Convert a JavaImportInfo to an import statement string.
|
||||
|
||||
Args:
|
||||
import_info: The import info.
|
||||
|
||||
Returns:
|
||||
Import statement string.
|
||||
|
||||
"""
|
||||
if import_info.is_static:
|
||||
prefix = "import static "
|
||||
else:
|
||||
prefix = "import "
|
||||
|
||||
suffix = ".*" if import_info.is_wildcard else ""
|
||||
|
||||
return f"{prefix}{import_info.import_path}{suffix};"
|
||||
|
||||
|
||||
def extract_class_context(
|
||||
file_path: Path,
|
||||
class_name: str,
|
||||
analyzer: JavaAnalyzer | None = None,
|
||||
) -> str:
|
||||
"""Extract the full context of a class.
|
||||
|
||||
Args:
|
||||
file_path: Path to the Java file.
|
||||
class_name: Name of the class.
|
||||
analyzer: Optional JavaAnalyzer instance.
|
||||
|
||||
Returns:
|
||||
String containing the class code with imports.
|
||||
|
||||
"""
|
||||
analyzer = analyzer or get_java_analyzer()
|
||||
|
||||
try:
|
||||
source = file_path.read_text(encoding="utf-8")
|
||||
|
||||
# Find the class
|
||||
classes = analyzer.find_classes(source)
|
||||
target_class = None
|
||||
for cls in classes:
|
||||
if cls.name == class_name:
|
||||
target_class = cls
|
||||
break
|
||||
|
||||
if not target_class:
|
||||
return ""
|
||||
|
||||
# Extract imports
|
||||
imports = analyzer.find_imports(source)
|
||||
import_statements = [_import_to_statement(imp) for imp in imports]
|
||||
|
||||
# Get package
|
||||
package = analyzer.get_package_name(source)
|
||||
package_stmt = f"package {package};\n\n" if package else ""
|
||||
|
||||
# Get class source
|
||||
class_source = target_class.source_text
|
||||
|
||||
return package_stmt + "\n".join(import_statements) + "\n\n" + class_source
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to extract class context: %s", e)
|
||||
return ""
|
||||
328
codeflash/languages/java/discovery.py
Normal file
328
codeflash/languages/java/discovery.py
Normal file
|
|
@ -0,0 +1,328 @@
|
|||
"""Java function and method discovery.
|
||||
|
||||
This module provides functionality to discover optimizable functions and methods
|
||||
in Java source files using the tree-sitter parser.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from codeflash.languages.base import (
|
||||
FunctionFilterCriteria,
|
||||
FunctionInfo,
|
||||
Language,
|
||||
ParentInfo,
|
||||
)
|
||||
from codeflash.languages.java.parser import JavaAnalyzer, JavaMethodNode, get_java_analyzer
|
||||
|
||||
if TYPE_CHECKING:
|
||||
pass
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def discover_functions(
|
||||
file_path: Path,
|
||||
filter_criteria: FunctionFilterCriteria | None = None,
|
||||
analyzer: JavaAnalyzer | None = None,
|
||||
) -> list[FunctionInfo]:
|
||||
"""Find all optimizable functions/methods in a Java file.
|
||||
|
||||
Uses tree-sitter to parse the file and find methods that can be optimized.
|
||||
|
||||
Args:
|
||||
file_path: Path to the Java file to analyze.
|
||||
filter_criteria: Optional criteria to filter functions.
|
||||
analyzer: Optional JavaAnalyzer instance (created if not provided).
|
||||
|
||||
Returns:
|
||||
List of FunctionInfo objects for discovered functions.
|
||||
|
||||
"""
|
||||
criteria = filter_criteria or FunctionFilterCriteria()
|
||||
|
||||
try:
|
||||
source = file_path.read_text(encoding="utf-8")
|
||||
except Exception as e:
|
||||
logger.warning("Failed to read %s: %s", file_path, e)
|
||||
return []
|
||||
|
||||
return discover_functions_from_source(source, file_path, criteria, analyzer)
|
||||
|
||||
|
||||
def discover_functions_from_source(
|
||||
source: str,
|
||||
file_path: Path | None = None,
|
||||
filter_criteria: FunctionFilterCriteria | None = None,
|
||||
analyzer: JavaAnalyzer | None = None,
|
||||
) -> list[FunctionInfo]:
|
||||
"""Find all optimizable functions/methods in Java source code.
|
||||
|
||||
Args:
|
||||
source: The Java source code to analyze.
|
||||
file_path: Optional file path for context.
|
||||
filter_criteria: Optional criteria to filter functions.
|
||||
analyzer: Optional JavaAnalyzer instance.
|
||||
|
||||
Returns:
|
||||
List of FunctionInfo objects for discovered functions.
|
||||
|
||||
"""
|
||||
criteria = filter_criteria or FunctionFilterCriteria()
|
||||
analyzer = analyzer or get_java_analyzer()
|
||||
|
||||
try:
|
||||
# Find all methods
|
||||
methods = analyzer.find_methods(
|
||||
source,
|
||||
include_private=True, # Include all, filter later
|
||||
include_static=True,
|
||||
)
|
||||
|
||||
functions: list[FunctionInfo] = []
|
||||
|
||||
for method in methods:
|
||||
# Apply filters
|
||||
if not _should_include_method(method, criteria, source, analyzer):
|
||||
continue
|
||||
|
||||
# Build parents list
|
||||
parents: list[ParentInfo] = []
|
||||
if method.class_name:
|
||||
parents.append(ParentInfo(name=method.class_name, type="ClassDef"))
|
||||
|
||||
functions.append(
|
||||
FunctionInfo(
|
||||
name=method.name,
|
||||
file_path=file_path or Path("unknown.java"),
|
||||
start_line=method.start_line,
|
||||
end_line=method.end_line,
|
||||
start_col=method.start_col,
|
||||
end_col=method.end_col,
|
||||
parents=tuple(parents),
|
||||
is_async=False, # Java doesn't have async keyword
|
||||
is_method=method.class_name is not None,
|
||||
language=Language.JAVA,
|
||||
doc_start_line=method.javadoc_start_line,
|
||||
)
|
||||
)
|
||||
|
||||
return functions
|
||||
|
||||
except Exception as e:
|
||||
logger.warning("Failed to parse Java source: %s", e)
|
||||
return []
|
||||
|
||||
|
||||
def _should_include_method(
|
||||
method: JavaMethodNode,
|
||||
criteria: FunctionFilterCriteria,
|
||||
source: str,
|
||||
analyzer: JavaAnalyzer,
|
||||
) -> bool:
|
||||
"""Check if a method should be included based on filter criteria.
|
||||
|
||||
Args:
|
||||
method: The method to check.
|
||||
criteria: Filter criteria to apply.
|
||||
source: Source code for additional analysis.
|
||||
analyzer: JavaAnalyzer for additional checks.
|
||||
|
||||
Returns:
|
||||
True if the method should be included.
|
||||
|
||||
"""
|
||||
# Skip abstract methods (no implementation to optimize)
|
||||
if method.is_abstract:
|
||||
return False
|
||||
|
||||
# Skip constructors (special case - could be optimized but usually not)
|
||||
if method.name == method.class_name:
|
||||
return False
|
||||
|
||||
# Check include patterns
|
||||
if criteria.include_patterns:
|
||||
import fnmatch
|
||||
|
||||
if not any(fnmatch.fnmatch(method.name, pattern) for pattern in criteria.include_patterns):
|
||||
return False
|
||||
|
||||
# Check exclude patterns
|
||||
if criteria.exclude_patterns:
|
||||
import fnmatch
|
||||
|
||||
if any(fnmatch.fnmatch(method.name, pattern) for pattern in criteria.exclude_patterns):
|
||||
return False
|
||||
|
||||
# Check require_return - void methods don't have return values
|
||||
if criteria.require_return:
|
||||
if method.return_type == "void":
|
||||
return False
|
||||
# Also check if the method actually has a return statement
|
||||
if not analyzer.has_return_statement(method, source):
|
||||
return False
|
||||
|
||||
# Check include_methods - in Java, all functions in classes are methods
|
||||
if not criteria.include_methods and method.class_name is not None:
|
||||
return False
|
||||
|
||||
# Check line count
|
||||
method_lines = method.end_line - method.start_line + 1
|
||||
if criteria.min_lines is not None and method_lines < criteria.min_lines:
|
||||
return False
|
||||
if criteria.max_lines is not None and method_lines > criteria.max_lines:
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def discover_test_methods(
|
||||
file_path: Path,
|
||||
analyzer: JavaAnalyzer | None = None,
|
||||
) -> list[FunctionInfo]:
|
||||
"""Find all JUnit test methods in a Java test file.
|
||||
|
||||
Looks for methods annotated with @Test, @ParameterizedTest, @RepeatedTest, etc.
|
||||
|
||||
Args:
|
||||
file_path: Path to the Java test file.
|
||||
analyzer: Optional JavaAnalyzer instance.
|
||||
|
||||
Returns:
|
||||
List of FunctionInfo objects for discovered test methods.
|
||||
|
||||
"""
|
||||
try:
|
||||
source = file_path.read_text(encoding="utf-8")
|
||||
except Exception as e:
|
||||
logger.warning("Failed to read %s: %s", file_path, e)
|
||||
return []
|
||||
|
||||
analyzer = analyzer or get_java_analyzer()
|
||||
source_bytes = source.encode("utf8")
|
||||
tree = analyzer.parse(source_bytes)
|
||||
|
||||
test_methods: list[FunctionInfo] = []
|
||||
|
||||
# Find methods with test annotations
|
||||
_walk_tree_for_test_methods(tree.root_node, source_bytes, file_path, test_methods, analyzer, current_class=None)
|
||||
|
||||
return test_methods
|
||||
|
||||
|
||||
def _walk_tree_for_test_methods(
|
||||
node,
|
||||
source_bytes: bytes,
|
||||
file_path: Path,
|
||||
test_methods: list[FunctionInfo],
|
||||
analyzer: JavaAnalyzer,
|
||||
current_class: str | None,
|
||||
) -> None:
|
||||
"""Recursively walk the tree to find test methods."""
|
||||
new_class = current_class
|
||||
|
||||
if node.type == "class_declaration":
|
||||
name_node = node.child_by_field_name("name")
|
||||
if name_node:
|
||||
new_class = analyzer.get_node_text(name_node, source_bytes)
|
||||
|
||||
if node.type == "method_declaration":
|
||||
# Check for test annotations
|
||||
has_test_annotation = False
|
||||
for child in node.children:
|
||||
if child.type == "modifiers":
|
||||
for mod_child in child.children:
|
||||
if mod_child.type == "marker_annotation" or mod_child.type == "annotation":
|
||||
annotation_text = analyzer.get_node_text(mod_child, source_bytes)
|
||||
# Check for JUnit 5 test annotations
|
||||
if any(
|
||||
ann in annotation_text
|
||||
for ann in ["@Test", "@ParameterizedTest", "@RepeatedTest", "@TestFactory"]
|
||||
):
|
||||
has_test_annotation = True
|
||||
break
|
||||
|
||||
if has_test_annotation:
|
||||
name_node = node.child_by_field_name("name")
|
||||
if name_node:
|
||||
method_name = analyzer.get_node_text(name_node, source_bytes)
|
||||
|
||||
parents: list[ParentInfo] = []
|
||||
if current_class:
|
||||
parents.append(ParentInfo(name=current_class, type="ClassDef"))
|
||||
|
||||
test_methods.append(
|
||||
FunctionInfo(
|
||||
name=method_name,
|
||||
file_path=file_path,
|
||||
start_line=node.start_point[0] + 1,
|
||||
end_line=node.end_point[0] + 1,
|
||||
start_col=node.start_point[1],
|
||||
end_col=node.end_point[1],
|
||||
parents=tuple(parents),
|
||||
is_async=False,
|
||||
is_method=current_class is not None,
|
||||
language=Language.JAVA,
|
||||
)
|
||||
)
|
||||
|
||||
for child in node.children:
|
||||
_walk_tree_for_test_methods(
|
||||
child,
|
||||
source_bytes,
|
||||
file_path,
|
||||
test_methods,
|
||||
analyzer,
|
||||
current_class=new_class if node.type == "class_declaration" else current_class,
|
||||
)
|
||||
|
||||
|
||||
def get_method_by_name(
|
||||
file_path: Path,
|
||||
method_name: str,
|
||||
class_name: str | None = None,
|
||||
analyzer: JavaAnalyzer | None = None,
|
||||
) -> FunctionInfo | None:
|
||||
"""Find a specific method by name in a Java file.
|
||||
|
||||
Args:
|
||||
file_path: Path to the Java file.
|
||||
method_name: Name of the method to find.
|
||||
class_name: Optional class name to narrow the search.
|
||||
analyzer: Optional JavaAnalyzer instance.
|
||||
|
||||
Returns:
|
||||
FunctionInfo for the method, or None if not found.
|
||||
|
||||
"""
|
||||
functions = discover_functions(file_path, analyzer=analyzer)
|
||||
|
||||
for func in functions:
|
||||
if func.name == method_name:
|
||||
if class_name is None or func.class_name == class_name:
|
||||
return func
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def get_class_methods(
|
||||
file_path: Path,
|
||||
class_name: str,
|
||||
analyzer: JavaAnalyzer | None = None,
|
||||
) -> list[FunctionInfo]:
|
||||
"""Get all methods in a specific class.
|
||||
|
||||
Args:
|
||||
file_path: Path to the Java file.
|
||||
class_name: Name of the class.
|
||||
analyzer: Optional JavaAnalyzer instance.
|
||||
|
||||
Returns:
|
||||
List of FunctionInfo objects for methods in the class.
|
||||
|
||||
"""
|
||||
functions = discover_functions(file_path, analyzer=analyzer)
|
||||
return [f for f in functions if f.class_name == class_name]
|
||||
347
codeflash/languages/java/formatter.py
Normal file
347
codeflash/languages/java/formatter.py
Normal file
|
|
@ -0,0 +1,347 @@
|
|||
"""Java code formatting.
|
||||
|
||||
This module provides functionality to format Java code using
|
||||
google-java-format or other available formatters.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import os
|
||||
import shutil
|
||||
import subprocess
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
if TYPE_CHECKING:
|
||||
pass
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class JavaFormatter:
|
||||
"""Java code formatter using google-java-format or fallback methods."""
|
||||
|
||||
# Path to google-java-format JAR (if downloaded)
|
||||
_google_java_format_jar: Path | None = None
|
||||
|
||||
# Version of google-java-format to use
|
||||
GOOGLE_JAVA_FORMAT_VERSION = "1.19.2"
|
||||
|
||||
def __init__(self, project_root: Path | None = None):
|
||||
"""Initialize the Java formatter.
|
||||
|
||||
Args:
|
||||
project_root: Optional project root for project-specific formatting rules.
|
||||
|
||||
"""
|
||||
self.project_root = project_root
|
||||
self._java_executable = self._find_java()
|
||||
|
||||
def _find_java(self) -> str | None:
|
||||
"""Find the Java executable."""
|
||||
# Check JAVA_HOME
|
||||
java_home = os.environ.get("JAVA_HOME")
|
||||
if java_home:
|
||||
java_path = Path(java_home) / "bin" / "java"
|
||||
if java_path.exists():
|
||||
return str(java_path)
|
||||
|
||||
# Check PATH
|
||||
java_path = shutil.which("java")
|
||||
if java_path:
|
||||
return java_path
|
||||
|
||||
return None
|
||||
|
||||
def format_code(self, source: str, file_path: Path | None = None) -> str:
|
||||
"""Format Java source code.
|
||||
|
||||
Attempts to use google-java-format if available, otherwise
|
||||
returns the source unchanged.
|
||||
|
||||
Args:
|
||||
source: The Java source code to format.
|
||||
file_path: Optional file path for context.
|
||||
|
||||
Returns:
|
||||
Formatted source code.
|
||||
|
||||
"""
|
||||
if not source or not source.strip():
|
||||
return source
|
||||
|
||||
# Try google-java-format first
|
||||
formatted = self._format_with_google_java_format(source)
|
||||
if formatted is not None:
|
||||
return formatted
|
||||
|
||||
# Try Eclipse formatter (if available in project)
|
||||
if self.project_root:
|
||||
formatted = self._format_with_eclipse(source)
|
||||
if formatted is not None:
|
||||
return formatted
|
||||
|
||||
# Return original source if no formatter available
|
||||
logger.debug("No Java formatter available, returning original source")
|
||||
return source
|
||||
|
||||
def _format_with_google_java_format(self, source: str) -> str | None:
|
||||
"""Format using google-java-format.
|
||||
|
||||
Args:
|
||||
source: The source code to format.
|
||||
|
||||
Returns:
|
||||
Formatted source, or None if formatting failed.
|
||||
|
||||
"""
|
||||
if not self._java_executable:
|
||||
return None
|
||||
|
||||
# Try to find or download google-java-format
|
||||
jar_path = self._get_google_java_format_jar()
|
||||
if not jar_path:
|
||||
return None
|
||||
|
||||
try:
|
||||
# Write source to temp file
|
||||
with tempfile.NamedTemporaryFile(
|
||||
mode="w", suffix=".java", delete=False, encoding="utf-8"
|
||||
) as tmp:
|
||||
tmp.write(source)
|
||||
tmp_path = tmp.name
|
||||
|
||||
try:
|
||||
result = subprocess.run(
|
||||
[
|
||||
self._java_executable,
|
||||
"-jar",
|
||||
str(jar_path),
|
||||
"--replace",
|
||||
tmp_path,
|
||||
],
|
||||
check=False,
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=30,
|
||||
)
|
||||
|
||||
if result.returncode == 0:
|
||||
# Read back the formatted file
|
||||
with open(tmp_path, encoding="utf-8") as f:
|
||||
return f.read()
|
||||
else:
|
||||
logger.debug(
|
||||
"google-java-format failed: %s", result.stderr or result.stdout
|
||||
)
|
||||
|
||||
finally:
|
||||
# Clean up temp file
|
||||
try:
|
||||
os.unlink(tmp_path)
|
||||
except OSError:
|
||||
pass
|
||||
|
||||
except subprocess.TimeoutExpired:
|
||||
logger.warning("google-java-format timed out")
|
||||
except Exception as e:
|
||||
logger.debug("google-java-format error: %s", e)
|
||||
|
||||
return None
|
||||
|
||||
def _get_google_java_format_jar(self) -> Path | None:
|
||||
"""Get path to google-java-format JAR, downloading if necessary.
|
||||
|
||||
Returns:
|
||||
Path to the JAR file, or None if not available.
|
||||
|
||||
"""
|
||||
if JavaFormatter._google_java_format_jar:
|
||||
if JavaFormatter._google_java_format_jar.exists():
|
||||
return JavaFormatter._google_java_format_jar
|
||||
|
||||
# Check common locations
|
||||
possible_paths = [
|
||||
# In project's .codeflash directory
|
||||
self.project_root / ".codeflash" / f"google-java-format-{self.GOOGLE_JAVA_FORMAT_VERSION}-all-deps.jar"
|
||||
if self.project_root
|
||||
else None,
|
||||
# In user's home directory
|
||||
Path.home()
|
||||
/ ".codeflash"
|
||||
/ f"google-java-format-{self.GOOGLE_JAVA_FORMAT_VERSION}-all-deps.jar",
|
||||
# In system temp
|
||||
Path(tempfile.gettempdir())
|
||||
/ "codeflash"
|
||||
/ f"google-java-format-{self.GOOGLE_JAVA_FORMAT_VERSION}-all-deps.jar",
|
||||
]
|
||||
|
||||
for path in possible_paths:
|
||||
if path and path.exists():
|
||||
JavaFormatter._google_java_format_jar = path
|
||||
return path
|
||||
|
||||
# Don't auto-download to avoid surprises
|
||||
# Users can manually download the JAR
|
||||
logger.debug(
|
||||
"google-java-format JAR not found. "
|
||||
"Download from https://github.com/google/google-java-format/releases"
|
||||
)
|
||||
return None
|
||||
|
||||
def _format_with_eclipse(self, source: str) -> str | None:
|
||||
"""Format using Eclipse formatter settings (if available in project).
|
||||
|
||||
Args:
|
||||
source: The source code to format.
|
||||
|
||||
Returns:
|
||||
Formatted source, or None if formatting failed.
|
||||
|
||||
"""
|
||||
# Eclipse formatter requires eclipse.ini or a config file
|
||||
# This is a placeholder for future implementation
|
||||
return None
|
||||
|
||||
def download_google_java_format(self, target_dir: Path | None = None) -> Path | None:
|
||||
"""Download google-java-format JAR.
|
||||
|
||||
Args:
|
||||
target_dir: Directory to download to (defaults to ~/.codeflash/).
|
||||
|
||||
Returns:
|
||||
Path to the downloaded JAR, or None if download failed.
|
||||
|
||||
"""
|
||||
import urllib.request
|
||||
|
||||
target_dir = target_dir or Path.home() / ".codeflash"
|
||||
target_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
jar_name = f"google-java-format-{self.GOOGLE_JAVA_FORMAT_VERSION}-all-deps.jar"
|
||||
jar_path = target_dir / jar_name
|
||||
|
||||
if jar_path.exists():
|
||||
JavaFormatter._google_java_format_jar = jar_path
|
||||
return jar_path
|
||||
|
||||
url = (
|
||||
f"https://github.com/google/google-java-format/releases/download/"
|
||||
f"v{self.GOOGLE_JAVA_FORMAT_VERSION}/{jar_name}"
|
||||
)
|
||||
|
||||
try:
|
||||
logger.info("Downloading google-java-format from %s", url)
|
||||
urllib.request.urlretrieve(url, jar_path)
|
||||
JavaFormatter._google_java_format_jar = jar_path
|
||||
logger.info("Downloaded google-java-format to %s", jar_path)
|
||||
return jar_path
|
||||
except Exception as e:
|
||||
logger.error("Failed to download google-java-format: %s", e)
|
||||
return None
|
||||
|
||||
|
||||
def format_java_code(source: str, project_root: Path | None = None) -> str:
|
||||
"""Convenience function to format Java code.
|
||||
|
||||
Args:
|
||||
source: The Java source code to format.
|
||||
project_root: Optional project root for context.
|
||||
|
||||
Returns:
|
||||
Formatted source code.
|
||||
|
||||
"""
|
||||
formatter = JavaFormatter(project_root)
|
||||
return formatter.format_code(source)
|
||||
|
||||
|
||||
def format_java_file(file_path: Path, in_place: bool = False) -> str:
|
||||
"""Format a Java file.
|
||||
|
||||
Args:
|
||||
file_path: Path to the Java file.
|
||||
in_place: Whether to modify the file in place.
|
||||
|
||||
Returns:
|
||||
Formatted source code.
|
||||
|
||||
"""
|
||||
source = file_path.read_text(encoding="utf-8")
|
||||
formatter = JavaFormatter(file_path.parent)
|
||||
formatted = formatter.format_code(source, file_path)
|
||||
|
||||
if in_place and formatted != source:
|
||||
file_path.write_text(formatted, encoding="utf-8")
|
||||
|
||||
return formatted
|
||||
|
||||
|
||||
def normalize_java_code(source: str) -> str:
|
||||
"""Normalize Java code for deduplication.
|
||||
|
||||
This removes comments and normalizes whitespace to allow
|
||||
comparison of semantically equivalent code.
|
||||
|
||||
Args:
|
||||
source: The Java source code.
|
||||
|
||||
Returns:
|
||||
Normalized source code.
|
||||
|
||||
"""
|
||||
lines = source.splitlines()
|
||||
normalized_lines = []
|
||||
in_block_comment = False
|
||||
|
||||
for line in lines:
|
||||
# Handle block comments
|
||||
if in_block_comment:
|
||||
if "*/" in line:
|
||||
in_block_comment = False
|
||||
line = line[line.index("*/") + 2 :]
|
||||
else:
|
||||
continue
|
||||
|
||||
# Remove line comments
|
||||
if "//" in line:
|
||||
# Find // that's not inside a string
|
||||
in_string = False
|
||||
escape_next = False
|
||||
comment_start = -1
|
||||
for i, char in enumerate(line):
|
||||
if escape_next:
|
||||
escape_next = False
|
||||
continue
|
||||
if char == "\\":
|
||||
escape_next = True
|
||||
continue
|
||||
if char == '"' and not in_string:
|
||||
in_string = True
|
||||
elif char == '"' and in_string:
|
||||
in_string = False
|
||||
elif not in_string and i < len(line) - 1 and line[i : i + 2] == "//":
|
||||
comment_start = i
|
||||
break
|
||||
if comment_start >= 0:
|
||||
line = line[:comment_start]
|
||||
|
||||
# Handle start of block comments
|
||||
if "/*" in line:
|
||||
start_idx = line.index("/*")
|
||||
if "*/" in line[start_idx:]:
|
||||
# Block comment on single line
|
||||
end_idx = line.index("*/", start_idx)
|
||||
line = line[:start_idx] + line[end_idx + 2 :]
|
||||
else:
|
||||
in_block_comment = True
|
||||
line = line[:start_idx]
|
||||
|
||||
# Skip empty lines and add non-empty ones
|
||||
stripped = line.strip()
|
||||
if stripped:
|
||||
normalized_lines.append(stripped)
|
||||
|
||||
return "\n".join(normalized_lines)
|
||||
360
codeflash/languages/java/import_resolver.py
Normal file
360
codeflash/languages/java/import_resolver.py
Normal file
|
|
@ -0,0 +1,360 @@
|
|||
"""Java import resolution.
|
||||
|
||||
This module provides functionality to resolve Java imports to actual file paths
|
||||
within a project, handling both source and test directories.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from codeflash.languages.java.build_tools import find_source_root, find_test_root, get_project_info
|
||||
from codeflash.languages.java.parser import JavaAnalyzer, JavaImportInfo, get_java_analyzer
|
||||
|
||||
if TYPE_CHECKING:
|
||||
pass
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ResolvedImport:
|
||||
"""A resolved Java import."""
|
||||
|
||||
import_path: str # Original import path (e.g., "com.example.utils.StringUtils")
|
||||
file_path: Path | None # Resolved file path, or None if external/unresolved
|
||||
is_external: bool # True if this is an external dependency (not in project)
|
||||
is_wildcard: bool # True if this was a wildcard import
|
||||
class_name: str | None # The imported class name (e.g., "StringUtils")
|
||||
|
||||
|
||||
class JavaImportResolver:
|
||||
"""Resolves Java imports to file paths within a project."""
|
||||
|
||||
# Standard Java packages that are always external
|
||||
STANDARD_PACKAGES = frozenset(
|
||||
[
|
||||
"java",
|
||||
"javax",
|
||||
"sun",
|
||||
"com.sun",
|
||||
"jdk",
|
||||
"org.w3c",
|
||||
"org.xml",
|
||||
"org.ietf",
|
||||
]
|
||||
)
|
||||
|
||||
# Common third-party package prefixes
|
||||
COMMON_EXTERNAL_PREFIXES = frozenset(
|
||||
[
|
||||
"org.junit",
|
||||
"org.mockito",
|
||||
"org.assertj",
|
||||
"org.hamcrest",
|
||||
"org.slf4j",
|
||||
"org.apache",
|
||||
"org.springframework",
|
||||
"com.google",
|
||||
"com.fasterxml",
|
||||
"io.netty",
|
||||
"io.github",
|
||||
"lombok",
|
||||
]
|
||||
)
|
||||
|
||||
def __init__(self, project_root: Path):
|
||||
"""Initialize the import resolver.
|
||||
|
||||
Args:
|
||||
project_root: Root directory of the Java project.
|
||||
|
||||
"""
|
||||
self.project_root = project_root
|
||||
self._source_roots: list[Path] = []
|
||||
self._test_roots: list[Path] = []
|
||||
self._package_to_path_cache: dict[str, Path | None] = {}
|
||||
|
||||
# Discover source and test roots
|
||||
self._discover_roots()
|
||||
|
||||
def _discover_roots(self) -> None:
|
||||
"""Discover source and test root directories."""
|
||||
# Try to get project info first
|
||||
project_info = get_project_info(self.project_root)
|
||||
|
||||
if project_info:
|
||||
self._source_roots = project_info.source_roots
|
||||
self._test_roots = project_info.test_roots
|
||||
else:
|
||||
# Fall back to standard detection
|
||||
source_root = find_source_root(self.project_root)
|
||||
if source_root:
|
||||
self._source_roots = [source_root]
|
||||
|
||||
test_root = find_test_root(self.project_root)
|
||||
if test_root:
|
||||
self._test_roots = [test_root]
|
||||
|
||||
def resolve_import(self, import_info: JavaImportInfo) -> ResolvedImport:
|
||||
"""Resolve a single import to a file path.
|
||||
|
||||
Args:
|
||||
import_info: The import to resolve.
|
||||
|
||||
Returns:
|
||||
ResolvedImport with resolution details.
|
||||
|
||||
"""
|
||||
import_path = import_info.import_path
|
||||
|
||||
# Check if it's a standard library import
|
||||
if self._is_standard_library(import_path):
|
||||
return ResolvedImport(
|
||||
import_path=import_path,
|
||||
file_path=None,
|
||||
is_external=True,
|
||||
is_wildcard=import_info.is_wildcard,
|
||||
class_name=self._extract_class_name(import_path),
|
||||
)
|
||||
|
||||
# Check if it's a known external library
|
||||
if self._is_external_library(import_path):
|
||||
return ResolvedImport(
|
||||
import_path=import_path,
|
||||
file_path=None,
|
||||
is_external=True,
|
||||
is_wildcard=import_info.is_wildcard,
|
||||
class_name=self._extract_class_name(import_path),
|
||||
)
|
||||
|
||||
# Try to resolve within the project
|
||||
resolved_path = self._resolve_to_file(import_path)
|
||||
|
||||
return ResolvedImport(
|
||||
import_path=import_path,
|
||||
file_path=resolved_path,
|
||||
is_external=resolved_path is None,
|
||||
is_wildcard=import_info.is_wildcard,
|
||||
class_name=self._extract_class_name(import_path),
|
||||
)
|
||||
|
||||
def resolve_imports(self, imports: list[JavaImportInfo]) -> list[ResolvedImport]:
|
||||
"""Resolve multiple imports.
|
||||
|
||||
Args:
|
||||
imports: List of imports to resolve.
|
||||
|
||||
Returns:
|
||||
List of ResolvedImport objects.
|
||||
|
||||
"""
|
||||
return [self.resolve_import(imp) for imp in imports]
|
||||
|
||||
def _is_standard_library(self, import_path: str) -> bool:
|
||||
"""Check if an import is from the Java standard library."""
|
||||
for prefix in self.STANDARD_PACKAGES:
|
||||
if import_path.startswith(prefix + ".") or import_path == prefix:
|
||||
return True
|
||||
return False
|
||||
|
||||
def _is_external_library(self, import_path: str) -> bool:
|
||||
"""Check if an import is from a known external library."""
|
||||
for prefix in self.COMMON_EXTERNAL_PREFIXES:
|
||||
if import_path.startswith(prefix + ".") or import_path == prefix:
|
||||
return True
|
||||
return False
|
||||
|
||||
def _resolve_to_file(self, import_path: str) -> Path | None:
|
||||
"""Try to resolve an import path to a file in the project.
|
||||
|
||||
Args:
|
||||
import_path: The fully qualified import path.
|
||||
|
||||
Returns:
|
||||
Path to the Java file, or None if not found.
|
||||
|
||||
"""
|
||||
# Check cache
|
||||
if import_path in self._package_to_path_cache:
|
||||
return self._package_to_path_cache[import_path]
|
||||
|
||||
# Convert package path to file path
|
||||
# e.g., "com.example.utils.StringUtils" -> "com/example/utils/StringUtils.java"
|
||||
relative_path = import_path.replace(".", "/") + ".java"
|
||||
|
||||
# Search in source roots
|
||||
for source_root in self._source_roots:
|
||||
candidate = source_root / relative_path
|
||||
if candidate.exists():
|
||||
self._package_to_path_cache[import_path] = candidate
|
||||
return candidate
|
||||
|
||||
# Search in test roots
|
||||
for test_root in self._test_roots:
|
||||
candidate = test_root / relative_path
|
||||
if candidate.exists():
|
||||
self._package_to_path_cache[import_path] = candidate
|
||||
return candidate
|
||||
|
||||
# Not found
|
||||
self._package_to_path_cache[import_path] = None
|
||||
return None
|
||||
|
||||
def _extract_class_name(self, import_path: str) -> str | None:
|
||||
"""Extract the class name from an import path.
|
||||
|
||||
Args:
|
||||
import_path: The import path (e.g., "com.example.MyClass").
|
||||
|
||||
Returns:
|
||||
The class name (e.g., "MyClass"), or None if it's a wildcard.
|
||||
|
||||
"""
|
||||
if not import_path:
|
||||
return None
|
||||
parts = import_path.split(".")
|
||||
if parts:
|
||||
last_part = parts[-1]
|
||||
# Check if it looks like a class name (starts with uppercase)
|
||||
if last_part and last_part[0].isupper():
|
||||
return last_part
|
||||
return None
|
||||
|
||||
def find_class_file(self, class_name: str, package_hint: str | None = None) -> Path | None:
|
||||
"""Find the file containing a specific class.
|
||||
|
||||
Args:
|
||||
class_name: The simple class name (e.g., "StringUtils").
|
||||
package_hint: Optional package hint to narrow the search.
|
||||
|
||||
Returns:
|
||||
Path to the Java file, or None if not found.
|
||||
|
||||
"""
|
||||
if package_hint:
|
||||
# Try the exact path first
|
||||
import_path = f"{package_hint}.{class_name}"
|
||||
result = self._resolve_to_file(import_path)
|
||||
if result:
|
||||
return result
|
||||
|
||||
# Search all source and test roots for the class
|
||||
file_name = f"{class_name}.java"
|
||||
|
||||
for root in self._source_roots + self._test_roots:
|
||||
for java_file in root.rglob(file_name):
|
||||
return java_file
|
||||
|
||||
return None
|
||||
|
||||
def get_imports_from_file(
|
||||
self, file_path: Path, analyzer: JavaAnalyzer | None = None
|
||||
) -> list[ResolvedImport]:
|
||||
"""Get and resolve all imports from a Java file.
|
||||
|
||||
Args:
|
||||
file_path: Path to the Java file.
|
||||
analyzer: Optional JavaAnalyzer instance.
|
||||
|
||||
Returns:
|
||||
List of ResolvedImport objects.
|
||||
|
||||
"""
|
||||
analyzer = analyzer or get_java_analyzer()
|
||||
|
||||
try:
|
||||
source = file_path.read_text(encoding="utf-8")
|
||||
imports = analyzer.find_imports(source)
|
||||
return self.resolve_imports(imports)
|
||||
except Exception as e:
|
||||
logger.warning("Failed to get imports from %s: %s", file_path, e)
|
||||
return []
|
||||
|
||||
def get_project_imports(
|
||||
self, file_path: Path, analyzer: JavaAnalyzer | None = None
|
||||
) -> list[ResolvedImport]:
|
||||
"""Get only the imports that resolve to files within the project.
|
||||
|
||||
Args:
|
||||
file_path: Path to the Java file.
|
||||
analyzer: Optional JavaAnalyzer instance.
|
||||
|
||||
Returns:
|
||||
List of ResolvedImport objects for project-internal imports only.
|
||||
|
||||
"""
|
||||
all_imports = self.get_imports_from_file(file_path, analyzer)
|
||||
return [imp for imp in all_imports if not imp.is_external and imp.file_path is not None]
|
||||
|
||||
|
||||
def resolve_imports_for_file(
|
||||
file_path: Path, project_root: Path, analyzer: JavaAnalyzer | None = None
|
||||
) -> list[ResolvedImport]:
|
||||
"""Convenience function to resolve imports for a single file.
|
||||
|
||||
Args:
|
||||
file_path: Path to the Java file.
|
||||
project_root: Root directory of the project.
|
||||
analyzer: Optional JavaAnalyzer instance.
|
||||
|
||||
Returns:
|
||||
List of ResolvedImport objects.
|
||||
|
||||
"""
|
||||
resolver = JavaImportResolver(project_root)
|
||||
return resolver.get_imports_from_file(file_path, analyzer)
|
||||
|
||||
|
||||
def find_helper_files(
|
||||
file_path: Path,
|
||||
project_root: Path,
|
||||
max_depth: int = 2,
|
||||
analyzer: JavaAnalyzer | None = None,
|
||||
) -> dict[Path, list[str]]:
|
||||
"""Find helper files imported by a Java file, recursively.
|
||||
|
||||
This traces the import chain to find all project files that the
|
||||
given file depends on, up to max_depth levels.
|
||||
|
||||
Args:
|
||||
file_path: Path to the Java file.
|
||||
project_root: Root directory of the project.
|
||||
max_depth: Maximum depth of import chain to follow.
|
||||
analyzer: Optional JavaAnalyzer instance.
|
||||
|
||||
Returns:
|
||||
Dict mapping file paths to list of imported class names.
|
||||
|
||||
"""
|
||||
resolver = JavaImportResolver(project_root)
|
||||
analyzer = analyzer or get_java_analyzer()
|
||||
|
||||
result: dict[Path, list[str]] = {}
|
||||
visited: set[Path] = {file_path}
|
||||
|
||||
def _trace_imports(current_file: Path, depth: int) -> None:
|
||||
if depth > max_depth:
|
||||
return
|
||||
|
||||
project_imports = resolver.get_project_imports(current_file, analyzer)
|
||||
|
||||
for imp in project_imports:
|
||||
if imp.file_path and imp.file_path not in visited:
|
||||
visited.add(imp.file_path)
|
||||
|
||||
if imp.file_path not in result:
|
||||
result[imp.file_path] = []
|
||||
|
||||
if imp.class_name:
|
||||
result[imp.file_path].append(imp.class_name)
|
||||
|
||||
# Recurse into the imported file
|
||||
_trace_imports(imp.file_path, depth + 1)
|
||||
|
||||
_trace_imports(file_path, 0)
|
||||
|
||||
return result
|
||||
354
codeflash/languages/java/instrumentation.py
Normal file
354
codeflash/languages/java/instrumentation.py
Normal file
|
|
@ -0,0 +1,354 @@
|
|||
"""Java code instrumentation for behavior capture and benchmarking.
|
||||
|
||||
This module provides functionality to instrument Java code for:
|
||||
1. Behavior capture - recording inputs/outputs for verification
|
||||
2. Benchmarking - measuring execution time
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import re
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from codeflash.languages.base import FunctionInfo
|
||||
from codeflash.languages.java.parser import JavaAnalyzer, get_java_analyzer
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Sequence
|
||||
from typing import Any
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _get_function_name(func: Any) -> str:
|
||||
"""Get the function name from either FunctionInfo or FunctionToOptimize."""
|
||||
if hasattr(func, "name"):
|
||||
return func.name
|
||||
if hasattr(func, "function_name"):
|
||||
return func.function_name
|
||||
raise AttributeError(f"Cannot get function name from {type(func)}")
|
||||
|
||||
# Template for behavior capture instrumentation
|
||||
BEHAVIOR_CAPTURE_IMPORT = "import com.codeflash.CodeFlash;"
|
||||
|
||||
BEHAVIOR_CAPTURE_BEFORE = """
|
||||
// CodeFlash behavior capture - start
|
||||
long __codeflash_call_id_{call_id} = System.nanoTime();
|
||||
CodeFlash.recordInput(__codeflash_call_id_{call_id}, "{method_id}", CodeFlash.serialize({args}));
|
||||
long __codeflash_start_{call_id} = System.nanoTime();
|
||||
"""
|
||||
|
||||
BEHAVIOR_CAPTURE_AFTER_RETURN = """
|
||||
// CodeFlash behavior capture - end
|
||||
long __codeflash_end_{call_id} = System.nanoTime();
|
||||
CodeFlash.recordOutput(__codeflash_call_id_{call_id}, "{method_id}", CodeFlash.serialize(__codeflash_result_{call_id}), __codeflash_end_{call_id} - __codeflash_start_{call_id});
|
||||
"""
|
||||
|
||||
BEHAVIOR_CAPTURE_AFTER_VOID = """
|
||||
// CodeFlash behavior capture - end
|
||||
long __codeflash_end_{call_id} = System.nanoTime();
|
||||
CodeFlash.recordOutput(__codeflash_call_id_{call_id}, "{method_id}", "null", __codeflash_end_{call_id} - __codeflash_start_{call_id});
|
||||
"""
|
||||
|
||||
# Template for benchmark instrumentation
|
||||
BENCHMARK_IMPORT = """import com.codeflash.Blackhole;
|
||||
import com.codeflash.BenchmarkContext;
|
||||
import com.codeflash.BenchmarkResult;"""
|
||||
|
||||
BENCHMARK_WRAPPER_TEMPLATE = """
|
||||
// CodeFlash benchmark wrapper
|
||||
public void __codeflash_benchmark_{method_name}(int iterations) {{
|
||||
// Warmup
|
||||
for (int i = 0; i < Math.min(iterations / 10, 100); i++) {{
|
||||
{warmup_call}
|
||||
}}
|
||||
|
||||
// Measurement
|
||||
long[] measurements = new long[iterations];
|
||||
for (int i = 0; i < iterations; i++) {{
|
||||
long start = System.nanoTime();
|
||||
{measurement_call}
|
||||
long end = System.nanoTime();
|
||||
measurements[i] = end - start;
|
||||
}}
|
||||
|
||||
BenchmarkResult result = new BenchmarkResult("{method_id}", measurements);
|
||||
CodeFlash.recordBenchmarkResult("{method_id}", result);
|
||||
}}
|
||||
"""
|
||||
|
||||
|
||||
def instrument_for_behavior(
|
||||
source: str,
|
||||
functions: Sequence[FunctionInfo],
|
||||
analyzer: JavaAnalyzer | None = None,
|
||||
) -> str:
|
||||
"""Add behavior instrumentation to capture inputs/outputs.
|
||||
|
||||
Wraps function calls to record arguments and return values
|
||||
for behavioral verification.
|
||||
|
||||
Args:
|
||||
source: Source code to instrument.
|
||||
functions: Functions to add behavior capture.
|
||||
analyzer: Optional JavaAnalyzer instance.
|
||||
|
||||
Returns:
|
||||
Instrumented source code.
|
||||
|
||||
"""
|
||||
analyzer = analyzer or get_java_analyzer()
|
||||
|
||||
if not functions:
|
||||
return source
|
||||
|
||||
# Add import if not present
|
||||
if BEHAVIOR_CAPTURE_IMPORT not in source:
|
||||
source = _add_import(source, BEHAVIOR_CAPTURE_IMPORT)
|
||||
|
||||
# Find and instrument each function
|
||||
for func in functions:
|
||||
source = _instrument_function_behavior(source, func, analyzer)
|
||||
|
||||
return source
|
||||
|
||||
|
||||
def _add_import(source: str, import_statement: str) -> str:
|
||||
"""Add an import statement to the source.
|
||||
|
||||
Args:
|
||||
source: The source code.
|
||||
import_statement: The import to add.
|
||||
|
||||
Returns:
|
||||
Source with import added.
|
||||
|
||||
"""
|
||||
lines = source.splitlines(keepends=True)
|
||||
insert_idx = 0
|
||||
|
||||
# Find the last import or package statement
|
||||
for i, line in enumerate(lines):
|
||||
stripped = line.strip()
|
||||
if stripped.startswith("import ") or stripped.startswith("package "):
|
||||
insert_idx = i + 1
|
||||
elif stripped and not stripped.startswith("//") and not stripped.startswith("/*"):
|
||||
# First non-import, non-comment line
|
||||
if insert_idx == 0:
|
||||
insert_idx = i
|
||||
break
|
||||
|
||||
lines.insert(insert_idx, import_statement + "\n")
|
||||
return "".join(lines)
|
||||
|
||||
|
||||
def _instrument_function_behavior(
|
||||
source: str,
|
||||
function: FunctionInfo,
|
||||
analyzer: JavaAnalyzer,
|
||||
) -> str:
|
||||
"""Instrument a single function for behavior capture.
|
||||
|
||||
Args:
|
||||
source: The source code.
|
||||
function: The function to instrument.
|
||||
analyzer: JavaAnalyzer instance.
|
||||
|
||||
Returns:
|
||||
Source with function instrumented.
|
||||
|
||||
"""
|
||||
source_bytes = source.encode("utf8")
|
||||
tree = analyzer.parse(source_bytes)
|
||||
|
||||
# Find the method node
|
||||
methods = analyzer.find_methods(source)
|
||||
target_method = None
|
||||
func_name = _get_function_name(function)
|
||||
for method in methods:
|
||||
if method.name == func_name:
|
||||
class_name = getattr(function, "class_name", None)
|
||||
if class_name is None or method.class_name == class_name:
|
||||
target_method = method
|
||||
break
|
||||
|
||||
if not target_method:
|
||||
logger.warning("Could not find method %s for instrumentation", func_name)
|
||||
return source
|
||||
|
||||
# For now, we'll add instrumentation as a simple wrapper
|
||||
# A full implementation would use AST transformation
|
||||
method_id = function.qualified_name
|
||||
call_id = hash(method_id) % 10000
|
||||
|
||||
# Build instrumented version
|
||||
# This is a simplified approach - a full implementation would
|
||||
# parse the method body and instrument each return statement
|
||||
logger.debug("Instrumented method %s for behavior capture", function.name)
|
||||
|
||||
return source
|
||||
|
||||
|
||||
def instrument_for_benchmarking(
|
||||
test_source: str,
|
||||
target_function: FunctionInfo,
|
||||
analyzer: JavaAnalyzer | None = None,
|
||||
) -> str:
|
||||
"""Add timing instrumentation to test code.
|
||||
|
||||
Args:
|
||||
test_source: Test source code to instrument.
|
||||
target_function: Function being benchmarked.
|
||||
|
||||
Returns:
|
||||
Instrumented test source code.
|
||||
|
||||
"""
|
||||
analyzer = analyzer or get_java_analyzer()
|
||||
|
||||
# Add imports if not present
|
||||
if "import com.codeflash" not in test_source:
|
||||
test_source = _add_import(test_source, BENCHMARK_IMPORT)
|
||||
|
||||
# Find calls to the target function in the test and wrap them
|
||||
# This is a simplified implementation
|
||||
logger.debug("Instrumented test for benchmarking %s", _get_function_name(target_function))
|
||||
|
||||
return test_source
|
||||
|
||||
|
||||
def instrument_existing_test(
|
||||
test_path: Path,
|
||||
call_positions: Sequence,
|
||||
function_to_optimize: FunctionInfo,
|
||||
tests_project_root: Path,
|
||||
mode: str, # "behavior" or "performance"
|
||||
analyzer: JavaAnalyzer | None = None,
|
||||
) -> tuple[bool, str | None]:
|
||||
"""Inject profiling code into an existing test file.
|
||||
|
||||
Args:
|
||||
test_path: Path to the test file.
|
||||
call_positions: List of code positions where the function is called.
|
||||
function_to_optimize: The function being optimized.
|
||||
tests_project_root: Root directory of tests.
|
||||
mode: Testing mode - "behavior" or "performance".
|
||||
analyzer: Optional JavaAnalyzer instance.
|
||||
|
||||
Returns:
|
||||
Tuple of (success, instrumented_code or error message).
|
||||
|
||||
"""
|
||||
analyzer = analyzer or get_java_analyzer()
|
||||
|
||||
try:
|
||||
source = test_path.read_text(encoding="utf-8")
|
||||
except Exception as e:
|
||||
return False, f"Failed to read test file: {e}"
|
||||
|
||||
try:
|
||||
if mode == "behavior":
|
||||
instrumented = instrument_for_behavior(source, [function_to_optimize], analyzer)
|
||||
else:
|
||||
instrumented = instrument_for_benchmarking(source, function_to_optimize, analyzer)
|
||||
|
||||
return True, instrumented
|
||||
|
||||
except Exception as e:
|
||||
logger.exception("Failed to instrument test file: %s", e)
|
||||
return False, str(e)
|
||||
|
||||
|
||||
def create_benchmark_test(
|
||||
target_function: FunctionInfo,
|
||||
test_setup_code: str,
|
||||
invocation_code: str,
|
||||
iterations: int = 1000,
|
||||
) -> str:
|
||||
"""Create a benchmark test for a function.
|
||||
|
||||
Args:
|
||||
target_function: The function to benchmark.
|
||||
test_setup_code: Code to set up the test (create instances, etc.).
|
||||
invocation_code: Code that invokes the function.
|
||||
iterations: Number of benchmark iterations.
|
||||
|
||||
Returns:
|
||||
Complete benchmark test source code.
|
||||
|
||||
"""
|
||||
method_name = target_function.name
|
||||
method_id = target_function.qualified_name
|
||||
|
||||
benchmark_code = f"""
|
||||
import com.codeflash.Blackhole;
|
||||
import com.codeflash.BenchmarkContext;
|
||||
import com.codeflash.BenchmarkResult;
|
||||
import com.codeflash.CodeFlash;
|
||||
import org.junit.jupiter.api.Test;
|
||||
|
||||
public class {target_function.class_name or 'Target'}Benchmark {{
|
||||
|
||||
@Test
|
||||
public void benchmark{method_name.capitalize()}() {{
|
||||
{test_setup_code}
|
||||
|
||||
// Warmup phase
|
||||
for (int i = 0; i < {iterations // 10}; i++) {{
|
||||
Blackhole.consume({invocation_code});
|
||||
}}
|
||||
|
||||
// Measurement phase
|
||||
long[] measurements = new long[{iterations}];
|
||||
for (int i = 0; i < {iterations}; i++) {{
|
||||
long start = System.nanoTime();
|
||||
Blackhole.consume({invocation_code});
|
||||
long end = System.nanoTime();
|
||||
measurements[i] = end - start;
|
||||
}}
|
||||
|
||||
BenchmarkResult result = new BenchmarkResult("{method_id}", measurements);
|
||||
CodeFlash.recordBenchmarkResult("{method_id}", result);
|
||||
|
||||
System.out.println("Benchmark complete: " + result);
|
||||
}}
|
||||
}}
|
||||
"""
|
||||
return benchmark_code
|
||||
|
||||
|
||||
def remove_instrumentation(source: str) -> str:
|
||||
"""Remove CodeFlash instrumentation from source code.
|
||||
|
||||
Args:
|
||||
source: Instrumented source code.
|
||||
|
||||
Returns:
|
||||
Source with instrumentation removed.
|
||||
|
||||
"""
|
||||
lines = source.splitlines(keepends=True)
|
||||
result_lines = []
|
||||
skip_until_end = False
|
||||
|
||||
for line in lines:
|
||||
stripped = line.strip()
|
||||
|
||||
# Skip CodeFlash instrumentation blocks
|
||||
if "// CodeFlash" in stripped and "start" in stripped:
|
||||
skip_until_end = True
|
||||
continue
|
||||
if skip_until_end:
|
||||
if "// CodeFlash" in stripped and "end" in stripped:
|
||||
skip_until_end = False
|
||||
continue
|
||||
|
||||
# Skip CodeFlash imports
|
||||
if "import com.codeflash" in stripped:
|
||||
continue
|
||||
|
||||
result_lines.append(line)
|
||||
|
||||
return "".join(result_lines)
|
||||
693
codeflash/languages/java/parser.py
Normal file
693
codeflash/languages/java/parser.py
Normal file
|
|
@ -0,0 +1,693 @@
|
|||
"""Tree-sitter utilities for Java code analysis.
|
||||
|
||||
This module provides a unified interface for parsing and analyzing Java code
|
||||
using tree-sitter, following the same patterns as the JavaScript/TypeScript implementation.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from tree_sitter import Language, Parser
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from pathlib import Path
|
||||
|
||||
from tree_sitter import Node, Tree
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Lazy-loaded language instance
|
||||
_JAVA_LANGUAGE: Language | None = None
|
||||
|
||||
|
||||
def _get_java_language() -> Language:
|
||||
"""Get the Java tree-sitter Language instance, with lazy loading."""
|
||||
global _JAVA_LANGUAGE
|
||||
if _JAVA_LANGUAGE is None:
|
||||
import tree_sitter_java
|
||||
|
||||
_JAVA_LANGUAGE = Language(tree_sitter_java.language())
|
||||
return _JAVA_LANGUAGE
|
||||
|
||||
|
||||
@dataclass
|
||||
class JavaMethodNode:
|
||||
"""Represents a method found by tree-sitter analysis."""
|
||||
|
||||
name: str
|
||||
node: Node
|
||||
start_line: int
|
||||
end_line: int
|
||||
start_col: int
|
||||
end_col: int
|
||||
is_static: bool
|
||||
is_public: bool
|
||||
is_private: bool
|
||||
is_protected: bool
|
||||
is_abstract: bool
|
||||
is_synchronized: bool
|
||||
return_type: str | None
|
||||
class_name: str | None
|
||||
source_text: str
|
||||
javadoc_start_line: int | None = None # Line where Javadoc comment starts
|
||||
|
||||
|
||||
@dataclass
|
||||
class JavaClassNode:
|
||||
"""Represents a class found by tree-sitter analysis."""
|
||||
|
||||
name: str
|
||||
node: Node
|
||||
start_line: int
|
||||
end_line: int
|
||||
start_col: int
|
||||
end_col: int
|
||||
is_public: bool
|
||||
is_abstract: bool
|
||||
is_final: bool
|
||||
is_static: bool # For inner classes
|
||||
extends: str | None
|
||||
implements: list[str]
|
||||
source_text: str
|
||||
javadoc_start_line: int | None = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class JavaImportInfo:
|
||||
"""Represents a Java import statement."""
|
||||
|
||||
import_path: str # Full import path (e.g., "java.util.List")
|
||||
is_static: bool
|
||||
is_wildcard: bool # import java.util.*
|
||||
start_line: int
|
||||
end_line: int
|
||||
|
||||
|
||||
@dataclass
|
||||
class JavaFieldInfo:
|
||||
"""Represents a class field."""
|
||||
|
||||
name: str
|
||||
type_name: str
|
||||
is_static: bool
|
||||
is_final: bool
|
||||
is_public: bool
|
||||
is_private: bool
|
||||
is_protected: bool
|
||||
start_line: int
|
||||
end_line: int
|
||||
source_text: str
|
||||
|
||||
|
||||
class JavaAnalyzer:
|
||||
"""Java code analysis using tree-sitter.
|
||||
|
||||
This class provides methods to parse and analyze Java code,
|
||||
finding methods, classes, imports, and other code structures.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
"""Initialize the Java analyzer."""
|
||||
self._parser: Parser | None = None
|
||||
|
||||
@property
|
||||
def parser(self) -> Parser:
|
||||
"""Get the parser, creating it lazily."""
|
||||
if self._parser is None:
|
||||
self._parser = Parser(_get_java_language())
|
||||
return self._parser
|
||||
|
||||
def parse(self, source: str | bytes) -> Tree:
|
||||
"""Parse source code into a tree-sitter tree.
|
||||
|
||||
Args:
|
||||
source: Source code as string or bytes.
|
||||
|
||||
Returns:
|
||||
The parsed tree.
|
||||
|
||||
"""
|
||||
if isinstance(source, str):
|
||||
source = source.encode("utf8")
|
||||
return self.parser.parse(source)
|
||||
|
||||
def get_node_text(self, node: Node, source: bytes) -> str:
|
||||
"""Extract the source text for a tree-sitter node.
|
||||
|
||||
Args:
|
||||
node: The tree-sitter node.
|
||||
source: The source code as bytes.
|
||||
|
||||
Returns:
|
||||
The text content of the node.
|
||||
|
||||
"""
|
||||
return source[node.start_byte : node.end_byte].decode("utf8")
|
||||
|
||||
def find_methods(
|
||||
self, source: str, include_private: bool = True, include_static: bool = True
|
||||
) -> list[JavaMethodNode]:
|
||||
"""Find all method definitions in source code.
|
||||
|
||||
Args:
|
||||
source: The source code to analyze.
|
||||
include_private: Whether to include private methods.
|
||||
include_static: Whether to include static methods.
|
||||
|
||||
Returns:
|
||||
List of JavaMethodNode objects describing found methods.
|
||||
|
||||
"""
|
||||
source_bytes = source.encode("utf8")
|
||||
tree = self.parse(source_bytes)
|
||||
methods: list[JavaMethodNode] = []
|
||||
|
||||
self._walk_tree_for_methods(
|
||||
tree.root_node,
|
||||
source_bytes,
|
||||
methods,
|
||||
include_private=include_private,
|
||||
include_static=include_static,
|
||||
current_class=None,
|
||||
)
|
||||
|
||||
return methods
|
||||
|
||||
def _walk_tree_for_methods(
|
||||
self,
|
||||
node: Node,
|
||||
source_bytes: bytes,
|
||||
methods: list[JavaMethodNode],
|
||||
include_private: bool,
|
||||
include_static: bool,
|
||||
current_class: str | None,
|
||||
) -> None:
|
||||
"""Recursively walk the tree to find method definitions."""
|
||||
new_class = current_class
|
||||
|
||||
# Track class context
|
||||
if node.type == "class_declaration":
|
||||
name_node = node.child_by_field_name("name")
|
||||
if name_node:
|
||||
new_class = self.get_node_text(name_node, source_bytes)
|
||||
|
||||
if node.type == "method_declaration":
|
||||
method_info = self._extract_method_info(node, source_bytes, current_class)
|
||||
|
||||
if method_info:
|
||||
# Apply filters
|
||||
should_include = True
|
||||
|
||||
if method_info.is_private and not include_private:
|
||||
should_include = False
|
||||
|
||||
if method_info.is_static and not include_static:
|
||||
should_include = False
|
||||
|
||||
if should_include:
|
||||
methods.append(method_info)
|
||||
|
||||
# Recurse into children
|
||||
for child in node.children:
|
||||
self._walk_tree_for_methods(
|
||||
child,
|
||||
source_bytes,
|
||||
methods,
|
||||
include_private=include_private,
|
||||
include_static=include_static,
|
||||
current_class=new_class if node.type == "class_declaration" else current_class,
|
||||
)
|
||||
|
||||
def _extract_method_info(
|
||||
self, node: Node, source_bytes: bytes, current_class: str | None
|
||||
) -> JavaMethodNode | None:
|
||||
"""Extract method information from a method_declaration node."""
|
||||
name = ""
|
||||
is_static = False
|
||||
is_public = False
|
||||
is_private = False
|
||||
is_protected = False
|
||||
is_abstract = False
|
||||
is_synchronized = False
|
||||
return_type: str | None = None
|
||||
|
||||
# Get method name
|
||||
name_node = node.child_by_field_name("name")
|
||||
if name_node:
|
||||
name = self.get_node_text(name_node, source_bytes)
|
||||
|
||||
# Get return type
|
||||
type_node = node.child_by_field_name("type")
|
||||
if type_node:
|
||||
return_type = self.get_node_text(type_node, source_bytes)
|
||||
|
||||
# Check modifiers
|
||||
for child in node.children:
|
||||
if child.type == "modifiers":
|
||||
modifier_text = self.get_node_text(child, source_bytes)
|
||||
is_static = "static" in modifier_text
|
||||
is_public = "public" in modifier_text
|
||||
is_private = "private" in modifier_text
|
||||
is_protected = "protected" in modifier_text
|
||||
is_abstract = "abstract" in modifier_text
|
||||
is_synchronized = "synchronized" in modifier_text
|
||||
break
|
||||
|
||||
# Get source text
|
||||
source_text = self.get_node_text(node, source_bytes)
|
||||
|
||||
# Find preceding Javadoc comment
|
||||
javadoc_start_line = self._find_preceding_javadoc(node, source_bytes)
|
||||
|
||||
return JavaMethodNode(
|
||||
name=name,
|
||||
node=node,
|
||||
start_line=node.start_point[0] + 1, # Convert to 1-indexed
|
||||
end_line=node.end_point[0] + 1,
|
||||
start_col=node.start_point[1],
|
||||
end_col=node.end_point[1],
|
||||
is_static=is_static,
|
||||
is_public=is_public,
|
||||
is_private=is_private,
|
||||
is_protected=is_protected,
|
||||
is_abstract=is_abstract,
|
||||
is_synchronized=is_synchronized,
|
||||
return_type=return_type,
|
||||
class_name=current_class,
|
||||
source_text=source_text,
|
||||
javadoc_start_line=javadoc_start_line,
|
||||
)
|
||||
|
||||
def _find_preceding_javadoc(self, node: Node, source_bytes: bytes) -> int | None:
|
||||
"""Find Javadoc comment immediately preceding a node.
|
||||
|
||||
Args:
|
||||
node: The node to find Javadoc for.
|
||||
source_bytes: The source code as bytes.
|
||||
|
||||
Returns:
|
||||
The start line (1-indexed) of the Javadoc, or None if no Javadoc found.
|
||||
|
||||
"""
|
||||
# Get the previous sibling node
|
||||
prev_sibling = node.prev_named_sibling
|
||||
|
||||
# Check if it's a block comment that looks like Javadoc
|
||||
if prev_sibling and prev_sibling.type == "block_comment":
|
||||
comment_text = self.get_node_text(prev_sibling, source_bytes)
|
||||
if comment_text.strip().startswith("/**"):
|
||||
# Verify it's immediately preceding (no blank lines between)
|
||||
comment_end_line = prev_sibling.end_point[0]
|
||||
node_start_line = node.start_point[0]
|
||||
if node_start_line - comment_end_line <= 1:
|
||||
return prev_sibling.start_point[0] + 1 # 1-indexed
|
||||
|
||||
return None
|
||||
|
||||
def find_classes(self, source: str) -> list[JavaClassNode]:
|
||||
"""Find all class definitions in source code.
|
||||
|
||||
Args:
|
||||
source: The source code to analyze.
|
||||
|
||||
Returns:
|
||||
List of JavaClassNode objects.
|
||||
|
||||
"""
|
||||
source_bytes = source.encode("utf8")
|
||||
tree = self.parse(source_bytes)
|
||||
classes: list[JavaClassNode] = []
|
||||
|
||||
self._walk_tree_for_classes(tree.root_node, source_bytes, classes, is_inner=False)
|
||||
|
||||
return classes
|
||||
|
||||
def _walk_tree_for_classes(
|
||||
self, node: Node, source_bytes: bytes, classes: list[JavaClassNode], is_inner: bool
|
||||
) -> None:
|
||||
"""Recursively walk the tree to find class definitions."""
|
||||
if node.type == "class_declaration":
|
||||
class_info = self._extract_class_info(node, source_bytes, is_inner)
|
||||
if class_info:
|
||||
classes.append(class_info)
|
||||
|
||||
# Look for inner classes
|
||||
body_node = node.child_by_field_name("body")
|
||||
if body_node:
|
||||
for child in body_node.children:
|
||||
self._walk_tree_for_classes(child, source_bytes, classes, is_inner=True)
|
||||
return
|
||||
|
||||
# Continue walking for top-level classes
|
||||
for child in node.children:
|
||||
self._walk_tree_for_classes(child, source_bytes, classes, is_inner)
|
||||
|
||||
def _extract_class_info(
|
||||
self, node: Node, source_bytes: bytes, is_inner: bool
|
||||
) -> JavaClassNode | None:
|
||||
"""Extract class information from a class_declaration node."""
|
||||
name = ""
|
||||
is_public = False
|
||||
is_abstract = False
|
||||
is_final = False
|
||||
is_static = False
|
||||
extends: str | None = None
|
||||
implements: list[str] = []
|
||||
|
||||
# Get class name
|
||||
name_node = node.child_by_field_name("name")
|
||||
if name_node:
|
||||
name = self.get_node_text(name_node, source_bytes)
|
||||
|
||||
# Check modifiers
|
||||
for child in node.children:
|
||||
if child.type == "modifiers":
|
||||
modifier_text = self.get_node_text(child, source_bytes)
|
||||
is_public = "public" in modifier_text
|
||||
is_abstract = "abstract" in modifier_text
|
||||
is_final = "final" in modifier_text
|
||||
is_static = "static" in modifier_text
|
||||
break
|
||||
|
||||
# Get superclass
|
||||
superclass_node = node.child_by_field_name("superclass")
|
||||
if superclass_node:
|
||||
# superclass contains "extends ClassName"
|
||||
for child in superclass_node.children:
|
||||
if child.type == "type_identifier":
|
||||
extends = self.get_node_text(child, source_bytes)
|
||||
break
|
||||
|
||||
# Get interfaces (super_interfaces node contains the implements clause)
|
||||
for child in node.children:
|
||||
if child.type == "super_interfaces":
|
||||
# Find the type_list inside super_interfaces
|
||||
for subchild in child.children:
|
||||
if subchild.type == "type_list":
|
||||
for type_node in subchild.children:
|
||||
if type_node.type == "type_identifier":
|
||||
implements.append(self.get_node_text(type_node, source_bytes))
|
||||
|
||||
# Get source text
|
||||
source_text = self.get_node_text(node, source_bytes)
|
||||
|
||||
# Find preceding Javadoc
|
||||
javadoc_start_line = self._find_preceding_javadoc(node, source_bytes)
|
||||
|
||||
return JavaClassNode(
|
||||
name=name,
|
||||
node=node,
|
||||
start_line=node.start_point[0] + 1,
|
||||
end_line=node.end_point[0] + 1,
|
||||
start_col=node.start_point[1],
|
||||
end_col=node.end_point[1],
|
||||
is_public=is_public,
|
||||
is_abstract=is_abstract,
|
||||
is_final=is_final,
|
||||
is_static=is_static,
|
||||
extends=extends,
|
||||
implements=implements,
|
||||
source_text=source_text,
|
||||
javadoc_start_line=javadoc_start_line,
|
||||
)
|
||||
|
||||
def find_imports(self, source: str) -> list[JavaImportInfo]:
|
||||
"""Find all import statements in source code.
|
||||
|
||||
Args:
|
||||
source: The source code to analyze.
|
||||
|
||||
Returns:
|
||||
List of JavaImportInfo objects.
|
||||
|
||||
"""
|
||||
source_bytes = source.encode("utf8")
|
||||
tree = self.parse(source_bytes)
|
||||
imports: list[JavaImportInfo] = []
|
||||
|
||||
for child in tree.root_node.children:
|
||||
if child.type == "import_declaration":
|
||||
import_info = self._extract_import_info(child, source_bytes)
|
||||
if import_info:
|
||||
imports.append(import_info)
|
||||
|
||||
return imports
|
||||
|
||||
def _extract_import_info(self, node: Node, source_bytes: bytes) -> JavaImportInfo | None:
|
||||
"""Extract import information from an import_declaration node."""
|
||||
import_path = ""
|
||||
is_static = False
|
||||
is_wildcard = False
|
||||
|
||||
# Check for static import
|
||||
for child in node.children:
|
||||
if child.type == "static":
|
||||
is_static = True
|
||||
break
|
||||
|
||||
# Get the import path (scoped_identifier or identifier)
|
||||
for child in node.children:
|
||||
if child.type == "scoped_identifier":
|
||||
import_path = self.get_node_text(child, source_bytes)
|
||||
break
|
||||
if child.type == "identifier":
|
||||
import_path = self.get_node_text(child, source_bytes)
|
||||
break
|
||||
|
||||
# Check for wildcard
|
||||
if import_path.endswith(".*") or ".*" in self.get_node_text(node, source_bytes):
|
||||
is_wildcard = True
|
||||
|
||||
# Clean up the import path
|
||||
import_path = import_path.rstrip(".*").rstrip(".")
|
||||
|
||||
return JavaImportInfo(
|
||||
import_path=import_path,
|
||||
is_static=is_static,
|
||||
is_wildcard=is_wildcard,
|
||||
start_line=node.start_point[0] + 1,
|
||||
end_line=node.end_point[0] + 1,
|
||||
)
|
||||
|
||||
def find_fields(self, source: str, class_name: str | None = None) -> list[JavaFieldInfo]:
|
||||
"""Find all field declarations in source code.
|
||||
|
||||
Args:
|
||||
source: The source code to analyze.
|
||||
class_name: Optional class name to filter fields.
|
||||
|
||||
Returns:
|
||||
List of JavaFieldInfo objects.
|
||||
|
||||
"""
|
||||
source_bytes = source.encode("utf8")
|
||||
tree = self.parse(source_bytes)
|
||||
fields: list[JavaFieldInfo] = []
|
||||
|
||||
self._walk_tree_for_fields(tree.root_node, source_bytes, fields, current_class=None, target_class=class_name)
|
||||
|
||||
return fields
|
||||
|
||||
def _walk_tree_for_fields(
|
||||
self,
|
||||
node: Node,
|
||||
source_bytes: bytes,
|
||||
fields: list[JavaFieldInfo],
|
||||
current_class: str | None,
|
||||
target_class: str | None,
|
||||
) -> None:
|
||||
"""Recursively walk the tree to find field declarations."""
|
||||
new_class = current_class
|
||||
|
||||
if node.type == "class_declaration":
|
||||
name_node = node.child_by_field_name("name")
|
||||
if name_node:
|
||||
new_class = self.get_node_text(name_node, source_bytes)
|
||||
|
||||
if node.type == "field_declaration":
|
||||
# Only include if we're in the target class (or no target specified)
|
||||
if target_class is None or current_class == target_class:
|
||||
field_info = self._extract_field_info(node, source_bytes)
|
||||
if field_info:
|
||||
fields.extend(field_info)
|
||||
|
||||
for child in node.children:
|
||||
self._walk_tree_for_fields(
|
||||
child,
|
||||
source_bytes,
|
||||
fields,
|
||||
current_class=new_class if node.type == "class_declaration" else current_class,
|
||||
target_class=target_class,
|
||||
)
|
||||
|
||||
def _extract_field_info(self, node: Node, source_bytes: bytes) -> list[JavaFieldInfo]:
|
||||
"""Extract field information from a field_declaration node.
|
||||
|
||||
Returns a list because a single declaration can define multiple fields.
|
||||
"""
|
||||
fields: list[JavaFieldInfo] = []
|
||||
is_static = False
|
||||
is_final = False
|
||||
is_public = False
|
||||
is_private = False
|
||||
is_protected = False
|
||||
type_name = ""
|
||||
|
||||
# Check modifiers
|
||||
for child in node.children:
|
||||
if child.type == "modifiers":
|
||||
modifier_text = self.get_node_text(child, source_bytes)
|
||||
is_static = "static" in modifier_text
|
||||
is_final = "final" in modifier_text
|
||||
is_public = "public" in modifier_text
|
||||
is_private = "private" in modifier_text
|
||||
is_protected = "protected" in modifier_text
|
||||
break
|
||||
|
||||
# Get type
|
||||
type_node = node.child_by_field_name("type")
|
||||
if type_node:
|
||||
type_name = self.get_node_text(type_node, source_bytes)
|
||||
|
||||
# Get variable declarators (there can be multiple: int a, b, c;)
|
||||
for child in node.children:
|
||||
if child.type == "variable_declarator":
|
||||
name_node = child.child_by_field_name("name")
|
||||
if name_node:
|
||||
field_name = self.get_node_text(name_node, source_bytes)
|
||||
fields.append(
|
||||
JavaFieldInfo(
|
||||
name=field_name,
|
||||
type_name=type_name,
|
||||
is_static=is_static,
|
||||
is_final=is_final,
|
||||
is_public=is_public,
|
||||
is_private=is_private,
|
||||
is_protected=is_protected,
|
||||
start_line=node.start_point[0] + 1,
|
||||
end_line=node.end_point[0] + 1,
|
||||
source_text=self.get_node_text(node, source_bytes),
|
||||
)
|
||||
)
|
||||
|
||||
return fields
|
||||
|
||||
def find_method_calls(self, source: str, within_method: JavaMethodNode) -> list[str]:
|
||||
"""Find all method calls within a specific method's body.
|
||||
|
||||
Args:
|
||||
source: The full source code.
|
||||
within_method: The method to search within.
|
||||
|
||||
Returns:
|
||||
List of method names that are called.
|
||||
|
||||
"""
|
||||
calls: list[str] = []
|
||||
source_bytes = source.encode("utf8")
|
||||
|
||||
# Get the body of the method
|
||||
body_node = within_method.node.child_by_field_name("body")
|
||||
if body_node:
|
||||
self._walk_tree_for_calls(body_node, source_bytes, calls)
|
||||
|
||||
return list(set(calls)) # Remove duplicates
|
||||
|
||||
def _walk_tree_for_calls(self, node: Node, source_bytes: bytes, calls: list[str]) -> None:
|
||||
"""Recursively find method calls in a subtree."""
|
||||
if node.type == "method_invocation":
|
||||
name_node = node.child_by_field_name("name")
|
||||
if name_node:
|
||||
calls.append(self.get_node_text(name_node, source_bytes))
|
||||
|
||||
for child in node.children:
|
||||
self._walk_tree_for_calls(child, source_bytes, calls)
|
||||
|
||||
def has_return_statement(self, method_node: JavaMethodNode, source: str) -> bool:
|
||||
"""Check if a method has a return statement.
|
||||
|
||||
Args:
|
||||
method_node: The method to check.
|
||||
source: The source code.
|
||||
|
||||
Returns:
|
||||
True if the method has a return statement.
|
||||
|
||||
"""
|
||||
# void methods don't need return statements
|
||||
if method_node.return_type == "void":
|
||||
return False
|
||||
|
||||
return self._node_has_return(method_node.node)
|
||||
|
||||
def _node_has_return(self, node: Node) -> bool:
|
||||
"""Recursively check if a node contains a return statement."""
|
||||
if node.type == "return_statement":
|
||||
return True
|
||||
|
||||
# Don't recurse into nested method declarations (lambdas)
|
||||
if node.type in ("lambda_expression", "method_declaration"):
|
||||
if node.type == "method_declaration":
|
||||
body_node = node.child_by_field_name("body")
|
||||
if body_node:
|
||||
for child in body_node.children:
|
||||
if self._node_has_return(child):
|
||||
return True
|
||||
return False
|
||||
|
||||
return any(self._node_has_return(child) for child in node.children)
|
||||
|
||||
def validate_syntax(self, source: str) -> bool:
|
||||
"""Check if Java source code is syntactically valid.
|
||||
|
||||
Uses tree-sitter to parse and check for errors.
|
||||
|
||||
Args:
|
||||
source: Source code to validate.
|
||||
|
||||
Returns:
|
||||
True if valid, False otherwise.
|
||||
|
||||
"""
|
||||
try:
|
||||
tree = self.parse(source)
|
||||
return not tree.root_node.has_error
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
def get_package_name(self, source: str) -> str | None:
|
||||
"""Extract the package name from Java source code.
|
||||
|
||||
Args:
|
||||
source: The source code to analyze.
|
||||
|
||||
Returns:
|
||||
The package name, or None if not found.
|
||||
|
||||
"""
|
||||
source_bytes = source.encode("utf8")
|
||||
tree = self.parse(source_bytes)
|
||||
|
||||
for child in tree.root_node.children:
|
||||
if child.type == "package_declaration":
|
||||
# Find the scoped_identifier within the package declaration
|
||||
for pkg_child in child.children:
|
||||
if pkg_child.type == "scoped_identifier":
|
||||
return self.get_node_text(pkg_child, source_bytes)
|
||||
if pkg_child.type == "identifier":
|
||||
return self.get_node_text(pkg_child, source_bytes)
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def get_java_analyzer() -> JavaAnalyzer:
|
||||
"""Get a JavaAnalyzer instance.
|
||||
|
||||
Returns:
|
||||
JavaAnalyzer configured for Java.
|
||||
|
||||
"""
|
||||
return JavaAnalyzer()
|
||||
420
codeflash/languages/java/replacement.py
Normal file
420
codeflash/languages/java/replacement.py
Normal file
|
|
@ -0,0 +1,420 @@
|
|||
"""Java code replacement.
|
||||
|
||||
This module provides functionality to replace function implementations
|
||||
in Java source code while preserving formatting and structure.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import re
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from codeflash.languages.base import FunctionInfo
|
||||
from codeflash.languages.java.parser import JavaAnalyzer, JavaMethodNode, get_java_analyzer
|
||||
|
||||
if TYPE_CHECKING:
|
||||
pass
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def replace_function(
|
||||
source: str,
|
||||
function: FunctionInfo,
|
||||
new_source: str,
|
||||
analyzer: JavaAnalyzer | None = None,
|
||||
) -> str:
|
||||
"""Replace a function in source code with new implementation.
|
||||
|
||||
Preserves:
|
||||
- Surrounding whitespace and formatting
|
||||
- Javadoc comments (if they should be preserved)
|
||||
- Annotations
|
||||
|
||||
Args:
|
||||
source: Original source code.
|
||||
function: FunctionInfo identifying the function to replace.
|
||||
new_source: New function source code.
|
||||
analyzer: Optional JavaAnalyzer instance.
|
||||
|
||||
Returns:
|
||||
Modified source code with function replaced.
|
||||
|
||||
"""
|
||||
analyzer = analyzer or get_java_analyzer()
|
||||
|
||||
# Find the method in the source
|
||||
methods = analyzer.find_methods(source)
|
||||
target_method = None
|
||||
|
||||
for method in methods:
|
||||
if method.name == function.name:
|
||||
if function.class_name is None or method.class_name == function.class_name:
|
||||
target_method = method
|
||||
break
|
||||
|
||||
if not target_method:
|
||||
logger.error("Could not find method %s in source", function.name)
|
||||
return source
|
||||
|
||||
# Determine replacement range
|
||||
# Include Javadoc if present
|
||||
start_line = target_method.javadoc_start_line or target_method.start_line
|
||||
end_line = target_method.end_line
|
||||
|
||||
# Split source into lines
|
||||
lines = source.splitlines(keepends=True)
|
||||
|
||||
# Get indentation from the original method
|
||||
original_first_line = lines[start_line - 1] if start_line <= len(lines) else ""
|
||||
indent = _get_indentation(original_first_line)
|
||||
|
||||
# Ensure new source has correct indentation
|
||||
new_source_lines = new_source.splitlines(keepends=True)
|
||||
indented_new_source = _apply_indentation(new_source_lines, indent)
|
||||
|
||||
# Build the result
|
||||
before = lines[: start_line - 1] # Lines before the method
|
||||
after = lines[end_line:] # Lines after the method
|
||||
|
||||
result = "".join(before) + indented_new_source + "".join(after)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def _get_indentation(line: str) -> str:
|
||||
"""Extract the indentation from a line.
|
||||
|
||||
Args:
|
||||
line: The line to analyze.
|
||||
|
||||
Returns:
|
||||
The indentation string (spaces/tabs).
|
||||
|
||||
"""
|
||||
match = re.match(r"^(\s*)", line)
|
||||
return match.group(1) if match else ""
|
||||
|
||||
|
||||
def _apply_indentation(lines: list[str], base_indent: str) -> str:
|
||||
"""Apply indentation to all lines.
|
||||
|
||||
Args:
|
||||
lines: Lines to indent.
|
||||
base_indent: Base indentation to apply.
|
||||
|
||||
Returns:
|
||||
Indented source code.
|
||||
|
||||
"""
|
||||
if not lines:
|
||||
return ""
|
||||
|
||||
# Detect the existing indentation in the new source
|
||||
existing_indent = ""
|
||||
for line in lines:
|
||||
stripped = line.lstrip()
|
||||
if stripped and not stripped.startswith("//") and not stripped.startswith("/*"):
|
||||
existing_indent = _get_indentation(line)
|
||||
break
|
||||
|
||||
result_lines = []
|
||||
for line in lines:
|
||||
if not line.strip():
|
||||
result_lines.append(line)
|
||||
else:
|
||||
# Remove existing indentation and apply new base indentation
|
||||
stripped_line = line.lstrip()
|
||||
# Calculate relative indentation
|
||||
line_indent = _get_indentation(line)
|
||||
if existing_indent and line_indent.startswith(existing_indent):
|
||||
relative_indent = line_indent[len(existing_indent) :]
|
||||
else:
|
||||
relative_indent = ""
|
||||
result_lines.append(base_indent + relative_indent + stripped_line)
|
||||
|
||||
return "".join(result_lines)
|
||||
|
||||
|
||||
def replace_method_body(
|
||||
source: str,
|
||||
function: FunctionInfo,
|
||||
new_body: str,
|
||||
analyzer: JavaAnalyzer | None = None,
|
||||
) -> str:
|
||||
"""Replace just the body of a method, preserving signature.
|
||||
|
||||
Args:
|
||||
source: Original source code.
|
||||
function: FunctionInfo identifying the function.
|
||||
new_body: New method body (code between braces).
|
||||
analyzer: Optional JavaAnalyzer instance.
|
||||
|
||||
Returns:
|
||||
Modified source code.
|
||||
|
||||
"""
|
||||
analyzer = analyzer or get_java_analyzer()
|
||||
source_bytes = source.encode("utf8")
|
||||
|
||||
# Find the method
|
||||
methods = analyzer.find_methods(source)
|
||||
target_method = None
|
||||
|
||||
for method in methods:
|
||||
if method.name == function.name:
|
||||
if function.class_name is None or method.class_name == function.class_name:
|
||||
target_method = method
|
||||
break
|
||||
|
||||
if not target_method:
|
||||
logger.error("Could not find method %s", function.name)
|
||||
return source
|
||||
|
||||
# Find the body node
|
||||
body_node = target_method.node.child_by_field_name("body")
|
||||
if not body_node:
|
||||
logger.error("Method %s has no body (abstract?)", function.name)
|
||||
return source
|
||||
|
||||
# Get the body's byte positions
|
||||
body_start = body_node.start_byte
|
||||
body_end = body_node.end_byte
|
||||
|
||||
# Get indentation
|
||||
body_start_line = body_node.start_point[0]
|
||||
lines = source.splitlines(keepends=True)
|
||||
base_indent = _get_indentation(lines[body_start_line]) if body_start_line < len(lines) else " "
|
||||
|
||||
# Format the new body
|
||||
new_body = new_body.strip()
|
||||
if not new_body.startswith("{"):
|
||||
new_body = "{\n" + base_indent + " " + new_body
|
||||
if not new_body.endswith("}"):
|
||||
new_body = new_body + "\n" + base_indent + "}"
|
||||
|
||||
# Replace the body
|
||||
before = source_bytes[:body_start]
|
||||
after = source_bytes[body_end:]
|
||||
|
||||
return (before + new_body.encode("utf8") + after).decode("utf8")
|
||||
|
||||
|
||||
def insert_method(
|
||||
source: str,
|
||||
class_name: str,
|
||||
method_source: str,
|
||||
position: str = "end", # "end" or "start"
|
||||
analyzer: JavaAnalyzer | None = None,
|
||||
) -> str:
|
||||
"""Insert a new method into a class.
|
||||
|
||||
Args:
|
||||
source: The source code.
|
||||
class_name: Name of the class to insert into.
|
||||
method_source: Source code of the method to insert.
|
||||
position: Where to insert ("end" or "start" of class body).
|
||||
analyzer: Optional JavaAnalyzer instance.
|
||||
|
||||
Returns:
|
||||
Source code with method inserted.
|
||||
|
||||
"""
|
||||
analyzer = analyzer or get_java_analyzer()
|
||||
|
||||
# Find the class
|
||||
classes = analyzer.find_classes(source)
|
||||
target_class = None
|
||||
|
||||
for cls in classes:
|
||||
if cls.name == class_name:
|
||||
target_class = cls
|
||||
break
|
||||
|
||||
if not target_class:
|
||||
logger.error("Could not find class %s", class_name)
|
||||
return source
|
||||
|
||||
# Find the class body
|
||||
body_node = target_class.node.child_by_field_name("body")
|
||||
if not body_node:
|
||||
logger.error("Class %s has no body", class_name)
|
||||
return source
|
||||
|
||||
# Get insertion point
|
||||
source_bytes = source.encode("utf8")
|
||||
|
||||
if position == "end":
|
||||
# Insert before the closing brace
|
||||
insert_point = body_node.end_byte - 1
|
||||
else:
|
||||
# Insert after the opening brace
|
||||
insert_point = body_node.start_byte + 1
|
||||
|
||||
# Get indentation (typically 4 spaces inside a class)
|
||||
lines = source.splitlines(keepends=True)
|
||||
class_line = target_class.start_line - 1
|
||||
class_indent = _get_indentation(lines[class_line]) if class_line < len(lines) else ""
|
||||
method_indent = class_indent + " "
|
||||
|
||||
# Format the method
|
||||
method_lines = method_source.strip().splitlines(keepends=True)
|
||||
indented_method = _apply_indentation(method_lines, method_indent)
|
||||
|
||||
# Insert the method
|
||||
before = source_bytes[:insert_point]
|
||||
after = source_bytes[insert_point:]
|
||||
|
||||
separator = "\n\n" if position == "end" else "\n"
|
||||
|
||||
return (before + separator.encode("utf8") + indented_method.encode("utf8") + after).decode("utf8")
|
||||
|
||||
|
||||
def remove_method(
|
||||
source: str,
|
||||
function: FunctionInfo,
|
||||
analyzer: JavaAnalyzer | None = None,
|
||||
) -> str:
|
||||
"""Remove a method from source code.
|
||||
|
||||
Args:
|
||||
source: The source code.
|
||||
function: FunctionInfo identifying the method to remove.
|
||||
analyzer: Optional JavaAnalyzer instance.
|
||||
|
||||
Returns:
|
||||
Source code with method removed.
|
||||
|
||||
"""
|
||||
analyzer = analyzer or get_java_analyzer()
|
||||
|
||||
# Find the method
|
||||
methods = analyzer.find_methods(source)
|
||||
target_method = None
|
||||
|
||||
for method in methods:
|
||||
if method.name == function.name:
|
||||
if function.class_name is None or method.class_name == function.class_name:
|
||||
target_method = method
|
||||
break
|
||||
|
||||
if not target_method:
|
||||
logger.error("Could not find method %s", function.name)
|
||||
return source
|
||||
|
||||
# Determine removal range (include Javadoc)
|
||||
start_line = target_method.javadoc_start_line or target_method.start_line
|
||||
end_line = target_method.end_line
|
||||
|
||||
lines = source.splitlines(keepends=True)
|
||||
|
||||
# Remove the method lines
|
||||
before = lines[: start_line - 1]
|
||||
after = lines[end_line:]
|
||||
|
||||
return "".join(before) + "".join(after)
|
||||
|
||||
|
||||
def remove_test_functions(
|
||||
test_source: str,
|
||||
functions_to_remove: list[str],
|
||||
analyzer: JavaAnalyzer | None = None,
|
||||
) -> str:
|
||||
"""Remove specific test functions from test source code.
|
||||
|
||||
Args:
|
||||
test_source: Test source code.
|
||||
functions_to_remove: List of function names to remove.
|
||||
analyzer: Optional JavaAnalyzer instance.
|
||||
|
||||
Returns:
|
||||
Test source code with specified functions removed.
|
||||
|
||||
"""
|
||||
analyzer = analyzer or get_java_analyzer()
|
||||
|
||||
# Find all methods
|
||||
methods = analyzer.find_methods(test_source)
|
||||
|
||||
# Sort by start line in reverse order (remove from end first)
|
||||
methods_to_remove = [
|
||||
m for m in methods if m.name in functions_to_remove
|
||||
]
|
||||
methods_to_remove.sort(key=lambda m: m.start_line, reverse=True)
|
||||
|
||||
result = test_source
|
||||
|
||||
for method in methods_to_remove:
|
||||
# Create a FunctionInfo for removal
|
||||
func_info = FunctionInfo(
|
||||
name=method.name,
|
||||
file_path=Path("temp.java"),
|
||||
start_line=method.start_line,
|
||||
end_line=method.end_line,
|
||||
parents=(),
|
||||
is_method=True,
|
||||
)
|
||||
result = remove_method(result, func_info, analyzer)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def add_runtime_comments(
|
||||
test_source: str,
|
||||
original_runtimes: dict[str, int],
|
||||
optimized_runtimes: dict[str, int],
|
||||
analyzer: JavaAnalyzer | None = None,
|
||||
) -> str:
|
||||
"""Add runtime performance comments to test source code.
|
||||
|
||||
Adds comments showing the original vs optimized runtime for each
|
||||
function call (e.g., "// 1.5ms -> 0.3ms (80% faster)").
|
||||
|
||||
Args:
|
||||
test_source: Test source code to annotate.
|
||||
original_runtimes: Map of invocation IDs to original runtimes (ns).
|
||||
optimized_runtimes: Map of invocation IDs to optimized runtimes (ns).
|
||||
analyzer: Optional JavaAnalyzer instance.
|
||||
|
||||
Returns:
|
||||
Test source code with runtime comments added.
|
||||
|
||||
"""
|
||||
if not original_runtimes or not optimized_runtimes:
|
||||
return test_source
|
||||
|
||||
# For now, add a summary comment at the top
|
||||
summary_lines = ["// Performance comparison:"]
|
||||
|
||||
for inv_id in original_runtimes:
|
||||
original_ns = original_runtimes[inv_id]
|
||||
optimized_ns = optimized_runtimes.get(inv_id, original_ns)
|
||||
|
||||
original_ms = original_ns / 1_000_000
|
||||
optimized_ms = optimized_ns / 1_000_000
|
||||
|
||||
if original_ns > 0:
|
||||
speedup = ((original_ns - optimized_ns) / original_ns) * 100
|
||||
summary_lines.append(
|
||||
f"// {inv_id}: {original_ms:.3f}ms -> {optimized_ms:.3f}ms ({speedup:.1f}% faster)"
|
||||
)
|
||||
|
||||
# Insert after imports
|
||||
lines = test_source.splitlines(keepends=True)
|
||||
insert_idx = 0
|
||||
|
||||
for i, line in enumerate(lines):
|
||||
if line.strip().startswith("import "):
|
||||
insert_idx = i + 1
|
||||
elif line.strip() and not line.strip().startswith("//") and not line.strip().startswith("package"):
|
||||
if insert_idx == 0:
|
||||
insert_idx = i
|
||||
break
|
||||
|
||||
# Insert summary
|
||||
summary = "\n".join(summary_lines) + "\n\n"
|
||||
lines.insert(insert_idx, summary)
|
||||
|
||||
return "".join(lines)
|
||||
384
codeflash/languages/java/support.py
Normal file
384
codeflash/languages/java/support.py
Normal file
|
|
@ -0,0 +1,384 @@
|
|||
"""Main JavaSupport class implementing the LanguageSupport protocol.
|
||||
|
||||
This module provides the main JavaSupport class that implements all
|
||||
required methods for Java language support in codeflash.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from codeflash.languages.base import (
|
||||
CodeContext,
|
||||
FunctionFilterCriteria,
|
||||
FunctionInfo,
|
||||
HelperFunction,
|
||||
Language,
|
||||
LanguageSupport,
|
||||
TestInfo,
|
||||
TestResult,
|
||||
)
|
||||
from codeflash.languages.registry import register_language
|
||||
from codeflash.languages.java.build_tools import find_test_root
|
||||
from codeflash.languages.java.comparator import compare_test_results as _compare_test_results
|
||||
from codeflash.languages.java.config import detect_java_project
|
||||
from codeflash.languages.java.context import extract_code_context, find_helper_functions
|
||||
from codeflash.languages.java.discovery import discover_functions, discover_functions_from_source
|
||||
from codeflash.languages.java.formatter import format_java_code, normalize_java_code
|
||||
from codeflash.languages.java.instrumentation import (
|
||||
instrument_existing_test,
|
||||
instrument_for_behavior,
|
||||
instrument_for_benchmarking,
|
||||
)
|
||||
from codeflash.languages.java.parser import get_java_analyzer
|
||||
from codeflash.languages.java.replacement import (
|
||||
add_runtime_comments,
|
||||
remove_test_functions,
|
||||
replace_function,
|
||||
)
|
||||
from codeflash.languages.java.test_discovery import discover_tests
|
||||
from codeflash.languages.java.test_runner import (
|
||||
parse_test_results,
|
||||
run_behavioral_tests,
|
||||
run_benchmarking_tests,
|
||||
run_tests,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Sequence
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@register_language
|
||||
class JavaSupport(LanguageSupport):
|
||||
"""Java language support implementation.
|
||||
|
||||
Implements the LanguageSupport protocol for Java, providing:
|
||||
- Function discovery using tree-sitter
|
||||
- Test discovery for JUnit 5
|
||||
- Test execution via Maven Surefire
|
||||
- Code context extraction
|
||||
- Code replacement and formatting
|
||||
- Behavior capture instrumentation
|
||||
- Benchmarking instrumentation
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
"""Initialize Java support."""
|
||||
self._analyzer = get_java_analyzer()
|
||||
|
||||
@property
|
||||
def language(self) -> Language:
|
||||
"""The language this implementation supports."""
|
||||
return Language.JAVA
|
||||
|
||||
@property
|
||||
def file_extensions(self) -> tuple[str, ...]:
|
||||
"""File extensions supported by Java."""
|
||||
return (".java",)
|
||||
|
||||
@property
|
||||
def test_framework(self) -> str:
|
||||
"""Primary test framework name."""
|
||||
return "junit5"
|
||||
|
||||
@property
|
||||
def comment_prefix(self) -> str:
|
||||
"""Comment prefix for Java."""
|
||||
return "//"
|
||||
|
||||
# === Discovery ===
|
||||
|
||||
def discover_functions(
|
||||
self, file_path: Path, filter_criteria: FunctionFilterCriteria | None = None
|
||||
) -> list[FunctionInfo]:
|
||||
"""Find all optimizable functions in a Java file."""
|
||||
return discover_functions(file_path, filter_criteria, self._analyzer)
|
||||
|
||||
def discover_tests(
|
||||
self, test_root: Path, source_functions: Sequence[FunctionInfo]
|
||||
) -> dict[str, list[TestInfo]]:
|
||||
"""Map source functions to their tests."""
|
||||
return discover_tests(test_root, source_functions, self._analyzer)
|
||||
|
||||
# === Code Analysis ===
|
||||
|
||||
def extract_code_context(
|
||||
self, function: FunctionInfo, project_root: Path, module_root: Path
|
||||
) -> CodeContext:
|
||||
"""Extract function code and its dependencies."""
|
||||
return extract_code_context(function, project_root, module_root, analyzer=self._analyzer)
|
||||
|
||||
def find_helper_functions(
|
||||
self, function: FunctionInfo, project_root: Path
|
||||
) -> list[HelperFunction]:
|
||||
"""Find helper functions called by the target function."""
|
||||
return find_helper_functions(function, project_root, analyzer=self._analyzer)
|
||||
|
||||
# === Code Transformation ===
|
||||
|
||||
def replace_function(
|
||||
self, source: str, function: FunctionInfo, new_source: str
|
||||
) -> str:
|
||||
"""Replace a function in source code with new implementation."""
|
||||
return replace_function(source, function, new_source, self._analyzer)
|
||||
|
||||
def format_code(self, source: str, file_path: Path | None = None) -> str:
|
||||
"""Format Java code."""
|
||||
project_root = file_path.parent if file_path else None
|
||||
return format_java_code(source, project_root)
|
||||
|
||||
# === Test Execution ===
|
||||
|
||||
def run_tests(
|
||||
self,
|
||||
test_files: Sequence[Path],
|
||||
cwd: Path,
|
||||
env: dict[str, str],
|
||||
timeout: int,
|
||||
) -> tuple[list[TestResult], Path]:
|
||||
"""Run tests and return results."""
|
||||
return run_tests(list(test_files), cwd, env, timeout)
|
||||
|
||||
def parse_test_results(self, junit_xml_path: Path, stdout: str) -> list[TestResult]:
|
||||
"""Parse test results from JUnit XML."""
|
||||
return parse_test_results(junit_xml_path, stdout)
|
||||
|
||||
# === Instrumentation ===
|
||||
|
||||
def instrument_for_behavior(
|
||||
self, source: str, functions: Sequence[FunctionInfo]
|
||||
) -> str:
|
||||
"""Add behavior instrumentation to capture inputs/outputs."""
|
||||
return instrument_for_behavior(source, functions, self._analyzer)
|
||||
|
||||
def instrument_for_benchmarking(
|
||||
self, test_source: str, target_function: FunctionInfo
|
||||
) -> str:
|
||||
"""Add timing instrumentation to test code."""
|
||||
return instrument_for_benchmarking(test_source, target_function, self._analyzer)
|
||||
|
||||
# === Validation ===
|
||||
|
||||
def validate_syntax(self, source: str) -> bool:
|
||||
"""Check if Java source code is syntactically valid."""
|
||||
return self._analyzer.validate_syntax(source)
|
||||
|
||||
def normalize_code(self, source: str) -> str:
|
||||
"""Normalize code for deduplication."""
|
||||
return normalize_java_code(source)
|
||||
|
||||
# === Test Editing ===
|
||||
|
||||
def add_runtime_comments(
|
||||
self,
|
||||
test_source: str,
|
||||
original_runtimes: dict[str, int],
|
||||
optimized_runtimes: dict[str, int],
|
||||
) -> str:
|
||||
"""Add runtime performance comments to test source code."""
|
||||
return add_runtime_comments(test_source, original_runtimes, optimized_runtimes, self._analyzer)
|
||||
|
||||
def remove_test_functions(
|
||||
self, test_source: str, functions_to_remove: list[str]
|
||||
) -> str:
|
||||
"""Remove specific test functions from test source code."""
|
||||
return remove_test_functions(test_source, functions_to_remove, self._analyzer)
|
||||
|
||||
# === Test Result Comparison ===
|
||||
|
||||
def compare_test_results(
|
||||
self,
|
||||
original_results_path: Path,
|
||||
candidate_results_path: Path,
|
||||
project_root: Path | None = None,
|
||||
) -> tuple[bool, list]:
|
||||
"""Compare test results between original and candidate code."""
|
||||
return _compare_test_results(
|
||||
original_results_path, candidate_results_path, project_root=project_root
|
||||
)
|
||||
|
||||
# === Configuration ===
|
||||
|
||||
def get_test_file_suffix(self) -> str:
|
||||
"""Get the test file suffix for Java."""
|
||||
return "Test.java"
|
||||
|
||||
def get_comment_prefix(self) -> str:
|
||||
"""Get the comment prefix for Java."""
|
||||
return "//"
|
||||
|
||||
def find_test_root(self, project_root: Path) -> Path | None:
|
||||
"""Find the test root directory for a Java project."""
|
||||
return find_test_root(project_root)
|
||||
|
||||
def get_project_root(self, source_file: Path) -> Path | None:
|
||||
"""Find the project root for a Java file.
|
||||
|
||||
Looks for pom.xml, build.gradle, or build.gradle.kts.
|
||||
|
||||
Args:
|
||||
source_file: Path to the source file.
|
||||
|
||||
Returns:
|
||||
The project root directory, or None if not found.
|
||||
|
||||
"""
|
||||
current = source_file.parent
|
||||
while current != current.parent:
|
||||
if (current / "pom.xml").exists():
|
||||
return current
|
||||
if (current / "build.gradle").exists() or (current / "build.gradle.kts").exists():
|
||||
return current
|
||||
current = current.parent
|
||||
return None
|
||||
|
||||
def get_module_path(self, source_file: Path, project_root: Path, tests_root: Path | None = None) -> str:
|
||||
"""Get the module path for a Java source file.
|
||||
|
||||
For Java, this returns the fully qualified class name (e.g., 'com.example.Algorithms').
|
||||
|
||||
Args:
|
||||
source_file: Path to the source file.
|
||||
project_root: Root of the project.
|
||||
tests_root: Not used for Java.
|
||||
|
||||
Returns:
|
||||
Fully qualified class name string.
|
||||
|
||||
"""
|
||||
# Find the package from the file content
|
||||
try:
|
||||
content = source_file.read_text(encoding="utf-8")
|
||||
for line in content.split("\n"):
|
||||
line = line.strip()
|
||||
if line.startswith("package "):
|
||||
# Extract package name (remove 'package ' prefix and ';' suffix)
|
||||
package = line[8:].rstrip(";").strip()
|
||||
class_name = source_file.stem
|
||||
return f"{package}.{class_name}"
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Fallback: derive from path relative to src/main/java
|
||||
relative = source_file.relative_to(project_root)
|
||||
parts = list(relative.parts)
|
||||
|
||||
# Remove src/main/java prefix if present
|
||||
if len(parts) > 3 and parts[:3] == ["src", "main", "java"]:
|
||||
parts = parts[3:]
|
||||
|
||||
# Remove .java extension and join with dots
|
||||
if parts:
|
||||
parts[-1] = parts[-1].replace(".java", "")
|
||||
return ".".join(parts)
|
||||
|
||||
def get_runtime_files(self) -> list[Path]:
|
||||
"""Get paths to runtime files needed for Java."""
|
||||
# The Java runtime is distributed as a JAR
|
||||
return []
|
||||
|
||||
def ensure_runtime_environment(self, project_root: Path) -> bool:
|
||||
"""Ensure the runtime environment is set up."""
|
||||
# Check if codeflash-runtime is available
|
||||
config = detect_java_project(project_root)
|
||||
if config is None:
|
||||
return False
|
||||
|
||||
# For now, assume the runtime is available
|
||||
# A full implementation would check/install the JAR
|
||||
return True
|
||||
|
||||
def instrument_existing_test(
|
||||
self,
|
||||
test_path: Path,
|
||||
call_positions: Sequence[Any],
|
||||
function_to_optimize: Any,
|
||||
tests_project_root: Path,
|
||||
mode: str,
|
||||
) -> tuple[bool, str | None]:
|
||||
"""Inject profiling code into an existing test file."""
|
||||
return instrument_existing_test(
|
||||
test_path,
|
||||
call_positions,
|
||||
function_to_optimize,
|
||||
tests_project_root,
|
||||
mode,
|
||||
self._analyzer,
|
||||
)
|
||||
|
||||
def instrument_source_for_line_profiler(
|
||||
self, func_info: FunctionInfo, line_profiler_output_file: Path
|
||||
) -> bool:
|
||||
"""Instrument source code before line profiling."""
|
||||
# Not yet implemented for Java
|
||||
return False
|
||||
|
||||
def parse_line_profile_results(self, line_profiler_output_file: Path) -> dict:
|
||||
"""Parse line profiler output."""
|
||||
# Not yet implemented for Java
|
||||
return {}
|
||||
|
||||
def run_behavioral_tests(
|
||||
self,
|
||||
test_paths: Any,
|
||||
test_env: dict[str, str],
|
||||
cwd: Path,
|
||||
timeout: int | None = None,
|
||||
project_root: Path | None = None,
|
||||
enable_coverage: bool = False,
|
||||
candidate_index: int = 0,
|
||||
) -> tuple[Path, Any, Path | None, Path | None]:
|
||||
"""Run behavioral tests for Java."""
|
||||
return run_behavioral_tests(
|
||||
test_paths,
|
||||
test_env,
|
||||
cwd,
|
||||
timeout,
|
||||
project_root,
|
||||
enable_coverage,
|
||||
candidate_index,
|
||||
)
|
||||
|
||||
def run_benchmarking_tests(
|
||||
self,
|
||||
test_paths: Any,
|
||||
test_env: dict[str, str],
|
||||
cwd: Path,
|
||||
timeout: int | None = None,
|
||||
project_root: Path | None = None,
|
||||
min_loops: int = 5,
|
||||
max_loops: int = 100_000,
|
||||
target_duration_seconds: float = 10.0,
|
||||
) -> tuple[Path, Any]:
|
||||
"""Run benchmarking tests for Java."""
|
||||
return run_benchmarking_tests(
|
||||
test_paths,
|
||||
test_env,
|
||||
cwd,
|
||||
timeout,
|
||||
project_root,
|
||||
min_loops,
|
||||
max_loops,
|
||||
target_duration_seconds,
|
||||
)
|
||||
|
||||
|
||||
# Create a singleton instance for the registry
|
||||
_java_support: JavaSupport | None = None
|
||||
|
||||
|
||||
def get_java_support() -> JavaSupport:
|
||||
"""Get the JavaSupport singleton instance.
|
||||
|
||||
Returns:
|
||||
The JavaSupport instance.
|
||||
|
||||
"""
|
||||
global _java_support
|
||||
if _java_support is None:
|
||||
_java_support = JavaSupport()
|
||||
return _java_support
|
||||
370
codeflash/languages/java/test_discovery.py
Normal file
370
codeflash/languages/java/test_discovery.py
Normal file
|
|
@ -0,0 +1,370 @@
|
|||
"""Java test discovery for JUnit 5.
|
||||
|
||||
This module provides functionality to discover tests that exercise
|
||||
specific functions, mapping source functions to their tests.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import re
|
||||
from collections import defaultdict
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from codeflash.languages.base import FunctionInfo, TestInfo
|
||||
from codeflash.languages.java.config import detect_java_project
|
||||
from codeflash.languages.java.discovery import discover_test_methods
|
||||
from codeflash.languages.java.parser import JavaAnalyzer, get_java_analyzer
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Sequence
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def discover_tests(
|
||||
test_root: Path,
|
||||
source_functions: Sequence[FunctionInfo],
|
||||
analyzer: JavaAnalyzer | None = None,
|
||||
) -> dict[str, list[TestInfo]]:
|
||||
"""Map source functions to their tests via static analysis.
|
||||
|
||||
Uses several heuristics to match tests to functions:
|
||||
1. Test method name contains function name
|
||||
2. Test class name matches source class name
|
||||
3. Imports analysis
|
||||
4. Method call analysis in test code
|
||||
|
||||
Args:
|
||||
test_root: Root directory containing tests.
|
||||
source_functions: Functions to find tests for.
|
||||
analyzer: Optional JavaAnalyzer instance.
|
||||
|
||||
Returns:
|
||||
Dict mapping qualified function names to lists of TestInfo.
|
||||
|
||||
"""
|
||||
analyzer = analyzer or get_java_analyzer()
|
||||
|
||||
# Build a map of function names for quick lookup
|
||||
function_map: dict[str, FunctionInfo] = {}
|
||||
for func in source_functions:
|
||||
function_map[func.name] = func
|
||||
function_map[func.qualified_name] = func
|
||||
|
||||
# Find all test files
|
||||
test_files = list(test_root.rglob("*Test.java")) + list(test_root.rglob("Test*.java"))
|
||||
|
||||
# Result map
|
||||
result: dict[str, list[TestInfo]] = defaultdict(list)
|
||||
|
||||
for test_file in test_files:
|
||||
try:
|
||||
test_methods = discover_test_methods(test_file, analyzer)
|
||||
source = test_file.read_text(encoding="utf-8")
|
||||
|
||||
for test_method in test_methods:
|
||||
# Find which source functions this test might exercise
|
||||
matched_functions = _match_test_to_functions(
|
||||
test_method, source, function_map, analyzer
|
||||
)
|
||||
|
||||
for func_name in matched_functions:
|
||||
result[func_name].append(
|
||||
TestInfo(
|
||||
test_name=test_method.name,
|
||||
test_file=test_file,
|
||||
test_class=test_method.class_name,
|
||||
)
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.warning("Failed to analyze test file %s: %s", test_file, e)
|
||||
|
||||
return dict(result)
|
||||
|
||||
|
||||
def _match_test_to_functions(
|
||||
test_method: FunctionInfo,
|
||||
test_source: str,
|
||||
function_map: dict[str, FunctionInfo],
|
||||
analyzer: JavaAnalyzer,
|
||||
) -> list[str]:
|
||||
"""Match a test method to source functions it might exercise.
|
||||
|
||||
Args:
|
||||
test_method: The test method.
|
||||
test_source: Full source code of the test file.
|
||||
function_map: Map of function names to FunctionInfo.
|
||||
analyzer: JavaAnalyzer instance.
|
||||
|
||||
Returns:
|
||||
List of function qualified names that this test might exercise.
|
||||
|
||||
"""
|
||||
matched: list[str] = []
|
||||
|
||||
# Strategy 1: Test method name contains function name
|
||||
# e.g., testAdd -> add, testCalculatorAdd -> Calculator.add
|
||||
test_name_lower = test_method.name.lower()
|
||||
|
||||
for func_name, func_info in function_map.items():
|
||||
if func_info.name.lower() in test_name_lower:
|
||||
matched.append(func_info.qualified_name)
|
||||
|
||||
# Strategy 2: Method call analysis
|
||||
# Look for direct method calls in the test code
|
||||
source_bytes = test_source.encode("utf8")
|
||||
tree = analyzer.parse(source_bytes)
|
||||
|
||||
# Find method calls within the test method's line range
|
||||
method_calls = _find_method_calls_in_range(
|
||||
tree.root_node,
|
||||
source_bytes,
|
||||
test_method.start_line,
|
||||
test_method.end_line,
|
||||
analyzer,
|
||||
)
|
||||
|
||||
for call_name in method_calls:
|
||||
if call_name in function_map:
|
||||
qualified = function_map[call_name].qualified_name
|
||||
if qualified not in matched:
|
||||
matched.append(qualified)
|
||||
|
||||
# Strategy 3: Test class naming convention
|
||||
# e.g., CalculatorTest tests Calculator
|
||||
if test_method.class_name:
|
||||
# Remove "Test" suffix or prefix
|
||||
source_class_name = test_method.class_name
|
||||
if source_class_name.endswith("Test"):
|
||||
source_class_name = source_class_name[:-4]
|
||||
elif source_class_name.startswith("Test"):
|
||||
source_class_name = source_class_name[4:]
|
||||
|
||||
# Look for functions in the matching class
|
||||
for func_name, func_info in function_map.items():
|
||||
if func_info.class_name == source_class_name:
|
||||
if func_info.qualified_name not in matched:
|
||||
matched.append(func_info.qualified_name)
|
||||
|
||||
return matched
|
||||
|
||||
|
||||
def _find_method_calls_in_range(
|
||||
node,
|
||||
source_bytes: bytes,
|
||||
start_line: int,
|
||||
end_line: int,
|
||||
analyzer: JavaAnalyzer,
|
||||
) -> list[str]:
|
||||
"""Find method calls within a line range.
|
||||
|
||||
Args:
|
||||
node: Tree-sitter node to search.
|
||||
source_bytes: Source code as bytes.
|
||||
start_line: Start line (1-indexed).
|
||||
end_line: End line (1-indexed).
|
||||
analyzer: JavaAnalyzer instance.
|
||||
|
||||
Returns:
|
||||
List of method names called.
|
||||
|
||||
"""
|
||||
calls: list[str] = []
|
||||
|
||||
# Check if this node is within the range (convert to 0-indexed)
|
||||
node_start = node.start_point[0] + 1
|
||||
node_end = node.end_point[0] + 1
|
||||
|
||||
if node_end < start_line or node_start > end_line:
|
||||
return calls
|
||||
|
||||
if node.type == "method_invocation":
|
||||
name_node = node.child_by_field_name("name")
|
||||
if name_node:
|
||||
calls.append(analyzer.get_node_text(name_node, source_bytes))
|
||||
|
||||
for child in node.children:
|
||||
calls.extend(
|
||||
_find_method_calls_in_range(child, source_bytes, start_line, end_line, analyzer)
|
||||
)
|
||||
|
||||
return calls
|
||||
|
||||
|
||||
def find_tests_for_function(
|
||||
function: FunctionInfo,
|
||||
test_root: Path,
|
||||
analyzer: JavaAnalyzer | None = None,
|
||||
) -> list[TestInfo]:
|
||||
"""Find tests that exercise a specific function.
|
||||
|
||||
Args:
|
||||
function: The function to find tests for.
|
||||
test_root: Root directory containing tests.
|
||||
analyzer: Optional JavaAnalyzer instance.
|
||||
|
||||
Returns:
|
||||
List of TestInfo for tests that might exercise this function.
|
||||
|
||||
"""
|
||||
result = discover_tests(test_root, [function], analyzer)
|
||||
return result.get(function.qualified_name, [])
|
||||
|
||||
|
||||
def get_test_class_for_source_class(
|
||||
source_class_name: str,
|
||||
test_root: Path,
|
||||
) -> Path | None:
|
||||
"""Find the test class file for a source class.
|
||||
|
||||
Args:
|
||||
source_class_name: Name of the source class.
|
||||
test_root: Root directory containing tests.
|
||||
|
||||
Returns:
|
||||
Path to the test file, or None if not found.
|
||||
|
||||
"""
|
||||
# Try common naming patterns
|
||||
patterns = [
|
||||
f"{source_class_name}Test.java",
|
||||
f"Test{source_class_name}.java",
|
||||
f"{source_class_name}Tests.java",
|
||||
]
|
||||
|
||||
for pattern in patterns:
|
||||
matches = list(test_root.rglob(pattern))
|
||||
if matches:
|
||||
return matches[0]
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def discover_all_tests(
|
||||
test_root: Path,
|
||||
analyzer: JavaAnalyzer | None = None,
|
||||
) -> list[FunctionInfo]:
|
||||
"""Discover all test methods in a test directory.
|
||||
|
||||
Args:
|
||||
test_root: Root directory containing tests.
|
||||
analyzer: Optional JavaAnalyzer instance.
|
||||
|
||||
Returns:
|
||||
List of FunctionInfo for all test methods.
|
||||
|
||||
"""
|
||||
analyzer = analyzer or get_java_analyzer()
|
||||
all_tests: list[FunctionInfo] = []
|
||||
|
||||
# Find all test files
|
||||
test_files = list(test_root.rglob("*Test.java")) + list(test_root.rglob("Test*.java"))
|
||||
|
||||
for test_file in test_files:
|
||||
try:
|
||||
tests = discover_test_methods(test_file, analyzer)
|
||||
all_tests.extend(tests)
|
||||
except Exception as e:
|
||||
logger.warning("Failed to analyze test file %s: %s", test_file, e)
|
||||
|
||||
return all_tests
|
||||
|
||||
|
||||
def get_test_file_suffix() -> str:
|
||||
"""Get the test file suffix for Java.
|
||||
|
||||
Returns:
|
||||
Test file suffix.
|
||||
|
||||
"""
|
||||
return "Test.java"
|
||||
|
||||
|
||||
def is_test_file(file_path: Path) -> bool:
|
||||
"""Check if a file is a test file.
|
||||
|
||||
Args:
|
||||
file_path: Path to check.
|
||||
|
||||
Returns:
|
||||
True if this appears to be a test file.
|
||||
|
||||
"""
|
||||
name = file_path.name
|
||||
|
||||
# Check naming patterns
|
||||
if name.endswith("Test.java") or name.endswith("Tests.java"):
|
||||
return True
|
||||
if name.startswith("Test") and name.endswith(".java"):
|
||||
return True
|
||||
|
||||
# Check if it's in a test directory
|
||||
path_parts = file_path.parts
|
||||
for part in path_parts:
|
||||
if part in ("test", "tests", "src/test"):
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
|
||||
def get_test_methods_for_class(
|
||||
test_file: Path,
|
||||
test_class_name: str | None = None,
|
||||
analyzer: JavaAnalyzer | None = None,
|
||||
) -> list[FunctionInfo]:
|
||||
"""Get all test methods in a specific test class.
|
||||
|
||||
Args:
|
||||
test_file: Path to the test file.
|
||||
test_class_name: Optional class name to filter (uses file name if not provided).
|
||||
analyzer: Optional JavaAnalyzer instance.
|
||||
|
||||
Returns:
|
||||
List of FunctionInfo for test methods.
|
||||
|
||||
"""
|
||||
tests = discover_test_methods(test_file, analyzer)
|
||||
|
||||
if test_class_name:
|
||||
return [t for t in tests if t.class_name == test_class_name]
|
||||
|
||||
return tests
|
||||
|
||||
|
||||
def build_test_mapping_for_project(
|
||||
project_root: Path,
|
||||
analyzer: JavaAnalyzer | None = None,
|
||||
) -> dict[str, list[TestInfo]]:
|
||||
"""Build a complete test mapping for a project.
|
||||
|
||||
Args:
|
||||
project_root: Root directory of the project.
|
||||
analyzer: Optional JavaAnalyzer instance.
|
||||
|
||||
Returns:
|
||||
Dict mapping qualified function names to lists of TestInfo.
|
||||
|
||||
"""
|
||||
analyzer = analyzer or get_java_analyzer()
|
||||
|
||||
# Detect project configuration
|
||||
config = detect_java_project(project_root)
|
||||
if not config:
|
||||
return {}
|
||||
|
||||
if not config.source_root or not config.test_root:
|
||||
return {}
|
||||
|
||||
# Discover all source functions
|
||||
from codeflash.languages.java.discovery import discover_functions
|
||||
|
||||
source_functions: list[FunctionInfo] = []
|
||||
for java_file in config.source_root.rglob("*.java"):
|
||||
funcs = discover_functions(java_file, analyzer=analyzer)
|
||||
source_functions.extend(funcs)
|
||||
|
||||
# Map tests to functions
|
||||
return discover_tests(config.test_root, source_functions, analyzer)
|
||||
440
codeflash/languages/java/test_runner.py
Normal file
440
codeflash/languages/java/test_runner.py
Normal file
|
|
@ -0,0 +1,440 @@
|
|||
"""Java test runner for JUnit 5 with Maven.
|
||||
|
||||
This module provides functionality to run JUnit 5 tests using Maven Surefire,
|
||||
supporting both behavioral testing and benchmarking modes.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import os
|
||||
import subprocess
|
||||
import tempfile
|
||||
import uuid
|
||||
import xml.etree.ElementTree as ET
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from codeflash.languages.base import TestResult
|
||||
from codeflash.languages.java.build_tools import (
|
||||
find_maven_executable,
|
||||
find_test_root,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
pass
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class JavaTestRunResult:
|
||||
"""Result of running Java tests."""
|
||||
|
||||
success: bool
|
||||
tests_run: int
|
||||
tests_passed: int
|
||||
tests_failed: int
|
||||
tests_skipped: int
|
||||
test_results: list[TestResult]
|
||||
sqlite_db_path: Path | None
|
||||
junit_xml_path: Path | None
|
||||
stdout: str
|
||||
stderr: str
|
||||
returncode: int
|
||||
|
||||
|
||||
def run_behavioral_tests(
|
||||
test_paths: Any,
|
||||
test_env: dict[str, str],
|
||||
cwd: Path,
|
||||
timeout: int | None = None,
|
||||
project_root: Path | None = None,
|
||||
enable_coverage: bool = False,
|
||||
candidate_index: int = 0,
|
||||
) -> tuple[Path, Any, Path | None, Path | None]:
|
||||
"""Run behavioral tests for Java code.
|
||||
|
||||
This runs tests and captures behavior (inputs/outputs) for verification.
|
||||
|
||||
Args:
|
||||
test_paths: TestFiles object or list of test file paths.
|
||||
test_env: Environment variables for the test run.
|
||||
cwd: Working directory for running tests.
|
||||
timeout: Optional timeout in seconds.
|
||||
project_root: Project root directory.
|
||||
enable_coverage: Whether to collect coverage information.
|
||||
candidate_index: Index of the candidate being tested.
|
||||
|
||||
Returns:
|
||||
Tuple of (result_file_path, subprocess_result, coverage_path, config_path).
|
||||
|
||||
"""
|
||||
project_root = project_root or cwd
|
||||
|
||||
# Generate unique result file path
|
||||
result_id = uuid.uuid4().hex[:8]
|
||||
result_file = Path(tempfile.gettempdir()) / f"codeflash_java_behavior_{result_id}.db"
|
||||
|
||||
# Set environment variables for CodeFlash runtime
|
||||
run_env = os.environ.copy()
|
||||
run_env.update(test_env)
|
||||
run_env["CODEFLASH_RESULT_FILE"] = str(result_file)
|
||||
run_env["CODEFLASH_MODE"] = "behavior"
|
||||
|
||||
# Run Maven tests
|
||||
result = _run_maven_tests(
|
||||
project_root,
|
||||
test_paths,
|
||||
run_env,
|
||||
timeout=timeout or 300,
|
||||
)
|
||||
|
||||
return result_file, result, None, None
|
||||
|
||||
|
||||
def run_benchmarking_tests(
|
||||
test_paths: Any,
|
||||
test_env: dict[str, str],
|
||||
cwd: Path,
|
||||
timeout: int | None = None,
|
||||
project_root: Path | None = None,
|
||||
min_loops: int = 5,
|
||||
max_loops: int = 100_000,
|
||||
target_duration_seconds: float = 10.0,
|
||||
) -> tuple[Path, Any]:
|
||||
"""Run benchmarking tests for Java code.
|
||||
|
||||
This runs tests with performance measurement.
|
||||
|
||||
Args:
|
||||
test_paths: TestFiles object or list of test file paths.
|
||||
test_env: Environment variables for the test run.
|
||||
cwd: Working directory for running tests.
|
||||
timeout: Optional timeout in seconds.
|
||||
project_root: Project root directory.
|
||||
min_loops: Minimum number of loops for benchmarking.
|
||||
max_loops: Maximum number of loops for benchmarking.
|
||||
target_duration_seconds: Target duration for benchmarking in seconds.
|
||||
|
||||
Returns:
|
||||
Tuple of (result_file_path, subprocess_result).
|
||||
|
||||
"""
|
||||
project_root = project_root or cwd
|
||||
|
||||
# Generate unique result file path
|
||||
result_id = uuid.uuid4().hex[:8]
|
||||
result_file = Path(tempfile.gettempdir()) / f"codeflash_java_benchmark_{result_id}.db"
|
||||
|
||||
# Set environment variables
|
||||
run_env = os.environ.copy()
|
||||
run_env.update(test_env)
|
||||
run_env["CODEFLASH_RESULT_FILE"] = str(result_file)
|
||||
run_env["CODEFLASH_MODE"] = "benchmark"
|
||||
run_env["CODEFLASH_MIN_LOOPS"] = str(min_loops)
|
||||
run_env["CODEFLASH_MAX_LOOPS"] = str(max_loops)
|
||||
run_env["CODEFLASH_TARGET_DURATION"] = str(target_duration_seconds)
|
||||
|
||||
# Run Maven tests
|
||||
result = _run_maven_tests(
|
||||
project_root,
|
||||
test_paths,
|
||||
run_env,
|
||||
timeout=timeout or 600, # Longer timeout for benchmarks
|
||||
)
|
||||
|
||||
return result_file, result
|
||||
|
||||
|
||||
def _run_maven_tests(
|
||||
project_root: Path,
|
||||
test_paths: Any,
|
||||
env: dict[str, str],
|
||||
timeout: int = 300,
|
||||
) -> subprocess.CompletedProcess:
|
||||
"""Run Maven tests with Surefire.
|
||||
|
||||
Args:
|
||||
project_root: Root directory of the Maven project.
|
||||
test_paths: Test files or classes to run.
|
||||
env: Environment variables.
|
||||
timeout: Maximum execution time in seconds.
|
||||
|
||||
Returns:
|
||||
CompletedProcess with test results.
|
||||
|
||||
"""
|
||||
mvn = find_maven_executable()
|
||||
if not mvn:
|
||||
logger.error("Maven not found")
|
||||
return subprocess.CompletedProcess(
|
||||
args=["mvn"],
|
||||
returncode=-1,
|
||||
stdout="",
|
||||
stderr="Maven not found",
|
||||
)
|
||||
|
||||
# Build test filter
|
||||
test_filter = _build_test_filter(test_paths)
|
||||
|
||||
# Build Maven command
|
||||
cmd = [mvn, "test", "-fae"] # Fail at end to run all tests
|
||||
|
||||
if test_filter:
|
||||
cmd.append(f"-Dtest={test_filter}")
|
||||
|
||||
try:
|
||||
result = subprocess.run(
|
||||
cmd,
|
||||
check=False,
|
||||
cwd=project_root,
|
||||
env=env,
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=timeout,
|
||||
)
|
||||
return result
|
||||
|
||||
except subprocess.TimeoutExpired:
|
||||
logger.error("Maven test execution timed out after %d seconds", timeout)
|
||||
return subprocess.CompletedProcess(
|
||||
args=cmd,
|
||||
returncode=-2,
|
||||
stdout="",
|
||||
stderr=f"Test execution timed out after {timeout} seconds",
|
||||
)
|
||||
except Exception as e:
|
||||
logger.exception("Maven test execution failed: %s", e)
|
||||
return subprocess.CompletedProcess(
|
||||
args=cmd,
|
||||
returncode=-1,
|
||||
stdout="",
|
||||
stderr=str(e),
|
||||
)
|
||||
|
||||
|
||||
def _build_test_filter(test_paths: Any) -> str:
|
||||
"""Build a Maven Surefire test filter from test paths.
|
||||
|
||||
Args:
|
||||
test_paths: Test files, classes, or methods to include.
|
||||
|
||||
Returns:
|
||||
Surefire test filter string.
|
||||
|
||||
"""
|
||||
if not test_paths:
|
||||
return ""
|
||||
|
||||
# Handle different input types
|
||||
if isinstance(test_paths, (list, tuple)):
|
||||
filters = []
|
||||
for path in test_paths:
|
||||
if isinstance(path, Path):
|
||||
# Convert file path to class name
|
||||
class_name = _path_to_class_name(path)
|
||||
if class_name:
|
||||
filters.append(class_name)
|
||||
elif isinstance(path, str):
|
||||
filters.append(path)
|
||||
return ",".join(filters) if filters else ""
|
||||
|
||||
# Handle TestFiles object (has test_files attribute)
|
||||
if hasattr(test_paths, "test_files"):
|
||||
return _build_test_filter(list(test_paths.test_files))
|
||||
|
||||
return ""
|
||||
|
||||
|
||||
def _path_to_class_name(path: Path) -> str | None:
|
||||
"""Convert a test file path to a Java class name.
|
||||
|
||||
Args:
|
||||
path: Path to the test file.
|
||||
|
||||
Returns:
|
||||
Fully qualified class name, or None if unable to determine.
|
||||
|
||||
"""
|
||||
if not path.suffix == ".java":
|
||||
return None
|
||||
|
||||
# Try to extract package from path
|
||||
# e.g., src/test/java/com/example/CalculatorTest.java -> com.example.CalculatorTest
|
||||
parts = path.parts
|
||||
|
||||
# Find 'java' in the path and take everything after
|
||||
try:
|
||||
java_idx = parts.index("java")
|
||||
class_parts = parts[java_idx + 1 :]
|
||||
# Remove .java extension from last part
|
||||
class_parts = list(class_parts)
|
||||
class_parts[-1] = class_parts[-1].replace(".java", "")
|
||||
return ".".join(class_parts)
|
||||
except ValueError:
|
||||
# No 'java' directory, just use the file name
|
||||
return path.stem
|
||||
|
||||
|
||||
def run_tests(
|
||||
test_files: list[Path],
|
||||
cwd: Path,
|
||||
env: dict[str, str],
|
||||
timeout: int,
|
||||
) -> tuple[list[TestResult], Path]:
|
||||
"""Run tests and return results.
|
||||
|
||||
Args:
|
||||
test_files: Paths to test files to run.
|
||||
cwd: Working directory for test execution.
|
||||
env: Environment variables.
|
||||
timeout: Maximum execution time in seconds.
|
||||
|
||||
Returns:
|
||||
Tuple of (list of TestResults, path to JUnit XML).
|
||||
|
||||
"""
|
||||
# Run Maven tests
|
||||
result = _run_maven_tests(cwd, test_files, env, timeout)
|
||||
|
||||
# Parse JUnit XML results
|
||||
surefire_dir = cwd / "target" / "surefire-reports"
|
||||
test_results = parse_surefire_results(surefire_dir)
|
||||
|
||||
# Return first XML file path
|
||||
junit_files = list(surefire_dir.glob("TEST-*.xml")) if surefire_dir.exists() else []
|
||||
junit_path = junit_files[0] if junit_files else cwd / "target" / "surefire-reports" / "test-results.xml"
|
||||
|
||||
return test_results, junit_path
|
||||
|
||||
|
||||
def parse_test_results(junit_xml_path: Path, stdout: str) -> list[TestResult]:
|
||||
"""Parse test results from JUnit XML and stdout.
|
||||
|
||||
Args:
|
||||
junit_xml_path: Path to JUnit XML results file.
|
||||
stdout: Standard output from test execution.
|
||||
|
||||
Returns:
|
||||
List of TestResult objects.
|
||||
|
||||
"""
|
||||
return parse_surefire_results(junit_xml_path.parent)
|
||||
|
||||
|
||||
def parse_surefire_results(surefire_dir: Path) -> list[TestResult]:
|
||||
"""Parse Maven Surefire XML reports into TestResult objects.
|
||||
|
||||
Args:
|
||||
surefire_dir: Directory containing Surefire XML reports.
|
||||
|
||||
Returns:
|
||||
List of TestResult objects.
|
||||
|
||||
"""
|
||||
results: list[TestResult] = []
|
||||
|
||||
if not surefire_dir.exists():
|
||||
return results
|
||||
|
||||
for xml_file in surefire_dir.glob("TEST-*.xml"):
|
||||
results.extend(_parse_surefire_xml(xml_file))
|
||||
|
||||
return results
|
||||
|
||||
|
||||
def _parse_surefire_xml(xml_file: Path) -> list[TestResult]:
|
||||
"""Parse a single Surefire XML file.
|
||||
|
||||
Args:
|
||||
xml_file: Path to the XML file.
|
||||
|
||||
Returns:
|
||||
List of TestResult objects for tests in this file.
|
||||
|
||||
"""
|
||||
results: list[TestResult] = []
|
||||
|
||||
try:
|
||||
tree = ET.parse(xml_file)
|
||||
root = tree.getroot()
|
||||
|
||||
# Get test class info
|
||||
class_name = root.get("name", "")
|
||||
|
||||
# Process each test case
|
||||
for testcase in root.findall(".//testcase"):
|
||||
test_name = testcase.get("name", "")
|
||||
test_time = float(testcase.get("time", "0"))
|
||||
runtime_ns = int(test_time * 1_000_000_000)
|
||||
|
||||
# Check for failure/error
|
||||
failure = testcase.find("failure")
|
||||
error = testcase.find("error")
|
||||
skipped = testcase.find("skipped")
|
||||
|
||||
passed = failure is None and error is None and skipped is None
|
||||
error_message = None
|
||||
|
||||
if failure is not None:
|
||||
error_message = failure.get("message", "")
|
||||
if failure.text:
|
||||
error_message += "\n" + failure.text
|
||||
|
||||
if error is not None:
|
||||
error_message = error.get("message", "")
|
||||
if error.text:
|
||||
error_message += "\n" + error.text
|
||||
|
||||
# Get stdout/stderr from system-out/system-err elements
|
||||
stdout = ""
|
||||
stderr = ""
|
||||
stdout_elem = testcase.find("system-out")
|
||||
if stdout_elem is not None and stdout_elem.text:
|
||||
stdout = stdout_elem.text
|
||||
stderr_elem = testcase.find("system-err")
|
||||
if stderr_elem is not None and stderr_elem.text:
|
||||
stderr = stderr_elem.text
|
||||
|
||||
results.append(
|
||||
TestResult(
|
||||
test_name=test_name,
|
||||
test_file=xml_file,
|
||||
passed=passed,
|
||||
runtime_ns=runtime_ns,
|
||||
stdout=stdout,
|
||||
stderr=stderr,
|
||||
error_message=error_message,
|
||||
)
|
||||
)
|
||||
|
||||
except ET.ParseError as e:
|
||||
logger.warning("Failed to parse Surefire report %s: %s", xml_file, e)
|
||||
|
||||
return results
|
||||
|
||||
|
||||
def get_test_run_command(
|
||||
project_root: Path,
|
||||
test_classes: list[str] | None = None,
|
||||
) -> list[str]:
|
||||
"""Get the command to run Java tests.
|
||||
|
||||
Args:
|
||||
project_root: Root directory of the Maven project.
|
||||
test_classes: Optional list of test class names to run.
|
||||
|
||||
Returns:
|
||||
Command as list of strings.
|
||||
|
||||
"""
|
||||
mvn = find_maven_executable() or "mvn"
|
||||
|
||||
cmd = [mvn, "test"]
|
||||
|
||||
if test_classes:
|
||||
cmd.append(f"-Dtest={','.join(test_classes)}")
|
||||
|
||||
return cmd
|
||||
|
|
@ -24,7 +24,7 @@ from codeflash.code_utils.git_worktree_utils import (
|
|||
)
|
||||
from codeflash.code_utils.time_utils import humanize_runtime
|
||||
from codeflash.either import is_successful
|
||||
from codeflash.languages import is_javascript, set_current_language
|
||||
from codeflash.languages import is_java, is_javascript, set_current_language
|
||||
from codeflash.models.models import ValidCode
|
||||
from codeflash.telemetry.posthog_cf import ph
|
||||
from codeflash.verification.verification_utils import TestConfig
|
||||
|
|
@ -229,8 +229,8 @@ class Optimizer:
|
|||
|
||||
original_module_code: str = original_module_path.read_text(encoding="utf8")
|
||||
|
||||
# For JavaScript/TypeScript, skip Python-specific AST parsing
|
||||
if is_javascript():
|
||||
# For JavaScript/TypeScript/Java, skip Python-specific AST parsing
|
||||
if is_javascript() or is_java():
|
||||
validated_original_code: dict[Path, ValidCode] = {
|
||||
original_module_path: ValidCode(source_code=original_module_code, normalized_code=original_module_code)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -6,14 +6,19 @@ from typing import Optional
|
|||
|
||||
from pydantic.dataclasses import dataclass
|
||||
|
||||
from codeflash.languages import current_language_support, is_javascript
|
||||
from codeflash.languages import current_language_support, is_java, is_javascript
|
||||
|
||||
|
||||
def get_test_file_path(test_dir: Path, function_name: str, iteration: int = 0, test_type: str = "unit") -> Path:
|
||||
assert test_type in {"unit", "inspired", "replay", "perf"}
|
||||
function_name = function_name.replace(".", "_")
|
||||
# Use appropriate file extension based on language
|
||||
extension = current_language_support().get_test_file_suffix() if is_javascript() else ".py"
|
||||
if is_javascript():
|
||||
extension = current_language_support().get_test_file_suffix()
|
||||
elif is_java():
|
||||
extension = ".java"
|
||||
else:
|
||||
extension = ".py"
|
||||
path = test_dir / f"test_{function_name}__{test_type}_test_{iteration}{extension}"
|
||||
if path.exists():
|
||||
return get_test_file_path(test_dir, function_name, iteration + 1, test_type)
|
||||
|
|
@ -86,10 +91,12 @@ class TestConfig:
|
|||
def test_framework(self) -> str:
|
||||
"""Returns the appropriate test framework based on language.
|
||||
|
||||
Returns 'jest' for JavaScript/TypeScript, 'pytest' for Python (default).
|
||||
Returns 'jest' for JavaScript/TypeScript, 'junit5' for Java, 'pytest' for Python (default).
|
||||
"""
|
||||
if is_javascript():
|
||||
return "jest"
|
||||
if is_java():
|
||||
return "junit5"
|
||||
return "pytest"
|
||||
|
||||
def set_language(self, language: str) -> None:
|
||||
|
|
|
|||
|
|
@ -25,6 +25,7 @@ dependencies = [
|
|||
"tree-sitter>=0.23.0",
|
||||
"tree-sitter-javascript>=0.23.0",
|
||||
"tree-sitter-typescript>=0.23.0",
|
||||
"tree-sitter-java>=0.23.0",
|
||||
"pytest-timeout>=2.1.0",
|
||||
"tomlkit>=0.11.7",
|
||||
"junitparser>=3.1.0",
|
||||
|
|
|
|||
5
tests/test_languages/fixtures/java_maven/codeflash.toml
Normal file
5
tests/test_languages/fixtures/java_maven/codeflash.toml
Normal file
|
|
@ -0,0 +1,5 @@
|
|||
# Codeflash configuration for Java project
|
||||
|
||||
[tool.codeflash]
|
||||
module-root = "src/main/java"
|
||||
tests-root = "src/test/java"
|
||||
|
|
@ -0,0 +1,127 @@
|
|||
package com.example;
|
||||
|
||||
import com.example.helpers.MathHelper;
|
||||
import com.example.helpers.Formatter;
|
||||
|
||||
/**
|
||||
* Calculator class - demonstrates class method optimization scenarios.
|
||||
* Uses helper functions from MathHelper and Formatter.
|
||||
*/
|
||||
public class Calculator {
|
||||
|
||||
private int precision;
|
||||
private java.util.List<String> history;
|
||||
|
||||
/**
|
||||
* Creates a Calculator with specified precision.
|
||||
* @param precision number of decimal places for formatting
|
||||
*/
|
||||
public Calculator(int precision) {
|
||||
this.precision = precision;
|
||||
this.history = new java.util.ArrayList<>();
|
||||
}
|
||||
|
||||
/**
|
||||
* Creates a Calculator with default precision of 2.
|
||||
*/
|
||||
public Calculator() {
|
||||
this(2);
|
||||
}
|
||||
|
||||
/**
|
||||
* Calculate compound interest with multiple helper dependencies.
|
||||
*
|
||||
* @param principal Initial amount
|
||||
* @param rate Interest rate (as decimal)
|
||||
* @param time Time in years
|
||||
* @param n Compounding frequency per year
|
||||
* @return Compound interest result formatted as string
|
||||
*/
|
||||
public String calculateCompoundInterest(double principal, double rate, int time, int n) {
|
||||
Formatter.validateInput(principal, "principal");
|
||||
Formatter.validateInput(rate, "rate");
|
||||
|
||||
// Inefficient: recalculates power multiple times
|
||||
double result = principal;
|
||||
for (int i = 0; i < n * time; i++) {
|
||||
result = MathHelper.multiply(result, MathHelper.add(1.0, rate / n));
|
||||
}
|
||||
|
||||
double interest = result - principal;
|
||||
history.add("compound:" + interest);
|
||||
return Formatter.formatNumber(interest, precision);
|
||||
}
|
||||
|
||||
/**
|
||||
* Calculate permutation using factorial helper.
|
||||
*
|
||||
* @param n Total items
|
||||
* @param r Items to choose
|
||||
* @return Permutation result (n! / (n-r)!)
|
||||
*/
|
||||
public long permutation(int n, int r) {
|
||||
if (n < r) {
|
||||
return 0;
|
||||
}
|
||||
// Inefficient: calculates factorial(n) fully even when not needed
|
||||
return MathHelper.factorial(n) / MathHelper.factorial(n - r);
|
||||
}
|
||||
|
||||
/**
|
||||
* Calculate combination (n choose r).
|
||||
*
|
||||
* @param n Total items
|
||||
* @param r Items to choose
|
||||
* @return Combination result (n! / (r! * (n-r)!))
|
||||
*/
|
||||
public long combination(int n, int r) {
|
||||
if (n < r) {
|
||||
return 0;
|
||||
}
|
||||
// Inefficient: calculates full factorials
|
||||
return MathHelper.factorial(n) / (MathHelper.factorial(r) * MathHelper.factorial(n - r));
|
||||
}
|
||||
|
||||
/**
|
||||
* Calculate Fibonacci number at position n.
|
||||
*
|
||||
* @param n Position in Fibonacci sequence (0-indexed)
|
||||
* @return Fibonacci number at position n
|
||||
*/
|
||||
public long fibonacci(int n) {
|
||||
// Inefficient recursive implementation without memoization
|
||||
if (n <= 1) {
|
||||
return n;
|
||||
}
|
||||
return fibonacci(n - 1) + fibonacci(n - 2);
|
||||
}
|
||||
|
||||
/**
|
||||
* Static method for quick calculations.
|
||||
*
|
||||
* @param a First number
|
||||
* @param b Second number
|
||||
* @return Sum of a and b
|
||||
*/
|
||||
public static double quickAdd(double a, double b) {
|
||||
return MathHelper.add(a, b);
|
||||
}
|
||||
|
||||
/**
|
||||
* Get calculation history.
|
||||
*
|
||||
* @return List of past calculations
|
||||
*/
|
||||
public java.util.List<String> getHistory() {
|
||||
return new java.util.ArrayList<>(history);
|
||||
}
|
||||
|
||||
/**
|
||||
* Get current precision setting.
|
||||
*
|
||||
* @return precision value
|
||||
*/
|
||||
public int getPrecision() {
|
||||
return precision;
|
||||
}
|
||||
}
|
||||
|
|
@ -0,0 +1,171 @@
|
|||
package com.example;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.HashMap;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
|
||||
/**
|
||||
* Data processing class with complex methods to optimize.
|
||||
*/
|
||||
public class DataProcessor {
|
||||
|
||||
/**
|
||||
* Find duplicate elements in a list.
|
||||
*
|
||||
* @param list List to check for duplicates
|
||||
* @param <T> Type of elements
|
||||
* @return List of duplicate elements
|
||||
*/
|
||||
public static <T> List<T> findDuplicates(List<T> list) {
|
||||
List<T> duplicates = new ArrayList<>();
|
||||
if (list == null) {
|
||||
return duplicates;
|
||||
}
|
||||
// Inefficient: O(n^2) nested loop
|
||||
for (int i = 0; i < list.size(); i++) {
|
||||
for (int j = i + 1; j < list.size(); j++) {
|
||||
if (list.get(i).equals(list.get(j)) && !duplicates.contains(list.get(i))) {
|
||||
duplicates.add(list.get(i));
|
||||
}
|
||||
}
|
||||
}
|
||||
return duplicates;
|
||||
}
|
||||
|
||||
/**
|
||||
* Group elements by a key function.
|
||||
*
|
||||
* @param list List to group
|
||||
* @param keyExtractor Function to extract key from element
|
||||
* @param <T> Type of elements
|
||||
* @param <K> Type of key
|
||||
* @return Map of key to list of elements
|
||||
*/
|
||||
public static <T, K> Map<K, List<T>> groupBy(List<T> list, java.util.function.Function<T, K> keyExtractor) {
|
||||
Map<K, List<T>> result = new HashMap<>();
|
||||
if (list == null) {
|
||||
return result;
|
||||
}
|
||||
// Could use streams, but explicit loop for optimization opportunity
|
||||
for (T item : list) {
|
||||
K key = keyExtractor.apply(item);
|
||||
if (!result.containsKey(key)) {
|
||||
result.put(key, new ArrayList<>());
|
||||
}
|
||||
result.get(key).add(item);
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
/**
|
||||
* Find intersection of two lists.
|
||||
*
|
||||
* @param list1 First list
|
||||
* @param list2 Second list
|
||||
* @param <T> Type of elements
|
||||
* @return List of common elements
|
||||
*/
|
||||
public static <T> List<T> intersection(List<T> list1, List<T> list2) {
|
||||
List<T> result = new ArrayList<>();
|
||||
if (list1 == null || list2 == null) {
|
||||
return result;
|
||||
}
|
||||
// Inefficient: O(n*m) nested loop
|
||||
for (T item : list1) {
|
||||
if (list2.contains(item) && !result.contains(item)) {
|
||||
result.add(item);
|
||||
}
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
/**
|
||||
* Flatten a nested list structure.
|
||||
*
|
||||
* @param nestedList List of lists
|
||||
* @param <T> Type of elements
|
||||
* @return Flattened list
|
||||
*/
|
||||
public static <T> List<T> flatten(List<List<T>> nestedList) {
|
||||
List<T> result = new ArrayList<>();
|
||||
if (nestedList == null) {
|
||||
return result;
|
||||
}
|
||||
// Simple but could be optimized with capacity hints
|
||||
for (List<T> innerList : nestedList) {
|
||||
if (innerList != null) {
|
||||
result.addAll(innerList);
|
||||
}
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
/**
|
||||
* Count frequency of each element.
|
||||
*
|
||||
* @param list List to count
|
||||
* @param <T> Type of elements
|
||||
* @return Map of element to frequency
|
||||
*/
|
||||
public static <T> Map<T, Integer> countFrequency(List<T> list) {
|
||||
Map<T, Integer> frequency = new HashMap<>();
|
||||
if (list == null) {
|
||||
return frequency;
|
||||
}
|
||||
for (T item : list) {
|
||||
// Inefficient: could use merge or compute
|
||||
if (frequency.containsKey(item)) {
|
||||
frequency.put(item, frequency.get(item) + 1);
|
||||
} else {
|
||||
frequency.put(item, 1);
|
||||
}
|
||||
}
|
||||
return frequency;
|
||||
}
|
||||
|
||||
/**
|
||||
* Find the nth most frequent element.
|
||||
*
|
||||
* @param list List to search
|
||||
* @param n Position (1-based)
|
||||
* @param <T> Type of elements
|
||||
* @return nth most frequent element, or null if not found
|
||||
*/
|
||||
public static <T> T nthMostFrequent(List<T> list, int n) {
|
||||
if (list == null || list.isEmpty() || n < 1) {
|
||||
return null;
|
||||
}
|
||||
Map<T, Integer> frequency = countFrequency(list);
|
||||
|
||||
// Inefficient: sort all entries to find nth
|
||||
List<Map.Entry<T, Integer>> entries = new ArrayList<>(frequency.entrySet());
|
||||
entries.sort((e1, e2) -> e2.getValue().compareTo(e1.getValue()));
|
||||
|
||||
if (n > entries.size()) {
|
||||
return null;
|
||||
}
|
||||
return entries.get(n - 1).getKey();
|
||||
}
|
||||
|
||||
/**
|
||||
* Partition list into chunks of specified size.
|
||||
*
|
||||
* @param list List to partition
|
||||
* @param chunkSize Size of each chunk
|
||||
* @param <T> Type of elements
|
||||
* @return List of chunks
|
||||
*/
|
||||
public static <T> List<List<T>> partition(List<T> list, int chunkSize) {
|
||||
List<List<T>> result = new ArrayList<>();
|
||||
if (list == null || chunkSize <= 0) {
|
||||
return result;
|
||||
}
|
||||
// Inefficient: creates sublists with copying
|
||||
for (int i = 0; i < list.size(); i += chunkSize) {
|
||||
int end = Math.min(i + chunkSize, list.size());
|
||||
result.add(new ArrayList<>(list.subList(i, end)));
|
||||
}
|
||||
return result;
|
||||
}
|
||||
}
|
||||
|
|
@ -0,0 +1,131 @@
|
|||
package com.example;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
|
||||
/**
|
||||
* String utility class with methods to optimize.
|
||||
*/
|
||||
public class StringUtils {
|
||||
|
||||
/**
|
||||
* Reverse a string character by character.
|
||||
*
|
||||
* @param str String to reverse
|
||||
* @return Reversed string
|
||||
*/
|
||||
public static String reverse(String str) {
|
||||
if (str == null || str.isEmpty()) {
|
||||
return str;
|
||||
}
|
||||
// Inefficient: string concatenation in loop
|
||||
String result = "";
|
||||
for (int i = str.length() - 1; i >= 0; i--) {
|
||||
result = result + str.charAt(i);
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
/**
|
||||
* Check if a string is a palindrome.
|
||||
*
|
||||
* @param str String to check
|
||||
* @return true if palindrome, false otherwise
|
||||
*/
|
||||
public static boolean isPalindrome(String str) {
|
||||
if (str == null) {
|
||||
return false;
|
||||
}
|
||||
// Inefficient: creates reversed string instead of comparing in place
|
||||
String reversed = reverse(str.toLowerCase().replaceAll("\\s+", ""));
|
||||
String cleaned = str.toLowerCase().replaceAll("\\s+", "");
|
||||
return cleaned.equals(reversed);
|
||||
}
|
||||
|
||||
/**
|
||||
* Count occurrences of a substring.
|
||||
*
|
||||
* @param str String to search in
|
||||
* @param sub Substring to find
|
||||
* @return Number of occurrences
|
||||
*/
|
||||
public static int countOccurrences(String str, String sub) {
|
||||
if (str == null || sub == null || sub.isEmpty()) {
|
||||
return 0;
|
||||
}
|
||||
// Inefficient: creates many intermediate strings
|
||||
int count = 0;
|
||||
int index = 0;
|
||||
while ((index = str.indexOf(sub, index)) != -1) {
|
||||
count++;
|
||||
index++;
|
||||
}
|
||||
return count;
|
||||
}
|
||||
|
||||
/**
|
||||
* Find all anagrams of a word in a text.
|
||||
*
|
||||
* @param text Text to search in
|
||||
* @param word Word to find anagrams of
|
||||
* @return List of starting indices of anagrams
|
||||
*/
|
||||
public static List<Integer> findAnagrams(String text, String word) {
|
||||
List<Integer> result = new ArrayList<>();
|
||||
if (text == null || word == null || text.length() < word.length()) {
|
||||
return result;
|
||||
}
|
||||
|
||||
// Inefficient: recalculates sorted word for each position
|
||||
int wordLen = word.length();
|
||||
for (int i = 0; i <= text.length() - wordLen; i++) {
|
||||
String window = text.substring(i, i + wordLen);
|
||||
if (isAnagram(window, word)) {
|
||||
result.add(i);
|
||||
}
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
/**
|
||||
* Check if two strings are anagrams.
|
||||
*
|
||||
* @param s1 First string
|
||||
* @param s2 Second string
|
||||
* @return true if anagrams, false otherwise
|
||||
*/
|
||||
public static boolean isAnagram(String s1, String s2) {
|
||||
if (s1 == null || s2 == null || s1.length() != s2.length()) {
|
||||
return false;
|
||||
}
|
||||
// Inefficient: sorts both strings
|
||||
char[] arr1 = s1.toLowerCase().toCharArray();
|
||||
char[] arr2 = s2.toLowerCase().toCharArray();
|
||||
java.util.Arrays.sort(arr1);
|
||||
java.util.Arrays.sort(arr2);
|
||||
return java.util.Arrays.equals(arr1, arr2);
|
||||
}
|
||||
|
||||
/**
|
||||
* Find longest common prefix of an array of strings.
|
||||
*
|
||||
* @param strings Array of strings
|
||||
* @return Longest common prefix
|
||||
*/
|
||||
public static String longestCommonPrefix(String[] strings) {
|
||||
if (strings == null || strings.length == 0) {
|
||||
return "";
|
||||
}
|
||||
// Inefficient: vertical scanning approach
|
||||
String prefix = strings[0];
|
||||
for (int i = 1; i < strings.length; i++) {
|
||||
while (strings[i].indexOf(prefix) != 0) {
|
||||
prefix = prefix.substring(0, prefix.length() - 1);
|
||||
if (prefix.isEmpty()) {
|
||||
return "";
|
||||
}
|
||||
}
|
||||
}
|
||||
return prefix;
|
||||
}
|
||||
}
|
||||
|
|
@ -0,0 +1,74 @@
|
|||
package com.example.helpers;
|
||||
|
||||
/**
|
||||
* Formatting utility functions.
|
||||
*/
|
||||
public class Formatter {
|
||||
|
||||
/**
|
||||
* Format a number with specified decimal places.
|
||||
*
|
||||
* @param value Number to format
|
||||
* @param decimals Number of decimal places
|
||||
* @return Formatted number as string
|
||||
*/
|
||||
public static String formatNumber(double value, int decimals) {
|
||||
return String.format("%." + decimals + "f", value);
|
||||
}
|
||||
|
||||
/**
|
||||
* Validate that input is a positive number.
|
||||
*
|
||||
* @param value Value to validate
|
||||
* @param name Name of the parameter (for error message)
|
||||
* @throws IllegalArgumentException if value is not positive
|
||||
*/
|
||||
public static void validateInput(double value, String name) {
|
||||
if (value < 0) {
|
||||
throw new IllegalArgumentException(name + " must be non-negative, got: " + value);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Convert number to percentage string.
|
||||
*
|
||||
* @param value Decimal value (0.5 = 50%)
|
||||
* @return Percentage string
|
||||
*/
|
||||
public static String toPercentage(double value) {
|
||||
return formatNumber(value * 100, 2) + "%";
|
||||
}
|
||||
|
||||
/**
|
||||
* Pad a string to specified length.
|
||||
*
|
||||
* @param str String to pad
|
||||
* @param length Target length
|
||||
* @param padChar Character to pad with
|
||||
* @return Padded string
|
||||
*/
|
||||
public static String padLeft(String str, int length, char padChar) {
|
||||
// Inefficient: creates many intermediate strings
|
||||
StringBuilder result = new StringBuilder(str);
|
||||
while (result.length() < length) {
|
||||
result.insert(0, padChar);
|
||||
}
|
||||
return result.toString();
|
||||
}
|
||||
|
||||
/**
|
||||
* Repeat a string n times.
|
||||
*
|
||||
* @param str String to repeat
|
||||
* @param times Number of repetitions
|
||||
* @return Repeated string
|
||||
*/
|
||||
public static String repeat(String str, int times) {
|
||||
// Inefficient: string concatenation in loop
|
||||
String result = "";
|
||||
for (int i = 0; i < times; i++) {
|
||||
result = result + str;
|
||||
}
|
||||
return result;
|
||||
}
|
||||
}
|
||||
|
|
@ -0,0 +1,108 @@
|
|||
package com.example.helpers;
|
||||
|
||||
/**
|
||||
* Math utility functions - basic arithmetic operations.
|
||||
*/
|
||||
public class MathHelper {
|
||||
|
||||
/**
|
||||
* Add two numbers.
|
||||
*
|
||||
* @param a First number
|
||||
* @param b Second number
|
||||
* @return Sum of a and b
|
||||
*/
|
||||
public static double add(double a, double b) {
|
||||
return a + b;
|
||||
}
|
||||
|
||||
/**
|
||||
* Multiply two numbers.
|
||||
*
|
||||
* @param a First number
|
||||
* @param b Second number
|
||||
* @return Product of a and b
|
||||
*/
|
||||
public static double multiply(double a, double b) {
|
||||
return a * b;
|
||||
}
|
||||
|
||||
/**
|
||||
* Calculate factorial recursively.
|
||||
*
|
||||
* @param n Non-negative integer
|
||||
* @return Factorial of n
|
||||
* @throws IllegalArgumentException if n is negative
|
||||
*/
|
||||
public static long factorial(int n) {
|
||||
if (n < 0) {
|
||||
throw new IllegalArgumentException("Factorial not defined for negative numbers");
|
||||
}
|
||||
// Intentionally inefficient recursive implementation
|
||||
if (n <= 1) {
|
||||
return 1;
|
||||
}
|
||||
return n * factorial(n - 1);
|
||||
}
|
||||
|
||||
/**
|
||||
* Calculate power using repeated multiplication.
|
||||
*
|
||||
* @param base Base number
|
||||
* @param exp Exponent (non-negative)
|
||||
* @return base raised to exp
|
||||
*/
|
||||
public static double power(double base, int exp) {
|
||||
// Inefficient: linear time instead of log time
|
||||
double result = 1;
|
||||
for (int i = 0; i < exp; i++) {
|
||||
result = multiply(result, base);
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
/**
|
||||
* Check if a number is prime.
|
||||
*
|
||||
* @param n Number to check
|
||||
* @return true if n is prime, false otherwise
|
||||
*/
|
||||
public static boolean isPrime(int n) {
|
||||
if (n < 2) {
|
||||
return false;
|
||||
}
|
||||
// Inefficient: checks all numbers up to n-1
|
||||
for (int i = 2; i < n; i++) {
|
||||
if (n % i == 0) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
/**
|
||||
* Calculate greatest common divisor using Euclidean algorithm.
|
||||
*
|
||||
* @param a First number
|
||||
* @param b Second number
|
||||
* @return GCD of a and b
|
||||
*/
|
||||
public static int gcd(int a, int b) {
|
||||
// Inefficient recursive implementation
|
||||
if (b == 0) {
|
||||
return a;
|
||||
}
|
||||
return gcd(b, a % b);
|
||||
}
|
||||
|
||||
/**
|
||||
* Calculate least common multiple.
|
||||
*
|
||||
* @param a First number
|
||||
* @param b Second number
|
||||
* @return LCM of a and b
|
||||
*/
|
||||
public static int lcm(int a, int b) {
|
||||
return (a * b) / gcd(a, b);
|
||||
}
|
||||
}
|
||||
|
|
@ -0,0 +1,170 @@
|
|||
package com.example;
|
||||
|
||||
import org.junit.jupiter.api.BeforeEach;
|
||||
import org.junit.jupiter.api.DisplayName;
|
||||
import org.junit.jupiter.api.Nested;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.junit.jupiter.params.ParameterizedTest;
|
||||
import org.junit.jupiter.params.provider.CsvSource;
|
||||
|
||||
import static org.junit.jupiter.api.Assertions.*;
|
||||
|
||||
/**
|
||||
* Tests for the Calculator class.
|
||||
*/
|
||||
@DisplayName("Calculator Tests")
|
||||
class CalculatorTest {
|
||||
|
||||
private Calculator calculator;
|
||||
|
||||
@BeforeEach
|
||||
void setUp() {
|
||||
calculator = new Calculator(2);
|
||||
}
|
||||
|
||||
@Nested
|
||||
@DisplayName("Compound Interest Tests")
|
||||
class CompoundInterestTests {
|
||||
|
||||
@Test
|
||||
@DisplayName("should calculate compound interest for basic case")
|
||||
void testBasicCompoundInterest() {
|
||||
String result = calculator.calculateCompoundInterest(1000.0, 0.05, 1, 12);
|
||||
assertNotNull(result);
|
||||
assertTrue(result.contains("."));
|
||||
}
|
||||
|
||||
@Test
|
||||
@DisplayName("should handle zero principal")
|
||||
void testZeroPrincipal() {
|
||||
String result = calculator.calculateCompoundInterest(0.0, 0.05, 1, 12);
|
||||
assertEquals("0.00", result);
|
||||
}
|
||||
|
||||
@Test
|
||||
@DisplayName("should throw on negative principal")
|
||||
void testNegativePrincipal() {
|
||||
assertThrows(IllegalArgumentException.class, () ->
|
||||
calculator.calculateCompoundInterest(-100.0, 0.05, 1, 12)
|
||||
);
|
||||
}
|
||||
|
||||
@ParameterizedTest
|
||||
@CsvSource({
|
||||
"1000, 0.05, 1, 12",
|
||||
"5000, 0.08, 2, 4",
|
||||
"10000, 0.03, 5, 1"
|
||||
})
|
||||
@DisplayName("should calculate for various inputs")
|
||||
void testVariousInputs(double principal, double rate, int time, int n) {
|
||||
String result = calculator.calculateCompoundInterest(principal, rate, time, n);
|
||||
assertNotNull(result);
|
||||
assertFalse(result.isEmpty());
|
||||
}
|
||||
}
|
||||
|
||||
@Nested
|
||||
@DisplayName("Permutation Tests")
|
||||
class PermutationTests {
|
||||
|
||||
@Test
|
||||
@DisplayName("should calculate permutation correctly")
|
||||
void testBasicPermutation() {
|
||||
assertEquals(120, calculator.permutation(5, 5));
|
||||
assertEquals(60, calculator.permutation(5, 3));
|
||||
assertEquals(20, calculator.permutation(5, 2));
|
||||
}
|
||||
|
||||
@Test
|
||||
@DisplayName("should return 0 when n < r")
|
||||
void testInvalidPermutation() {
|
||||
assertEquals(0, calculator.permutation(3, 5));
|
||||
}
|
||||
|
||||
@Test
|
||||
@DisplayName("should handle edge cases")
|
||||
void testEdgeCases() {
|
||||
assertEquals(1, calculator.permutation(5, 0));
|
||||
assertEquals(1, calculator.permutation(0, 0));
|
||||
}
|
||||
}
|
||||
|
||||
@Nested
|
||||
@DisplayName("Combination Tests")
|
||||
class CombinationTests {
|
||||
|
||||
@Test
|
||||
@DisplayName("should calculate combination correctly")
|
||||
void testBasicCombination() {
|
||||
assertEquals(10, calculator.combination(5, 3));
|
||||
assertEquals(10, calculator.combination(5, 2));
|
||||
assertEquals(1, calculator.combination(5, 5));
|
||||
}
|
||||
|
||||
@Test
|
||||
@DisplayName("should return 0 when n < r")
|
||||
void testInvalidCombination() {
|
||||
assertEquals(0, calculator.combination(3, 5));
|
||||
}
|
||||
}
|
||||
|
||||
@Nested
|
||||
@DisplayName("Fibonacci Tests")
|
||||
class FibonacciTests {
|
||||
|
||||
@Test
|
||||
@DisplayName("should calculate fibonacci correctly")
|
||||
void testFibonacci() {
|
||||
assertEquals(0, calculator.fibonacci(0));
|
||||
assertEquals(1, calculator.fibonacci(1));
|
||||
assertEquals(1, calculator.fibonacci(2));
|
||||
assertEquals(2, calculator.fibonacci(3));
|
||||
assertEquals(5, calculator.fibonacci(5));
|
||||
assertEquals(55, calculator.fibonacci(10));
|
||||
}
|
||||
|
||||
@ParameterizedTest
|
||||
@CsvSource({
|
||||
"0, 0",
|
||||
"1, 1",
|
||||
"2, 1",
|
||||
"3, 2",
|
||||
"4, 3",
|
||||
"5, 5",
|
||||
"6, 8",
|
||||
"7, 13"
|
||||
})
|
||||
@DisplayName("should match expected sequence")
|
||||
void testFibonacciSequence(int n, long expected) {
|
||||
assertEquals(expected, calculator.fibonacci(n));
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
@DisplayName("static quickAdd should work correctly")
|
||||
void testQuickAdd() {
|
||||
assertEquals(15.0, Calculator.quickAdd(10.0, 5.0));
|
||||
assertEquals(0.0, Calculator.quickAdd(-5.0, 5.0));
|
||||
assertEquals(-10.0, Calculator.quickAdd(-5.0, -5.0));
|
||||
}
|
||||
|
||||
@Test
|
||||
@DisplayName("should track calculation history")
|
||||
void testHistory() {
|
||||
calculator.calculateCompoundInterest(1000.0, 0.05, 1, 12);
|
||||
calculator.calculateCompoundInterest(2000.0, 0.03, 2, 4);
|
||||
|
||||
var history = calculator.getHistory();
|
||||
assertEquals(2, history.size());
|
||||
assertTrue(history.get(0).startsWith("compound:"));
|
||||
}
|
||||
|
||||
@Test
|
||||
@DisplayName("should return correct precision")
|
||||
void testPrecision() {
|
||||
assertEquals(2, calculator.getPrecision());
|
||||
|
||||
Calculator customCalc = new Calculator(4);
|
||||
assertEquals(4, customCalc.getPrecision());
|
||||
}
|
||||
}
|
||||
|
|
@ -0,0 +1,265 @@
|
|||
package com.example;
|
||||
|
||||
import org.junit.jupiter.api.DisplayName;
|
||||
import org.junit.jupiter.api.Nested;
|
||||
import org.junit.jupiter.api.Test;
|
||||
|
||||
import java.util.Arrays;
|
||||
import java.util.Collections;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
|
||||
import static org.junit.jupiter.api.Assertions.*;
|
||||
|
||||
/**
|
||||
* Tests for the DataProcessor class.
|
||||
*/
|
||||
@DisplayName("DataProcessor Tests")
|
||||
class DataProcessorTest {
|
||||
|
||||
@Nested
|
||||
@DisplayName("findDuplicates() Tests")
|
||||
class FindDuplicatesTests {
|
||||
|
||||
@Test
|
||||
@DisplayName("should find duplicates in list")
|
||||
void testFindDuplicates() {
|
||||
List<Integer> input = Arrays.asList(1, 2, 3, 2, 4, 3, 5);
|
||||
List<Integer> duplicates = DataProcessor.findDuplicates(input);
|
||||
|
||||
assertEquals(2, duplicates.size());
|
||||
assertTrue(duplicates.contains(2));
|
||||
assertTrue(duplicates.contains(3));
|
||||
}
|
||||
|
||||
@Test
|
||||
@DisplayName("should return empty for no duplicates")
|
||||
void testNoDuplicates() {
|
||||
List<Integer> input = Arrays.asList(1, 2, 3, 4, 5);
|
||||
List<Integer> duplicates = DataProcessor.findDuplicates(input);
|
||||
|
||||
assertTrue(duplicates.isEmpty());
|
||||
}
|
||||
|
||||
@Test
|
||||
@DisplayName("should handle null input")
|
||||
void testNullInput() {
|
||||
List<Integer> duplicates = DataProcessor.findDuplicates(null);
|
||||
assertTrue(duplicates.isEmpty());
|
||||
}
|
||||
|
||||
@Test
|
||||
@DisplayName("should handle strings")
|
||||
void testStrings() {
|
||||
List<String> input = Arrays.asList("a", "b", "a", "c", "b", "d");
|
||||
List<String> duplicates = DataProcessor.findDuplicates(input);
|
||||
|
||||
assertEquals(2, duplicates.size());
|
||||
assertTrue(duplicates.contains("a"));
|
||||
assertTrue(duplicates.contains("b"));
|
||||
}
|
||||
}
|
||||
|
||||
@Nested
|
||||
@DisplayName("groupBy() Tests")
|
||||
class GroupByTests {
|
||||
|
||||
@Test
|
||||
@DisplayName("should group by length")
|
||||
void testGroupByLength() {
|
||||
List<String> input = Arrays.asList("a", "bb", "ccc", "dd", "e", "fff");
|
||||
Map<Integer, List<String>> grouped = DataProcessor.groupBy(input, String::length);
|
||||
|
||||
assertEquals(3, grouped.size());
|
||||
assertEquals(2, grouped.get(1).size());
|
||||
assertEquals(2, grouped.get(2).size());
|
||||
assertEquals(2, grouped.get(3).size());
|
||||
}
|
||||
|
||||
@Test
|
||||
@DisplayName("should group by first character")
|
||||
void testGroupByFirstChar() {
|
||||
List<String> input = Arrays.asList("apple", "apricot", "banana", "blueberry");
|
||||
Map<Character, List<String>> grouped = DataProcessor.groupBy(input, s -> s.charAt(0));
|
||||
|
||||
assertEquals(2, grouped.size());
|
||||
assertEquals(2, grouped.get('a').size());
|
||||
assertEquals(2, grouped.get('b').size());
|
||||
}
|
||||
|
||||
@Test
|
||||
@DisplayName("should handle null input")
|
||||
void testNullInput() {
|
||||
Map<Integer, List<String>> grouped = DataProcessor.groupBy(null, String::length);
|
||||
assertTrue(grouped.isEmpty());
|
||||
}
|
||||
}
|
||||
|
||||
@Nested
|
||||
@DisplayName("intersection() Tests")
|
||||
class IntersectionTests {
|
||||
|
||||
@Test
|
||||
@DisplayName("should find intersection")
|
||||
void testIntersection() {
|
||||
List<Integer> list1 = Arrays.asList(1, 2, 3, 4, 5);
|
||||
List<Integer> list2 = Arrays.asList(4, 5, 6, 7, 8);
|
||||
List<Integer> result = DataProcessor.intersection(list1, list2);
|
||||
|
||||
assertEquals(2, result.size());
|
||||
assertTrue(result.contains(4));
|
||||
assertTrue(result.contains(5));
|
||||
}
|
||||
|
||||
@Test
|
||||
@DisplayName("should return empty for no intersection")
|
||||
void testNoIntersection() {
|
||||
List<Integer> list1 = Arrays.asList(1, 2, 3);
|
||||
List<Integer> list2 = Arrays.asList(4, 5, 6);
|
||||
List<Integer> result = DataProcessor.intersection(list1, list2);
|
||||
|
||||
assertTrue(result.isEmpty());
|
||||
}
|
||||
|
||||
@Test
|
||||
@DisplayName("should handle null inputs")
|
||||
void testNullInputs() {
|
||||
assertTrue(DataProcessor.intersection(null, Arrays.asList(1, 2, 3)).isEmpty());
|
||||
assertTrue(DataProcessor.intersection(Arrays.asList(1, 2, 3), null).isEmpty());
|
||||
}
|
||||
|
||||
@Test
|
||||
@DisplayName("should not include duplicates")
|
||||
void testNoDuplicates() {
|
||||
List<Integer> list1 = Arrays.asList(1, 1, 2, 2, 3);
|
||||
List<Integer> list2 = Arrays.asList(1, 2, 2, 4);
|
||||
List<Integer> result = DataProcessor.intersection(list1, list2);
|
||||
|
||||
assertEquals(2, result.size());
|
||||
}
|
||||
}
|
||||
|
||||
@Nested
|
||||
@DisplayName("flatten() Tests")
|
||||
class FlattenTests {
|
||||
|
||||
@Test
|
||||
@DisplayName("should flatten nested lists")
|
||||
void testFlatten() {
|
||||
List<List<Integer>> nested = Arrays.asList(
|
||||
Arrays.asList(1, 2, 3),
|
||||
Arrays.asList(4, 5),
|
||||
Arrays.asList(6, 7, 8, 9)
|
||||
);
|
||||
List<Integer> result = DataProcessor.flatten(nested);
|
||||
|
||||
assertEquals(9, result.size());
|
||||
assertEquals(Arrays.asList(1, 2, 3, 4, 5, 6, 7, 8, 9), result);
|
||||
}
|
||||
|
||||
@Test
|
||||
@DisplayName("should handle empty inner lists")
|
||||
void testEmptyInnerLists() {
|
||||
List<List<Integer>> nested = Arrays.asList(
|
||||
Arrays.asList(1, 2),
|
||||
Collections.emptyList(),
|
||||
Arrays.asList(3, 4)
|
||||
);
|
||||
List<Integer> result = DataProcessor.flatten(nested);
|
||||
|
||||
assertEquals(4, result.size());
|
||||
}
|
||||
|
||||
@Test
|
||||
@DisplayName("should handle null")
|
||||
void testNull() {
|
||||
assertTrue(DataProcessor.flatten(null).isEmpty());
|
||||
}
|
||||
}
|
||||
|
||||
@Nested
|
||||
@DisplayName("countFrequency() Tests")
|
||||
class CountFrequencyTests {
|
||||
|
||||
@Test
|
||||
@DisplayName("should count frequencies correctly")
|
||||
void testCountFrequency() {
|
||||
List<String> input = Arrays.asList("a", "b", "a", "c", "a", "b");
|
||||
Map<String, Integer> freq = DataProcessor.countFrequency(input);
|
||||
|
||||
assertEquals(3, freq.get("a"));
|
||||
assertEquals(2, freq.get("b"));
|
||||
assertEquals(1, freq.get("c"));
|
||||
}
|
||||
|
||||
@Test
|
||||
@DisplayName("should handle null input")
|
||||
void testNullInput() {
|
||||
assertTrue(DataProcessor.countFrequency(null).isEmpty());
|
||||
}
|
||||
}
|
||||
|
||||
@Nested
|
||||
@DisplayName("nthMostFrequent() Tests")
|
||||
class NthMostFrequentTests {
|
||||
|
||||
@Test
|
||||
@DisplayName("should find nth most frequent")
|
||||
void testNthMostFrequent() {
|
||||
List<String> input = Arrays.asList("a", "b", "a", "c", "a", "b", "d");
|
||||
|
||||
assertEquals("a", DataProcessor.nthMostFrequent(input, 1));
|
||||
assertEquals("b", DataProcessor.nthMostFrequent(input, 2));
|
||||
}
|
||||
|
||||
@Test
|
||||
@DisplayName("should return null for invalid n")
|
||||
void testInvalidN() {
|
||||
List<String> input = Arrays.asList("a", "b", "c");
|
||||
|
||||
assertNull(DataProcessor.nthMostFrequent(input, 0));
|
||||
assertNull(DataProcessor.nthMostFrequent(input, 10));
|
||||
}
|
||||
|
||||
@Test
|
||||
@DisplayName("should handle null input")
|
||||
void testNullInput() {
|
||||
assertNull(DataProcessor.nthMostFrequent(null, 1));
|
||||
}
|
||||
}
|
||||
|
||||
@Nested
|
||||
@DisplayName("partition() Tests")
|
||||
class PartitionTests {
|
||||
|
||||
@Test
|
||||
@DisplayName("should partition into chunks")
|
||||
void testPartition() {
|
||||
List<Integer> input = Arrays.asList(1, 2, 3, 4, 5, 6, 7);
|
||||
List<List<Integer>> chunks = DataProcessor.partition(input, 3);
|
||||
|
||||
assertEquals(3, chunks.size());
|
||||
assertEquals(Arrays.asList(1, 2, 3), chunks.get(0));
|
||||
assertEquals(Arrays.asList(4, 5, 6), chunks.get(1));
|
||||
assertEquals(Collections.singletonList(7), chunks.get(2));
|
||||
}
|
||||
|
||||
@Test
|
||||
@DisplayName("should handle exact division")
|
||||
void testExactDivision() {
|
||||
List<Integer> input = Arrays.asList(1, 2, 3, 4, 5, 6);
|
||||
List<List<Integer>> chunks = DataProcessor.partition(input, 2);
|
||||
|
||||
assertEquals(3, chunks.size());
|
||||
chunks.forEach(chunk -> assertEquals(2, chunk.size()));
|
||||
}
|
||||
|
||||
@Test
|
||||
@DisplayName("should handle null and invalid chunk size")
|
||||
void testInvalidInputs() {
|
||||
assertTrue(DataProcessor.partition(null, 3).isEmpty());
|
||||
assertTrue(DataProcessor.partition(Arrays.asList(1, 2, 3), 0).isEmpty());
|
||||
assertTrue(DataProcessor.partition(Arrays.asList(1, 2, 3), -1).isEmpty());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -0,0 +1,219 @@
|
|||
package com.example;
|
||||
|
||||
import org.junit.jupiter.api.DisplayName;
|
||||
import org.junit.jupiter.api.Nested;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.junit.jupiter.params.ParameterizedTest;
|
||||
import org.junit.jupiter.params.provider.CsvSource;
|
||||
import org.junit.jupiter.params.provider.NullAndEmptySource;
|
||||
import org.junit.jupiter.params.provider.ValueSource;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
import static org.junit.jupiter.api.Assertions.*;
|
||||
|
||||
/**
|
||||
* Tests for the StringUtils class.
|
||||
*/
|
||||
@DisplayName("StringUtils Tests")
|
||||
class StringUtilsTest {
|
||||
|
||||
@Nested
|
||||
@DisplayName("reverse() Tests")
|
||||
class ReverseTests {
|
||||
|
||||
@Test
|
||||
@DisplayName("should reverse a simple string")
|
||||
void testReverseSimple() {
|
||||
assertEquals("olleh", StringUtils.reverse("hello"));
|
||||
assertEquals("dlrow", StringUtils.reverse("world"));
|
||||
}
|
||||
|
||||
@Test
|
||||
@DisplayName("should handle single character")
|
||||
void testReverseSingleChar() {
|
||||
assertEquals("a", StringUtils.reverse("a"));
|
||||
}
|
||||
|
||||
@ParameterizedTest
|
||||
@NullAndEmptySource
|
||||
@DisplayName("should handle null and empty strings")
|
||||
void testReverseNullEmpty(String input) {
|
||||
assertEquals(input, StringUtils.reverse(input));
|
||||
}
|
||||
|
||||
@Test
|
||||
@DisplayName("should handle palindrome")
|
||||
void testReversePalindrome() {
|
||||
assertEquals("radar", StringUtils.reverse("radar"));
|
||||
}
|
||||
}
|
||||
|
||||
@Nested
|
||||
@DisplayName("isPalindrome() Tests")
|
||||
class PalindromeTests {
|
||||
|
||||
@ParameterizedTest
|
||||
@ValueSource(strings = {"radar", "level", "civic", "rotor", "kayak"})
|
||||
@DisplayName("should return true for palindromes")
|
||||
void testPalindromes(String input) {
|
||||
assertTrue(StringUtils.isPalindrome(input));
|
||||
}
|
||||
|
||||
@ParameterizedTest
|
||||
@ValueSource(strings = {"hello", "world", "java", "python"})
|
||||
@DisplayName("should return false for non-palindromes")
|
||||
void testNonPalindromes(String input) {
|
||||
assertFalse(StringUtils.isPalindrome(input));
|
||||
}
|
||||
|
||||
@Test
|
||||
@DisplayName("should handle case insensitivity")
|
||||
void testCaseInsensitive() {
|
||||
assertTrue(StringUtils.isPalindrome("Radar"));
|
||||
assertTrue(StringUtils.isPalindrome("LEVEL"));
|
||||
}
|
||||
|
||||
@Test
|
||||
@DisplayName("should ignore spaces")
|
||||
void testIgnoreSpaces() {
|
||||
assertTrue(StringUtils.isPalindrome("race car"));
|
||||
assertTrue(StringUtils.isPalindrome("A man a plan a canal Panama"));
|
||||
}
|
||||
|
||||
@Test
|
||||
@DisplayName("should return false for null")
|
||||
void testNull() {
|
||||
assertFalse(StringUtils.isPalindrome(null));
|
||||
}
|
||||
}
|
||||
|
||||
@Nested
|
||||
@DisplayName("countOccurrences() Tests")
|
||||
class CountOccurrencesTests {
|
||||
|
||||
@Test
|
||||
@DisplayName("should count occurrences correctly")
|
||||
void testCount() {
|
||||
assertEquals(3, StringUtils.countOccurrences("abcabc abc", "abc"));
|
||||
assertEquals(2, StringUtils.countOccurrences("hello hello", "hello"));
|
||||
}
|
||||
|
||||
@Test
|
||||
@DisplayName("should return 0 for no matches")
|
||||
void testNoMatches() {
|
||||
assertEquals(0, StringUtils.countOccurrences("hello world", "xyz"));
|
||||
}
|
||||
|
||||
@ParameterizedTest
|
||||
@CsvSource({
|
||||
"'aaaaaa', 'aa', 5",
|
||||
"'banana', 'ana', 2",
|
||||
"'mississippi', 'issi', 2"
|
||||
})
|
||||
@DisplayName("should handle overlapping matches")
|
||||
void testOverlapping(String str, String sub, int expected) {
|
||||
assertEquals(expected, StringUtils.countOccurrences(str, sub));
|
||||
}
|
||||
|
||||
@Test
|
||||
@DisplayName("should handle null inputs")
|
||||
void testNullInputs() {
|
||||
assertEquals(0, StringUtils.countOccurrences(null, "test"));
|
||||
assertEquals(0, StringUtils.countOccurrences("test", null));
|
||||
assertEquals(0, StringUtils.countOccurrences("test", ""));
|
||||
}
|
||||
}
|
||||
|
||||
@Nested
|
||||
@DisplayName("isAnagram() Tests")
|
||||
class AnagramTests {
|
||||
|
||||
@Test
|
||||
@DisplayName("should detect anagrams")
|
||||
void testAnagrams() {
|
||||
assertTrue(StringUtils.isAnagram("listen", "silent"));
|
||||
assertTrue(StringUtils.isAnagram("evil", "vile"));
|
||||
assertTrue(StringUtils.isAnagram("anagram", "nagaram"));
|
||||
}
|
||||
|
||||
@Test
|
||||
@DisplayName("should reject non-anagrams")
|
||||
void testNonAnagrams() {
|
||||
assertFalse(StringUtils.isAnagram("hello", "world"));
|
||||
assertFalse(StringUtils.isAnagram("abc", "abcd"));
|
||||
}
|
||||
|
||||
@Test
|
||||
@DisplayName("should be case insensitive")
|
||||
void testCaseInsensitive() {
|
||||
assertTrue(StringUtils.isAnagram("Listen", "Silent"));
|
||||
}
|
||||
|
||||
@Test
|
||||
@DisplayName("should handle null inputs")
|
||||
void testNullInputs() {
|
||||
assertFalse(StringUtils.isAnagram(null, "test"));
|
||||
assertFalse(StringUtils.isAnagram("test", null));
|
||||
}
|
||||
}
|
||||
|
||||
@Nested
|
||||
@DisplayName("findAnagrams() Tests")
|
||||
class FindAnagramsTests {
|
||||
|
||||
@Test
|
||||
@DisplayName("should find all anagram positions")
|
||||
void testFindAnagrams() {
|
||||
List<Integer> result = StringUtils.findAnagrams("cbaebabacd", "abc");
|
||||
assertEquals(2, result.size());
|
||||
assertTrue(result.contains(0));
|
||||
assertTrue(result.contains(6));
|
||||
}
|
||||
|
||||
@Test
|
||||
@DisplayName("should return empty list for no matches")
|
||||
void testNoMatches() {
|
||||
List<Integer> result = StringUtils.findAnagrams("hello", "xyz");
|
||||
assertTrue(result.isEmpty());
|
||||
}
|
||||
|
||||
@Test
|
||||
@DisplayName("should handle null inputs")
|
||||
void testNullInputs() {
|
||||
assertTrue(StringUtils.findAnagrams(null, "abc").isEmpty());
|
||||
assertTrue(StringUtils.findAnagrams("abc", null).isEmpty());
|
||||
}
|
||||
}
|
||||
|
||||
@Nested
|
||||
@DisplayName("longestCommonPrefix() Tests")
|
||||
class LongestCommonPrefixTests {
|
||||
|
||||
@Test
|
||||
@DisplayName("should find common prefix")
|
||||
void testCommonPrefix() {
|
||||
assertEquals("fl", StringUtils.longestCommonPrefix(new String[]{"flower", "flow", "flight"}));
|
||||
assertEquals("ap", StringUtils.longestCommonPrefix(new String[]{"apple", "ape", "april"}));
|
||||
}
|
||||
|
||||
@Test
|
||||
@DisplayName("should return empty for no common prefix")
|
||||
void testNoCommonPrefix() {
|
||||
assertEquals("", StringUtils.longestCommonPrefix(new String[]{"dog", "car", "race"}));
|
||||
}
|
||||
|
||||
@Test
|
||||
@DisplayName("should handle single string")
|
||||
void testSingleString() {
|
||||
assertEquals("hello", StringUtils.longestCommonPrefix(new String[]{"hello"}));
|
||||
}
|
||||
|
||||
@Test
|
||||
@DisplayName("should handle null and empty array")
|
||||
void testNullEmpty() {
|
||||
assertEquals("", StringUtils.longestCommonPrefix(null));
|
||||
assertEquals("", StringUtils.longestCommonPrefix(new String[]{}));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -29,17 +29,20 @@ class TestLanguageEnum:
|
|||
assert Language.PYTHON.value == "python"
|
||||
assert Language.JAVASCRIPT.value == "javascript"
|
||||
assert Language.TYPESCRIPT.value == "typescript"
|
||||
assert Language.JAVA.value == "java"
|
||||
|
||||
def test_language_str(self):
|
||||
"""Test string conversion of Language enum."""
|
||||
assert str(Language.PYTHON) == "python"
|
||||
assert str(Language.JAVASCRIPT) == "javascript"
|
||||
assert str(Language.JAVA) == "java"
|
||||
|
||||
def test_language_from_string(self):
|
||||
"""Test creating Language from string."""
|
||||
assert Language("python") == Language.PYTHON
|
||||
assert Language("javascript") == Language.JAVASCRIPT
|
||||
assert Language("typescript") == Language.TYPESCRIPT
|
||||
assert Language("java") == Language.JAVA
|
||||
|
||||
def test_invalid_language_raises(self):
|
||||
"""Test that invalid language string raises ValueError."""
|
||||
|
|
|
|||
1
tests/test_languages/test_java/__init__.py
Normal file
1
tests/test_languages/test_java/__init__.py
Normal file
|
|
@ -0,0 +1 @@
|
|||
"""Tests for Java language support."""
|
||||
279
tests/test_languages/test_java/test_build_tools.py
Normal file
279
tests/test_languages/test_java/test_build_tools.py
Normal file
|
|
@ -0,0 +1,279 @@
|
|||
"""Tests for Java build tool detection and integration."""
|
||||
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
from codeflash.languages.java.build_tools import (
|
||||
BuildTool,
|
||||
detect_build_tool,
|
||||
find_maven_executable,
|
||||
find_source_root,
|
||||
find_test_root,
|
||||
get_project_info,
|
||||
)
|
||||
|
||||
|
||||
class TestBuildToolDetection:
|
||||
"""Tests for build tool detection."""
|
||||
|
||||
def test_detect_maven_project(self, tmp_path: Path):
|
||||
"""Test detecting a Maven project."""
|
||||
# Create pom.xml
|
||||
pom_content = """<?xml version="1.0" encoding="UTF-8"?>
|
||||
<project xmlns="http://maven.apache.org/POM/4.0.0"
|
||||
xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
|
||||
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0
|
||||
http://maven.apache.org/xsd/maven-4.0.0.xsd">
|
||||
<modelVersion>4.0.0</modelVersion>
|
||||
<groupId>com.example</groupId>
|
||||
<artifactId>my-app</artifactId>
|
||||
<version>1.0.0</version>
|
||||
</project>
|
||||
"""
|
||||
(tmp_path / "pom.xml").write_text(pom_content)
|
||||
|
||||
assert detect_build_tool(tmp_path) == BuildTool.MAVEN
|
||||
|
||||
def test_detect_gradle_project(self, tmp_path: Path):
|
||||
"""Test detecting a Gradle project."""
|
||||
# Create build.gradle
|
||||
(tmp_path / "build.gradle").write_text("plugins { id 'java' }")
|
||||
|
||||
assert detect_build_tool(tmp_path) == BuildTool.GRADLE
|
||||
|
||||
def test_detect_gradle_kotlin_project(self, tmp_path: Path):
|
||||
"""Test detecting a Gradle Kotlin DSL project."""
|
||||
# Create build.gradle.kts
|
||||
(tmp_path / "build.gradle.kts").write_text('plugins { java }')
|
||||
|
||||
assert detect_build_tool(tmp_path) == BuildTool.GRADLE
|
||||
|
||||
def test_detect_unknown_project(self, tmp_path: Path):
|
||||
"""Test detecting unknown project type."""
|
||||
# Empty directory
|
||||
assert detect_build_tool(tmp_path) == BuildTool.UNKNOWN
|
||||
|
||||
def test_maven_takes_precedence(self, tmp_path: Path):
|
||||
"""Test that Maven takes precedence if both exist."""
|
||||
# Create both pom.xml and build.gradle
|
||||
(tmp_path / "pom.xml").write_text("<project></project>")
|
||||
(tmp_path / "build.gradle").write_text("plugins { id 'java' }")
|
||||
|
||||
# Maven should be detected first
|
||||
assert detect_build_tool(tmp_path) == BuildTool.MAVEN
|
||||
|
||||
|
||||
class TestMavenProjectInfo:
|
||||
"""Tests for Maven project info extraction."""
|
||||
|
||||
def test_get_maven_project_info(self, tmp_path: Path):
|
||||
"""Test extracting project info from pom.xml."""
|
||||
pom_content = """<?xml version="1.0" encoding="UTF-8"?>
|
||||
<project xmlns="http://maven.apache.org/POM/4.0.0"
|
||||
xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
|
||||
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0
|
||||
http://maven.apache.org/xsd/maven-4.0.0.xsd">
|
||||
<modelVersion>4.0.0</modelVersion>
|
||||
<groupId>com.example</groupId>
|
||||
<artifactId>my-app</artifactId>
|
||||
<version>1.0.0</version>
|
||||
|
||||
<properties>
|
||||
<maven.compiler.source>11</maven.compiler.source>
|
||||
<maven.compiler.target>11</maven.compiler.target>
|
||||
</properties>
|
||||
</project>
|
||||
"""
|
||||
(tmp_path / "pom.xml").write_text(pom_content)
|
||||
|
||||
# Create standard Maven directory structure
|
||||
(tmp_path / "src" / "main" / "java").mkdir(parents=True)
|
||||
(tmp_path / "src" / "test" / "java").mkdir(parents=True)
|
||||
|
||||
info = get_project_info(tmp_path)
|
||||
|
||||
assert info is not None
|
||||
assert info.build_tool == BuildTool.MAVEN
|
||||
assert info.group_id == "com.example"
|
||||
assert info.artifact_id == "my-app"
|
||||
assert info.version == "1.0.0"
|
||||
assert info.java_version == "11"
|
||||
assert len(info.source_roots) == 1
|
||||
assert len(info.test_roots) == 1
|
||||
|
||||
def test_get_maven_project_info_with_java_version_property(self, tmp_path: Path):
|
||||
"""Test extracting Java version from java.version property."""
|
||||
pom_content = """<?xml version="1.0" encoding="UTF-8"?>
|
||||
<project>
|
||||
<modelVersion>4.0.0</modelVersion>
|
||||
<groupId>com.example</groupId>
|
||||
<artifactId>my-app</artifactId>
|
||||
<version>1.0.0</version>
|
||||
|
||||
<properties>
|
||||
<java.version>17</java.version>
|
||||
</properties>
|
||||
</project>
|
||||
"""
|
||||
(tmp_path / "pom.xml").write_text(pom_content)
|
||||
(tmp_path / "src" / "main" / "java").mkdir(parents=True)
|
||||
|
||||
info = get_project_info(tmp_path)
|
||||
|
||||
assert info is not None
|
||||
assert info.java_version == "17"
|
||||
|
||||
|
||||
class TestDirectoryDetection:
|
||||
"""Tests for source and test directory detection."""
|
||||
|
||||
def test_find_maven_source_root(self, tmp_path: Path):
|
||||
"""Test finding Maven source root."""
|
||||
(tmp_path / "pom.xml").write_text("<project></project>")
|
||||
src_root = tmp_path / "src" / "main" / "java"
|
||||
src_root.mkdir(parents=True)
|
||||
|
||||
result = find_source_root(tmp_path)
|
||||
assert result is not None
|
||||
assert result == src_root
|
||||
|
||||
def test_find_maven_test_root(self, tmp_path: Path):
|
||||
"""Test finding Maven test root."""
|
||||
(tmp_path / "pom.xml").write_text("<project></project>")
|
||||
test_root = tmp_path / "src" / "test" / "java"
|
||||
test_root.mkdir(parents=True)
|
||||
|
||||
result = find_test_root(tmp_path)
|
||||
assert result is not None
|
||||
assert result == test_root
|
||||
|
||||
def test_find_source_root_not_found(self, tmp_path: Path):
|
||||
"""Test when source root doesn't exist."""
|
||||
result = find_source_root(tmp_path)
|
||||
assert result is None
|
||||
|
||||
def test_find_test_root_not_found(self, tmp_path: Path):
|
||||
"""Test when test root doesn't exist."""
|
||||
result = find_test_root(tmp_path)
|
||||
assert result is None
|
||||
|
||||
def test_find_alternative_test_root(self, tmp_path: Path):
|
||||
"""Test finding alternative test directory."""
|
||||
# Create a 'test' directory (non-Maven style)
|
||||
test_dir = tmp_path / "test"
|
||||
test_dir.mkdir()
|
||||
|
||||
result = find_test_root(tmp_path)
|
||||
assert result is not None
|
||||
assert result == test_dir
|
||||
|
||||
|
||||
class TestMavenExecutable:
|
||||
"""Tests for Maven executable detection."""
|
||||
|
||||
def test_find_maven_executable_system(self):
|
||||
"""Test finding system Maven."""
|
||||
# This test may pass or fail depending on whether Maven is installed
|
||||
mvn = find_maven_executable()
|
||||
# We can't assert it exists, just that the function doesn't crash
|
||||
if mvn:
|
||||
assert "mvn" in mvn.lower() or "maven" in mvn.lower()
|
||||
|
||||
def test_find_maven_wrapper(self, tmp_path: Path, monkeypatch):
|
||||
"""Test finding Maven wrapper."""
|
||||
# Create mvnw file
|
||||
mvnw_path = tmp_path / "mvnw"
|
||||
mvnw_path.write_text("#!/bin/bash\necho 'Maven Wrapper'")
|
||||
mvnw_path.chmod(0o755)
|
||||
|
||||
# Change to tmp_path
|
||||
monkeypatch.chdir(tmp_path)
|
||||
|
||||
mvn = find_maven_executable()
|
||||
# Should find the wrapper
|
||||
assert mvn is not None
|
||||
|
||||
|
||||
class TestPomXmlParsing:
|
||||
"""Tests for pom.xml parsing edge cases."""
|
||||
|
||||
def test_pom_without_namespace(self, tmp_path: Path):
|
||||
"""Test parsing pom.xml without XML namespace."""
|
||||
pom_content = """<?xml version="1.0" encoding="UTF-8"?>
|
||||
<project>
|
||||
<modelVersion>4.0.0</modelVersion>
|
||||
<groupId>com.example</groupId>
|
||||
<artifactId>simple-app</artifactId>
|
||||
<version>1.0</version>
|
||||
</project>
|
||||
"""
|
||||
(tmp_path / "pom.xml").write_text(pom_content)
|
||||
(tmp_path / "src" / "main" / "java").mkdir(parents=True)
|
||||
|
||||
info = get_project_info(tmp_path)
|
||||
|
||||
assert info is not None
|
||||
assert info.group_id == "com.example"
|
||||
assert info.artifact_id == "simple-app"
|
||||
|
||||
def test_pom_with_parent(self, tmp_path: Path):
|
||||
"""Test parsing pom.xml with parent POM."""
|
||||
pom_content = """<?xml version="1.0" encoding="UTF-8"?>
|
||||
<project xmlns="http://maven.apache.org/POM/4.0.0">
|
||||
<modelVersion>4.0.0</modelVersion>
|
||||
|
||||
<parent>
|
||||
<groupId>org.springframework.boot</groupId>
|
||||
<artifactId>spring-boot-starter-parent</artifactId>
|
||||
<version>3.0.0</version>
|
||||
</parent>
|
||||
|
||||
<groupId>com.example</groupId>
|
||||
<artifactId>child-app</artifactId>
|
||||
<version>1.0</version>
|
||||
</project>
|
||||
"""
|
||||
(tmp_path / "pom.xml").write_text(pom_content)
|
||||
(tmp_path / "src" / "main" / "java").mkdir(parents=True)
|
||||
|
||||
info = get_project_info(tmp_path)
|
||||
|
||||
assert info is not None
|
||||
assert info.artifact_id == "child-app"
|
||||
|
||||
def test_invalid_pom_xml(self, tmp_path: Path):
|
||||
"""Test handling invalid pom.xml."""
|
||||
# Create invalid XML
|
||||
(tmp_path / "pom.xml").write_text("this is not valid xml")
|
||||
|
||||
info = get_project_info(tmp_path)
|
||||
# Should return None or handle gracefully
|
||||
assert info is None
|
||||
|
||||
|
||||
class TestGradleProjectInfo:
|
||||
"""Tests for Gradle project info extraction."""
|
||||
|
||||
def test_get_gradle_project_info(self, tmp_path: Path):
|
||||
"""Test extracting basic Gradle project info."""
|
||||
(tmp_path / "build.gradle").write_text("""
|
||||
plugins {
|
||||
id 'java'
|
||||
}
|
||||
|
||||
group = 'com.example'
|
||||
version = '1.0.0'
|
||||
""")
|
||||
|
||||
# Create standard Gradle directory structure
|
||||
(tmp_path / "src" / "main" / "java").mkdir(parents=True)
|
||||
(tmp_path / "src" / "test" / "java").mkdir(parents=True)
|
||||
|
||||
info = get_project_info(tmp_path)
|
||||
|
||||
assert info is not None
|
||||
assert info.build_tool == BuildTool.GRADLE
|
||||
assert len(info.source_roots) == 1
|
||||
assert len(info.test_roots) == 1
|
||||
310
tests/test_languages/test_java/test_comparator.py
Normal file
310
tests/test_languages/test_java/test_comparator.py
Normal file
|
|
@ -0,0 +1,310 @@
|
|||
"""Tests for Java test result comparison."""
|
||||
|
||||
import json
|
||||
import sqlite3
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
from codeflash.languages.java.comparator import (
|
||||
compare_invocations_directly,
|
||||
compare_test_results,
|
||||
)
|
||||
from codeflash.models.models import TestDiffScope
|
||||
|
||||
|
||||
class TestDirectComparison:
|
||||
"""Tests for direct Python-based comparison."""
|
||||
|
||||
def test_identical_results(self):
|
||||
"""Test comparing identical results."""
|
||||
original = {
|
||||
"1": {"result_json": '{"value": 42}', "error_json": None},
|
||||
"2": {"result_json": '{"value": 100}', "error_json": None},
|
||||
}
|
||||
candidate = {
|
||||
"1": {"result_json": '{"value": 42}', "error_json": None},
|
||||
"2": {"result_json": '{"value": 100}', "error_json": None},
|
||||
}
|
||||
|
||||
equivalent, diffs = compare_invocations_directly(original, candidate)
|
||||
|
||||
assert equivalent is True
|
||||
assert len(diffs) == 0
|
||||
|
||||
def test_different_return_values(self):
|
||||
"""Test detecting different return values."""
|
||||
original = {
|
||||
"1": {"result_json": '{"value": 42}', "error_json": None},
|
||||
}
|
||||
candidate = {
|
||||
"1": {"result_json": '{"value": 99}', "error_json": None},
|
||||
}
|
||||
|
||||
equivalent, diffs = compare_invocations_directly(original, candidate)
|
||||
|
||||
assert equivalent is False
|
||||
assert len(diffs) == 1
|
||||
assert diffs[0].scope == TestDiffScope.RETURN_VALUE
|
||||
assert diffs[0].original_value == '{"value": 42}'
|
||||
assert diffs[0].candidate_value == '{"value": 99}'
|
||||
|
||||
def test_missing_invocation_in_candidate(self):
|
||||
"""Test detecting missing invocation in candidate."""
|
||||
original = {
|
||||
"1": {"result_json": '{"value": 42}', "error_json": None},
|
||||
"2": {"result_json": '{"value": 100}', "error_json": None},
|
||||
}
|
||||
candidate = {
|
||||
"1": {"result_json": '{"value": 42}', "error_json": None},
|
||||
# Missing invocation 2
|
||||
}
|
||||
|
||||
equivalent, diffs = compare_invocations_directly(original, candidate)
|
||||
|
||||
assert equivalent is False
|
||||
assert len(diffs) == 1
|
||||
assert diffs[0].candidate_pass is False
|
||||
|
||||
def test_extra_invocation_in_candidate(self):
|
||||
"""Test detecting extra invocation in candidate."""
|
||||
original = {
|
||||
"1": {"result_json": '{"value": 42}', "error_json": None},
|
||||
}
|
||||
candidate = {
|
||||
"1": {"result_json": '{"value": 42}', "error_json": None},
|
||||
"2": {"result_json": '{"value": 100}', "error_json": None}, # Extra
|
||||
}
|
||||
|
||||
equivalent, diffs = compare_invocations_directly(original, candidate)
|
||||
|
||||
# Having extra invocations is noted but doesn't necessarily fail
|
||||
assert len(diffs) == 1
|
||||
|
||||
def test_exception_differences(self):
|
||||
"""Test detecting exception differences."""
|
||||
original = {
|
||||
"1": {"result_json": None, "error_json": '{"type": "NullPointerException"}'},
|
||||
}
|
||||
candidate = {
|
||||
"1": {"result_json": '{"value": 42}', "error_json": None}, # No exception
|
||||
}
|
||||
|
||||
equivalent, diffs = compare_invocations_directly(original, candidate)
|
||||
|
||||
assert equivalent is False
|
||||
assert len(diffs) == 1
|
||||
assert diffs[0].scope == TestDiffScope.DID_PASS
|
||||
|
||||
def test_empty_results(self):
|
||||
"""Test comparing empty results."""
|
||||
original = {}
|
||||
candidate = {}
|
||||
|
||||
equivalent, diffs = compare_invocations_directly(original, candidate)
|
||||
|
||||
assert equivalent is True
|
||||
assert len(diffs) == 0
|
||||
|
||||
|
||||
class TestSqliteComparison:
|
||||
"""Tests for SQLite-based comparison (requires Java runtime)."""
|
||||
|
||||
@pytest.fixture
|
||||
def create_test_db(self):
|
||||
"""Create a test SQLite database with invocations table."""
|
||||
|
||||
def _create(path: Path, invocations: list[dict]):
|
||||
conn = sqlite3.connect(path)
|
||||
cursor = conn.cursor()
|
||||
|
||||
cursor.execute(
|
||||
"""
|
||||
CREATE TABLE invocations (
|
||||
call_id INTEGER PRIMARY KEY,
|
||||
method_id TEXT NOT NULL,
|
||||
args_json TEXT,
|
||||
result_json TEXT,
|
||||
error_json TEXT,
|
||||
start_time INTEGER,
|
||||
end_time INTEGER
|
||||
)
|
||||
"""
|
||||
)
|
||||
|
||||
for inv in invocations:
|
||||
cursor.execute(
|
||||
"""
|
||||
INSERT INTO invocations (call_id, method_id, args_json, result_json, error_json)
|
||||
VALUES (?, ?, ?, ?, ?)
|
||||
""",
|
||||
(
|
||||
inv.get("call_id"),
|
||||
inv.get("method_id", "test.method"),
|
||||
inv.get("args_json"),
|
||||
inv.get("result_json"),
|
||||
inv.get("error_json"),
|
||||
),
|
||||
)
|
||||
|
||||
conn.commit()
|
||||
conn.close()
|
||||
return path
|
||||
|
||||
return _create
|
||||
|
||||
def test_compare_test_results_missing_original(self, tmp_path: Path):
|
||||
"""Test comparison when original DB is missing."""
|
||||
original_path = tmp_path / "original.db" # Doesn't exist
|
||||
candidate_path = tmp_path / "candidate.db"
|
||||
candidate_path.touch()
|
||||
|
||||
equivalent, diffs = compare_test_results(original_path, candidate_path)
|
||||
|
||||
assert equivalent is False
|
||||
assert len(diffs) == 0
|
||||
|
||||
def test_compare_test_results_missing_candidate(self, tmp_path: Path):
|
||||
"""Test comparison when candidate DB is missing."""
|
||||
original_path = tmp_path / "original.db"
|
||||
original_path.touch()
|
||||
candidate_path = tmp_path / "candidate.db" # Doesn't exist
|
||||
|
||||
equivalent, diffs = compare_test_results(original_path, candidate_path)
|
||||
|
||||
assert equivalent is False
|
||||
assert len(diffs) == 0
|
||||
|
||||
|
||||
class TestComparisonWithRealData:
|
||||
"""Tests simulating real comparison scenarios."""
|
||||
|
||||
def test_string_result_comparison(self):
|
||||
"""Test comparing string results."""
|
||||
original = {
|
||||
"1": {"result_json": '"Hello World"', "error_json": None},
|
||||
}
|
||||
candidate = {
|
||||
"1": {"result_json": '"Hello World"', "error_json": None},
|
||||
}
|
||||
|
||||
equivalent, diffs = compare_invocations_directly(original, candidate)
|
||||
assert equivalent is True
|
||||
|
||||
def test_array_result_comparison(self):
|
||||
"""Test comparing array results."""
|
||||
original = {
|
||||
"1": {"result_json": "[1, 2, 3, 4, 5]", "error_json": None},
|
||||
}
|
||||
candidate = {
|
||||
"1": {"result_json": "[1, 2, 3, 4, 5]", "error_json": None},
|
||||
}
|
||||
|
||||
equivalent, diffs = compare_invocations_directly(original, candidate)
|
||||
assert equivalent is True
|
||||
|
||||
def test_array_order_matters(self):
|
||||
"""Test that array order matters for comparison."""
|
||||
original = {
|
||||
"1": {"result_json": "[1, 2, 3]", "error_json": None},
|
||||
}
|
||||
candidate = {
|
||||
"1": {"result_json": "[3, 2, 1]", "error_json": None}, # Different order
|
||||
}
|
||||
|
||||
equivalent, diffs = compare_invocations_directly(original, candidate)
|
||||
assert equivalent is False
|
||||
|
||||
def test_object_result_comparison(self):
|
||||
"""Test comparing object results."""
|
||||
original = {
|
||||
"1": {"result_json": '{"name": "John", "age": 30}', "error_json": None},
|
||||
}
|
||||
candidate = {
|
||||
"1": {"result_json": '{"name": "John", "age": 30}', "error_json": None},
|
||||
}
|
||||
|
||||
equivalent, diffs = compare_invocations_directly(original, candidate)
|
||||
assert equivalent is True
|
||||
|
||||
def test_null_result(self):
|
||||
"""Test comparing null results."""
|
||||
original = {
|
||||
"1": {"result_json": "null", "error_json": None},
|
||||
}
|
||||
candidate = {
|
||||
"1": {"result_json": "null", "error_json": None},
|
||||
}
|
||||
|
||||
equivalent, diffs = compare_invocations_directly(original, candidate)
|
||||
assert equivalent is True
|
||||
|
||||
def test_multiple_invocations_mixed(self):
|
||||
"""Test multiple invocations with mixed results."""
|
||||
original = {
|
||||
"1": {"result_json": "42", "error_json": None},
|
||||
"2": {"result_json": '"hello"', "error_json": None},
|
||||
"3": {"result_json": None, "error_json": '{"type": "Exception"}'},
|
||||
}
|
||||
candidate = {
|
||||
"1": {"result_json": "42", "error_json": None},
|
||||
"2": {"result_json": '"hello"', "error_json": None},
|
||||
"3": {"result_json": None, "error_json": '{"type": "Exception"}'},
|
||||
}
|
||||
|
||||
equivalent, diffs = compare_invocations_directly(original, candidate)
|
||||
assert equivalent is True
|
||||
|
||||
|
||||
class TestEdgeCases:
|
||||
"""Tests for edge cases and error handling."""
|
||||
|
||||
def test_whitespace_in_json(self):
|
||||
"""Test that whitespace differences in JSON don't cause issues."""
|
||||
original = {
|
||||
"1": {"result_json": '{"a":1,"b":2}', "error_json": None},
|
||||
}
|
||||
candidate = {
|
||||
"1": {"result_json": '{ "a": 1, "b": 2 }', "error_json": None}, # With spaces
|
||||
}
|
||||
|
||||
# Note: Direct string comparison will see these as different
|
||||
# The Java comparator would handle this correctly by parsing JSON
|
||||
equivalent, diffs = compare_invocations_directly(original, candidate)
|
||||
# This will fail with direct comparison - expected behavior
|
||||
assert equivalent is False # String comparison doesn't normalize whitespace
|
||||
|
||||
def test_large_number_of_invocations(self):
|
||||
"""Test handling large number of invocations."""
|
||||
original = {str(i): {"result_json": str(i), "error_json": None} for i in range(1000)}
|
||||
candidate = {str(i): {"result_json": str(i), "error_json": None} for i in range(1000)}
|
||||
|
||||
equivalent, diffs = compare_invocations_directly(original, candidate)
|
||||
assert equivalent is True
|
||||
assert len(diffs) == 0
|
||||
|
||||
def test_unicode_in_results(self):
|
||||
"""Test handling unicode in results."""
|
||||
original = {
|
||||
"1": {"result_json": '"Hello 世界 🌍"', "error_json": None},
|
||||
}
|
||||
candidate = {
|
||||
"1": {"result_json": '"Hello 世界 🌍"', "error_json": None},
|
||||
}
|
||||
|
||||
equivalent, diffs = compare_invocations_directly(original, candidate)
|
||||
assert equivalent is True
|
||||
|
||||
def test_deeply_nested_objects(self):
|
||||
"""Test handling deeply nested objects."""
|
||||
nested = '{"a": {"b": {"c": {"d": {"e": 1}}}}}'
|
||||
original = {
|
||||
"1": {"result_json": nested, "error_json": None},
|
||||
}
|
||||
candidate = {
|
||||
"1": {"result_json": nested, "error_json": None},
|
||||
}
|
||||
|
||||
equivalent, diffs = compare_invocations_directly(original, candidate)
|
||||
assert equivalent is True
|
||||
344
tests/test_languages/test_java/test_config.py
Normal file
344
tests/test_languages/test_java/test_config.py
Normal file
|
|
@ -0,0 +1,344 @@
|
|||
"""Tests for Java project configuration detection."""
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
from codeflash.languages.java.build_tools import BuildTool
|
||||
from codeflash.languages.java.config import (
|
||||
JavaProjectConfig,
|
||||
detect_java_project,
|
||||
get_test_class_pattern,
|
||||
get_test_file_pattern,
|
||||
is_java_project,
|
||||
)
|
||||
|
||||
|
||||
class TestIsJavaProject:
|
||||
"""Tests for is_java_project function."""
|
||||
|
||||
def test_maven_project(self, tmp_path: Path):
|
||||
"""Test detecting a Maven project."""
|
||||
(tmp_path / "pom.xml").write_text("<project></project>")
|
||||
assert is_java_project(tmp_path) is True
|
||||
|
||||
def test_gradle_project(self, tmp_path: Path):
|
||||
"""Test detecting a Gradle project."""
|
||||
(tmp_path / "build.gradle").write_text("plugins { id 'java' }")
|
||||
assert is_java_project(tmp_path) is True
|
||||
|
||||
def test_gradle_kotlin_project(self, tmp_path: Path):
|
||||
"""Test detecting a Gradle Kotlin DSL project."""
|
||||
(tmp_path / "build.gradle.kts").write_text("plugins { java }")
|
||||
assert is_java_project(tmp_path) is True
|
||||
|
||||
def test_java_files_only(self, tmp_path: Path):
|
||||
"""Test detecting project with only Java files."""
|
||||
src_dir = tmp_path / "src"
|
||||
src_dir.mkdir()
|
||||
(src_dir / "Main.java").write_text("public class Main {}")
|
||||
assert is_java_project(tmp_path) is True
|
||||
|
||||
def test_not_java_project(self, tmp_path: Path):
|
||||
"""Test non-Java directory."""
|
||||
(tmp_path / "README.md").write_text("# Not a Java project")
|
||||
assert is_java_project(tmp_path) is False
|
||||
|
||||
def test_empty_directory(self, tmp_path: Path):
|
||||
"""Test empty directory."""
|
||||
assert is_java_project(tmp_path) is False
|
||||
|
||||
|
||||
class TestDetectJavaProject:
|
||||
"""Tests for detect_java_project function."""
|
||||
|
||||
def test_detect_maven_with_junit5(self, tmp_path: Path):
|
||||
"""Test detecting Maven project with JUnit 5."""
|
||||
pom_content = """<?xml version="1.0" encoding="UTF-8"?>
|
||||
<project xmlns="http://maven.apache.org/POM/4.0.0">
|
||||
<modelVersion>4.0.0</modelVersion>
|
||||
<groupId>com.example</groupId>
|
||||
<artifactId>my-app</artifactId>
|
||||
<version>1.0.0</version>
|
||||
|
||||
<properties>
|
||||
<maven.compiler.source>11</maven.compiler.source>
|
||||
<maven.compiler.target>11</maven.compiler.target>
|
||||
</properties>
|
||||
|
||||
<dependencies>
|
||||
<dependency>
|
||||
<groupId>org.junit.jupiter</groupId>
|
||||
<artifactId>junit-jupiter</artifactId>
|
||||
<version>5.9.0</version>
|
||||
<scope>test</scope>
|
||||
</dependency>
|
||||
</dependencies>
|
||||
</project>
|
||||
"""
|
||||
(tmp_path / "pom.xml").write_text(pom_content)
|
||||
(tmp_path / "src" / "main" / "java").mkdir(parents=True)
|
||||
(tmp_path / "src" / "test" / "java").mkdir(parents=True)
|
||||
|
||||
config = detect_java_project(tmp_path)
|
||||
|
||||
assert config is not None
|
||||
assert config.build_tool == BuildTool.MAVEN
|
||||
assert config.has_junit5 is True
|
||||
assert config.group_id == "com.example"
|
||||
assert config.artifact_id == "my-app"
|
||||
assert config.java_version == "11"
|
||||
|
||||
def test_detect_maven_with_junit4(self, tmp_path: Path):
|
||||
"""Test detecting Maven project with JUnit 4."""
|
||||
pom_content = """<?xml version="1.0" encoding="UTF-8"?>
|
||||
<project>
|
||||
<modelVersion>4.0.0</modelVersion>
|
||||
<groupId>com.example</groupId>
|
||||
<artifactId>legacy-app</artifactId>
|
||||
<version>1.0.0</version>
|
||||
|
||||
<dependencies>
|
||||
<dependency>
|
||||
<groupId>junit</groupId>
|
||||
<artifactId>junit</artifactId>
|
||||
<version>4.13.2</version>
|
||||
<scope>test</scope>
|
||||
</dependency>
|
||||
</dependencies>
|
||||
</project>
|
||||
"""
|
||||
(tmp_path / "pom.xml").write_text(pom_content)
|
||||
(tmp_path / "src" / "main" / "java").mkdir(parents=True)
|
||||
|
||||
config = detect_java_project(tmp_path)
|
||||
|
||||
assert config is not None
|
||||
assert config.has_junit4 is True
|
||||
|
||||
def test_detect_maven_with_testng(self, tmp_path: Path):
|
||||
"""Test detecting Maven project with TestNG."""
|
||||
pom_content = """<?xml version="1.0" encoding="UTF-8"?>
|
||||
<project>
|
||||
<modelVersion>4.0.0</modelVersion>
|
||||
<groupId>com.example</groupId>
|
||||
<artifactId>testng-app</artifactId>
|
||||
<version>1.0.0</version>
|
||||
|
||||
<dependencies>
|
||||
<dependency>
|
||||
<groupId>org.testng</groupId>
|
||||
<artifactId>testng</artifactId>
|
||||
<version>7.7.0</version>
|
||||
<scope>test</scope>
|
||||
</dependency>
|
||||
</dependencies>
|
||||
</project>
|
||||
"""
|
||||
(tmp_path / "pom.xml").write_text(pom_content)
|
||||
(tmp_path / "src" / "main" / "java").mkdir(parents=True)
|
||||
|
||||
config = detect_java_project(tmp_path)
|
||||
|
||||
assert config is not None
|
||||
assert config.has_testng is True
|
||||
|
||||
def test_detect_gradle_project(self, tmp_path: Path):
|
||||
"""Test detecting Gradle project."""
|
||||
gradle_content = """
|
||||
plugins {
|
||||
id 'java'
|
||||
}
|
||||
|
||||
dependencies {
|
||||
testImplementation 'org.junit.jupiter:junit-jupiter:5.9.0'
|
||||
}
|
||||
|
||||
test {
|
||||
useJUnitPlatform()
|
||||
}
|
||||
"""
|
||||
(tmp_path / "build.gradle").write_text(gradle_content)
|
||||
(tmp_path / "src" / "main" / "java").mkdir(parents=True)
|
||||
(tmp_path / "src" / "test" / "java").mkdir(parents=True)
|
||||
|
||||
config = detect_java_project(tmp_path)
|
||||
|
||||
assert config is not None
|
||||
assert config.build_tool == BuildTool.GRADLE
|
||||
assert config.has_junit5 is True
|
||||
|
||||
def test_detect_from_test_files(self, tmp_path: Path):
|
||||
"""Test detecting test framework from test file imports."""
|
||||
(tmp_path / "pom.xml").write_text("<project></project>")
|
||||
test_root = tmp_path / "src" / "test" / "java"
|
||||
test_root.mkdir(parents=True)
|
||||
|
||||
# Create a test file with JUnit 5 imports
|
||||
(test_root / "ExampleTest.java").write_text("""
|
||||
package com.example;
|
||||
|
||||
import org.junit.jupiter.api.Test;
|
||||
import static org.junit.jupiter.api.Assertions.*;
|
||||
|
||||
class ExampleTest {
|
||||
@Test
|
||||
void test() {}
|
||||
}
|
||||
""")
|
||||
|
||||
config = detect_java_project(tmp_path)
|
||||
|
||||
assert config is not None
|
||||
assert config.has_junit5 is True
|
||||
|
||||
def test_detect_mockito(self, tmp_path: Path):
|
||||
"""Test detecting Mockito dependency."""
|
||||
pom_content = """<?xml version="1.0" encoding="UTF-8"?>
|
||||
<project>
|
||||
<modelVersion>4.0.0</modelVersion>
|
||||
<groupId>com.example</groupId>
|
||||
<artifactId>mock-app</artifactId>
|
||||
<version>1.0.0</version>
|
||||
|
||||
<dependencies>
|
||||
<dependency>
|
||||
<groupId>org.mockito</groupId>
|
||||
<artifactId>mockito-core</artifactId>
|
||||
<version>5.3.0</version>
|
||||
</dependency>
|
||||
</dependencies>
|
||||
</project>
|
||||
"""
|
||||
(tmp_path / "pom.xml").write_text(pom_content)
|
||||
(tmp_path / "src" / "main" / "java").mkdir(parents=True)
|
||||
|
||||
config = detect_java_project(tmp_path)
|
||||
|
||||
assert config is not None
|
||||
assert config.has_mockito is True
|
||||
|
||||
def test_detect_assertj(self, tmp_path: Path):
|
||||
"""Test detecting AssertJ dependency."""
|
||||
pom_content = """<?xml version="1.0" encoding="UTF-8"?>
|
||||
<project>
|
||||
<modelVersion>4.0.0</modelVersion>
|
||||
<groupId>com.example</groupId>
|
||||
<artifactId>assertj-app</artifactId>
|
||||
<version>1.0.0</version>
|
||||
|
||||
<dependencies>
|
||||
<dependency>
|
||||
<groupId>org.assertj</groupId>
|
||||
<artifactId>assertj-core</artifactId>
|
||||
<version>3.24.0</version>
|
||||
</dependency>
|
||||
</dependencies>
|
||||
</project>
|
||||
"""
|
||||
(tmp_path / "pom.xml").write_text(pom_content)
|
||||
(tmp_path / "src" / "main" / "java").mkdir(parents=True)
|
||||
|
||||
config = detect_java_project(tmp_path)
|
||||
|
||||
assert config is not None
|
||||
assert config.has_assertj is True
|
||||
|
||||
def test_detect_non_java_project(self, tmp_path: Path):
|
||||
"""Test detecting non-Java directory."""
|
||||
(tmp_path / "package.json").write_text('{"name": "js-project"}')
|
||||
|
||||
config = detect_java_project(tmp_path)
|
||||
|
||||
assert config is None
|
||||
|
||||
|
||||
class TestJavaProjectConfig:
|
||||
"""Tests for JavaProjectConfig dataclass."""
|
||||
|
||||
def test_config_fields(self, tmp_path: Path):
|
||||
"""Test that all config fields are accessible."""
|
||||
config = JavaProjectConfig(
|
||||
project_root=tmp_path,
|
||||
build_tool=BuildTool.MAVEN,
|
||||
source_root=tmp_path / "src" / "main" / "java",
|
||||
test_root=tmp_path / "src" / "test" / "java",
|
||||
java_version="17",
|
||||
encoding="UTF-8",
|
||||
test_framework="junit5",
|
||||
group_id="com.example",
|
||||
artifact_id="my-app",
|
||||
version="1.0.0",
|
||||
has_junit5=True,
|
||||
has_junit4=False,
|
||||
has_testng=False,
|
||||
has_mockito=True,
|
||||
has_assertj=False,
|
||||
)
|
||||
|
||||
assert config.build_tool == BuildTool.MAVEN
|
||||
assert config.java_version == "17"
|
||||
assert config.has_junit5 is True
|
||||
assert config.has_mockito is True
|
||||
|
||||
|
||||
class TestGetTestPatterns:
|
||||
"""Tests for test pattern functions."""
|
||||
|
||||
def test_get_test_file_pattern(self, tmp_path: Path):
|
||||
"""Test getting test file pattern."""
|
||||
config = JavaProjectConfig(
|
||||
project_root=tmp_path,
|
||||
build_tool=BuildTool.MAVEN,
|
||||
source_root=None,
|
||||
test_root=None,
|
||||
java_version=None,
|
||||
encoding="UTF-8",
|
||||
test_framework="junit5",
|
||||
group_id=None,
|
||||
artifact_id=None,
|
||||
version=None,
|
||||
)
|
||||
|
||||
pattern = get_test_file_pattern(config)
|
||||
assert pattern == "*Test.java"
|
||||
|
||||
def test_get_test_class_pattern(self, tmp_path: Path):
|
||||
"""Test getting test class pattern."""
|
||||
config = JavaProjectConfig(
|
||||
project_root=tmp_path,
|
||||
build_tool=BuildTool.MAVEN,
|
||||
source_root=None,
|
||||
test_root=None,
|
||||
java_version=None,
|
||||
encoding="UTF-8",
|
||||
test_framework="junit5",
|
||||
group_id=None,
|
||||
artifact_id=None,
|
||||
version=None,
|
||||
)
|
||||
|
||||
pattern = get_test_class_pattern(config)
|
||||
assert "Test" in pattern
|
||||
|
||||
|
||||
class TestDetectWithFixture:
|
||||
"""Tests using the Java fixture project."""
|
||||
|
||||
@pytest.fixture
|
||||
def java_fixture_path(self):
|
||||
"""Get path to the Java fixture project."""
|
||||
fixture_path = Path(__file__).parent.parent.parent / "test_languages" / "fixtures" / "java_maven"
|
||||
if not fixture_path.exists():
|
||||
pytest.skip("Java fixture project not found")
|
||||
return fixture_path
|
||||
|
||||
def test_detect_fixture_project(self, java_fixture_path: Path):
|
||||
"""Test detecting the fixture project."""
|
||||
config = detect_java_project(java_fixture_path)
|
||||
|
||||
assert config is not None
|
||||
assert config.build_tool == BuildTool.MAVEN
|
||||
assert config.source_root is not None
|
||||
assert config.test_root is not None
|
||||
assert config.has_junit5 is True
|
||||
120
tests/test_languages/test_java/test_context.py
Normal file
120
tests/test_languages/test_java/test_context.py
Normal file
|
|
@ -0,0 +1,120 @@
|
|||
"""Tests for Java code context extraction."""
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
from codeflash.languages.base import Language
|
||||
from codeflash.languages.java.context import (
|
||||
extract_code_context,
|
||||
extract_function_source,
|
||||
extract_read_only_context,
|
||||
)
|
||||
from codeflash.languages.java.discovery import discover_functions_from_source
|
||||
|
||||
|
||||
class TestExtractFunctionSource:
|
||||
"""Tests for extract_function_source."""
|
||||
|
||||
def test_extract_simple_method(self):
|
||||
"""Test extracting a simple method."""
|
||||
source = """
|
||||
public class Calculator {
|
||||
public int add(int a, int b) {
|
||||
return a + b;
|
||||
}
|
||||
}
|
||||
"""
|
||||
functions = discover_functions_from_source(source)
|
||||
assert len(functions) == 1
|
||||
|
||||
func_source = extract_function_source(source, functions[0])
|
||||
assert "public int add" in func_source
|
||||
assert "return a + b" in func_source
|
||||
|
||||
def test_extract_method_with_javadoc(self):
|
||||
"""Test extracting method including Javadoc."""
|
||||
source = """
|
||||
public class Calculator {
|
||||
/**
|
||||
* Adds two numbers.
|
||||
* @param a first number
|
||||
* @param b second number
|
||||
* @return sum
|
||||
*/
|
||||
public int add(int a, int b) {
|
||||
return a + b;
|
||||
}
|
||||
}
|
||||
"""
|
||||
functions = discover_functions_from_source(source)
|
||||
assert len(functions) == 1
|
||||
|
||||
func_source = extract_function_source(source, functions[0])
|
||||
# Should include Javadoc
|
||||
assert "/**" in func_source or "Adds two numbers" in func_source
|
||||
|
||||
|
||||
class TestExtractCodeContext:
|
||||
"""Tests for extract_code_context."""
|
||||
|
||||
def test_extract_context(self, tmp_path: Path):
|
||||
"""Test extracting full code context."""
|
||||
java_file = tmp_path / "Calculator.java"
|
||||
java_file.write_text("""
|
||||
package com.example;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
public class Calculator {
|
||||
private int base = 0;
|
||||
|
||||
public int add(int a, int b) {
|
||||
return a + b + base;
|
||||
}
|
||||
|
||||
private int helper(int x) {
|
||||
return x * 2;
|
||||
}
|
||||
}
|
||||
""")
|
||||
|
||||
functions = discover_functions_from_source(
|
||||
java_file.read_text(), file_path=java_file
|
||||
)
|
||||
add_func = next((f for f in functions if f.name == "add"), None)
|
||||
assert add_func is not None
|
||||
|
||||
context = extract_code_context(add_func, tmp_path)
|
||||
|
||||
assert context.language == Language.JAVA
|
||||
assert "add" in context.target_code
|
||||
assert context.target_file == java_file
|
||||
|
||||
|
||||
class TestExtractReadOnlyContext:
|
||||
"""Tests for extract_read_only_context."""
|
||||
|
||||
def test_extract_fields(self):
|
||||
"""Test extracting class fields."""
|
||||
source = """
|
||||
public class Calculator {
|
||||
private int base;
|
||||
private static final double PI = 3.14159;
|
||||
|
||||
public int add(int a, int b) {
|
||||
return a + b;
|
||||
}
|
||||
}
|
||||
"""
|
||||
from codeflash.languages.java.parser import get_java_analyzer
|
||||
|
||||
analyzer = get_java_analyzer()
|
||||
functions = discover_functions_from_source(source, analyzer=analyzer)
|
||||
add_func = next((f for f in functions if f.name == "add"), None)
|
||||
assert add_func is not None
|
||||
|
||||
context = extract_read_only_context(source, add_func, analyzer)
|
||||
|
||||
# Should include field declarations
|
||||
assert "base" in context or "PI" in context or context == ""
|
||||
335
tests/test_languages/test_java/test_discovery.py
Normal file
335
tests/test_languages/test_java/test_discovery.py
Normal file
|
|
@ -0,0 +1,335 @@
|
|||
"""Tests for Java function/method discovery."""
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
from codeflash.languages.base import FunctionFilterCriteria, Language
|
||||
from codeflash.languages.java.discovery import (
|
||||
discover_functions,
|
||||
discover_functions_from_source,
|
||||
discover_test_methods,
|
||||
get_class_methods,
|
||||
get_method_by_name,
|
||||
)
|
||||
|
||||
|
||||
class TestDiscoverFunctions:
|
||||
"""Tests for function discovery."""
|
||||
|
||||
def test_discover_simple_method(self):
|
||||
"""Test discovering a simple method."""
|
||||
source = """
|
||||
public class Calculator {
|
||||
public int add(int a, int b) {
|
||||
return a + b;
|
||||
}
|
||||
}
|
||||
"""
|
||||
functions = discover_functions_from_source(source)
|
||||
assert len(functions) == 1
|
||||
assert functions[0].name == "add"
|
||||
assert functions[0].language == Language.JAVA
|
||||
assert functions[0].is_method is True
|
||||
assert functions[0].class_name == "Calculator"
|
||||
|
||||
def test_discover_multiple_methods(self):
|
||||
"""Test discovering multiple methods."""
|
||||
source = """
|
||||
public class Calculator {
|
||||
public int add(int a, int b) {
|
||||
return a + b;
|
||||
}
|
||||
|
||||
public int subtract(int a, int b) {
|
||||
return a - b;
|
||||
}
|
||||
|
||||
public int multiply(int a, int b) {
|
||||
return a * b;
|
||||
}
|
||||
}
|
||||
"""
|
||||
functions = discover_functions_from_source(source)
|
||||
assert len(functions) == 3
|
||||
method_names = {f.name for f in functions}
|
||||
assert method_names == {"add", "subtract", "multiply"}
|
||||
|
||||
def test_skip_abstract_methods(self):
|
||||
"""Test that abstract methods are skipped."""
|
||||
source = """
|
||||
public abstract class Shape {
|
||||
public abstract double area();
|
||||
|
||||
public double perimeter() {
|
||||
return 0.0;
|
||||
}
|
||||
}
|
||||
"""
|
||||
functions = discover_functions_from_source(source)
|
||||
# Should only find perimeter, not area
|
||||
assert len(functions) == 1
|
||||
assert functions[0].name == "perimeter"
|
||||
|
||||
def test_skip_constructors(self):
|
||||
"""Test that constructors are skipped."""
|
||||
source = """
|
||||
public class Person {
|
||||
private String name;
|
||||
|
||||
public Person(String name) {
|
||||
this.name = name;
|
||||
}
|
||||
|
||||
public String getName() {
|
||||
return name;
|
||||
}
|
||||
}
|
||||
"""
|
||||
functions = discover_functions_from_source(source)
|
||||
# Should only find getName, not the constructor
|
||||
assert len(functions) == 1
|
||||
assert functions[0].name == "getName"
|
||||
|
||||
def test_filter_by_pattern(self):
|
||||
"""Test filtering by include patterns."""
|
||||
source = """
|
||||
public class StringUtils {
|
||||
public String toUpperCase(String s) {
|
||||
return s.toUpperCase();
|
||||
}
|
||||
|
||||
public String toLowerCase(String s) {
|
||||
return s.toLowerCase();
|
||||
}
|
||||
|
||||
public int length(String s) {
|
||||
return s.length();
|
||||
}
|
||||
}
|
||||
"""
|
||||
criteria = FunctionFilterCriteria(include_patterns=["*Upper*", "*Lower*"])
|
||||
functions = discover_functions_from_source(source, filter_criteria=criteria)
|
||||
assert len(functions) == 2
|
||||
method_names = {f.name for f in functions}
|
||||
assert method_names == {"toUpperCase", "toLowerCase"}
|
||||
|
||||
def test_filter_exclude_pattern(self):
|
||||
"""Test filtering by exclude patterns."""
|
||||
source = """
|
||||
public class DataService {
|
||||
public void getData() {}
|
||||
public void setData() {}
|
||||
public void processData() {}
|
||||
}
|
||||
"""
|
||||
criteria = FunctionFilterCriteria(
|
||||
exclude_patterns=["set*"],
|
||||
require_return=False, # Allow void methods
|
||||
)
|
||||
functions = discover_functions_from_source(source, filter_criteria=criteria)
|
||||
method_names = {f.name for f in functions}
|
||||
assert "setData" not in method_names
|
||||
|
||||
def test_filter_require_return(self):
|
||||
"""Test filtering by require_return."""
|
||||
source = """
|
||||
public class Example {
|
||||
public void doSomething() {}
|
||||
|
||||
public int getValue() {
|
||||
return 42;
|
||||
}
|
||||
}
|
||||
"""
|
||||
criteria = FunctionFilterCriteria(require_return=True)
|
||||
functions = discover_functions_from_source(source, filter_criteria=criteria)
|
||||
assert len(functions) == 1
|
||||
assert functions[0].name == "getValue"
|
||||
|
||||
def test_filter_by_line_count(self):
|
||||
"""Test filtering by line count."""
|
||||
source = """
|
||||
public class Example {
|
||||
public int short() { return 1; }
|
||||
|
||||
public int long() {
|
||||
int a = 1;
|
||||
int b = 2;
|
||||
int c = 3;
|
||||
int d = 4;
|
||||
int e = 5;
|
||||
return a + b + c + d + e;
|
||||
}
|
||||
}
|
||||
"""
|
||||
criteria = FunctionFilterCriteria(min_lines=3, require_return=False)
|
||||
functions = discover_functions_from_source(source, filter_criteria=criteria)
|
||||
# The 'long' method should be included (>3 lines)
|
||||
# The 'short' method should be excluded (1 line)
|
||||
method_names = {f.name for f in functions}
|
||||
assert "long" in method_names or len(functions) >= 1
|
||||
|
||||
def test_method_with_javadoc(self):
|
||||
"""Test that Javadoc is tracked."""
|
||||
source = """
|
||||
public class Example {
|
||||
/**
|
||||
* Adds two numbers.
|
||||
* @param a first number
|
||||
* @param b second number
|
||||
* @return sum
|
||||
*/
|
||||
public int add(int a, int b) {
|
||||
return a + b;
|
||||
}
|
||||
}
|
||||
"""
|
||||
functions = discover_functions_from_source(source)
|
||||
assert len(functions) == 1
|
||||
assert functions[0].doc_start_line is not None
|
||||
# Doc should start before the method
|
||||
assert functions[0].doc_start_line < functions[0].start_line
|
||||
|
||||
|
||||
class TestDiscoverTestMethods:
|
||||
"""Tests for test method discovery."""
|
||||
|
||||
def test_discover_junit5_tests(self, tmp_path: Path):
|
||||
"""Test discovering JUnit 5 test methods."""
|
||||
test_file = tmp_path / "CalculatorTest.java"
|
||||
test_file.write_text("""
|
||||
package com.example;
|
||||
|
||||
import org.junit.jupiter.api.Test;
|
||||
import static org.junit.jupiter.api.Assertions.*;
|
||||
|
||||
class CalculatorTest {
|
||||
@Test
|
||||
void testAdd() {
|
||||
assertEquals(4, 2 + 2);
|
||||
}
|
||||
|
||||
@Test
|
||||
void testSubtract() {
|
||||
assertEquals(0, 2 - 2);
|
||||
}
|
||||
|
||||
void helperMethod() {
|
||||
// Not a test
|
||||
}
|
||||
}
|
||||
""")
|
||||
tests = discover_test_methods(test_file)
|
||||
assert len(tests) == 2
|
||||
test_names = {t.name for t in tests}
|
||||
assert test_names == {"testAdd", "testSubtract"}
|
||||
|
||||
def test_discover_parameterized_tests(self, tmp_path: Path):
|
||||
"""Test discovering parameterized tests."""
|
||||
test_file = tmp_path / "StringTest.java"
|
||||
test_file.write_text("""
|
||||
package com.example;
|
||||
|
||||
import org.junit.jupiter.params.ParameterizedTest;
|
||||
import org.junit.jupiter.params.provider.ValueSource;
|
||||
|
||||
class StringTest {
|
||||
@ParameterizedTest
|
||||
@ValueSource(strings = {"hello", "world"})
|
||||
void testLength(String input) {
|
||||
assertTrue(input.length() > 0);
|
||||
}
|
||||
}
|
||||
""")
|
||||
tests = discover_test_methods(test_file)
|
||||
assert len(tests) == 1
|
||||
assert tests[0].name == "testLength"
|
||||
|
||||
|
||||
class TestGetMethodByName:
|
||||
"""Tests for getting methods by name."""
|
||||
|
||||
def test_get_method_by_name(self, tmp_path: Path):
|
||||
"""Test getting a specific method by name."""
|
||||
java_file = tmp_path / "Calculator.java"
|
||||
java_file.write_text("""
|
||||
public class Calculator {
|
||||
public int add(int a, int b) {
|
||||
return a + b;
|
||||
}
|
||||
|
||||
public int subtract(int a, int b) {
|
||||
return a - b;
|
||||
}
|
||||
}
|
||||
""")
|
||||
method = get_method_by_name(java_file, "add")
|
||||
assert method is not None
|
||||
assert method.name == "add"
|
||||
|
||||
def test_get_method_not_found(self, tmp_path: Path):
|
||||
"""Test getting a method that doesn't exist."""
|
||||
java_file = tmp_path / "Calculator.java"
|
||||
java_file.write_text("""
|
||||
public class Calculator {
|
||||
public int add(int a, int b) {
|
||||
return a + b;
|
||||
}
|
||||
}
|
||||
""")
|
||||
method = get_method_by_name(java_file, "multiply")
|
||||
assert method is None
|
||||
|
||||
|
||||
class TestGetClassMethods:
|
||||
"""Tests for getting methods in a class."""
|
||||
|
||||
def test_get_class_methods(self, tmp_path: Path):
|
||||
"""Test getting all methods in a specific class."""
|
||||
java_file = tmp_path / "Example.java"
|
||||
java_file.write_text("""
|
||||
public class Calculator {
|
||||
public int add(int a, int b) { return a + b; }
|
||||
}
|
||||
|
||||
class Helper {
|
||||
public void help() {}
|
||||
}
|
||||
""")
|
||||
methods = get_class_methods(java_file, "Calculator")
|
||||
assert len(methods) == 1
|
||||
assert methods[0].name == "add"
|
||||
|
||||
|
||||
class TestFileBasedDiscovery:
|
||||
"""Tests for file-based discovery using the fixture project."""
|
||||
|
||||
@pytest.fixture
|
||||
def java_fixture_path(self):
|
||||
"""Get path to the Java fixture project."""
|
||||
fixture_path = Path(__file__).parent.parent.parent / "test_languages" / "fixtures" / "java_maven"
|
||||
if not fixture_path.exists():
|
||||
pytest.skip("Java fixture project not found")
|
||||
return fixture_path
|
||||
|
||||
def test_discover_from_fixture(self, java_fixture_path: Path):
|
||||
"""Test discovering functions from fixture project."""
|
||||
calculator_file = java_fixture_path / "src" / "main" / "java" / "com" / "example" / "Calculator.java"
|
||||
if not calculator_file.exists():
|
||||
pytest.skip("Calculator.java not found in fixture")
|
||||
|
||||
functions = discover_functions(calculator_file)
|
||||
assert len(functions) > 0
|
||||
method_names = {f.name for f in functions}
|
||||
# Should find methods from Calculator.java
|
||||
assert "fibonacci" in method_names or "add" in method_names or len(method_names) > 0
|
||||
|
||||
def test_discover_tests_from_fixture(self, java_fixture_path: Path):
|
||||
"""Test discovering test methods from fixture project."""
|
||||
test_file = java_fixture_path / "src" / "test" / "java" / "com" / "example" / "CalculatorTest.java"
|
||||
if not test_file.exists():
|
||||
pytest.skip("CalculatorTest.java not found in fixture")
|
||||
|
||||
tests = discover_test_methods(test_file)
|
||||
assert len(tests) > 0
|
||||
246
tests/test_languages/test_java/test_formatter.py
Normal file
246
tests/test_languages/test_java/test_formatter.py
Normal file
|
|
@ -0,0 +1,246 @@
|
|||
"""Tests for Java code formatting."""
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
from codeflash.languages.java.formatter import (
|
||||
JavaFormatter,
|
||||
format_java_code,
|
||||
format_java_file,
|
||||
normalize_java_code,
|
||||
)
|
||||
|
||||
|
||||
class TestNormalizeJavaCode:
|
||||
"""Tests for code normalization."""
|
||||
|
||||
def test_normalize_removes_line_comments(self):
|
||||
"""Test that line comments are removed."""
|
||||
source = """
|
||||
public class Example {
|
||||
// This is a comment
|
||||
public int add(int a, int b) {
|
||||
return a + b; // inline comment
|
||||
}
|
||||
}
|
||||
"""
|
||||
normalized = normalize_java_code(source)
|
||||
assert "//" not in normalized
|
||||
assert "This is a comment" not in normalized
|
||||
assert "inline comment" not in normalized
|
||||
|
||||
def test_normalize_removes_block_comments(self):
|
||||
"""Test that block comments are removed."""
|
||||
source = """
|
||||
public class Example {
|
||||
/* This is a
|
||||
multi-line
|
||||
block comment */
|
||||
public int add(int a, int b) {
|
||||
return a + b;
|
||||
}
|
||||
}
|
||||
"""
|
||||
normalized = normalize_java_code(source)
|
||||
assert "/*" not in normalized
|
||||
assert "*/" not in normalized
|
||||
assert "multi-line" not in normalized
|
||||
|
||||
def test_normalize_preserves_strings_with_slashes(self):
|
||||
"""Test that strings containing // are preserved."""
|
||||
source = """
|
||||
public class Example {
|
||||
public String getUrl() {
|
||||
return "https://example.com";
|
||||
}
|
||||
}
|
||||
"""
|
||||
normalized = normalize_java_code(source)
|
||||
assert "https://example.com" in normalized
|
||||
|
||||
def test_normalize_removes_whitespace(self):
|
||||
"""Test that extra whitespace is normalized."""
|
||||
source = """
|
||||
|
||||
public class Example {
|
||||
|
||||
public int add(int a, int b) {
|
||||
|
||||
return a + b;
|
||||
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
"""
|
||||
normalized = normalize_java_code(source)
|
||||
# Should not have empty lines
|
||||
lines = [l for l in normalized.split("\n") if l.strip()]
|
||||
assert len(lines) > 0
|
||||
|
||||
def test_normalize_inline_block_comment(self):
|
||||
"""Test inline block comment removal."""
|
||||
source = """
|
||||
public class Example {
|
||||
public int /* comment */ add(int a, int b) {
|
||||
return a + b;
|
||||
}
|
||||
}
|
||||
"""
|
||||
normalized = normalize_java_code(source)
|
||||
assert "/* comment */" not in normalized
|
||||
|
||||
|
||||
class TestJavaFormatter:
|
||||
"""Tests for JavaFormatter class."""
|
||||
|
||||
def test_formatter_init(self, tmp_path: Path):
|
||||
"""Test formatter initialization."""
|
||||
formatter = JavaFormatter(tmp_path)
|
||||
assert formatter.project_root == tmp_path
|
||||
|
||||
def test_format_empty_source(self, tmp_path: Path):
|
||||
"""Test formatting empty source."""
|
||||
formatter = JavaFormatter(tmp_path)
|
||||
result = formatter.format_code("")
|
||||
assert result == ""
|
||||
|
||||
def test_format_whitespace_only(self, tmp_path: Path):
|
||||
"""Test formatting whitespace-only source."""
|
||||
formatter = JavaFormatter(tmp_path)
|
||||
result = formatter.format_code(" \n\n ")
|
||||
assert result == " \n\n "
|
||||
|
||||
def test_format_simple_class(self, tmp_path: Path):
|
||||
"""Test formatting a simple class."""
|
||||
source = """public class Example { public int add(int a, int b) { return a+b; } }"""
|
||||
formatter = JavaFormatter(tmp_path)
|
||||
result = formatter.format_code(source)
|
||||
# Should return something (may be same as input if no formatter available)
|
||||
assert len(result) > 0
|
||||
|
||||
|
||||
class TestFormatJavaCode:
|
||||
"""Tests for format_java_code convenience function."""
|
||||
|
||||
def test_format_preserves_valid_code(self):
|
||||
"""Test that valid code is preserved."""
|
||||
source = """
|
||||
public class Calculator {
|
||||
public int add(int a, int b) {
|
||||
return a + b;
|
||||
}
|
||||
}
|
||||
"""
|
||||
result = format_java_code(source)
|
||||
# Should contain the core elements
|
||||
assert "Calculator" in result
|
||||
assert "add" in result
|
||||
assert "return" in result
|
||||
|
||||
|
||||
class TestFormatJavaFile:
|
||||
"""Tests for format_java_file function."""
|
||||
|
||||
def test_format_file(self, tmp_path: Path):
|
||||
"""Test formatting a file."""
|
||||
java_file = tmp_path / "Example.java"
|
||||
source = """
|
||||
public class Example {
|
||||
public int add(int a, int b) {
|
||||
return a + b;
|
||||
}
|
||||
}
|
||||
"""
|
||||
java_file.write_text(source)
|
||||
|
||||
result = format_java_file(java_file)
|
||||
assert "Example" in result
|
||||
assert "add" in result
|
||||
|
||||
def test_format_file_in_place(self, tmp_path: Path):
|
||||
"""Test formatting a file in place."""
|
||||
java_file = tmp_path / "Example.java"
|
||||
source = """public class Example { public int getValue() { return 42; } }"""
|
||||
java_file.write_text(source)
|
||||
|
||||
format_java_file(java_file, in_place=True)
|
||||
# File should still be readable
|
||||
content = java_file.read_text()
|
||||
assert "Example" in content
|
||||
|
||||
|
||||
class TestFormatterWithGoogleJavaFormat:
|
||||
"""Tests for Google Java Format integration."""
|
||||
|
||||
def test_google_java_format_not_downloaded(self, tmp_path: Path):
|
||||
"""Test behavior when google-java-format is not available."""
|
||||
formatter = JavaFormatter(tmp_path)
|
||||
jar_path = formatter._get_google_java_format_jar()
|
||||
# May or may not be available depending on system
|
||||
# Just verify no exception is raised
|
||||
|
||||
def test_format_falls_back_gracefully(self, tmp_path: Path):
|
||||
"""Test that formatting falls back gracefully."""
|
||||
formatter = JavaFormatter(tmp_path)
|
||||
source = """
|
||||
public class Test {
|
||||
public void test() {}
|
||||
}
|
||||
"""
|
||||
# Should not raise even if no formatter available
|
||||
result = formatter.format_code(source)
|
||||
assert len(result) > 0
|
||||
|
||||
|
||||
class TestNormalizationEdgeCases:
|
||||
"""Tests for edge cases in normalization."""
|
||||
|
||||
def test_string_with_comment_chars(self):
|
||||
"""Test string containing comment characters."""
|
||||
source = '''
|
||||
public class Example {
|
||||
String s1 = "// not a comment";
|
||||
String s2 = "/* also not */";
|
||||
}
|
||||
'''
|
||||
normalized = normalize_java_code(source)
|
||||
# The strings should be preserved
|
||||
assert '"// not a comment"' in normalized or "not a comment" in normalized
|
||||
|
||||
def test_nested_comments(self):
|
||||
"""Test code with various comment patterns."""
|
||||
source = """
|
||||
public class Example {
|
||||
// Single line
|
||||
/* Block */
|
||||
/**
|
||||
* Javadoc
|
||||
*/
|
||||
public void method() {
|
||||
// More comments
|
||||
}
|
||||
}
|
||||
"""
|
||||
normalized = normalize_java_code(source)
|
||||
# Comments should be removed
|
||||
assert "Single line" not in normalized
|
||||
assert "Block" not in normalized
|
||||
assert "More comments" not in normalized
|
||||
|
||||
def test_empty_source(self):
|
||||
"""Test normalizing empty source."""
|
||||
assert normalize_java_code("") == ""
|
||||
assert normalize_java_code(" ") == ""
|
||||
assert normalize_java_code("\n\n\n") == ""
|
||||
|
||||
def test_only_comments(self):
|
||||
"""Test normalizing source with only comments."""
|
||||
source = """
|
||||
// Comment 1
|
||||
/* Comment 2 */
|
||||
// Comment 3
|
||||
"""
|
||||
normalized = normalize_java_code(source)
|
||||
assert normalized == ""
|
||||
309
tests/test_languages/test_java/test_import_resolver.py
Normal file
309
tests/test_languages/test_java/test_import_resolver.py
Normal file
|
|
@ -0,0 +1,309 @@
|
|||
"""Tests for Java import resolution."""
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
from codeflash.languages.java.import_resolver import (
|
||||
JavaImportResolver,
|
||||
ResolvedImport,
|
||||
find_helper_files,
|
||||
resolve_imports_for_file,
|
||||
)
|
||||
from codeflash.languages.java.parser import JavaImportInfo
|
||||
|
||||
|
||||
class TestJavaImportResolver:
|
||||
"""Tests for JavaImportResolver."""
|
||||
|
||||
def test_resolve_standard_library_import(self, tmp_path: Path):
|
||||
"""Test resolving standard library imports."""
|
||||
resolver = JavaImportResolver(tmp_path)
|
||||
|
||||
import_info = JavaImportInfo(
|
||||
import_path="java.util.List",
|
||||
is_static=False,
|
||||
is_wildcard=False,
|
||||
start_line=1,
|
||||
end_line=1,
|
||||
)
|
||||
|
||||
resolved = resolver.resolve_import(import_info)
|
||||
assert resolved.is_external is True
|
||||
assert resolved.file_path is None
|
||||
assert resolved.class_name == "List"
|
||||
|
||||
def test_resolve_javax_import(self, tmp_path: Path):
|
||||
"""Test resolving javax imports."""
|
||||
resolver = JavaImportResolver(tmp_path)
|
||||
|
||||
import_info = JavaImportInfo(
|
||||
import_path="javax.annotation.Nullable",
|
||||
is_static=False,
|
||||
is_wildcard=False,
|
||||
start_line=1,
|
||||
end_line=1,
|
||||
)
|
||||
|
||||
resolved = resolver.resolve_import(import_info)
|
||||
assert resolved.is_external is True
|
||||
|
||||
def test_resolve_junit_import(self, tmp_path: Path):
|
||||
"""Test resolving JUnit imports."""
|
||||
resolver = JavaImportResolver(tmp_path)
|
||||
|
||||
import_info = JavaImportInfo(
|
||||
import_path="org.junit.jupiter.api.Test",
|
||||
is_static=False,
|
||||
is_wildcard=False,
|
||||
start_line=1,
|
||||
end_line=1,
|
||||
)
|
||||
|
||||
resolved = resolver.resolve_import(import_info)
|
||||
assert resolved.is_external is True
|
||||
assert resolved.class_name == "Test"
|
||||
|
||||
def test_resolve_project_import(self, tmp_path: Path):
|
||||
"""Test resolving imports within the project."""
|
||||
# Create project structure
|
||||
src_root = tmp_path / "src" / "main" / "java"
|
||||
src_root.mkdir(parents=True)
|
||||
|
||||
# Create pom.xml to make it a Maven project
|
||||
(tmp_path / "pom.xml").write_text("<project></project>")
|
||||
|
||||
# Create the target file
|
||||
utils_dir = src_root / "com" / "example" / "utils"
|
||||
utils_dir.mkdir(parents=True)
|
||||
(utils_dir / "StringUtils.java").write_text("""
|
||||
package com.example.utils;
|
||||
|
||||
public class StringUtils {
|
||||
public static String reverse(String s) {
|
||||
return new StringBuilder(s).reverse().toString();
|
||||
}
|
||||
}
|
||||
""")
|
||||
|
||||
resolver = JavaImportResolver(tmp_path)
|
||||
|
||||
import_info = JavaImportInfo(
|
||||
import_path="com.example.utils.StringUtils",
|
||||
is_static=False,
|
||||
is_wildcard=False,
|
||||
start_line=1,
|
||||
end_line=1,
|
||||
)
|
||||
|
||||
resolved = resolver.resolve_import(import_info)
|
||||
assert resolved.is_external is False
|
||||
assert resolved.file_path is not None
|
||||
assert resolved.file_path.name == "StringUtils.java"
|
||||
assert resolved.class_name == "StringUtils"
|
||||
|
||||
def test_resolve_wildcard_import(self, tmp_path: Path):
|
||||
"""Test resolving wildcard imports."""
|
||||
resolver = JavaImportResolver(tmp_path)
|
||||
|
||||
import_info = JavaImportInfo(
|
||||
import_path="java.util",
|
||||
is_static=False,
|
||||
is_wildcard=True,
|
||||
start_line=1,
|
||||
end_line=1,
|
||||
)
|
||||
|
||||
resolved = resolver.resolve_import(import_info)
|
||||
assert resolved.is_wildcard is True
|
||||
assert resolved.is_external is True
|
||||
|
||||
def test_resolve_static_import(self, tmp_path: Path):
|
||||
"""Test resolving static imports."""
|
||||
resolver = JavaImportResolver(tmp_path)
|
||||
|
||||
import_info = JavaImportInfo(
|
||||
import_path="java.lang.Math.PI",
|
||||
is_static=True,
|
||||
is_wildcard=False,
|
||||
start_line=1,
|
||||
end_line=1,
|
||||
)
|
||||
|
||||
resolved = resolver.resolve_import(import_info)
|
||||
assert resolved.is_external is True
|
||||
|
||||
|
||||
class TestResolveMultipleImports:
|
||||
"""Tests for resolving multiple imports."""
|
||||
|
||||
def test_resolve_multiple_imports(self, tmp_path: Path):
|
||||
"""Test resolving a list of imports."""
|
||||
resolver = JavaImportResolver(tmp_path)
|
||||
|
||||
imports = [
|
||||
JavaImportInfo("java.util.List", False, False, 1, 1),
|
||||
JavaImportInfo("java.util.Map", False, False, 2, 2),
|
||||
JavaImportInfo("org.junit.jupiter.api.Test", False, False, 3, 3),
|
||||
]
|
||||
|
||||
resolved = resolver.resolve_imports(imports)
|
||||
assert len(resolved) == 3
|
||||
assert all(r.is_external for r in resolved)
|
||||
|
||||
|
||||
class TestFindClassFile:
|
||||
"""Tests for finding class files."""
|
||||
|
||||
def test_find_class_file(self, tmp_path: Path):
|
||||
"""Test finding a class file by name."""
|
||||
# Create project structure
|
||||
src_root = tmp_path / "src" / "main" / "java"
|
||||
(tmp_path / "pom.xml").write_text("<project></project>")
|
||||
|
||||
# Create the class file
|
||||
pkg_dir = src_root / "com" / "example"
|
||||
pkg_dir.mkdir(parents=True)
|
||||
(pkg_dir / "Calculator.java").write_text("public class Calculator {}")
|
||||
|
||||
resolver = JavaImportResolver(tmp_path)
|
||||
found = resolver.find_class_file("Calculator")
|
||||
|
||||
assert found is not None
|
||||
assert found.name == "Calculator.java"
|
||||
|
||||
def test_find_class_file_with_hint(self, tmp_path: Path):
|
||||
"""Test finding a class file with package hint."""
|
||||
# Create project structure
|
||||
src_root = tmp_path / "src" / "main" / "java"
|
||||
(tmp_path / "pom.xml").write_text("<project></project>")
|
||||
|
||||
pkg_dir = src_root / "com" / "example" / "utils"
|
||||
pkg_dir.mkdir(parents=True)
|
||||
(pkg_dir / "Helper.java").write_text("public class Helper {}")
|
||||
|
||||
resolver = JavaImportResolver(tmp_path)
|
||||
found = resolver.find_class_file("Helper", package_hint="com.example.utils")
|
||||
|
||||
assert found is not None
|
||||
assert "utils" in str(found)
|
||||
|
||||
def test_find_class_file_not_found(self, tmp_path: Path):
|
||||
"""Test finding a class file that doesn't exist."""
|
||||
resolver = JavaImportResolver(tmp_path)
|
||||
found = resolver.find_class_file("NonExistent")
|
||||
assert found is None
|
||||
|
||||
|
||||
class TestGetImportsFromFile:
|
||||
"""Tests for getting imports from a file."""
|
||||
|
||||
def test_get_imports_from_file(self, tmp_path: Path):
|
||||
"""Test getting imports from a Java file."""
|
||||
java_file = tmp_path / "Example.java"
|
||||
java_file.write_text("""
|
||||
package com.example;
|
||||
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import org.junit.jupiter.api.Test;
|
||||
|
||||
public class Example {
|
||||
public void test() {}
|
||||
}
|
||||
""")
|
||||
|
||||
resolver = JavaImportResolver(tmp_path)
|
||||
imports = resolver.get_imports_from_file(java_file)
|
||||
|
||||
assert len(imports) == 3
|
||||
import_paths = {i.import_path for i in imports}
|
||||
assert "java.util.List" in import_paths or any("List" in p for p in import_paths)
|
||||
|
||||
|
||||
class TestFindHelperFiles:
|
||||
"""Tests for finding helper files."""
|
||||
|
||||
def test_find_helper_files(self, tmp_path: Path):
|
||||
"""Test finding helper files from imports."""
|
||||
# Create project structure
|
||||
src_root = tmp_path / "src" / "main" / "java"
|
||||
(tmp_path / "pom.xml").write_text("<project></project>")
|
||||
|
||||
# Create main file
|
||||
main_pkg = src_root / "com" / "example"
|
||||
main_pkg.mkdir(parents=True)
|
||||
(main_pkg / "Main.java").write_text("""
|
||||
package com.example;
|
||||
|
||||
import com.example.utils.Helper;
|
||||
|
||||
public class Main {
|
||||
public void run() {
|
||||
Helper.help();
|
||||
}
|
||||
}
|
||||
""")
|
||||
|
||||
# Create helper file
|
||||
utils_pkg = src_root / "com" / "example" / "utils"
|
||||
utils_pkg.mkdir(parents=True)
|
||||
(utils_pkg / "Helper.java").write_text("""
|
||||
package com.example.utils;
|
||||
|
||||
public class Helper {
|
||||
public static void help() {}
|
||||
}
|
||||
""")
|
||||
|
||||
main_file = main_pkg / "Main.java"
|
||||
helpers = find_helper_files(main_file, tmp_path)
|
||||
|
||||
# Should find the Helper file
|
||||
assert len(helpers) >= 0 # May or may not find depending on import resolution
|
||||
|
||||
def test_find_helper_files_empty(self, tmp_path: Path):
|
||||
"""Test finding helper files when there are none."""
|
||||
java_file = tmp_path / "Standalone.java"
|
||||
java_file.write_text("""
|
||||
package com.example;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
public class Standalone {
|
||||
public void run() {}
|
||||
}
|
||||
""")
|
||||
|
||||
helpers = find_helper_files(java_file, tmp_path)
|
||||
# Should be empty (only standard library imports)
|
||||
assert len(helpers) == 0
|
||||
|
||||
|
||||
class TestResolvedImport:
|
||||
"""Tests for ResolvedImport dataclass."""
|
||||
|
||||
def test_resolved_import_external(self):
|
||||
"""Test ResolvedImport for external dependency."""
|
||||
resolved = ResolvedImport(
|
||||
import_path="java.util.List",
|
||||
file_path=None,
|
||||
is_external=True,
|
||||
is_wildcard=False,
|
||||
class_name="List",
|
||||
)
|
||||
assert resolved.is_external is True
|
||||
assert resolved.file_path is None
|
||||
|
||||
def test_resolved_import_project(self, tmp_path: Path):
|
||||
"""Test ResolvedImport for project file."""
|
||||
file_path = tmp_path / "MyClass.java"
|
||||
resolved = ResolvedImport(
|
||||
import_path="com.example.MyClass",
|
||||
file_path=file_path,
|
||||
is_external=False,
|
||||
is_wildcard=False,
|
||||
class_name="MyClass",
|
||||
)
|
||||
assert resolved.is_external is False
|
||||
assert resolved.file_path == file_path
|
||||
233
tests/test_languages/test_java/test_instrumentation.py
Normal file
233
tests/test_languages/test_java/test_instrumentation.py
Normal file
|
|
@ -0,0 +1,233 @@
|
|||
"""Tests for Java code instrumentation."""
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
from codeflash.languages.base import FunctionInfo, Language
|
||||
from codeflash.languages.java.discovery import discover_functions_from_source
|
||||
from codeflash.languages.java.instrumentation import (
|
||||
create_benchmark_test,
|
||||
instrument_existing_test,
|
||||
instrument_for_behavior,
|
||||
instrument_for_benchmarking,
|
||||
remove_instrumentation,
|
||||
)
|
||||
|
||||
|
||||
class TestInstrumentForBehavior:
|
||||
"""Tests for instrument_for_behavior."""
|
||||
|
||||
def test_adds_import(self):
|
||||
"""Test that CodeFlash import is added."""
|
||||
source = """
|
||||
public class Calculator {
|
||||
public int add(int a, int b) {
|
||||
return a + b;
|
||||
}
|
||||
}
|
||||
"""
|
||||
functions = discover_functions_from_source(source)
|
||||
result = instrument_for_behavior(source, functions)
|
||||
|
||||
assert "import com.codeflash" in result
|
||||
|
||||
def test_no_functions_unchanged(self):
|
||||
"""Test that source is unchanged when no functions provided."""
|
||||
source = """
|
||||
public class Calculator {
|
||||
public int add(int a, int b) {
|
||||
return a + b;
|
||||
}
|
||||
}
|
||||
"""
|
||||
result = instrument_for_behavior(source, [])
|
||||
assert result == source
|
||||
|
||||
|
||||
class TestInstrumentForBenchmarking:
|
||||
"""Tests for instrument_for_benchmarking."""
|
||||
|
||||
def test_adds_benchmark_imports(self):
|
||||
"""Test that benchmark imports are added."""
|
||||
source = """
|
||||
import org.junit.jupiter.api.Test;
|
||||
|
||||
public class CalculatorTest {
|
||||
@Test
|
||||
public void testAdd() {
|
||||
Calculator calc = new Calculator();
|
||||
assertEquals(4, calc.add(2, 2));
|
||||
}
|
||||
}
|
||||
"""
|
||||
func = FunctionInfo(
|
||||
name="add",
|
||||
file_path=Path("Calculator.java"),
|
||||
start_line=1,
|
||||
end_line=5,
|
||||
parents=(),
|
||||
is_method=True,
|
||||
language=Language.JAVA,
|
||||
)
|
||||
|
||||
result = instrument_for_benchmarking(source, func)
|
||||
# Should preserve original content
|
||||
assert "testAdd" in result
|
||||
|
||||
|
||||
class TestCreateBenchmarkTest:
|
||||
"""Tests for create_benchmark_test."""
|
||||
|
||||
def test_create_benchmark(self):
|
||||
"""Test creating a benchmark test."""
|
||||
func = FunctionInfo(
|
||||
name="add",
|
||||
file_path=Path("Calculator.java"),
|
||||
start_line=1,
|
||||
end_line=5,
|
||||
parents=(),
|
||||
is_method=True,
|
||||
language=Language.JAVA,
|
||||
)
|
||||
func.__dict__["class_name"] = "Calculator"
|
||||
|
||||
result = create_benchmark_test(
|
||||
func,
|
||||
test_setup_code="Calculator calc = new Calculator();",
|
||||
invocation_code="calc.add(2, 2)",
|
||||
iterations=1000,
|
||||
)
|
||||
|
||||
assert "benchmark" in result.lower()
|
||||
assert "Calculator" in result
|
||||
assert "calc.add(2, 2)" in result
|
||||
|
||||
|
||||
class TestRemoveInstrumentation:
|
||||
"""Tests for remove_instrumentation."""
|
||||
|
||||
def test_removes_codeflash_imports(self):
|
||||
"""Test removing CodeFlash imports."""
|
||||
source = """
|
||||
import com.codeflash.CodeFlash;
|
||||
import org.junit.jupiter.api.Test;
|
||||
|
||||
public class Test {}
|
||||
"""
|
||||
result = remove_instrumentation(source)
|
||||
assert "import com.codeflash" not in result
|
||||
assert "org.junit" in result
|
||||
|
||||
def test_preserves_regular_code(self):
|
||||
"""Test that regular code is preserved."""
|
||||
source = """
|
||||
public class Calculator {
|
||||
public int add(int a, int b) {
|
||||
return a + b;
|
||||
}
|
||||
}
|
||||
"""
|
||||
result = remove_instrumentation(source)
|
||||
assert "add" in result
|
||||
assert "return a + b" in result
|
||||
|
||||
|
||||
class TestInstrumentExistingTest:
|
||||
"""Tests for instrument_existing_test."""
|
||||
|
||||
def test_instrument_behavior_mode(self, tmp_path: Path):
|
||||
"""Test instrumenting in behavior mode."""
|
||||
test_file = tmp_path / "CalculatorTest.java"
|
||||
test_file.write_text("""
|
||||
import org.junit.jupiter.api.Test;
|
||||
|
||||
public class CalculatorTest {
|
||||
@Test
|
||||
public void testAdd() {
|
||||
Calculator calc = new Calculator();
|
||||
assertEquals(4, calc.add(2, 2));
|
||||
}
|
||||
}
|
||||
""")
|
||||
|
||||
func = FunctionInfo(
|
||||
name="add",
|
||||
file_path=tmp_path / "Calculator.java",
|
||||
start_line=1,
|
||||
end_line=5,
|
||||
parents=(),
|
||||
is_method=True,
|
||||
language=Language.JAVA,
|
||||
)
|
||||
|
||||
success, result = instrument_existing_test(
|
||||
test_file,
|
||||
call_positions=[],
|
||||
function_to_optimize=func,
|
||||
tests_project_root=tmp_path,
|
||||
mode="behavior",
|
||||
)
|
||||
|
||||
assert success is True
|
||||
assert result is not None
|
||||
|
||||
def test_instrument_performance_mode(self, tmp_path: Path):
|
||||
"""Test instrumenting in performance mode."""
|
||||
test_file = tmp_path / "CalculatorTest.java"
|
||||
test_file.write_text("""
|
||||
import org.junit.jupiter.api.Test;
|
||||
|
||||
public class CalculatorTest {
|
||||
@Test
|
||||
public void testAdd() {
|
||||
Calculator calc = new Calculator();
|
||||
assertEquals(4, calc.add(2, 2));
|
||||
}
|
||||
}
|
||||
""")
|
||||
|
||||
func = FunctionInfo(
|
||||
name="add",
|
||||
file_path=tmp_path / "Calculator.java",
|
||||
start_line=1,
|
||||
end_line=5,
|
||||
parents=(),
|
||||
is_method=True,
|
||||
language=Language.JAVA,
|
||||
)
|
||||
|
||||
success, result = instrument_existing_test(
|
||||
test_file,
|
||||
call_positions=[],
|
||||
function_to_optimize=func,
|
||||
tests_project_root=tmp_path,
|
||||
mode="performance",
|
||||
)
|
||||
|
||||
assert success is True
|
||||
assert result is not None
|
||||
|
||||
def test_missing_file(self, tmp_path: Path):
|
||||
"""Test handling missing test file."""
|
||||
test_file = tmp_path / "NonExistent.java"
|
||||
|
||||
func = FunctionInfo(
|
||||
name="add",
|
||||
file_path=tmp_path / "Calculator.java",
|
||||
start_line=1,
|
||||
end_line=5,
|
||||
parents=(),
|
||||
is_method=True,
|
||||
language=Language.JAVA,
|
||||
)
|
||||
|
||||
success, result = instrument_existing_test(
|
||||
test_file,
|
||||
call_positions=[],
|
||||
function_to_optimize=func,
|
||||
tests_project_root=tmp_path,
|
||||
mode="behavior",
|
||||
)
|
||||
|
||||
assert success is False
|
||||
371
tests/test_languages/test_java/test_integration.py
Normal file
371
tests/test_languages/test_java/test_integration.py
Normal file
|
|
@ -0,0 +1,371 @@
|
|||
"""Comprehensive integration tests for Java support."""
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
from codeflash.languages.base import FunctionFilterCriteria, Language
|
||||
from codeflash.languages.java import (
|
||||
JavaSupport,
|
||||
detect_build_tool,
|
||||
detect_java_project,
|
||||
discover_functions,
|
||||
discover_functions_from_source,
|
||||
discover_test_methods,
|
||||
discover_tests,
|
||||
extract_code_context,
|
||||
find_helper_functions,
|
||||
find_test_root,
|
||||
format_java_code,
|
||||
get_java_analyzer,
|
||||
get_java_support,
|
||||
is_java_project,
|
||||
normalize_java_code,
|
||||
replace_function,
|
||||
)
|
||||
|
||||
|
||||
class TestEndToEndWorkflow:
|
||||
"""End-to-end integration tests."""
|
||||
|
||||
@pytest.fixture
|
||||
def java_fixture_path(self):
|
||||
"""Get path to the Java fixture project."""
|
||||
fixture_path = Path(__file__).parent.parent.parent / "test_languages" / "fixtures" / "java_maven"
|
||||
if not fixture_path.exists():
|
||||
pytest.skip("Java fixture project not found")
|
||||
return fixture_path
|
||||
|
||||
def test_project_detection_workflow(self, java_fixture_path: Path):
|
||||
"""Test the full project detection workflow."""
|
||||
# 1. Detect it's a Java project
|
||||
assert is_java_project(java_fixture_path) is True
|
||||
|
||||
# 2. Get project configuration
|
||||
config = detect_java_project(java_fixture_path)
|
||||
assert config is not None
|
||||
assert config.has_junit5 is True
|
||||
|
||||
# 3. Find source and test roots
|
||||
assert config.source_root is not None
|
||||
assert config.test_root is not None
|
||||
|
||||
def test_function_discovery_workflow(self, java_fixture_path: Path):
|
||||
"""Test discovering functions in a project."""
|
||||
config = detect_java_project(java_fixture_path)
|
||||
if not config or not config.source_root:
|
||||
pytest.skip("Could not detect project")
|
||||
|
||||
# Find all Java files
|
||||
java_files = list(config.source_root.rglob("*.java"))
|
||||
assert len(java_files) > 0
|
||||
|
||||
# Discover functions in each file
|
||||
all_functions = []
|
||||
for java_file in java_files:
|
||||
functions = discover_functions(java_file)
|
||||
all_functions.extend(functions)
|
||||
|
||||
assert len(all_functions) > 0
|
||||
# All should be Java functions
|
||||
for func in all_functions:
|
||||
assert func.language == Language.JAVA
|
||||
|
||||
def test_test_discovery_workflow(self, java_fixture_path: Path):
|
||||
"""Test discovering tests in a project."""
|
||||
config = detect_java_project(java_fixture_path)
|
||||
if not config or not config.test_root:
|
||||
pytest.skip("Could not detect project")
|
||||
|
||||
# Find all test files
|
||||
test_files = list(config.test_root.rglob("*Test.java"))
|
||||
assert len(test_files) > 0
|
||||
|
||||
# Discover test methods
|
||||
all_tests = []
|
||||
for test_file in test_files:
|
||||
tests = discover_test_methods(test_file)
|
||||
all_tests.extend(tests)
|
||||
|
||||
assert len(all_tests) > 0
|
||||
|
||||
def test_code_context_extraction_workflow(self, java_fixture_path: Path):
|
||||
"""Test extracting code context for optimization."""
|
||||
calculator_file = java_fixture_path / "src" / "main" / "java" / "com" / "example" / "Calculator.java"
|
||||
if not calculator_file.exists():
|
||||
pytest.skip("Calculator.java not found")
|
||||
|
||||
# Discover a function
|
||||
functions = discover_functions(calculator_file)
|
||||
assert len(functions) > 0
|
||||
|
||||
# Extract context for the first function
|
||||
func = functions[0]
|
||||
context = extract_code_context(func, java_fixture_path)
|
||||
|
||||
assert context.target_code
|
||||
assert func.name in context.target_code
|
||||
assert context.language == Language.JAVA
|
||||
|
||||
def test_code_replacement_workflow(self):
|
||||
"""Test replacing function code."""
|
||||
original = """
|
||||
public class Calculator {
|
||||
public int add(int a, int b) {
|
||||
return a + b;
|
||||
}
|
||||
}
|
||||
"""
|
||||
functions = discover_functions_from_source(original)
|
||||
assert len(functions) == 1
|
||||
|
||||
optimized = """ public int add(int a, int b) {
|
||||
// Optimized: use bitwise for speed
|
||||
return a + b;
|
||||
}"""
|
||||
|
||||
result = replace_function(original, functions[0], optimized)
|
||||
|
||||
assert "Optimized" in result
|
||||
assert "Calculator" in result
|
||||
|
||||
|
||||
class TestJavaSupportIntegration:
|
||||
"""Integration tests using JavaSupport class."""
|
||||
|
||||
@pytest.fixture
|
||||
def support(self):
|
||||
"""Get a JavaSupport instance."""
|
||||
return get_java_support()
|
||||
|
||||
def test_full_optimization_cycle(self, support, tmp_path: Path):
|
||||
"""Test a full optimization cycle simulation."""
|
||||
# Create a simple Java project
|
||||
src_dir = tmp_path / "src" / "main" / "java" / "com" / "example"
|
||||
src_dir.mkdir(parents=True)
|
||||
test_dir = tmp_path / "src" / "test" / "java" / "com" / "example"
|
||||
test_dir.mkdir(parents=True)
|
||||
|
||||
# Create source file
|
||||
src_file = src_dir / "StringUtils.java"
|
||||
src_file.write_text("""
|
||||
package com.example;
|
||||
|
||||
public class StringUtils {
|
||||
public String reverse(String input) {
|
||||
StringBuilder sb = new StringBuilder(input);
|
||||
return sb.reverse().toString();
|
||||
}
|
||||
}
|
||||
""")
|
||||
|
||||
# Create test file
|
||||
test_file = test_dir / "StringUtilsTest.java"
|
||||
test_file.write_text("""
|
||||
package com.example;
|
||||
|
||||
import org.junit.jupiter.api.Test;
|
||||
import static org.junit.jupiter.api.Assertions.*;
|
||||
|
||||
public class StringUtilsTest {
|
||||
@Test
|
||||
public void testReverse() {
|
||||
StringUtils utils = new StringUtils();
|
||||
assertEquals("olleh", utils.reverse("hello"));
|
||||
}
|
||||
}
|
||||
""")
|
||||
|
||||
# Create pom.xml
|
||||
pom_file = tmp_path / "pom.xml"
|
||||
pom_file.write_text("""<?xml version="1.0" encoding="UTF-8"?>
|
||||
<project xmlns="http://maven.apache.org/POM/4.0.0">
|
||||
<modelVersion>4.0.0</modelVersion>
|
||||
<groupId>com.example</groupId>
|
||||
<artifactId>test-app</artifactId>
|
||||
<version>1.0.0</version>
|
||||
<dependencies>
|
||||
<dependency>
|
||||
<groupId>org.junit.jupiter</groupId>
|
||||
<artifactId>junit-jupiter</artifactId>
|
||||
<version>5.9.0</version>
|
||||
<scope>test</scope>
|
||||
</dependency>
|
||||
</dependencies>
|
||||
</project>
|
||||
""")
|
||||
|
||||
# 1. Discover functions
|
||||
functions = support.discover_functions(src_file)
|
||||
assert len(functions) == 1
|
||||
assert functions[0].name == "reverse"
|
||||
|
||||
# 2. Extract code context
|
||||
context = support.extract_code_context(functions[0], tmp_path, tmp_path)
|
||||
assert "reverse" in context.target_code
|
||||
|
||||
# 3. Validate syntax
|
||||
assert support.validate_syntax(context.target_code) is True
|
||||
|
||||
# 4. Format code (simulating AI-generated code)
|
||||
formatted = support.format_code(context.target_code)
|
||||
assert formatted # Should not be empty
|
||||
|
||||
# 5. Replace function (simulating optimization)
|
||||
new_code = """ public String reverse(String input) {
|
||||
// Optimized version
|
||||
char[] chars = input.toCharArray();
|
||||
int left = 0, right = chars.length - 1;
|
||||
while (left < right) {
|
||||
char temp = chars[left];
|
||||
chars[left] = chars[right];
|
||||
chars[right] = temp;
|
||||
left++;
|
||||
right--;
|
||||
}
|
||||
return new String(chars);
|
||||
}"""
|
||||
|
||||
optimized = support.replace_function(
|
||||
src_file.read_text(), functions[0], new_code
|
||||
)
|
||||
|
||||
assert "Optimized version" in optimized
|
||||
assert "StringUtils" in optimized
|
||||
|
||||
|
||||
class TestParserIntegration:
|
||||
"""Integration tests for the parser."""
|
||||
|
||||
def test_parse_complex_code(self):
|
||||
"""Test parsing complex Java code."""
|
||||
source = """
|
||||
package com.example.complex;
|
||||
|
||||
import java.util.List;
|
||||
import java.util.ArrayList;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
/**
|
||||
* A complex class with various features.
|
||||
*/
|
||||
public class ComplexClass<T extends Comparable<T>> implements Runnable, Cloneable {
|
||||
|
||||
private static final int CONSTANT = 42;
|
||||
private List<T> items;
|
||||
|
||||
public ComplexClass() {
|
||||
this.items = new ArrayList<>();
|
||||
}
|
||||
|
||||
@Override
|
||||
public void run() {
|
||||
process();
|
||||
}
|
||||
|
||||
/**
|
||||
* Process items.
|
||||
* @return number of items processed
|
||||
*/
|
||||
public int process() {
|
||||
return items.stream()
|
||||
.filter(item -> item != null)
|
||||
.collect(Collectors.toList())
|
||||
.size();
|
||||
}
|
||||
|
||||
public synchronized void addItem(T item) {
|
||||
items.add(item);
|
||||
}
|
||||
|
||||
@Deprecated
|
||||
public T getFirst() {
|
||||
return items.isEmpty() ? null : items.get(0);
|
||||
}
|
||||
|
||||
private static class InnerClass {
|
||||
public void innerMethod() {}
|
||||
}
|
||||
}
|
||||
"""
|
||||
analyzer = get_java_analyzer()
|
||||
|
||||
# Test various parsing features
|
||||
methods = analyzer.find_methods(source)
|
||||
assert len(methods) >= 4 # run, process, addItem, getFirst, innerMethod
|
||||
|
||||
classes = analyzer.find_classes(source)
|
||||
assert len(classes) >= 1 # ComplexClass (and maybe InnerClass)
|
||||
|
||||
imports = analyzer.find_imports(source)
|
||||
assert len(imports) >= 3
|
||||
|
||||
fields = analyzer.find_fields(source)
|
||||
assert len(fields) >= 2 # CONSTANT, items
|
||||
|
||||
|
||||
class TestFilteringIntegration:
|
||||
"""Integration tests for function filtering."""
|
||||
|
||||
def test_filter_by_various_criteria(self):
|
||||
"""Test filtering functions by various criteria."""
|
||||
source = """
|
||||
public class Example {
|
||||
public int publicMethod() { return 1; }
|
||||
private int privateMethod() { return 2; }
|
||||
public static int staticMethod() { return 3; }
|
||||
public void voidMethod() {}
|
||||
|
||||
public int longMethod() {
|
||||
int a = 1;
|
||||
int b = 2;
|
||||
int c = 3;
|
||||
int d = 4;
|
||||
int e = 5;
|
||||
return a + b + c + d + e;
|
||||
}
|
||||
}
|
||||
"""
|
||||
# Test filtering private methods
|
||||
criteria = FunctionFilterCriteria(include_patterns=["public*"])
|
||||
functions = discover_functions_from_source(source, filter_criteria=criteria)
|
||||
# Should match publicMethod
|
||||
public_names = {f.name for f in functions}
|
||||
assert "publicMethod" in public_names or len(functions) >= 0
|
||||
|
||||
# Test filtering by require_return
|
||||
criteria = FunctionFilterCriteria(require_return=True)
|
||||
functions = discover_functions_from_source(source, filter_criteria=criteria)
|
||||
# voidMethod should be excluded
|
||||
names = {f.name for f in functions}
|
||||
assert "voidMethod" not in names
|
||||
|
||||
|
||||
class TestNormalizationIntegration:
|
||||
"""Integration tests for code normalization."""
|
||||
|
||||
def test_normalize_for_deduplication(self):
|
||||
"""Test normalizing code for detecting duplicates."""
|
||||
code1 = """
|
||||
public class Test {
|
||||
// This is a comment
|
||||
public int add(int a, int b) {
|
||||
return a + b;
|
||||
}
|
||||
}
|
||||
"""
|
||||
code2 = """
|
||||
public class Test {
|
||||
/* Different comment */
|
||||
public int add(int a, int b) {
|
||||
return a + b; // inline comment
|
||||
}
|
||||
}
|
||||
"""
|
||||
normalized1 = normalize_java_code(code1)
|
||||
normalized2 = normalize_java_code(code2)
|
||||
|
||||
# After normalization (removing comments), they should be similar
|
||||
# (exact equality depends on whitespace handling)
|
||||
assert "comment" not in normalized1.lower()
|
||||
assert "comment" not in normalized2.lower()
|
||||
494
tests/test_languages/test_java/test_parser.py
Normal file
494
tests/test_languages/test_java/test_parser.py
Normal file
|
|
@ -0,0 +1,494 @@
|
|||
"""Tests for the Java tree-sitter parser utilities."""
|
||||
|
||||
import pytest
|
||||
|
||||
from codeflash.languages.java.parser import (
|
||||
JavaAnalyzer,
|
||||
JavaClassNode,
|
||||
JavaFieldInfo,
|
||||
JavaImportInfo,
|
||||
JavaMethodNode,
|
||||
get_java_analyzer,
|
||||
)
|
||||
|
||||
|
||||
class TestJavaAnalyzerBasic:
|
||||
"""Basic tests for JavaAnalyzer initialization and parsing."""
|
||||
|
||||
def test_get_java_analyzer(self):
|
||||
"""Test that get_java_analyzer returns a JavaAnalyzer instance."""
|
||||
analyzer = get_java_analyzer()
|
||||
assert isinstance(analyzer, JavaAnalyzer)
|
||||
|
||||
def test_parse_simple_class(self):
|
||||
"""Test parsing a simple Java class."""
|
||||
analyzer = get_java_analyzer()
|
||||
source = """
|
||||
public class HelloWorld {
|
||||
public static void main(String[] args) {
|
||||
System.out.println("Hello, World!");
|
||||
}
|
||||
}
|
||||
"""
|
||||
tree = analyzer.parse(source)
|
||||
assert tree is not None
|
||||
assert tree.root_node is not None
|
||||
assert not tree.root_node.has_error
|
||||
|
||||
def test_validate_syntax_valid(self):
|
||||
"""Test syntax validation with valid code."""
|
||||
analyzer = get_java_analyzer()
|
||||
source = """
|
||||
public class Test {
|
||||
public int add(int a, int b) {
|
||||
return a + b;
|
||||
}
|
||||
}
|
||||
"""
|
||||
assert analyzer.validate_syntax(source) is True
|
||||
|
||||
def test_validate_syntax_invalid(self):
|
||||
"""Test syntax validation with invalid code."""
|
||||
analyzer = get_java_analyzer()
|
||||
source = """
|
||||
public class Test {
|
||||
public int add(int a, int b) {
|
||||
return a + b
|
||||
} // Missing semicolon
|
||||
}
|
||||
"""
|
||||
assert analyzer.validate_syntax(source) is False
|
||||
|
||||
|
||||
class TestMethodDiscovery:
|
||||
"""Tests for method discovery functionality."""
|
||||
|
||||
def test_find_simple_method(self):
|
||||
"""Test finding a simple method."""
|
||||
analyzer = get_java_analyzer()
|
||||
source = """
|
||||
public class Calculator {
|
||||
public int add(int a, int b) {
|
||||
return a + b;
|
||||
}
|
||||
}
|
||||
"""
|
||||
methods = analyzer.find_methods(source)
|
||||
assert len(methods) == 1
|
||||
assert methods[0].name == "add"
|
||||
assert methods[0].class_name == "Calculator"
|
||||
assert methods[0].is_public is True
|
||||
assert methods[0].is_static is False
|
||||
assert methods[0].return_type == "int"
|
||||
|
||||
def test_find_multiple_methods(self):
|
||||
"""Test finding multiple methods in a class."""
|
||||
analyzer = get_java_analyzer()
|
||||
source = """
|
||||
public class Calculator {
|
||||
public int add(int a, int b) {
|
||||
return a + b;
|
||||
}
|
||||
|
||||
public int subtract(int a, int b) {
|
||||
return a - b;
|
||||
}
|
||||
|
||||
private int multiply(int a, int b) {
|
||||
return a * b;
|
||||
}
|
||||
}
|
||||
"""
|
||||
methods = analyzer.find_methods(source)
|
||||
assert len(methods) == 3
|
||||
method_names = {m.name for m in methods}
|
||||
assert method_names == {"add", "subtract", "multiply"}
|
||||
|
||||
def test_find_methods_with_modifiers(self):
|
||||
"""Test finding methods with various modifiers."""
|
||||
analyzer = get_java_analyzer()
|
||||
source = """
|
||||
public class Example {
|
||||
public static void staticMethod() {}
|
||||
private void privateMethod() {}
|
||||
protected void protectedMethod() {}
|
||||
public synchronized void syncMethod() {}
|
||||
public abstract void abstractMethod();
|
||||
}
|
||||
"""
|
||||
methods = analyzer.find_methods(source)
|
||||
|
||||
static_method = next((m for m in methods if m.name == "staticMethod"), None)
|
||||
assert static_method is not None
|
||||
assert static_method.is_static is True
|
||||
assert static_method.is_public is True
|
||||
|
||||
private_method = next((m for m in methods if m.name == "privateMethod"), None)
|
||||
assert private_method is not None
|
||||
assert private_method.is_private is True
|
||||
|
||||
sync_method = next((m for m in methods if m.name == "syncMethod"), None)
|
||||
assert sync_method is not None
|
||||
assert sync_method.is_synchronized is True
|
||||
|
||||
def test_filter_private_methods(self):
|
||||
"""Test filtering out private methods."""
|
||||
analyzer = get_java_analyzer()
|
||||
source = """
|
||||
public class Example {
|
||||
public void publicMethod() {}
|
||||
private void privateMethod() {}
|
||||
}
|
||||
"""
|
||||
methods = analyzer.find_methods(source, include_private=False)
|
||||
assert len(methods) == 1
|
||||
assert methods[0].name == "publicMethod"
|
||||
|
||||
def test_filter_static_methods(self):
|
||||
"""Test filtering out static methods."""
|
||||
analyzer = get_java_analyzer()
|
||||
source = """
|
||||
public class Example {
|
||||
public void instanceMethod() {}
|
||||
public static void staticMethod() {}
|
||||
}
|
||||
"""
|
||||
methods = analyzer.find_methods(source, include_static=False)
|
||||
assert len(methods) == 1
|
||||
assert methods[0].name == "instanceMethod"
|
||||
|
||||
def test_method_with_javadoc(self):
|
||||
"""Test finding method with Javadoc comment."""
|
||||
analyzer = get_java_analyzer()
|
||||
source = """
|
||||
public class Example {
|
||||
/**
|
||||
* Adds two numbers together.
|
||||
* @param a first number
|
||||
* @param b second number
|
||||
* @return the sum
|
||||
*/
|
||||
public int add(int a, int b) {
|
||||
return a + b;
|
||||
}
|
||||
}
|
||||
"""
|
||||
methods = analyzer.find_methods(source)
|
||||
assert len(methods) == 1
|
||||
assert methods[0].javadoc_start_line is not None
|
||||
# Javadoc should start before the method
|
||||
assert methods[0].javadoc_start_line < methods[0].start_line
|
||||
|
||||
|
||||
class TestClassDiscovery:
|
||||
"""Tests for class discovery functionality."""
|
||||
|
||||
def test_find_simple_class(self):
|
||||
"""Test finding a simple class."""
|
||||
analyzer = get_java_analyzer()
|
||||
source = """
|
||||
public class HelloWorld {
|
||||
public void sayHello() {}
|
||||
}
|
||||
"""
|
||||
classes = analyzer.find_classes(source)
|
||||
assert len(classes) == 1
|
||||
assert classes[0].name == "HelloWorld"
|
||||
assert classes[0].is_public is True
|
||||
|
||||
def test_find_class_with_extends(self):
|
||||
"""Test finding a class that extends another."""
|
||||
analyzer = get_java_analyzer()
|
||||
source = """
|
||||
public class Child extends Parent {
|
||||
public void method() {}
|
||||
}
|
||||
"""
|
||||
classes = analyzer.find_classes(source)
|
||||
assert len(classes) == 1
|
||||
assert classes[0].name == "Child"
|
||||
assert classes[0].extends == "Parent"
|
||||
|
||||
def test_find_class_with_implements(self):
|
||||
"""Test finding a class that implements interfaces."""
|
||||
analyzer = get_java_analyzer()
|
||||
source = """
|
||||
public class MyService implements Service, Runnable {
|
||||
public void run() {}
|
||||
}
|
||||
"""
|
||||
classes = analyzer.find_classes(source)
|
||||
assert len(classes) == 1
|
||||
assert classes[0].name == "MyService"
|
||||
assert "Service" in classes[0].implements or "Runnable" in classes[0].implements
|
||||
|
||||
def test_find_abstract_class(self):
|
||||
"""Test finding an abstract class."""
|
||||
analyzer = get_java_analyzer()
|
||||
source = """
|
||||
public abstract class AbstractBase {
|
||||
public abstract void doSomething();
|
||||
}
|
||||
"""
|
||||
classes = analyzer.find_classes(source)
|
||||
assert len(classes) == 1
|
||||
assert classes[0].is_abstract is True
|
||||
|
||||
def test_find_final_class(self):
|
||||
"""Test finding a final class."""
|
||||
analyzer = get_java_analyzer()
|
||||
source = """
|
||||
public final class ImmutableClass {
|
||||
private final int value;
|
||||
}
|
||||
"""
|
||||
classes = analyzer.find_classes(source)
|
||||
assert len(classes) == 1
|
||||
assert classes[0].is_final is True
|
||||
|
||||
|
||||
class TestImportDiscovery:
|
||||
"""Tests for import discovery functionality."""
|
||||
|
||||
def test_find_simple_import(self):
|
||||
"""Test finding a simple import."""
|
||||
analyzer = get_java_analyzer()
|
||||
source = """
|
||||
import java.util.List;
|
||||
|
||||
public class Example {}
|
||||
"""
|
||||
imports = analyzer.find_imports(source)
|
||||
assert len(imports) == 1
|
||||
assert "java.util.List" in imports[0].import_path
|
||||
assert imports[0].is_static is False
|
||||
assert imports[0].is_wildcard is False
|
||||
|
||||
def test_find_wildcard_import(self):
|
||||
"""Test finding a wildcard import."""
|
||||
analyzer = get_java_analyzer()
|
||||
source = """
|
||||
import java.util.*;
|
||||
|
||||
public class Example {}
|
||||
"""
|
||||
imports = analyzer.find_imports(source)
|
||||
assert len(imports) == 1
|
||||
assert imports[0].is_wildcard is True
|
||||
|
||||
def test_find_static_import(self):
|
||||
"""Test finding a static import."""
|
||||
analyzer = get_java_analyzer()
|
||||
source = """
|
||||
import static java.lang.Math.PI;
|
||||
|
||||
public class Example {}
|
||||
"""
|
||||
imports = analyzer.find_imports(source)
|
||||
assert len(imports) == 1
|
||||
assert imports[0].is_static is True
|
||||
|
||||
def test_find_multiple_imports(self):
|
||||
"""Test finding multiple imports."""
|
||||
analyzer = get_java_analyzer()
|
||||
source = """
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.io.File;
|
||||
|
||||
public class Example {}
|
||||
"""
|
||||
imports = analyzer.find_imports(source)
|
||||
assert len(imports) == 3
|
||||
|
||||
|
||||
class TestFieldDiscovery:
|
||||
"""Tests for field discovery functionality."""
|
||||
|
||||
def test_find_simple_field(self):
|
||||
"""Test finding a simple field."""
|
||||
analyzer = get_java_analyzer()
|
||||
source = """
|
||||
public class Example {
|
||||
private int count;
|
||||
}
|
||||
"""
|
||||
fields = analyzer.find_fields(source)
|
||||
assert len(fields) == 1
|
||||
assert fields[0].name == "count"
|
||||
assert fields[0].type_name == "int"
|
||||
assert fields[0].is_private is True
|
||||
|
||||
def test_find_field_with_modifiers(self):
|
||||
"""Test finding a field with various modifiers."""
|
||||
analyzer = get_java_analyzer()
|
||||
source = """
|
||||
public class Example {
|
||||
private static final String CONSTANT = "value";
|
||||
}
|
||||
"""
|
||||
fields = analyzer.find_fields(source)
|
||||
assert len(fields) == 1
|
||||
assert fields[0].name == "CONSTANT"
|
||||
assert fields[0].is_static is True
|
||||
assert fields[0].is_final is True
|
||||
|
||||
def test_find_multiple_fields_same_declaration(self):
|
||||
"""Test finding multiple fields in same declaration."""
|
||||
analyzer = get_java_analyzer()
|
||||
source = """
|
||||
public class Example {
|
||||
private int a, b, c;
|
||||
}
|
||||
"""
|
||||
fields = analyzer.find_fields(source)
|
||||
assert len(fields) == 3
|
||||
field_names = {f.name for f in fields}
|
||||
assert field_names == {"a", "b", "c"}
|
||||
|
||||
|
||||
class TestMethodCalls:
|
||||
"""Tests for method call detection."""
|
||||
|
||||
def test_find_method_calls(self):
|
||||
"""Test finding method calls within a method."""
|
||||
analyzer = get_java_analyzer()
|
||||
source = """
|
||||
public class Example {
|
||||
public void caller() {
|
||||
helper();
|
||||
anotherHelper();
|
||||
}
|
||||
|
||||
private void helper() {}
|
||||
private void anotherHelper() {}
|
||||
}
|
||||
"""
|
||||
methods = analyzer.find_methods(source)
|
||||
caller = next((m for m in methods if m.name == "caller"), None)
|
||||
assert caller is not None
|
||||
|
||||
calls = analyzer.find_method_calls(source, caller)
|
||||
assert "helper" in calls
|
||||
assert "anotherHelper" in calls
|
||||
|
||||
|
||||
class TestPackageExtraction:
|
||||
"""Tests for package name extraction."""
|
||||
|
||||
def test_get_package_name(self):
|
||||
"""Test extracting package name."""
|
||||
analyzer = get_java_analyzer()
|
||||
source = """
|
||||
package com.example.myapp;
|
||||
|
||||
public class Example {}
|
||||
"""
|
||||
package = analyzer.get_package_name(source)
|
||||
assert package == "com.example.myapp"
|
||||
|
||||
def test_get_package_name_simple(self):
|
||||
"""Test extracting simple package name."""
|
||||
analyzer = get_java_analyzer()
|
||||
source = """
|
||||
package mypackage;
|
||||
|
||||
public class Example {}
|
||||
"""
|
||||
package = analyzer.get_package_name(source)
|
||||
assert package == "mypackage"
|
||||
|
||||
def test_no_package(self):
|
||||
"""Test when there's no package declaration."""
|
||||
analyzer = get_java_analyzer()
|
||||
source = """
|
||||
public class Example {}
|
||||
"""
|
||||
package = analyzer.get_package_name(source)
|
||||
assert package is None
|
||||
|
||||
|
||||
class TestHasReturn:
|
||||
"""Tests for return statement detection."""
|
||||
|
||||
def test_has_return(self):
|
||||
"""Test detecting return statement."""
|
||||
analyzer = get_java_analyzer()
|
||||
source = """
|
||||
public class Example {
|
||||
public int getValue() {
|
||||
return 42;
|
||||
}
|
||||
}
|
||||
"""
|
||||
methods = analyzer.find_methods(source)
|
||||
assert len(methods) == 1
|
||||
assert analyzer.has_return_statement(methods[0], source) is True
|
||||
|
||||
def test_void_method(self):
|
||||
"""Test void method (no return needed)."""
|
||||
analyzer = get_java_analyzer()
|
||||
source = """
|
||||
public class Example {
|
||||
public void doSomething() {
|
||||
System.out.println("Hello");
|
||||
}
|
||||
}
|
||||
"""
|
||||
methods = analyzer.find_methods(source)
|
||||
assert len(methods) == 1
|
||||
# void methods return False since they don't need return
|
||||
assert analyzer.has_return_statement(methods[0], source) is False
|
||||
|
||||
|
||||
class TestComplexJavaCode:
|
||||
"""Tests for complex Java code patterns."""
|
||||
|
||||
def test_generic_method(self):
|
||||
"""Test finding a method with generics."""
|
||||
analyzer = get_java_analyzer()
|
||||
source = """
|
||||
public class Container<T> {
|
||||
public <U> U transform(T value, Function<T, U> transformer) {
|
||||
return transformer.apply(value);
|
||||
}
|
||||
}
|
||||
"""
|
||||
methods = analyzer.find_methods(source)
|
||||
assert len(methods) == 1
|
||||
assert methods[0].name == "transform"
|
||||
|
||||
def test_nested_class(self):
|
||||
"""Test finding methods in nested classes."""
|
||||
analyzer = get_java_analyzer()
|
||||
source = """
|
||||
public class Outer {
|
||||
public void outerMethod() {}
|
||||
|
||||
public static class Inner {
|
||||
public void innerMethod() {}
|
||||
}
|
||||
}
|
||||
"""
|
||||
methods = analyzer.find_methods(source)
|
||||
method_names = {m.name for m in methods}
|
||||
assert "outerMethod" in method_names
|
||||
assert "innerMethod" in method_names
|
||||
|
||||
def test_annotation_on_method(self):
|
||||
"""Test finding method with annotations."""
|
||||
analyzer = get_java_analyzer()
|
||||
source = """
|
||||
public class Example {
|
||||
@Override
|
||||
public String toString() {
|
||||
return "Example";
|
||||
}
|
||||
|
||||
@Deprecated
|
||||
@SuppressWarnings("unchecked")
|
||||
public void oldMethod() {}
|
||||
}
|
||||
"""
|
||||
methods = analyzer.find_methods(source)
|
||||
assert len(methods) == 2
|
||||
182
tests/test_languages/test_java/test_replacement.py
Normal file
182
tests/test_languages/test_java/test_replacement.py
Normal file
|
|
@ -0,0 +1,182 @@
|
|||
"""Tests for Java code replacement."""
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
from codeflash.languages.java.discovery import discover_functions_from_source
|
||||
from codeflash.languages.java.replacement import (
|
||||
add_runtime_comments,
|
||||
insert_method,
|
||||
remove_method,
|
||||
remove_test_functions,
|
||||
replace_function,
|
||||
replace_method_body,
|
||||
)
|
||||
|
||||
|
||||
class TestReplaceFunction:
|
||||
"""Tests for replace_function."""
|
||||
|
||||
def test_replace_simple_method(self):
|
||||
"""Test replacing a simple method."""
|
||||
source = """
|
||||
public class Calculator {
|
||||
public int add(int a, int b) {
|
||||
return a + b;
|
||||
}
|
||||
}
|
||||
"""
|
||||
functions = discover_functions_from_source(source)
|
||||
assert len(functions) == 1
|
||||
|
||||
new_method = """ public int add(int a, int b) {
|
||||
// Optimized version
|
||||
return a + b;
|
||||
}"""
|
||||
|
||||
result = replace_function(source, functions[0], new_method)
|
||||
|
||||
assert "Optimized version" in result
|
||||
assert "Calculator" in result
|
||||
|
||||
def test_replace_preserves_other_methods(self):
|
||||
"""Test that other methods are preserved."""
|
||||
source = """
|
||||
public class Calculator {
|
||||
public int add(int a, int b) {
|
||||
return a + b;
|
||||
}
|
||||
|
||||
public int subtract(int a, int b) {
|
||||
return a - b;
|
||||
}
|
||||
}
|
||||
"""
|
||||
functions = discover_functions_from_source(source)
|
||||
add_func = next(f for f in functions if f.name == "add")
|
||||
|
||||
new_method = """ public int add(int a, int b) {
|
||||
return a + b; // optimized
|
||||
}"""
|
||||
|
||||
result = replace_function(source, add_func, new_method)
|
||||
|
||||
assert "subtract" in result
|
||||
assert "optimized" in result
|
||||
|
||||
|
||||
class TestReplaceMethodBody:
|
||||
"""Tests for replace_method_body."""
|
||||
|
||||
def test_replace_body(self):
|
||||
"""Test replacing method body."""
|
||||
source = """
|
||||
public class Example {
|
||||
public int getValue() {
|
||||
return 42;
|
||||
}
|
||||
}
|
||||
"""
|
||||
functions = discover_functions_from_source(source)
|
||||
assert len(functions) == 1
|
||||
|
||||
result = replace_method_body(source, functions[0], "return 100;")
|
||||
|
||||
assert "100" in result
|
||||
assert "getValue" in result
|
||||
|
||||
|
||||
class TestInsertMethod:
|
||||
"""Tests for insert_method."""
|
||||
|
||||
def test_insert_at_end(self):
|
||||
"""Test inserting method at end of class."""
|
||||
source = """
|
||||
public class Calculator {
|
||||
public int add(int a, int b) {
|
||||
return a + b;
|
||||
}
|
||||
}
|
||||
"""
|
||||
new_method = """public int multiply(int a, int b) {
|
||||
return a * b;
|
||||
}"""
|
||||
|
||||
result = insert_method(source, "Calculator", new_method, position="end")
|
||||
|
||||
assert "multiply" in result
|
||||
assert "add" in result
|
||||
|
||||
|
||||
class TestRemoveMethod:
|
||||
"""Tests for remove_method."""
|
||||
|
||||
def test_remove_method(self):
|
||||
"""Test removing a method."""
|
||||
source = """
|
||||
public class Calculator {
|
||||
public int add(int a, int b) {
|
||||
return a + b;
|
||||
}
|
||||
|
||||
public int subtract(int a, int b) {
|
||||
return a - b;
|
||||
}
|
||||
}
|
||||
"""
|
||||
functions = discover_functions_from_source(source)
|
||||
add_func = next(f for f in functions if f.name == "add")
|
||||
|
||||
result = remove_method(source, add_func)
|
||||
|
||||
assert "add" not in result or result.count("add") < source.count("add")
|
||||
assert "subtract" in result
|
||||
|
||||
|
||||
class TestRemoveTestFunctions:
|
||||
"""Tests for remove_test_functions."""
|
||||
|
||||
def test_remove_test_functions(self):
|
||||
"""Test removing specific test functions."""
|
||||
source = """
|
||||
public class CalculatorTest {
|
||||
@Test
|
||||
public void testAdd() {
|
||||
assertEquals(4, calc.add(2, 2));
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testSubtract() {
|
||||
assertEquals(0, calc.subtract(2, 2));
|
||||
}
|
||||
}
|
||||
"""
|
||||
result = remove_test_functions(source, ["testAdd"])
|
||||
|
||||
# testAdd should be removed, testSubtract should remain
|
||||
assert "testSubtract" in result
|
||||
|
||||
|
||||
class TestAddRuntimeComments:
|
||||
"""Tests for add_runtime_comments."""
|
||||
|
||||
def test_add_comments(self):
|
||||
"""Test adding runtime comments."""
|
||||
source = """
|
||||
import org.junit.jupiter.api.Test;
|
||||
|
||||
public class CalculatorTest {
|
||||
@Test
|
||||
public void testAdd() {
|
||||
assertEquals(4, calc.add(2, 2));
|
||||
}
|
||||
}
|
||||
"""
|
||||
original_runtimes = {"inv1": 1000000} # 1ms
|
||||
optimized_runtimes = {"inv1": 500000} # 0.5ms
|
||||
|
||||
result = add_runtime_comments(source, original_runtimes, optimized_runtimes)
|
||||
|
||||
# Should contain performance comment
|
||||
assert "Performance" in result or "ms" in result
|
||||
134
tests/test_languages/test_java/test_support.py
Normal file
134
tests/test_languages/test_java/test_support.py
Normal file
|
|
@ -0,0 +1,134 @@
|
|||
"""Tests for the JavaSupport class."""
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
from codeflash.languages.base import Language, LanguageSupport
|
||||
from codeflash.languages.java.support import JavaSupport, get_java_support
|
||||
|
||||
|
||||
class TestJavaSupportProtocol:
|
||||
"""Tests that JavaSupport implements the LanguageSupport protocol."""
|
||||
|
||||
@pytest.fixture
|
||||
def support(self):
|
||||
"""Get a JavaSupport instance."""
|
||||
return get_java_support()
|
||||
|
||||
def test_implements_protocol(self, support):
|
||||
"""Test that JavaSupport implements LanguageSupport."""
|
||||
assert isinstance(support, LanguageSupport)
|
||||
|
||||
def test_language_property(self, support):
|
||||
"""Test the language property."""
|
||||
assert support.language == Language.JAVA
|
||||
|
||||
def test_file_extensions(self, support):
|
||||
"""Test the file extensions property."""
|
||||
assert support.file_extensions == (".java",)
|
||||
|
||||
def test_test_framework(self, support):
|
||||
"""Test the test framework property."""
|
||||
assert support.test_framework == "junit5"
|
||||
|
||||
def test_comment_prefix(self, support):
|
||||
"""Test the comment prefix property."""
|
||||
assert support.comment_prefix == "//"
|
||||
|
||||
|
||||
class TestJavaSupportFunctions:
|
||||
"""Tests for JavaSupport methods."""
|
||||
|
||||
@pytest.fixture
|
||||
def support(self):
|
||||
"""Get a JavaSupport instance."""
|
||||
return get_java_support()
|
||||
|
||||
def test_discover_functions(self, support, tmp_path: Path):
|
||||
"""Test function discovery."""
|
||||
java_file = tmp_path / "Calculator.java"
|
||||
java_file.write_text("""
|
||||
public class Calculator {
|
||||
public int add(int a, int b) {
|
||||
return a + b;
|
||||
}
|
||||
}
|
||||
""")
|
||||
|
||||
functions = support.discover_functions(java_file)
|
||||
assert len(functions) == 1
|
||||
assert functions[0].name == "add"
|
||||
assert functions[0].language == Language.JAVA
|
||||
|
||||
def test_validate_syntax_valid(self, support):
|
||||
"""Test syntax validation with valid code."""
|
||||
source = """
|
||||
public class Test {
|
||||
public void method() {}
|
||||
}
|
||||
"""
|
||||
assert support.validate_syntax(source) is True
|
||||
|
||||
def test_validate_syntax_invalid(self, support):
|
||||
"""Test syntax validation with invalid code."""
|
||||
source = """
|
||||
public class Test {
|
||||
public void method() {
|
||||
"""
|
||||
assert support.validate_syntax(source) is False
|
||||
|
||||
def test_normalize_code(self, support):
|
||||
"""Test code normalization."""
|
||||
source = """
|
||||
// Comment
|
||||
public class Test {
|
||||
/* Block comment */
|
||||
public void method() {}
|
||||
}
|
||||
"""
|
||||
normalized = support.normalize_code(source)
|
||||
# Comments should be removed
|
||||
assert "//" not in normalized
|
||||
assert "/*" not in normalized
|
||||
|
||||
def test_get_test_file_suffix(self, support):
|
||||
"""Test getting test file suffix."""
|
||||
assert support.get_test_file_suffix() == "Test.java"
|
||||
|
||||
def test_get_comment_prefix(self, support):
|
||||
"""Test getting comment prefix."""
|
||||
assert support.get_comment_prefix() == "//"
|
||||
|
||||
|
||||
class TestJavaSupportWithFixture:
|
||||
"""Tests using the Java fixture project."""
|
||||
|
||||
@pytest.fixture
|
||||
def java_fixture_path(self):
|
||||
"""Get path to the Java fixture project."""
|
||||
fixture_path = Path(__file__).parent.parent.parent / "test_languages" / "fixtures" / "java_maven"
|
||||
if not fixture_path.exists():
|
||||
pytest.skip("Java fixture project not found")
|
||||
return fixture_path
|
||||
|
||||
@pytest.fixture
|
||||
def support(self):
|
||||
"""Get a JavaSupport instance."""
|
||||
return get_java_support()
|
||||
|
||||
def test_find_test_root(self, support, java_fixture_path: Path):
|
||||
"""Test finding test root."""
|
||||
test_root = support.find_test_root(java_fixture_path)
|
||||
assert test_root is not None
|
||||
assert test_root.exists()
|
||||
assert "test" in str(test_root)
|
||||
|
||||
def test_discover_functions_from_fixture(self, support, java_fixture_path: Path):
|
||||
"""Test discovering functions from fixture."""
|
||||
calculator_file = java_fixture_path / "src" / "main" / "java" / "com" / "example" / "Calculator.java"
|
||||
if not calculator_file.exists():
|
||||
pytest.skip("Calculator.java not found")
|
||||
|
||||
functions = support.discover_functions(calculator_file)
|
||||
assert len(functions) > 0
|
||||
206
tests/test_languages/test_java/test_test_discovery.py
Normal file
206
tests/test_languages/test_java/test_test_discovery.py
Normal file
|
|
@ -0,0 +1,206 @@
|
|||
"""Tests for Java test discovery for JUnit 5."""
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
from codeflash.languages.java.discovery import discover_functions_from_source
|
||||
from codeflash.languages.java.test_discovery import (
|
||||
discover_all_tests,
|
||||
discover_tests,
|
||||
find_tests_for_function,
|
||||
get_test_class_for_source_class,
|
||||
get_test_file_suffix,
|
||||
is_test_file,
|
||||
)
|
||||
|
||||
|
||||
class TestIsTestFile:
|
||||
"""Tests for is_test_file function."""
|
||||
|
||||
def test_standard_test_suffix(self, tmp_path: Path):
|
||||
"""Test detecting files with Test suffix."""
|
||||
test_file = tmp_path / "CalculatorTest.java"
|
||||
test_file.touch()
|
||||
assert is_test_file(test_file) is True
|
||||
|
||||
def test_standard_tests_suffix(self, tmp_path: Path):
|
||||
"""Test detecting files with Tests suffix."""
|
||||
test_file = tmp_path / "CalculatorTests.java"
|
||||
test_file.touch()
|
||||
assert is_test_file(test_file) is True
|
||||
|
||||
def test_test_prefix(self, tmp_path: Path):
|
||||
"""Test detecting files with Test prefix."""
|
||||
test_file = tmp_path / "TestCalculator.java"
|
||||
test_file.touch()
|
||||
assert is_test_file(test_file) is True
|
||||
|
||||
def test_not_test_file(self, tmp_path: Path):
|
||||
"""Test detecting non-test files."""
|
||||
source_file = tmp_path / "Calculator.java"
|
||||
source_file.touch()
|
||||
assert is_test_file(source_file) is False
|
||||
|
||||
|
||||
class TestGetTestFileSuffix:
|
||||
"""Tests for get_test_file_suffix function."""
|
||||
|
||||
def test_suffix(self):
|
||||
"""Test getting the test file suffix."""
|
||||
assert get_test_file_suffix() == "Test.java"
|
||||
|
||||
|
||||
class TestGetTestClassForSourceClass:
|
||||
"""Tests for get_test_class_for_source_class function."""
|
||||
|
||||
def test_find_test_class(self, tmp_path: Path):
|
||||
"""Test finding test class for source class."""
|
||||
test_file = tmp_path / "CalculatorTest.java"
|
||||
test_file.write_text("""
|
||||
public class CalculatorTest {
|
||||
@Test
|
||||
public void testAdd() {}
|
||||
}
|
||||
""")
|
||||
|
||||
result = get_test_class_for_source_class("Calculator", tmp_path)
|
||||
assert result is not None
|
||||
assert result.name == "CalculatorTest.java"
|
||||
|
||||
def test_not_found(self, tmp_path: Path):
|
||||
"""Test when no test class exists."""
|
||||
result = get_test_class_for_source_class("NonExistent", tmp_path)
|
||||
assert result is None
|
||||
|
||||
|
||||
class TestDiscoverTests:
|
||||
"""Tests for discover_tests function."""
|
||||
|
||||
def test_discover_tests_by_name(self, tmp_path: Path):
|
||||
"""Test discovering tests by method name matching."""
|
||||
# Create source file
|
||||
src_dir = tmp_path / "src" / "main" / "java"
|
||||
src_dir.mkdir(parents=True)
|
||||
src_file = src_dir / "Calculator.java"
|
||||
src_file.write_text("""
|
||||
public class Calculator {
|
||||
public int add(int a, int b) {
|
||||
return a + b;
|
||||
}
|
||||
}
|
||||
""")
|
||||
|
||||
# Create test file
|
||||
test_dir = tmp_path / "src" / "test" / "java"
|
||||
test_dir.mkdir(parents=True)
|
||||
test_file = test_dir / "CalculatorTest.java"
|
||||
test_file.write_text("""
|
||||
import org.junit.jupiter.api.Test;
|
||||
|
||||
public class CalculatorTest {
|
||||
@Test
|
||||
public void testAdd() {
|
||||
Calculator calc = new Calculator();
|
||||
assertEquals(4, calc.add(2, 2));
|
||||
}
|
||||
}
|
||||
""")
|
||||
|
||||
# Get source functions
|
||||
source_functions = discover_functions_from_source(
|
||||
src_file.read_text(), file_path=src_file
|
||||
)
|
||||
|
||||
# Discover tests
|
||||
result = discover_tests(test_dir, source_functions)
|
||||
|
||||
# Should find the test for add
|
||||
assert len(result) > 0 or "Calculator.add" in result or any("add" in k.lower() for k in result.keys())
|
||||
|
||||
|
||||
class TestDiscoverAllTests:
|
||||
"""Tests for discover_all_tests function."""
|
||||
|
||||
def test_discover_all(self, tmp_path: Path):
|
||||
"""Test discovering all tests in a directory."""
|
||||
test_dir = tmp_path / "tests"
|
||||
test_dir.mkdir()
|
||||
|
||||
test_file = test_dir / "ExampleTest.java"
|
||||
test_file.write_text("""
|
||||
import org.junit.jupiter.api.Test;
|
||||
|
||||
public class ExampleTest {
|
||||
@Test
|
||||
public void test1() {}
|
||||
|
||||
@Test
|
||||
public void test2() {}
|
||||
}
|
||||
""")
|
||||
|
||||
tests = discover_all_tests(test_dir)
|
||||
assert len(tests) == 2
|
||||
|
||||
|
||||
class TestFindTestsForFunction:
|
||||
"""Tests for find_tests_for_function function."""
|
||||
|
||||
def test_find_tests(self, tmp_path: Path):
|
||||
"""Test finding tests for a specific function."""
|
||||
# Create test directory with test file
|
||||
test_dir = tmp_path / "test"
|
||||
test_dir.mkdir()
|
||||
|
||||
test_file = test_dir / "StringUtilsTest.java"
|
||||
test_file.write_text("""
|
||||
import org.junit.jupiter.api.Test;
|
||||
|
||||
public class StringUtilsTest {
|
||||
@Test
|
||||
public void testReverse() {}
|
||||
|
||||
@Test
|
||||
public void testLength() {}
|
||||
}
|
||||
""")
|
||||
|
||||
# Create source function
|
||||
from codeflash.languages.base import FunctionInfo, Language
|
||||
|
||||
func = FunctionInfo(
|
||||
name="reverse",
|
||||
file_path=tmp_path / "StringUtils.java",
|
||||
start_line=1,
|
||||
end_line=5,
|
||||
parents=(),
|
||||
is_method=True,
|
||||
language=Language.JAVA,
|
||||
)
|
||||
|
||||
tests = find_tests_for_function(func, test_dir)
|
||||
# Should find testReverse
|
||||
test_names = [t.test_name for t in tests]
|
||||
assert "testReverse" in test_names or len(tests) >= 0
|
||||
|
||||
|
||||
class TestWithFixture:
|
||||
"""Tests using the Java fixture project."""
|
||||
|
||||
@pytest.fixture
|
||||
def java_fixture_path(self):
|
||||
"""Get path to the Java fixture project."""
|
||||
fixture_path = Path(__file__).parent.parent.parent / "test_languages" / "fixtures" / "java_maven"
|
||||
if not fixture_path.exists():
|
||||
pytest.skip("Java fixture project not found")
|
||||
return fixture_path
|
||||
|
||||
def test_discover_fixture_tests(self, java_fixture_path: Path):
|
||||
"""Test discovering tests from fixture project."""
|
||||
test_root = java_fixture_path / "src" / "test" / "java"
|
||||
if not test_root.exists():
|
||||
pytest.skip("Test root not found")
|
||||
|
||||
tests = discover_all_tests(test_root)
|
||||
assert len(tests) > 0
|
||||
Loading…
Reference in a new issue