mirror of
https://github.com/codeflash-ai/codeflash.git
synced 2026-05-04 18:25:17 +00:00
Merge branch 'omni-java' of github.com:codeflash-ai/codeflash into omni-java
This commit is contained in:
commit
c40798fa73
5 changed files with 734 additions and 0 deletions
70
.github/workflows/java-e2e-tests.yml
vendored
Normal file
70
.github/workflows/java-e2e-tests.yml
vendored
Normal file
|
|
@ -0,0 +1,70 @@
|
|||
name: Java E2E Tests
|
||||
|
||||
on:
|
||||
push:
|
||||
branches:
|
||||
- main
|
||||
- omni-java
|
||||
paths:
|
||||
- 'codeflash/languages/java/**'
|
||||
- 'tests/test_languages/test_java*.py'
|
||||
- 'code_to_optimize/java/**'
|
||||
- '.github/workflows/java-e2e-tests.yml'
|
||||
pull_request:
|
||||
paths:
|
||||
- 'codeflash/languages/java/**'
|
||||
- 'tests/test_languages/test_java*.py'
|
||||
- 'code_to_optimize/java/**'
|
||||
- '.github/workflows/java-e2e-tests.yml'
|
||||
|
||||
concurrency:
|
||||
group: ${{ github.workflow }}-${{ github.ref_name }}
|
||||
cancel-in-progress: true
|
||||
|
||||
jobs:
|
||||
java-e2e:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
fetch-depth: 0
|
||||
token: ${{ secrets.GITHUB_TOKEN }}
|
||||
|
||||
- name: Set up JDK 11
|
||||
uses: actions/setup-java@v4
|
||||
with:
|
||||
java-version: '11'
|
||||
distribution: 'temurin'
|
||||
cache: maven
|
||||
|
||||
- name: Install uv
|
||||
uses: astral-sh/setup-uv@v6
|
||||
|
||||
- name: Set up Python environment
|
||||
run: |
|
||||
uv venv --seed
|
||||
uv sync
|
||||
|
||||
- name: Verify Java installation
|
||||
run: |
|
||||
java -version
|
||||
mvn --version
|
||||
|
||||
- name: Build Java sample project
|
||||
run: |
|
||||
cd code_to_optimize/java
|
||||
mvn compile -q
|
||||
|
||||
- name: Run Java sample project tests
|
||||
run: |
|
||||
cd code_to_optimize/java
|
||||
mvn test -q
|
||||
|
||||
- name: Run Java E2E tests
|
||||
run: |
|
||||
uv run pytest tests/test_languages/test_java_e2e.py -v --tb=short
|
||||
|
||||
- name: Run Java unit tests
|
||||
run: |
|
||||
uv run pytest tests/test_languages/test_java/ -v --tb=short -x
|
||||
2
.gitignore
vendored
2
.gitignore
vendored
|
|
@ -166,6 +166,8 @@ cython_debug/
|
|||
*.xml
|
||||
# Allow pom.xml in test fixtures for Maven project detection
|
||||
!tests/test_languages/fixtures/**/pom.xml
|
||||
# Allow pom.xml in Java sample project
|
||||
!code_to_optimize/java/pom.xml
|
||||
*.pem
|
||||
|
||||
# Ruff cache
|
||||
|
|
|
|||
67
code_to_optimize/java/pom.xml
Normal file
67
code_to_optimize/java/pom.xml
Normal file
|
|
@ -0,0 +1,67 @@
|
|||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<project xmlns="http://maven.apache.org/POM/4.0.0"
|
||||
xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
|
||||
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
|
||||
<modelVersion>4.0.0</modelVersion>
|
||||
|
||||
<groupId>com.example</groupId>
|
||||
<artifactId>codeflash-java-sample</artifactId>
|
||||
<version>1.0.0</version>
|
||||
<packaging>jar</packaging>
|
||||
|
||||
<name>Codeflash Java Sample Project</name>
|
||||
<description>Sample Java project for testing Codeflash optimization</description>
|
||||
|
||||
<properties>
|
||||
<maven.compiler.source>11</maven.compiler.source>
|
||||
<maven.compiler.target>11</maven.compiler.target>
|
||||
<project.build.sourceEncoding>UTF-8</project.build.sourceEncoding>
|
||||
<junit.jupiter.version>5.10.0</junit.jupiter.version>
|
||||
</properties>
|
||||
|
||||
<dependencies>
|
||||
<dependency>
|
||||
<groupId>org.junit.jupiter</groupId>
|
||||
<artifactId>junit-jupiter</artifactId>
|
||||
<version>${junit.jupiter.version}</version>
|
||||
<scope>test</scope>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>org.junit.jupiter</groupId>
|
||||
<artifactId>junit-jupiter-params</artifactId>
|
||||
<version>${junit.jupiter.version}</version>
|
||||
<scope>test</scope>
|
||||
</dependency>
|
||||
<!-- SQLite JDBC for Codeflash instrumentation -->
|
||||
<dependency>
|
||||
<groupId>org.xerial</groupId>
|
||||
<artifactId>sqlite-jdbc</artifactId>
|
||||
<version>3.42.0.0</version>
|
||||
<scope>test</scope>
|
||||
</dependency>
|
||||
</dependencies>
|
||||
|
||||
<build>
|
||||
<plugins>
|
||||
<plugin>
|
||||
<groupId>org.apache.maven.plugins</groupId>
|
||||
<artifactId>maven-compiler-plugin</artifactId>
|
||||
<version>3.11.0</version>
|
||||
<configuration>
|
||||
<source>11</source>
|
||||
<target>11</target>
|
||||
</configuration>
|
||||
</plugin>
|
||||
<plugin>
|
||||
<groupId>org.apache.maven.plugins</groupId>
|
||||
<artifactId>maven-surefire-plugin</artifactId>
|
||||
<version>3.1.2</version>
|
||||
<configuration>
|
||||
<includes>
|
||||
<include>**/*Test.java</include>
|
||||
</includes>
|
||||
</configuration>
|
||||
</plugin>
|
||||
</plugins>
|
||||
</build>
|
||||
</project>
|
||||
|
|
@ -1,6 +1,7 @@
|
|||
"""Tests for Java test result comparison."""
|
||||
|
||||
import json
|
||||
import shutil
|
||||
import sqlite3
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
|
|
@ -13,6 +14,12 @@ from codeflash.languages.java.comparator import (
|
|||
)
|
||||
from codeflash.models.models import TestDiffScope
|
||||
|
||||
# Skip tests that require Java runtime if Java is not available
|
||||
requires_java = pytest.mark.skipif(
|
||||
shutil.which("java") is None,
|
||||
reason="Java not found - skipping Comparator integration tests",
|
||||
)
|
||||
|
||||
|
||||
class TestDirectComparison:
|
||||
"""Tests for direct Python-based comparison."""
|
||||
|
|
@ -308,3 +315,241 @@ class TestEdgeCases:
|
|||
|
||||
equivalent, diffs = compare_invocations_directly(original, candidate)
|
||||
assert equivalent is True
|
||||
|
||||
|
||||
@requires_java
|
||||
class TestTestResultsTableSchema:
|
||||
"""Tests for Java Comparator reading from test_results table schema.
|
||||
|
||||
This validates the schema integration between instrumentation (which writes
|
||||
to test_results) and the Comparator (which reads from test_results).
|
||||
|
||||
These tests require Java to be installed to run the actual Comparator.jar.
|
||||
"""
|
||||
|
||||
@pytest.fixture
|
||||
def create_test_results_db(self):
|
||||
"""Create a test SQLite database with test_results table (actual schema used by instrumentation)."""
|
||||
|
||||
def _create(path: Path, results: list[dict]):
|
||||
conn = sqlite3.connect(path)
|
||||
cursor = conn.cursor()
|
||||
|
||||
# Create test_results table matching instrumentation schema
|
||||
cursor.execute(
|
||||
"""
|
||||
CREATE TABLE test_results (
|
||||
test_module_path TEXT,
|
||||
test_class_name TEXT,
|
||||
test_function_name TEXT,
|
||||
function_getting_tested TEXT,
|
||||
loop_index INTEGER,
|
||||
iteration_id TEXT,
|
||||
runtime INTEGER,
|
||||
return_value TEXT,
|
||||
verification_type TEXT
|
||||
)
|
||||
"""
|
||||
)
|
||||
|
||||
for result in results:
|
||||
cursor.execute(
|
||||
"""
|
||||
INSERT INTO test_results
|
||||
(test_module_path, test_class_name, test_function_name,
|
||||
function_getting_tested, loop_index, iteration_id,
|
||||
runtime, return_value, verification_type)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)
|
||||
""",
|
||||
(
|
||||
result.get("test_module_path", "TestModule"),
|
||||
result.get("test_class_name", "TestClass"),
|
||||
result.get("test_function_name", "testMethod"),
|
||||
result.get("function_getting_tested", "targetMethod"),
|
||||
result.get("loop_index", 1),
|
||||
result.get("iteration_id", "1_0"),
|
||||
result.get("runtime", 1000000),
|
||||
result.get("return_value"),
|
||||
result.get("verification_type", "function_call"),
|
||||
),
|
||||
)
|
||||
|
||||
conn.commit()
|
||||
conn.close()
|
||||
return path
|
||||
|
||||
return _create
|
||||
|
||||
def test_comparator_reads_test_results_table_identical(
|
||||
self, tmp_path: Path, create_test_results_db
|
||||
):
|
||||
"""Test that Comparator correctly reads test_results table with identical results."""
|
||||
original_path = tmp_path / "original.db"
|
||||
candidate_path = tmp_path / "candidate.db"
|
||||
|
||||
# Create databases with identical results
|
||||
results = [
|
||||
{
|
||||
"test_class_name": "CalculatorTest",
|
||||
"function_getting_tested": "add",
|
||||
"loop_index": 1,
|
||||
"iteration_id": "1_0",
|
||||
"return_value": '{"value": 42}',
|
||||
},
|
||||
{
|
||||
"test_class_name": "CalculatorTest",
|
||||
"function_getting_tested": "add",
|
||||
"loop_index": 1,
|
||||
"iteration_id": "2_0",
|
||||
"return_value": '{"value": 100}',
|
||||
},
|
||||
]
|
||||
|
||||
create_test_results_db(original_path, results)
|
||||
create_test_results_db(candidate_path, results)
|
||||
|
||||
# Compare using Java Comparator
|
||||
equivalent, diffs = compare_test_results(original_path, candidate_path)
|
||||
|
||||
assert equivalent is True
|
||||
assert len(diffs) == 0
|
||||
|
||||
def test_comparator_reads_test_results_table_different_values(
|
||||
self, tmp_path: Path, create_test_results_db
|
||||
):
|
||||
"""Test that Comparator detects different return values from test_results table."""
|
||||
original_path = tmp_path / "original.db"
|
||||
candidate_path = tmp_path / "candidate.db"
|
||||
|
||||
original_results = [
|
||||
{
|
||||
"test_class_name": "StringUtilsTest",
|
||||
"function_getting_tested": "reverse",
|
||||
"loop_index": 1,
|
||||
"iteration_id": "1_0",
|
||||
"return_value": '"olleh"',
|
||||
},
|
||||
]
|
||||
|
||||
candidate_results = [
|
||||
{
|
||||
"test_class_name": "StringUtilsTest",
|
||||
"function_getting_tested": "reverse",
|
||||
"loop_index": 1,
|
||||
"iteration_id": "1_0",
|
||||
"return_value": '"wrong"', # Different result
|
||||
},
|
||||
]
|
||||
|
||||
create_test_results_db(original_path, original_results)
|
||||
create_test_results_db(candidate_path, candidate_results)
|
||||
|
||||
# Compare using Java Comparator
|
||||
equivalent, diffs = compare_test_results(original_path, candidate_path)
|
||||
|
||||
assert equivalent is False
|
||||
assert len(diffs) == 1
|
||||
assert diffs[0].scope == TestDiffScope.RETURN_VALUE
|
||||
|
||||
def test_comparator_handles_multiple_loop_iterations(
|
||||
self, tmp_path: Path, create_test_results_db
|
||||
):
|
||||
"""Test that Comparator correctly handles multiple loop iterations."""
|
||||
original_path = tmp_path / "original.db"
|
||||
candidate_path = tmp_path / "candidate.db"
|
||||
|
||||
# Simulate multiple benchmark loops
|
||||
results = []
|
||||
for loop in range(1, 4): # 3 loops
|
||||
for iteration in range(1, 3): # 2 iterations per loop
|
||||
results.append(
|
||||
{
|
||||
"test_class_name": "AlgorithmTest",
|
||||
"function_getting_tested": "fibonacci",
|
||||
"loop_index": loop,
|
||||
"iteration_id": f"{iteration}_0",
|
||||
"return_value": str(loop * iteration),
|
||||
}
|
||||
)
|
||||
|
||||
create_test_results_db(original_path, results)
|
||||
create_test_results_db(candidate_path, results)
|
||||
|
||||
# Compare using Java Comparator
|
||||
equivalent, diffs = compare_test_results(original_path, candidate_path)
|
||||
|
||||
assert equivalent is True
|
||||
assert len(diffs) == 0
|
||||
|
||||
def test_comparator_iteration_id_parsing(
|
||||
self, tmp_path: Path, create_test_results_db
|
||||
):
|
||||
"""Test that Comparator correctly parses iteration_id format 'iter_testIteration'."""
|
||||
original_path = tmp_path / "original.db"
|
||||
candidate_path = tmp_path / "candidate.db"
|
||||
|
||||
# Test various iteration_id formats
|
||||
results = [
|
||||
{
|
||||
"loop_index": 1,
|
||||
"iteration_id": "1_0", # Standard format
|
||||
"return_value": '{"result": 1}',
|
||||
},
|
||||
{
|
||||
"loop_index": 1,
|
||||
"iteration_id": "2_5", # With test iteration
|
||||
"return_value": '{"result": 2}',
|
||||
},
|
||||
{
|
||||
"loop_index": 2,
|
||||
"iteration_id": "1_0", # Different loop
|
||||
"return_value": '{"result": 3}',
|
||||
},
|
||||
]
|
||||
|
||||
create_test_results_db(original_path, results)
|
||||
create_test_results_db(candidate_path, results)
|
||||
|
||||
# Compare using Java Comparator
|
||||
equivalent, diffs = compare_test_results(original_path, candidate_path)
|
||||
|
||||
assert equivalent is True
|
||||
assert len(diffs) == 0
|
||||
|
||||
def test_comparator_missing_result_in_candidate(
|
||||
self, tmp_path: Path, create_test_results_db
|
||||
):
|
||||
"""Test that Comparator detects missing results in candidate."""
|
||||
original_path = tmp_path / "original.db"
|
||||
candidate_path = tmp_path / "candidate.db"
|
||||
|
||||
original_results = [
|
||||
{
|
||||
"loop_index": 1,
|
||||
"iteration_id": "1_0",
|
||||
"return_value": '{"value": 1}',
|
||||
},
|
||||
{
|
||||
"loop_index": 1,
|
||||
"iteration_id": "2_0",
|
||||
"return_value": '{"value": 2}',
|
||||
},
|
||||
]
|
||||
|
||||
candidate_results = [
|
||||
{
|
||||
"loop_index": 1,
|
||||
"iteration_id": "1_0",
|
||||
"return_value": '{"value": 1}',
|
||||
},
|
||||
# Missing second iteration
|
||||
]
|
||||
|
||||
create_test_results_db(original_path, original_results)
|
||||
create_test_results_db(candidate_path, candidate_results)
|
||||
|
||||
# Compare using Java Comparator
|
||||
equivalent, diffs = compare_test_results(original_path, candidate_path)
|
||||
|
||||
assert equivalent is False
|
||||
assert len(diffs) >= 1 # Should detect missing invocation
|
||||
|
|
|
|||
350
tests/test_languages/test_java_e2e.py
Normal file
350
tests/test_languages/test_java_e2e.py
Normal file
|
|
@ -0,0 +1,350 @@
|
|||
"""End-to-end integration tests for Java pipeline.
|
||||
|
||||
Tests the full optimization pipeline for Java:
|
||||
- Function discovery
|
||||
- Code context extraction
|
||||
- Test discovery
|
||||
- Code replacement
|
||||
"""
|
||||
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
from codeflash.discovery.functions_to_optimize import find_all_functions_in_file, get_files_for_language
|
||||
from codeflash.languages.base import Language
|
||||
|
||||
|
||||
class TestJavaFunctionDiscovery:
|
||||
"""Tests for Java function discovery in the main pipeline."""
|
||||
|
||||
@pytest.fixture
|
||||
def java_project_dir(self):
|
||||
"""Get the Java sample project directory."""
|
||||
project_root = Path(__file__).parent.parent.parent
|
||||
java_dir = project_root / "code_to_optimize" / "java"
|
||||
if not java_dir.exists():
|
||||
pytest.skip("code_to_optimize/java directory not found")
|
||||
return java_dir
|
||||
|
||||
def test_discover_functions_in_bubble_sort(self, java_project_dir):
|
||||
"""Test discovering functions in BubbleSort.java."""
|
||||
sort_file = java_project_dir / "src" / "main" / "java" / "com" / "example" / "BubbleSort.java"
|
||||
if not sort_file.exists():
|
||||
pytest.skip("BubbleSort.java not found")
|
||||
|
||||
functions = find_all_functions_in_file(sort_file)
|
||||
|
||||
assert sort_file in functions
|
||||
func_list = functions[sort_file]
|
||||
|
||||
# Should find the sorting methods
|
||||
func_names = {f.function_name for f in func_list}
|
||||
assert "bubbleSort" in func_names
|
||||
assert "bubbleSortDescending" in func_names
|
||||
assert "insertionSort" in func_names
|
||||
assert "selectionSort" in func_names
|
||||
assert "isSorted" in func_names
|
||||
|
||||
# All should be Java methods
|
||||
for func in func_list:
|
||||
assert func.language == "java"
|
||||
|
||||
def test_discover_functions_in_calculator(self, java_project_dir):
|
||||
"""Test discovering functions in Calculator.java."""
|
||||
calc_file = java_project_dir / "src" / "main" / "java" / "com" / "example" / "Calculator.java"
|
||||
if not calc_file.exists():
|
||||
pytest.skip("Calculator.java not found")
|
||||
|
||||
functions = find_all_functions_in_file(calc_file)
|
||||
|
||||
assert calc_file in functions
|
||||
func_list = functions[calc_file]
|
||||
|
||||
func_names = {f.function_name for f in func_list}
|
||||
assert "add" in func_names or len(func_names) > 0 # Should find at least some methods
|
||||
|
||||
def test_get_java_files(self, java_project_dir):
|
||||
"""Test getting Java files from directory."""
|
||||
source_dir = java_project_dir / "src" / "main" / "java"
|
||||
files = get_files_for_language(source_dir, Language.JAVA)
|
||||
|
||||
# Should find .java files
|
||||
java_files = [f for f in files if f.suffix == ".java"]
|
||||
assert len(java_files) >= 5 # BubbleSort, Calculator, etc.
|
||||
|
||||
|
||||
class TestJavaCodeContext:
|
||||
"""Tests for Java code context extraction."""
|
||||
|
||||
@pytest.fixture
|
||||
def java_project_dir(self):
|
||||
"""Get the Java sample project directory."""
|
||||
project_root = Path(__file__).parent.parent.parent
|
||||
java_dir = project_root / "code_to_optimize" / "java"
|
||||
if not java_dir.exists():
|
||||
pytest.skip("code_to_optimize/java directory not found")
|
||||
return java_dir
|
||||
|
||||
def test_extract_code_context_for_java(self, java_project_dir):
|
||||
"""Test extracting code context for a Java method."""
|
||||
from codeflash.context.code_context_extractor import get_code_optimization_context
|
||||
from codeflash.languages import current as lang_current
|
||||
from codeflash.languages.base import Language
|
||||
|
||||
# Force set language to Java for proper context extraction routing
|
||||
lang_current._current_language = Language.JAVA
|
||||
|
||||
sort_file = java_project_dir / "src" / "main" / "java" / "com" / "example" / "BubbleSort.java"
|
||||
if not sort_file.exists():
|
||||
pytest.skip("BubbleSort.java not found")
|
||||
|
||||
functions = find_all_functions_in_file(sort_file)
|
||||
func_list = functions[sort_file]
|
||||
|
||||
# Find the bubbleSort method
|
||||
bubble_func = next((f for f in func_list if f.function_name == "bubbleSort"), None)
|
||||
assert bubble_func is not None
|
||||
|
||||
# Extract code context
|
||||
context = get_code_optimization_context(bubble_func, java_project_dir)
|
||||
|
||||
# Verify context structure
|
||||
assert context.read_writable_code is not None
|
||||
assert context.read_writable_code.language == "java"
|
||||
assert len(context.read_writable_code.code_strings) > 0
|
||||
|
||||
# The code should contain the method
|
||||
code = context.read_writable_code.code_strings[0].code
|
||||
assert "bubbleSort" in code
|
||||
|
||||
|
||||
class TestJavaCodeReplacement:
|
||||
"""Tests for Java code replacement."""
|
||||
|
||||
def test_replace_method_in_java_file(self):
|
||||
"""Test replacing a method in a Java file."""
|
||||
from codeflash.languages import get_language_support
|
||||
from codeflash.languages.base import FunctionInfo, Language, ParentInfo
|
||||
|
||||
original_source = """package com.example;
|
||||
|
||||
public class Calculator {
|
||||
public int add(int a, int b) {
|
||||
return a + b;
|
||||
}
|
||||
|
||||
public int multiply(int a, int b) {
|
||||
return a * b;
|
||||
}
|
||||
}
|
||||
"""
|
||||
|
||||
new_method = """public int add(int a, int b) {
|
||||
// Optimized version
|
||||
return a + b;
|
||||
}"""
|
||||
|
||||
java_support = get_language_support(Language.JAVA)
|
||||
|
||||
# Create FunctionInfo for the add method with parent class
|
||||
func_info = FunctionInfo(
|
||||
name="add",
|
||||
file_path=Path("/tmp/Calculator.java"),
|
||||
start_line=4,
|
||||
end_line=6,
|
||||
language=Language.JAVA,
|
||||
parents=(ParentInfo(name="Calculator", type="ClassDef"),),
|
||||
)
|
||||
|
||||
result = java_support.replace_function(original_source, func_info, new_method)
|
||||
|
||||
# Verify the method was replaced
|
||||
assert "// Optimized version" in result
|
||||
assert "multiply" in result # Other method should still be there
|
||||
|
||||
|
||||
class TestJavaTestDiscovery:
|
||||
"""Tests for Java test discovery."""
|
||||
|
||||
@pytest.fixture
|
||||
def java_project_dir(self):
|
||||
"""Get the Java sample project directory."""
|
||||
project_root = Path(__file__).parent.parent.parent
|
||||
java_dir = project_root / "code_to_optimize" / "java"
|
||||
if not java_dir.exists():
|
||||
pytest.skip("code_to_optimize/java directory not found")
|
||||
return java_dir
|
||||
|
||||
def test_discover_junit_tests(self, java_project_dir):
|
||||
"""Test discovering JUnit tests for Java methods."""
|
||||
from codeflash.languages import get_language_support
|
||||
from codeflash.languages.base import FunctionInfo, Language, ParentInfo
|
||||
|
||||
java_support = get_language_support(Language.JAVA)
|
||||
test_root = java_project_dir / "src" / "test" / "java"
|
||||
|
||||
if not test_root.exists():
|
||||
pytest.skip("test directory not found")
|
||||
|
||||
# Create FunctionInfo for bubbleSort method with parent class
|
||||
sort_file = java_project_dir / "src" / "main" / "java" / "com" / "example" / "BubbleSort.java"
|
||||
func_info = FunctionInfo(
|
||||
name="bubbleSort",
|
||||
file_path=sort_file,
|
||||
start_line=14,
|
||||
end_line=37,
|
||||
language=Language.JAVA,
|
||||
parents=(ParentInfo(name="BubbleSort", type="ClassDef"),),
|
||||
)
|
||||
|
||||
# Discover tests
|
||||
tests = java_support.discover_tests(test_root, [func_info])
|
||||
|
||||
# Should find tests for bubbleSort
|
||||
assert func_info.qualified_name in tests or "bubbleSort" in str(tests)
|
||||
|
||||
|
||||
class TestJavaPipelineIntegration:
|
||||
"""Integration tests for the full Java pipeline."""
|
||||
|
||||
def test_function_to_optimize_has_correct_fields(self):
|
||||
"""Test that FunctionToOptimize from Java has all required fields."""
|
||||
with tempfile.NamedTemporaryFile(suffix=".java", mode="w", delete=False) as f:
|
||||
f.write("""package com.example;
|
||||
|
||||
public class Calculator {
|
||||
public int add(int a, int b) {
|
||||
return a + b;
|
||||
}
|
||||
|
||||
public int subtract(int a, int b) {
|
||||
return a - b;
|
||||
}
|
||||
|
||||
public static int multiply(int x, int y) {
|
||||
return x * y;
|
||||
}
|
||||
}
|
||||
""")
|
||||
f.flush()
|
||||
file_path = Path(f.name)
|
||||
|
||||
functions = find_all_functions_in_file(file_path)
|
||||
|
||||
# Should find class methods
|
||||
assert len(functions.get(file_path, [])) >= 3
|
||||
|
||||
# Check instance method
|
||||
add_fn = next((fn for fn in functions[file_path] if fn.function_name == "add"), None)
|
||||
assert add_fn is not None
|
||||
assert add_fn.language == "java"
|
||||
assert len(add_fn.parents) == 1
|
||||
assert add_fn.parents[0].name == "Calculator"
|
||||
|
||||
# Check static method
|
||||
multiply_fn = next((fn for fn in functions[file_path] if fn.function_name == "multiply"), None)
|
||||
assert multiply_fn is not None
|
||||
assert multiply_fn.language == "java"
|
||||
|
||||
def test_code_strings_markdown_uses_java_tag(self):
|
||||
"""Test that CodeStringsMarkdown uses java for code blocks."""
|
||||
from codeflash.models.models import CodeString, CodeStringsMarkdown
|
||||
|
||||
code_strings = CodeStringsMarkdown(
|
||||
code_strings=[
|
||||
CodeString(
|
||||
code="public int add(int a, int b) { return a + b; }",
|
||||
file_path=Path("Calculator.java"),
|
||||
language="java",
|
||||
)
|
||||
],
|
||||
language="java",
|
||||
)
|
||||
|
||||
markdown = code_strings.markdown
|
||||
assert "```java" in markdown
|
||||
|
||||
|
||||
class TestJavaProjectDetection:
|
||||
"""Tests for Java project detection."""
|
||||
|
||||
@pytest.fixture
|
||||
def java_project_dir(self):
|
||||
"""Get the Java sample project directory."""
|
||||
project_root = Path(__file__).parent.parent.parent
|
||||
java_dir = project_root / "code_to_optimize" / "java"
|
||||
if not java_dir.exists():
|
||||
pytest.skip("code_to_optimize/java directory not found")
|
||||
return java_dir
|
||||
|
||||
def test_detect_maven_project(self, java_project_dir):
|
||||
"""Test detecting Maven project structure."""
|
||||
from codeflash.languages.java.config import detect_java_project
|
||||
|
||||
config = detect_java_project(java_project_dir)
|
||||
|
||||
assert config is not None
|
||||
assert config.source_root is not None
|
||||
assert config.test_root is not None
|
||||
assert config.has_junit5 is True
|
||||
|
||||
|
||||
class TestJavaCompilation:
|
||||
"""Tests for Java compilation."""
|
||||
|
||||
@pytest.fixture
|
||||
def java_project_dir(self):
|
||||
"""Get the Java sample project directory."""
|
||||
project_root = Path(__file__).parent.parent.parent
|
||||
java_dir = project_root / "code_to_optimize" / "java"
|
||||
if not java_dir.exists():
|
||||
pytest.skip("code_to_optimize/java directory not found")
|
||||
return java_dir
|
||||
|
||||
@pytest.mark.slow
|
||||
def test_compile_java_project(self, java_project_dir):
|
||||
"""Test that the sample Java project compiles successfully."""
|
||||
import subprocess
|
||||
|
||||
# Check if Maven is available
|
||||
try:
|
||||
result = subprocess.run(["mvn", "--version"], capture_output=True, timeout=10)
|
||||
if result.returncode != 0:
|
||||
pytest.skip("Maven not available")
|
||||
except FileNotFoundError:
|
||||
pytest.skip("Maven not installed")
|
||||
|
||||
# Compile the project
|
||||
result = subprocess.run(
|
||||
["mvn", "compile", "-q"],
|
||||
cwd=java_project_dir,
|
||||
capture_output=True,
|
||||
timeout=120,
|
||||
)
|
||||
|
||||
assert result.returncode == 0, f"Compilation failed: {result.stderr.decode()}"
|
||||
|
||||
@pytest.mark.slow
|
||||
def test_run_java_tests(self, java_project_dir):
|
||||
"""Test that the sample Java tests run successfully."""
|
||||
import subprocess
|
||||
|
||||
# Check if Maven is available
|
||||
try:
|
||||
result = subprocess.run(["mvn", "--version"], capture_output=True, timeout=10)
|
||||
if result.returncode != 0:
|
||||
pytest.skip("Maven not available")
|
||||
except FileNotFoundError:
|
||||
pytest.skip("Maven not installed")
|
||||
|
||||
# Run tests
|
||||
result = subprocess.run(
|
||||
["mvn", "test", "-q"],
|
||||
cwd=java_project_dir,
|
||||
capture_output=True,
|
||||
timeout=180,
|
||||
)
|
||||
|
||||
assert result.returncode == 0, f"Tests failed: {result.stderr.decode()}"
|
||||
Loading…
Reference in a new issue