Merge branch 'main' into fix/trigger_cc_on_multiple_commits

This commit is contained in:
Sarthak Agarwal 2026-03-20 00:20:22 +05:30 committed by GitHub
commit ead7fadac5
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
47 changed files with 6871 additions and 5470 deletions

View file

@ -92,14 +92,14 @@ jobs:
Before doing any work, assess the PR scope:
1. Run `gh pr diff ${{ github.event.pull_request.number }} --name-only` to get changed files.
2. Classify as TRIVIAL if ALL changed files are:
- Config/CI files (.github/, .tessl/, *.toml, *.lock, *.json, *.yml, *.yaml)
- Documentation (*.md, docs/)
- Non-production code (demos/, experiments/, code_to_optimize/)
- Only whitespace, formatting, or comment changes
2. Run `gh pr diff ${{ github.event.pull_request.number }} --stat` to get the total lines changed.
3. Classify the PR into one of these categories:
- TRIVIAL: ALL changed files are config/CI (.github/, .tessl/, *.toml, *.lock, *.json, *.yml, *.yaml), documentation (*.md, docs/), non-production code (demos/, experiments/, code_to_optimize/), or only whitespace/formatting/comment changes.
- SMALL: ≤ 50 lines of production code changed (excluding tests, config, lock files)
- LARGE: > 50 lines of production code changed
If TRIVIAL: post a single comment "No substantive code changes to review." and stop — do not execute any further steps.
Otherwise: continue with the full review below.
Otherwise: record the size (SMALL or LARGE) and continue. Later steps will adapt their depth based on this.
</step>
<step name="lint_and_typecheck">
@ -127,24 +127,47 @@ jobs:
2. For each unresolved thread:
a. Read the file at that path to check if the issue still exists
b. If fixed → resolve it: `gh api graphql -f query='mutation { resolveReviewThread(input: {threadId: "<THREAD_ID>"}) { thread { isResolved } } }'`
b. If fixed → resolve the thread silently (do NOT post a reply comment like "✅ Fixed"):
`gh api graphql -f query='mutation { resolveReviewThread(input: {threadId: "<THREAD_ID>"}) { thread { isResolved } } }'`
c. If still present → leave it
Read the actual code before deciding. If there are no unresolved threads, skip to the next step.
</step>
<step name="review">
Review the diff (`gh pr diff ${{ github.event.pull_request.number }}`) for:
Review the diff (`gh pr diff ${{ github.event.pull_request.number }}`).
SCOPE RULES:
- Only read files that appear in the diff. Do NOT explore the broader codebase unless a specific function call in the diff requires understanding its definition.
- Match review depth to PR size: SMALL PRs get a focused correctness check; LARGE PRs get the full review below.
- For codeflash-ai[bot] optimization PRs: focus on whether the optimization is correct and the speedup claim is credible. Keep the review concise — a short correctness verdict, not a multi-paragraph essay.
Check for:
1. Bugs that will crash at runtime
2. Security vulnerabilities
3. Breaking API changes
4. Design issues (LARGE PRs only):
- Is logic placed in the right module? (e.g., language-specific logic belongs in `languages/<lang>/`, not in the base optimizer)
- Are fixes addressing the root cause or just papering over symptoms? (e.g., hardcoding a list of imports vs having the AI service handle it)
- Is language-agnostic config being stored in a language-specific file? (e.g., storing general config in `pyproject.toml` which is Python-exclusive)
5. Accidental file inclusions — flag if the diff contains:
- Binary files (`.jar`, `.whl`, `.so`, `.dll`, `.exe`)
- Version file changes (`version.py`) that look auto-generated (e.g., `.post<N>.dev0+<hash>`)
- Internal planning/design docs committed to `docs/` (those belong in Linear or a separate location)
- Changes to `codeflash-benchmark/` version files
Ignore style issues, type hints, and log message wording.
Record findings for the summary comment. Refer to CLAUDE.md for project conventions.
</step>
<step name="duplicate_detection">
Check whether this PR introduces code that duplicates logic already present elsewhere in the repository — including across languages. Focus on finding true duplicates, not just similar-looking code.
Check whether this PR introduces code that duplicates logic already present elsewhere in the repository.
Depth depends on PR size:
- SMALL PRs: Quick check — search for any new function names defined elsewhere using Grep. Only flag exact or near-exact duplicates. Skip the cross-module deep dive.
- LARGE PRs: Full analysis as described below.
Full analysis (LARGE PRs only):
1. Get changed source files (excluding tests and config):
`git diff --name-only origin/main...HEAD -- '*.py' '*.js' '*.ts' '*.java' | grep -v -E '(test_|_test\.(py|js|ts)|\.test\.(js|ts)|\.spec\.(js|ts)|conftest\.py|/tests/|/test/|/__tests__/)' | grep -v -E '^(\.github/|code_to_optimize/|\.tessl/|node_modules/)'`
@ -171,6 +194,8 @@ jobs:
</step>
<step name="coverage">
Skip this step for SMALL PRs — only run for LARGE PRs.
Analyze test coverage for changed files:
1. Get changed Python files (excluding tests): `git diff --name-only origin/main...HEAD -- '*.py' | grep -v test`
@ -178,6 +203,7 @@ jobs:
3. Get per-file coverage: `uv run coverage report --include="<changed_files>"`
4. Compare with main: checkout main, run coverage, checkout back
5. Flag: new files below 75%, decreased coverage, untested changed lines
6. If the PR adds new public functions or methods without any corresponding tests, explicitly call this out and request tests be added. New production logic should have tests.
</step>
<step name="summary_comment">
@ -186,6 +212,7 @@ jobs:
## PR Review Summary
### Prek Checks
### Code Review
(Include bugs, security, breaking changes, and design issues. Omit subsections with no findings.)
### Duplicate Detection
### Test Coverage
---
@ -198,20 +225,25 @@ jobs:
For each PR:
- If CI passes and the PR is mergeable → merge with `--squash --delete-branch`
- If CI is failing:
1. Check out the PR branch and inspect the failing tests
2. Attempt to fix the failures (the optimization may have broken tests or introduced issues)
3. If fixed: commit, push, and leave a comment explaining what was fixed
4. If unfixable: close with `gh pr close <number> --comment "Closing: CI checks are failing — <describe the specific failures and why they can't be auto-fixed>." --delete-branch`
- Close the PR (without attempting fixes) if ANY of these apply:
- Older than 7 days
- Has merge conflicts (mergeable state is "CONFLICTING")
- If CI is failing, first determine whether the failures are CAUSED BY the PR or PRE-EXISTING on the base branch:
1. Run `gh pr checks <number>` to identify failing checks
2. Check out the PR's BASE branch and check if the same tests/checks fail there too
3. If failures are PRE-EXISTING on the base branch (not caused by the PR):
- Do NOT close the PR
- Leave a comment: "CI failures are pre-existing on the base branch (not caused by this PR): <list failing checks>. Leaving open for merge once base branch CI is fixed."
4. If failures are CAUSED BY the PR's changes:
a. Check out the PR branch and attempt to fix (lint issues, duplicate definitions, type errors, etc.)
b. If fixed: commit, push, and leave a comment explaining what was fixed
c. If unfixable: close with `gh pr close <number> --comment "Closing: this PR introduces CI failures that cannot be auto-fixed — <describe the specific failures>." --delete-branch`
- Close the PR (without attempting fixes) ONLY if ANY of these apply:
- The optimized function no longer exists in the target file (check the diff)
- Has merge conflicts (mergeable state is "CONFLICTING") AND the PR is older than 3 days (to give time for the base branch to stabilize before giving up)
Close with: `gh pr close <number> --comment "<reason>" --delete-branch`
where <reason> explains WHY the PR is being closed. Examples:
- "Closing: PR is older than 7 days without being merged."
- "Closing: merge conflicts with the target branch."
- "Closing: merge conflicts with the target branch and PR is older than 3 days."
- "Closing: the optimized function no longer exists in the target file."
- NEVER close a PR solely because it is old. Age alone is not a valid reason to close.
- NEVER mass-close multiple PRs without individually evaluating each one.
</step>
<verification>

1
.gitignore vendored
View file

@ -10,6 +10,7 @@ __pycache__/
# Distribution / packaging
.Python
build/
.gradle/
develop-eggs/
cli/dist/
downloads/

View file

@ -0,0 +1,2 @@
.gradle/
build/

View file

@ -0,0 +1,29 @@
plugins {
java
jacoco
}
group = "com.example"
version = "1.0.0"
java {
sourceCompatibility = JavaVersion.VERSION_11
targetCompatibility = JavaVersion.VERSION_11
}
repositories {
mavenCentral()
mavenLocal()
}
dependencies {
testImplementation("org.junit.jupiter:junit-jupiter:5.10.0")
testImplementation("org.junit.jupiter:junit-jupiter-params:5.10.0")
testImplementation("org.xerial:sqlite-jdbc:3.42.0.0")
testRuntimeOnly("org.junit.platform:junit-platform-launcher")
testImplementation(files("/Users/heshammohamed/Work/codeflash/code_to_optimize/java-gradle/libs/codeflash-runtime-1.0.0.jar")) // codeflash-runtime
}
tasks.test {
useJUnitPlatform()
}

View file

@ -0,0 +1,4 @@
[tool.codeflash]
module-root = "src/main/java"
tests-root = "src/test/java"
formatter-cmds = []

View file

@ -0,0 +1 @@
rootProject.name = "codeflash-java-gradle-sample"

View file

@ -0,0 +1,14 @@
package com.example;
public class Fibonacci {
public static long fibonacci(int n) {
if (n < 0) {
throw new IllegalArgumentException("n must be non-negative");
}
if (n <= 1) {
return n;
}
return fibonacci(n - 1) + fibonacci(n - 2);
}
}

View file

@ -0,0 +1,27 @@
package com.example;
import org.junit.jupiter.api.Test;
import static org.junit.jupiter.api.Assertions.*;
class FibonacciTest {
@Test
void testBaseCase() {
assertEquals(0, Fibonacci.fibonacci(0));
assertEquals(1, Fibonacci.fibonacci(1));
}
@Test
void testSmallValues() {
assertEquals(1, Fibonacci.fibonacci(2));
assertEquals(2, Fibonacci.fibonacci(3));
assertEquals(5, Fibonacci.fibonacci(5));
assertEquals(8, Fibonacci.fibonacci(6));
assertEquals(55, Fibonacci.fibonacci(10));
}
@Test
void testNegative() {
assertThrows(IllegalArgumentException.class, () -> Fibonacci.fibonacci(-1));
}
}

View file

@ -0,0 +1,65 @@
plugins {
java
id("com.gradleup.shadow") version "9.0.0-beta12"
}
group = "com.codeflash"
version = "1.0.0"
java {
sourceCompatibility = JavaVersion.VERSION_11
targetCompatibility = JavaVersion.VERSION_11
}
repositories {
mavenCentral()
}
dependencies {
implementation("com.google.code.gson:gson:2.10.1")
implementation("com.esotericsoftware:kryo:5.6.2")
implementation("org.objenesis:objenesis:3.4")
implementation("org.xerial:sqlite-jdbc:3.45.0.0")
implementation("org.ow2.asm:asm:9.7.1")
implementation("org.ow2.asm:asm-commons:9.7.1")
testImplementation("org.junit.jupiter:junit-jupiter:5.10.1")
testRuntimeOnly("org.junit.platform:junit-platform-launcher")
}
tasks.test {
useJUnitPlatform()
jvmArgs(
"--add-opens", "java.base/java.util=ALL-UNNAMED",
"--add-opens", "java.base/java.lang=ALL-UNNAMED",
"--add-opens", "java.base/java.lang.reflect=ALL-UNNAMED",
"--add-opens", "java.base/java.math=ALL-UNNAMED",
"--add-opens", "java.base/java.io=ALL-UNNAMED",
"--add-opens", "java.base/java.net=ALL-UNNAMED",
"--add-opens", "java.base/java.time=ALL-UNNAMED",
)
}
tasks.shadowJar {
archiveBaseName.set("codeflash-runtime")
archiveVersion.set("1.0.0")
archiveClassifier.set("")
relocate("org.objectweb.asm", "com.codeflash.asm")
manifest {
attributes(
"Main-Class" to "com.codeflash.Comparator",
"Premain-Class" to "com.codeflash.profiler.ProfilerAgent",
"Can-Retransform-Classes" to "true",
)
}
exclude("META-INF/*.SF")
exclude("META-INF/*.DSA")
exclude("META-INF/*.RSA")
}
tasks.build {
dependsOn(tasks.shadowJar)
}

View file

@ -0,0 +1 @@
rootProject.name = "codeflash-runtime"

View file

@ -54,8 +54,8 @@ class ComparatorCorrectnessTest {
KryoPlaceholder.create(new Object(), "unserializable", "root")
);
insertRow(originalDb, "iter_1_0", 1, placeholderBytes);
insertRow(candidateDb, "iter_1_1", 1, placeholderBytes);
insertRow(originalDb, "1", 1, placeholderBytes);
insertRow(candidateDb, "1", 1, placeholderBytes);
String json = Comparator.compareDatabases(originalDb.toString(), candidateDb.toString());
Map<String, Object> result = parseJson(json);
@ -74,8 +74,8 @@ class ComparatorCorrectnessTest {
// Insert corrupted byte data that will fail Kryo deserialization
byte[] corruptedBytes = new byte[]{0x01, 0x02, 0x03, (byte) 0xFF, (byte) 0xFE};
insertRow(originalDb, "iter_1_0", 1, corruptedBytes);
insertRow(candidateDb, "iter_1_1", 1, corruptedBytes);
insertRow(originalDb, "1", 1, corruptedBytes);
insertRow(candidateDb, "1", 1, corruptedBytes);
String json = Comparator.compareDatabases(originalDb.toString(), candidateDb.toString());
Map<String, Object> result = parseJson(json);
@ -97,12 +97,12 @@ class ComparatorCorrectnessTest {
KryoPlaceholder.create(new Object(), "unserializable", "root")
);
insertRow(originalDb, "iter_1_0", 1, realBytes1);
insertRow(candidateDb, "iter_1_1", 1, realBytes1);
insertRow(originalDb, "iter_2_0", 1, realBytes2);
insertRow(candidateDb, "iter_2_1", 1, realBytes2);
insertRow(originalDb, "iter_3_0", 1, placeholderBytes);
insertRow(candidateDb, "iter_3_1", 1, placeholderBytes);
insertRow(originalDb, "1", 1, realBytes1);
insertRow(candidateDb, "1", 1, realBytes1);
insertRow(originalDb, "2", 1, realBytes2);
insertRow(candidateDb, "2", 1, realBytes2);
insertRow(originalDb, "3", 1, placeholderBytes);
insertRow(candidateDb, "3", 1, placeholderBytes);
String json = Comparator.compareDatabases(originalDb.toString(), candidateDb.toString());
Map<String, Object> result = parseJson(json);
@ -121,10 +121,10 @@ class ComparatorCorrectnessTest {
byte[] bytes1 = Serializer.serialize(100);
byte[] bytes2 = Serializer.serialize("world");
insertRow(originalDb, "iter_1_0", 1, bytes1);
insertRow(candidateDb, "iter_1_1", 1, bytes1);
insertRow(originalDb, "iter_2_0", 1, bytes2);
insertRow(candidateDb, "iter_2_1", 1, bytes2);
insertRow(originalDb, "1", 1, bytes1);
insertRow(candidateDb, "1", 1, bytes1);
insertRow(originalDb, "2", 1, bytes2);
insertRow(candidateDb, "2", 1, bytes2);
String json = Comparator.compareDatabases(originalDb.toString(), candidateDb.toString());
Map<String, Object> result = parseJson(json);
@ -144,8 +144,8 @@ class ComparatorCorrectnessTest {
byte[] origBytes = Serializer.serialize(42);
byte[] candBytes = Serializer.serialize(99);
insertRow(originalDb, "iter_1_0", 1, origBytes);
insertRow(candidateDb, "iter_1_1", 1, candBytes);
insertRow(originalDb, "1", 1, origBytes);
insertRow(candidateDb, "1", 1, candBytes);
String json = Comparator.compareDatabases(originalDb.toString(), candidateDb.toString());
Map<String, Object> result = parseJson(json);
@ -161,8 +161,8 @@ class ComparatorCorrectnessTest {
createTestDb(candidateDb);
// Insert rows with NULL return_value (void methods)
insertRow(originalDb, "iter_1_0", 1, null);
insertRow(candidateDb, "iter_1_1", 1, null);
insertRow(originalDb, "1", 1, null);
insertRow(candidateDb, "1", 1, null);
String json = Comparator.compareDatabases(originalDb.toString(), candidateDb.toString());
Map<String, Object> result = parseJson(json);
@ -178,7 +178,7 @@ class ComparatorCorrectnessTest {
createTestDb(candidateDb);
byte[] bytes = Serializer.serialize(42);
insertRow(originalDb, "iter_1_0", 1, bytes);
insertRow(originalDb, "1", 1, bytes);
// candidateDb has no rows
String json = Comparator.compareDatabases(originalDb.toString(), candidateDb.toString());
@ -222,7 +222,7 @@ class ComparatorCorrectnessTest {
+ "iteration_id TEXT NOT NULL, "
+ "loop_index INTEGER NOT NULL, "
+ "return_value BLOB, "
+ "PRIMARY KEY (iteration_id, loop_index))");
+ "verification_type TEXT)");
}
}

View file

@ -338,6 +338,19 @@ def get_qualified_name(module_name: str, full_qualified_name: str) -> str:
return full_qualified_name[len(module_name) + 1 :]
_PARAMETERIZED_INDEX_RE = re.compile(r"\[(\d+)")
def extract_parameterized_test_index(test_name: str) -> int:
"""Extract the numeric index from a parameterized test name.
Handles formats like ``test[ 0 ]``, ``test[1]``, and
``test[1] input=foo, expected=bar``. Returns 1 when no numeric index is found.
"""
m = _PARAMETERIZED_INDEX_RE.search(test_name)
return int(m.group(1)) if m else 1
def module_name_from_file_path(file_path: Path, project_root_path: Path, *, traverse_up: bool = False) -> str:
try:
relative_path = file_path.resolve().relative_to(project_root_path.resolve())

View file

@ -13,7 +13,6 @@ from rich.prompt import Confirm
from unidiff import PatchSet
from codeflash.cli_cmds.console import logger
from codeflash.languages.registry import get_supported_extensions
if TYPE_CHECKING:
from git import Repo
@ -26,13 +25,15 @@ def get_git_diff(
uncommitted_changes: bool = False,
since_commit: Optional[str] = None,
) -> dict[str, list[int]]:
from codeflash.languages.current import current_language_support
from codeflash.languages.registry import get_supported_extensions
if repo_directory is None:
repo_directory = Path.cwd()
repository = git.Repo(repo_directory, search_parent_directories=True)
commit = repository.head.commit
supported_extensions = current_language_support().file_extensions
# Use all registered extensions (Python + JS/TS + Java etc.) rather than
# current_language_support() which defaults to Python before language detection runs.
supported_extensions = set(get_supported_extensions())
if since_commit:
# Diff from a base commit to HEAD — captures all changes across multiple commits
uni_diff_text = repository.git.diff(
@ -48,7 +49,6 @@ def get_git_diff(
uni_diff_text = repository.git.diff(
commit.hexsha + "^1", commit.hexsha, ignore_blank_lines=True, ignore_space_at_eol=True
)
supported_extensions = set(get_supported_extensions())
patch_set = PatchSet(StringIO(uni_diff_text))
change_list: dict[str, list[int]] = {} # list of changes
for patched_file in patch_set:

View file

@ -688,14 +688,21 @@ class LanguageSupport(Protocol):
# === Test Result Comparison ===
def compare_test_results(
self, original_results_path: Path, candidate_results_path: Path, project_root: Path | None = None
) -> tuple[bool, list[Any]]:
self,
original_results_path: Path,
candidate_results_path: Path,
project_root: Path | None = None,
project_classpath: str | None = None,
) -> tuple[bool, list]:
"""Compare test results between original and candidate code.
Args:
original_results_path: Path to original test results (e.g., SQLite DB).
candidate_results_path: Path to candidate test results.
project_root: Project root directory (for finding node_modules, etc.).
project_classpath: Full project classpath string (Java only). Needed so
the Comparator JVM can resolve project-specific classes during Kryo
deserialization.
Returns:
Tuple of (are_equivalent, list of TestDiff objects).

View file

@ -23,7 +23,9 @@ if TYPE_CHECKING:
_SOURCE_CRITERIA = FunctionFilterCriteria(require_return=False, require_export=False)
def get_optimized_code_for_module(relative_path: Path, optimized_code: CodeStringsMarkdown) -> str:
def get_optimized_code_for_module(
relative_path: Path, optimized_code: CodeStringsMarkdown, allow_fallback: bool = True
) -> str:
from codeflash.languages.current import is_python
file_to_code_context = optimized_code.file_to_path()
@ -31,6 +33,9 @@ def get_optimized_code_for_module(relative_path: Path, optimized_code: CodeStrin
if module_optimized_code is not None:
return module_optimized_code
if not allow_fallback:
return ""
# Fallback 1: single code block with no file path
if "None" in file_to_code_context and len(file_to_code_context) == 1:
logger.debug(f"Using code block with None file_path for {relative_path}")
@ -72,7 +77,10 @@ def replace_function_definitions_for_language(
and LanguageSupport.discover_functions.
"""
original_source_code: str = module_abspath.read_text(encoding="utf8")
code_to_apply = get_optimized_code_for_module(module_abspath.relative_to(project_root_path), optimized_code)
is_target_file = function_to_optimize is not None and function_to_optimize.file_path == module_abspath
code_to_apply = get_optimized_code_for_module(
module_abspath.relative_to(project_root_path), optimized_code, allow_fallback=is_target_file
)
if not code_to_apply.strip():
return False

View file

@ -5,21 +5,15 @@ test execution, and optimization using tree-sitter for parsing and
Maven/Gradle for build operations.
"""
from codeflash.languages.java.build_tool_strategy import BuildToolStrategy, get_strategy
from codeflash.languages.java.build_tools import (
BuildTool,
JavaProjectInfo,
MavenTestResult,
add_codeflash_dependency_to_pom,
compile_maven_project,
detect_build_tool,
find_gradle_executable,
find_maven_executable,
find_source_root,
find_test_root,
get_classpath,
get_project_info,
install_codeflash_runtime,
run_maven_tests,
)
from codeflash.languages.java.comparator import compare_invocations_directly, compare_test_results
from codeflash.languages.java.config import (
@ -58,6 +52,7 @@ from codeflash.languages.java.instrumentation import (
instrument_generated_java_test,
remove_instrumentation,
)
from codeflash.languages.java.maven_strategy import add_codeflash_dependency, install_codeflash_runtime
from codeflash.languages.java.parser import (
JavaAnalyzer,
JavaClassNode,
@ -103,6 +98,7 @@ from codeflash.languages.java.test_runner import (
__all__ = [
# Build tools
"BuildTool",
"BuildToolStrategy",
# Parser
"JavaAnalyzer",
# Assertion removal
@ -124,7 +120,7 @@ __all__ = [
"JavaTestRunResult",
"MavenTestResult",
"ResolvedImport",
"add_codeflash_dependency_to_pom",
"add_codeflash_dependency",
# Replacement
"add_runtime_comments",
# Test discovery
@ -132,7 +128,6 @@ __all__ = [
# Comparator
"compare_invocations_directly",
"compare_test_results",
"compile_maven_project",
# Instrumentation
"create_benchmark_test",
"detect_build_tool",
@ -148,21 +143,19 @@ __all__ = [
"extract_code_context",
"extract_function_source",
"extract_read_only_context",
"find_gradle_executable",
"find_helper_files",
"find_helper_functions",
"find_maven_executable",
"find_source_root",
"find_test_root",
"find_tests_for_function",
"format_java_code",
"format_java_file",
"get_class_methods",
"get_classpath",
"get_java_analyzer",
"get_java_support",
"get_method_by_name",
"get_project_info",
"get_strategy",
"get_test_class_for_source_class",
"get_test_class_pattern",
"get_test_file_pattern",
@ -189,7 +182,6 @@ __all__ = [
"resolve_imports_for_file",
"run_behavioral_tests",
"run_benchmarking_tests",
"run_maven_tests",
"run_tests",
"transform_java_assertions",
]

View file

@ -0,0 +1,191 @@
"""Abstract build tool strategy for Java projects.
Defines the interface for build-tool-specific operations (compilation,
classpath extraction, test execution, coverage). Concrete implementations
live in maven_strategy.py and gradle_strategy.py.
"""
from __future__ import annotations
import logging
import os
from abc import ABC, abstractmethod
from pathlib import Path
from typing import TYPE_CHECKING, Any
if TYPE_CHECKING:
import subprocess
logger = logging.getLogger(__name__)
_RUNTIME_JAR_NAME = "codeflash-runtime-1.0.0.jar"
_JAVA_RUNTIME_DIR = Path(__file__).parent.parent.parent.parent / "codeflash-java-runtime"
def module_to_dir(test_module: str) -> str:
"""Convert a build-tool module name to a filesystem-relative path.
Gradle uses ``:`` as the module separator (``connect:runtime``), while Maven
uses the directory name directly. On the filesystem the separator is always
``/`` (or ``os.sep``).
"""
return test_module.replace(":", os.sep)
class BuildToolStrategy(ABC):
"""Strategy interface for Java build tool operations.
Only methods that genuinely differ between Maven and Gradle belong here.
Shared logic (direct JVM execution, JUnit XML parsing) stays in test_runner.py.
"""
@property
@abstractmethod
def name(self) -> str:
"""Human-readable name for log messages (e.g. 'Maven', 'Gradle')."""
...
def find_runtime_jar(self) -> Path | None:
"""Find the codeflash-runtime JAR file.
Checks package resources and development build directories.
Subclasses should override to prepend tool-specific cache paths
and fall back to super().find_runtime_jar().
"""
resources_jar = Path(__file__).parent / "resources" / _RUNTIME_JAR_NAME
if resources_jar.exists():
return resources_jar
dev_jar_maven = _JAVA_RUNTIME_DIR / "target" / _RUNTIME_JAR_NAME
if dev_jar_maven.exists():
return dev_jar_maven
dev_jar_gradle = _JAVA_RUNTIME_DIR / "build" / "libs" / _RUNTIME_JAR_NAME
if dev_jar_gradle.exists():
return dev_jar_gradle
return None
@abstractmethod
def find_executable(self, build_root: Path) -> str | None:
"""Find the build tool executable, searching up parent directories if needed."""
...
@abstractmethod
def ensure_runtime(self, build_root: Path, test_module: str | None) -> bool:
"""Install codeflash-runtime JAR and register it as a project dependency."""
...
@abstractmethod
def install_multi_module_deps(self, build_root: Path, test_module: str | None, env: dict[str, str]) -> bool:
"""Pre-install multi-module dependencies so later invocations skip recompilation."""
...
@abstractmethod
def compile_tests(
self, build_root: Path, env: dict[str, str], test_module: str | None, timeout: int = 120
) -> subprocess.CompletedProcess[str]:
"""Compile test code without running tests."""
...
@abstractmethod
def compile_source_only(
self, build_root: Path, env: dict[str, str], test_module: str | None, timeout: int = 120
) -> subprocess.CompletedProcess[str]:
"""Compile only main source code (not tests). Used when test classes are already compiled."""
...
@abstractmethod
def get_classpath(
self, build_root: Path, env: dict[str, str], test_module: str | None, timeout: int = 60
) -> str | None:
"""Return the full test classpath string. Caching is an implementation detail."""
...
@abstractmethod
def get_reports_dir(self, build_root: Path, test_module: str | None) -> Path:
"""Return the directory containing JUnit XML test reports."""
...
@abstractmethod
def get_build_output_dir(self, build_root: Path, test_module: str | None) -> Path:
"""Return the build output directory (e.g. target/ for Maven, build/ for Gradle)."""
...
@abstractmethod
def run_tests_via_build_tool(
self,
build_root: Path,
test_paths: Any,
env: dict[str, str],
timeout: int,
mode: str,
test_module: str | None,
javaagent_arg: str | None = None,
enable_coverage: bool = False,
) -> subprocess.CompletedProcess[str]:
"""Run tests via the build tool (e.g. Maven Surefire). Used as fallback when direct JVM fails."""
...
@abstractmethod
def run_benchmarking_via_build_tool(
self,
test_paths: Any,
test_env: dict[str, str],
cwd: Path,
timeout: int | None,
project_root: Path | None,
min_loops: int,
max_loops: int,
target_duration_seconds: float,
inner_iterations: int,
) -> tuple[Path, Any]:
"""Run benchmarking loop via build tool (fallback when direct JVM fails)."""
...
@abstractmethod
def run_tests_with_coverage(
self,
build_root: Path,
test_module: str | None,
test_paths: Any,
run_env: dict[str, str],
timeout: int,
candidate_index: int,
) -> tuple[subprocess.CompletedProcess[str], Path, Path | None]:
"""Run tests with coverage enabled. Returns (result, junit_xml_path, coverage_xml_path)."""
...
@abstractmethod
def setup_coverage(self, build_root: Path, test_module: str | None, project_root: Path) -> Path | None:
"""Configure coverage tool (e.g. JaCoCo) and return expected XML report path."""
...
@abstractmethod
def get_test_run_command(self, project_root: Path, test_classes: list[str] | None = None) -> list[str]:
"""Return the shell command to run tests, including any test class filters."""
...
def _build_strategy_registry() -> dict[str, type[BuildToolStrategy]]:
"""Lazily import and return the {BuildTool.value -> class} mapping."""
from codeflash.languages.java.gradle_strategy import GradleStrategy
from codeflash.languages.java.maven_strategy import MavenStrategy
return {"maven": MavenStrategy, "gradle": GradleStrategy}
def get_strategy(project_root: Path) -> BuildToolStrategy:
"""Detect build tool and return the appropriate strategy."""
from codeflash.languages.java.build_tools import detect_build_tool
build_tool = detect_build_tool(project_root)
registry = _build_strategy_registry()
strategy_cls = registry.get(build_tool.value)
if strategy_cls is not None:
return strategy_cls()
supported = ", ".join(registry)
msg = f"No supported build tool found in {project_root}. Expected one of: {supported}."
raise ValueError(msg)

View file

@ -7,15 +7,10 @@ This module provides functionality to detect and work with Java build tools
from __future__ import annotations
import logging
import os
import re
import shutil
import subprocess
import urllib.request
import xml.etree.ElementTree as ET
from dataclasses import dataclass
from enum import Enum
from pathlib import Path
from pathlib import Path # noqa: TC003 — used at runtime
logger = logging.getLogger(__name__)
@ -25,71 +20,6 @@ CODEFLASH_RUNTIME_JAR_NAME = f"codeflash-runtime-{CODEFLASH_RUNTIME_VERSION}.jar
JACOCO_PLUGIN_VERSION = "0.8.13"
GITHUB_RELEASE_URL = (
"https://github.com/codeflash-ai/codeflash/releases/download"
f"/runtime-v{CODEFLASH_RUNTIME_VERSION}/{CODEFLASH_RUNTIME_JAR_NAME}"
)
CODEFLASH_CACHE_DIR = Path.home() / ".cache" / "codeflash"
def download_from_github_releases() -> Path | None:
"""Download codeflash-runtime JAR from GitHub Releases.
Downloads to ~/.cache/codeflash/ and returns the path to the downloaded JAR.
Returns None if the download fails (e.g., no release published yet, network error).
This serves as a fallback when Maven Central resolution fails for example,
when the user's project doesn't have Maven installed or Maven Central is unreachable.
Requires a GitHub Release tagged 'runtime-v{version}' with the JAR as an asset.
"""
cache_jar = CODEFLASH_CACHE_DIR / CODEFLASH_RUNTIME_JAR_NAME
if cache_jar.exists():
logger.info("Found cached codeflash-runtime JAR: %s", cache_jar)
return cache_jar
try:
CODEFLASH_CACHE_DIR.mkdir(parents=True, exist_ok=True)
logger.info("Downloading codeflash-runtime from GitHub Releases: %s", GITHUB_RELEASE_URL)
urllib.request.urlretrieve(GITHUB_RELEASE_URL, cache_jar) # noqa: S310
logger.info("Downloaded codeflash-runtime to %s", cache_jar)
return cache_jar
except Exception as e:
logger.debug("GitHub Releases download failed: %s", e)
cache_jar.unlink(missing_ok=True)
return None
def resolve_from_maven_central(maven_root: Path) -> bool:
"""Ask Maven to resolve codeflash-runtime from Maven Central.
This downloads the JAR to ~/.m2/repository/ automatically.
Only works once the JAR is published to Maven Central.
Returns True if Maven successfully resolved the artifact.
"""
mvn = find_maven_executable()
if not mvn:
return False
cmd = [
mvn,
"dependency:resolve",
f"-Dartifact=com.codeflash:codeflash-runtime:{CODEFLASH_RUNTIME_VERSION}",
"-B",
"-q",
]
try:
result = subprocess.run(cmd, check=False, cwd=maven_root, capture_output=True, text=True, timeout=60)
if result.returncode == 0:
logger.info("Resolved codeflash-runtime %s from Maven Central", CODEFLASH_RUNTIME_VERSION)
return True
logger.debug("Maven Central resolution failed: %s", result.stderr)
return False
except Exception as e:
logger.debug("Maven Central resolution error: %s", e)
return False
def _safe_parse_xml(file_path: Path) -> ET.ElementTree:
"""Safely parse an XML file with protections against XXE attacks.
@ -363,191 +293,9 @@ def _get_gradle_project_info(project_root: Path) -> JavaProjectInfo | None:
)
def find_maven_executable(project_root: Path | None = None) -> str | None:
"""Find the Maven executable.
Returns:
Path to mvn executable, or None if not found.
"""
# Check for Maven wrapper in project root first
if project_root is not None:
mvnw_path = project_root / "mvnw"
if mvnw_path.exists():
return str(mvnw_path)
mvnw_cmd_path = project_root / "mvnw.cmd"
if mvnw_cmd_path.exists():
return str(mvnw_cmd_path)
# Check for Maven wrapper in current directory
if Path("mvnw").exists():
return "./mvnw"
if Path("mvnw.cmd").exists():
return "mvnw.cmd"
# Check system Maven
mvn_path = shutil.which("mvn")
if mvn_path:
return mvn_path
return None
def find_gradle_executable(project_root: Path | None = None) -> str | None:
"""Find the Gradle executable.
Checks for Gradle wrapper in the project root and current directory,
then falls back to system Gradle.
Args:
project_root: Optional project root directory to search for Gradle wrapper.
Returns:
Path to gradle executable, or None if not found.
"""
# Check for Gradle wrapper in project root first
if project_root is not None:
gradlew_path = project_root / "gradlew"
if gradlew_path.exists():
return str(gradlew_path)
gradlew_bat_path = project_root / "gradlew.bat"
if gradlew_bat_path.exists():
return str(gradlew_bat_path)
# Check for Gradle wrapper in current directory
if Path("gradlew").exists():
return "./gradlew"
if Path("gradlew.bat").exists():
return "gradlew.bat"
# Check system Gradle
gradle_path = shutil.which("gradle")
if gradle_path:
return gradle_path
return None
def run_maven_tests(
project_root: Path,
test_classes: list[str] | None = None,
test_methods: list[str] | None = None,
env: dict[str, str] | None = None,
timeout: int = 300,
skip_compilation: bool = False,
) -> MavenTestResult:
"""Run Maven tests using Surefire.
Args:
project_root: Root directory of the Maven project.
test_classes: Optional list of test class names to run.
test_methods: Optional list of specific test methods (format: ClassName#methodName).
env: Optional environment variables.
timeout: Maximum time in seconds for test execution.
skip_compilation: Whether to skip compilation (useful when only running tests).
Returns:
MavenTestResult with test execution results.
"""
mvn = find_maven_executable()
if not mvn:
logger.error("Maven not found. Please install Maven or use Maven wrapper.")
return MavenTestResult(
success=False,
tests_run=0,
failures=0,
errors=0,
skipped=0,
surefire_reports_dir=None,
stdout="",
stderr="Maven not found",
returncode=-1,
)
# Build Maven command
cmd = [mvn]
if skip_compilation:
cmd.append("-Dmaven.test.skip=false")
cmd.append("-DskipTests=false")
cmd.append("surefire:test")
else:
cmd.append("test")
# Add test filtering
if test_classes or test_methods:
if test_methods:
# Format: -Dtest=ClassName#method1+method2,OtherClass#method3
tests = ",".join(test_methods)
elif test_classes:
tests = ",".join(test_classes)
cmd.extend(["-Dtest=" + tests])
# Fail at end to run all tests; -B for batch mode (no ANSI colors)
cmd.extend(["-fae", "-B"])
# Use full environment with optional overrides
run_env = os.environ.copy()
if env:
run_env.update(env)
try:
result = subprocess.run(
cmd, check=False, cwd=project_root, env=run_env, capture_output=True, text=True, timeout=timeout
)
# Parse test results from Surefire reports
surefire_dir = project_root / "target" / "surefire-reports"
tests_run, failures, errors, skipped = _parse_surefire_reports(surefire_dir)
return MavenTestResult(
success=result.returncode == 0,
tests_run=tests_run,
failures=failures,
errors=errors,
skipped=skipped,
surefire_reports_dir=surefire_dir if surefire_dir.exists() else None,
stdout=result.stdout,
stderr=result.stderr,
returncode=result.returncode,
)
except subprocess.TimeoutExpired:
logger.exception("Maven test execution timed out after %d seconds", timeout)
return MavenTestResult(
success=False,
tests_run=0,
failures=0,
errors=0,
skipped=0,
surefire_reports_dir=None,
stdout="",
stderr=f"Test execution timed out after {timeout} seconds",
returncode=-2,
)
except Exception as e:
logger.exception("Maven test execution failed: %s", e)
return MavenTestResult(
success=False,
tests_run=0,
failures=0,
errors=0,
skipped=0,
surefire_reports_dir=None,
stdout="",
stderr=str(e),
returncode=-1,
)
def _parse_surefire_reports(surefire_dir: Path) -> tuple[int, int, int, int]:
"""Parse Surefire XML reports to get test counts.
Args:
surefire_dir: Directory containing Surefire XML reports.
Returns:
Tuple of (tests_run, failures, errors, skipped).
@ -564,8 +312,9 @@ def _parse_surefire_reports(surefire_dir: Path) -> tuple[int, int, int, int]:
try:
tree = _safe_parse_xml(xml_file)
root = tree.getroot()
if root is None:
continue
# Safely parse numeric attributes with validation
try:
tests_run += int(root.get("tests", "0"))
except (ValueError, TypeError):
@ -594,411 +343,6 @@ def _parse_surefire_reports(surefire_dir: Path) -> tuple[int, int, int, int]:
return tests_run, failures, errors, skipped
def compile_maven_project(
project_root: Path, include_tests: bool = True, env: dict[str, str] | None = None, timeout: int = 300
) -> tuple[bool, str, str]:
"""Compile a Maven project.
Args:
project_root: Root directory of the Maven project.
include_tests: Whether to compile test classes as well.
env: Optional environment variables.
timeout: Maximum time in seconds for compilation.
Returns:
Tuple of (success, stdout, stderr).
"""
mvn = find_maven_executable()
if not mvn:
return False, "", "Maven not found"
cmd = [mvn]
if include_tests:
cmd.append("test-compile")
else:
cmd.append("compile")
# Skip test execution; -B for batch mode (no ANSI colors)
cmd.extend(["-DskipTests", "-B"])
run_env = os.environ.copy()
if env:
run_env.update(env)
try:
result = subprocess.run(
cmd, check=False, cwd=project_root, env=run_env, capture_output=True, text=True, timeout=timeout
)
return result.returncode == 0, result.stdout, result.stderr
except subprocess.TimeoutExpired:
return False, "", f"Compilation timed out after {timeout} seconds"
except Exception as e:
return False, "", str(e)
def install_codeflash_runtime(project_root: Path, runtime_jar_path: Path) -> bool:
"""Install the codeflash runtime JAR to the local Maven repository.
Args:
project_root: Root directory of the Maven project.
runtime_jar_path: Path to the codeflash-runtime.jar file.
Returns:
True if installation succeeded, False otherwise.
"""
mvn = find_maven_executable()
if not mvn:
logger.error("Maven not found")
return False
if not runtime_jar_path.exists():
logger.error("Runtime JAR not found: %s", runtime_jar_path)
return False
cmd = [
mvn,
"install:install-file",
f"-Dfile={runtime_jar_path}",
"-DgroupId=com.codeflash",
"-DartifactId=codeflash-runtime",
f"-Dversion={CODEFLASH_RUNTIME_VERSION}",
"-Dpackaging=jar",
"-B",
]
try:
result = subprocess.run(cmd, check=False, cwd=project_root, capture_output=True, text=True, timeout=60)
if result.returncode == 0:
logger.info("Successfully installed codeflash-runtime to local Maven repository")
return True
logger.error("Failed to install codeflash-runtime: %s", result.stderr)
return False
except Exception as e:
logger.exception("Failed to install codeflash-runtime: %s", e)
return False
CODEFLASH_DEPENDENCY_SNIPPET = f"""\
<dependency>
<groupId>com.codeflash</groupId>
<artifactId>codeflash-runtime</artifactId>
<version>{CODEFLASH_RUNTIME_VERSION}</version>
<scope>test</scope>
</dependency>
</dependencies>"""
def add_codeflash_dependency_to_pom(pom_path: Path) -> bool:
"""Add codeflash-runtime dependency to pom.xml if not present.
Uses string manipulation instead of ElementTree to preserve the original
XML formatting and namespace prefixes (ElementTree rewrites ns0: prefixes
which breaks Maven).
Args:
pom_path: Path to the pom.xml file.
Returns:
True if dependency was added or already present, False on error.
"""
if not pom_path.exists():
return False
try:
content = pom_path.read_text(encoding="utf-8")
# Check if already present
if "codeflash-runtime" in content:
# If a previous run left a system-scope dependency, replace it with test scope.
# System-scope dependencies cause Maven warnings and are rejected by some projects.
if "<scope>system</scope>" in content:
# Replace ONLY the codeflash-runtime dependency block that has system scope.
# We find each <dependency>...</dependency> block individually and only replace
# the one containing both "codeflash-runtime" and "<scope>system</scope>".
# The previous regex used [\s\S]*? lookaheads that could match across blocks,
# accidentally replacing every dependency in the file.
def replace_system_dep(match: re.Match) -> str:
block = match.group(0)
if "codeflash-runtime" in block and "<scope>system</scope>" in block:
return (
"<dependency>\n"
" <groupId>com.codeflash</groupId>\n"
" <artifactId>codeflash-runtime</artifactId>\n"
f" <version>{CODEFLASH_RUNTIME_VERSION}</version>\n"
" <scope>test</scope>\n"
" </dependency>"
)
return block
content = re.sub(r"<dependency>[\s\S]*?</dependency>", replace_system_dep, content)
pom_path.write_text(content, encoding="utf-8")
logger.info("Replaced system-scope codeflash-runtime dependency with test scope")
return True
logger.info("codeflash-runtime dependency already present in pom.xml")
return True
# Find </dependencies> closing tag and insert before it
closing_tag = "</dependencies>"
idx = content.find(closing_tag)
if idx == -1:
logger.warning("No </dependencies> tag found in pom.xml, cannot add dependency")
return False
new_content = content[:idx] + CODEFLASH_DEPENDENCY_SNIPPET
# Skip the original </dependencies> tag since our snippet includes it
new_content += content[idx + len(closing_tag) :]
pom_path.write_text(new_content, encoding="utf-8")
logger.info("Added codeflash-runtime dependency to pom.xml")
return True
except Exception as e:
logger.exception("Failed to add dependency to pom.xml: %s", e)
return False
def is_jacoco_configured(pom_path: Path) -> bool:
"""Check if JaCoCo plugin is already configured in pom.xml.
Checks both the main build section and any profile build sections.
Args:
pom_path: Path to the pom.xml file.
Returns:
True if JaCoCo plugin is configured anywhere in the pom.xml, False otherwise.
"""
if not pom_path.exists():
return False
try:
tree = _safe_parse_xml(pom_path)
root = tree.getroot()
# Handle Maven namespace
ns_prefix = "{http://maven.apache.org/POM/4.0.0}"
# Check if namespace is used
use_ns = root.tag.startswith("{")
if not use_ns:
ns_prefix = ""
# Search all build/plugins sections (including those in profiles)
# Using .// to search recursively for all plugin elements
for plugin in root.findall(f".//{ns_prefix}plugin" if use_ns else ".//plugin"):
artifact_id = plugin.find(f"{ns_prefix}artifactId" if use_ns else "artifactId")
if artifact_id is not None and artifact_id.text == "jacoco-maven-plugin":
group_id = plugin.find(f"{ns_prefix}groupId" if use_ns else "groupId")
# Verify groupId if present (it's optional for org.jacoco)
if group_id is None or group_id.text == "org.jacoco":
return True
return False
except ET.ParseError as e:
logger.warning("Failed to parse pom.xml for JaCoCo check: %s", e)
return False
def add_jacoco_plugin_to_pom(pom_path: Path) -> bool:
"""Add JaCoCo Maven plugin to pom.xml for coverage collection.
Uses string manipulation to preserve the original XML format and avoid
namespace prefix issues that ElementTree causes.
Args:
pom_path: Path to the pom.xml file.
Returns:
True if plugin was added or already present, False on error.
"""
if not pom_path.exists():
logger.error("pom.xml not found: %s", pom_path)
return False
# Check if already configured
if is_jacoco_configured(pom_path):
logger.info("JaCoCo plugin already configured in pom.xml")
return True
try:
content = pom_path.read_text(encoding="utf-8")
# Basic validation that it's a Maven pom.xml
if "</project>" not in content:
logger.error("Invalid pom.xml: no closing </project> tag found")
return False
# JaCoCo plugin XML to insert (indented for typical pom.xml format)
# Note: For multi-module projects where tests are in a separate module,
# we configure the report to look in multiple directories for classes
jacoco_plugin = f"""
<plugin>
<groupId>org.jacoco</groupId>
<artifactId>jacoco-maven-plugin</artifactId>
<version>{JACOCO_PLUGIN_VERSION}</version>
<executions>
<execution>
<id>prepare-agent</id>
<goals>
<goal>prepare-agent</goal>
</goals>
</execution>
<execution>
<id>report</id>
<phase>verify</phase>
<goals>
<goal>report</goal>
</goals>
<configuration>
<!-- For multi-module projects, include dependency classes -->
<includes>
<include>**/*.class</include>
</includes>
</configuration>
</execution>
</executions>
</plugin>"""
# Find the main <build> section (not inside <profiles>)
# We need to find a <build> that appears after </profiles> or before <profiles>
# or if there's no profiles section at all
profiles_start = content.find("<profiles>")
profiles_end = content.find("</profiles>")
# Find all <build> tags
# Find the main build section - it's the one NOT inside profiles
# Strategy: Look for <build> that comes after </profiles> or before <profiles> (or no profiles)
if profiles_start == -1:
# No profiles, any <build> is the main one
build_start = content.find("<build>")
build_end = content.find("</build>")
else:
# Has profiles - find <build> outside of profiles
# Check for <build> before <profiles>
build_before_profiles = content[:profiles_start].rfind("<build>")
# Check for <build> after </profiles>
build_after_profiles = content[profiles_end:].find("<build>") if profiles_end != -1 else -1
if build_after_profiles != -1:
build_after_profiles += profiles_end
if build_before_profiles != -1:
build_start = build_before_profiles
# Find corresponding </build> - need to handle nested builds
build_end = _find_closing_tag(content, build_start, "build")
elif build_after_profiles != -1:
build_start = build_after_profiles
build_end = _find_closing_tag(content, build_start, "build")
else:
build_start = -1
build_end = -1
if build_start != -1 and build_end != -1:
# Found main build section, find plugins within it
build_section = content[build_start : build_end + len("</build>")]
plugins_start_in_build = build_section.find("<plugins>")
plugins_end_in_build = build_section.rfind("</plugins>")
if plugins_start_in_build != -1 and plugins_end_in_build != -1:
# Insert before </plugins> within the main build section
absolute_plugins_end = build_start + plugins_end_in_build
content = content[:absolute_plugins_end] + jacoco_plugin + "\n " + content[absolute_plugins_end:]
else:
# No plugins section in main build, add one before </build>
plugins_section = f"<plugins>{jacoco_plugin}\n </plugins>\n "
content = content[:build_end] + plugins_section + content[build_end:]
else:
# No main build section found, add one before </project>
project_end = content.rfind("</project>")
build_section = f"""
<build>
<plugins>{jacoco_plugin}
</plugins>
</build>
"""
content = content[:project_end] + build_section + content[project_end:]
pom_path.write_text(content, encoding="utf-8")
logger.info("Added JaCoCo plugin to pom.xml")
return True
except Exception as e:
logger.exception("Failed to add JaCoCo plugin to pom.xml: %s", e)
return False
def _find_closing_tag(content: str, start_pos: int, tag_name: str) -> int:
"""Find the position of the closing tag that matches the opening tag at start_pos.
Handles nested tags of the same name.
Args:
content: The XML content.
start_pos: Position of the opening tag.
tag_name: Name of the tag.
Returns:
Position of the closing tag, or -1 if not found.
"""
open_tag = f"<{tag_name}>"
open_tag_short = f"<{tag_name} " # For tags with attributes
close_tag = f"</{tag_name}>"
# Start searching after the opening tag we're matching
depth = 1 # We've already found the opening tag at start_pos
pos = start_pos + len(f"<{tag_name}") # Move past the opening tag
while pos < len(content):
next_open = content.find(open_tag, pos)
next_open_short = content.find(open_tag_short, pos)
next_close = content.find(close_tag, pos)
if next_close == -1:
return -1
# Find the earliest opening tag (if any)
candidates = [x for x in [next_open, next_open_short] if x != -1 and x < next_close]
next_open_any = min(candidates) if candidates else len(content) + 1
if next_open_any < next_close:
# Found opening tag first - nested tag
depth += 1
pos = next_open_any + 1
else:
# Found closing tag first
depth -= 1
if depth == 0:
return next_close
pos = next_close + len(close_tag)
return -1
def get_jacoco_xml_path(project_root: Path) -> Path:
"""Get the expected path to the JaCoCo XML report.
Args:
project_root: Root directory of the Maven project.
Returns:
Path to the JaCoCo XML report file.
"""
return project_root / "target" / "site" / "jacoco" / "jacoco.xml"
def find_test_root(project_root: Path) -> Path | None:
"""Find the test root directory for a Java project.
@ -1049,60 +393,3 @@ def find_source_root(project_root: Path) -> Path | None:
return src_path
return None
def get_classpath(project_root: Path) -> str | None:
"""Get the classpath for a Java project.
For Maven projects, this runs 'mvn dependency:build-classpath'.
Args:
project_root: Root directory of the Java project.
Returns:
Classpath string, or None if unable to determine.
"""
build_tool = detect_build_tool(project_root)
if build_tool == BuildTool.MAVEN:
return _get_maven_classpath(project_root)
if build_tool == BuildTool.GRADLE:
return _get_gradle_classpath(project_root)
return None
def _get_maven_classpath(project_root: Path) -> str | None:
"""Get classpath from Maven."""
mvn = find_maven_executable()
if not mvn:
return None
try:
result = subprocess.run(
[mvn, "dependency:build-classpath", "-q", "-DincludeScope=test", "-B"],
check=False,
cwd=project_root,
capture_output=True,
text=True,
timeout=120,
)
if result.returncode == 0:
# The classpath is in stdout
return result.stdout.strip()
except Exception as e:
logger.warning("Failed to get Maven classpath: %s", e)
return None
def _get_gradle_classpath(project_root: Path) -> str | None:
"""Get classpath from Gradle.
Note: This requires a custom task to be added to build.gradle.
Returns None for now as Gradle support is not fully implemented.
"""
return None

View file

@ -136,6 +136,7 @@ def compare_test_results(
candidate_sqlite_path: Path,
comparator_jar: Path | None = None,
project_root: Path | None = None,
project_classpath: str | None = None,
) -> tuple[bool, list]:
"""Compare Java test results using the codeflash-runtime Comparator.
@ -150,6 +151,10 @@ def compare_test_results(
candidate_sqlite_path: Path to SQLite database with candidate code results.
comparator_jar: Optional path to the codeflash-runtime JAR.
project_root: Project root directory.
project_classpath: Full project classpath from the build tool. When provided,
the Comparator JVM uses this classpath so Kryo can resolve project-specific
classes during deserialization. Without it, only the runtime JAR is on the
classpath and any project class causes ClassNotFoundException.
Returns:
Tuple of (all_equivalent, list of TestDiff objects).
@ -180,6 +185,16 @@ def compare_test_results(
cwd = project_root or Path.cwd()
# Build classpath: runtime JAR + project classpath (if available).
# The project classpath is needed so Kryo can resolve project-specific classes
# during deserialization. Without it, Class.forName() fails for any type not
# bundled in the runtime JAR.
cp_separator = ";" if platform.system() == "Windows" else ":"
if project_classpath:
full_cp = f"{jar_path}{cp_separator}{project_classpath}"
else:
full_cp = str(jar_path)
try:
result = subprocess.run(
[
@ -200,7 +215,7 @@ def compare_test_results(
"--add-opens",
"java.base/java.util.zip=ALL-UNNAMED",
"-cp",
str(jar_path),
full_cp,
"com.codeflash.Comparator",
str(original_sqlite_path),
str(candidate_sqlite_path),

View file

@ -145,17 +145,28 @@ class JavaFunctionOptimizer(FunctionOptimizer):
def _get_java_sources_root(self) -> Path:
"""Get the Java sources root directory for test files.
For Java projects, tests_root might include the package path
(e.g., test/src/com/aerospike/test). We need to find the base directory
that should contain the package directories, not the tests_root itself.
For multi-module projects (Kafka, OpenSearch, Spring Boot), the test
directory must correspond to the same module as the source file, not
the globally configured tests_root. When the source file follows the
standard Maven/Gradle ``src/main/java`` layout we derive the test root
by replacing ``main`` with ``test`` in the same module prefix.
This method looks for standard Java package prefixes (com, org, net, io, edu, gov)
in the tests_root path and returns everything before that prefix.
Falls back to the existing tests_root-based heuristics for non-standard
layouts or single-module projects.
Returns:
Path to the Java sources root directory.
"""
file_path_str = str(self.function_to_optimize.file_path)
src_main_java_marker = str(Path("src") / "main" / "java")
idx = file_path_str.find(src_main_java_marker)
if idx != -1:
module_prefix = Path(file_path_str[:idx])
derived_test_root = module_prefix / "src" / "test" / "java"
if derived_test_root.exists():
return derived_test_root
tests_root = self.test_cfg.tests_root
parts = tests_root.parts
@ -326,15 +337,45 @@ class JavaFunctionOptimizer(FunctionOptimizer):
original_sqlite = get_run_tmp_file(Path("test_return_values_0.sqlite"))
candidate_sqlite = get_run_tmp_file(Path(f"test_return_values_{optimization_candidate_index}.sqlite"))
if len(baseline_results.behavior_test_results) == 0 or len(candidate_behavior_results) == 0:
return False, []
if original_sqlite.exists() and candidate_sqlite.exists():
match, diffs = self.language_support.compare_test_results(
original_sqlite, candidate_sqlite, project_root=self.project_root
original_sqlite,
candidate_sqlite,
project_root=self.project_root,
project_classpath=self._get_project_classpath(),
)
candidate_sqlite.unlink(missing_ok=True)
else:
match, diffs = compare_test_results(baseline_results.behavior_test_results, candidate_behavior_results)
return match, diffs
_cached_project_classpath: str | None
def _get_project_classpath(self) -> str | None:
"""Get the project's full classpath from the build tool strategy.
The classpath is cached by the strategy after the first test run,
so this is a cheap dict lookup.
"""
if hasattr(self, "_cached_project_classpath"):
return self._cached_project_classpath
try:
import os
from codeflash.languages.java.build_tool_strategy import get_strategy
strategy = get_strategy(self.project_root)
classpath = strategy.get_classpath(self.project_root, os.environ.copy(), None, timeout=60)
self._cached_project_classpath = classpath
return classpath
except Exception:
logger.debug("Could not get project classpath for Comparator", exc_info=True)
return None
def should_skip_sqlite_cleanup(self, testing_type: TestingMode, optimization_iteration: int) -> bool:
return testing_type == TestingMode.BEHAVIOR or optimization_iteration == 0

View file

@ -0,0 +1,815 @@
"""Gradle build tool strategy for Java projects.
Implements BuildToolStrategy for Gradle-based projects, handling compilation,
classpath extraction, test execution, and JaCoCo coverage.
"""
from __future__ import annotations
import logging
import os
import re
import shutil
import subprocess
import tempfile
import xml.etree.ElementTree as ET
from pathlib import Path
from typing import Any
from codeflash.languages.java.build_tool_strategy import BuildToolStrategy, module_to_dir
_BUILD = "build"
logger = logging.getLogger(__name__)
# Groovy init script that disables validation/analysis plugins.
# Equivalent to Maven's -Dcheckstyle.skip=true, -Dspotbugs.skip=true, etc.
# Using an init script is safe even if the plugins aren't applied — unlike
# `-x taskName` which fails if the task doesn't exist.
_GRADLE_SKIP_VALIDATION_INIT_SCRIPT = """\
gradle.projectsEvaluated {
allprojects {
tasks.matching { task ->
task.name in [
'checkstyleMain', 'checkstyleTest',
'spotbugsMain', 'spotbugsTest',
'pmdMain', 'pmdTest',
'rat', 'japicmp'
]
}.configureEach {
enabled = false
}
tasks.withType(JavaCompile) {
options.compilerArgs.removeAll { it == '-Werror' }
options.compilerArgs.removeAll { it == '-Xlint:all' }
if (options.hasProperty('errorprone')) {
options.errorprone {
enabled = false
}
}
}
}
}
"""
# Lazily-created temp file for the validation-skip init script.
_skip_validation_init_path: str | None = None
def _get_skip_validation_init_script() -> str:
"""Return the path to a persistent temp init script that disables validation tasks."""
global _skip_validation_init_path
if _skip_validation_init_path is None or not Path(_skip_validation_init_path).exists():
fd, path = tempfile.mkstemp(suffix=".gradle", prefix="codeflash_skip_validation_")
with os.fdopen(fd, "w", encoding="utf-8") as f:
f.write(_GRADLE_SKIP_VALIDATION_INIT_SCRIPT)
_skip_validation_init_path = path
return _skip_validation_init_path
# Cache for classpath strings — keyed on (gradle_root, test_module).
_classpath_cache: dict[tuple[Path, str | None], str] = {}
# Cache for multi-module dependency installs — keyed on (gradle_root, test_module).
_multimodule_deps_installed: set[tuple[Path, str]] = set()
# Gradle init script that prints the test runtime classpath.
# Uses projectsEvaluated to avoid triggering configuration of unrelated subprojects.
_CLASSPATH_INIT_SCRIPT = """\
gradle.projectsEvaluated {
allprojects {
tasks.register("codeflashPrintClasspath") {
doLast {
def cp = configurations.findByName('testRuntimeClasspath')
if (cp != null && cp.isCanBeResolved()) {
println "CODEFLASH_CP_START"
println cp.asPath
println "CODEFLASH_CP_END"
}
}
}
}
}
"""
# Gradle init script that applies JaCoCo plugin for coverage collection.
# Uses projectsEvaluated to avoid triggering configuration of unrelated subprojects.
_JACOCO_INIT_SCRIPT = """\
gradle.projectsEvaluated {
allprojects {
apply plugin: 'jacoco'
jacocoTestReport {
reports {
xml.required = true
html.required = false
}
}
}
}
"""
def find_gradle_build_file(project_root: Path) -> Path | None:
kts = project_root / "build.gradle.kts"
if kts.exists():
return kts
groovy = project_root / "build.gradle"
if groovy.exists():
return groovy
return None
def _find_top_level_dependencies_block(build_file: Path, content: str) -> int | None:
"""Find the insert position (before closing }) of the top-level dependencies block using tree-sitter.
Returns the byte offset of the closing brace, or None if no top-level dependencies block exists.
Only matches `dependencies { }` at the root level ignores blocks nested inside
`buildscript`, `subprojects`, `allprojects`, etc.
"""
import tree_sitter as ts
is_kts = build_file.name.endswith(".kts")
source_bytes = content.encode("utf-8")
if is_kts:
import tree_sitter_kotlin as tsk
parser = ts.Parser(ts.Language(tsk.language()))
else:
import tree_sitter_groovy as tsg
parser = ts.Parser(ts.Language(tsg.language()))
tree = parser.parse(source_bytes)
# Walk only direct children of root to find top-level `dependencies { }`
for child in tree.root_node.children:
# Groovy: expression_statement > method_invocation(identifier="dependencies", closure)
# Kotlin: call_expression(identifier="dependencies", annotated_lambda)
node = child
if node.type == "expression_statement" and node.child_count > 0:
node = node.children[0]
if node.type not in ("method_invocation", "call_expression"):
continue
name_node = None
body_node = None
for c in node.children:
if c.type == "identifier":
name_node = c
elif c.type in ("closure", "annotated_lambda", "lambda_literal"):
body_node = c
if name_node is None or body_node is None:
continue
name = source_bytes[name_node.start_byte : name_node.end_byte].decode("utf-8")
if name != "dependencies":
continue
# Find the closing brace of this block
closing_brace = body_node.children[-1] if body_node.children else None
# For Kotlin, annotated_lambda wraps lambda_literal
if closing_brace is not None and closing_brace.type == "lambda_literal":
closing_brace = closing_brace.children[-1] if closing_brace.children else None
if closing_brace is not None and closing_brace.type == "}":
return closing_brace.start_byte
return None
def add_codeflash_dependency(build_file: Path, runtime_jar_path: Path) -> bool:
if not build_file.exists():
return False
try:
content = build_file.read_text(encoding="utf-8")
if "codeflash-runtime" in content:
logger.info("codeflash-runtime dependency already present in %s", build_file.name)
return True
is_kts = build_file.name.endswith(".kts")
jar_str = str(runtime_jar_path).replace("\\", "/")
if is_kts:
dep_line = f' testImplementation(files("{jar_str}")) // codeflash-runtime\n'
else:
dep_line = f" testImplementation files('{jar_str}') // codeflash-runtime\n"
# Use tree-sitter to find the top-level dependencies block
insert_pos = _find_top_level_dependencies_block(build_file, content)
if insert_pos is not None:
content = content[:insert_pos] + dep_line + content[insert_pos:]
build_file.write_text(content, encoding="utf-8")
logger.info("Added codeflash-runtime dependency to %s (tree-sitter)", build_file.name)
return True
# No existing dependencies block — append one
if is_kts:
content += f'\ndependencies {{\n testImplementation(files("{jar_str}")) // codeflash-runtime\n}}\n'
else:
content += f"\ndependencies {{\n testImplementation files('{jar_str}') // codeflash-runtime\n}}\n"
build_file.write_text(content, encoding="utf-8")
logger.info("Added codeflash-runtime dependency to %s (new block)", build_file.name)
return True
except Exception as e:
logger.exception("Failed to add dependency to %s: %s", build_file.name, e)
return False
def _normalize_gradle_xml_reports(reports_dir: Path) -> None:
"""Normalize Gradle JUnit XML reports to match Maven Surefire format.
Gradle's JUnit Platform XML differs from Maven Surefire in ways that
can crash the downstream parser:
1. <failure>/<error> elements may omit the ``message`` attribute
Maven always sets it.
2. Timeout information may only appear in the element body text,
not in the ``message`` attribute.
This function rewrites the XML files in-place so they conform to the
Maven Surefire contract the parser expects.
"""
if not reports_dir.exists():
return
for xml_file in reports_dir.glob("TEST-*.xml"):
try:
tree = ET.parse(xml_file)
root = tree.getroot()
modified = False
for tag in ("failure", "error"):
for elem in root.iter(tag):
if elem.get("message") is None:
body = (elem.text or "").strip()
first_line = body.split("\n", 1)[0] if body else ""
elem.set("message", first_line)
modified = True
if modified:
tree.write(xml_file, encoding="unicode", xml_declaration=True)
except ET.ParseError:
logger.debug("Failed to normalize Gradle XML report %s", xml_file)
class GradleStrategy(BuildToolStrategy):
"""Gradle-specific build tool operations."""
@property
def name(self) -> str:
return "Gradle"
def find_executable(self, build_root: Path) -> str | None:
# Walk up from build_root to find gradlew — for multi-module projects
# the wrapper lives at the repo root, which may be a parent of build_root.
current = build_root.resolve()
while True:
gradlew_path = current / "gradlew"
if gradlew_path.exists():
return str(gradlew_path)
gradlew_bat_path = current / "gradlew.bat"
if gradlew_bat_path.exists():
return str(gradlew_bat_path)
parent = current.parent
if parent == current:
break
current = parent
# Fall back to system Gradle
return shutil.which("gradle")
def ensure_runtime(self, build_root: Path, test_module: str | None) -> bool:
runtime_jar = self.find_runtime_jar()
if runtime_jar is None:
logger.error("codeflash-runtime JAR not found. Generated tests will fail to compile.")
return False
if test_module:
module_root = build_root / module_to_dir(test_module)
else:
module_root = build_root
libs_dir = module_root / "libs"
libs_dir.mkdir(parents=True, exist_ok=True)
dest_jar = libs_dir / "codeflash-runtime-1.0.0.jar"
if not dest_jar.exists():
logger.info("Copying codeflash-runtime JAR to %s", dest_jar)
shutil.copy2(runtime_jar, dest_jar)
build_file = find_gradle_build_file(module_root)
if build_file is None:
logger.warning("No build.gradle(.kts) found at %s, cannot add codeflash-runtime dependency", module_root)
return False
if not add_codeflash_dependency(build_file, dest_jar):
logger.error("Failed to add codeflash-runtime dependency to %s", build_file)
return False
return True
def install_multi_module_deps(self, build_root: Path, test_module: str | None, env: dict[str, str]) -> bool:
from codeflash.languages.java.test_runner import _run_cmd_kill_pg_on_timeout
if not test_module:
return True
cache_key = (build_root, test_module)
if cache_key in _multimodule_deps_installed:
logger.debug("Multi-module deps already installed for %s:%s, skipping", build_root, test_module)
return True
gradle = self.find_executable(build_root)
if not gradle:
logger.error("Gradle not found — cannot pre-install multi-module dependencies")
return False
cmd = [gradle, f":{test_module}:classes", "-x", "test", "--build-cache", "--no-daemon"]
cmd.extend(["--init-script", _get_skip_validation_init_script()])
logger.info("Pre-installing multi-module dependencies: %s (module: %s)", build_root, test_module)
logger.debug("Running: %s", " ".join(cmd))
try:
result = _run_cmd_kill_pg_on_timeout(cmd, cwd=build_root, env=env, timeout=300)
if result.returncode != 0:
logger.error(
"Failed to pre-install multi-module deps (exit %d).\nstdout: %s\nstderr: %s",
result.returncode,
result.stdout[-2000:] if result.stdout else "",
result.stderr[-2000:] if result.stderr else "",
)
return False
except Exception:
logger.exception("Exception during multi-module dependency install")
return False
_multimodule_deps_installed.add(cache_key)
logger.info("Multi-module dependencies installed successfully for %s:%s", build_root, test_module)
return True
def compile_tests(
self, build_root: Path, env: dict[str, str], test_module: str | None, timeout: int = 120
) -> subprocess.CompletedProcess[str]:
from codeflash.languages.java.test_runner import _run_cmd_kill_pg_on_timeout
gradle = self.find_executable(build_root)
if not gradle:
logger.error("Gradle not found")
return subprocess.CompletedProcess(args=["gradle"], returncode=-1, stdout="", stderr="Gradle not found")
if test_module:
cmd = [gradle, f":{test_module}:testClasses", "--no-daemon"]
else:
cmd = [gradle, "testClasses", "--no-daemon"]
cmd.extend(["--init-script", _get_skip_validation_init_script()])
logger.debug("Compiling tests: %s in %s", " ".join(cmd), build_root)
try:
return _run_cmd_kill_pg_on_timeout(cmd, cwd=build_root, env=env, timeout=timeout)
except Exception as e:
logger.exception("Gradle compilation failed: %s", e)
return subprocess.CompletedProcess(args=cmd, returncode=-1, stdout="", stderr=str(e))
def compile_source_only(
self, build_root: Path, env: dict[str, str], test_module: str | None, timeout: int = 120
) -> subprocess.CompletedProcess[str]:
from codeflash.languages.java.test_runner import _run_cmd_kill_pg_on_timeout
gradle = self.find_executable(build_root)
if not gradle:
logger.error("Gradle not found")
return subprocess.CompletedProcess(args=["gradle"], returncode=-1, stdout="", stderr="Gradle not found")
if test_module:
cmd = [gradle, f":{test_module}:classes", "--no-daemon"]
else:
cmd = [gradle, "classes", "--no-daemon"]
cmd.extend(["--init-script", _get_skip_validation_init_script()])
logger.debug("Compiling source only: %s in %s", " ".join(cmd), build_root)
try:
return _run_cmd_kill_pg_on_timeout(cmd, cwd=build_root, env=env, timeout=timeout)
except Exception as e:
logger.exception("Gradle source compilation failed: %s", e)
return subprocess.CompletedProcess(args=cmd, returncode=-1, stdout="", stderr=str(e))
def get_classpath(
self, build_root: Path, env: dict[str, str], test_module: str | None, timeout: int = 60
) -> str | None:
key = (build_root, test_module)
cached = _classpath_cache.get(key)
if cached is not None:
logger.debug("Using cached classpath for (%s, %s)", build_root, test_module)
return cached
result = self._get_classpath_uncached(build_root, env, test_module, timeout)
if result is not None:
_classpath_cache[key] = result
return result
def _get_classpath_uncached(
self, build_root: Path, env: dict[str, str], test_module: str | None, timeout: int = 60
) -> str | None:
from codeflash.languages.java.test_runner import _find_junit_console_standalone, _run_cmd_kill_pg_on_timeout
gradle = self.find_executable(build_root)
if not gradle:
return None
# Write init script to a temp file
init_script_fd, init_script_path = tempfile.mkstemp(suffix=".gradle", prefix="codeflash_cp_")
try:
with os.fdopen(init_script_fd, "w", encoding="utf-8") as f:
f.write(_CLASSPATH_INIT_SCRIPT)
if test_module:
task = f":{test_module}:codeflashPrintClasspath"
else:
task = "codeflashPrintClasspath"
cmd = [gradle, "--init-script", init_script_path, task, "-q", "--no-daemon"]
logger.debug("Getting classpath: %s", " ".join(cmd))
result = _run_cmd_kill_pg_on_timeout(cmd, cwd=build_root, env=env, timeout=timeout)
if result.returncode != 0:
logger.error("Failed to get classpath: %s", result.stderr)
return None
classpath = self._parse_classpath_output(result.stdout)
if not classpath:
logger.error("Classpath not found in Gradle output")
return None
if test_module:
module_path = build_root / module_to_dir(test_module)
else:
module_path = build_root
test_classes = module_path / "build" / "classes" / "java" / "test"
main_classes = module_path / "build" / "classes" / "java" / "main"
cp_parts = [classpath]
if test_classes.exists():
cp_parts.append(str(test_classes))
if main_classes.exists():
cp_parts.append(str(main_classes))
if test_module:
module_dir_name = module_to_dir(test_module)
for module_dir in build_root.iterdir():
if module_dir.is_dir() and module_dir.name != module_dir_name:
module_classes = module_dir / "build" / "classes" / "java" / "main"
if module_classes.exists():
logger.debug("Adding multi-module classpath: %s", module_classes)
cp_parts.append(str(module_classes))
if "console-standalone" not in classpath and "ConsoleLauncher" not in classpath:
console_jar = _find_junit_console_standalone()
if console_jar:
logger.debug("Adding JUnit Console Standalone to classpath: %s", console_jar)
cp_parts.append(str(console_jar))
return os.pathsep.join(cp_parts)
except Exception as e:
logger.exception("Failed to get classpath: %s", e)
return None
finally:
Path(init_script_path).unlink(missing_ok=True)
@staticmethod
def _parse_classpath_output(stdout: str) -> str | None:
in_cp = False
for line in stdout.splitlines():
if line.strip() == "CODEFLASH_CP_START":
in_cp = True
continue
if line.strip() == "CODEFLASH_CP_END":
break
if in_cp and line.strip():
return line.strip()
return None
def get_reports_dir(self, build_root: Path, test_module: str | None) -> Path:
build_dir = self.get_build_output_dir(build_root, test_module)
return build_dir / "test-results" / "test"
def get_build_output_dir(self, build_root: Path, test_module: str | None) -> Path:
if test_module:
return build_root.joinpath(module_to_dir(test_module), _BUILD)
return build_root.joinpath(_BUILD)
def run_tests_via_build_tool(
self,
build_root: Path,
test_paths: Any,
env: dict[str, str],
timeout: int,
mode: str,
test_module: str | None,
javaagent_arg: str | None = None,
enable_coverage: bool = False,
) -> subprocess.CompletedProcess[str]:
from codeflash.languages.java.test_runner import _build_test_filter, _run_cmd_kill_pg_on_timeout
gradle = self.find_executable(build_root)
if not gradle:
logger.error("Gradle not found")
return subprocess.CompletedProcess(args=["gradle"], returncode=-1, stdout="", stderr="Gradle not found")
test_filter = _build_test_filter(test_paths, mode=mode)
logger.debug("Built test filter for mode=%s: '%s' (empty=%s)", mode, test_filter, not test_filter)
if test_module:
task = f":{test_module}:test"
else:
task = "test"
# Write an init script that configures JVM args for the test task.
# -Dorg.gradle.jvmargs only affects the Gradle daemon, NOT the forked test JVM.
add_opens = [
"--add-opens",
"java.base/java.util=ALL-UNNAMED",
"--add-opens",
"java.base/java.lang=ALL-UNNAMED",
"--add-opens",
"java.base/java.lang.reflect=ALL-UNNAMED",
"--add-opens",
"java.base/java.io=ALL-UNNAMED",
"--add-opens",
"java.base/java.math=ALL-UNNAMED",
"--add-opens",
"java.base/java.net=ALL-UNNAMED",
"--add-opens",
"java.base/java.util.zip=ALL-UNNAMED",
]
all_jvm_args = list(add_opens)
if javaagent_arg:
all_jvm_args.insert(0, javaagent_arg)
per_test_timeout = max(timeout // 3, 10)
quoted_args = ", ".join(f'"{a}"' for a in all_jvm_args)
init_script_content = (
f"gradle.projectsEvaluated {{\n"
f" allprojects {{\n"
f" tasks.withType(Test) {{\n"
f" jvmArgs({quoted_args})\n"
f' systemProperty "junit.jupiter.execution.timeout.default", "{per_test_timeout}s"\n'
f" reports.junitXml.outputPerTestCase = true\n"
f" filter.failOnNoMatchingTests = false\n"
f" }}\n"
f" }}\n"
f"}}\n"
)
if not test_filter:
error_msg = (
f"Test filter is EMPTY for mode={mode}! "
f"Gradle will run ALL tests instead of the specified tests. "
f"This indicates a problem with test file instrumentation or path resolution."
)
logger.error(error_msg)
raise ValueError(error_msg)
init_fd, init_path = tempfile.mkstemp(suffix=".gradle", prefix="codeflash_jvmargs_")
try:
with os.fdopen(init_fd, "w", encoding="utf-8") as f:
f.write(init_script_content)
cmd = [gradle, task, "--no-daemon", "--rerun", "--init-script", init_path]
cmd.extend(["--init-script", _get_skip_validation_init_script()])
# --continue ensures Gradle keeps going even if some tests fail.
# For coverage: needed so jacocoTestReport runs even after test failures
# (matches Maven's -Dmaven.test.failure.ignore=true).
# Note: multi-module --tests filtering is handled by
# filter.failOnNoMatchingTests = false in the init script above
# (matches Maven's -DfailIfNoTests=false).
if enable_coverage:
cmd.append("--continue")
for class_filter in test_filter.split(","):
class_filter = class_filter.strip()
if class_filter:
cmd.extend(["--tests", class_filter])
logger.debug("Added --tests filters to Gradle command")
# Append jacocoTestReport AFTER --tests so Gradle doesn't try to apply --tests to it
if enable_coverage:
cmd.append("jacocoTestReport")
logger.debug("Running Gradle command: %s in %s", " ".join(cmd), build_root)
result = _run_cmd_kill_pg_on_timeout(cmd, cwd=build_root, env=env, timeout=timeout)
# Normalize XML reports so <failure>/<error> always have a message
# attribute — Maven Surefire always sets it, Gradle may omit it.
reports_dir = self.get_reports_dir(build_root, test_module)
_normalize_gradle_xml_reports(reports_dir)
if result.returncode != 0:
compilation_error_indicators = [
"Compilation failed",
"COMPILATION ERROR",
"cannot find symbol",
"error: package",
]
combined_output = (result.stdout or "") + (result.stderr or "")
has_compilation_error = any(
indicator.lower() in combined_output.lower() for indicator in compilation_error_indicators
)
if has_compilation_error:
logger.error(
"Gradle compilation failed for %s tests. "
"Check that generated test code is syntactically valid Java. "
"Return code: %s",
mode,
result.returncode,
)
output_lines = combined_output.split("\n")
error_context = "\n".join(output_lines[:50]) if len(output_lines) > 50 else combined_output
logger.error("Gradle compilation error output:\n%s", error_context)
return result
except Exception as e:
logger.exception("Gradle test execution failed: %s", e)
return subprocess.CompletedProcess(args=cmd, returncode=-1, stdout="", stderr=str(e))
finally:
Path(init_path).unlink(missing_ok=True)
def run_benchmarking_via_build_tool(
self,
test_paths: Any,
test_env: dict[str, str],
cwd: Path,
timeout: int | None,
project_root: Path | None,
min_loops: int,
max_loops: int,
target_duration_seconds: float,
inner_iterations: int,
) -> tuple[Path, Any]:
import time
from codeflash.languages.java.test_runner import _find_multi_module_root, _get_combined_junit_xml
project_root = project_root or cwd
gradle_root, test_module = _find_multi_module_root(project_root, test_paths)
all_stdout: list[str] = []
all_stderr: list[str] = []
total_start_time = time.time()
loop_count = 0
last_result = None
per_loop_timeout = max(timeout or 0, 120, 60 + inner_iterations)
logger.debug("Using Gradle-based benchmarking (fallback mode)")
for loop_idx in range(1, max_loops + 1):
run_env = os.environ.copy()
run_env.update(test_env)
run_env["CODEFLASH_LOOP_INDEX"] = str(loop_idx)
run_env["CODEFLASH_MODE"] = "performance"
run_env["CODEFLASH_TEST_ITERATION"] = "0"
if "CODEFLASH_INNER_ITERATIONS" not in run_env:
run_env["CODEFLASH_INNER_ITERATIONS"] = str(inner_iterations)
result = self.run_tests_via_build_tool(
gradle_root, test_paths, run_env, timeout=per_loop_timeout, mode="performance", test_module=test_module
)
last_result = result
loop_count = loop_idx
if result.stdout:
all_stdout.append(result.stdout)
if result.stderr:
all_stderr.append(result.stderr)
elapsed = time.time() - total_start_time
if loop_idx >= min_loops and elapsed >= target_duration_seconds:
logger.debug("Stopping Gradle benchmark after %d loops (%.2fs elapsed)", loop_idx, elapsed)
break
if result.returncode != 0:
timing_pattern = re.compile(r"!######[^:]*:[^:]*:[^:]*:[^:]*:[^:]+:[^:]+######!")
has_timing_markers = bool(timing_pattern.search(result.stdout or ""))
if not has_timing_markers:
logger.warning("Tests failed in Gradle loop %d with no timing markers, stopping", loop_idx)
break
logger.debug("Some tests failed in Gradle loop %d but timing markers present, continuing", loop_idx)
combined_stdout = "\n".join(all_stdout)
combined_stderr = "\n".join(all_stderr)
total_iterations = loop_count * inner_iterations
logger.debug(
"Gradle fallback: %d loops x %d iterations = %d total in %.2fs",
loop_count,
inner_iterations,
total_iterations,
time.time() - total_start_time,
)
combined_result = subprocess.CompletedProcess(
args=last_result.args if last_result else ["gradle", "test"],
returncode=last_result.returncode if last_result else -1,
stdout=combined_stdout,
stderr=combined_stderr,
)
reports_dir = self.get_reports_dir(gradle_root, test_module)
result_xml_path = _get_combined_junit_xml(reports_dir, -1)
return result_xml_path, combined_result
def run_tests_with_coverage(
self,
build_root: Path,
test_module: str | None,
test_paths: Any,
run_env: dict[str, str],
timeout: int,
candidate_index: int,
) -> tuple[subprocess.CompletedProcess[str], Path, Path | None]:
from codeflash.languages.java.test_runner import _get_combined_junit_xml
coverage_xml_path = self.setup_coverage(build_root, test_module, build_root)
result = self.run_tests_via_build_tool(
build_root,
test_paths,
run_env,
timeout=timeout,
mode="behavior",
enable_coverage=True,
test_module=test_module,
)
reports_dir = self.get_reports_dir(build_root, test_module)
result_xml_path = _get_combined_junit_xml(reports_dir, candidate_index)
return result, result_xml_path, coverage_xml_path
def setup_coverage(self, build_root: Path, test_module: str | None, project_root: Path) -> Path | None:
if test_module:
module_root = build_root / module_to_dir(test_module)
else:
module_root = project_root
build_file = find_gradle_build_file(module_root)
if build_file is None:
logger.warning("No build.gradle(.kts) found at %s, cannot setup JaCoCo", module_root)
return None
content = build_file.read_text(encoding="utf-8")
if "jacoco" not in content.lower():
logger.info("Adding JaCoCo plugin to %s for coverage collection", build_file.name)
is_kts = build_file.name.endswith(".kts")
if is_kts:
plugin_line = "plugins {\n jacoco\n}\n"
else:
plugin_line = "apply plugin: 'jacoco'\n"
if "plugins {" in content or "plugins{" in content:
# Insert jacoco inside existing plugins block
plugins_idx = content.find("plugins")
brace_depth = 0
for i in range(plugins_idx, len(content)):
if content[i] == "{":
brace_depth += 1
elif content[i] == "}":
brace_depth -= 1
if brace_depth == 0:
insert = " jacoco\n" if is_kts else " id 'jacoco'\n"
content = content[:i] + insert + content[i:]
break
else:
content = plugin_line + content
build_file.write_text(content, encoding="utf-8")
return module_root / "build" / "reports" / "jacoco" / "test" / "jacocoTestReport.xml"
def get_test_run_command(self, project_root: Path, test_classes: list[str] | None = None) -> list[str]:
from codeflash.languages.java.test_runner import _validate_java_class_name
if test_classes:
for test_class in test_classes:
if not _validate_java_class_name(test_class):
msg = f"Invalid test class name: '{test_class}'. Test names must follow Java identifier rules."
raise ValueError(msg)
gradle = self.find_executable(project_root) or "gradle"
cmd = [gradle, "test", "--no-daemon"]
if test_classes:
for cls in test_classes:
cmd.extend(["--tests", cls])
return cmd

View file

@ -0,0 +1,865 @@
"""Maven build tool strategy for Java projects.
Implements BuildToolStrategy for Maven-based projects, handling compilation,
classpath extraction, test execution via Surefire, and JaCoCo coverage.
"""
from __future__ import annotations
import logging
import os
import re
import shutil
import subprocess
import urllib.request
import xml.etree.ElementTree as ET
from pathlib import Path
from typing import Any
from codeflash.languages.java.build_tool_strategy import BuildToolStrategy, module_to_dir
from codeflash.languages.java.build_tools import CODEFLASH_RUNTIME_JAR_NAME, CODEFLASH_RUNTIME_VERSION
_TARGET = "target"
logger = logging.getLogger(__name__)
# Skip validation/analysis plugins that reject generated instrumented files
_MAVEN_VALIDATION_SKIP_FLAGS = [
"-Drat.skip=true",
"-Dcheckstyle.skip=true",
"-Dspotbugs.skip=true",
"-Dpmd.skip=true",
"-Denforcer.skip=true",
"-Djapicmp.skip=true",
"-Derrorprone.skip=true",
"-Dmaven.compiler.failOnWarning=false",
"-Dmaven.compiler.showWarnings=false",
]
# Cache for classpath strings — keyed on (maven_root, test_module).
_classpath_cache: dict[tuple[Path, str | None], str] = {}
# Cache for multi-module dependency installs — keyed on (maven_root, test_module).
_multimodule_deps_installed: set[tuple[Path, str]] = set()
JACOCO_PLUGIN_VERSION = "0.8.13"
GITHUB_RELEASE_URL = (
"https://github.com/codeflash-ai/codeflash/releases/download"
f"/runtime-v{CODEFLASH_RUNTIME_VERSION}/{CODEFLASH_RUNTIME_JAR_NAME}"
)
CODEFLASH_CACHE_DIR = Path.home() / ".cache" / "codeflash"
CODEFLASH_DEPENDENCY_SNIPPET = """\
<dependency>
<groupId>com.codeflash</groupId>
<artifactId>codeflash-runtime</artifactId>
<version>1.0.0</version>
<scope>test</scope>
</dependency>
</dependencies>"""
def download_from_github_releases() -> Path | None:
"""Download codeflash-runtime JAR from GitHub Releases.
Downloads to ~/.cache/codeflash/ and returns the path to the downloaded JAR.
Returns None if the download fails (e.g., no release published yet, network error).
"""
cache_jar = CODEFLASH_CACHE_DIR / CODEFLASH_RUNTIME_JAR_NAME
if cache_jar.exists():
logger.info("Found cached codeflash-runtime JAR: %s", cache_jar)
return cache_jar
try:
CODEFLASH_CACHE_DIR.mkdir(parents=True, exist_ok=True)
logger.info("Downloading codeflash-runtime from GitHub Releases: %s", GITHUB_RELEASE_URL)
urllib.request.urlretrieve(GITHUB_RELEASE_URL, cache_jar) # noqa: S310
logger.info("Downloaded codeflash-runtime to %s", cache_jar)
return cache_jar
except Exception as e:
logger.debug("GitHub Releases download failed: %s", e)
cache_jar.unlink(missing_ok=True)
return None
def resolve_from_maven_central(maven_root: Path) -> bool:
"""Ask Maven to resolve codeflash-runtime from Maven Central.
Downloads the JAR to ~/.m2/repository/ automatically.
Returns True if Maven successfully resolved the artifact.
"""
mvn = shutil.which("mvn")
if not mvn:
return False
cmd = [
mvn,
"dependency:resolve",
f"-Dartifact=com.codeflash:codeflash-runtime:{CODEFLASH_RUNTIME_VERSION}",
"-B",
"-q",
]
try:
result = subprocess.run(cmd, check=False, cwd=maven_root, capture_output=True, text=True, timeout=60)
if result.returncode == 0:
logger.info("Resolved codeflash-runtime %s from Maven Central", CODEFLASH_RUNTIME_VERSION)
return True
logger.debug("Maven Central resolution failed: %s", result.stderr)
return False
except Exception as e:
logger.debug("Maven Central resolution error: %s", e)
return False
def _safe_parse_xml(file_path: Path) -> ET.ElementTree:
content = file_path.read_text(encoding="utf-8")
root = ET.fromstring(content)
return ET.ElementTree(root)
def install_codeflash_runtime(project_root: Path, runtime_jar_path: Path, mvn: str | None = None) -> bool:
if not mvn:
mvn = shutil.which("mvn")
if not mvn:
logger.error("Maven not found")
return False
if not runtime_jar_path.exists():
logger.error("Runtime JAR not found: %s", runtime_jar_path)
return False
cmd = [
mvn,
"install:install-file",
f"-Dfile={runtime_jar_path}",
"-DgroupId=com.codeflash",
"-DartifactId=codeflash-runtime",
"-Dversion=1.0.0",
"-Dpackaging=jar",
"-B",
]
try:
result = subprocess.run(cmd, check=False, cwd=project_root, capture_output=True, text=True, timeout=60)
if result.returncode == 0:
logger.info("Successfully installed codeflash-runtime to local Maven repository")
return True
logger.error("Failed to install codeflash-runtime: %s", result.stderr)
return False
except Exception as e:
logger.exception("Failed to install codeflash-runtime: %s", e)
return False
def add_codeflash_dependency(pom_path: Path) -> bool:
if not pom_path.exists():
return False
try:
content = pom_path.read_text(encoding="utf-8")
if "codeflash-runtime" in content:
if "<scope>system</scope>" in content:
def replace_system_dep(match: re.Match[str]) -> str:
block: str = match.group(0)
if "codeflash-runtime" in block and "<scope>system</scope>" in block:
return (
"<dependency>\n"
" <groupId>com.codeflash</groupId>\n"
" <artifactId>codeflash-runtime</artifactId>\n"
" <version>1.0.0</version>\n"
" <scope>test</scope>\n"
" </dependency>"
)
return block
content = re.sub(r"<dependency>[\s\S]*?</dependency>", replace_system_dep, content)
pom_path.write_text(content, encoding="utf-8")
logger.info("Replaced system-scope codeflash-runtime dependency with test scope")
return True
logger.info("codeflash-runtime dependency already present in pom.xml")
return True
closing_tag = "</dependencies>"
idx = content.find(closing_tag)
if idx == -1:
logger.warning("No </dependencies> tag found in pom.xml, cannot add dependency")
return False
new_content = content[:idx] + CODEFLASH_DEPENDENCY_SNIPPET
new_content += content[idx + len(closing_tag) :]
pom_path.write_text(new_content, encoding="utf-8")
logger.info("Added codeflash-runtime dependency to pom.xml")
return True
except Exception as e:
logger.exception("Failed to add dependency to pom.xml: %s", e)
return False
def is_jacoco_configured(pom_path: Path) -> bool:
if not pom_path.exists():
return False
try:
tree = _safe_parse_xml(pom_path)
root = tree.getroot()
if root is None:
return False
ns_prefix = "{http://maven.apache.org/POM/4.0.0}"
use_ns = root.tag.startswith("{")
if not use_ns:
ns_prefix = ""
for plugin in root.findall(f".//{ns_prefix}plugin" if use_ns else ".//plugin"):
artifact_id = plugin.find(f"{ns_prefix}artifactId" if use_ns else "artifactId")
if artifact_id is not None and artifact_id.text == "jacoco-maven-plugin":
group_id = plugin.find(f"{ns_prefix}groupId" if use_ns else "groupId")
if group_id is None or group_id.text == "org.jacoco":
return True
return False
except ET.ParseError as e:
logger.warning("Failed to parse pom.xml for JaCoCo check: %s", e)
return False
def _find_closing_tag(content: str, start_pos: int, tag_name: str) -> int:
open_tag = f"<{tag_name}>"
open_tag_short = f"<{tag_name} "
close_tag = f"</{tag_name}>"
depth = 1
pos = start_pos + len(f"<{tag_name}")
while pos < len(content):
next_open = content.find(open_tag, pos)
next_open_short = content.find(open_tag_short, pos)
next_close = content.find(close_tag, pos)
if next_close == -1:
return -1
candidates = [x for x in [next_open, next_open_short] if x != -1 and x < next_close]
next_open_any = min(candidates) if candidates else len(content) + 1
if next_open_any < next_close:
depth += 1
pos = next_open_any + 1
else:
depth -= 1
if depth == 0:
return next_close
pos = next_close + len(close_tag)
return -1
def add_jacoco_plugin(pom_path: Path) -> bool:
if not pom_path.exists():
logger.error("pom.xml not found: %s", pom_path)
return False
if is_jacoco_configured(pom_path):
logger.info("JaCoCo plugin already configured in pom.xml")
return True
try:
content = pom_path.read_text(encoding="utf-8")
if "</project>" not in content:
logger.error("Invalid pom.xml: no closing </project> tag found")
return False
jacoco_plugin = f"""
<plugin>
<groupId>org.jacoco</groupId>
<artifactId>jacoco-maven-plugin</artifactId>
<version>{JACOCO_PLUGIN_VERSION}</version>
<executions>
<execution>
<id>prepare-agent</id>
<goals>
<goal>prepare-agent</goal>
</goals>
</execution>
<execution>
<id>report</id>
<phase>verify</phase>
<goals>
<goal>report</goal>
</goals>
<configuration>
<!-- For multi-module projects, include dependency classes -->
<includes>
<include>**/*.class</include>
</includes>
</configuration>
</execution>
</executions>
</plugin>"""
profiles_start = content.find("<profiles>")
profiles_end = content.find("</profiles>")
if profiles_start == -1:
build_start = content.find("<build>")
build_end = content.find("</build>")
else:
build_before_profiles = content[:profiles_start].rfind("<build>")
build_after_profiles = content[profiles_end:].find("<build>") if profiles_end != -1 else -1
if build_after_profiles != -1:
build_after_profiles += profiles_end
if build_before_profiles != -1:
build_start = build_before_profiles
build_end = _find_closing_tag(content, build_start, "build")
elif build_after_profiles != -1:
build_start = build_after_profiles
build_end = _find_closing_tag(content, build_start, "build")
else:
build_start = -1
build_end = -1
if build_start != -1 and build_end != -1:
build_section = content[build_start : build_end + len("</build>")]
plugins_start_in_build = build_section.find("<plugins>")
plugins_end_in_build = build_section.rfind("</plugins>")
if plugins_start_in_build != -1 and plugins_end_in_build != -1:
absolute_plugins_end = build_start + plugins_end_in_build
content = content[:absolute_plugins_end] + jacoco_plugin + "\n " + content[absolute_plugins_end:]
else:
plugins_section = f"<plugins>{jacoco_plugin}\n </plugins>\n "
content = content[:build_end] + plugins_section + content[build_end:]
else:
project_end = content.rfind("</project>")
build_section = f"""
<build>
<plugins>{jacoco_plugin}
</plugins>
</build>
"""
content = content[:project_end] + build_section + content[project_end:]
pom_path.write_text(content, encoding="utf-8")
logger.info("Added JaCoCo plugin to pom.xml")
return True
except Exception as e:
logger.exception("Failed to add JaCoCo plugin to pom.xml: %s", e)
return False
def get_jacoco_report_path(project_root: Path) -> Path:
return project_root / "target" / "site" / "jacoco" / "jacoco.xml"
class MavenStrategy(BuildToolStrategy):
"""Maven-specific build tool operations."""
_M2_JAR = (
Path.home()
/ ".m2"
/ "repository"
/ "com"
/ "codeflash"
/ "codeflash-runtime"
/ "1.0.0"
/ "codeflash-runtime-1.0.0.jar"
)
@property
def name(self) -> str:
return "Maven"
def find_executable(self, build_root: Path) -> str | None:
mvnw_path = build_root / "mvnw"
if mvnw_path.exists():
return str(mvnw_path)
mvnw_cmd_path = build_root / "mvnw.cmd"
if mvnw_cmd_path.exists():
return str(mvnw_cmd_path)
if Path("mvnw").exists():
return "./mvnw"
if Path("mvnw.cmd").exists():
return "mvnw.cmd"
return shutil.which("mvn")
def find_runtime_jar(self) -> Path | None:
if self._M2_JAR.exists():
return self._M2_JAR
return super().find_runtime_jar()
def ensure_runtime(self, build_root: Path, test_module: str | None) -> bool:
if not self._M2_JAR.exists():
if resolve_from_maven_central(build_root):
logger.info("Resolved codeflash-runtime from Maven Central")
else:
runtime_jar = self.find_runtime_jar()
if runtime_jar is None:
runtime_jar = download_from_github_releases()
if runtime_jar is None:
logger.error(
"codeflash-runtime JAR not found. Maven Central resolution failed and "
"GitHub Releases download failed. Generated tests will fail to compile."
)
return False
logger.info("Installing codeflash-runtime JAR to local Maven repository from %s", runtime_jar)
if not install_codeflash_runtime(build_root, runtime_jar, mvn=self.find_executable(build_root)):
logger.error("Failed to install codeflash-runtime to local Maven repository")
return False
if test_module:
pom_path = build_root / module_to_dir(test_module) / "pom.xml"
else:
pom_path = build_root / "pom.xml"
if pom_path.exists():
if not add_codeflash_dependency(pom_path):
logger.error("Failed to add codeflash-runtime dependency to %s", pom_path)
return False
else:
logger.warning("pom.xml not found at %s, cannot add codeflash-runtime dependency", pom_path)
return False
return True
def install_multi_module_deps(self, build_root: Path, test_module: str | None, env: dict[str, str]) -> bool:
from codeflash.languages.java.test_runner import _run_cmd_kill_pg_on_timeout
if not test_module:
return True
cache_key = (build_root, test_module)
if cache_key in _multimodule_deps_installed:
logger.debug("Multi-module deps already installed for %s:%s, skipping", build_root, test_module)
return True
mvn = self.find_executable(build_root)
if not mvn:
logger.error("Maven not found — cannot pre-install multi-module dependencies")
return False
cmd = [mvn, "install", "-DskipTests", "-B", "-pl", module_to_dir(test_module), "-am"]
cmd.extend(_MAVEN_VALIDATION_SKIP_FLAGS)
logger.info("Pre-installing multi-module dependencies: %s (module: %s)", build_root, test_module)
logger.debug("Running: %s", " ".join(cmd))
try:
result = _run_cmd_kill_pg_on_timeout(cmd, cwd=build_root, env=env, timeout=300)
if result.returncode != 0:
logger.error(
"Failed to pre-install multi-module deps (exit %d).\nstdout: %s\nstderr: %s",
result.returncode,
result.stdout[-2000:] if result.stdout else "",
result.stderr[-2000:] if result.stderr else "",
)
return False
except Exception:
logger.exception("Exception during multi-module dependency install")
return False
_multimodule_deps_installed.add(cache_key)
logger.info("Multi-module dependencies installed successfully for %s:%s", build_root, test_module)
return True
def compile_tests(
self, build_root: Path, env: dict[str, str], test_module: str | None, timeout: int = 120
) -> subprocess.CompletedProcess[str]:
from codeflash.languages.java.test_runner import _run_cmd_kill_pg_on_timeout
mvn = self.find_executable(build_root)
if not mvn:
logger.error("Maven not found")
return subprocess.CompletedProcess(args=["mvn"], returncode=-1, stdout="", stderr="Maven not found")
cmd = [mvn, "test-compile", "-e", "-B"]
cmd.extend(_MAVEN_VALIDATION_SKIP_FLAGS)
if test_module:
cmd.extend(["-pl", module_to_dir(test_module)])
logger.debug("Compiling tests: %s in %s", " ".join(cmd), build_root)
try:
return _run_cmd_kill_pg_on_timeout(cmd, cwd=build_root, env=env, timeout=timeout)
except Exception as e:
logger.exception("Maven compilation failed: %s", e)
return subprocess.CompletedProcess(args=cmd, returncode=-1, stdout="", stderr=str(e))
def compile_source_only(
self, build_root: Path, env: dict[str, str], test_module: str | None, timeout: int = 120
) -> subprocess.CompletedProcess[str]:
from codeflash.languages.java.test_runner import _run_cmd_kill_pg_on_timeout
mvn = self.find_executable(build_root)
if not mvn:
logger.error("Maven not found")
return subprocess.CompletedProcess(args=["mvn"], returncode=-1, stdout="", stderr="Maven not found")
cmd = [mvn, "compile", "-e", "-B"]
cmd.extend(_MAVEN_VALIDATION_SKIP_FLAGS)
if test_module:
cmd.extend(["-pl", module_to_dir(test_module)])
logger.debug("Compiling source only: %s in %s", " ".join(cmd), build_root)
try:
return _run_cmd_kill_pg_on_timeout(cmd, cwd=build_root, env=env, timeout=timeout)
except Exception as e:
logger.exception("Maven source compilation failed: %s", e)
return subprocess.CompletedProcess(args=cmd, returncode=-1, stdout="", stderr=str(e))
def get_classpath(
self, build_root: Path, env: dict[str, str], test_module: str | None, timeout: int = 60
) -> str | None:
key = (build_root, test_module)
cached = _classpath_cache.get(key)
if cached is not None:
logger.debug("Using cached classpath for (%s, %s)", build_root, test_module)
return cached
result = self._get_classpath_uncached(build_root, env, test_module, timeout)
if result is not None:
_classpath_cache[key] = result
return result
def _get_classpath_uncached(
self, build_root: Path, env: dict[str, str], test_module: str | None, timeout: int = 60
) -> str | None:
from codeflash.languages.java.test_runner import _find_junit_console_standalone, _run_cmd_kill_pg_on_timeout
mvn = self.find_executable(build_root)
if not mvn:
return None
cp_file = build_root / ".codeflash_classpath.txt"
cmd = [mvn, "dependency:build-classpath", "-DincludeScope=test", f"-Dmdep.outputFile={cp_file}", "-q", "-B"]
if test_module:
cmd.extend(["-pl", module_to_dir(test_module)])
logger.debug("Getting classpath: %s", " ".join(cmd))
try:
result = _run_cmd_kill_pg_on_timeout(cmd, cwd=build_root, env=env, timeout=timeout)
if result.returncode != 0:
logger.error("Failed to get classpath: %s", result.stderr)
return None
if not cp_file.exists():
logger.error("Classpath file not created")
return None
classpath = cp_file.read_text(encoding="utf-8").strip()
if test_module:
module_path = build_root / module_to_dir(test_module)
else:
module_path = build_root
test_classes = module_path / "target" / "test-classes"
main_classes = module_path / "target" / "classes"
cp_parts = [classpath]
if test_classes.exists():
cp_parts.append(str(test_classes))
if main_classes.exists():
cp_parts.append(str(main_classes))
if test_module:
module_dir_name = module_to_dir(test_module)
for module_dir in build_root.iterdir():
if module_dir.is_dir() and module_dir.name != module_dir_name:
module_classes = module_dir / "target" / "classes"
if module_classes.exists():
logger.debug("Adding multi-module classpath: %s", module_classes)
cp_parts.append(str(module_classes))
if "console-standalone" not in classpath and "ConsoleLauncher" not in classpath:
console_jar = _find_junit_console_standalone()
if console_jar:
logger.debug("Adding JUnit Console Standalone to classpath: %s", console_jar)
cp_parts.append(str(console_jar))
return os.pathsep.join(cp_parts)
except Exception as e:
logger.exception("Failed to get classpath: %s", e)
return None
finally:
if cp_file.exists():
cp_file.unlink()
def get_reports_dir(self, build_root: Path, test_module: str | None) -> Path:
target_dir = self.get_build_output_dir(build_root, test_module)
return target_dir / "surefire-reports"
def get_build_output_dir(self, build_root: Path, test_module: str | None) -> Path:
if test_module:
return build_root.joinpath(module_to_dir(test_module), _TARGET)
return build_root.joinpath(_TARGET)
def run_tests_via_build_tool(
self,
build_root: Path,
test_paths: Any,
env: dict[str, str],
timeout: int,
mode: str,
test_module: str | None,
javaagent_arg: str | None = None,
enable_coverage: bool = False,
) -> subprocess.CompletedProcess[str]:
from codeflash.languages.java.test_runner import (
_build_test_filter,
_run_cmd_kill_pg_on_timeout,
_validate_test_filter,
)
mvn = self.find_executable(build_root)
if not mvn:
logger.error("Maven not found")
return subprocess.CompletedProcess(args=["mvn"], returncode=-1, stdout="", stderr="Maven not found")
test_filter = _build_test_filter(test_paths, mode=mode)
logger.debug("Built test filter for mode=%s: '%s' (empty=%s)", mode, test_filter, not test_filter)
maven_goal = "verify" if enable_coverage else "test"
cmd = [mvn, maven_goal, "-fae", "-B"]
cmd.extend(_MAVEN_VALIDATION_SKIP_FLAGS)
add_opens_flags = (
"--add-opens java.base/java.util=ALL-UNNAMED"
" --add-opens java.base/java.lang=ALL-UNNAMED"
" --add-opens java.base/java.lang.reflect=ALL-UNNAMED"
" --add-opens java.base/java.io=ALL-UNNAMED"
" --add-opens java.base/java.math=ALL-UNNAMED"
" --add-opens java.base/java.net=ALL-UNNAMED"
" --add-opens java.base/java.util.zip=ALL-UNNAMED"
)
if javaagent_arg:
cmd.append(f"-DargLine={javaagent_arg} {add_opens_flags}")
else:
cmd.append(f"-DargLine={add_opens_flags}")
if mode == "performance":
cmd.append("-Dsurefire.useFile=false")
if enable_coverage:
cmd.append("-Dmaven.test.failure.ignore=true")
if test_module:
cmd.extend(
[
"-pl",
module_to_dir(test_module),
"-DfailIfNoTests=false",
"-Dsurefire.failIfNoSpecifiedTests=false",
"-DskipTests=false",
]
)
if test_filter:
validated_filter = _validate_test_filter(test_filter)
cmd.append(f"-Dtest={validated_filter}")
logger.debug("Added -Dtest=%s to Maven command", validated_filter)
else:
error_msg = (
f"Test filter is EMPTY for mode={mode}! "
f"Maven will run ALL tests instead of the specified tests. "
f"This indicates a problem with test file instrumentation or path resolution."
)
logger.error(error_msg)
raise ValueError(error_msg)
logger.debug("Running Maven command: %s in %s", " ".join(cmd), build_root)
try:
result = _run_cmd_kill_pg_on_timeout(cmd, cwd=build_root, env=env, timeout=timeout)
if result.returncode != 0:
compilation_error_indicators = [
"[ERROR] COMPILATION ERROR",
"[ERROR] Failed to execute goal org.apache.maven.plugins:maven-compiler-plugin",
"compilation failure",
"cannot find symbol",
"package .* does not exist",
]
combined_output = (result.stdout or "") + (result.stderr or "")
has_compilation_error = any(
indicator.lower() in combined_output.lower() for indicator in compilation_error_indicators
)
if has_compilation_error:
logger.error(
"Maven compilation failed for %s tests. "
"Check that generated test code is syntactically valid Java. "
"Return code: %s",
mode,
result.returncode,
)
output_lines = combined_output.split("\n")
error_context = "\n".join(output_lines[:50]) if len(output_lines) > 50 else combined_output
logger.error("Maven compilation error output:\n%s", error_context)
return result
except Exception as e:
logger.exception("Maven test execution failed: %s", e)
return subprocess.CompletedProcess(args=cmd, returncode=-1, stdout="", stderr=str(e))
def run_benchmarking_via_build_tool(
self,
test_paths: Any,
test_env: dict[str, str],
cwd: Path,
timeout: int | None,
project_root: Path | None,
min_loops: int,
max_loops: int,
target_duration_seconds: float,
inner_iterations: int,
) -> tuple[Path, Any]:
import time
from codeflash.languages.java.test_runner import _find_multi_module_root, _get_combined_junit_xml
project_root = project_root or cwd
maven_root, test_module = _find_multi_module_root(project_root, test_paths)
all_stdout: list[str] = []
all_stderr: list[str] = []
total_start_time = time.time()
loop_count = 0
last_result = None
per_loop_timeout = max(timeout or 0, 120, 60 + inner_iterations)
logger.debug("Using Maven-based benchmarking (fallback mode)")
for loop_idx in range(1, max_loops + 1):
run_env = os.environ.copy()
run_env.update(test_env)
run_env["CODEFLASH_LOOP_INDEX"] = str(loop_idx)
run_env["CODEFLASH_MODE"] = "performance"
run_env["CODEFLASH_TEST_ITERATION"] = "0"
if "CODEFLASH_INNER_ITERATIONS" not in run_env:
run_env["CODEFLASH_INNER_ITERATIONS"] = str(inner_iterations)
result = self.run_tests_via_build_tool(
maven_root, test_paths, run_env, timeout=per_loop_timeout, mode="performance", test_module=test_module
)
last_result = result
loop_count = loop_idx
if result.stdout:
all_stdout.append(result.stdout)
if result.stderr:
all_stderr.append(result.stderr)
elapsed = time.time() - total_start_time
if loop_idx >= min_loops and elapsed >= target_duration_seconds:
logger.debug("Stopping Maven benchmark after %d loops (%.2fs elapsed)", loop_idx, elapsed)
break
if result.returncode != 0:
timing_pattern = re.compile(r"!######[^:]*:[^:]*:[^:]*:[^:]*:[^:]+:[^:]+######!")
has_timing_markers = bool(timing_pattern.search(result.stdout or ""))
if not has_timing_markers:
logger.warning("Tests failed in Maven loop %d with no timing markers, stopping", loop_idx)
break
logger.debug("Some tests failed in Maven loop %d but timing markers present, continuing", loop_idx)
combined_stdout = "\n".join(all_stdout)
combined_stderr = "\n".join(all_stderr)
total_iterations = loop_count * inner_iterations
logger.debug(
"Maven fallback: %d loops x %d iterations = %d total in %.2fs",
loop_count,
inner_iterations,
total_iterations,
time.time() - total_start_time,
)
combined_result = subprocess.CompletedProcess(
args=last_result.args if last_result else ["mvn", "test"],
returncode=last_result.returncode if last_result else -1,
stdout=combined_stdout,
stderr=combined_stderr,
)
reports_dir = self.get_reports_dir(maven_root, test_module)
result_xml_path = _get_combined_junit_xml(reports_dir, -1)
return result_xml_path, combined_result
def run_tests_with_coverage(
self,
build_root: Path,
test_module: str | None,
test_paths: Any,
run_env: dict[str, str],
timeout: int,
candidate_index: int,
) -> tuple[subprocess.CompletedProcess[str], Path, Path | None]:
from codeflash.languages.java.test_runner import _get_combined_junit_xml
coverage_xml_path = self.setup_coverage(build_root, test_module, build_root)
result = self.run_tests_via_build_tool(
build_root,
test_paths,
run_env,
timeout=timeout,
mode="behavior",
enable_coverage=True,
test_module=test_module,
)
reports_dir = self.get_reports_dir(build_root, test_module)
result_xml_path = _get_combined_junit_xml(reports_dir, candidate_index)
return result, result_xml_path, coverage_xml_path
def setup_coverage(self, build_root: Path, test_module: str | None, project_root: Path) -> Path | None:
if test_module:
test_module_pom = build_root / module_to_dir(test_module) / "pom.xml"
if test_module_pom.exists():
if not is_jacoco_configured(test_module_pom):
logger.info("Adding JaCoCo plugin to test module pom.xml: %s", test_module_pom)
add_jacoco_plugin(test_module_pom)
return get_jacoco_report_path(build_root / module_to_dir(test_module))
else:
pom_path = project_root / "pom.xml"
if pom_path.exists():
if not is_jacoco_configured(pom_path):
logger.info("Adding JaCoCo plugin to pom.xml for coverage collection")
add_jacoco_plugin(pom_path)
return get_jacoco_report_path(project_root)
return None
def get_test_run_command(self, project_root: Path, test_classes: list[str] | None = None) -> list[str]:
from codeflash.languages.java.test_runner import _validate_java_class_name
if test_classes:
for test_class in test_classes:
if not _validate_java_class_name(test_class):
msg = f"Invalid test class name: '{test_class}'. Test names must follow Java identifier rules."
raise ValueError(msg)
mvn = self.find_executable(project_root) or "mvn"
cmd = [mvn, "test", "-B"]
if test_classes:
cmd.append(f"-Dtest={','.join(test_classes)}")
return cmd

View file

@ -16,7 +16,7 @@ from typing import TYPE_CHECKING
from junitparser.xunit2 import JUnitXml
from codeflash.cli_cmds.console import console, logger
from codeflash.code_utils.code_utils import module_name_from_file_path
from codeflash.code_utils.code_utils import extract_parameterized_test_index, module_name_from_file_path
from codeflash.models.models import FunctionTestInvocation, InvocationId, TestResults
if TYPE_CHECKING:
@ -128,7 +128,9 @@ def parse_java_test_xml(
if class_name is not None and class_name.startswith(test_module_path):
test_class = class_name[len(test_module_path) + 1 :]
loop_index = int(testcase.name.split("[ ")[-1][:-2]) if testcase.name and "[" in testcase.name else 1
loop_index = (
extract_parameterized_test_index(testcase.name) if testcase.name and "[" in testcase.name else 1
)
timed_out = False
if len(testcase.result) > 1:

View file

@ -16,7 +16,7 @@ import re
import textwrap
from dataclasses import dataclass, field
from pathlib import Path
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Any
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
from codeflash.languages.java.parser import get_java_analyzer
@ -208,7 +208,7 @@ def _insert_class_members(
if not fields and not helpers_before_target and not helpers_after_target:
return source
def get_target_class_and_body(src: str): # type: ignore[return]
def get_target_class_and_body(src: str) -> tuple[Any, Any]:
for cls in analyzer.find_classes(src):
if cls.name == class_name:
body = cls.node.child_by_field_name("body")

View file

@ -66,6 +66,7 @@ class JavaSupport(LanguageSupport):
self.line_profiler_agent_arg: str | None = None
self.line_profiler_warmup_iterations: int = 0
self._language_version: str | None = None
self._test_framework: str = "junit5"
@property
def language(self) -> Language:
@ -79,8 +80,8 @@ class JavaSupport(LanguageSupport):
@property
def test_framework(self) -> str:
"""Primary test framework name."""
return "junit5"
"""Primary test framework name, detected from project build config."""
return self._test_framework
@property
def comment_prefix(self) -> str:
@ -368,10 +369,19 @@ class JavaSupport(LanguageSupport):
# === Test Result Comparison ===
def compare_test_results(
self, original_results_path: Path, candidate_results_path: Path, project_root: Path | None = None
self,
original_results_path: Path,
candidate_results_path: Path,
project_root: Path | None = None,
project_classpath: str | None = None,
) -> tuple[bool, list[Any]]:
"""Compare test results between original and candidate code."""
return _compare_test_results(original_results_path, candidate_results_path, project_root=project_root)
return _compare_test_results(
original_results_path,
candidate_results_path,
project_root=project_root,
project_classpath=project_classpath,
)
# === Reference Finding ===
@ -394,9 +404,10 @@ class JavaSupport(LanguageSupport):
return None
def setup_test_config(self, test_cfg: Any, file_path: Path, current_worktree: Path | None = None) -> None:
return None
# === Configuration ===
"""Detect test framework from project build config (pom.xml / build.gradle)."""
config = detect_java_project(test_cfg.project_root_path)
if config is not None:
self._test_framework = config.test_framework
def adjust_test_config_for_discovery(self, test_cfg: Any) -> None:
"""Adjust test config before test discovery for Java.
@ -534,8 +545,8 @@ class JavaSupport(LanguageSupport):
if self._language_version is None:
self._detect_java_version()
# For now, assume the runtime is available
# A full implementation would check/install the JAR
self._test_framework = config.test_framework
return True
def _detect_java_version(self) -> None:

File diff suppressed because it is too large Load diff

View file

@ -1904,7 +1904,11 @@ class JavaScriptSupport:
# === Test Result Comparison ===
def compare_test_results(
self, original_results_path: Path, candidate_results_path: Path, project_root: Path | None = None
self,
original_results_path: Path,
candidate_results_path: Path,
project_root: Path | None = None,
project_classpath: str | None = None,
) -> tuple[bool, list]:
"""Compare test results between original and candidate code.

View file

@ -861,6 +861,47 @@ def _get_dataclass_config(class_node: ast.ClassDef, import_aliases: dict[str, st
return False, False, False
_ATTRS_NAMESPACES = frozenset({"attrs", "attr"})
_ATTRS_DECORATOR_NAMES = frozenset({"define", "mutable", "frozen", "s", "attrs"})
def _resolve_decorator_name(expr_name: str, import_aliases: dict[str, str]) -> str:
resolved = import_aliases.get(expr_name)
if resolved is not None:
return resolved
first_part, sep, rest = expr_name.partition(".")
if sep:
root_resolved = import_aliases.get(first_part)
if root_resolved is not None:
return f"{root_resolved}.{rest}"
return expr_name
def _get_attrs_config(class_node: ast.ClassDef, import_aliases: dict[str, str]) -> tuple[bool, bool, bool]:
for decorator in class_node.decorator_list:
expr_name = _get_expr_name(decorator)
if expr_name is None:
continue
resolved = _resolve_decorator_name(expr_name, import_aliases)
parts = resolved.split(".")
if len(parts) < 2 or parts[-2] not in _ATTRS_NAMESPACES or parts[-1] not in _ATTRS_DECORATOR_NAMES:
continue
init_enabled = True
kw_only = False
if isinstance(decorator, ast.Call):
for keyword in decorator.keywords:
literal_value = _bool_literal(keyword.value)
if literal_value is None:
continue
if keyword.arg == "init":
init_enabled = literal_value
elif keyword.arg == "kw_only":
kw_only = literal_value
return True, init_enabled, kw_only
return False, False, False
def _is_classvar_annotation(annotation: ast.expr, import_aliases: dict[str, str]) -> bool:
annotation_root = annotation.value if isinstance(annotation, ast.Subscript) else annotation
return _expr_matches_name(annotation_root, import_aliases, "ClassVar")
@ -885,10 +926,13 @@ def _class_has_explicit_init(class_node: ast.ClassDef) -> bool:
def _collect_synthetic_constructor_type_names(class_node: ast.ClassDef, import_aliases: dict[str, str]) -> set[str]:
is_dataclass, dataclass_init_enabled, _ = _get_dataclass_config(class_node, import_aliases)
if not _is_namedtuple_class(class_node, import_aliases) and not is_dataclass:
is_attrs, attrs_init_enabled, _ = _get_attrs_config(class_node, import_aliases)
if not _is_namedtuple_class(class_node, import_aliases) and not is_dataclass and not is_attrs:
return set()
if is_dataclass and not dataclass_init_enabled:
return set()
if is_attrs and not attrs_init_enabled:
return set()
names = set[str]()
for item in class_node.body:
@ -939,9 +983,9 @@ def _extract_synthetic_init_parameters(
kw_only = literal_value
elif keyword.arg == "default":
default_value = _get_node_source(keyword.value, module_source)
elif keyword.arg == "default_factory":
# Default factories still imply an optional constructor parameter, but
# the generated __init__ does not use the field() call directly.
elif keyword.arg in {"default_factory", "factory"}:
# Default factories (dataclass default_factory= / attrs factory=) still imply
# an optional constructor parameter.
default_value = "..."
else:
default_value = _get_node_source(item.value, module_source)
@ -960,13 +1004,17 @@ def _build_synthetic_init_stub(
) -> str | None:
is_namedtuple = _is_namedtuple_class(class_node, import_aliases)
is_dataclass, dataclass_init_enabled, dataclass_kw_only = _get_dataclass_config(class_node, import_aliases)
if not is_namedtuple and not is_dataclass:
is_attrs, attrs_init_enabled, attrs_kw_only = _get_attrs_config(class_node, import_aliases)
if not is_namedtuple and not is_dataclass and not is_attrs:
return None
if is_dataclass and not dataclass_init_enabled:
return None
if is_attrs and not attrs_init_enabled:
return None
kw_only_by_default = dataclass_kw_only or attrs_kw_only
parameters = _extract_synthetic_init_parameters(
class_node, module_source, import_aliases, kw_only_by_default=dataclass_kw_only
class_node, module_source, import_aliases, kw_only_by_default=kw_only_by_default
)
if not parameters:
return None

View file

@ -2,10 +2,11 @@ from __future__ import annotations
import ast
from pathlib import Path
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, cast
from codeflash.code_utils.code_utils import get_run_tmp_file
from codeflash.code_utils.formatter import sort_imports
from codeflash.languages.python.context.code_context_extractor import _ATTRS_DECORATOR_NAMES, _ATTRS_NAMESPACES
if TYPE_CHECKING:
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
@ -80,6 +81,7 @@ class InitDecorator(ast.NodeTransformer):
self.has_import = False
self.tests_root = tests_root
self.inserted_decorator = False
self._attrs_classes_to_patch: dict[str, ast.Call] = {}
# Precompute decorator components to avoid reconstructing on every node visit
# Only the `function_name` field changes per class
@ -118,6 +120,21 @@ class InitDecorator(ast.NodeTransformer):
defaults=[],
)
# Pre-build reusable AST nodes for _build_attrs_patch_block
self._load_ctx = ast.Load()
self._store_ctx = ast.Store()
self._args_name_load = ast.Name(id="args", ctx=self._load_ctx)
self._kwargs_name_load = ast.Name(id="kwargs", ctx=self._load_ctx)
self._self_arg_node = ast.arg(arg="self")
self._args_arg_node = ast.arg(arg="args")
self._kwargs_arg_node = ast.arg(arg="kwargs")
self._self_name_load = ast.Name(id="self", ctx=self._load_ctx)
self._starred_args = ast.Starred(value=self._args_name_load, ctx=self._load_ctx)
self._kwargs_keyword = ast.keyword(arg=None, value=self._kwargs_name_load)
# Pre-parse the import statement to avoid repeated parsing in visit_Module
self._import_stmt = ast.parse("from codeflash.verification.codeflash_capture import codeflash_capture").body[0]
def visit_ImportFrom(self, node: ast.ImportFrom) -> ast.ImportFrom:
# Check if our import already exists
if node.module == "codeflash.verification.codeflash_capture" and any(
@ -128,10 +145,20 @@ class InitDecorator(ast.NodeTransformer):
def visit_Module(self, node: ast.Module) -> ast.Module:
self.generic_visit(node)
# Insert module-level monkey-patch wrappers for attrs classes immediately after their
# class definitions. We do this before inserting the import so indices stay stable.
if self._attrs_classes_to_patch:
new_body: list[ast.stmt] = []
for stmt in node.body:
new_body.append(stmt)
if isinstance(stmt, ast.ClassDef) and stmt.name in self._attrs_classes_to_patch:
new_body.extend(self._build_attrs_patch_block(stmt.name, self._attrs_classes_to_patch[stmt.name]))
node.body = new_body
# Add import statement
if not self.has_import and self.inserted_decorator:
import_stmt = ast.parse("from codeflash.verification.codeflash_capture import codeflash_capture").body[0]
node.body.insert(0, import_stmt)
node.body.insert(0, self._import_stmt)
return node
@ -171,6 +198,8 @@ class InitDecorator(ast.NodeTransformer):
item.decorator_list.insert(0, decorator)
self.inserted_decorator = True
break
if not has_init:
# Skip dataclasses — their __init__ is auto-generated at class creation time and isn't in the AST.
# The synthetic __init__ with super().__init__(*args, **kwargs) overrides it and fails because
@ -181,6 +210,18 @@ class InitDecorator(ast.NodeTransformer):
dec_name = self._expr_name(dec)
if dec_name is not None and dec_name.endswith("dataclass"):
return node
if dec_name is not None:
parts = dec_name.split(".")
if len(parts) >= 2 and parts[-2] in _ATTRS_NAMESPACES and parts[-1] in _ATTRS_DECORATOR_NAMES:
if isinstance(dec, ast.Call):
for kw in dec.keywords:
if kw.arg == "init" and isinstance(kw.value, ast.Constant) and kw.value.value is False:
return node
self._attrs_classes_to_patch[node.name] = decorator
self.inserted_decorator = True
return node
# Create super().__init__(*args, **kwargs) call (use prebuilt AST fragments)
# Skip NamedTuples — their __init__ is synthesized and cannot be overwritten.
for base in node.bases:
@ -202,6 +243,60 @@ class InitDecorator(ast.NodeTransformer):
return node
def _build_attrs_patch_block(self, class_name: str, decorator: ast.Call) -> list[ast.stmt]:
orig_name = f"_codeflash_orig_{class_name}_init"
patched_name = f"_codeflash_patched_{class_name}_init"
# _codeflash_orig_ClassName_init = ClassName.__init__
# Create class name nodes once
class_name_load = ast.Name(id=class_name, ctx=self._load_ctx)
# _codeflash_orig_ClassName_init = ClassName.__init__
save_orig = ast.Assign(
targets=[ast.Name(id=orig_name, ctx=self._store_ctx)],
value=ast.Attribute(value=class_name_load, attr="__init__", ctx=self._load_ctx),
)
# def _codeflash_patched_ClassName_init(self, *args, **kwargs):
# return _codeflash_orig_ClassName_init(self, *args, **kwargs)
patched_func = ast.FunctionDef(
name=patched_name,
args=ast.arguments(
posonlyargs=[],
args=[self._self_arg_node],
vararg=self._args_arg_node,
kwonlyargs=[],
kw_defaults=[],
kwarg=self._kwargs_arg_node,
defaults=[],
),
body=cast(
"list[ast.stmt]",
[
ast.Return(
value=ast.Call(
func=ast.Name(id=orig_name, ctx=self._load_ctx),
args=[self._self_name_load, self._starred_args],
keywords=[self._kwargs_keyword],
)
)
],
),
decorator_list=cast("list[ast.expr]", []),
returns=None,
)
# ClassName.__init__ = codeflash_capture(...)(_codeflash_patched_ClassName_init)
assign_patched = ast.Assign(
targets=[
ast.Attribute(value=ast.Name(id=class_name, ctx=self._load_ctx), attr="__init__", ctx=self._store_ctx)
],
value=ast.Call(func=decorator, args=[ast.Name(id=patched_name, ctx=self._load_ctx)], keywords=[]),
)
return [save_orig, patched_func, assign_patched]
def _expr_name(self, node: ast.AST) -> str | None:
if isinstance(node, ast.Name):
return node.id

View file

@ -14,7 +14,11 @@ from typing import TYPE_CHECKING
from junitparser.xunit2 import JUnitXml
from codeflash.cli_cmds.console import console, logger
from codeflash.code_utils.code_utils import file_path_from_module_name, module_name_from_file_path
from codeflash.code_utils.code_utils import (
extract_parameterized_test_index,
file_path_from_module_name,
module_name_from_file_path,
)
from codeflash.models.models import FunctionTestInvocation, InvocationId, TestResults
if TYPE_CHECKING:
@ -140,13 +144,15 @@ def parse_python_test_xml(
if class_name is not None and class_name.startswith(test_module_path):
test_class = class_name[len(test_module_path) + 1 :]
loop_index = int(testcase.name.split("[ ")[-1][:-2]) if testcase.name and "[" in testcase.name else 1
loop_index = (
extract_parameterized_test_index(testcase.name) if testcase.name and "[" in testcase.name else 1
)
timed_out = False
if len(testcase.result) > 1:
logger.debug(f"!!!!!Multiple results for {testcase.name or '<None>'} in {test_xml_file_path}!!!")
if len(testcase.result) == 1:
message = testcase.result[0].message.lower()
message = (testcase.result[0].message or "").lower()
if "failed: timeout >" in message or "timed out" in message:
timed_out = True

View file

@ -848,7 +848,11 @@ class PythonSupport:
# === Test Result Comparison ===
def compare_test_results(
self, original_results_path: Path, candidate_results_path: Path, project_root: Path | None = None
self,
original_results_path: Path,
candidate_results_path: Path,
project_root: Path | None = None,
project_classpath: str | None = None,
) -> tuple[bool, list]:
"""Compare test results between original and candidate code.
@ -856,6 +860,7 @@ class PythonSupport:
original_results_path: Path to original test results.
candidate_results_path: Path to candidate test results.
project_root: Project root directory.
project_classpath: Unused (Java only).
Returns:
Tuple of (are_equivalent, list of TestDiff objects).

View file

@ -714,6 +714,10 @@ class Optimizer:
cleanup_paths(paths_to_cleanup)
def cleanup_temporary_paths(self) -> None:
from codeflash.languages.java.test_runner import CompilationCache
CompilationCache.clear()
if hasattr(get_run_tmp_file, "tmpdir"):
get_run_tmp_file.tmpdir.cleanup()
del get_run_tmp_file.tmpdir

View file

@ -26,6 +26,8 @@ dependencies = [
"tree-sitter-javascript>=0.23.0",
"tree-sitter-typescript>=0.23.0",
"tree-sitter-java>=0.23.0",
"tree-sitter-groovy>=0.1.0",
"tree-sitter-kotlin>=1.0.0",
"pytest-timeout>=2.1.0",
"tomlkit>=0.11.7",
"junitparser>=3.1.0",

View file

@ -5013,6 +5013,68 @@ def process(cfg: ChildConfig) -> str:
assert "qualified_name: str" in combined
def test_extract_init_stub_attrs_define(tmp_path: Path) -> None:
"""extract_init_stub_from_class produces a synthetic __init__ stub for @attrs.define classes."""
source = """
import attrs
from attrs.validators import instance_of
@attrs.define(frozen=True)
class ImportCST:
module: str = attrs.field(converter=str.lower)
name: str = attrs.field(validator=[instance_of(str)])
as_name: str = attrs.field(validator=[instance_of(str)])
def to_str(self) -> str:
return f"from {self.module} import {self.name}"
"""
expected = "class ImportCST:\n def __init__(self, module: str, name: str, as_name: str):\n ..."
tree = ast.parse(source)
stub = extract_init_stub_from_class("ImportCST", source, tree)
assert stub == expected
def test_extract_init_stub_attrs_factory_fields(tmp_path: Path) -> None:
"""Fields using attrs factory= keyword should appear as optional (= ...) in the stub."""
source = """
import attrs
@attrs.define
class ClassCST:
name: str = attrs.field()
methods: list = attrs.field(factory=list)
imports: set = attrs.field(factory=set)
def compute(self) -> int:
return len(self.methods)
"""
expected = "class ClassCST:\n def __init__(self, name: str, methods: list = ..., imports: set = ...):\n ..."
tree = ast.parse(source)
stub = extract_init_stub_from_class("ClassCST", source, tree)
assert stub == expected
def test_extract_init_stub_attrs_init_disabled(tmp_path: Path) -> None:
"""When @attrs.define(init=False) but with explicit __init__, the explicit body is returned."""
source = """
import attrs
@attrs.define(init=False)
class NoAutoInit:
x: int = attrs.field()
def __init__(self, x: int):
self.x = x * 2
def get(self) -> int:
return self.x
"""
expected = "class NoAutoInit:\n def __init__(self, x: int):\n self.x = x * 2"
tree = ast.parse(source)
stub = extract_init_stub_from_class("NoAutoInit", source, tree)
assert stub == expected
def test_enrich_testgen_context_third_party_uses_stubs(tmp_path: Path) -> None:
"""Third-party classes should produce compact __init__ stubs, not full class source."""
# Use a real third-party package (pydantic) so jedi can actually resolve it

View file

@ -303,7 +303,7 @@ class TestGetGitDiffMultiLanguage(unittest.TestCase):
reset_current_language()
@patch("codeflash.code_utils.git_utils.git.Repo")
def test_java_diff_ignored_when_language_is_python(self, mock_repo_cls):
def test_java_diff_found_regardless_of_current_language(self, mock_repo_cls):
from codeflash.languages.current import reset_current_language, set_current_language
repo = mock_repo_cls.return_value
@ -311,15 +311,18 @@ class TestGetGitDiffMultiLanguage(unittest.TestCase):
repo.working_dir = "/repo"
repo.git.diff.return_value = JAVA_ADDITION_DIFF
# get_git_diff uses all registered extensions, not just the current language's
set_current_language("python")
try:
result = get_git_diff(repo_directory=None, uncommitted_changes=True)
assert len(result) == 0
assert len(result) == 1
key = list(result.keys())[0]
assert str(key).endswith("Fibonacci.java")
finally:
reset_current_language()
@patch("codeflash.code_utils.git_utils.git.Repo")
def test_mixed_lang_diff_filters_by_current_language(self, mock_repo_cls):
def test_mixed_lang_diff_returns_all_supported_extensions(self, mock_repo_cls):
from codeflash.languages.current import reset_current_language, set_current_language
repo = mock_repo_cls.return_value
@ -327,23 +330,14 @@ class TestGetGitDiffMultiLanguage(unittest.TestCase):
repo.working_dir = "/repo"
repo.git.diff.return_value = MIXED_LANG_DIFF
# When language is Python, only .py file should be found
# All supported extensions are returned regardless of current language
set_current_language("python")
try:
result = get_git_diff(repo_directory=None, uncommitted_changes=True)
assert len(result) == 1
key = list(result.keys())[0]
assert str(key).endswith("utils.py")
finally:
reset_current_language()
# When language is Java, only .java file should be found
set_current_language("java")
try:
result = get_git_diff(repo_directory=None, uncommitted_changes=True)
assert len(result) == 1
key = list(result.keys())[0]
assert str(key).endswith("App.java")
assert len(result) == 2
paths = [str(k) for k in result.keys()]
assert any(p.endswith("utils.py") for p in paths)
assert any(p.endswith("App.java") for p in paths)
finally:
reset_current_language()

View file

@ -2,8 +2,8 @@ from pathlib import Path
from codeflash.code_utils.code_utils import get_run_tmp_file
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
from codeflash.models.models import FunctionParent
from codeflash.languages.python.instrument_codeflash_capture import instrument_codeflash_capture
from codeflash.models.models import FunctionParent
def test_add_codeflash_capture():
@ -499,6 +499,184 @@ class MyTuple(typing.NamedTuple):
test_path.unlink(missing_ok=True)
def test_attrs_define_patched_via_module_wrapper():
"""@attrs.define classes must NOT get a synthetic body __init__; instead a module-level
monkey-patch block is emitted after the class to avoid the __class__ cell TypeError
that arises when attrs.define(slots=True) replaces the original class object.
"""
original_code = """
import attrs
from attrs.validators import instance_of
@attrs.define
class MyAttrsClass:
x: int = attrs.field(validator=[instance_of(int)])
y: str = attrs.field(default="hello")
def compute(self):
return self.x
"""
test_path = (Path(__file__).parent.resolve() / "../code_to_optimize/tests/pytest/test_file.py").resolve()
expected = f"""import attrs
from attrs.validators import instance_of
from codeflash.verification.codeflash_capture import codeflash_capture
@attrs.define
class MyAttrsClass:
x: int = attrs.field(validator=[instance_of(int)])
y: str = attrs.field(default='hello')
def compute(self):
return self.x
_codeflash_orig_MyAttrsClass_init = MyAttrsClass.__init__
def _codeflash_patched_MyAttrsClass_init(self, *args, **kwargs):
return _codeflash_orig_MyAttrsClass_init(self, *args, **kwargs)
MyAttrsClass.__init__ = codeflash_capture(function_name='MyAttrsClass.__init__', tmp_dir_path='{get_run_tmp_file(Path("test_return_values")).as_posix()}', tests_root='{test_path.parent.as_posix()}', is_fto=True)(_codeflash_patched_MyAttrsClass_init)
"""
test_path.write_text(original_code)
function = FunctionToOptimize(
function_name="compute", file_path=test_path, parents=[FunctionParent(type="ClassDef", name="MyAttrsClass")]
)
try:
instrument_codeflash_capture(function, {}, test_path.parent)
modified_code = test_path.read_text()
assert modified_code.strip() == expected.strip()
finally:
test_path.unlink(missing_ok=True)
def test_attrs_define_frozen_patched_via_module_wrapper():
"""@attrs.define(frozen=True) should also be monkey-patched at module level."""
original_code = """
import attrs
@attrs.define(frozen=True)
class FrozenPoint:
x: float = attrs.field()
y: float = attrs.field()
def distance(self):
return (self.x ** 2 + self.y ** 2) ** 0.5
"""
test_path = (Path(__file__).parent.resolve() / "../code_to_optimize/tests/pytest/test_file.py").resolve()
expected = f"""import attrs
from codeflash.verification.codeflash_capture import codeflash_capture
@attrs.define(frozen=True)
class FrozenPoint:
x: float = attrs.field()
y: float = attrs.field()
def distance(self):
return (self.x ** 2 + self.y ** 2) ** 0.5
_codeflash_orig_FrozenPoint_init = FrozenPoint.__init__
def _codeflash_patched_FrozenPoint_init(self, *args, **kwargs):
return _codeflash_orig_FrozenPoint_init(self, *args, **kwargs)
FrozenPoint.__init__ = codeflash_capture(function_name='FrozenPoint.__init__', tmp_dir_path='{get_run_tmp_file(Path("test_return_values")).as_posix()}', tests_root='{test_path.parent.as_posix()}', is_fto=True)(_codeflash_patched_FrozenPoint_init)
"""
test_path.write_text(original_code)
function = FunctionToOptimize(
function_name="distance", file_path=test_path, parents=[FunctionParent(type="ClassDef", name="FrozenPoint")]
)
try:
instrument_codeflash_capture(function, {}, test_path.parent)
modified_code = test_path.read_text()
assert modified_code.strip() == expected.strip()
finally:
test_path.unlink(missing_ok=True)
def test_attr_s_patched_via_module_wrapper():
"""@attr.s classes should also be monkey-patched at module level."""
original_code = """
import attr
@attr.s
class MyAttrClass:
x: int = attr.ib()
def display(self):
return self.x
"""
test_path = (Path(__file__).parent.resolve() / "../code_to_optimize/tests/pytest/test_file.py").resolve()
expected = f"""import attr
from codeflash.verification.codeflash_capture import codeflash_capture
@attr.s
class MyAttrClass:
x: int = attr.ib()
def display(self):
return self.x
_codeflash_orig_MyAttrClass_init = MyAttrClass.__init__
def _codeflash_patched_MyAttrClass_init(self, *args, **kwargs):
return _codeflash_orig_MyAttrClass_init(self, *args, **kwargs)
MyAttrClass.__init__ = codeflash_capture(function_name='MyAttrClass.__init__', tmp_dir_path='{get_run_tmp_file(Path("test_return_values")).as_posix()}', tests_root='{test_path.parent.as_posix()}', is_fto=True)(_codeflash_patched_MyAttrClass_init)
"""
test_path.write_text(original_code)
function = FunctionToOptimize(
function_name="display", file_path=test_path, parents=[FunctionParent(type="ClassDef", name="MyAttrClass")]
)
try:
instrument_codeflash_capture(function, {}, test_path.parent)
modified_code = test_path.read_text()
assert modified_code.strip() == expected.strip()
finally:
test_path.unlink(missing_ok=True)
def test_attrs_define_init_false_skipped():
"""@attrs.define(init=False) should NOT be monkey-patched because attrs won't generate an __init__."""
original_code = """
import attrs
@attrs.define(init=False)
class ManualInit:
x: int = attrs.field()
def compute(self):
return self.x
"""
expected = """import attrs
@attrs.define(init=False)
class ManualInit:
x: int = attrs.field()
def compute(self):
return self.x
"""
test_path = (Path(__file__).parent.resolve() / "../code_to_optimize/tests/pytest/test_file.py").resolve()
test_path.write_text(original_code)
function = FunctionToOptimize(
function_name="compute", file_path=test_path, parents=[FunctionParent(type="ClassDef", name="ManualInit")]
)
try:
instrument_codeflash_capture(function, {}, test_path.parent)
modified_code = test_path.read_text()
assert modified_code.strip() == expected.strip()
finally:
test_path.unlink(missing_ok=True)
def test_dataclass_with_explicit_init_still_instrumented():
"""A dataclass that defines its own __init__ should still be instrumented normally."""
original_code = """

View file

@ -1,4 +1,4 @@
"""Tests for ensure_multi_module_deps_installed in Java test runner."""
"""Tests for MavenStrategy.install_multi_module_deps."""
import subprocess
from pathlib import Path
@ -6,7 +6,7 @@ from unittest.mock import patch
import pytest
from codeflash.languages.java.test_runner import _multimodule_deps_installed, ensure_multi_module_deps_installed
from codeflash.languages.java.maven_strategy import MavenStrategy, _multimodule_deps_installed
@pytest.fixture(autouse=True)
@ -17,21 +17,26 @@ def clear_cache():
_multimodule_deps_installed.clear()
def test_skipped_for_single_module():
@pytest.fixture()
def strategy():
return MavenStrategy()
def test_skipped_for_single_module(strategy):
"""Single-module projects (test_module=None) should be a no-op."""
result = ensure_multi_module_deps_installed(Path("/fake"), None, {})
result = strategy.install_multi_module_deps(Path("/fake"), None, {})
assert result is True
assert len(_multimodule_deps_installed) == 0
@patch("codeflash.languages.java.test_runner.find_maven_executable", return_value="mvn")
@patch.object(MavenStrategy, "find_executable", return_value="mvn")
@patch("codeflash.languages.java.test_runner._run_cmd_kill_pg_on_timeout")
def test_runs_install_command_with_correct_args(mock_run, mock_mvn):
def test_runs_install_command_with_correct_args(mock_run, mock_mvn, strategy):
"""Should run mvn install -DskipTests -pl <module> -am with validation skip flags."""
mock_run.return_value = subprocess.CompletedProcess(args=["mvn"], returncode=0, stdout="", stderr="")
root = Path("/project")
result = ensure_multi_module_deps_installed(root, "guava-tests", {"JAVA_HOME": "/jdk"})
result = strategy.install_multi_module_deps(root, "guava-tests", {"JAVA_HOME": "/jdk"})
assert result is True
mock_run.assert_called_once()
@ -43,55 +48,52 @@ def test_runs_install_command_with_correct_args(mock_run, mock_mvn):
assert "guava-tests" in cmd
assert "-am" in cmd
assert "-B" in cmd
# Validation skip flags should be present
assert "-Drat.skip=true" in cmd
assert "-Dcheckstyle.skip=true" in cmd
# cwd should be maven_root
assert mock_run.call_args[1]["cwd"] == root
@patch("codeflash.languages.java.test_runner.find_maven_executable", return_value="mvn")
@patch.object(MavenStrategy, "find_executable", return_value="mvn")
@patch("codeflash.languages.java.test_runner._run_cmd_kill_pg_on_timeout")
def test_caches_and_does_not_rerun(mock_run, mock_mvn):
def test_caches_and_does_not_rerun(mock_run, mock_mvn, strategy):
"""Second call with same (root, module) should be cached — no Maven invocation."""
mock_run.return_value = subprocess.CompletedProcess(args=["mvn"], returncode=0, stdout="", stderr="")
root = Path("/project")
ensure_multi_module_deps_installed(root, "guava-tests", {})
strategy.install_multi_module_deps(root, "guava-tests", {})
assert mock_run.call_count == 1
# Second call — should be cached
result = ensure_multi_module_deps_installed(root, "guava-tests", {})
result = strategy.install_multi_module_deps(root, "guava-tests", {})
assert result is True
assert mock_run.call_count == 1 # NOT called again
assert mock_run.call_count == 1
@patch("codeflash.languages.java.test_runner.find_maven_executable", return_value="mvn")
@patch.object(MavenStrategy, "find_executable", return_value="mvn")
@patch("codeflash.languages.java.test_runner._run_cmd_kill_pg_on_timeout")
def test_different_modules_not_cached(mock_run, mock_mvn):
def test_different_modules_not_cached(mock_run, mock_mvn, strategy):
"""Different test modules should each trigger their own install."""
mock_run.return_value = subprocess.CompletedProcess(args=["mvn"], returncode=0, stdout="", stderr="")
root = Path("/project")
ensure_multi_module_deps_installed(root, "module-a", {})
ensure_multi_module_deps_installed(root, "module-b", {})
strategy.install_multi_module_deps(root, "module-a", {})
strategy.install_multi_module_deps(root, "module-b", {})
assert mock_run.call_count == 2
@patch("codeflash.languages.java.test_runner.find_maven_executable", return_value="mvn")
@patch.object(MavenStrategy, "find_executable", return_value="mvn")
@patch("codeflash.languages.java.test_runner._run_cmd_kill_pg_on_timeout")
def test_returns_false_on_maven_failure(mock_run, mock_mvn):
def test_returns_false_on_maven_failure(mock_run, mock_mvn, strategy):
"""Non-zero exit code should return False and NOT cache."""
mock_run.return_value = subprocess.CompletedProcess(args=["mvn"], returncode=1, stdout="", stderr="BUILD FAILURE")
root = Path("/project")
result = ensure_multi_module_deps_installed(root, "guava-tests", {})
result = strategy.install_multi_module_deps(root, "guava-tests", {})
assert result is False
assert len(_multimodule_deps_installed) == 0
@patch("codeflash.languages.java.test_runner.find_maven_executable", return_value=None)
def test_returns_false_when_maven_not_found(mock_mvn):
@patch.object(MavenStrategy, "find_executable", return_value=None)
def test_returns_false_when_maven_not_found(mock_mvn, strategy):
"""Should return False if Maven executable is not found."""
result = ensure_multi_module_deps_installed(Path("/fake"), "module", {})
result = strategy.install_multi_module_deps(Path("/fake"), "module", {})
assert result is False

View file

@ -6,38 +6,36 @@ from unittest.mock import patch
import pytest
from codeflash.languages.java.test_runner import _build_test_filter, _run_maven_tests
from codeflash.languages.java.maven_strategy import MavenStrategy
from codeflash.languages.java.test_runner import _build_test_filter
from codeflash.models.models import TestFile, TestFiles, TestType
def test_build_test_filter_with_none_benchmarking_paths():
"""Test that _build_test_filter handles None benchmarking paths correctly."""
# Create TestFiles with None benchmarking_file_path
test_files = TestFiles(
test_files=[
TestFile(
instrumented_behavior_file_path=Path("/tmp/test1__perfinstrumented.java"),
benchmarking_file_path=None, # None path!
benchmarking_file_path=None,
original_file_path=Path("/tmp/test1.java"),
test_type=TestType.EXISTING_UNIT_TEST,
),
TestFile(
instrumented_behavior_file_path=Path("/tmp/test2__perfinstrumented.java"),
benchmarking_file_path=None, # None path!
benchmarking_file_path=None,
original_file_path=Path("/tmp/test2.java"),
test_type=TestType.EXISTING_UNIT_TEST,
),
]
)
# In performance mode with None paths, filter should be empty
result = _build_test_filter(test_files, mode="performance")
assert result == "", f"Expected empty filter but got: {result}"
def test_build_test_filter_with_valid_paths():
"""Test that _build_test_filter works correctly with valid paths."""
# Create TestFiles with valid paths
test_files = TestFiles(
test_files=[
TestFile(
@ -49,50 +47,46 @@ def test_build_test_filter_with_valid_paths():
]
)
# Should produce valid filter
result = _build_test_filter(test_files, mode="performance")
assert result != "", "Expected non-empty filter"
assert "Test1__perfonlyinstrumented" in result
def test_run_maven_tests_raises_on_empty_filter():
"""Test that _run_maven_tests raises ValueError when filter is empty."""
def test_run_tests_via_build_tool_raises_on_empty_filter():
"""Test that MavenStrategy.run_tests_via_build_tool raises ValueError when filter is empty."""
strategy = MavenStrategy()
project_root = Path("/tmp/test_project")
env = {}
# Create TestFiles with None paths (will produce empty filter)
test_files = TestFiles(
test_files=[
TestFile(
instrumented_behavior_file_path=Path("/tmp/test__perfinstrumented.java"),
benchmarking_file_path=None, # Will cause empty filter in performance mode
benchmarking_file_path=None,
original_file_path=Path("/tmp/test.java"),
test_type=TestType.EXISTING_UNIT_TEST,
)
]
)
# Mock Maven executable
with patch("codeflash.languages.java.test_runner.find_maven_executable") as mock_maven:
mock_maven.return_value = "mvn"
# Should raise ValueError due to empty filter
with patch.object(MavenStrategy, "find_executable", return_value="mvn"):
with pytest.raises(ValueError, match="Test filter is EMPTY"):
_run_maven_tests(
strategy.run_tests_via_build_tool(
project_root,
test_files,
env,
timeout=60,
mode="performance", # Performance mode with None benchmarking_file_path
mode="performance",
test_module=None,
)
def test_run_maven_tests_succeeds_with_valid_filter():
"""Test that _run_maven_tests works correctly when filter is not empty."""
def test_run_tests_via_build_tool_succeeds_with_valid_filter():
"""Test that MavenStrategy.run_tests_via_build_tool works correctly when filter is not empty."""
strategy = MavenStrategy()
project_root = Path("/tmp/test_project")
env = {}
# Create TestFiles with valid paths
test_files = TestFiles(
test_files=[
TestFile(
@ -104,20 +98,18 @@ def test_run_maven_tests_succeeds_with_valid_filter():
]
)
# Mock Maven executable and _run_cmd_kill_pg_on_timeout (which replaced subprocess.run)
with (
patch("codeflash.languages.java.test_runner.find_maven_executable") as mock_maven,
patch.object(MavenStrategy, "find_executable", return_value="mvn"),
patch("codeflash.languages.java.test_runner._run_cmd_kill_pg_on_timeout") as mock_run,
):
mock_maven.return_value = "mvn"
mock_run.return_value = subprocess.CompletedProcess(
args=[], returncode=0, stdout="Tests run: 1, Failures: 0, Errors: 0, Skipped: 0", stderr=""
)
# Should not raise - filter is valid
result = _run_maven_tests(project_root, test_files, env, timeout=60, mode="performance")
result = strategy.run_tests_via_build_tool(
project_root, test_files, env, timeout=60, mode="performance", test_module=None
)
# Verify Maven was called with -Dtest parameter
assert mock_run.called
cmd = mock_run.call_args[0][0]
assert any("-Dtest=" in arg for arg in cmd), f"Expected -Dtest parameter in command: {cmd}"

View file

@ -2,16 +2,17 @@
import os
from pathlib import Path
from unittest.mock import patch
from codeflash.languages.java.build_tools import (
BuildTool,
add_codeflash_dependency_to_pom,
detect_build_tool,
find_maven_executable,
find_source_root,
find_test_root,
get_project_info,
)
from codeflash.languages.java.gradle_strategy import GradleStrategy
from codeflash.languages.java.maven_strategy import MavenStrategy, add_codeflash_dependency
from codeflash.languages.java.test_runner import _extract_modules_from_pom_content
@ -175,8 +176,8 @@ class TestMavenExecutable:
def test_find_maven_executable_system(self):
"""Test finding system Maven."""
# This test may pass or fail depending on whether Maven is installed
mvn = find_maven_executable()
strategy = MavenStrategy()
mvn = strategy.find_executable(Path())
# We can't assert it exists, just that the function doesn't crash
if mvn:
assert "mvn" in mvn.lower() or "maven" in mvn.lower()
@ -188,10 +189,8 @@ class TestMavenExecutable:
mvnw_path.write_text("#!/bin/bash\necho 'Maven Wrapper'")
mvnw_path.chmod(0o755)
# Change to tmp_path
monkeypatch.chdir(tmp_path)
mvn = find_maven_executable()
strategy = MavenStrategy()
mvn = strategy.find_executable(tmp_path)
# Should find the wrapper
assert mvn is not None
@ -377,24 +376,27 @@ class TestMavenProfiles:
class TestMavenExecutableWithProjectRoot:
"""Tests for find_maven_executable with project_root parameter."""
"""Tests for MavenStrategy.find_executable with project_root parameter."""
def test_find_wrapper_in_project_root(self, tmp_path):
mvnw_path = tmp_path / "mvnw"
mvnw_path.write_text("#!/bin/bash\necho Maven Wrapper")
mvnw_path.chmod(0o755)
result = find_maven_executable(project_root=tmp_path)
strategy = MavenStrategy()
result = strategy.find_executable(tmp_path)
assert result is not None
assert str(tmp_path / "mvnw") in result
def test_fallback_to_cwd_when_no_project_root(self):
result = find_maven_executable()
# Should not crash even without project_root
def test_fallback_to_cwd(self, tmp_path):
strategy = MavenStrategy()
result = strategy.find_executable(tmp_path)
# Should not crash even with a dir that has no wrapper
def test_project_root_none_uses_cwd(self):
result = find_maven_executable(project_root=None)
# Should not crash
def test_with_nonexistent_wrapper(self, tmp_path):
strategy = MavenStrategy()
result = strategy.find_executable(tmp_path)
# Should not crash, may return system mvn or None
class TestCustomSourceDirectoryDetection:
@ -460,7 +462,7 @@ class TestCustomSourceDirectoryDetection:
class TestAddCodeflashDependencyToPom:
"""Tests for add_codeflash_dependency_to_pom, including stale system-scope replacement."""
"""Tests for add_codeflash_dependency, including stale system-scope replacement."""
def test_adds_dependency_to_clean_pom(self, tmp_path):
pom = tmp_path / "pom.xml"
@ -477,7 +479,7 @@ class TestAddCodeflashDependencyToPom:
"</project>\n",
encoding="utf-8",
)
assert add_codeflash_dependency_to_pom(pom) is True
assert add_codeflash_dependency(pom) is True
content = pom.read_text(encoding="utf-8")
assert "codeflash-runtime" in content
assert "<scope>test</scope>" in content
@ -499,7 +501,7 @@ class TestAddCodeflashDependencyToPom:
"</project>\n",
encoding="utf-8",
)
assert add_codeflash_dependency_to_pom(pom) is True
assert add_codeflash_dependency(pom) is True
content = pom.read_text(encoding="utf-8")
assert "<scope>test</scope>" in content
assert "<scope>system</scope>" not in content
@ -523,7 +525,7 @@ class TestAddCodeflashDependencyToPom:
"</project>\n",
encoding="utf-8",
)
assert add_codeflash_dependency_to_pom(pom) is True
assert add_codeflash_dependency(pom) is True
content = pom.read_text(encoding="utf-8")
assert "<scope>test</scope>" in content
assert "<scope>system</scope>" not in content
@ -545,17 +547,97 @@ class TestAddCodeflashDependencyToPom:
"</project>\n",
encoding="utf-8",
)
assert add_codeflash_dependency_to_pom(pom) is True
assert add_codeflash_dependency(pom) is True
content = pom.read_text(encoding="utf-8")
assert content.count("codeflash-runtime") == 1
def test_returns_false_for_missing_pom(self, tmp_path):
pom = tmp_path / "pom.xml"
assert add_codeflash_dependency_to_pom(pom) is False
assert add_codeflash_dependency(pom) is False
def test_returns_false_when_no_dependencies_tag(self, tmp_path):
pom = tmp_path / "pom.xml"
pom.write_text(
'<?xml version="1.0"?>\n<project><modelVersion>4.0.0</modelVersion></project>\n', encoding="utf-8"
)
assert add_codeflash_dependency_to_pom(pom) is False
assert add_codeflash_dependency(pom) is False
class TestGradleEnsureRuntimeMultiModule:
"""Tests that ensure_runtime adds the dependency to the correct module build file."""
def _make_multi_module_project(self, tmp_path):
"""Create a multi-module Gradle project with submodule build files."""
# Root
(tmp_path / "build.gradle.kts").write_text("// root build\n", encoding="utf-8")
(tmp_path / "settings.gradle.kts").write_text('include("clients", "streams")', encoding="utf-8")
(tmp_path / "gradlew").write_text("#!/bin/sh\necho gradle", encoding="utf-8")
(tmp_path / "gradlew").chmod(0o755)
# Submodule build files with a dependencies block
for module in ["clients", "streams"]:
module_dir = tmp_path / module
module_dir.mkdir()
(module_dir / "build.gradle.kts").write_text(
'plugins {\n java\n}\n\ndependencies {\n testImplementation("junit:junit:4.13.2")\n}\n',
encoding="utf-8",
)
return tmp_path
def test_adds_dependency_to_correct_module_build_file(self, tmp_path):
"""When test_module='streams', the dependency must be added to streams/build.gradle.kts."""
project = self._make_multi_module_project(tmp_path)
strategy = GradleStrategy()
# Provide a fake runtime JAR
fake_jar = tmp_path / "fake-runtime.jar"
fake_jar.write_bytes(b"PK\x03\x04") # minimal zip header
with patch.object(strategy, "find_runtime_jar", return_value=fake_jar):
result = strategy.ensure_runtime(project, test_module="streams")
assert result is True
# Dependency should be in streams/build.gradle.kts
streams_build = (project / "streams" / "build.gradle.kts").read_text(encoding="utf-8")
assert "codeflash-runtime" in streams_build
# And NOT in clients/build.gradle.kts or root build.gradle.kts
clients_build = (project / "clients" / "build.gradle.kts").read_text(encoding="utf-8")
assert "codeflash-runtime" not in clients_build
root_build = (project / "build.gradle.kts").read_text(encoding="utf-8")
assert "codeflash-runtime" not in root_build
def test_adds_dependency_to_root_when_no_module(self, tmp_path):
"""When test_module=None, the dependency is added to the root build file."""
project = self._make_multi_module_project(tmp_path)
strategy = GradleStrategy()
fake_jar = tmp_path / "fake-runtime.jar"
fake_jar.write_bytes(b"PK\x03\x04")
with patch.object(strategy, "find_runtime_jar", return_value=fake_jar):
result = strategy.ensure_runtime(project, test_module=None)
assert result is True
root_build = (project / "build.gradle.kts").read_text(encoding="utf-8")
assert "codeflash-runtime" in root_build
def test_adds_dependency_to_nested_module(self, tmp_path):
"""When test_module='connect:runtime', the dep goes to connect/runtime/build.gradle.kts."""
project = self._make_multi_module_project(tmp_path)
# Add nested module
nested = tmp_path / "connect" / "runtime"
nested.mkdir(parents=True)
(nested / "build.gradle.kts").write_text(
'plugins {\n java\n}\n\ndependencies {\n testImplementation("junit:junit:4.13.2")\n}\n',
encoding="utf-8",
)
strategy = GradleStrategy()
fake_jar = tmp_path / "fake-runtime.jar"
fake_jar.write_bytes(b"PK\x03\x04")
with patch.object(strategy, "find_runtime_jar", return_value=fake_jar):
result = strategy.ensure_runtime(project, test_module="connect:runtime")
assert result is True
nested_build = (nested / "build.gradle.kts").read_text(encoding="utf-8")
assert "codeflash-runtime" in nested_build

View file

@ -4,10 +4,10 @@ from __future__ import annotations
from pathlib import Path
from codeflash.languages.java.build_tools import (
from codeflash.languages.java.maven_strategy import (
JACOCO_PLUGIN_VERSION,
add_jacoco_plugin_to_pom,
get_jacoco_xml_path,
add_jacoco_plugin,
get_jacoco_report_path,
is_jacoco_configured,
)
from codeflash.models.models import CodeOptimizationContext, CodeStringsMarkdown, CoverageStatus, FunctionSource
@ -470,7 +470,7 @@ class TestJacocoPluginAddition:
pom_path.write_text(POM_MINIMAL)
# Add JaCoCo plugin
result = add_jacoco_plugin_to_pom(pom_path)
result = add_jacoco_plugin(pom_path)
assert result is True
# Verify it's now configured
@ -483,13 +483,13 @@ class TestJacocoPluginAddition:
assert "prepare-agent" in content
assert "report" in content
def test_add_jacoco_plugin_to_pom_with_build(self, tmp_path: Path) -> None:
def test_add_jacoco_plugin_with_build(self, tmp_path: Path) -> None:
"""Test adding JaCoCo to pom.xml that has a build section."""
pom_path = tmp_path / "pom.xml"
pom_path.write_text(POM_WITHOUT_JACOCO)
# Add JaCoCo plugin
result = add_jacoco_plugin_to_pom(pom_path)
result = add_jacoco_plugin(pom_path)
assert result is True
# Verify it's now configured
@ -501,7 +501,7 @@ class TestJacocoPluginAddition:
pom_path.write_text(POM_WITH_JACOCO)
# Try to add JaCoCo plugin
result = add_jacoco_plugin_to_pom(pom_path)
result = add_jacoco_plugin(pom_path)
assert result is True # Should succeed (already present)
# Verify it's still configured
@ -513,7 +513,7 @@ class TestJacocoPluginAddition:
pom_path.write_text(POM_NO_NAMESPACE)
# Add JaCoCo plugin
result = add_jacoco_plugin_to_pom(pom_path)
result = add_jacoco_plugin(pom_path)
assert result is True
# Verify it's now configured
@ -523,7 +523,7 @@ class TestJacocoPluginAddition:
"""Test adding JaCoCo when pom.xml doesn't exist."""
pom_path = tmp_path / "pom.xml"
result = add_jacoco_plugin_to_pom(pom_path)
result = add_jacoco_plugin(pom_path)
assert result is False
def test_add_jacoco_plugin_invalid_xml(self, tmp_path: Path) -> None:
@ -531,16 +531,16 @@ class TestJacocoPluginAddition:
pom_path = tmp_path / "pom.xml"
pom_path.write_text("this is not valid xml")
result = add_jacoco_plugin_to_pom(pom_path)
result = add_jacoco_plugin(pom_path)
assert result is False
class TestJacocoXmlPath:
"""Tests for JaCoCo XML path resolution."""
def test_get_jacoco_xml_path(self, tmp_path: Path) -> None:
def test_get_jacoco_report_path(self, tmp_path: Path) -> None:
"""Test getting the expected JaCoCo XML path."""
path = get_jacoco_xml_path(tmp_path)
path = get_jacoco_report_path(tmp_path)
assert path == tmp_path / "target" / "site" / "jacoco" / "jacoco.xml"

View file

@ -22,7 +22,7 @@ os.environ["CODEFLASH_API_KEY"] = "cf-test-key"
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
from codeflash.languages.base import Language
from codeflash.languages.current import set_current_language
from codeflash.languages.java.build_tools import find_maven_executable
from codeflash.languages.java.maven_strategy import MavenStrategy
from codeflash.languages.java.discovery import discover_functions_from_source
from codeflash.languages.java.instrumentation import (
_add_behavior_instrumentation,
@ -1968,7 +1968,7 @@ public class AccentTest {
# Skip all E2E tests if Maven is not available
requires_maven = pytest.mark.skipif(
find_maven_executable() is None, reason="Maven not found - skipping execution tests"
MavenStrategy().find_executable(Path(".")) is None, reason="Maven not found - skipping execution tests"
)

View file

@ -3,7 +3,14 @@
from pathlib import Path
from unittest.mock import MagicMock
from codeflash.languages.java.test_runner import _extract_source_dirs_from_pom, _path_to_class_name
from codeflash.languages.java.build_tool_strategy import module_to_dir
from codeflash.languages.java.test_runner import (
_extract_custom_source_dirs,
_extract_modules_from_settings_gradle,
_find_multi_module_root,
_match_module_from_rel_path,
_path_to_class_name,
)
class TestGetJavaSourcesRoot:
@ -13,12 +20,12 @@ class TestGetJavaSourcesRoot:
"""Create a mock FunctionOptimizer with the given tests_root."""
from codeflash.languages.java.function_optimizer import JavaFunctionOptimizer
# Create a minimal mock
mock_optimizer = MagicMock(spec=JavaFunctionOptimizer)
mock_optimizer.test_cfg = MagicMock()
mock_optimizer.test_cfg.tests_root = Path(tests_root)
mock_optimizer.function_to_optimize = MagicMock()
mock_optimizer.function_to_optimize.file_path = Path("/nonexistent/Foo.java")
# Bind the actual method to the mock
mock_optimizer._get_java_sources_root = lambda: JavaFunctionOptimizer._get_java_sources_root(mock_optimizer)
return mock_optimizer
@ -87,6 +94,168 @@ class TestGetJavaSourcesRoot:
assert result == Path("/Users/test/Work/aerospike-client-java/test/src")
class TestGetJavaSourcesRootMultiModule:
"""Tests for _get_java_sources_root with multi-module projects."""
def _create_mock_optimizer(self, tests_root: str, file_path: str):
from codeflash.languages.java.function_optimizer import JavaFunctionOptimizer
mock_optimizer = MagicMock(spec=JavaFunctionOptimizer)
mock_optimizer.test_cfg = MagicMock()
mock_optimizer.test_cfg.tests_root = Path(tests_root)
mock_optimizer.function_to_optimize = MagicMock()
mock_optimizer.function_to_optimize.file_path = Path(file_path)
mock_optimizer._get_java_sources_root = lambda: JavaFunctionOptimizer._get_java_sources_root(mock_optimizer)
return mock_optimizer
def test_kafka_streams_module(self, tmp_path):
"""Kafka: function in streams module should use streams test dir, not clients."""
(tmp_path / "streams" / "src" / "main" / "java").mkdir(parents=True)
(tmp_path / "streams" / "src" / "test" / "java").mkdir(parents=True)
(tmp_path / "clients" / "src" / "test" / "java").mkdir(parents=True)
optimizer = self._create_mock_optimizer(
tests_root=str(tmp_path / "clients" / "src" / "test" / "java"),
file_path=str(
tmp_path
/ "streams"
/ "src"
/ "main"
/ "java"
/ "org"
/ "apache"
/ "kafka"
/ "streams"
/ "query"
/ "QueryConfig.java"
),
)
result = optimizer._get_java_sources_root()
assert result == tmp_path / "streams" / "src" / "test" / "java"
def test_kafka_connect_module(self, tmp_path):
"""Kafka: function in connect/runtime should use connect/runtime test dir."""
(tmp_path / "connect" / "runtime" / "src" / "main" / "java").mkdir(parents=True)
(tmp_path / "connect" / "runtime" / "src" / "test" / "java").mkdir(parents=True)
(tmp_path / "clients" / "src" / "test" / "java").mkdir(parents=True)
optimizer = self._create_mock_optimizer(
tests_root=str(tmp_path / "clients" / "src" / "test" / "java"),
file_path=str(
tmp_path
/ "connect"
/ "runtime"
/ "src"
/ "main"
/ "java"
/ "org"
/ "apache"
/ "kafka"
/ "connect"
/ "runtime"
/ "Worker.java"
),
)
result = optimizer._get_java_sources_root()
assert result == tmp_path / "connect" / "runtime" / "src" / "test" / "java"
def test_kafka_clients_module_same_as_config(self, tmp_path):
"""Kafka: function in clients module should still use clients test dir."""
(tmp_path / "clients" / "src" / "main" / "java").mkdir(parents=True)
(tmp_path / "clients" / "src" / "test" / "java").mkdir(parents=True)
optimizer = self._create_mock_optimizer(
tests_root=str(tmp_path / "clients" / "src" / "test" / "java"),
file_path=str(
tmp_path
/ "clients"
/ "src"
/ "main"
/ "java"
/ "org"
/ "apache"
/ "kafka"
/ "common"
/ "utils"
/ "Bytes.java"
),
)
result = optimizer._get_java_sources_root()
assert result == tmp_path / "clients" / "src" / "test" / "java"
def test_opensearch_libs_module(self, tmp_path):
"""OpenSearch: function in libs/core should use libs/core test dir."""
(tmp_path / "libs" / "core" / "src" / "main" / "java").mkdir(parents=True)
(tmp_path / "libs" / "core" / "src" / "test" / "java").mkdir(parents=True)
(tmp_path / "server" / "src" / "test" / "java").mkdir(parents=True)
optimizer = self._create_mock_optimizer(
tests_root=str(tmp_path / "server" / "src" / "test" / "java"),
file_path=str(
tmp_path
/ "libs"
/ "core"
/ "src"
/ "main"
/ "java"
/ "org"
/ "opensearch"
/ "core"
/ "common"
/ "Strings.java"
),
)
result = optimizer._get_java_sources_root()
assert result == tmp_path / "libs" / "core" / "src" / "test" / "java"
def test_spring_boot_subproject(self, tmp_path):
"""Spring Boot: function in autoconfigure should use autoconfigure test dir."""
(tmp_path / "spring-boot-autoconfigure" / "src" / "main" / "java").mkdir(parents=True)
(tmp_path / "spring-boot-autoconfigure" / "src" / "test" / "java").mkdir(parents=True)
(tmp_path / "spring-boot" / "src" / "test" / "java").mkdir(parents=True)
optimizer = self._create_mock_optimizer(
tests_root=str(tmp_path / "spring-boot" / "src" / "test" / "java"),
file_path=str(
tmp_path
/ "spring-boot-autoconfigure"
/ "src"
/ "main"
/ "java"
/ "org"
/ "springframework"
/ "boot"
/ "autoconfigure"
/ "web"
/ "ServerProperties.java"
),
)
result = optimizer._get_java_sources_root()
assert result == tmp_path / "spring-boot-autoconfigure" / "src" / "test" / "java"
def test_fallback_when_derived_test_dir_missing(self, tmp_path):
"""When derived test dir doesn't exist, fall back to tests_root logic."""
(tmp_path / "module-a" / "src" / "main" / "java").mkdir(parents=True)
# Deliberately NOT creating module-a/src/test/java
tests_root = tmp_path / "src" / "test" / "java"
tests_root.mkdir(parents=True)
optimizer = self._create_mock_optimizer(
tests_root=str(tests_root),
file_path=str(tmp_path / "module-a" / "src" / "main" / "java" / "com" / "example" / "Foo.java"),
)
result = optimizer._get_java_sources_root()
assert result == tests_root
def test_non_standard_layout_falls_through(self, tmp_path):
"""Non-standard layout (no src/main/java) falls through to existing logic."""
optimizer = self._create_mock_optimizer(
tests_root=str(tmp_path / "custom" / "tests"), file_path=str(tmp_path / "custom" / "src" / "Foo.java")
)
result = optimizer._get_java_sources_root()
assert result == tmp_path / "custom" / "tests"
class TestFixJavaTestPathsIntegration:
"""Integration tests for _fix_java_test_paths with the path fix."""
@ -97,8 +266,9 @@ class TestFixJavaTestPathsIntegration:
mock_optimizer = MagicMock(spec=JavaFunctionOptimizer)
mock_optimizer.test_cfg = MagicMock()
mock_optimizer.test_cfg.tests_root = Path(tests_root)
mock_optimizer.function_to_optimize = MagicMock()
mock_optimizer.function_to_optimize.file_path = Path("/nonexistent/Foo.java")
# Bind the actual methods
mock_optimizer._get_java_sources_root = lambda: JavaFunctionOptimizer._get_java_sources_root(mock_optimizer)
mock_optimizer._fix_java_test_paths = lambda behavior_source, perf_source, used_paths, display_source="": (
JavaFunctionOptimizer._fix_java_test_paths(
@ -245,7 +415,7 @@ class TestExtractSourceDirsFromPom:
</project>
"""
(tmp_path / "pom.xml").write_text(pom_content)
dirs = _extract_source_dirs_from_pom(tmp_path)
dirs = _extract_custom_source_dirs(tmp_path)
assert "src/main/custom" in dirs
assert "src/test/custom" in dirs
@ -259,11 +429,11 @@ class TestExtractSourceDirsFromPom:
</project>
"""
(tmp_path / "pom.xml").write_text(pom_content)
dirs = _extract_source_dirs_from_pom(tmp_path)
dirs = _extract_custom_source_dirs(tmp_path)
assert dirs == []
def test_no_pom_returns_empty(self, tmp_path):
dirs = _extract_source_dirs_from_pom(tmp_path)
dirs = _extract_custom_source_dirs(tmp_path)
assert dirs == []
def test_pom_without_build_section(self, tmp_path):
@ -273,10 +443,191 @@ class TestExtractSourceDirsFromPom:
</project>
"""
(tmp_path / "pom.xml").write_text(pom_content)
dirs = _extract_source_dirs_from_pom(tmp_path)
dirs = _extract_custom_source_dirs(tmp_path)
assert dirs == []
def test_malformed_xml(self, tmp_path):
(tmp_path / "pom.xml").write_text("this is not valid xml <<<<")
dirs = _extract_source_dirs_from_pom(tmp_path)
dirs = _extract_custom_source_dirs(tmp_path)
assert dirs == []
class TestMatchModuleFromRelPath:
"""Tests for _match_module_from_rel_path."""
def test_simple_module(self):
assert _match_module_from_rel_path(Path("streams/src/test/java/Test.java"), ["streams", "clients"]) == "streams"
def test_nested_module(self):
result = _match_module_from_rel_path(
Path("connect/runtime/src/test/java/Test.java"), ["connect:runtime", "streams"]
)
assert result == "connect:runtime"
def test_no_match(self):
assert _match_module_from_rel_path(Path("unknown/src/Test.java"), ["streams", "clients"]) is None
def test_partial_name_no_false_match(self):
"""'streams-ng' should not match module 'streams'."""
assert _match_module_from_rel_path(Path("streams-ng/src/Test.java"), ["streams"]) is None
class TestModuleToDir:
"""Tests for module_to_dir."""
def test_simple(self):
assert module_to_dir("streams") == "streams"
def test_nested(self):
result = module_to_dir("connect:runtime")
assert result == "connect" + "/" + "runtime" or result == "connect" + "\\" + "runtime"
class TestExtractModulesFromSettingsGradle:
"""Tests for _extract_modules_from_settings_gradle."""
def test_simple_top_level_modules(self):
content = """include("streams", "clients", "tools")"""
modules = _extract_modules_from_settings_gradle(content)
assert "streams" in modules
assert "clients" in modules
assert "tools" in modules
def test_nested_gradle_modules(self):
"""Nested modules like connect:runtime should be extracted."""
content = """include("connect:runtime", "connect:api", "streams")"""
modules = _extract_modules_from_settings_gradle(content)
assert "connect:runtime" in modules
assert "connect:api" in modules
assert "streams" in modules
def test_leading_colon_stripped(self):
content = """include(":streams", ":clients")"""
modules = _extract_modules_from_settings_gradle(content)
assert "streams" in modules
assert "clients" in modules
class TestFindMultiModuleRoot:
"""Tests for _find_multi_module_root with Gradle multi-module projects."""
def _make_kafka_like_project(self, tmp_path):
"""Create a Kafka-like multi-module Gradle project structure."""
# Root build files
(tmp_path / "build.gradle.kts").write_text("// root build", encoding="utf-8")
(tmp_path / "settings.gradle.kts").write_text(
'include("clients", "streams", "tools", "connect:runtime")', encoding="utf-8"
)
# Module build files and source/test dirs
for module in ["clients", "streams", "tools"]:
(tmp_path / module / "src" / "main" / "java").mkdir(parents=True)
(tmp_path / module / "src" / "test" / "java").mkdir(parents=True)
(tmp_path / module / "build.gradle.kts").write_text(f"// {module} build", encoding="utf-8")
# Nested module
(tmp_path / "connect" / "runtime" / "src" / "main" / "java").mkdir(parents=True)
(tmp_path / "connect" / "runtime" / "src" / "test" / "java").mkdir(parents=True)
(tmp_path / "connect" / "runtime" / "build.gradle.kts").write_text("// connect:runtime build", encoding="utf-8")
def _make_test_paths_mock(self, file_paths: list[Path]):
"""Create a mock test_paths object with test_files."""
mock = MagicMock()
mock.test_files = []
for fp in file_paths:
tf = MagicMock()
tf.benchmarking_file_path = None
tf.instrumented_behavior_file_path = fp
mock.test_files.append(tf)
return mock
def test_streams_tests_return_streams_module(self, tmp_path):
"""When ALL test files are in streams/, should return 'streams' module."""
self._make_kafka_like_project(tmp_path)
test_file = tmp_path / "streams" / "src" / "test" / "java" / "org" / "apache" / "kafka" / "StreamsTest.java"
test_file.parent.mkdir(parents=True, exist_ok=True)
test_file.touch()
test_paths = self._make_test_paths_mock([test_file])
build_root, test_module = _find_multi_module_root(tmp_path, test_paths)
assert build_root == tmp_path
assert test_module == "streams", f"Expected 'streams' but got '{test_module}'"
def test_tools_tests_return_tools_module(self, tmp_path):
"""When test files are in tools/, should return 'tools' module."""
self._make_kafka_like_project(tmp_path)
test_file = tmp_path / "tools" / "src" / "test" / "java" / "org" / "apache" / "kafka" / "ToolsTest.java"
test_file.parent.mkdir(parents=True, exist_ok=True)
test_file.touch()
test_paths = self._make_test_paths_mock([test_file])
build_root, test_module = _find_multi_module_root(tmp_path, test_paths)
assert build_root == tmp_path
assert test_module == "tools", f"Expected 'tools' but got '{test_module}'"
def test_mixed_modules_majority_wins(self, tmp_path):
"""When tests span multiple modules, the module with the most test files wins."""
self._make_kafka_like_project(tmp_path)
clients_test = tmp_path / "clients" / "src" / "test" / "java" / "com" / "ClientsTest.java"
clients_test.parent.mkdir(parents=True, exist_ok=True)
clients_test.touch()
streams_test_1 = tmp_path / "streams" / "src" / "test" / "java" / "com" / "StreamsTest1.java"
streams_test_1.parent.mkdir(parents=True, exist_ok=True)
streams_test_1.touch()
streams_test_2 = tmp_path / "streams" / "src" / "test" / "java" / "com" / "StreamsTest2.java"
streams_test_2.touch()
# 1 clients test + 2 streams tests → streams wins by majority
test_paths = self._make_test_paths_mock([clients_test, streams_test_1, streams_test_2])
build_root, test_module = _find_multi_module_root(tmp_path, test_paths)
assert build_root == tmp_path
assert test_module == "streams"
def test_mixed_modules_equal_count_deterministic(self, tmp_path):
"""When modules are tied, a module is still selected (not None)."""
self._make_kafka_like_project(tmp_path)
clients_test = tmp_path / "clients" / "src" / "test" / "java" / "com" / "ClientsTest.java"
clients_test.parent.mkdir(parents=True, exist_ok=True)
clients_test.touch()
streams_test = tmp_path / "streams" / "src" / "test" / "java" / "com" / "StreamsTest.java"
streams_test.parent.mkdir(parents=True, exist_ok=True)
streams_test.touch()
test_paths = self._make_test_paths_mock([clients_test, streams_test])
build_root, test_module = _find_multi_module_root(tmp_path, test_paths)
assert build_root == tmp_path
assert test_module in ("clients", "streams")
def test_nested_module_connect_runtime(self, tmp_path):
"""Nested Gradle module 'connect:runtime' (dir connect/runtime/) is matched."""
self._make_kafka_like_project(tmp_path)
test_file = (
tmp_path / "connect" / "runtime" / "src" / "test" / "java" / "org" / "kafka" / "ConnectRuntimeTest.java"
)
test_file.parent.mkdir(parents=True, exist_ok=True)
test_file.touch()
test_paths = self._make_test_paths_mock([test_file])
build_root, test_module = _find_multi_module_root(tmp_path, test_paths)
assert build_root == tmp_path
assert test_module == "connect:runtime"
def test_project_root_is_submodule_test_outside(self, tmp_path):
"""When project_root is a submodule (e.g., kafka/clients) and generated
tests are placed in a sibling module (kafka/streams), the function should
walk up to find the repo root and return the correct module.
"""
self._make_kafka_like_project(tmp_path)
submodule_root = tmp_path / "clients"
test_file = tmp_path / "streams" / "src" / "test" / "java" / "com" / "StreamsTest.java"
test_file.parent.mkdir(parents=True, exist_ok=True)
test_file.touch()
test_paths = self._make_test_paths_mock([test_file])
build_root, test_module = _find_multi_module_root(submodule_root, test_paths)
assert build_root == tmp_path
assert test_module == "streams"

View file

@ -90,9 +90,10 @@ POM_CONTENT = """<?xml version="1.0" encoding="UTF-8"?>
def skip_if_maven_not_available():
from codeflash.languages.java.build_tools import find_maven_executable
from codeflash.languages.java.maven_strategy import MavenStrategy
if not MavenStrategy().find_executable(Path(".")):
if not find_maven_executable():
pytest.skip("Maven not available")

View file

@ -81,6 +81,7 @@ class TestInputValidation:
def test_get_test_run_command_validates_input(self, tmp_path: Path):
"""Test that get_test_run_command validates test class names."""
(tmp_path / "pom.xml").write_text("<project></project>", encoding="utf-8")
# Valid class names should work
cmd = get_test_run_command(tmp_path, ["MyTest", "OtherTest"])
assert "-Dtest=MyTest,OtherTest" in " ".join(cmd)

View file

@ -9,7 +9,6 @@ from codeflash.languages.java.line_profiler import JavaLineProfiler
def test_parse_line_profile_results_non_python_java_json():
set_current_language(Language.JAVA)
with TemporaryDirectory() as tmpdir:
tmp_path = Path(tmpdir)
source_file = tmp_path / "Util.java"

6108
uv.lock

File diff suppressed because it is too large Load diff