mirror of
https://github.com/codeflash-ai/codeflash.git
synced 2026-05-04 18:25:17 +00:00
Merge branch 'main' into fix/trigger_cc_on_multiple_commits
This commit is contained in:
commit
ead7fadac5
47 changed files with 6871 additions and 5470 deletions
70
.github/workflows/claude.yml
vendored
70
.github/workflows/claude.yml
vendored
|
|
@ -92,14 +92,14 @@ jobs:
|
|||
Before doing any work, assess the PR scope:
|
||||
|
||||
1. Run `gh pr diff ${{ github.event.pull_request.number }} --name-only` to get changed files.
|
||||
2. Classify as TRIVIAL if ALL changed files are:
|
||||
- Config/CI files (.github/, .tessl/, *.toml, *.lock, *.json, *.yml, *.yaml)
|
||||
- Documentation (*.md, docs/)
|
||||
- Non-production code (demos/, experiments/, code_to_optimize/)
|
||||
- Only whitespace, formatting, or comment changes
|
||||
2. Run `gh pr diff ${{ github.event.pull_request.number }} --stat` to get the total lines changed.
|
||||
3. Classify the PR into one of these categories:
|
||||
- TRIVIAL: ALL changed files are config/CI (.github/, .tessl/, *.toml, *.lock, *.json, *.yml, *.yaml), documentation (*.md, docs/), non-production code (demos/, experiments/, code_to_optimize/), or only whitespace/formatting/comment changes.
|
||||
- SMALL: ≤ 50 lines of production code changed (excluding tests, config, lock files)
|
||||
- LARGE: > 50 lines of production code changed
|
||||
|
||||
If TRIVIAL: post a single comment "No substantive code changes to review." and stop — do not execute any further steps.
|
||||
Otherwise: continue with the full review below.
|
||||
Otherwise: record the size (SMALL or LARGE) and continue. Later steps will adapt their depth based on this.
|
||||
</step>
|
||||
|
||||
<step name="lint_and_typecheck">
|
||||
|
|
@ -127,24 +127,47 @@ jobs:
|
|||
|
||||
2. For each unresolved thread:
|
||||
a. Read the file at that path to check if the issue still exists
|
||||
b. If fixed → resolve it: `gh api graphql -f query='mutation { resolveReviewThread(input: {threadId: "<THREAD_ID>"}) { thread { isResolved } } }'`
|
||||
b. If fixed → resolve the thread silently (do NOT post a reply comment like "✅ Fixed"):
|
||||
`gh api graphql -f query='mutation { resolveReviewThread(input: {threadId: "<THREAD_ID>"}) { thread { isResolved } } }'`
|
||||
c. If still present → leave it
|
||||
|
||||
Read the actual code before deciding. If there are no unresolved threads, skip to the next step.
|
||||
</step>
|
||||
|
||||
<step name="review">
|
||||
Review the diff (`gh pr diff ${{ github.event.pull_request.number }}`) for:
|
||||
Review the diff (`gh pr diff ${{ github.event.pull_request.number }}`).
|
||||
|
||||
SCOPE RULES:
|
||||
- Only read files that appear in the diff. Do NOT explore the broader codebase unless a specific function call in the diff requires understanding its definition.
|
||||
- Match review depth to PR size: SMALL PRs get a focused correctness check; LARGE PRs get the full review below.
|
||||
- For codeflash-ai[bot] optimization PRs: focus on whether the optimization is correct and the speedup claim is credible. Keep the review concise — a short correctness verdict, not a multi-paragraph essay.
|
||||
|
||||
Check for:
|
||||
1. Bugs that will crash at runtime
|
||||
2. Security vulnerabilities
|
||||
3. Breaking API changes
|
||||
4. Design issues (LARGE PRs only):
|
||||
- Is logic placed in the right module? (e.g., language-specific logic belongs in `languages/<lang>/`, not in the base optimizer)
|
||||
- Are fixes addressing the root cause or just papering over symptoms? (e.g., hardcoding a list of imports vs having the AI service handle it)
|
||||
- Is language-agnostic config being stored in a language-specific file? (e.g., storing general config in `pyproject.toml` which is Python-exclusive)
|
||||
5. Accidental file inclusions — flag if the diff contains:
|
||||
- Binary files (`.jar`, `.whl`, `.so`, `.dll`, `.exe`)
|
||||
- Version file changes (`version.py`) that look auto-generated (e.g., `.post<N>.dev0+<hash>`)
|
||||
- Internal planning/design docs committed to `docs/` (those belong in Linear or a separate location)
|
||||
- Changes to `codeflash-benchmark/` version files
|
||||
|
||||
Ignore style issues, type hints, and log message wording.
|
||||
Record findings for the summary comment. Refer to CLAUDE.md for project conventions.
|
||||
</step>
|
||||
|
||||
<step name="duplicate_detection">
|
||||
Check whether this PR introduces code that duplicates logic already present elsewhere in the repository — including across languages. Focus on finding true duplicates, not just similar-looking code.
|
||||
Check whether this PR introduces code that duplicates logic already present elsewhere in the repository.
|
||||
|
||||
Depth depends on PR size:
|
||||
- SMALL PRs: Quick check — search for any new function names defined elsewhere using Grep. Only flag exact or near-exact duplicates. Skip the cross-module deep dive.
|
||||
- LARGE PRs: Full analysis as described below.
|
||||
|
||||
Full analysis (LARGE PRs only):
|
||||
|
||||
1. Get changed source files (excluding tests and config):
|
||||
`git diff --name-only origin/main...HEAD -- '*.py' '*.js' '*.ts' '*.java' | grep -v -E '(test_|_test\.(py|js|ts)|\.test\.(js|ts)|\.spec\.(js|ts)|conftest\.py|/tests/|/test/|/__tests__/)' | grep -v -E '^(\.github/|code_to_optimize/|\.tessl/|node_modules/)'`
|
||||
|
|
@ -171,6 +194,8 @@ jobs:
|
|||
</step>
|
||||
|
||||
<step name="coverage">
|
||||
Skip this step for SMALL PRs — only run for LARGE PRs.
|
||||
|
||||
Analyze test coverage for changed files:
|
||||
|
||||
1. Get changed Python files (excluding tests): `git diff --name-only origin/main...HEAD -- '*.py' | grep -v test`
|
||||
|
|
@ -178,6 +203,7 @@ jobs:
|
|||
3. Get per-file coverage: `uv run coverage report --include="<changed_files>"`
|
||||
4. Compare with main: checkout main, run coverage, checkout back
|
||||
5. Flag: new files below 75%, decreased coverage, untested changed lines
|
||||
6. If the PR adds new public functions or methods without any corresponding tests, explicitly call this out and request tests be added. New production logic should have tests.
|
||||
</step>
|
||||
|
||||
<step name="summary_comment">
|
||||
|
|
@ -186,6 +212,7 @@ jobs:
|
|||
## PR Review Summary
|
||||
### Prek Checks
|
||||
### Code Review
|
||||
(Include bugs, security, breaking changes, and design issues. Omit subsections with no findings.)
|
||||
### Duplicate Detection
|
||||
### Test Coverage
|
||||
---
|
||||
|
|
@ -198,20 +225,25 @@ jobs:
|
|||
|
||||
For each PR:
|
||||
- If CI passes and the PR is mergeable → merge with `--squash --delete-branch`
|
||||
- If CI is failing:
|
||||
1. Check out the PR branch and inspect the failing tests
|
||||
2. Attempt to fix the failures (the optimization may have broken tests or introduced issues)
|
||||
3. If fixed: commit, push, and leave a comment explaining what was fixed
|
||||
4. If unfixable: close with `gh pr close <number> --comment "Closing: CI checks are failing — <describe the specific failures and why they can't be auto-fixed>." --delete-branch`
|
||||
- Close the PR (without attempting fixes) if ANY of these apply:
|
||||
- Older than 7 days
|
||||
- Has merge conflicts (mergeable state is "CONFLICTING")
|
||||
- If CI is failing, first determine whether the failures are CAUSED BY the PR or PRE-EXISTING on the base branch:
|
||||
1. Run `gh pr checks <number>` to identify failing checks
|
||||
2. Check out the PR's BASE branch and check if the same tests/checks fail there too
|
||||
3. If failures are PRE-EXISTING on the base branch (not caused by the PR):
|
||||
- Do NOT close the PR
|
||||
- Leave a comment: "CI failures are pre-existing on the base branch (not caused by this PR): <list failing checks>. Leaving open for merge once base branch CI is fixed."
|
||||
4. If failures are CAUSED BY the PR's changes:
|
||||
a. Check out the PR branch and attempt to fix (lint issues, duplicate definitions, type errors, etc.)
|
||||
b. If fixed: commit, push, and leave a comment explaining what was fixed
|
||||
c. If unfixable: close with `gh pr close <number> --comment "Closing: this PR introduces CI failures that cannot be auto-fixed — <describe the specific failures>." --delete-branch`
|
||||
- Close the PR (without attempting fixes) ONLY if ANY of these apply:
|
||||
- The optimized function no longer exists in the target file (check the diff)
|
||||
- Has merge conflicts (mergeable state is "CONFLICTING") AND the PR is older than 3 days (to give time for the base branch to stabilize before giving up)
|
||||
Close with: `gh pr close <number> --comment "<reason>" --delete-branch`
|
||||
where <reason> explains WHY the PR is being closed. Examples:
|
||||
- "Closing: PR is older than 7 days without being merged."
|
||||
- "Closing: merge conflicts with the target branch."
|
||||
- "Closing: merge conflicts with the target branch and PR is older than 3 days."
|
||||
- "Closing: the optimized function no longer exists in the target file."
|
||||
- NEVER close a PR solely because it is old. Age alone is not a valid reason to close.
|
||||
- NEVER mass-close multiple PRs without individually evaluating each one.
|
||||
</step>
|
||||
|
||||
<verification>
|
||||
|
|
|
|||
1
.gitignore
vendored
1
.gitignore
vendored
|
|
@ -10,6 +10,7 @@ __pycache__/
|
|||
# Distribution / packaging
|
||||
.Python
|
||||
build/
|
||||
.gradle/
|
||||
develop-eggs/
|
||||
cli/dist/
|
||||
downloads/
|
||||
|
|
|
|||
2
code_to_optimize/java-gradle/.gitignore
vendored
Normal file
2
code_to_optimize/java-gradle/.gitignore
vendored
Normal file
|
|
@ -0,0 +1,2 @@
|
|||
.gradle/
|
||||
build/
|
||||
29
code_to_optimize/java-gradle/build.gradle.kts
Normal file
29
code_to_optimize/java-gradle/build.gradle.kts
Normal file
|
|
@ -0,0 +1,29 @@
|
|||
plugins {
|
||||
java
|
||||
jacoco
|
||||
}
|
||||
|
||||
group = "com.example"
|
||||
version = "1.0.0"
|
||||
|
||||
java {
|
||||
sourceCompatibility = JavaVersion.VERSION_11
|
||||
targetCompatibility = JavaVersion.VERSION_11
|
||||
}
|
||||
|
||||
repositories {
|
||||
mavenCentral()
|
||||
mavenLocal()
|
||||
}
|
||||
|
||||
dependencies {
|
||||
testImplementation("org.junit.jupiter:junit-jupiter:5.10.0")
|
||||
testImplementation("org.junit.jupiter:junit-jupiter-params:5.10.0")
|
||||
testImplementation("org.xerial:sqlite-jdbc:3.42.0.0")
|
||||
testRuntimeOnly("org.junit.platform:junit-platform-launcher")
|
||||
testImplementation(files("/Users/heshammohamed/Work/codeflash/code_to_optimize/java-gradle/libs/codeflash-runtime-1.0.0.jar")) // codeflash-runtime
|
||||
}
|
||||
|
||||
tasks.test {
|
||||
useJUnitPlatform()
|
||||
}
|
||||
4
code_to_optimize/java-gradle/codeflash.toml
Normal file
4
code_to_optimize/java-gradle/codeflash.toml
Normal file
|
|
@ -0,0 +1,4 @@
|
|||
[tool.codeflash]
|
||||
module-root = "src/main/java"
|
||||
tests-root = "src/test/java"
|
||||
formatter-cmds = []
|
||||
BIN
code_to_optimize/java-gradle/libs/codeflash-runtime-1.0.0.jar
Normal file
BIN
code_to_optimize/java-gradle/libs/codeflash-runtime-1.0.0.jar
Normal file
Binary file not shown.
1
code_to_optimize/java-gradle/settings.gradle.kts
Normal file
1
code_to_optimize/java-gradle/settings.gradle.kts
Normal file
|
|
@ -0,0 +1 @@
|
|||
rootProject.name = "codeflash-java-gradle-sample"
|
||||
|
|
@ -0,0 +1,14 @@
|
|||
package com.example;
|
||||
|
||||
public class Fibonacci {
|
||||
|
||||
public static long fibonacci(int n) {
|
||||
if (n < 0) {
|
||||
throw new IllegalArgumentException("n must be non-negative");
|
||||
}
|
||||
if (n <= 1) {
|
||||
return n;
|
||||
}
|
||||
return fibonacci(n - 1) + fibonacci(n - 2);
|
||||
}
|
||||
}
|
||||
|
|
@ -0,0 +1,27 @@
|
|||
package com.example;
|
||||
|
||||
import org.junit.jupiter.api.Test;
|
||||
import static org.junit.jupiter.api.Assertions.*;
|
||||
|
||||
class FibonacciTest {
|
||||
|
||||
@Test
|
||||
void testBaseCase() {
|
||||
assertEquals(0, Fibonacci.fibonacci(0));
|
||||
assertEquals(1, Fibonacci.fibonacci(1));
|
||||
}
|
||||
|
||||
@Test
|
||||
void testSmallValues() {
|
||||
assertEquals(1, Fibonacci.fibonacci(2));
|
||||
assertEquals(2, Fibonacci.fibonacci(3));
|
||||
assertEquals(5, Fibonacci.fibonacci(5));
|
||||
assertEquals(8, Fibonacci.fibonacci(6));
|
||||
assertEquals(55, Fibonacci.fibonacci(10));
|
||||
}
|
||||
|
||||
@Test
|
||||
void testNegative() {
|
||||
assertThrows(IllegalArgumentException.class, () -> Fibonacci.fibonacci(-1));
|
||||
}
|
||||
}
|
||||
65
codeflash-java-runtime/build.gradle.kts
Normal file
65
codeflash-java-runtime/build.gradle.kts
Normal file
|
|
@ -0,0 +1,65 @@
|
|||
plugins {
|
||||
java
|
||||
id("com.gradleup.shadow") version "9.0.0-beta12"
|
||||
}
|
||||
|
||||
group = "com.codeflash"
|
||||
version = "1.0.0"
|
||||
|
||||
java {
|
||||
sourceCompatibility = JavaVersion.VERSION_11
|
||||
targetCompatibility = JavaVersion.VERSION_11
|
||||
}
|
||||
|
||||
repositories {
|
||||
mavenCentral()
|
||||
}
|
||||
|
||||
dependencies {
|
||||
implementation("com.google.code.gson:gson:2.10.1")
|
||||
implementation("com.esotericsoftware:kryo:5.6.2")
|
||||
implementation("org.objenesis:objenesis:3.4")
|
||||
implementation("org.xerial:sqlite-jdbc:3.45.0.0")
|
||||
implementation("org.ow2.asm:asm:9.7.1")
|
||||
implementation("org.ow2.asm:asm-commons:9.7.1")
|
||||
|
||||
testImplementation("org.junit.jupiter:junit-jupiter:5.10.1")
|
||||
testRuntimeOnly("org.junit.platform:junit-platform-launcher")
|
||||
}
|
||||
|
||||
tasks.test {
|
||||
useJUnitPlatform()
|
||||
jvmArgs(
|
||||
"--add-opens", "java.base/java.util=ALL-UNNAMED",
|
||||
"--add-opens", "java.base/java.lang=ALL-UNNAMED",
|
||||
"--add-opens", "java.base/java.lang.reflect=ALL-UNNAMED",
|
||||
"--add-opens", "java.base/java.math=ALL-UNNAMED",
|
||||
"--add-opens", "java.base/java.io=ALL-UNNAMED",
|
||||
"--add-opens", "java.base/java.net=ALL-UNNAMED",
|
||||
"--add-opens", "java.base/java.time=ALL-UNNAMED",
|
||||
)
|
||||
}
|
||||
|
||||
tasks.shadowJar {
|
||||
archiveBaseName.set("codeflash-runtime")
|
||||
archiveVersion.set("1.0.0")
|
||||
archiveClassifier.set("")
|
||||
|
||||
relocate("org.objectweb.asm", "com.codeflash.asm")
|
||||
|
||||
manifest {
|
||||
attributes(
|
||||
"Main-Class" to "com.codeflash.Comparator",
|
||||
"Premain-Class" to "com.codeflash.profiler.ProfilerAgent",
|
||||
"Can-Retransform-Classes" to "true",
|
||||
)
|
||||
}
|
||||
|
||||
exclude("META-INF/*.SF")
|
||||
exclude("META-INF/*.DSA")
|
||||
exclude("META-INF/*.RSA")
|
||||
}
|
||||
|
||||
tasks.build {
|
||||
dependsOn(tasks.shadowJar)
|
||||
}
|
||||
1
codeflash-java-runtime/settings.gradle.kts
Normal file
1
codeflash-java-runtime/settings.gradle.kts
Normal file
|
|
@ -0,0 +1 @@
|
|||
rootProject.name = "codeflash-runtime"
|
||||
|
|
@ -54,8 +54,8 @@ class ComparatorCorrectnessTest {
|
|||
KryoPlaceholder.create(new Object(), "unserializable", "root")
|
||||
);
|
||||
|
||||
insertRow(originalDb, "iter_1_0", 1, placeholderBytes);
|
||||
insertRow(candidateDb, "iter_1_1", 1, placeholderBytes);
|
||||
insertRow(originalDb, "1", 1, placeholderBytes);
|
||||
insertRow(candidateDb, "1", 1, placeholderBytes);
|
||||
|
||||
String json = Comparator.compareDatabases(originalDb.toString(), candidateDb.toString());
|
||||
Map<String, Object> result = parseJson(json);
|
||||
|
|
@ -74,8 +74,8 @@ class ComparatorCorrectnessTest {
|
|||
// Insert corrupted byte data that will fail Kryo deserialization
|
||||
byte[] corruptedBytes = new byte[]{0x01, 0x02, 0x03, (byte) 0xFF, (byte) 0xFE};
|
||||
|
||||
insertRow(originalDb, "iter_1_0", 1, corruptedBytes);
|
||||
insertRow(candidateDb, "iter_1_1", 1, corruptedBytes);
|
||||
insertRow(originalDb, "1", 1, corruptedBytes);
|
||||
insertRow(candidateDb, "1", 1, corruptedBytes);
|
||||
|
||||
String json = Comparator.compareDatabases(originalDb.toString(), candidateDb.toString());
|
||||
Map<String, Object> result = parseJson(json);
|
||||
|
|
@ -97,12 +97,12 @@ class ComparatorCorrectnessTest {
|
|||
KryoPlaceholder.create(new Object(), "unserializable", "root")
|
||||
);
|
||||
|
||||
insertRow(originalDb, "iter_1_0", 1, realBytes1);
|
||||
insertRow(candidateDb, "iter_1_1", 1, realBytes1);
|
||||
insertRow(originalDb, "iter_2_0", 1, realBytes2);
|
||||
insertRow(candidateDb, "iter_2_1", 1, realBytes2);
|
||||
insertRow(originalDb, "iter_3_0", 1, placeholderBytes);
|
||||
insertRow(candidateDb, "iter_3_1", 1, placeholderBytes);
|
||||
insertRow(originalDb, "1", 1, realBytes1);
|
||||
insertRow(candidateDb, "1", 1, realBytes1);
|
||||
insertRow(originalDb, "2", 1, realBytes2);
|
||||
insertRow(candidateDb, "2", 1, realBytes2);
|
||||
insertRow(originalDb, "3", 1, placeholderBytes);
|
||||
insertRow(candidateDb, "3", 1, placeholderBytes);
|
||||
|
||||
String json = Comparator.compareDatabases(originalDb.toString(), candidateDb.toString());
|
||||
Map<String, Object> result = parseJson(json);
|
||||
|
|
@ -121,10 +121,10 @@ class ComparatorCorrectnessTest {
|
|||
byte[] bytes1 = Serializer.serialize(100);
|
||||
byte[] bytes2 = Serializer.serialize("world");
|
||||
|
||||
insertRow(originalDb, "iter_1_0", 1, bytes1);
|
||||
insertRow(candidateDb, "iter_1_1", 1, bytes1);
|
||||
insertRow(originalDb, "iter_2_0", 1, bytes2);
|
||||
insertRow(candidateDb, "iter_2_1", 1, bytes2);
|
||||
insertRow(originalDb, "1", 1, bytes1);
|
||||
insertRow(candidateDb, "1", 1, bytes1);
|
||||
insertRow(originalDb, "2", 1, bytes2);
|
||||
insertRow(candidateDb, "2", 1, bytes2);
|
||||
|
||||
String json = Comparator.compareDatabases(originalDb.toString(), candidateDb.toString());
|
||||
Map<String, Object> result = parseJson(json);
|
||||
|
|
@ -144,8 +144,8 @@ class ComparatorCorrectnessTest {
|
|||
byte[] origBytes = Serializer.serialize(42);
|
||||
byte[] candBytes = Serializer.serialize(99);
|
||||
|
||||
insertRow(originalDb, "iter_1_0", 1, origBytes);
|
||||
insertRow(candidateDb, "iter_1_1", 1, candBytes);
|
||||
insertRow(originalDb, "1", 1, origBytes);
|
||||
insertRow(candidateDb, "1", 1, candBytes);
|
||||
|
||||
String json = Comparator.compareDatabases(originalDb.toString(), candidateDb.toString());
|
||||
Map<String, Object> result = parseJson(json);
|
||||
|
|
@ -161,8 +161,8 @@ class ComparatorCorrectnessTest {
|
|||
createTestDb(candidateDb);
|
||||
|
||||
// Insert rows with NULL return_value (void methods)
|
||||
insertRow(originalDb, "iter_1_0", 1, null);
|
||||
insertRow(candidateDb, "iter_1_1", 1, null);
|
||||
insertRow(originalDb, "1", 1, null);
|
||||
insertRow(candidateDb, "1", 1, null);
|
||||
|
||||
String json = Comparator.compareDatabases(originalDb.toString(), candidateDb.toString());
|
||||
Map<String, Object> result = parseJson(json);
|
||||
|
|
@ -178,7 +178,7 @@ class ComparatorCorrectnessTest {
|
|||
createTestDb(candidateDb);
|
||||
|
||||
byte[] bytes = Serializer.serialize(42);
|
||||
insertRow(originalDb, "iter_1_0", 1, bytes);
|
||||
insertRow(originalDb, "1", 1, bytes);
|
||||
// candidateDb has no rows
|
||||
|
||||
String json = Comparator.compareDatabases(originalDb.toString(), candidateDb.toString());
|
||||
|
|
@ -222,7 +222,7 @@ class ComparatorCorrectnessTest {
|
|||
+ "iteration_id TEXT NOT NULL, "
|
||||
+ "loop_index INTEGER NOT NULL, "
|
||||
+ "return_value BLOB, "
|
||||
+ "PRIMARY KEY (iteration_id, loop_index))");
|
||||
+ "verification_type TEXT)");
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -338,6 +338,19 @@ def get_qualified_name(module_name: str, full_qualified_name: str) -> str:
|
|||
return full_qualified_name[len(module_name) + 1 :]
|
||||
|
||||
|
||||
_PARAMETERIZED_INDEX_RE = re.compile(r"\[(\d+)")
|
||||
|
||||
|
||||
def extract_parameterized_test_index(test_name: str) -> int:
|
||||
"""Extract the numeric index from a parameterized test name.
|
||||
|
||||
Handles formats like ``test[ 0 ]``, ``test[1]``, and
|
||||
``test[1] input=foo, expected=bar``. Returns 1 when no numeric index is found.
|
||||
"""
|
||||
m = _PARAMETERIZED_INDEX_RE.search(test_name)
|
||||
return int(m.group(1)) if m else 1
|
||||
|
||||
|
||||
def module_name_from_file_path(file_path: Path, project_root_path: Path, *, traverse_up: bool = False) -> str:
|
||||
try:
|
||||
relative_path = file_path.resolve().relative_to(project_root_path.resolve())
|
||||
|
|
|
|||
|
|
@ -13,7 +13,6 @@ from rich.prompt import Confirm
|
|||
from unidiff import PatchSet
|
||||
|
||||
from codeflash.cli_cmds.console import logger
|
||||
from codeflash.languages.registry import get_supported_extensions
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from git import Repo
|
||||
|
|
@ -26,13 +25,15 @@ def get_git_diff(
|
|||
uncommitted_changes: bool = False,
|
||||
since_commit: Optional[str] = None,
|
||||
) -> dict[str, list[int]]:
|
||||
from codeflash.languages.current import current_language_support
|
||||
from codeflash.languages.registry import get_supported_extensions
|
||||
|
||||
if repo_directory is None:
|
||||
repo_directory = Path.cwd()
|
||||
repository = git.Repo(repo_directory, search_parent_directories=True)
|
||||
commit = repository.head.commit
|
||||
supported_extensions = current_language_support().file_extensions
|
||||
# Use all registered extensions (Python + JS/TS + Java etc.) rather than
|
||||
# current_language_support() which defaults to Python before language detection runs.
|
||||
supported_extensions = set(get_supported_extensions())
|
||||
if since_commit:
|
||||
# Diff from a base commit to HEAD — captures all changes across multiple commits
|
||||
uni_diff_text = repository.git.diff(
|
||||
|
|
@ -48,7 +49,6 @@ def get_git_diff(
|
|||
uni_diff_text = repository.git.diff(
|
||||
commit.hexsha + "^1", commit.hexsha, ignore_blank_lines=True, ignore_space_at_eol=True
|
||||
)
|
||||
supported_extensions = set(get_supported_extensions())
|
||||
patch_set = PatchSet(StringIO(uni_diff_text))
|
||||
change_list: dict[str, list[int]] = {} # list of changes
|
||||
for patched_file in patch_set:
|
||||
|
|
|
|||
|
|
@ -688,14 +688,21 @@ class LanguageSupport(Protocol):
|
|||
# === Test Result Comparison ===
|
||||
|
||||
def compare_test_results(
|
||||
self, original_results_path: Path, candidate_results_path: Path, project_root: Path | None = None
|
||||
) -> tuple[bool, list[Any]]:
|
||||
self,
|
||||
original_results_path: Path,
|
||||
candidate_results_path: Path,
|
||||
project_root: Path | None = None,
|
||||
project_classpath: str | None = None,
|
||||
) -> tuple[bool, list]:
|
||||
"""Compare test results between original and candidate code.
|
||||
|
||||
Args:
|
||||
original_results_path: Path to original test results (e.g., SQLite DB).
|
||||
candidate_results_path: Path to candidate test results.
|
||||
project_root: Project root directory (for finding node_modules, etc.).
|
||||
project_classpath: Full project classpath string (Java only). Needed so
|
||||
the Comparator JVM can resolve project-specific classes during Kryo
|
||||
deserialization.
|
||||
|
||||
Returns:
|
||||
Tuple of (are_equivalent, list of TestDiff objects).
|
||||
|
|
|
|||
|
|
@ -23,7 +23,9 @@ if TYPE_CHECKING:
|
|||
_SOURCE_CRITERIA = FunctionFilterCriteria(require_return=False, require_export=False)
|
||||
|
||||
|
||||
def get_optimized_code_for_module(relative_path: Path, optimized_code: CodeStringsMarkdown) -> str:
|
||||
def get_optimized_code_for_module(
|
||||
relative_path: Path, optimized_code: CodeStringsMarkdown, allow_fallback: bool = True
|
||||
) -> str:
|
||||
from codeflash.languages.current import is_python
|
||||
|
||||
file_to_code_context = optimized_code.file_to_path()
|
||||
|
|
@ -31,6 +33,9 @@ def get_optimized_code_for_module(relative_path: Path, optimized_code: CodeStrin
|
|||
if module_optimized_code is not None:
|
||||
return module_optimized_code
|
||||
|
||||
if not allow_fallback:
|
||||
return ""
|
||||
|
||||
# Fallback 1: single code block with no file path
|
||||
if "None" in file_to_code_context and len(file_to_code_context) == 1:
|
||||
logger.debug(f"Using code block with None file_path for {relative_path}")
|
||||
|
|
@ -72,7 +77,10 @@ def replace_function_definitions_for_language(
|
|||
and LanguageSupport.discover_functions.
|
||||
"""
|
||||
original_source_code: str = module_abspath.read_text(encoding="utf8")
|
||||
code_to_apply = get_optimized_code_for_module(module_abspath.relative_to(project_root_path), optimized_code)
|
||||
is_target_file = function_to_optimize is not None and function_to_optimize.file_path == module_abspath
|
||||
code_to_apply = get_optimized_code_for_module(
|
||||
module_abspath.relative_to(project_root_path), optimized_code, allow_fallback=is_target_file
|
||||
)
|
||||
|
||||
if not code_to_apply.strip():
|
||||
return False
|
||||
|
|
|
|||
|
|
@ -5,21 +5,15 @@ test execution, and optimization using tree-sitter for parsing and
|
|||
Maven/Gradle for build operations.
|
||||
"""
|
||||
|
||||
from codeflash.languages.java.build_tool_strategy import BuildToolStrategy, get_strategy
|
||||
from codeflash.languages.java.build_tools import (
|
||||
BuildTool,
|
||||
JavaProjectInfo,
|
||||
MavenTestResult,
|
||||
add_codeflash_dependency_to_pom,
|
||||
compile_maven_project,
|
||||
detect_build_tool,
|
||||
find_gradle_executable,
|
||||
find_maven_executable,
|
||||
find_source_root,
|
||||
find_test_root,
|
||||
get_classpath,
|
||||
get_project_info,
|
||||
install_codeflash_runtime,
|
||||
run_maven_tests,
|
||||
)
|
||||
from codeflash.languages.java.comparator import compare_invocations_directly, compare_test_results
|
||||
from codeflash.languages.java.config import (
|
||||
|
|
@ -58,6 +52,7 @@ from codeflash.languages.java.instrumentation import (
|
|||
instrument_generated_java_test,
|
||||
remove_instrumentation,
|
||||
)
|
||||
from codeflash.languages.java.maven_strategy import add_codeflash_dependency, install_codeflash_runtime
|
||||
from codeflash.languages.java.parser import (
|
||||
JavaAnalyzer,
|
||||
JavaClassNode,
|
||||
|
|
@ -103,6 +98,7 @@ from codeflash.languages.java.test_runner import (
|
|||
__all__ = [
|
||||
# Build tools
|
||||
"BuildTool",
|
||||
"BuildToolStrategy",
|
||||
# Parser
|
||||
"JavaAnalyzer",
|
||||
# Assertion removal
|
||||
|
|
@ -124,7 +120,7 @@ __all__ = [
|
|||
"JavaTestRunResult",
|
||||
"MavenTestResult",
|
||||
"ResolvedImport",
|
||||
"add_codeflash_dependency_to_pom",
|
||||
"add_codeflash_dependency",
|
||||
# Replacement
|
||||
"add_runtime_comments",
|
||||
# Test discovery
|
||||
|
|
@ -132,7 +128,6 @@ __all__ = [
|
|||
# Comparator
|
||||
"compare_invocations_directly",
|
||||
"compare_test_results",
|
||||
"compile_maven_project",
|
||||
# Instrumentation
|
||||
"create_benchmark_test",
|
||||
"detect_build_tool",
|
||||
|
|
@ -148,21 +143,19 @@ __all__ = [
|
|||
"extract_code_context",
|
||||
"extract_function_source",
|
||||
"extract_read_only_context",
|
||||
"find_gradle_executable",
|
||||
"find_helper_files",
|
||||
"find_helper_functions",
|
||||
"find_maven_executable",
|
||||
"find_source_root",
|
||||
"find_test_root",
|
||||
"find_tests_for_function",
|
||||
"format_java_code",
|
||||
"format_java_file",
|
||||
"get_class_methods",
|
||||
"get_classpath",
|
||||
"get_java_analyzer",
|
||||
"get_java_support",
|
||||
"get_method_by_name",
|
||||
"get_project_info",
|
||||
"get_strategy",
|
||||
"get_test_class_for_source_class",
|
||||
"get_test_class_pattern",
|
||||
"get_test_file_pattern",
|
||||
|
|
@ -189,7 +182,6 @@ __all__ = [
|
|||
"resolve_imports_for_file",
|
||||
"run_behavioral_tests",
|
||||
"run_benchmarking_tests",
|
||||
"run_maven_tests",
|
||||
"run_tests",
|
||||
"transform_java_assertions",
|
||||
]
|
||||
|
|
|
|||
191
codeflash/languages/java/build_tool_strategy.py
Normal file
191
codeflash/languages/java/build_tool_strategy.py
Normal file
|
|
@ -0,0 +1,191 @@
|
|||
"""Abstract build tool strategy for Java projects.
|
||||
|
||||
Defines the interface for build-tool-specific operations (compilation,
|
||||
classpath extraction, test execution, coverage). Concrete implementations
|
||||
live in maven_strategy.py and gradle_strategy.py.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import os
|
||||
from abc import ABC, abstractmethod
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import subprocess
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_RUNTIME_JAR_NAME = "codeflash-runtime-1.0.0.jar"
|
||||
_JAVA_RUNTIME_DIR = Path(__file__).parent.parent.parent.parent / "codeflash-java-runtime"
|
||||
|
||||
|
||||
def module_to_dir(test_module: str) -> str:
|
||||
"""Convert a build-tool module name to a filesystem-relative path.
|
||||
|
||||
Gradle uses ``:`` as the module separator (``connect:runtime``), while Maven
|
||||
uses the directory name directly. On the filesystem the separator is always
|
||||
``/`` (or ``os.sep``).
|
||||
"""
|
||||
return test_module.replace(":", os.sep)
|
||||
|
||||
|
||||
class BuildToolStrategy(ABC):
|
||||
"""Strategy interface for Java build tool operations.
|
||||
|
||||
Only methods that genuinely differ between Maven and Gradle belong here.
|
||||
Shared logic (direct JVM execution, JUnit XML parsing) stays in test_runner.py.
|
||||
"""
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def name(self) -> str:
|
||||
"""Human-readable name for log messages (e.g. 'Maven', 'Gradle')."""
|
||||
...
|
||||
|
||||
def find_runtime_jar(self) -> Path | None:
|
||||
"""Find the codeflash-runtime JAR file.
|
||||
|
||||
Checks package resources and development build directories.
|
||||
Subclasses should override to prepend tool-specific cache paths
|
||||
and fall back to super().find_runtime_jar().
|
||||
"""
|
||||
resources_jar = Path(__file__).parent / "resources" / _RUNTIME_JAR_NAME
|
||||
if resources_jar.exists():
|
||||
return resources_jar
|
||||
|
||||
dev_jar_maven = _JAVA_RUNTIME_DIR / "target" / _RUNTIME_JAR_NAME
|
||||
if dev_jar_maven.exists():
|
||||
return dev_jar_maven
|
||||
|
||||
dev_jar_gradle = _JAVA_RUNTIME_DIR / "build" / "libs" / _RUNTIME_JAR_NAME
|
||||
if dev_jar_gradle.exists():
|
||||
return dev_jar_gradle
|
||||
|
||||
return None
|
||||
|
||||
@abstractmethod
|
||||
def find_executable(self, build_root: Path) -> str | None:
|
||||
"""Find the build tool executable, searching up parent directories if needed."""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def ensure_runtime(self, build_root: Path, test_module: str | None) -> bool:
|
||||
"""Install codeflash-runtime JAR and register it as a project dependency."""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def install_multi_module_deps(self, build_root: Path, test_module: str | None, env: dict[str, str]) -> bool:
|
||||
"""Pre-install multi-module dependencies so later invocations skip recompilation."""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def compile_tests(
|
||||
self, build_root: Path, env: dict[str, str], test_module: str | None, timeout: int = 120
|
||||
) -> subprocess.CompletedProcess[str]:
|
||||
"""Compile test code without running tests."""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def compile_source_only(
|
||||
self, build_root: Path, env: dict[str, str], test_module: str | None, timeout: int = 120
|
||||
) -> subprocess.CompletedProcess[str]:
|
||||
"""Compile only main source code (not tests). Used when test classes are already compiled."""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def get_classpath(
|
||||
self, build_root: Path, env: dict[str, str], test_module: str | None, timeout: int = 60
|
||||
) -> str | None:
|
||||
"""Return the full test classpath string. Caching is an implementation detail."""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def get_reports_dir(self, build_root: Path, test_module: str | None) -> Path:
|
||||
"""Return the directory containing JUnit XML test reports."""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def get_build_output_dir(self, build_root: Path, test_module: str | None) -> Path:
|
||||
"""Return the build output directory (e.g. target/ for Maven, build/ for Gradle)."""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def run_tests_via_build_tool(
|
||||
self,
|
||||
build_root: Path,
|
||||
test_paths: Any,
|
||||
env: dict[str, str],
|
||||
timeout: int,
|
||||
mode: str,
|
||||
test_module: str | None,
|
||||
javaagent_arg: str | None = None,
|
||||
enable_coverage: bool = False,
|
||||
) -> subprocess.CompletedProcess[str]:
|
||||
"""Run tests via the build tool (e.g. Maven Surefire). Used as fallback when direct JVM fails."""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def run_benchmarking_via_build_tool(
|
||||
self,
|
||||
test_paths: Any,
|
||||
test_env: dict[str, str],
|
||||
cwd: Path,
|
||||
timeout: int | None,
|
||||
project_root: Path | None,
|
||||
min_loops: int,
|
||||
max_loops: int,
|
||||
target_duration_seconds: float,
|
||||
inner_iterations: int,
|
||||
) -> tuple[Path, Any]:
|
||||
"""Run benchmarking loop via build tool (fallback when direct JVM fails)."""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def run_tests_with_coverage(
|
||||
self,
|
||||
build_root: Path,
|
||||
test_module: str | None,
|
||||
test_paths: Any,
|
||||
run_env: dict[str, str],
|
||||
timeout: int,
|
||||
candidate_index: int,
|
||||
) -> tuple[subprocess.CompletedProcess[str], Path, Path | None]:
|
||||
"""Run tests with coverage enabled. Returns (result, junit_xml_path, coverage_xml_path)."""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def setup_coverage(self, build_root: Path, test_module: str | None, project_root: Path) -> Path | None:
|
||||
"""Configure coverage tool (e.g. JaCoCo) and return expected XML report path."""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def get_test_run_command(self, project_root: Path, test_classes: list[str] | None = None) -> list[str]:
|
||||
"""Return the shell command to run tests, including any test class filters."""
|
||||
...
|
||||
|
||||
|
||||
def _build_strategy_registry() -> dict[str, type[BuildToolStrategy]]:
|
||||
"""Lazily import and return the {BuildTool.value -> class} mapping."""
|
||||
from codeflash.languages.java.gradle_strategy import GradleStrategy
|
||||
from codeflash.languages.java.maven_strategy import MavenStrategy
|
||||
|
||||
return {"maven": MavenStrategy, "gradle": GradleStrategy}
|
||||
|
||||
|
||||
def get_strategy(project_root: Path) -> BuildToolStrategy:
|
||||
"""Detect build tool and return the appropriate strategy."""
|
||||
from codeflash.languages.java.build_tools import detect_build_tool
|
||||
|
||||
build_tool = detect_build_tool(project_root)
|
||||
registry = _build_strategy_registry()
|
||||
|
||||
strategy_cls = registry.get(build_tool.value)
|
||||
if strategy_cls is not None:
|
||||
return strategy_cls()
|
||||
|
||||
supported = ", ".join(registry)
|
||||
msg = f"No supported build tool found in {project_root}. Expected one of: {supported}."
|
||||
raise ValueError(msg)
|
||||
|
|
@ -7,15 +7,10 @@ This module provides functionality to detect and work with Java build tools
|
|||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
import shutil
|
||||
import subprocess
|
||||
import urllib.request
|
||||
import xml.etree.ElementTree as ET
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
from pathlib import Path
|
||||
from pathlib import Path # noqa: TC003 — used at runtime
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
|
@ -25,71 +20,6 @@ CODEFLASH_RUNTIME_JAR_NAME = f"codeflash-runtime-{CODEFLASH_RUNTIME_VERSION}.jar
|
|||
JACOCO_PLUGIN_VERSION = "0.8.13"
|
||||
|
||||
|
||||
GITHUB_RELEASE_URL = (
|
||||
"https://github.com/codeflash-ai/codeflash/releases/download"
|
||||
f"/runtime-v{CODEFLASH_RUNTIME_VERSION}/{CODEFLASH_RUNTIME_JAR_NAME}"
|
||||
)
|
||||
|
||||
CODEFLASH_CACHE_DIR = Path.home() / ".cache" / "codeflash"
|
||||
|
||||
|
||||
def download_from_github_releases() -> Path | None:
|
||||
"""Download codeflash-runtime JAR from GitHub Releases.
|
||||
|
||||
Downloads to ~/.cache/codeflash/ and returns the path to the downloaded JAR.
|
||||
Returns None if the download fails (e.g., no release published yet, network error).
|
||||
|
||||
This serves as a fallback when Maven Central resolution fails — for example,
|
||||
when the user's project doesn't have Maven installed or Maven Central is unreachable.
|
||||
Requires a GitHub Release tagged 'runtime-v{version}' with the JAR as an asset.
|
||||
"""
|
||||
cache_jar = CODEFLASH_CACHE_DIR / CODEFLASH_RUNTIME_JAR_NAME
|
||||
if cache_jar.exists():
|
||||
logger.info("Found cached codeflash-runtime JAR: %s", cache_jar)
|
||||
return cache_jar
|
||||
|
||||
try:
|
||||
CODEFLASH_CACHE_DIR.mkdir(parents=True, exist_ok=True)
|
||||
logger.info("Downloading codeflash-runtime from GitHub Releases: %s", GITHUB_RELEASE_URL)
|
||||
urllib.request.urlretrieve(GITHUB_RELEASE_URL, cache_jar) # noqa: S310
|
||||
logger.info("Downloaded codeflash-runtime to %s", cache_jar)
|
||||
return cache_jar
|
||||
except Exception as e:
|
||||
logger.debug("GitHub Releases download failed: %s", e)
|
||||
cache_jar.unlink(missing_ok=True)
|
||||
return None
|
||||
|
||||
|
||||
def resolve_from_maven_central(maven_root: Path) -> bool:
|
||||
"""Ask Maven to resolve codeflash-runtime from Maven Central.
|
||||
|
||||
This downloads the JAR to ~/.m2/repository/ automatically.
|
||||
Only works once the JAR is published to Maven Central.
|
||||
|
||||
Returns True if Maven successfully resolved the artifact.
|
||||
"""
|
||||
mvn = find_maven_executable()
|
||||
if not mvn:
|
||||
return False
|
||||
cmd = [
|
||||
mvn,
|
||||
"dependency:resolve",
|
||||
f"-Dartifact=com.codeflash:codeflash-runtime:{CODEFLASH_RUNTIME_VERSION}",
|
||||
"-B",
|
||||
"-q",
|
||||
]
|
||||
try:
|
||||
result = subprocess.run(cmd, check=False, cwd=maven_root, capture_output=True, text=True, timeout=60)
|
||||
if result.returncode == 0:
|
||||
logger.info("Resolved codeflash-runtime %s from Maven Central", CODEFLASH_RUNTIME_VERSION)
|
||||
return True
|
||||
logger.debug("Maven Central resolution failed: %s", result.stderr)
|
||||
return False
|
||||
except Exception as e:
|
||||
logger.debug("Maven Central resolution error: %s", e)
|
||||
return False
|
||||
|
||||
|
||||
def _safe_parse_xml(file_path: Path) -> ET.ElementTree:
|
||||
"""Safely parse an XML file with protections against XXE attacks.
|
||||
|
||||
|
|
@ -363,191 +293,9 @@ def _get_gradle_project_info(project_root: Path) -> JavaProjectInfo | None:
|
|||
)
|
||||
|
||||
|
||||
def find_maven_executable(project_root: Path | None = None) -> str | None:
|
||||
"""Find the Maven executable.
|
||||
|
||||
Returns:
|
||||
Path to mvn executable, or None if not found.
|
||||
|
||||
"""
|
||||
# Check for Maven wrapper in project root first
|
||||
if project_root is not None:
|
||||
mvnw_path = project_root / "mvnw"
|
||||
if mvnw_path.exists():
|
||||
return str(mvnw_path)
|
||||
mvnw_cmd_path = project_root / "mvnw.cmd"
|
||||
if mvnw_cmd_path.exists():
|
||||
return str(mvnw_cmd_path)
|
||||
|
||||
# Check for Maven wrapper in current directory
|
||||
if Path("mvnw").exists():
|
||||
return "./mvnw"
|
||||
if Path("mvnw.cmd").exists():
|
||||
return "mvnw.cmd"
|
||||
|
||||
# Check system Maven
|
||||
mvn_path = shutil.which("mvn")
|
||||
if mvn_path:
|
||||
return mvn_path
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def find_gradle_executable(project_root: Path | None = None) -> str | None:
|
||||
"""Find the Gradle executable.
|
||||
|
||||
Checks for Gradle wrapper in the project root and current directory,
|
||||
then falls back to system Gradle.
|
||||
|
||||
Args:
|
||||
project_root: Optional project root directory to search for Gradle wrapper.
|
||||
|
||||
Returns:
|
||||
Path to gradle executable, or None if not found.
|
||||
|
||||
"""
|
||||
# Check for Gradle wrapper in project root first
|
||||
if project_root is not None:
|
||||
gradlew_path = project_root / "gradlew"
|
||||
if gradlew_path.exists():
|
||||
return str(gradlew_path)
|
||||
gradlew_bat_path = project_root / "gradlew.bat"
|
||||
if gradlew_bat_path.exists():
|
||||
return str(gradlew_bat_path)
|
||||
|
||||
# Check for Gradle wrapper in current directory
|
||||
if Path("gradlew").exists():
|
||||
return "./gradlew"
|
||||
if Path("gradlew.bat").exists():
|
||||
return "gradlew.bat"
|
||||
|
||||
# Check system Gradle
|
||||
gradle_path = shutil.which("gradle")
|
||||
if gradle_path:
|
||||
return gradle_path
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def run_maven_tests(
|
||||
project_root: Path,
|
||||
test_classes: list[str] | None = None,
|
||||
test_methods: list[str] | None = None,
|
||||
env: dict[str, str] | None = None,
|
||||
timeout: int = 300,
|
||||
skip_compilation: bool = False,
|
||||
) -> MavenTestResult:
|
||||
"""Run Maven tests using Surefire.
|
||||
|
||||
Args:
|
||||
project_root: Root directory of the Maven project.
|
||||
test_classes: Optional list of test class names to run.
|
||||
test_methods: Optional list of specific test methods (format: ClassName#methodName).
|
||||
env: Optional environment variables.
|
||||
timeout: Maximum time in seconds for test execution.
|
||||
skip_compilation: Whether to skip compilation (useful when only running tests).
|
||||
|
||||
Returns:
|
||||
MavenTestResult with test execution results.
|
||||
|
||||
"""
|
||||
mvn = find_maven_executable()
|
||||
if not mvn:
|
||||
logger.error("Maven not found. Please install Maven or use Maven wrapper.")
|
||||
return MavenTestResult(
|
||||
success=False,
|
||||
tests_run=0,
|
||||
failures=0,
|
||||
errors=0,
|
||||
skipped=0,
|
||||
surefire_reports_dir=None,
|
||||
stdout="",
|
||||
stderr="Maven not found",
|
||||
returncode=-1,
|
||||
)
|
||||
|
||||
# Build Maven command
|
||||
cmd = [mvn]
|
||||
|
||||
if skip_compilation:
|
||||
cmd.append("-Dmaven.test.skip=false")
|
||||
cmd.append("-DskipTests=false")
|
||||
cmd.append("surefire:test")
|
||||
else:
|
||||
cmd.append("test")
|
||||
|
||||
# Add test filtering
|
||||
if test_classes or test_methods:
|
||||
if test_methods:
|
||||
# Format: -Dtest=ClassName#method1+method2,OtherClass#method3
|
||||
tests = ",".join(test_methods)
|
||||
elif test_classes:
|
||||
tests = ",".join(test_classes)
|
||||
cmd.extend(["-Dtest=" + tests])
|
||||
|
||||
# Fail at end to run all tests; -B for batch mode (no ANSI colors)
|
||||
cmd.extend(["-fae", "-B"])
|
||||
|
||||
# Use full environment with optional overrides
|
||||
run_env = os.environ.copy()
|
||||
if env:
|
||||
run_env.update(env)
|
||||
|
||||
try:
|
||||
result = subprocess.run(
|
||||
cmd, check=False, cwd=project_root, env=run_env, capture_output=True, text=True, timeout=timeout
|
||||
)
|
||||
|
||||
# Parse test results from Surefire reports
|
||||
surefire_dir = project_root / "target" / "surefire-reports"
|
||||
tests_run, failures, errors, skipped = _parse_surefire_reports(surefire_dir)
|
||||
|
||||
return MavenTestResult(
|
||||
success=result.returncode == 0,
|
||||
tests_run=tests_run,
|
||||
failures=failures,
|
||||
errors=errors,
|
||||
skipped=skipped,
|
||||
surefire_reports_dir=surefire_dir if surefire_dir.exists() else None,
|
||||
stdout=result.stdout,
|
||||
stderr=result.stderr,
|
||||
returncode=result.returncode,
|
||||
)
|
||||
|
||||
except subprocess.TimeoutExpired:
|
||||
logger.exception("Maven test execution timed out after %d seconds", timeout)
|
||||
return MavenTestResult(
|
||||
success=False,
|
||||
tests_run=0,
|
||||
failures=0,
|
||||
errors=0,
|
||||
skipped=0,
|
||||
surefire_reports_dir=None,
|
||||
stdout="",
|
||||
stderr=f"Test execution timed out after {timeout} seconds",
|
||||
returncode=-2,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.exception("Maven test execution failed: %s", e)
|
||||
return MavenTestResult(
|
||||
success=False,
|
||||
tests_run=0,
|
||||
failures=0,
|
||||
errors=0,
|
||||
skipped=0,
|
||||
surefire_reports_dir=None,
|
||||
stdout="",
|
||||
stderr=str(e),
|
||||
returncode=-1,
|
||||
)
|
||||
|
||||
|
||||
def _parse_surefire_reports(surefire_dir: Path) -> tuple[int, int, int, int]:
|
||||
"""Parse Surefire XML reports to get test counts.
|
||||
|
||||
Args:
|
||||
surefire_dir: Directory containing Surefire XML reports.
|
||||
|
||||
Returns:
|
||||
Tuple of (tests_run, failures, errors, skipped).
|
||||
|
||||
|
|
@ -564,8 +312,9 @@ def _parse_surefire_reports(surefire_dir: Path) -> tuple[int, int, int, int]:
|
|||
try:
|
||||
tree = _safe_parse_xml(xml_file)
|
||||
root = tree.getroot()
|
||||
if root is None:
|
||||
continue
|
||||
|
||||
# Safely parse numeric attributes with validation
|
||||
try:
|
||||
tests_run += int(root.get("tests", "0"))
|
||||
except (ValueError, TypeError):
|
||||
|
|
@ -594,411 +343,6 @@ def _parse_surefire_reports(surefire_dir: Path) -> tuple[int, int, int, int]:
|
|||
return tests_run, failures, errors, skipped
|
||||
|
||||
|
||||
def compile_maven_project(
|
||||
project_root: Path, include_tests: bool = True, env: dict[str, str] | None = None, timeout: int = 300
|
||||
) -> tuple[bool, str, str]:
|
||||
"""Compile a Maven project.
|
||||
|
||||
Args:
|
||||
project_root: Root directory of the Maven project.
|
||||
include_tests: Whether to compile test classes as well.
|
||||
env: Optional environment variables.
|
||||
timeout: Maximum time in seconds for compilation.
|
||||
|
||||
Returns:
|
||||
Tuple of (success, stdout, stderr).
|
||||
|
||||
"""
|
||||
mvn = find_maven_executable()
|
||||
if not mvn:
|
||||
return False, "", "Maven not found"
|
||||
|
||||
cmd = [mvn]
|
||||
|
||||
if include_tests:
|
||||
cmd.append("test-compile")
|
||||
else:
|
||||
cmd.append("compile")
|
||||
|
||||
# Skip test execution; -B for batch mode (no ANSI colors)
|
||||
cmd.extend(["-DskipTests", "-B"])
|
||||
|
||||
run_env = os.environ.copy()
|
||||
if env:
|
||||
run_env.update(env)
|
||||
|
||||
try:
|
||||
result = subprocess.run(
|
||||
cmd, check=False, cwd=project_root, env=run_env, capture_output=True, text=True, timeout=timeout
|
||||
)
|
||||
|
||||
return result.returncode == 0, result.stdout, result.stderr
|
||||
|
||||
except subprocess.TimeoutExpired:
|
||||
return False, "", f"Compilation timed out after {timeout} seconds"
|
||||
except Exception as e:
|
||||
return False, "", str(e)
|
||||
|
||||
|
||||
def install_codeflash_runtime(project_root: Path, runtime_jar_path: Path) -> bool:
|
||||
"""Install the codeflash runtime JAR to the local Maven repository.
|
||||
|
||||
Args:
|
||||
project_root: Root directory of the Maven project.
|
||||
runtime_jar_path: Path to the codeflash-runtime.jar file.
|
||||
|
||||
Returns:
|
||||
True if installation succeeded, False otherwise.
|
||||
|
||||
"""
|
||||
mvn = find_maven_executable()
|
||||
if not mvn:
|
||||
logger.error("Maven not found")
|
||||
return False
|
||||
|
||||
if not runtime_jar_path.exists():
|
||||
logger.error("Runtime JAR not found: %s", runtime_jar_path)
|
||||
return False
|
||||
|
||||
cmd = [
|
||||
mvn,
|
||||
"install:install-file",
|
||||
f"-Dfile={runtime_jar_path}",
|
||||
"-DgroupId=com.codeflash",
|
||||
"-DartifactId=codeflash-runtime",
|
||||
f"-Dversion={CODEFLASH_RUNTIME_VERSION}",
|
||||
"-Dpackaging=jar",
|
||||
"-B",
|
||||
]
|
||||
|
||||
try:
|
||||
result = subprocess.run(cmd, check=False, cwd=project_root, capture_output=True, text=True, timeout=60)
|
||||
|
||||
if result.returncode == 0:
|
||||
logger.info("Successfully installed codeflash-runtime to local Maven repository")
|
||||
return True
|
||||
logger.error("Failed to install codeflash-runtime: %s", result.stderr)
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
logger.exception("Failed to install codeflash-runtime: %s", e)
|
||||
return False
|
||||
|
||||
|
||||
CODEFLASH_DEPENDENCY_SNIPPET = f"""\
|
||||
<dependency>
|
||||
<groupId>com.codeflash</groupId>
|
||||
<artifactId>codeflash-runtime</artifactId>
|
||||
<version>{CODEFLASH_RUNTIME_VERSION}</version>
|
||||
<scope>test</scope>
|
||||
</dependency>
|
||||
</dependencies>"""
|
||||
|
||||
|
||||
def add_codeflash_dependency_to_pom(pom_path: Path) -> bool:
|
||||
"""Add codeflash-runtime dependency to pom.xml if not present.
|
||||
|
||||
Uses string manipulation instead of ElementTree to preserve the original
|
||||
XML formatting and namespace prefixes (ElementTree rewrites ns0: prefixes
|
||||
which breaks Maven).
|
||||
|
||||
Args:
|
||||
pom_path: Path to the pom.xml file.
|
||||
|
||||
Returns:
|
||||
True if dependency was added or already present, False on error.
|
||||
|
||||
"""
|
||||
if not pom_path.exists():
|
||||
return False
|
||||
|
||||
try:
|
||||
content = pom_path.read_text(encoding="utf-8")
|
||||
|
||||
# Check if already present
|
||||
if "codeflash-runtime" in content:
|
||||
# If a previous run left a system-scope dependency, replace it with test scope.
|
||||
# System-scope dependencies cause Maven warnings and are rejected by some projects.
|
||||
if "<scope>system</scope>" in content:
|
||||
# Replace ONLY the codeflash-runtime dependency block that has system scope.
|
||||
# We find each <dependency>...</dependency> block individually and only replace
|
||||
# the one containing both "codeflash-runtime" and "<scope>system</scope>".
|
||||
# The previous regex used [\s\S]*? lookaheads that could match across blocks,
|
||||
# accidentally replacing every dependency in the file.
|
||||
def replace_system_dep(match: re.Match) -> str:
|
||||
block = match.group(0)
|
||||
if "codeflash-runtime" in block and "<scope>system</scope>" in block:
|
||||
return (
|
||||
"<dependency>\n"
|
||||
" <groupId>com.codeflash</groupId>\n"
|
||||
" <artifactId>codeflash-runtime</artifactId>\n"
|
||||
f" <version>{CODEFLASH_RUNTIME_VERSION}</version>\n"
|
||||
" <scope>test</scope>\n"
|
||||
" </dependency>"
|
||||
)
|
||||
return block
|
||||
|
||||
content = re.sub(r"<dependency>[\s\S]*?</dependency>", replace_system_dep, content)
|
||||
pom_path.write_text(content, encoding="utf-8")
|
||||
logger.info("Replaced system-scope codeflash-runtime dependency with test scope")
|
||||
return True
|
||||
logger.info("codeflash-runtime dependency already present in pom.xml")
|
||||
return True
|
||||
|
||||
# Find </dependencies> closing tag and insert before it
|
||||
closing_tag = "</dependencies>"
|
||||
idx = content.find(closing_tag)
|
||||
if idx == -1:
|
||||
logger.warning("No </dependencies> tag found in pom.xml, cannot add dependency")
|
||||
return False
|
||||
|
||||
new_content = content[:idx] + CODEFLASH_DEPENDENCY_SNIPPET
|
||||
# Skip the original </dependencies> tag since our snippet includes it
|
||||
new_content += content[idx + len(closing_tag) :]
|
||||
|
||||
pom_path.write_text(new_content, encoding="utf-8")
|
||||
logger.info("Added codeflash-runtime dependency to pom.xml")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.exception("Failed to add dependency to pom.xml: %s", e)
|
||||
return False
|
||||
|
||||
|
||||
def is_jacoco_configured(pom_path: Path) -> bool:
|
||||
"""Check if JaCoCo plugin is already configured in pom.xml.
|
||||
|
||||
Checks both the main build section and any profile build sections.
|
||||
|
||||
Args:
|
||||
pom_path: Path to the pom.xml file.
|
||||
|
||||
Returns:
|
||||
True if JaCoCo plugin is configured anywhere in the pom.xml, False otherwise.
|
||||
|
||||
"""
|
||||
if not pom_path.exists():
|
||||
return False
|
||||
|
||||
try:
|
||||
tree = _safe_parse_xml(pom_path)
|
||||
root = tree.getroot()
|
||||
|
||||
# Handle Maven namespace
|
||||
ns_prefix = "{http://maven.apache.org/POM/4.0.0}"
|
||||
|
||||
# Check if namespace is used
|
||||
use_ns = root.tag.startswith("{")
|
||||
if not use_ns:
|
||||
ns_prefix = ""
|
||||
|
||||
# Search all build/plugins sections (including those in profiles)
|
||||
# Using .// to search recursively for all plugin elements
|
||||
for plugin in root.findall(f".//{ns_prefix}plugin" if use_ns else ".//plugin"):
|
||||
artifact_id = plugin.find(f"{ns_prefix}artifactId" if use_ns else "artifactId")
|
||||
if artifact_id is not None and artifact_id.text == "jacoco-maven-plugin":
|
||||
group_id = plugin.find(f"{ns_prefix}groupId" if use_ns else "groupId")
|
||||
# Verify groupId if present (it's optional for org.jacoco)
|
||||
if group_id is None or group_id.text == "org.jacoco":
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
except ET.ParseError as e:
|
||||
logger.warning("Failed to parse pom.xml for JaCoCo check: %s", e)
|
||||
return False
|
||||
|
||||
|
||||
def add_jacoco_plugin_to_pom(pom_path: Path) -> bool:
|
||||
"""Add JaCoCo Maven plugin to pom.xml for coverage collection.
|
||||
|
||||
Uses string manipulation to preserve the original XML format and avoid
|
||||
namespace prefix issues that ElementTree causes.
|
||||
|
||||
Args:
|
||||
pom_path: Path to the pom.xml file.
|
||||
|
||||
Returns:
|
||||
True if plugin was added or already present, False on error.
|
||||
|
||||
"""
|
||||
if not pom_path.exists():
|
||||
logger.error("pom.xml not found: %s", pom_path)
|
||||
return False
|
||||
|
||||
# Check if already configured
|
||||
if is_jacoco_configured(pom_path):
|
||||
logger.info("JaCoCo plugin already configured in pom.xml")
|
||||
return True
|
||||
|
||||
try:
|
||||
content = pom_path.read_text(encoding="utf-8")
|
||||
|
||||
# Basic validation that it's a Maven pom.xml
|
||||
if "</project>" not in content:
|
||||
logger.error("Invalid pom.xml: no closing </project> tag found")
|
||||
return False
|
||||
|
||||
# JaCoCo plugin XML to insert (indented for typical pom.xml format)
|
||||
# Note: For multi-module projects where tests are in a separate module,
|
||||
# we configure the report to look in multiple directories for classes
|
||||
jacoco_plugin = f"""
|
||||
<plugin>
|
||||
<groupId>org.jacoco</groupId>
|
||||
<artifactId>jacoco-maven-plugin</artifactId>
|
||||
<version>{JACOCO_PLUGIN_VERSION}</version>
|
||||
<executions>
|
||||
<execution>
|
||||
<id>prepare-agent</id>
|
||||
<goals>
|
||||
<goal>prepare-agent</goal>
|
||||
</goals>
|
||||
</execution>
|
||||
<execution>
|
||||
<id>report</id>
|
||||
<phase>verify</phase>
|
||||
<goals>
|
||||
<goal>report</goal>
|
||||
</goals>
|
||||
<configuration>
|
||||
<!-- For multi-module projects, include dependency classes -->
|
||||
<includes>
|
||||
<include>**/*.class</include>
|
||||
</includes>
|
||||
</configuration>
|
||||
</execution>
|
||||
</executions>
|
||||
</plugin>"""
|
||||
|
||||
# Find the main <build> section (not inside <profiles>)
|
||||
# We need to find a <build> that appears after </profiles> or before <profiles>
|
||||
# or if there's no profiles section at all
|
||||
profiles_start = content.find("<profiles>")
|
||||
profiles_end = content.find("</profiles>")
|
||||
|
||||
# Find all <build> tags
|
||||
|
||||
# Find the main build section - it's the one NOT inside profiles
|
||||
# Strategy: Look for <build> that comes after </profiles> or before <profiles> (or no profiles)
|
||||
if profiles_start == -1:
|
||||
# No profiles, any <build> is the main one
|
||||
build_start = content.find("<build>")
|
||||
build_end = content.find("</build>")
|
||||
else:
|
||||
# Has profiles - find <build> outside of profiles
|
||||
# Check for <build> before <profiles>
|
||||
build_before_profiles = content[:profiles_start].rfind("<build>")
|
||||
# Check for <build> after </profiles>
|
||||
build_after_profiles = content[profiles_end:].find("<build>") if profiles_end != -1 else -1
|
||||
if build_after_profiles != -1:
|
||||
build_after_profiles += profiles_end
|
||||
|
||||
if build_before_profiles != -1:
|
||||
build_start = build_before_profiles
|
||||
# Find corresponding </build> - need to handle nested builds
|
||||
build_end = _find_closing_tag(content, build_start, "build")
|
||||
elif build_after_profiles != -1:
|
||||
build_start = build_after_profiles
|
||||
build_end = _find_closing_tag(content, build_start, "build")
|
||||
else:
|
||||
build_start = -1
|
||||
build_end = -1
|
||||
|
||||
if build_start != -1 and build_end != -1:
|
||||
# Found main build section, find plugins within it
|
||||
build_section = content[build_start : build_end + len("</build>")]
|
||||
plugins_start_in_build = build_section.find("<plugins>")
|
||||
plugins_end_in_build = build_section.rfind("</plugins>")
|
||||
|
||||
if plugins_start_in_build != -1 and plugins_end_in_build != -1:
|
||||
# Insert before </plugins> within the main build section
|
||||
absolute_plugins_end = build_start + plugins_end_in_build
|
||||
content = content[:absolute_plugins_end] + jacoco_plugin + "\n " + content[absolute_plugins_end:]
|
||||
else:
|
||||
# No plugins section in main build, add one before </build>
|
||||
plugins_section = f"<plugins>{jacoco_plugin}\n </plugins>\n "
|
||||
content = content[:build_end] + plugins_section + content[build_end:]
|
||||
else:
|
||||
# No main build section found, add one before </project>
|
||||
project_end = content.rfind("</project>")
|
||||
build_section = f"""
|
||||
<build>
|
||||
<plugins>{jacoco_plugin}
|
||||
</plugins>
|
||||
</build>
|
||||
"""
|
||||
content = content[:project_end] + build_section + content[project_end:]
|
||||
|
||||
pom_path.write_text(content, encoding="utf-8")
|
||||
logger.info("Added JaCoCo plugin to pom.xml")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.exception("Failed to add JaCoCo plugin to pom.xml: %s", e)
|
||||
return False
|
||||
|
||||
|
||||
def _find_closing_tag(content: str, start_pos: int, tag_name: str) -> int:
|
||||
"""Find the position of the closing tag that matches the opening tag at start_pos.
|
||||
|
||||
Handles nested tags of the same name.
|
||||
|
||||
Args:
|
||||
content: The XML content.
|
||||
start_pos: Position of the opening tag.
|
||||
tag_name: Name of the tag.
|
||||
|
||||
Returns:
|
||||
Position of the closing tag, or -1 if not found.
|
||||
|
||||
"""
|
||||
open_tag = f"<{tag_name}>"
|
||||
open_tag_short = f"<{tag_name} " # For tags with attributes
|
||||
close_tag = f"</{tag_name}>"
|
||||
|
||||
# Start searching after the opening tag we're matching
|
||||
depth = 1 # We've already found the opening tag at start_pos
|
||||
pos = start_pos + len(f"<{tag_name}") # Move past the opening tag
|
||||
|
||||
while pos < len(content):
|
||||
next_open = content.find(open_tag, pos)
|
||||
next_open_short = content.find(open_tag_short, pos)
|
||||
next_close = content.find(close_tag, pos)
|
||||
|
||||
if next_close == -1:
|
||||
return -1
|
||||
|
||||
# Find the earliest opening tag (if any)
|
||||
candidates = [x for x in [next_open, next_open_short] if x != -1 and x < next_close]
|
||||
next_open_any = min(candidates) if candidates else len(content) + 1
|
||||
|
||||
if next_open_any < next_close:
|
||||
# Found opening tag first - nested tag
|
||||
depth += 1
|
||||
pos = next_open_any + 1
|
||||
else:
|
||||
# Found closing tag first
|
||||
depth -= 1
|
||||
if depth == 0:
|
||||
return next_close
|
||||
pos = next_close + len(close_tag)
|
||||
|
||||
return -1
|
||||
|
||||
|
||||
def get_jacoco_xml_path(project_root: Path) -> Path:
|
||||
"""Get the expected path to the JaCoCo XML report.
|
||||
|
||||
Args:
|
||||
project_root: Root directory of the Maven project.
|
||||
|
||||
Returns:
|
||||
Path to the JaCoCo XML report file.
|
||||
|
||||
"""
|
||||
return project_root / "target" / "site" / "jacoco" / "jacoco.xml"
|
||||
|
||||
|
||||
def find_test_root(project_root: Path) -> Path | None:
|
||||
"""Find the test root directory for a Java project.
|
||||
|
||||
|
|
@ -1049,60 +393,3 @@ def find_source_root(project_root: Path) -> Path | None:
|
|||
return src_path
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def get_classpath(project_root: Path) -> str | None:
|
||||
"""Get the classpath for a Java project.
|
||||
|
||||
For Maven projects, this runs 'mvn dependency:build-classpath'.
|
||||
|
||||
Args:
|
||||
project_root: Root directory of the Java project.
|
||||
|
||||
Returns:
|
||||
Classpath string, or None if unable to determine.
|
||||
|
||||
"""
|
||||
build_tool = detect_build_tool(project_root)
|
||||
|
||||
if build_tool == BuildTool.MAVEN:
|
||||
return _get_maven_classpath(project_root)
|
||||
if build_tool == BuildTool.GRADLE:
|
||||
return _get_gradle_classpath(project_root)
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def _get_maven_classpath(project_root: Path) -> str | None:
|
||||
"""Get classpath from Maven."""
|
||||
mvn = find_maven_executable()
|
||||
if not mvn:
|
||||
return None
|
||||
|
||||
try:
|
||||
result = subprocess.run(
|
||||
[mvn, "dependency:build-classpath", "-q", "-DincludeScope=test", "-B"],
|
||||
check=False,
|
||||
cwd=project_root,
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=120,
|
||||
)
|
||||
|
||||
if result.returncode == 0:
|
||||
# The classpath is in stdout
|
||||
return result.stdout.strip()
|
||||
|
||||
except Exception as e:
|
||||
logger.warning("Failed to get Maven classpath: %s", e)
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def _get_gradle_classpath(project_root: Path) -> str | None:
|
||||
"""Get classpath from Gradle.
|
||||
|
||||
Note: This requires a custom task to be added to build.gradle.
|
||||
Returns None for now as Gradle support is not fully implemented.
|
||||
"""
|
||||
return None
|
||||
|
|
|
|||
|
|
@ -136,6 +136,7 @@ def compare_test_results(
|
|||
candidate_sqlite_path: Path,
|
||||
comparator_jar: Path | None = None,
|
||||
project_root: Path | None = None,
|
||||
project_classpath: str | None = None,
|
||||
) -> tuple[bool, list]:
|
||||
"""Compare Java test results using the codeflash-runtime Comparator.
|
||||
|
||||
|
|
@ -150,6 +151,10 @@ def compare_test_results(
|
|||
candidate_sqlite_path: Path to SQLite database with candidate code results.
|
||||
comparator_jar: Optional path to the codeflash-runtime JAR.
|
||||
project_root: Project root directory.
|
||||
project_classpath: Full project classpath from the build tool. When provided,
|
||||
the Comparator JVM uses this classpath so Kryo can resolve project-specific
|
||||
classes during deserialization. Without it, only the runtime JAR is on the
|
||||
classpath and any project class causes ClassNotFoundException.
|
||||
|
||||
Returns:
|
||||
Tuple of (all_equivalent, list of TestDiff objects).
|
||||
|
|
@ -180,6 +185,16 @@ def compare_test_results(
|
|||
|
||||
cwd = project_root or Path.cwd()
|
||||
|
||||
# Build classpath: runtime JAR + project classpath (if available).
|
||||
# The project classpath is needed so Kryo can resolve project-specific classes
|
||||
# during deserialization. Without it, Class.forName() fails for any type not
|
||||
# bundled in the runtime JAR.
|
||||
cp_separator = ";" if platform.system() == "Windows" else ":"
|
||||
if project_classpath:
|
||||
full_cp = f"{jar_path}{cp_separator}{project_classpath}"
|
||||
else:
|
||||
full_cp = str(jar_path)
|
||||
|
||||
try:
|
||||
result = subprocess.run(
|
||||
[
|
||||
|
|
@ -200,7 +215,7 @@ def compare_test_results(
|
|||
"--add-opens",
|
||||
"java.base/java.util.zip=ALL-UNNAMED",
|
||||
"-cp",
|
||||
str(jar_path),
|
||||
full_cp,
|
||||
"com.codeflash.Comparator",
|
||||
str(original_sqlite_path),
|
||||
str(candidate_sqlite_path),
|
||||
|
|
|
|||
|
|
@ -145,17 +145,28 @@ class JavaFunctionOptimizer(FunctionOptimizer):
|
|||
def _get_java_sources_root(self) -> Path:
|
||||
"""Get the Java sources root directory for test files.
|
||||
|
||||
For Java projects, tests_root might include the package path
|
||||
(e.g., test/src/com/aerospike/test). We need to find the base directory
|
||||
that should contain the package directories, not the tests_root itself.
|
||||
For multi-module projects (Kafka, OpenSearch, Spring Boot), the test
|
||||
directory must correspond to the same module as the source file, not
|
||||
the globally configured tests_root. When the source file follows the
|
||||
standard Maven/Gradle ``src/main/java`` layout we derive the test root
|
||||
by replacing ``main`` with ``test`` in the same module prefix.
|
||||
|
||||
This method looks for standard Java package prefixes (com, org, net, io, edu, gov)
|
||||
in the tests_root path and returns everything before that prefix.
|
||||
Falls back to the existing tests_root-based heuristics for non-standard
|
||||
layouts or single-module projects.
|
||||
|
||||
Returns:
|
||||
Path to the Java sources root directory.
|
||||
|
||||
"""
|
||||
file_path_str = str(self.function_to_optimize.file_path)
|
||||
src_main_java_marker = str(Path("src") / "main" / "java")
|
||||
idx = file_path_str.find(src_main_java_marker)
|
||||
if idx != -1:
|
||||
module_prefix = Path(file_path_str[:idx])
|
||||
derived_test_root = module_prefix / "src" / "test" / "java"
|
||||
if derived_test_root.exists():
|
||||
return derived_test_root
|
||||
|
||||
tests_root = self.test_cfg.tests_root
|
||||
parts = tests_root.parts
|
||||
|
||||
|
|
@ -326,15 +337,45 @@ class JavaFunctionOptimizer(FunctionOptimizer):
|
|||
original_sqlite = get_run_tmp_file(Path("test_return_values_0.sqlite"))
|
||||
candidate_sqlite = get_run_tmp_file(Path(f"test_return_values_{optimization_candidate_index}.sqlite"))
|
||||
|
||||
if len(baseline_results.behavior_test_results) == 0 or len(candidate_behavior_results) == 0:
|
||||
return False, []
|
||||
|
||||
if original_sqlite.exists() and candidate_sqlite.exists():
|
||||
match, diffs = self.language_support.compare_test_results(
|
||||
original_sqlite, candidate_sqlite, project_root=self.project_root
|
||||
original_sqlite,
|
||||
candidate_sqlite,
|
||||
project_root=self.project_root,
|
||||
project_classpath=self._get_project_classpath(),
|
||||
)
|
||||
candidate_sqlite.unlink(missing_ok=True)
|
||||
else:
|
||||
match, diffs = compare_test_results(baseline_results.behavior_test_results, candidate_behavior_results)
|
||||
return match, diffs
|
||||
|
||||
_cached_project_classpath: str | None
|
||||
|
||||
def _get_project_classpath(self) -> str | None:
|
||||
"""Get the project's full classpath from the build tool strategy.
|
||||
|
||||
The classpath is cached by the strategy after the first test run,
|
||||
so this is a cheap dict lookup.
|
||||
"""
|
||||
if hasattr(self, "_cached_project_classpath"):
|
||||
return self._cached_project_classpath
|
||||
|
||||
try:
|
||||
import os
|
||||
|
||||
from codeflash.languages.java.build_tool_strategy import get_strategy
|
||||
|
||||
strategy = get_strategy(self.project_root)
|
||||
classpath = strategy.get_classpath(self.project_root, os.environ.copy(), None, timeout=60)
|
||||
self._cached_project_classpath = classpath
|
||||
return classpath
|
||||
except Exception:
|
||||
logger.debug("Could not get project classpath for Comparator", exc_info=True)
|
||||
return None
|
||||
|
||||
def should_skip_sqlite_cleanup(self, testing_type: TestingMode, optimization_iteration: int) -> bool:
|
||||
return testing_type == TestingMode.BEHAVIOR or optimization_iteration == 0
|
||||
|
||||
|
|
|
|||
815
codeflash/languages/java/gradle_strategy.py
Normal file
815
codeflash/languages/java/gradle_strategy.py
Normal file
|
|
@ -0,0 +1,815 @@
|
|||
"""Gradle build tool strategy for Java projects.
|
||||
|
||||
Implements BuildToolStrategy for Gradle-based projects, handling compilation,
|
||||
classpath extraction, test execution, and JaCoCo coverage.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
import shutil
|
||||
import subprocess
|
||||
import tempfile
|
||||
import xml.etree.ElementTree as ET
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from codeflash.languages.java.build_tool_strategy import BuildToolStrategy, module_to_dir
|
||||
|
||||
_BUILD = "build"
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Groovy init script that disables validation/analysis plugins.
|
||||
# Equivalent to Maven's -Dcheckstyle.skip=true, -Dspotbugs.skip=true, etc.
|
||||
# Using an init script is safe even if the plugins aren't applied — unlike
|
||||
# `-x taskName` which fails if the task doesn't exist.
|
||||
_GRADLE_SKIP_VALIDATION_INIT_SCRIPT = """\
|
||||
gradle.projectsEvaluated {
|
||||
allprojects {
|
||||
tasks.matching { task ->
|
||||
task.name in [
|
||||
'checkstyleMain', 'checkstyleTest',
|
||||
'spotbugsMain', 'spotbugsTest',
|
||||
'pmdMain', 'pmdTest',
|
||||
'rat', 'japicmp'
|
||||
]
|
||||
}.configureEach {
|
||||
enabled = false
|
||||
}
|
||||
tasks.withType(JavaCompile) {
|
||||
options.compilerArgs.removeAll { it == '-Werror' }
|
||||
options.compilerArgs.removeAll { it == '-Xlint:all' }
|
||||
if (options.hasProperty('errorprone')) {
|
||||
options.errorprone {
|
||||
enabled = false
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
"""
|
||||
|
||||
# Lazily-created temp file for the validation-skip init script.
|
||||
_skip_validation_init_path: str | None = None
|
||||
|
||||
|
||||
def _get_skip_validation_init_script() -> str:
|
||||
"""Return the path to a persistent temp init script that disables validation tasks."""
|
||||
global _skip_validation_init_path
|
||||
if _skip_validation_init_path is None or not Path(_skip_validation_init_path).exists():
|
||||
fd, path = tempfile.mkstemp(suffix=".gradle", prefix="codeflash_skip_validation_")
|
||||
with os.fdopen(fd, "w", encoding="utf-8") as f:
|
||||
f.write(_GRADLE_SKIP_VALIDATION_INIT_SCRIPT)
|
||||
_skip_validation_init_path = path
|
||||
return _skip_validation_init_path
|
||||
|
||||
|
||||
# Cache for classpath strings — keyed on (gradle_root, test_module).
|
||||
_classpath_cache: dict[tuple[Path, str | None], str] = {}
|
||||
|
||||
# Cache for multi-module dependency installs — keyed on (gradle_root, test_module).
|
||||
_multimodule_deps_installed: set[tuple[Path, str]] = set()
|
||||
|
||||
# Gradle init script that prints the test runtime classpath.
|
||||
# Uses projectsEvaluated to avoid triggering configuration of unrelated subprojects.
|
||||
_CLASSPATH_INIT_SCRIPT = """\
|
||||
gradle.projectsEvaluated {
|
||||
allprojects {
|
||||
tasks.register("codeflashPrintClasspath") {
|
||||
doLast {
|
||||
def cp = configurations.findByName('testRuntimeClasspath')
|
||||
if (cp != null && cp.isCanBeResolved()) {
|
||||
println "CODEFLASH_CP_START"
|
||||
println cp.asPath
|
||||
println "CODEFLASH_CP_END"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
"""
|
||||
|
||||
# Gradle init script that applies JaCoCo plugin for coverage collection.
|
||||
# Uses projectsEvaluated to avoid triggering configuration of unrelated subprojects.
|
||||
_JACOCO_INIT_SCRIPT = """\
|
||||
gradle.projectsEvaluated {
|
||||
allprojects {
|
||||
apply plugin: 'jacoco'
|
||||
jacocoTestReport {
|
||||
reports {
|
||||
xml.required = true
|
||||
html.required = false
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
"""
|
||||
|
||||
|
||||
def find_gradle_build_file(project_root: Path) -> Path | None:
|
||||
kts = project_root / "build.gradle.kts"
|
||||
if kts.exists():
|
||||
return kts
|
||||
groovy = project_root / "build.gradle"
|
||||
if groovy.exists():
|
||||
return groovy
|
||||
return None
|
||||
|
||||
|
||||
def _find_top_level_dependencies_block(build_file: Path, content: str) -> int | None:
|
||||
"""Find the insert position (before closing }) of the top-level dependencies block using tree-sitter.
|
||||
|
||||
Returns the byte offset of the closing brace, or None if no top-level dependencies block exists.
|
||||
Only matches `dependencies { }` at the root level — ignores blocks nested inside
|
||||
`buildscript`, `subprojects`, `allprojects`, etc.
|
||||
"""
|
||||
import tree_sitter as ts
|
||||
|
||||
is_kts = build_file.name.endswith(".kts")
|
||||
source_bytes = content.encode("utf-8")
|
||||
|
||||
if is_kts:
|
||||
import tree_sitter_kotlin as tsk
|
||||
|
||||
parser = ts.Parser(ts.Language(tsk.language()))
|
||||
else:
|
||||
import tree_sitter_groovy as tsg
|
||||
|
||||
parser = ts.Parser(ts.Language(tsg.language()))
|
||||
|
||||
tree = parser.parse(source_bytes)
|
||||
|
||||
# Walk only direct children of root to find top-level `dependencies { }`
|
||||
for child in tree.root_node.children:
|
||||
# Groovy: expression_statement > method_invocation(identifier="dependencies", closure)
|
||||
# Kotlin: call_expression(identifier="dependencies", annotated_lambda)
|
||||
node = child
|
||||
if node.type == "expression_statement" and node.child_count > 0:
|
||||
node = node.children[0]
|
||||
|
||||
if node.type not in ("method_invocation", "call_expression"):
|
||||
continue
|
||||
|
||||
name_node = None
|
||||
body_node = None
|
||||
for c in node.children:
|
||||
if c.type == "identifier":
|
||||
name_node = c
|
||||
elif c.type in ("closure", "annotated_lambda", "lambda_literal"):
|
||||
body_node = c
|
||||
|
||||
if name_node is None or body_node is None:
|
||||
continue
|
||||
|
||||
name = source_bytes[name_node.start_byte : name_node.end_byte].decode("utf-8")
|
||||
if name != "dependencies":
|
||||
continue
|
||||
|
||||
# Find the closing brace of this block
|
||||
closing_brace = body_node.children[-1] if body_node.children else None
|
||||
# For Kotlin, annotated_lambda wraps lambda_literal
|
||||
if closing_brace is not None and closing_brace.type == "lambda_literal":
|
||||
closing_brace = closing_brace.children[-1] if closing_brace.children else None
|
||||
|
||||
if closing_brace is not None and closing_brace.type == "}":
|
||||
return closing_brace.start_byte
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def add_codeflash_dependency(build_file: Path, runtime_jar_path: Path) -> bool:
|
||||
if not build_file.exists():
|
||||
return False
|
||||
|
||||
try:
|
||||
content = build_file.read_text(encoding="utf-8")
|
||||
|
||||
if "codeflash-runtime" in content:
|
||||
logger.info("codeflash-runtime dependency already present in %s", build_file.name)
|
||||
return True
|
||||
|
||||
is_kts = build_file.name.endswith(".kts")
|
||||
jar_str = str(runtime_jar_path).replace("\\", "/")
|
||||
|
||||
if is_kts:
|
||||
dep_line = f' testImplementation(files("{jar_str}")) // codeflash-runtime\n'
|
||||
else:
|
||||
dep_line = f" testImplementation files('{jar_str}') // codeflash-runtime\n"
|
||||
|
||||
# Use tree-sitter to find the top-level dependencies block
|
||||
insert_pos = _find_top_level_dependencies_block(build_file, content)
|
||||
if insert_pos is not None:
|
||||
content = content[:insert_pos] + dep_line + content[insert_pos:]
|
||||
build_file.write_text(content, encoding="utf-8")
|
||||
logger.info("Added codeflash-runtime dependency to %s (tree-sitter)", build_file.name)
|
||||
return True
|
||||
|
||||
# No existing dependencies block — append one
|
||||
if is_kts:
|
||||
content += f'\ndependencies {{\n testImplementation(files("{jar_str}")) // codeflash-runtime\n}}\n'
|
||||
else:
|
||||
content += f"\ndependencies {{\n testImplementation files('{jar_str}') // codeflash-runtime\n}}\n"
|
||||
build_file.write_text(content, encoding="utf-8")
|
||||
logger.info("Added codeflash-runtime dependency to %s (new block)", build_file.name)
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.exception("Failed to add dependency to %s: %s", build_file.name, e)
|
||||
return False
|
||||
|
||||
|
||||
def _normalize_gradle_xml_reports(reports_dir: Path) -> None:
|
||||
"""Normalize Gradle JUnit XML reports to match Maven Surefire format.
|
||||
|
||||
Gradle's JUnit Platform XML differs from Maven Surefire in ways that
|
||||
can crash the downstream parser:
|
||||
1. <failure>/<error> elements may omit the ``message`` attribute —
|
||||
Maven always sets it.
|
||||
2. Timeout information may only appear in the element body text,
|
||||
not in the ``message`` attribute.
|
||||
|
||||
This function rewrites the XML files in-place so they conform to the
|
||||
Maven Surefire contract the parser expects.
|
||||
"""
|
||||
if not reports_dir.exists():
|
||||
return
|
||||
for xml_file in reports_dir.glob("TEST-*.xml"):
|
||||
try:
|
||||
tree = ET.parse(xml_file)
|
||||
root = tree.getroot()
|
||||
modified = False
|
||||
for tag in ("failure", "error"):
|
||||
for elem in root.iter(tag):
|
||||
if elem.get("message") is None:
|
||||
body = (elem.text or "").strip()
|
||||
first_line = body.split("\n", 1)[0] if body else ""
|
||||
elem.set("message", first_line)
|
||||
modified = True
|
||||
if modified:
|
||||
tree.write(xml_file, encoding="unicode", xml_declaration=True)
|
||||
except ET.ParseError:
|
||||
logger.debug("Failed to normalize Gradle XML report %s", xml_file)
|
||||
|
||||
|
||||
class GradleStrategy(BuildToolStrategy):
|
||||
"""Gradle-specific build tool operations."""
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return "Gradle"
|
||||
|
||||
def find_executable(self, build_root: Path) -> str | None:
|
||||
# Walk up from build_root to find gradlew — for multi-module projects
|
||||
# the wrapper lives at the repo root, which may be a parent of build_root.
|
||||
current = build_root.resolve()
|
||||
while True:
|
||||
gradlew_path = current / "gradlew"
|
||||
if gradlew_path.exists():
|
||||
return str(gradlew_path)
|
||||
gradlew_bat_path = current / "gradlew.bat"
|
||||
if gradlew_bat_path.exists():
|
||||
return str(gradlew_bat_path)
|
||||
parent = current.parent
|
||||
if parent == current:
|
||||
break
|
||||
current = parent
|
||||
# Fall back to system Gradle
|
||||
return shutil.which("gradle")
|
||||
|
||||
def ensure_runtime(self, build_root: Path, test_module: str | None) -> bool:
|
||||
runtime_jar = self.find_runtime_jar()
|
||||
if runtime_jar is None:
|
||||
logger.error("codeflash-runtime JAR not found. Generated tests will fail to compile.")
|
||||
return False
|
||||
|
||||
if test_module:
|
||||
module_root = build_root / module_to_dir(test_module)
|
||||
else:
|
||||
module_root = build_root
|
||||
|
||||
libs_dir = module_root / "libs"
|
||||
libs_dir.mkdir(parents=True, exist_ok=True)
|
||||
dest_jar = libs_dir / "codeflash-runtime-1.0.0.jar"
|
||||
|
||||
if not dest_jar.exists():
|
||||
logger.info("Copying codeflash-runtime JAR to %s", dest_jar)
|
||||
shutil.copy2(runtime_jar, dest_jar)
|
||||
|
||||
build_file = find_gradle_build_file(module_root)
|
||||
if build_file is None:
|
||||
logger.warning("No build.gradle(.kts) found at %s, cannot add codeflash-runtime dependency", module_root)
|
||||
return False
|
||||
|
||||
if not add_codeflash_dependency(build_file, dest_jar):
|
||||
logger.error("Failed to add codeflash-runtime dependency to %s", build_file)
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
def install_multi_module_deps(self, build_root: Path, test_module: str | None, env: dict[str, str]) -> bool:
|
||||
from codeflash.languages.java.test_runner import _run_cmd_kill_pg_on_timeout
|
||||
|
||||
if not test_module:
|
||||
return True
|
||||
|
||||
cache_key = (build_root, test_module)
|
||||
if cache_key in _multimodule_deps_installed:
|
||||
logger.debug("Multi-module deps already installed for %s:%s, skipping", build_root, test_module)
|
||||
return True
|
||||
|
||||
gradle = self.find_executable(build_root)
|
||||
if not gradle:
|
||||
logger.error("Gradle not found — cannot pre-install multi-module dependencies")
|
||||
return False
|
||||
|
||||
cmd = [gradle, f":{test_module}:classes", "-x", "test", "--build-cache", "--no-daemon"]
|
||||
cmd.extend(["--init-script", _get_skip_validation_init_script()])
|
||||
|
||||
logger.info("Pre-installing multi-module dependencies: %s (module: %s)", build_root, test_module)
|
||||
logger.debug("Running: %s", " ".join(cmd))
|
||||
|
||||
try:
|
||||
result = _run_cmd_kill_pg_on_timeout(cmd, cwd=build_root, env=env, timeout=300)
|
||||
if result.returncode != 0:
|
||||
logger.error(
|
||||
"Failed to pre-install multi-module deps (exit %d).\nstdout: %s\nstderr: %s",
|
||||
result.returncode,
|
||||
result.stdout[-2000:] if result.stdout else "",
|
||||
result.stderr[-2000:] if result.stderr else "",
|
||||
)
|
||||
return False
|
||||
except Exception:
|
||||
logger.exception("Exception during multi-module dependency install")
|
||||
return False
|
||||
|
||||
_multimodule_deps_installed.add(cache_key)
|
||||
logger.info("Multi-module dependencies installed successfully for %s:%s", build_root, test_module)
|
||||
return True
|
||||
|
||||
def compile_tests(
|
||||
self, build_root: Path, env: dict[str, str], test_module: str | None, timeout: int = 120
|
||||
) -> subprocess.CompletedProcess[str]:
|
||||
from codeflash.languages.java.test_runner import _run_cmd_kill_pg_on_timeout
|
||||
|
||||
gradle = self.find_executable(build_root)
|
||||
if not gradle:
|
||||
logger.error("Gradle not found")
|
||||
return subprocess.CompletedProcess(args=["gradle"], returncode=-1, stdout="", stderr="Gradle not found")
|
||||
|
||||
if test_module:
|
||||
cmd = [gradle, f":{test_module}:testClasses", "--no-daemon"]
|
||||
else:
|
||||
cmd = [gradle, "testClasses", "--no-daemon"]
|
||||
cmd.extend(["--init-script", _get_skip_validation_init_script()])
|
||||
|
||||
logger.debug("Compiling tests: %s in %s", " ".join(cmd), build_root)
|
||||
|
||||
try:
|
||||
return _run_cmd_kill_pg_on_timeout(cmd, cwd=build_root, env=env, timeout=timeout)
|
||||
except Exception as e:
|
||||
logger.exception("Gradle compilation failed: %s", e)
|
||||
return subprocess.CompletedProcess(args=cmd, returncode=-1, stdout="", stderr=str(e))
|
||||
|
||||
def compile_source_only(
|
||||
self, build_root: Path, env: dict[str, str], test_module: str | None, timeout: int = 120
|
||||
) -> subprocess.CompletedProcess[str]:
|
||||
from codeflash.languages.java.test_runner import _run_cmd_kill_pg_on_timeout
|
||||
|
||||
gradle = self.find_executable(build_root)
|
||||
if not gradle:
|
||||
logger.error("Gradle not found")
|
||||
return subprocess.CompletedProcess(args=["gradle"], returncode=-1, stdout="", stderr="Gradle not found")
|
||||
|
||||
if test_module:
|
||||
cmd = [gradle, f":{test_module}:classes", "--no-daemon"]
|
||||
else:
|
||||
cmd = [gradle, "classes", "--no-daemon"]
|
||||
cmd.extend(["--init-script", _get_skip_validation_init_script()])
|
||||
|
||||
logger.debug("Compiling source only: %s in %s", " ".join(cmd), build_root)
|
||||
|
||||
try:
|
||||
return _run_cmd_kill_pg_on_timeout(cmd, cwd=build_root, env=env, timeout=timeout)
|
||||
except Exception as e:
|
||||
logger.exception("Gradle source compilation failed: %s", e)
|
||||
return subprocess.CompletedProcess(args=cmd, returncode=-1, stdout="", stderr=str(e))
|
||||
|
||||
def get_classpath(
|
||||
self, build_root: Path, env: dict[str, str], test_module: str | None, timeout: int = 60
|
||||
) -> str | None:
|
||||
key = (build_root, test_module)
|
||||
cached = _classpath_cache.get(key)
|
||||
if cached is not None:
|
||||
logger.debug("Using cached classpath for (%s, %s)", build_root, test_module)
|
||||
return cached
|
||||
result = self._get_classpath_uncached(build_root, env, test_module, timeout)
|
||||
if result is not None:
|
||||
_classpath_cache[key] = result
|
||||
return result
|
||||
|
||||
def _get_classpath_uncached(
|
||||
self, build_root: Path, env: dict[str, str], test_module: str | None, timeout: int = 60
|
||||
) -> str | None:
|
||||
from codeflash.languages.java.test_runner import _find_junit_console_standalone, _run_cmd_kill_pg_on_timeout
|
||||
|
||||
gradle = self.find_executable(build_root)
|
||||
if not gradle:
|
||||
return None
|
||||
|
||||
# Write init script to a temp file
|
||||
init_script_fd, init_script_path = tempfile.mkstemp(suffix=".gradle", prefix="codeflash_cp_")
|
||||
try:
|
||||
with os.fdopen(init_script_fd, "w", encoding="utf-8") as f:
|
||||
f.write(_CLASSPATH_INIT_SCRIPT)
|
||||
|
||||
if test_module:
|
||||
task = f":{test_module}:codeflashPrintClasspath"
|
||||
else:
|
||||
task = "codeflashPrintClasspath"
|
||||
|
||||
cmd = [gradle, "--init-script", init_script_path, task, "-q", "--no-daemon"]
|
||||
|
||||
logger.debug("Getting classpath: %s", " ".join(cmd))
|
||||
|
||||
result = _run_cmd_kill_pg_on_timeout(cmd, cwd=build_root, env=env, timeout=timeout)
|
||||
|
||||
if result.returncode != 0:
|
||||
logger.error("Failed to get classpath: %s", result.stderr)
|
||||
return None
|
||||
|
||||
classpath = self._parse_classpath_output(result.stdout)
|
||||
if not classpath:
|
||||
logger.error("Classpath not found in Gradle output")
|
||||
return None
|
||||
|
||||
if test_module:
|
||||
module_path = build_root / module_to_dir(test_module)
|
||||
else:
|
||||
module_path = build_root
|
||||
|
||||
test_classes = module_path / "build" / "classes" / "java" / "test"
|
||||
main_classes = module_path / "build" / "classes" / "java" / "main"
|
||||
|
||||
cp_parts = [classpath]
|
||||
if test_classes.exists():
|
||||
cp_parts.append(str(test_classes))
|
||||
if main_classes.exists():
|
||||
cp_parts.append(str(main_classes))
|
||||
|
||||
if test_module:
|
||||
module_dir_name = module_to_dir(test_module)
|
||||
for module_dir in build_root.iterdir():
|
||||
if module_dir.is_dir() and module_dir.name != module_dir_name:
|
||||
module_classes = module_dir / "build" / "classes" / "java" / "main"
|
||||
if module_classes.exists():
|
||||
logger.debug("Adding multi-module classpath: %s", module_classes)
|
||||
cp_parts.append(str(module_classes))
|
||||
|
||||
if "console-standalone" not in classpath and "ConsoleLauncher" not in classpath:
|
||||
console_jar = _find_junit_console_standalone()
|
||||
if console_jar:
|
||||
logger.debug("Adding JUnit Console Standalone to classpath: %s", console_jar)
|
||||
cp_parts.append(str(console_jar))
|
||||
|
||||
return os.pathsep.join(cp_parts)
|
||||
|
||||
except Exception as e:
|
||||
logger.exception("Failed to get classpath: %s", e)
|
||||
return None
|
||||
finally:
|
||||
Path(init_script_path).unlink(missing_ok=True)
|
||||
|
||||
@staticmethod
|
||||
def _parse_classpath_output(stdout: str) -> str | None:
|
||||
in_cp = False
|
||||
for line in stdout.splitlines():
|
||||
if line.strip() == "CODEFLASH_CP_START":
|
||||
in_cp = True
|
||||
continue
|
||||
if line.strip() == "CODEFLASH_CP_END":
|
||||
break
|
||||
if in_cp and line.strip():
|
||||
return line.strip()
|
||||
return None
|
||||
|
||||
def get_reports_dir(self, build_root: Path, test_module: str | None) -> Path:
|
||||
build_dir = self.get_build_output_dir(build_root, test_module)
|
||||
return build_dir / "test-results" / "test"
|
||||
|
||||
def get_build_output_dir(self, build_root: Path, test_module: str | None) -> Path:
|
||||
if test_module:
|
||||
return build_root.joinpath(module_to_dir(test_module), _BUILD)
|
||||
return build_root.joinpath(_BUILD)
|
||||
|
||||
def run_tests_via_build_tool(
|
||||
self,
|
||||
build_root: Path,
|
||||
test_paths: Any,
|
||||
env: dict[str, str],
|
||||
timeout: int,
|
||||
mode: str,
|
||||
test_module: str | None,
|
||||
javaagent_arg: str | None = None,
|
||||
enable_coverage: bool = False,
|
||||
) -> subprocess.CompletedProcess[str]:
|
||||
from codeflash.languages.java.test_runner import _build_test_filter, _run_cmd_kill_pg_on_timeout
|
||||
|
||||
gradle = self.find_executable(build_root)
|
||||
if not gradle:
|
||||
logger.error("Gradle not found")
|
||||
return subprocess.CompletedProcess(args=["gradle"], returncode=-1, stdout="", stderr="Gradle not found")
|
||||
|
||||
test_filter = _build_test_filter(test_paths, mode=mode)
|
||||
logger.debug("Built test filter for mode=%s: '%s' (empty=%s)", mode, test_filter, not test_filter)
|
||||
|
||||
if test_module:
|
||||
task = f":{test_module}:test"
|
||||
else:
|
||||
task = "test"
|
||||
|
||||
# Write an init script that configures JVM args for the test task.
|
||||
# -Dorg.gradle.jvmargs only affects the Gradle daemon, NOT the forked test JVM.
|
||||
add_opens = [
|
||||
"--add-opens",
|
||||
"java.base/java.util=ALL-UNNAMED",
|
||||
"--add-opens",
|
||||
"java.base/java.lang=ALL-UNNAMED",
|
||||
"--add-opens",
|
||||
"java.base/java.lang.reflect=ALL-UNNAMED",
|
||||
"--add-opens",
|
||||
"java.base/java.io=ALL-UNNAMED",
|
||||
"--add-opens",
|
||||
"java.base/java.math=ALL-UNNAMED",
|
||||
"--add-opens",
|
||||
"java.base/java.net=ALL-UNNAMED",
|
||||
"--add-opens",
|
||||
"java.base/java.util.zip=ALL-UNNAMED",
|
||||
]
|
||||
all_jvm_args = list(add_opens)
|
||||
if javaagent_arg:
|
||||
all_jvm_args.insert(0, javaagent_arg)
|
||||
|
||||
per_test_timeout = max(timeout // 3, 10)
|
||||
quoted_args = ", ".join(f'"{a}"' for a in all_jvm_args)
|
||||
init_script_content = (
|
||||
f"gradle.projectsEvaluated {{\n"
|
||||
f" allprojects {{\n"
|
||||
f" tasks.withType(Test) {{\n"
|
||||
f" jvmArgs({quoted_args})\n"
|
||||
f' systemProperty "junit.jupiter.execution.timeout.default", "{per_test_timeout}s"\n'
|
||||
f" reports.junitXml.outputPerTestCase = true\n"
|
||||
f" filter.failOnNoMatchingTests = false\n"
|
||||
f" }}\n"
|
||||
f" }}\n"
|
||||
f"}}\n"
|
||||
)
|
||||
|
||||
if not test_filter:
|
||||
error_msg = (
|
||||
f"Test filter is EMPTY for mode={mode}! "
|
||||
f"Gradle will run ALL tests instead of the specified tests. "
|
||||
f"This indicates a problem with test file instrumentation or path resolution."
|
||||
)
|
||||
logger.error(error_msg)
|
||||
raise ValueError(error_msg)
|
||||
|
||||
init_fd, init_path = tempfile.mkstemp(suffix=".gradle", prefix="codeflash_jvmargs_")
|
||||
try:
|
||||
with os.fdopen(init_fd, "w", encoding="utf-8") as f:
|
||||
f.write(init_script_content)
|
||||
|
||||
cmd = [gradle, task, "--no-daemon", "--rerun", "--init-script", init_path]
|
||||
cmd.extend(["--init-script", _get_skip_validation_init_script()])
|
||||
|
||||
# --continue ensures Gradle keeps going even if some tests fail.
|
||||
# For coverage: needed so jacocoTestReport runs even after test failures
|
||||
# (matches Maven's -Dmaven.test.failure.ignore=true).
|
||||
# Note: multi-module --tests filtering is handled by
|
||||
# filter.failOnNoMatchingTests = false in the init script above
|
||||
# (matches Maven's -DfailIfNoTests=false).
|
||||
if enable_coverage:
|
||||
cmd.append("--continue")
|
||||
|
||||
for class_filter in test_filter.split(","):
|
||||
class_filter = class_filter.strip()
|
||||
if class_filter:
|
||||
cmd.extend(["--tests", class_filter])
|
||||
logger.debug("Added --tests filters to Gradle command")
|
||||
|
||||
# Append jacocoTestReport AFTER --tests so Gradle doesn't try to apply --tests to it
|
||||
if enable_coverage:
|
||||
cmd.append("jacocoTestReport")
|
||||
|
||||
logger.debug("Running Gradle command: %s in %s", " ".join(cmd), build_root)
|
||||
|
||||
result = _run_cmd_kill_pg_on_timeout(cmd, cwd=build_root, env=env, timeout=timeout)
|
||||
|
||||
# Normalize XML reports so <failure>/<error> always have a message
|
||||
# attribute — Maven Surefire always sets it, Gradle may omit it.
|
||||
reports_dir = self.get_reports_dir(build_root, test_module)
|
||||
_normalize_gradle_xml_reports(reports_dir)
|
||||
|
||||
if result.returncode != 0:
|
||||
compilation_error_indicators = [
|
||||
"Compilation failed",
|
||||
"COMPILATION ERROR",
|
||||
"cannot find symbol",
|
||||
"error: package",
|
||||
]
|
||||
combined_output = (result.stdout or "") + (result.stderr or "")
|
||||
has_compilation_error = any(
|
||||
indicator.lower() in combined_output.lower() for indicator in compilation_error_indicators
|
||||
)
|
||||
if has_compilation_error:
|
||||
logger.error(
|
||||
"Gradle compilation failed for %s tests. "
|
||||
"Check that generated test code is syntactically valid Java. "
|
||||
"Return code: %s",
|
||||
mode,
|
||||
result.returncode,
|
||||
)
|
||||
output_lines = combined_output.split("\n")
|
||||
error_context = "\n".join(output_lines[:50]) if len(output_lines) > 50 else combined_output
|
||||
logger.error("Gradle compilation error output:\n%s", error_context)
|
||||
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
logger.exception("Gradle test execution failed: %s", e)
|
||||
return subprocess.CompletedProcess(args=cmd, returncode=-1, stdout="", stderr=str(e))
|
||||
finally:
|
||||
Path(init_path).unlink(missing_ok=True)
|
||||
|
||||
def run_benchmarking_via_build_tool(
|
||||
self,
|
||||
test_paths: Any,
|
||||
test_env: dict[str, str],
|
||||
cwd: Path,
|
||||
timeout: int | None,
|
||||
project_root: Path | None,
|
||||
min_loops: int,
|
||||
max_loops: int,
|
||||
target_duration_seconds: float,
|
||||
inner_iterations: int,
|
||||
) -> tuple[Path, Any]:
|
||||
import time
|
||||
|
||||
from codeflash.languages.java.test_runner import _find_multi_module_root, _get_combined_junit_xml
|
||||
|
||||
project_root = project_root or cwd
|
||||
gradle_root, test_module = _find_multi_module_root(project_root, test_paths)
|
||||
|
||||
all_stdout: list[str] = []
|
||||
all_stderr: list[str] = []
|
||||
total_start_time = time.time()
|
||||
loop_count = 0
|
||||
last_result = None
|
||||
|
||||
per_loop_timeout = max(timeout or 0, 120, 60 + inner_iterations)
|
||||
|
||||
logger.debug("Using Gradle-based benchmarking (fallback mode)")
|
||||
|
||||
for loop_idx in range(1, max_loops + 1):
|
||||
run_env = os.environ.copy()
|
||||
run_env.update(test_env)
|
||||
run_env["CODEFLASH_LOOP_INDEX"] = str(loop_idx)
|
||||
run_env["CODEFLASH_MODE"] = "performance"
|
||||
run_env["CODEFLASH_TEST_ITERATION"] = "0"
|
||||
if "CODEFLASH_INNER_ITERATIONS" not in run_env:
|
||||
run_env["CODEFLASH_INNER_ITERATIONS"] = str(inner_iterations)
|
||||
|
||||
result = self.run_tests_via_build_tool(
|
||||
gradle_root, test_paths, run_env, timeout=per_loop_timeout, mode="performance", test_module=test_module
|
||||
)
|
||||
|
||||
last_result = result
|
||||
loop_count = loop_idx
|
||||
|
||||
if result.stdout:
|
||||
all_stdout.append(result.stdout)
|
||||
if result.stderr:
|
||||
all_stderr.append(result.stderr)
|
||||
|
||||
elapsed = time.time() - total_start_time
|
||||
if loop_idx >= min_loops and elapsed >= target_duration_seconds:
|
||||
logger.debug("Stopping Gradle benchmark after %d loops (%.2fs elapsed)", loop_idx, elapsed)
|
||||
break
|
||||
|
||||
if result.returncode != 0:
|
||||
timing_pattern = re.compile(r"!######[^:]*:[^:]*:[^:]*:[^:]*:[^:]+:[^:]+######!")
|
||||
has_timing_markers = bool(timing_pattern.search(result.stdout or ""))
|
||||
if not has_timing_markers:
|
||||
logger.warning("Tests failed in Gradle loop %d with no timing markers, stopping", loop_idx)
|
||||
break
|
||||
logger.debug("Some tests failed in Gradle loop %d but timing markers present, continuing", loop_idx)
|
||||
|
||||
combined_stdout = "\n".join(all_stdout)
|
||||
combined_stderr = "\n".join(all_stderr)
|
||||
|
||||
total_iterations = loop_count * inner_iterations
|
||||
logger.debug(
|
||||
"Gradle fallback: %d loops x %d iterations = %d total in %.2fs",
|
||||
loop_count,
|
||||
inner_iterations,
|
||||
total_iterations,
|
||||
time.time() - total_start_time,
|
||||
)
|
||||
|
||||
combined_result = subprocess.CompletedProcess(
|
||||
args=last_result.args if last_result else ["gradle", "test"],
|
||||
returncode=last_result.returncode if last_result else -1,
|
||||
stdout=combined_stdout,
|
||||
stderr=combined_stderr,
|
||||
)
|
||||
|
||||
reports_dir = self.get_reports_dir(gradle_root, test_module)
|
||||
result_xml_path = _get_combined_junit_xml(reports_dir, -1)
|
||||
|
||||
return result_xml_path, combined_result
|
||||
|
||||
def run_tests_with_coverage(
|
||||
self,
|
||||
build_root: Path,
|
||||
test_module: str | None,
|
||||
test_paths: Any,
|
||||
run_env: dict[str, str],
|
||||
timeout: int,
|
||||
candidate_index: int,
|
||||
) -> tuple[subprocess.CompletedProcess[str], Path, Path | None]:
|
||||
from codeflash.languages.java.test_runner import _get_combined_junit_xml
|
||||
|
||||
coverage_xml_path = self.setup_coverage(build_root, test_module, build_root)
|
||||
|
||||
result = self.run_tests_via_build_tool(
|
||||
build_root,
|
||||
test_paths,
|
||||
run_env,
|
||||
timeout=timeout,
|
||||
mode="behavior",
|
||||
enable_coverage=True,
|
||||
test_module=test_module,
|
||||
)
|
||||
|
||||
reports_dir = self.get_reports_dir(build_root, test_module)
|
||||
result_xml_path = _get_combined_junit_xml(reports_dir, candidate_index)
|
||||
|
||||
return result, result_xml_path, coverage_xml_path
|
||||
|
||||
def setup_coverage(self, build_root: Path, test_module: str | None, project_root: Path) -> Path | None:
|
||||
if test_module:
|
||||
module_root = build_root / module_to_dir(test_module)
|
||||
else:
|
||||
module_root = project_root
|
||||
|
||||
build_file = find_gradle_build_file(module_root)
|
||||
if build_file is None:
|
||||
logger.warning("No build.gradle(.kts) found at %s, cannot setup JaCoCo", module_root)
|
||||
return None
|
||||
|
||||
content = build_file.read_text(encoding="utf-8")
|
||||
if "jacoco" not in content.lower():
|
||||
logger.info("Adding JaCoCo plugin to %s for coverage collection", build_file.name)
|
||||
is_kts = build_file.name.endswith(".kts")
|
||||
if is_kts:
|
||||
plugin_line = "plugins {\n jacoco\n}\n"
|
||||
else:
|
||||
plugin_line = "apply plugin: 'jacoco'\n"
|
||||
|
||||
if "plugins {" in content or "plugins{" in content:
|
||||
# Insert jacoco inside existing plugins block
|
||||
plugins_idx = content.find("plugins")
|
||||
brace_depth = 0
|
||||
for i in range(plugins_idx, len(content)):
|
||||
if content[i] == "{":
|
||||
brace_depth += 1
|
||||
elif content[i] == "}":
|
||||
brace_depth -= 1
|
||||
if brace_depth == 0:
|
||||
insert = " jacoco\n" if is_kts else " id 'jacoco'\n"
|
||||
content = content[:i] + insert + content[i:]
|
||||
break
|
||||
else:
|
||||
content = plugin_line + content
|
||||
|
||||
build_file.write_text(content, encoding="utf-8")
|
||||
|
||||
return module_root / "build" / "reports" / "jacoco" / "test" / "jacocoTestReport.xml"
|
||||
|
||||
def get_test_run_command(self, project_root: Path, test_classes: list[str] | None = None) -> list[str]:
|
||||
from codeflash.languages.java.test_runner import _validate_java_class_name
|
||||
|
||||
if test_classes:
|
||||
for test_class in test_classes:
|
||||
if not _validate_java_class_name(test_class):
|
||||
msg = f"Invalid test class name: '{test_class}'. Test names must follow Java identifier rules."
|
||||
raise ValueError(msg)
|
||||
|
||||
gradle = self.find_executable(project_root) or "gradle"
|
||||
cmd = [gradle, "test", "--no-daemon"]
|
||||
if test_classes:
|
||||
for cls in test_classes:
|
||||
cmd.extend(["--tests", cls])
|
||||
return cmd
|
||||
865
codeflash/languages/java/maven_strategy.py
Normal file
865
codeflash/languages/java/maven_strategy.py
Normal file
|
|
@ -0,0 +1,865 @@
|
|||
"""Maven build tool strategy for Java projects.
|
||||
|
||||
Implements BuildToolStrategy for Maven-based projects, handling compilation,
|
||||
classpath extraction, test execution via Surefire, and JaCoCo coverage.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
import shutil
|
||||
import subprocess
|
||||
import urllib.request
|
||||
import xml.etree.ElementTree as ET
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from codeflash.languages.java.build_tool_strategy import BuildToolStrategy, module_to_dir
|
||||
from codeflash.languages.java.build_tools import CODEFLASH_RUNTIME_JAR_NAME, CODEFLASH_RUNTIME_VERSION
|
||||
|
||||
_TARGET = "target"
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Skip validation/analysis plugins that reject generated instrumented files
|
||||
_MAVEN_VALIDATION_SKIP_FLAGS = [
|
||||
"-Drat.skip=true",
|
||||
"-Dcheckstyle.skip=true",
|
||||
"-Dspotbugs.skip=true",
|
||||
"-Dpmd.skip=true",
|
||||
"-Denforcer.skip=true",
|
||||
"-Djapicmp.skip=true",
|
||||
"-Derrorprone.skip=true",
|
||||
"-Dmaven.compiler.failOnWarning=false",
|
||||
"-Dmaven.compiler.showWarnings=false",
|
||||
]
|
||||
|
||||
# Cache for classpath strings — keyed on (maven_root, test_module).
|
||||
_classpath_cache: dict[tuple[Path, str | None], str] = {}
|
||||
|
||||
# Cache for multi-module dependency installs — keyed on (maven_root, test_module).
|
||||
_multimodule_deps_installed: set[tuple[Path, str]] = set()
|
||||
|
||||
JACOCO_PLUGIN_VERSION = "0.8.13"
|
||||
|
||||
GITHUB_RELEASE_URL = (
|
||||
"https://github.com/codeflash-ai/codeflash/releases/download"
|
||||
f"/runtime-v{CODEFLASH_RUNTIME_VERSION}/{CODEFLASH_RUNTIME_JAR_NAME}"
|
||||
)
|
||||
|
||||
CODEFLASH_CACHE_DIR = Path.home() / ".cache" / "codeflash"
|
||||
|
||||
CODEFLASH_DEPENDENCY_SNIPPET = """\
|
||||
<dependency>
|
||||
<groupId>com.codeflash</groupId>
|
||||
<artifactId>codeflash-runtime</artifactId>
|
||||
<version>1.0.0</version>
|
||||
<scope>test</scope>
|
||||
</dependency>
|
||||
</dependencies>"""
|
||||
|
||||
|
||||
def download_from_github_releases() -> Path | None:
|
||||
"""Download codeflash-runtime JAR from GitHub Releases.
|
||||
|
||||
Downloads to ~/.cache/codeflash/ and returns the path to the downloaded JAR.
|
||||
Returns None if the download fails (e.g., no release published yet, network error).
|
||||
"""
|
||||
cache_jar = CODEFLASH_CACHE_DIR / CODEFLASH_RUNTIME_JAR_NAME
|
||||
if cache_jar.exists():
|
||||
logger.info("Found cached codeflash-runtime JAR: %s", cache_jar)
|
||||
return cache_jar
|
||||
|
||||
try:
|
||||
CODEFLASH_CACHE_DIR.mkdir(parents=True, exist_ok=True)
|
||||
logger.info("Downloading codeflash-runtime from GitHub Releases: %s", GITHUB_RELEASE_URL)
|
||||
urllib.request.urlretrieve(GITHUB_RELEASE_URL, cache_jar) # noqa: S310
|
||||
logger.info("Downloaded codeflash-runtime to %s", cache_jar)
|
||||
return cache_jar
|
||||
except Exception as e:
|
||||
logger.debug("GitHub Releases download failed: %s", e)
|
||||
cache_jar.unlink(missing_ok=True)
|
||||
return None
|
||||
|
||||
|
||||
def resolve_from_maven_central(maven_root: Path) -> bool:
|
||||
"""Ask Maven to resolve codeflash-runtime from Maven Central.
|
||||
|
||||
Downloads the JAR to ~/.m2/repository/ automatically.
|
||||
Returns True if Maven successfully resolved the artifact.
|
||||
"""
|
||||
mvn = shutil.which("mvn")
|
||||
if not mvn:
|
||||
return False
|
||||
cmd = [
|
||||
mvn,
|
||||
"dependency:resolve",
|
||||
f"-Dartifact=com.codeflash:codeflash-runtime:{CODEFLASH_RUNTIME_VERSION}",
|
||||
"-B",
|
||||
"-q",
|
||||
]
|
||||
try:
|
||||
result = subprocess.run(cmd, check=False, cwd=maven_root, capture_output=True, text=True, timeout=60)
|
||||
if result.returncode == 0:
|
||||
logger.info("Resolved codeflash-runtime %s from Maven Central", CODEFLASH_RUNTIME_VERSION)
|
||||
return True
|
||||
logger.debug("Maven Central resolution failed: %s", result.stderr)
|
||||
return False
|
||||
except Exception as e:
|
||||
logger.debug("Maven Central resolution error: %s", e)
|
||||
return False
|
||||
|
||||
|
||||
def _safe_parse_xml(file_path: Path) -> ET.ElementTree:
|
||||
content = file_path.read_text(encoding="utf-8")
|
||||
root = ET.fromstring(content)
|
||||
return ET.ElementTree(root)
|
||||
|
||||
|
||||
def install_codeflash_runtime(project_root: Path, runtime_jar_path: Path, mvn: str | None = None) -> bool:
|
||||
if not mvn:
|
||||
mvn = shutil.which("mvn")
|
||||
if not mvn:
|
||||
logger.error("Maven not found")
|
||||
return False
|
||||
|
||||
if not runtime_jar_path.exists():
|
||||
logger.error("Runtime JAR not found: %s", runtime_jar_path)
|
||||
return False
|
||||
|
||||
cmd = [
|
||||
mvn,
|
||||
"install:install-file",
|
||||
f"-Dfile={runtime_jar_path}",
|
||||
"-DgroupId=com.codeflash",
|
||||
"-DartifactId=codeflash-runtime",
|
||||
"-Dversion=1.0.0",
|
||||
"-Dpackaging=jar",
|
||||
"-B",
|
||||
]
|
||||
|
||||
try:
|
||||
result = subprocess.run(cmd, check=False, cwd=project_root, capture_output=True, text=True, timeout=60)
|
||||
if result.returncode == 0:
|
||||
logger.info("Successfully installed codeflash-runtime to local Maven repository")
|
||||
return True
|
||||
logger.error("Failed to install codeflash-runtime: %s", result.stderr)
|
||||
return False
|
||||
except Exception as e:
|
||||
logger.exception("Failed to install codeflash-runtime: %s", e)
|
||||
return False
|
||||
|
||||
|
||||
def add_codeflash_dependency(pom_path: Path) -> bool:
|
||||
if not pom_path.exists():
|
||||
return False
|
||||
|
||||
try:
|
||||
content = pom_path.read_text(encoding="utf-8")
|
||||
|
||||
if "codeflash-runtime" in content:
|
||||
if "<scope>system</scope>" in content:
|
||||
|
||||
def replace_system_dep(match: re.Match[str]) -> str:
|
||||
block: str = match.group(0)
|
||||
if "codeflash-runtime" in block and "<scope>system</scope>" in block:
|
||||
return (
|
||||
"<dependency>\n"
|
||||
" <groupId>com.codeflash</groupId>\n"
|
||||
" <artifactId>codeflash-runtime</artifactId>\n"
|
||||
" <version>1.0.0</version>\n"
|
||||
" <scope>test</scope>\n"
|
||||
" </dependency>"
|
||||
)
|
||||
return block
|
||||
|
||||
content = re.sub(r"<dependency>[\s\S]*?</dependency>", replace_system_dep, content)
|
||||
pom_path.write_text(content, encoding="utf-8")
|
||||
logger.info("Replaced system-scope codeflash-runtime dependency with test scope")
|
||||
return True
|
||||
logger.info("codeflash-runtime dependency already present in pom.xml")
|
||||
return True
|
||||
|
||||
closing_tag = "</dependencies>"
|
||||
idx = content.find(closing_tag)
|
||||
if idx == -1:
|
||||
logger.warning("No </dependencies> tag found in pom.xml, cannot add dependency")
|
||||
return False
|
||||
|
||||
new_content = content[:idx] + CODEFLASH_DEPENDENCY_SNIPPET
|
||||
new_content += content[idx + len(closing_tag) :]
|
||||
|
||||
pom_path.write_text(new_content, encoding="utf-8")
|
||||
logger.info("Added codeflash-runtime dependency to pom.xml")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.exception("Failed to add dependency to pom.xml: %s", e)
|
||||
return False
|
||||
|
||||
|
||||
def is_jacoco_configured(pom_path: Path) -> bool:
|
||||
if not pom_path.exists():
|
||||
return False
|
||||
|
||||
try:
|
||||
tree = _safe_parse_xml(pom_path)
|
||||
root = tree.getroot()
|
||||
if root is None:
|
||||
return False
|
||||
|
||||
ns_prefix = "{http://maven.apache.org/POM/4.0.0}"
|
||||
use_ns = root.tag.startswith("{")
|
||||
if not use_ns:
|
||||
ns_prefix = ""
|
||||
|
||||
for plugin in root.findall(f".//{ns_prefix}plugin" if use_ns else ".//plugin"):
|
||||
artifact_id = plugin.find(f"{ns_prefix}artifactId" if use_ns else "artifactId")
|
||||
if artifact_id is not None and artifact_id.text == "jacoco-maven-plugin":
|
||||
group_id = plugin.find(f"{ns_prefix}groupId" if use_ns else "groupId")
|
||||
if group_id is None or group_id.text == "org.jacoco":
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
except ET.ParseError as e:
|
||||
logger.warning("Failed to parse pom.xml for JaCoCo check: %s", e)
|
||||
return False
|
||||
|
||||
|
||||
def _find_closing_tag(content: str, start_pos: int, tag_name: str) -> int:
|
||||
open_tag = f"<{tag_name}>"
|
||||
open_tag_short = f"<{tag_name} "
|
||||
close_tag = f"</{tag_name}>"
|
||||
|
||||
depth = 1
|
||||
pos = start_pos + len(f"<{tag_name}")
|
||||
|
||||
while pos < len(content):
|
||||
next_open = content.find(open_tag, pos)
|
||||
next_open_short = content.find(open_tag_short, pos)
|
||||
next_close = content.find(close_tag, pos)
|
||||
|
||||
if next_close == -1:
|
||||
return -1
|
||||
|
||||
candidates = [x for x in [next_open, next_open_short] if x != -1 and x < next_close]
|
||||
next_open_any = min(candidates) if candidates else len(content) + 1
|
||||
|
||||
if next_open_any < next_close:
|
||||
depth += 1
|
||||
pos = next_open_any + 1
|
||||
else:
|
||||
depth -= 1
|
||||
if depth == 0:
|
||||
return next_close
|
||||
pos = next_close + len(close_tag)
|
||||
|
||||
return -1
|
||||
|
||||
|
||||
def add_jacoco_plugin(pom_path: Path) -> bool:
|
||||
if not pom_path.exists():
|
||||
logger.error("pom.xml not found: %s", pom_path)
|
||||
return False
|
||||
|
||||
if is_jacoco_configured(pom_path):
|
||||
logger.info("JaCoCo plugin already configured in pom.xml")
|
||||
return True
|
||||
|
||||
try:
|
||||
content = pom_path.read_text(encoding="utf-8")
|
||||
|
||||
if "</project>" not in content:
|
||||
logger.error("Invalid pom.xml: no closing </project> tag found")
|
||||
return False
|
||||
|
||||
jacoco_plugin = f"""
|
||||
<plugin>
|
||||
<groupId>org.jacoco</groupId>
|
||||
<artifactId>jacoco-maven-plugin</artifactId>
|
||||
<version>{JACOCO_PLUGIN_VERSION}</version>
|
||||
<executions>
|
||||
<execution>
|
||||
<id>prepare-agent</id>
|
||||
<goals>
|
||||
<goal>prepare-agent</goal>
|
||||
</goals>
|
||||
</execution>
|
||||
<execution>
|
||||
<id>report</id>
|
||||
<phase>verify</phase>
|
||||
<goals>
|
||||
<goal>report</goal>
|
||||
</goals>
|
||||
<configuration>
|
||||
<!-- For multi-module projects, include dependency classes -->
|
||||
<includes>
|
||||
<include>**/*.class</include>
|
||||
</includes>
|
||||
</configuration>
|
||||
</execution>
|
||||
</executions>
|
||||
</plugin>"""
|
||||
|
||||
profiles_start = content.find("<profiles>")
|
||||
profiles_end = content.find("</profiles>")
|
||||
|
||||
if profiles_start == -1:
|
||||
build_start = content.find("<build>")
|
||||
build_end = content.find("</build>")
|
||||
else:
|
||||
build_before_profiles = content[:profiles_start].rfind("<build>")
|
||||
build_after_profiles = content[profiles_end:].find("<build>") if profiles_end != -1 else -1
|
||||
if build_after_profiles != -1:
|
||||
build_after_profiles += profiles_end
|
||||
|
||||
if build_before_profiles != -1:
|
||||
build_start = build_before_profiles
|
||||
build_end = _find_closing_tag(content, build_start, "build")
|
||||
elif build_after_profiles != -1:
|
||||
build_start = build_after_profiles
|
||||
build_end = _find_closing_tag(content, build_start, "build")
|
||||
else:
|
||||
build_start = -1
|
||||
build_end = -1
|
||||
|
||||
if build_start != -1 and build_end != -1:
|
||||
build_section = content[build_start : build_end + len("</build>")]
|
||||
plugins_start_in_build = build_section.find("<plugins>")
|
||||
plugins_end_in_build = build_section.rfind("</plugins>")
|
||||
|
||||
if plugins_start_in_build != -1 and plugins_end_in_build != -1:
|
||||
absolute_plugins_end = build_start + plugins_end_in_build
|
||||
content = content[:absolute_plugins_end] + jacoco_plugin + "\n " + content[absolute_plugins_end:]
|
||||
else:
|
||||
plugins_section = f"<plugins>{jacoco_plugin}\n </plugins>\n "
|
||||
content = content[:build_end] + plugins_section + content[build_end:]
|
||||
else:
|
||||
project_end = content.rfind("</project>")
|
||||
build_section = f"""
|
||||
<build>
|
||||
<plugins>{jacoco_plugin}
|
||||
</plugins>
|
||||
</build>
|
||||
"""
|
||||
content = content[:project_end] + build_section + content[project_end:]
|
||||
|
||||
pom_path.write_text(content, encoding="utf-8")
|
||||
logger.info("Added JaCoCo plugin to pom.xml")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.exception("Failed to add JaCoCo plugin to pom.xml: %s", e)
|
||||
return False
|
||||
|
||||
|
||||
def get_jacoco_report_path(project_root: Path) -> Path:
|
||||
return project_root / "target" / "site" / "jacoco" / "jacoco.xml"
|
||||
|
||||
|
||||
class MavenStrategy(BuildToolStrategy):
|
||||
"""Maven-specific build tool operations."""
|
||||
|
||||
_M2_JAR = (
|
||||
Path.home()
|
||||
/ ".m2"
|
||||
/ "repository"
|
||||
/ "com"
|
||||
/ "codeflash"
|
||||
/ "codeflash-runtime"
|
||||
/ "1.0.0"
|
||||
/ "codeflash-runtime-1.0.0.jar"
|
||||
)
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return "Maven"
|
||||
|
||||
def find_executable(self, build_root: Path) -> str | None:
|
||||
mvnw_path = build_root / "mvnw"
|
||||
if mvnw_path.exists():
|
||||
return str(mvnw_path)
|
||||
mvnw_cmd_path = build_root / "mvnw.cmd"
|
||||
if mvnw_cmd_path.exists():
|
||||
return str(mvnw_cmd_path)
|
||||
if Path("mvnw").exists():
|
||||
return "./mvnw"
|
||||
if Path("mvnw.cmd").exists():
|
||||
return "mvnw.cmd"
|
||||
return shutil.which("mvn")
|
||||
|
||||
def find_runtime_jar(self) -> Path | None:
|
||||
if self._M2_JAR.exists():
|
||||
return self._M2_JAR
|
||||
return super().find_runtime_jar()
|
||||
|
||||
def ensure_runtime(self, build_root: Path, test_module: str | None) -> bool:
|
||||
if not self._M2_JAR.exists():
|
||||
if resolve_from_maven_central(build_root):
|
||||
logger.info("Resolved codeflash-runtime from Maven Central")
|
||||
else:
|
||||
runtime_jar = self.find_runtime_jar()
|
||||
if runtime_jar is None:
|
||||
runtime_jar = download_from_github_releases()
|
||||
if runtime_jar is None:
|
||||
logger.error(
|
||||
"codeflash-runtime JAR not found. Maven Central resolution failed and "
|
||||
"GitHub Releases download failed. Generated tests will fail to compile."
|
||||
)
|
||||
return False
|
||||
logger.info("Installing codeflash-runtime JAR to local Maven repository from %s", runtime_jar)
|
||||
if not install_codeflash_runtime(build_root, runtime_jar, mvn=self.find_executable(build_root)):
|
||||
logger.error("Failed to install codeflash-runtime to local Maven repository")
|
||||
return False
|
||||
|
||||
if test_module:
|
||||
pom_path = build_root / module_to_dir(test_module) / "pom.xml"
|
||||
else:
|
||||
pom_path = build_root / "pom.xml"
|
||||
|
||||
if pom_path.exists():
|
||||
if not add_codeflash_dependency(pom_path):
|
||||
logger.error("Failed to add codeflash-runtime dependency to %s", pom_path)
|
||||
return False
|
||||
else:
|
||||
logger.warning("pom.xml not found at %s, cannot add codeflash-runtime dependency", pom_path)
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
def install_multi_module_deps(self, build_root: Path, test_module: str | None, env: dict[str, str]) -> bool:
|
||||
from codeflash.languages.java.test_runner import _run_cmd_kill_pg_on_timeout
|
||||
|
||||
if not test_module:
|
||||
return True
|
||||
|
||||
cache_key = (build_root, test_module)
|
||||
if cache_key in _multimodule_deps_installed:
|
||||
logger.debug("Multi-module deps already installed for %s:%s, skipping", build_root, test_module)
|
||||
return True
|
||||
|
||||
mvn = self.find_executable(build_root)
|
||||
if not mvn:
|
||||
logger.error("Maven not found — cannot pre-install multi-module dependencies")
|
||||
return False
|
||||
|
||||
cmd = [mvn, "install", "-DskipTests", "-B", "-pl", module_to_dir(test_module), "-am"]
|
||||
cmd.extend(_MAVEN_VALIDATION_SKIP_FLAGS)
|
||||
|
||||
logger.info("Pre-installing multi-module dependencies: %s (module: %s)", build_root, test_module)
|
||||
logger.debug("Running: %s", " ".join(cmd))
|
||||
|
||||
try:
|
||||
result = _run_cmd_kill_pg_on_timeout(cmd, cwd=build_root, env=env, timeout=300)
|
||||
if result.returncode != 0:
|
||||
logger.error(
|
||||
"Failed to pre-install multi-module deps (exit %d).\nstdout: %s\nstderr: %s",
|
||||
result.returncode,
|
||||
result.stdout[-2000:] if result.stdout else "",
|
||||
result.stderr[-2000:] if result.stderr else "",
|
||||
)
|
||||
return False
|
||||
except Exception:
|
||||
logger.exception("Exception during multi-module dependency install")
|
||||
return False
|
||||
|
||||
_multimodule_deps_installed.add(cache_key)
|
||||
logger.info("Multi-module dependencies installed successfully for %s:%s", build_root, test_module)
|
||||
return True
|
||||
|
||||
def compile_tests(
|
||||
self, build_root: Path, env: dict[str, str], test_module: str | None, timeout: int = 120
|
||||
) -> subprocess.CompletedProcess[str]:
|
||||
from codeflash.languages.java.test_runner import _run_cmd_kill_pg_on_timeout
|
||||
|
||||
mvn = self.find_executable(build_root)
|
||||
if not mvn:
|
||||
logger.error("Maven not found")
|
||||
return subprocess.CompletedProcess(args=["mvn"], returncode=-1, stdout="", stderr="Maven not found")
|
||||
|
||||
cmd = [mvn, "test-compile", "-e", "-B"]
|
||||
cmd.extend(_MAVEN_VALIDATION_SKIP_FLAGS)
|
||||
|
||||
if test_module:
|
||||
cmd.extend(["-pl", module_to_dir(test_module)])
|
||||
|
||||
logger.debug("Compiling tests: %s in %s", " ".join(cmd), build_root)
|
||||
|
||||
try:
|
||||
return _run_cmd_kill_pg_on_timeout(cmd, cwd=build_root, env=env, timeout=timeout)
|
||||
except Exception as e:
|
||||
logger.exception("Maven compilation failed: %s", e)
|
||||
return subprocess.CompletedProcess(args=cmd, returncode=-1, stdout="", stderr=str(e))
|
||||
|
||||
def compile_source_only(
|
||||
self, build_root: Path, env: dict[str, str], test_module: str | None, timeout: int = 120
|
||||
) -> subprocess.CompletedProcess[str]:
|
||||
from codeflash.languages.java.test_runner import _run_cmd_kill_pg_on_timeout
|
||||
|
||||
mvn = self.find_executable(build_root)
|
||||
if not mvn:
|
||||
logger.error("Maven not found")
|
||||
return subprocess.CompletedProcess(args=["mvn"], returncode=-1, stdout="", stderr="Maven not found")
|
||||
|
||||
cmd = [mvn, "compile", "-e", "-B"]
|
||||
cmd.extend(_MAVEN_VALIDATION_SKIP_FLAGS)
|
||||
|
||||
if test_module:
|
||||
cmd.extend(["-pl", module_to_dir(test_module)])
|
||||
|
||||
logger.debug("Compiling source only: %s in %s", " ".join(cmd), build_root)
|
||||
|
||||
try:
|
||||
return _run_cmd_kill_pg_on_timeout(cmd, cwd=build_root, env=env, timeout=timeout)
|
||||
except Exception as e:
|
||||
logger.exception("Maven source compilation failed: %s", e)
|
||||
return subprocess.CompletedProcess(args=cmd, returncode=-1, stdout="", stderr=str(e))
|
||||
|
||||
def get_classpath(
|
||||
self, build_root: Path, env: dict[str, str], test_module: str | None, timeout: int = 60
|
||||
) -> str | None:
|
||||
key = (build_root, test_module)
|
||||
cached = _classpath_cache.get(key)
|
||||
if cached is not None:
|
||||
logger.debug("Using cached classpath for (%s, %s)", build_root, test_module)
|
||||
return cached
|
||||
result = self._get_classpath_uncached(build_root, env, test_module, timeout)
|
||||
if result is not None:
|
||||
_classpath_cache[key] = result
|
||||
return result
|
||||
|
||||
def _get_classpath_uncached(
|
||||
self, build_root: Path, env: dict[str, str], test_module: str | None, timeout: int = 60
|
||||
) -> str | None:
|
||||
from codeflash.languages.java.test_runner import _find_junit_console_standalone, _run_cmd_kill_pg_on_timeout
|
||||
|
||||
mvn = self.find_executable(build_root)
|
||||
if not mvn:
|
||||
return None
|
||||
|
||||
cp_file = build_root / ".codeflash_classpath.txt"
|
||||
|
||||
cmd = [mvn, "dependency:build-classpath", "-DincludeScope=test", f"-Dmdep.outputFile={cp_file}", "-q", "-B"]
|
||||
|
||||
if test_module:
|
||||
cmd.extend(["-pl", module_to_dir(test_module)])
|
||||
|
||||
logger.debug("Getting classpath: %s", " ".join(cmd))
|
||||
|
||||
try:
|
||||
result = _run_cmd_kill_pg_on_timeout(cmd, cwd=build_root, env=env, timeout=timeout)
|
||||
|
||||
if result.returncode != 0:
|
||||
logger.error("Failed to get classpath: %s", result.stderr)
|
||||
return None
|
||||
|
||||
if not cp_file.exists():
|
||||
logger.error("Classpath file not created")
|
||||
return None
|
||||
|
||||
classpath = cp_file.read_text(encoding="utf-8").strip()
|
||||
|
||||
if test_module:
|
||||
module_path = build_root / module_to_dir(test_module)
|
||||
else:
|
||||
module_path = build_root
|
||||
|
||||
test_classes = module_path / "target" / "test-classes"
|
||||
main_classes = module_path / "target" / "classes"
|
||||
|
||||
cp_parts = [classpath]
|
||||
if test_classes.exists():
|
||||
cp_parts.append(str(test_classes))
|
||||
if main_classes.exists():
|
||||
cp_parts.append(str(main_classes))
|
||||
|
||||
if test_module:
|
||||
module_dir_name = module_to_dir(test_module)
|
||||
for module_dir in build_root.iterdir():
|
||||
if module_dir.is_dir() and module_dir.name != module_dir_name:
|
||||
module_classes = module_dir / "target" / "classes"
|
||||
if module_classes.exists():
|
||||
logger.debug("Adding multi-module classpath: %s", module_classes)
|
||||
cp_parts.append(str(module_classes))
|
||||
|
||||
if "console-standalone" not in classpath and "ConsoleLauncher" not in classpath:
|
||||
console_jar = _find_junit_console_standalone()
|
||||
if console_jar:
|
||||
logger.debug("Adding JUnit Console Standalone to classpath: %s", console_jar)
|
||||
cp_parts.append(str(console_jar))
|
||||
|
||||
return os.pathsep.join(cp_parts)
|
||||
|
||||
except Exception as e:
|
||||
logger.exception("Failed to get classpath: %s", e)
|
||||
return None
|
||||
finally:
|
||||
if cp_file.exists():
|
||||
cp_file.unlink()
|
||||
|
||||
def get_reports_dir(self, build_root: Path, test_module: str | None) -> Path:
|
||||
target_dir = self.get_build_output_dir(build_root, test_module)
|
||||
return target_dir / "surefire-reports"
|
||||
|
||||
def get_build_output_dir(self, build_root: Path, test_module: str | None) -> Path:
|
||||
if test_module:
|
||||
return build_root.joinpath(module_to_dir(test_module), _TARGET)
|
||||
return build_root.joinpath(_TARGET)
|
||||
|
||||
def run_tests_via_build_tool(
|
||||
self,
|
||||
build_root: Path,
|
||||
test_paths: Any,
|
||||
env: dict[str, str],
|
||||
timeout: int,
|
||||
mode: str,
|
||||
test_module: str | None,
|
||||
javaagent_arg: str | None = None,
|
||||
enable_coverage: bool = False,
|
||||
) -> subprocess.CompletedProcess[str]:
|
||||
from codeflash.languages.java.test_runner import (
|
||||
_build_test_filter,
|
||||
_run_cmd_kill_pg_on_timeout,
|
||||
_validate_test_filter,
|
||||
)
|
||||
|
||||
mvn = self.find_executable(build_root)
|
||||
if not mvn:
|
||||
logger.error("Maven not found")
|
||||
return subprocess.CompletedProcess(args=["mvn"], returncode=-1, stdout="", stderr="Maven not found")
|
||||
|
||||
test_filter = _build_test_filter(test_paths, mode=mode)
|
||||
logger.debug("Built test filter for mode=%s: '%s' (empty=%s)", mode, test_filter, not test_filter)
|
||||
|
||||
maven_goal = "verify" if enable_coverage else "test"
|
||||
cmd = [mvn, maven_goal, "-fae", "-B"]
|
||||
cmd.extend(_MAVEN_VALIDATION_SKIP_FLAGS)
|
||||
|
||||
add_opens_flags = (
|
||||
"--add-opens java.base/java.util=ALL-UNNAMED"
|
||||
" --add-opens java.base/java.lang=ALL-UNNAMED"
|
||||
" --add-opens java.base/java.lang.reflect=ALL-UNNAMED"
|
||||
" --add-opens java.base/java.io=ALL-UNNAMED"
|
||||
" --add-opens java.base/java.math=ALL-UNNAMED"
|
||||
" --add-opens java.base/java.net=ALL-UNNAMED"
|
||||
" --add-opens java.base/java.util.zip=ALL-UNNAMED"
|
||||
)
|
||||
if javaagent_arg:
|
||||
cmd.append(f"-DargLine={javaagent_arg} {add_opens_flags}")
|
||||
else:
|
||||
cmd.append(f"-DargLine={add_opens_flags}")
|
||||
|
||||
if mode == "performance":
|
||||
cmd.append("-Dsurefire.useFile=false")
|
||||
|
||||
if enable_coverage:
|
||||
cmd.append("-Dmaven.test.failure.ignore=true")
|
||||
|
||||
if test_module:
|
||||
cmd.extend(
|
||||
[
|
||||
"-pl",
|
||||
module_to_dir(test_module),
|
||||
"-DfailIfNoTests=false",
|
||||
"-Dsurefire.failIfNoSpecifiedTests=false",
|
||||
"-DskipTests=false",
|
||||
]
|
||||
)
|
||||
|
||||
if test_filter:
|
||||
validated_filter = _validate_test_filter(test_filter)
|
||||
cmd.append(f"-Dtest={validated_filter}")
|
||||
logger.debug("Added -Dtest=%s to Maven command", validated_filter)
|
||||
else:
|
||||
error_msg = (
|
||||
f"Test filter is EMPTY for mode={mode}! "
|
||||
f"Maven will run ALL tests instead of the specified tests. "
|
||||
f"This indicates a problem with test file instrumentation or path resolution."
|
||||
)
|
||||
logger.error(error_msg)
|
||||
raise ValueError(error_msg)
|
||||
|
||||
logger.debug("Running Maven command: %s in %s", " ".join(cmd), build_root)
|
||||
|
||||
try:
|
||||
result = _run_cmd_kill_pg_on_timeout(cmd, cwd=build_root, env=env, timeout=timeout)
|
||||
|
||||
if result.returncode != 0:
|
||||
compilation_error_indicators = [
|
||||
"[ERROR] COMPILATION ERROR",
|
||||
"[ERROR] Failed to execute goal org.apache.maven.plugins:maven-compiler-plugin",
|
||||
"compilation failure",
|
||||
"cannot find symbol",
|
||||
"package .* does not exist",
|
||||
]
|
||||
combined_output = (result.stdout or "") + (result.stderr or "")
|
||||
has_compilation_error = any(
|
||||
indicator.lower() in combined_output.lower() for indicator in compilation_error_indicators
|
||||
)
|
||||
if has_compilation_error:
|
||||
logger.error(
|
||||
"Maven compilation failed for %s tests. "
|
||||
"Check that generated test code is syntactically valid Java. "
|
||||
"Return code: %s",
|
||||
mode,
|
||||
result.returncode,
|
||||
)
|
||||
output_lines = combined_output.split("\n")
|
||||
error_context = "\n".join(output_lines[:50]) if len(output_lines) > 50 else combined_output
|
||||
logger.error("Maven compilation error output:\n%s", error_context)
|
||||
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
logger.exception("Maven test execution failed: %s", e)
|
||||
return subprocess.CompletedProcess(args=cmd, returncode=-1, stdout="", stderr=str(e))
|
||||
|
||||
def run_benchmarking_via_build_tool(
|
||||
self,
|
||||
test_paths: Any,
|
||||
test_env: dict[str, str],
|
||||
cwd: Path,
|
||||
timeout: int | None,
|
||||
project_root: Path | None,
|
||||
min_loops: int,
|
||||
max_loops: int,
|
||||
target_duration_seconds: float,
|
||||
inner_iterations: int,
|
||||
) -> tuple[Path, Any]:
|
||||
import time
|
||||
|
||||
from codeflash.languages.java.test_runner import _find_multi_module_root, _get_combined_junit_xml
|
||||
|
||||
project_root = project_root or cwd
|
||||
maven_root, test_module = _find_multi_module_root(project_root, test_paths)
|
||||
|
||||
all_stdout: list[str] = []
|
||||
all_stderr: list[str] = []
|
||||
total_start_time = time.time()
|
||||
loop_count = 0
|
||||
last_result = None
|
||||
|
||||
per_loop_timeout = max(timeout or 0, 120, 60 + inner_iterations)
|
||||
|
||||
logger.debug("Using Maven-based benchmarking (fallback mode)")
|
||||
|
||||
for loop_idx in range(1, max_loops + 1):
|
||||
run_env = os.environ.copy()
|
||||
run_env.update(test_env)
|
||||
run_env["CODEFLASH_LOOP_INDEX"] = str(loop_idx)
|
||||
run_env["CODEFLASH_MODE"] = "performance"
|
||||
run_env["CODEFLASH_TEST_ITERATION"] = "0"
|
||||
if "CODEFLASH_INNER_ITERATIONS" not in run_env:
|
||||
run_env["CODEFLASH_INNER_ITERATIONS"] = str(inner_iterations)
|
||||
|
||||
result = self.run_tests_via_build_tool(
|
||||
maven_root, test_paths, run_env, timeout=per_loop_timeout, mode="performance", test_module=test_module
|
||||
)
|
||||
|
||||
last_result = result
|
||||
loop_count = loop_idx
|
||||
|
||||
if result.stdout:
|
||||
all_stdout.append(result.stdout)
|
||||
if result.stderr:
|
||||
all_stderr.append(result.stderr)
|
||||
|
||||
elapsed = time.time() - total_start_time
|
||||
if loop_idx >= min_loops and elapsed >= target_duration_seconds:
|
||||
logger.debug("Stopping Maven benchmark after %d loops (%.2fs elapsed)", loop_idx, elapsed)
|
||||
break
|
||||
|
||||
if result.returncode != 0:
|
||||
timing_pattern = re.compile(r"!######[^:]*:[^:]*:[^:]*:[^:]*:[^:]+:[^:]+######!")
|
||||
has_timing_markers = bool(timing_pattern.search(result.stdout or ""))
|
||||
if not has_timing_markers:
|
||||
logger.warning("Tests failed in Maven loop %d with no timing markers, stopping", loop_idx)
|
||||
break
|
||||
logger.debug("Some tests failed in Maven loop %d but timing markers present, continuing", loop_idx)
|
||||
|
||||
combined_stdout = "\n".join(all_stdout)
|
||||
combined_stderr = "\n".join(all_stderr)
|
||||
|
||||
total_iterations = loop_count * inner_iterations
|
||||
logger.debug(
|
||||
"Maven fallback: %d loops x %d iterations = %d total in %.2fs",
|
||||
loop_count,
|
||||
inner_iterations,
|
||||
total_iterations,
|
||||
time.time() - total_start_time,
|
||||
)
|
||||
|
||||
combined_result = subprocess.CompletedProcess(
|
||||
args=last_result.args if last_result else ["mvn", "test"],
|
||||
returncode=last_result.returncode if last_result else -1,
|
||||
stdout=combined_stdout,
|
||||
stderr=combined_stderr,
|
||||
)
|
||||
|
||||
reports_dir = self.get_reports_dir(maven_root, test_module)
|
||||
result_xml_path = _get_combined_junit_xml(reports_dir, -1)
|
||||
|
||||
return result_xml_path, combined_result
|
||||
|
||||
def run_tests_with_coverage(
|
||||
self,
|
||||
build_root: Path,
|
||||
test_module: str | None,
|
||||
test_paths: Any,
|
||||
run_env: dict[str, str],
|
||||
timeout: int,
|
||||
candidate_index: int,
|
||||
) -> tuple[subprocess.CompletedProcess[str], Path, Path | None]:
|
||||
from codeflash.languages.java.test_runner import _get_combined_junit_xml
|
||||
|
||||
coverage_xml_path = self.setup_coverage(build_root, test_module, build_root)
|
||||
|
||||
result = self.run_tests_via_build_tool(
|
||||
build_root,
|
||||
test_paths,
|
||||
run_env,
|
||||
timeout=timeout,
|
||||
mode="behavior",
|
||||
enable_coverage=True,
|
||||
test_module=test_module,
|
||||
)
|
||||
|
||||
reports_dir = self.get_reports_dir(build_root, test_module)
|
||||
result_xml_path = _get_combined_junit_xml(reports_dir, candidate_index)
|
||||
|
||||
return result, result_xml_path, coverage_xml_path
|
||||
|
||||
def setup_coverage(self, build_root: Path, test_module: str | None, project_root: Path) -> Path | None:
|
||||
if test_module:
|
||||
test_module_pom = build_root / module_to_dir(test_module) / "pom.xml"
|
||||
if test_module_pom.exists():
|
||||
if not is_jacoco_configured(test_module_pom):
|
||||
logger.info("Adding JaCoCo plugin to test module pom.xml: %s", test_module_pom)
|
||||
add_jacoco_plugin(test_module_pom)
|
||||
return get_jacoco_report_path(build_root / module_to_dir(test_module))
|
||||
else:
|
||||
pom_path = project_root / "pom.xml"
|
||||
if pom_path.exists():
|
||||
if not is_jacoco_configured(pom_path):
|
||||
logger.info("Adding JaCoCo plugin to pom.xml for coverage collection")
|
||||
add_jacoco_plugin(pom_path)
|
||||
return get_jacoco_report_path(project_root)
|
||||
return None
|
||||
|
||||
def get_test_run_command(self, project_root: Path, test_classes: list[str] | None = None) -> list[str]:
|
||||
from codeflash.languages.java.test_runner import _validate_java_class_name
|
||||
|
||||
if test_classes:
|
||||
for test_class in test_classes:
|
||||
if not _validate_java_class_name(test_class):
|
||||
msg = f"Invalid test class name: '{test_class}'. Test names must follow Java identifier rules."
|
||||
raise ValueError(msg)
|
||||
|
||||
mvn = self.find_executable(project_root) or "mvn"
|
||||
cmd = [mvn, "test", "-B"]
|
||||
if test_classes:
|
||||
cmd.append(f"-Dtest={','.join(test_classes)}")
|
||||
return cmd
|
||||
|
|
@ -16,7 +16,7 @@ from typing import TYPE_CHECKING
|
|||
from junitparser.xunit2 import JUnitXml
|
||||
|
||||
from codeflash.cli_cmds.console import console, logger
|
||||
from codeflash.code_utils.code_utils import module_name_from_file_path
|
||||
from codeflash.code_utils.code_utils import extract_parameterized_test_index, module_name_from_file_path
|
||||
from codeflash.models.models import FunctionTestInvocation, InvocationId, TestResults
|
||||
|
||||
if TYPE_CHECKING:
|
||||
|
|
@ -128,7 +128,9 @@ def parse_java_test_xml(
|
|||
if class_name is not None and class_name.startswith(test_module_path):
|
||||
test_class = class_name[len(test_module_path) + 1 :]
|
||||
|
||||
loop_index = int(testcase.name.split("[ ")[-1][:-2]) if testcase.name and "[" in testcase.name else 1
|
||||
loop_index = (
|
||||
extract_parameterized_test_index(testcase.name) if testcase.name and "[" in testcase.name else 1
|
||||
)
|
||||
|
||||
timed_out = False
|
||||
if len(testcase.result) > 1:
|
||||
|
|
|
|||
|
|
@ -16,7 +16,7 @@ import re
|
|||
import textwrap
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
|
||||
from codeflash.languages.java.parser import get_java_analyzer
|
||||
|
|
@ -208,7 +208,7 @@ def _insert_class_members(
|
|||
if not fields and not helpers_before_target and not helpers_after_target:
|
||||
return source
|
||||
|
||||
def get_target_class_and_body(src: str): # type: ignore[return]
|
||||
def get_target_class_and_body(src: str) -> tuple[Any, Any]:
|
||||
for cls in analyzer.find_classes(src):
|
||||
if cls.name == class_name:
|
||||
body = cls.node.child_by_field_name("body")
|
||||
|
|
|
|||
|
|
@ -66,6 +66,7 @@ class JavaSupport(LanguageSupport):
|
|||
self.line_profiler_agent_arg: str | None = None
|
||||
self.line_profiler_warmup_iterations: int = 0
|
||||
self._language_version: str | None = None
|
||||
self._test_framework: str = "junit5"
|
||||
|
||||
@property
|
||||
def language(self) -> Language:
|
||||
|
|
@ -79,8 +80,8 @@ class JavaSupport(LanguageSupport):
|
|||
|
||||
@property
|
||||
def test_framework(self) -> str:
|
||||
"""Primary test framework name."""
|
||||
return "junit5"
|
||||
"""Primary test framework name, detected from project build config."""
|
||||
return self._test_framework
|
||||
|
||||
@property
|
||||
def comment_prefix(self) -> str:
|
||||
|
|
@ -368,10 +369,19 @@ class JavaSupport(LanguageSupport):
|
|||
# === Test Result Comparison ===
|
||||
|
||||
def compare_test_results(
|
||||
self, original_results_path: Path, candidate_results_path: Path, project_root: Path | None = None
|
||||
self,
|
||||
original_results_path: Path,
|
||||
candidate_results_path: Path,
|
||||
project_root: Path | None = None,
|
||||
project_classpath: str | None = None,
|
||||
) -> tuple[bool, list[Any]]:
|
||||
"""Compare test results between original and candidate code."""
|
||||
return _compare_test_results(original_results_path, candidate_results_path, project_root=project_root)
|
||||
return _compare_test_results(
|
||||
original_results_path,
|
||||
candidate_results_path,
|
||||
project_root=project_root,
|
||||
project_classpath=project_classpath,
|
||||
)
|
||||
|
||||
# === Reference Finding ===
|
||||
|
||||
|
|
@ -394,9 +404,10 @@ class JavaSupport(LanguageSupport):
|
|||
return None
|
||||
|
||||
def setup_test_config(self, test_cfg: Any, file_path: Path, current_worktree: Path | None = None) -> None:
|
||||
return None
|
||||
|
||||
# === Configuration ===
|
||||
"""Detect test framework from project build config (pom.xml / build.gradle)."""
|
||||
config = detect_java_project(test_cfg.project_root_path)
|
||||
if config is not None:
|
||||
self._test_framework = config.test_framework
|
||||
|
||||
def adjust_test_config_for_discovery(self, test_cfg: Any) -> None:
|
||||
"""Adjust test config before test discovery for Java.
|
||||
|
|
@ -534,8 +545,8 @@ class JavaSupport(LanguageSupport):
|
|||
if self._language_version is None:
|
||||
self._detect_java_version()
|
||||
|
||||
# For now, assume the runtime is available
|
||||
# A full implementation would check/install the JAR
|
||||
self._test_framework = config.test_framework
|
||||
|
||||
return True
|
||||
|
||||
def _detect_java_version(self) -> None:
|
||||
|
|
|
|||
File diff suppressed because it is too large
Load diff
|
|
@ -1904,7 +1904,11 @@ class JavaScriptSupport:
|
|||
# === Test Result Comparison ===
|
||||
|
||||
def compare_test_results(
|
||||
self, original_results_path: Path, candidate_results_path: Path, project_root: Path | None = None
|
||||
self,
|
||||
original_results_path: Path,
|
||||
candidate_results_path: Path,
|
||||
project_root: Path | None = None,
|
||||
project_classpath: str | None = None,
|
||||
) -> tuple[bool, list]:
|
||||
"""Compare test results between original and candidate code.
|
||||
|
||||
|
|
|
|||
|
|
@ -861,6 +861,47 @@ def _get_dataclass_config(class_node: ast.ClassDef, import_aliases: dict[str, st
|
|||
return False, False, False
|
||||
|
||||
|
||||
_ATTRS_NAMESPACES = frozenset({"attrs", "attr"})
|
||||
_ATTRS_DECORATOR_NAMES = frozenset({"define", "mutable", "frozen", "s", "attrs"})
|
||||
|
||||
|
||||
def _resolve_decorator_name(expr_name: str, import_aliases: dict[str, str]) -> str:
|
||||
resolved = import_aliases.get(expr_name)
|
||||
if resolved is not None:
|
||||
return resolved
|
||||
first_part, sep, rest = expr_name.partition(".")
|
||||
if sep:
|
||||
root_resolved = import_aliases.get(first_part)
|
||||
if root_resolved is not None:
|
||||
return f"{root_resolved}.{rest}"
|
||||
|
||||
return expr_name
|
||||
|
||||
|
||||
def _get_attrs_config(class_node: ast.ClassDef, import_aliases: dict[str, str]) -> tuple[bool, bool, bool]:
|
||||
for decorator in class_node.decorator_list:
|
||||
expr_name = _get_expr_name(decorator)
|
||||
if expr_name is None:
|
||||
continue
|
||||
resolved = _resolve_decorator_name(expr_name, import_aliases)
|
||||
parts = resolved.split(".")
|
||||
if len(parts) < 2 or parts[-2] not in _ATTRS_NAMESPACES or parts[-1] not in _ATTRS_DECORATOR_NAMES:
|
||||
continue
|
||||
init_enabled = True
|
||||
kw_only = False
|
||||
if isinstance(decorator, ast.Call):
|
||||
for keyword in decorator.keywords:
|
||||
literal_value = _bool_literal(keyword.value)
|
||||
if literal_value is None:
|
||||
continue
|
||||
if keyword.arg == "init":
|
||||
init_enabled = literal_value
|
||||
elif keyword.arg == "kw_only":
|
||||
kw_only = literal_value
|
||||
return True, init_enabled, kw_only
|
||||
return False, False, False
|
||||
|
||||
|
||||
def _is_classvar_annotation(annotation: ast.expr, import_aliases: dict[str, str]) -> bool:
|
||||
annotation_root = annotation.value if isinstance(annotation, ast.Subscript) else annotation
|
||||
return _expr_matches_name(annotation_root, import_aliases, "ClassVar")
|
||||
|
|
@ -885,10 +926,13 @@ def _class_has_explicit_init(class_node: ast.ClassDef) -> bool:
|
|||
|
||||
def _collect_synthetic_constructor_type_names(class_node: ast.ClassDef, import_aliases: dict[str, str]) -> set[str]:
|
||||
is_dataclass, dataclass_init_enabled, _ = _get_dataclass_config(class_node, import_aliases)
|
||||
if not _is_namedtuple_class(class_node, import_aliases) and not is_dataclass:
|
||||
is_attrs, attrs_init_enabled, _ = _get_attrs_config(class_node, import_aliases)
|
||||
if not _is_namedtuple_class(class_node, import_aliases) and not is_dataclass and not is_attrs:
|
||||
return set()
|
||||
if is_dataclass and not dataclass_init_enabled:
|
||||
return set()
|
||||
if is_attrs and not attrs_init_enabled:
|
||||
return set()
|
||||
|
||||
names = set[str]()
|
||||
for item in class_node.body:
|
||||
|
|
@ -939,9 +983,9 @@ def _extract_synthetic_init_parameters(
|
|||
kw_only = literal_value
|
||||
elif keyword.arg == "default":
|
||||
default_value = _get_node_source(keyword.value, module_source)
|
||||
elif keyword.arg == "default_factory":
|
||||
# Default factories still imply an optional constructor parameter, but
|
||||
# the generated __init__ does not use the field() call directly.
|
||||
elif keyword.arg in {"default_factory", "factory"}:
|
||||
# Default factories (dataclass default_factory= / attrs factory=) still imply
|
||||
# an optional constructor parameter.
|
||||
default_value = "..."
|
||||
else:
|
||||
default_value = _get_node_source(item.value, module_source)
|
||||
|
|
@ -960,13 +1004,17 @@ def _build_synthetic_init_stub(
|
|||
) -> str | None:
|
||||
is_namedtuple = _is_namedtuple_class(class_node, import_aliases)
|
||||
is_dataclass, dataclass_init_enabled, dataclass_kw_only = _get_dataclass_config(class_node, import_aliases)
|
||||
if not is_namedtuple and not is_dataclass:
|
||||
is_attrs, attrs_init_enabled, attrs_kw_only = _get_attrs_config(class_node, import_aliases)
|
||||
if not is_namedtuple and not is_dataclass and not is_attrs:
|
||||
return None
|
||||
if is_dataclass and not dataclass_init_enabled:
|
||||
return None
|
||||
if is_attrs and not attrs_init_enabled:
|
||||
return None
|
||||
|
||||
kw_only_by_default = dataclass_kw_only or attrs_kw_only
|
||||
parameters = _extract_synthetic_init_parameters(
|
||||
class_node, module_source, import_aliases, kw_only_by_default=dataclass_kw_only
|
||||
class_node, module_source, import_aliases, kw_only_by_default=kw_only_by_default
|
||||
)
|
||||
if not parameters:
|
||||
return None
|
||||
|
|
|
|||
|
|
@ -2,10 +2,11 @@ from __future__ import annotations
|
|||
|
||||
import ast
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING
|
||||
from typing import TYPE_CHECKING, cast
|
||||
|
||||
from codeflash.code_utils.code_utils import get_run_tmp_file
|
||||
from codeflash.code_utils.formatter import sort_imports
|
||||
from codeflash.languages.python.context.code_context_extractor import _ATTRS_DECORATOR_NAMES, _ATTRS_NAMESPACES
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
|
||||
|
|
@ -80,6 +81,7 @@ class InitDecorator(ast.NodeTransformer):
|
|||
self.has_import = False
|
||||
self.tests_root = tests_root
|
||||
self.inserted_decorator = False
|
||||
self._attrs_classes_to_patch: dict[str, ast.Call] = {}
|
||||
|
||||
# Precompute decorator components to avoid reconstructing on every node visit
|
||||
# Only the `function_name` field changes per class
|
||||
|
|
@ -118,6 +120,21 @@ class InitDecorator(ast.NodeTransformer):
|
|||
defaults=[],
|
||||
)
|
||||
|
||||
# Pre-build reusable AST nodes for _build_attrs_patch_block
|
||||
self._load_ctx = ast.Load()
|
||||
self._store_ctx = ast.Store()
|
||||
self._args_name_load = ast.Name(id="args", ctx=self._load_ctx)
|
||||
self._kwargs_name_load = ast.Name(id="kwargs", ctx=self._load_ctx)
|
||||
self._self_arg_node = ast.arg(arg="self")
|
||||
self._args_arg_node = ast.arg(arg="args")
|
||||
self._kwargs_arg_node = ast.arg(arg="kwargs")
|
||||
self._self_name_load = ast.Name(id="self", ctx=self._load_ctx)
|
||||
self._starred_args = ast.Starred(value=self._args_name_load, ctx=self._load_ctx)
|
||||
self._kwargs_keyword = ast.keyword(arg=None, value=self._kwargs_name_load)
|
||||
|
||||
# Pre-parse the import statement to avoid repeated parsing in visit_Module
|
||||
self._import_stmt = ast.parse("from codeflash.verification.codeflash_capture import codeflash_capture").body[0]
|
||||
|
||||
def visit_ImportFrom(self, node: ast.ImportFrom) -> ast.ImportFrom:
|
||||
# Check if our import already exists
|
||||
if node.module == "codeflash.verification.codeflash_capture" and any(
|
||||
|
|
@ -128,10 +145,20 @@ class InitDecorator(ast.NodeTransformer):
|
|||
|
||||
def visit_Module(self, node: ast.Module) -> ast.Module:
|
||||
self.generic_visit(node)
|
||||
|
||||
# Insert module-level monkey-patch wrappers for attrs classes immediately after their
|
||||
# class definitions. We do this before inserting the import so indices stay stable.
|
||||
if self._attrs_classes_to_patch:
|
||||
new_body: list[ast.stmt] = []
|
||||
for stmt in node.body:
|
||||
new_body.append(stmt)
|
||||
if isinstance(stmt, ast.ClassDef) and stmt.name in self._attrs_classes_to_patch:
|
||||
new_body.extend(self._build_attrs_patch_block(stmt.name, self._attrs_classes_to_patch[stmt.name]))
|
||||
node.body = new_body
|
||||
|
||||
# Add import statement
|
||||
if not self.has_import and self.inserted_decorator:
|
||||
import_stmt = ast.parse("from codeflash.verification.codeflash_capture import codeflash_capture").body[0]
|
||||
node.body.insert(0, import_stmt)
|
||||
node.body.insert(0, self._import_stmt)
|
||||
|
||||
return node
|
||||
|
||||
|
|
@ -171,6 +198,8 @@ class InitDecorator(ast.NodeTransformer):
|
|||
item.decorator_list.insert(0, decorator)
|
||||
self.inserted_decorator = True
|
||||
|
||||
break
|
||||
|
||||
if not has_init:
|
||||
# Skip dataclasses — their __init__ is auto-generated at class creation time and isn't in the AST.
|
||||
# The synthetic __init__ with super().__init__(*args, **kwargs) overrides it and fails because
|
||||
|
|
@ -181,6 +210,18 @@ class InitDecorator(ast.NodeTransformer):
|
|||
dec_name = self._expr_name(dec)
|
||||
if dec_name is not None and dec_name.endswith("dataclass"):
|
||||
return node
|
||||
if dec_name is not None:
|
||||
parts = dec_name.split(".")
|
||||
if len(parts) >= 2 and parts[-2] in _ATTRS_NAMESPACES and parts[-1] in _ATTRS_DECORATOR_NAMES:
|
||||
if isinstance(dec, ast.Call):
|
||||
for kw in dec.keywords:
|
||||
if kw.arg == "init" and isinstance(kw.value, ast.Constant) and kw.value.value is False:
|
||||
return node
|
||||
self._attrs_classes_to_patch[node.name] = decorator
|
||||
self.inserted_decorator = True
|
||||
return node
|
||||
|
||||
# Create super().__init__(*args, **kwargs) call (use prebuilt AST fragments)
|
||||
|
||||
# Skip NamedTuples — their __init__ is synthesized and cannot be overwritten.
|
||||
for base in node.bases:
|
||||
|
|
@ -202,6 +243,60 @@ class InitDecorator(ast.NodeTransformer):
|
|||
|
||||
return node
|
||||
|
||||
def _build_attrs_patch_block(self, class_name: str, decorator: ast.Call) -> list[ast.stmt]:
|
||||
orig_name = f"_codeflash_orig_{class_name}_init"
|
||||
patched_name = f"_codeflash_patched_{class_name}_init"
|
||||
|
||||
# _codeflash_orig_ClassName_init = ClassName.__init__
|
||||
|
||||
# Create class name nodes once
|
||||
class_name_load = ast.Name(id=class_name, ctx=self._load_ctx)
|
||||
|
||||
# _codeflash_orig_ClassName_init = ClassName.__init__
|
||||
save_orig = ast.Assign(
|
||||
targets=[ast.Name(id=orig_name, ctx=self._store_ctx)],
|
||||
value=ast.Attribute(value=class_name_load, attr="__init__", ctx=self._load_ctx),
|
||||
)
|
||||
|
||||
# def _codeflash_patched_ClassName_init(self, *args, **kwargs):
|
||||
# return _codeflash_orig_ClassName_init(self, *args, **kwargs)
|
||||
patched_func = ast.FunctionDef(
|
||||
name=patched_name,
|
||||
args=ast.arguments(
|
||||
posonlyargs=[],
|
||||
args=[self._self_arg_node],
|
||||
vararg=self._args_arg_node,
|
||||
kwonlyargs=[],
|
||||
kw_defaults=[],
|
||||
kwarg=self._kwargs_arg_node,
|
||||
defaults=[],
|
||||
),
|
||||
body=cast(
|
||||
"list[ast.stmt]",
|
||||
[
|
||||
ast.Return(
|
||||
value=ast.Call(
|
||||
func=ast.Name(id=orig_name, ctx=self._load_ctx),
|
||||
args=[self._self_name_load, self._starred_args],
|
||||
keywords=[self._kwargs_keyword],
|
||||
)
|
||||
)
|
||||
],
|
||||
),
|
||||
decorator_list=cast("list[ast.expr]", []),
|
||||
returns=None,
|
||||
)
|
||||
|
||||
# ClassName.__init__ = codeflash_capture(...)(_codeflash_patched_ClassName_init)
|
||||
assign_patched = ast.Assign(
|
||||
targets=[
|
||||
ast.Attribute(value=ast.Name(id=class_name, ctx=self._load_ctx), attr="__init__", ctx=self._store_ctx)
|
||||
],
|
||||
value=ast.Call(func=decorator, args=[ast.Name(id=patched_name, ctx=self._load_ctx)], keywords=[]),
|
||||
)
|
||||
|
||||
return [save_orig, patched_func, assign_patched]
|
||||
|
||||
def _expr_name(self, node: ast.AST) -> str | None:
|
||||
if isinstance(node, ast.Name):
|
||||
return node.id
|
||||
|
|
|
|||
|
|
@ -14,7 +14,11 @@ from typing import TYPE_CHECKING
|
|||
from junitparser.xunit2 import JUnitXml
|
||||
|
||||
from codeflash.cli_cmds.console import console, logger
|
||||
from codeflash.code_utils.code_utils import file_path_from_module_name, module_name_from_file_path
|
||||
from codeflash.code_utils.code_utils import (
|
||||
extract_parameterized_test_index,
|
||||
file_path_from_module_name,
|
||||
module_name_from_file_path,
|
||||
)
|
||||
from codeflash.models.models import FunctionTestInvocation, InvocationId, TestResults
|
||||
|
||||
if TYPE_CHECKING:
|
||||
|
|
@ -140,13 +144,15 @@ def parse_python_test_xml(
|
|||
if class_name is not None and class_name.startswith(test_module_path):
|
||||
test_class = class_name[len(test_module_path) + 1 :]
|
||||
|
||||
loop_index = int(testcase.name.split("[ ")[-1][:-2]) if testcase.name and "[" in testcase.name else 1
|
||||
loop_index = (
|
||||
extract_parameterized_test_index(testcase.name) if testcase.name and "[" in testcase.name else 1
|
||||
)
|
||||
|
||||
timed_out = False
|
||||
if len(testcase.result) > 1:
|
||||
logger.debug(f"!!!!!Multiple results for {testcase.name or '<None>'} in {test_xml_file_path}!!!")
|
||||
if len(testcase.result) == 1:
|
||||
message = testcase.result[0].message.lower()
|
||||
message = (testcase.result[0].message or "").lower()
|
||||
if "failed: timeout >" in message or "timed out" in message:
|
||||
timed_out = True
|
||||
|
||||
|
|
|
|||
|
|
@ -848,7 +848,11 @@ class PythonSupport:
|
|||
# === Test Result Comparison ===
|
||||
|
||||
def compare_test_results(
|
||||
self, original_results_path: Path, candidate_results_path: Path, project_root: Path | None = None
|
||||
self,
|
||||
original_results_path: Path,
|
||||
candidate_results_path: Path,
|
||||
project_root: Path | None = None,
|
||||
project_classpath: str | None = None,
|
||||
) -> tuple[bool, list]:
|
||||
"""Compare test results between original and candidate code.
|
||||
|
||||
|
|
@ -856,6 +860,7 @@ class PythonSupport:
|
|||
original_results_path: Path to original test results.
|
||||
candidate_results_path: Path to candidate test results.
|
||||
project_root: Project root directory.
|
||||
project_classpath: Unused (Java only).
|
||||
|
||||
Returns:
|
||||
Tuple of (are_equivalent, list of TestDiff objects).
|
||||
|
|
|
|||
|
|
@ -714,6 +714,10 @@ class Optimizer:
|
|||
cleanup_paths(paths_to_cleanup)
|
||||
|
||||
def cleanup_temporary_paths(self) -> None:
|
||||
from codeflash.languages.java.test_runner import CompilationCache
|
||||
|
||||
CompilationCache.clear()
|
||||
|
||||
if hasattr(get_run_tmp_file, "tmpdir"):
|
||||
get_run_tmp_file.tmpdir.cleanup()
|
||||
del get_run_tmp_file.tmpdir
|
||||
|
|
|
|||
|
|
@ -26,6 +26,8 @@ dependencies = [
|
|||
"tree-sitter-javascript>=0.23.0",
|
||||
"tree-sitter-typescript>=0.23.0",
|
||||
"tree-sitter-java>=0.23.0",
|
||||
"tree-sitter-groovy>=0.1.0",
|
||||
"tree-sitter-kotlin>=1.0.0",
|
||||
"pytest-timeout>=2.1.0",
|
||||
"tomlkit>=0.11.7",
|
||||
"junitparser>=3.1.0",
|
||||
|
|
|
|||
|
|
@ -5013,6 +5013,68 @@ def process(cfg: ChildConfig) -> str:
|
|||
assert "qualified_name: str" in combined
|
||||
|
||||
|
||||
def test_extract_init_stub_attrs_define(tmp_path: Path) -> None:
|
||||
"""extract_init_stub_from_class produces a synthetic __init__ stub for @attrs.define classes."""
|
||||
source = """
|
||||
import attrs
|
||||
from attrs.validators import instance_of
|
||||
|
||||
@attrs.define(frozen=True)
|
||||
class ImportCST:
|
||||
module: str = attrs.field(converter=str.lower)
|
||||
name: str = attrs.field(validator=[instance_of(str)])
|
||||
as_name: str = attrs.field(validator=[instance_of(str)])
|
||||
|
||||
def to_str(self) -> str:
|
||||
return f"from {self.module} import {self.name}"
|
||||
"""
|
||||
expected = "class ImportCST:\n def __init__(self, module: str, name: str, as_name: str):\n ..."
|
||||
tree = ast.parse(source)
|
||||
stub = extract_init_stub_from_class("ImportCST", source, tree)
|
||||
assert stub == expected
|
||||
|
||||
|
||||
def test_extract_init_stub_attrs_factory_fields(tmp_path: Path) -> None:
|
||||
"""Fields using attrs factory= keyword should appear as optional (= ...) in the stub."""
|
||||
source = """
|
||||
import attrs
|
||||
|
||||
@attrs.define
|
||||
class ClassCST:
|
||||
name: str = attrs.field()
|
||||
methods: list = attrs.field(factory=list)
|
||||
imports: set = attrs.field(factory=set)
|
||||
|
||||
def compute(self) -> int:
|
||||
return len(self.methods)
|
||||
"""
|
||||
expected = "class ClassCST:\n def __init__(self, name: str, methods: list = ..., imports: set = ...):\n ..."
|
||||
tree = ast.parse(source)
|
||||
stub = extract_init_stub_from_class("ClassCST", source, tree)
|
||||
assert stub == expected
|
||||
|
||||
|
||||
def test_extract_init_stub_attrs_init_disabled(tmp_path: Path) -> None:
|
||||
"""When @attrs.define(init=False) but with explicit __init__, the explicit body is returned."""
|
||||
source = """
|
||||
import attrs
|
||||
|
||||
@attrs.define(init=False)
|
||||
class NoAutoInit:
|
||||
x: int = attrs.field()
|
||||
|
||||
def __init__(self, x: int):
|
||||
self.x = x * 2
|
||||
|
||||
def get(self) -> int:
|
||||
return self.x
|
||||
"""
|
||||
expected = "class NoAutoInit:\n def __init__(self, x: int):\n self.x = x * 2"
|
||||
tree = ast.parse(source)
|
||||
stub = extract_init_stub_from_class("NoAutoInit", source, tree)
|
||||
assert stub == expected
|
||||
|
||||
|
||||
def test_enrich_testgen_context_third_party_uses_stubs(tmp_path: Path) -> None:
|
||||
"""Third-party classes should produce compact __init__ stubs, not full class source."""
|
||||
# Use a real third-party package (pydantic) so jedi can actually resolve it
|
||||
|
|
|
|||
|
|
@ -303,7 +303,7 @@ class TestGetGitDiffMultiLanguage(unittest.TestCase):
|
|||
reset_current_language()
|
||||
|
||||
@patch("codeflash.code_utils.git_utils.git.Repo")
|
||||
def test_java_diff_ignored_when_language_is_python(self, mock_repo_cls):
|
||||
def test_java_diff_found_regardless_of_current_language(self, mock_repo_cls):
|
||||
from codeflash.languages.current import reset_current_language, set_current_language
|
||||
|
||||
repo = mock_repo_cls.return_value
|
||||
|
|
@ -311,15 +311,18 @@ class TestGetGitDiffMultiLanguage(unittest.TestCase):
|
|||
repo.working_dir = "/repo"
|
||||
repo.git.diff.return_value = JAVA_ADDITION_DIFF
|
||||
|
||||
# get_git_diff uses all registered extensions, not just the current language's
|
||||
set_current_language("python")
|
||||
try:
|
||||
result = get_git_diff(repo_directory=None, uncommitted_changes=True)
|
||||
assert len(result) == 0
|
||||
assert len(result) == 1
|
||||
key = list(result.keys())[0]
|
||||
assert str(key).endswith("Fibonacci.java")
|
||||
finally:
|
||||
reset_current_language()
|
||||
|
||||
@patch("codeflash.code_utils.git_utils.git.Repo")
|
||||
def test_mixed_lang_diff_filters_by_current_language(self, mock_repo_cls):
|
||||
def test_mixed_lang_diff_returns_all_supported_extensions(self, mock_repo_cls):
|
||||
from codeflash.languages.current import reset_current_language, set_current_language
|
||||
|
||||
repo = mock_repo_cls.return_value
|
||||
|
|
@ -327,23 +330,14 @@ class TestGetGitDiffMultiLanguage(unittest.TestCase):
|
|||
repo.working_dir = "/repo"
|
||||
repo.git.diff.return_value = MIXED_LANG_DIFF
|
||||
|
||||
# When language is Python, only .py file should be found
|
||||
# All supported extensions are returned regardless of current language
|
||||
set_current_language("python")
|
||||
try:
|
||||
result = get_git_diff(repo_directory=None, uncommitted_changes=True)
|
||||
assert len(result) == 1
|
||||
key = list(result.keys())[0]
|
||||
assert str(key).endswith("utils.py")
|
||||
finally:
|
||||
reset_current_language()
|
||||
|
||||
# When language is Java, only .java file should be found
|
||||
set_current_language("java")
|
||||
try:
|
||||
result = get_git_diff(repo_directory=None, uncommitted_changes=True)
|
||||
assert len(result) == 1
|
||||
key = list(result.keys())[0]
|
||||
assert str(key).endswith("App.java")
|
||||
assert len(result) == 2
|
||||
paths = [str(k) for k in result.keys()]
|
||||
assert any(p.endswith("utils.py") for p in paths)
|
||||
assert any(p.endswith("App.java") for p in paths)
|
||||
finally:
|
||||
reset_current_language()
|
||||
|
||||
|
|
|
|||
|
|
@ -2,8 +2,8 @@ from pathlib import Path
|
|||
|
||||
from codeflash.code_utils.code_utils import get_run_tmp_file
|
||||
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
|
||||
from codeflash.models.models import FunctionParent
|
||||
from codeflash.languages.python.instrument_codeflash_capture import instrument_codeflash_capture
|
||||
from codeflash.models.models import FunctionParent
|
||||
|
||||
|
||||
def test_add_codeflash_capture():
|
||||
|
|
@ -499,6 +499,184 @@ class MyTuple(typing.NamedTuple):
|
|||
test_path.unlink(missing_ok=True)
|
||||
|
||||
|
||||
def test_attrs_define_patched_via_module_wrapper():
|
||||
"""@attrs.define classes must NOT get a synthetic body __init__; instead a module-level
|
||||
monkey-patch block is emitted after the class to avoid the __class__ cell TypeError
|
||||
that arises when attrs.define(slots=True) replaces the original class object.
|
||||
"""
|
||||
original_code = """
|
||||
import attrs
|
||||
from attrs.validators import instance_of
|
||||
|
||||
@attrs.define
|
||||
class MyAttrsClass:
|
||||
x: int = attrs.field(validator=[instance_of(int)])
|
||||
y: str = attrs.field(default="hello")
|
||||
|
||||
def compute(self):
|
||||
return self.x
|
||||
"""
|
||||
test_path = (Path(__file__).parent.resolve() / "../code_to_optimize/tests/pytest/test_file.py").resolve()
|
||||
expected = f"""import attrs
|
||||
from attrs.validators import instance_of
|
||||
|
||||
from codeflash.verification.codeflash_capture import codeflash_capture
|
||||
|
||||
|
||||
@attrs.define
|
||||
class MyAttrsClass:
|
||||
x: int = attrs.field(validator=[instance_of(int)])
|
||||
y: str = attrs.field(default='hello')
|
||||
|
||||
def compute(self):
|
||||
return self.x
|
||||
_codeflash_orig_MyAttrsClass_init = MyAttrsClass.__init__
|
||||
|
||||
def _codeflash_patched_MyAttrsClass_init(self, *args, **kwargs):
|
||||
return _codeflash_orig_MyAttrsClass_init(self, *args, **kwargs)
|
||||
MyAttrsClass.__init__ = codeflash_capture(function_name='MyAttrsClass.__init__', tmp_dir_path='{get_run_tmp_file(Path("test_return_values")).as_posix()}', tests_root='{test_path.parent.as_posix()}', is_fto=True)(_codeflash_patched_MyAttrsClass_init)
|
||||
"""
|
||||
test_path.write_text(original_code)
|
||||
|
||||
function = FunctionToOptimize(
|
||||
function_name="compute", file_path=test_path, parents=[FunctionParent(type="ClassDef", name="MyAttrsClass")]
|
||||
)
|
||||
|
||||
try:
|
||||
instrument_codeflash_capture(function, {}, test_path.parent)
|
||||
modified_code = test_path.read_text()
|
||||
assert modified_code.strip() == expected.strip()
|
||||
finally:
|
||||
test_path.unlink(missing_ok=True)
|
||||
|
||||
|
||||
def test_attrs_define_frozen_patched_via_module_wrapper():
|
||||
"""@attrs.define(frozen=True) should also be monkey-patched at module level."""
|
||||
original_code = """
|
||||
import attrs
|
||||
|
||||
@attrs.define(frozen=True)
|
||||
class FrozenPoint:
|
||||
x: float = attrs.field()
|
||||
y: float = attrs.field()
|
||||
|
||||
def distance(self):
|
||||
return (self.x ** 2 + self.y ** 2) ** 0.5
|
||||
"""
|
||||
test_path = (Path(__file__).parent.resolve() / "../code_to_optimize/tests/pytest/test_file.py").resolve()
|
||||
expected = f"""import attrs
|
||||
|
||||
from codeflash.verification.codeflash_capture import codeflash_capture
|
||||
|
||||
|
||||
@attrs.define(frozen=True)
|
||||
class FrozenPoint:
|
||||
x: float = attrs.field()
|
||||
y: float = attrs.field()
|
||||
|
||||
def distance(self):
|
||||
return (self.x ** 2 + self.y ** 2) ** 0.5
|
||||
_codeflash_orig_FrozenPoint_init = FrozenPoint.__init__
|
||||
|
||||
def _codeflash_patched_FrozenPoint_init(self, *args, **kwargs):
|
||||
return _codeflash_orig_FrozenPoint_init(self, *args, **kwargs)
|
||||
FrozenPoint.__init__ = codeflash_capture(function_name='FrozenPoint.__init__', tmp_dir_path='{get_run_tmp_file(Path("test_return_values")).as_posix()}', tests_root='{test_path.parent.as_posix()}', is_fto=True)(_codeflash_patched_FrozenPoint_init)
|
||||
"""
|
||||
test_path.write_text(original_code)
|
||||
|
||||
function = FunctionToOptimize(
|
||||
function_name="distance", file_path=test_path, parents=[FunctionParent(type="ClassDef", name="FrozenPoint")]
|
||||
)
|
||||
|
||||
try:
|
||||
instrument_codeflash_capture(function, {}, test_path.parent)
|
||||
modified_code = test_path.read_text()
|
||||
assert modified_code.strip() == expected.strip()
|
||||
finally:
|
||||
test_path.unlink(missing_ok=True)
|
||||
|
||||
|
||||
def test_attr_s_patched_via_module_wrapper():
|
||||
"""@attr.s classes should also be monkey-patched at module level."""
|
||||
original_code = """
|
||||
import attr
|
||||
|
||||
@attr.s
|
||||
class MyAttrClass:
|
||||
x: int = attr.ib()
|
||||
|
||||
def display(self):
|
||||
return self.x
|
||||
"""
|
||||
test_path = (Path(__file__).parent.resolve() / "../code_to_optimize/tests/pytest/test_file.py").resolve()
|
||||
expected = f"""import attr
|
||||
|
||||
from codeflash.verification.codeflash_capture import codeflash_capture
|
||||
|
||||
|
||||
@attr.s
|
||||
class MyAttrClass:
|
||||
x: int = attr.ib()
|
||||
|
||||
def display(self):
|
||||
return self.x
|
||||
_codeflash_orig_MyAttrClass_init = MyAttrClass.__init__
|
||||
|
||||
def _codeflash_patched_MyAttrClass_init(self, *args, **kwargs):
|
||||
return _codeflash_orig_MyAttrClass_init(self, *args, **kwargs)
|
||||
MyAttrClass.__init__ = codeflash_capture(function_name='MyAttrClass.__init__', tmp_dir_path='{get_run_tmp_file(Path("test_return_values")).as_posix()}', tests_root='{test_path.parent.as_posix()}', is_fto=True)(_codeflash_patched_MyAttrClass_init)
|
||||
"""
|
||||
test_path.write_text(original_code)
|
||||
|
||||
function = FunctionToOptimize(
|
||||
function_name="display", file_path=test_path, parents=[FunctionParent(type="ClassDef", name="MyAttrClass")]
|
||||
)
|
||||
|
||||
try:
|
||||
instrument_codeflash_capture(function, {}, test_path.parent)
|
||||
modified_code = test_path.read_text()
|
||||
assert modified_code.strip() == expected.strip()
|
||||
finally:
|
||||
test_path.unlink(missing_ok=True)
|
||||
|
||||
|
||||
def test_attrs_define_init_false_skipped():
|
||||
"""@attrs.define(init=False) should NOT be monkey-patched because attrs won't generate an __init__."""
|
||||
original_code = """
|
||||
import attrs
|
||||
|
||||
@attrs.define(init=False)
|
||||
class ManualInit:
|
||||
x: int = attrs.field()
|
||||
|
||||
def compute(self):
|
||||
return self.x
|
||||
"""
|
||||
expected = """import attrs
|
||||
|
||||
|
||||
@attrs.define(init=False)
|
||||
class ManualInit:
|
||||
x: int = attrs.field()
|
||||
|
||||
def compute(self):
|
||||
return self.x
|
||||
"""
|
||||
test_path = (Path(__file__).parent.resolve() / "../code_to_optimize/tests/pytest/test_file.py").resolve()
|
||||
test_path.write_text(original_code)
|
||||
|
||||
function = FunctionToOptimize(
|
||||
function_name="compute", file_path=test_path, parents=[FunctionParent(type="ClassDef", name="ManualInit")]
|
||||
)
|
||||
|
||||
try:
|
||||
instrument_codeflash_capture(function, {}, test_path.parent)
|
||||
modified_code = test_path.read_text()
|
||||
assert modified_code.strip() == expected.strip()
|
||||
finally:
|
||||
test_path.unlink(missing_ok=True)
|
||||
|
||||
|
||||
def test_dataclass_with_explicit_init_still_instrumented():
|
||||
"""A dataclass that defines its own __init__ should still be instrumented normally."""
|
||||
original_code = """
|
||||
|
|
|
|||
|
|
@ -1,4 +1,4 @@
|
|||
"""Tests for ensure_multi_module_deps_installed in Java test runner."""
|
||||
"""Tests for MavenStrategy.install_multi_module_deps."""
|
||||
|
||||
import subprocess
|
||||
from pathlib import Path
|
||||
|
|
@ -6,7 +6,7 @@ from unittest.mock import patch
|
|||
|
||||
import pytest
|
||||
|
||||
from codeflash.languages.java.test_runner import _multimodule_deps_installed, ensure_multi_module_deps_installed
|
||||
from codeflash.languages.java.maven_strategy import MavenStrategy, _multimodule_deps_installed
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
|
|
@ -17,21 +17,26 @@ def clear_cache():
|
|||
_multimodule_deps_installed.clear()
|
||||
|
||||
|
||||
def test_skipped_for_single_module():
|
||||
@pytest.fixture()
|
||||
def strategy():
|
||||
return MavenStrategy()
|
||||
|
||||
|
||||
def test_skipped_for_single_module(strategy):
|
||||
"""Single-module projects (test_module=None) should be a no-op."""
|
||||
result = ensure_multi_module_deps_installed(Path("/fake"), None, {})
|
||||
result = strategy.install_multi_module_deps(Path("/fake"), None, {})
|
||||
assert result is True
|
||||
assert len(_multimodule_deps_installed) == 0
|
||||
|
||||
|
||||
@patch("codeflash.languages.java.test_runner.find_maven_executable", return_value="mvn")
|
||||
@patch.object(MavenStrategy, "find_executable", return_value="mvn")
|
||||
@patch("codeflash.languages.java.test_runner._run_cmd_kill_pg_on_timeout")
|
||||
def test_runs_install_command_with_correct_args(mock_run, mock_mvn):
|
||||
def test_runs_install_command_with_correct_args(mock_run, mock_mvn, strategy):
|
||||
"""Should run mvn install -DskipTests -pl <module> -am with validation skip flags."""
|
||||
mock_run.return_value = subprocess.CompletedProcess(args=["mvn"], returncode=0, stdout="", stderr="")
|
||||
|
||||
root = Path("/project")
|
||||
result = ensure_multi_module_deps_installed(root, "guava-tests", {"JAVA_HOME": "/jdk"})
|
||||
result = strategy.install_multi_module_deps(root, "guava-tests", {"JAVA_HOME": "/jdk"})
|
||||
|
||||
assert result is True
|
||||
mock_run.assert_called_once()
|
||||
|
|
@ -43,55 +48,52 @@ def test_runs_install_command_with_correct_args(mock_run, mock_mvn):
|
|||
assert "guava-tests" in cmd
|
||||
assert "-am" in cmd
|
||||
assert "-B" in cmd
|
||||
# Validation skip flags should be present
|
||||
assert "-Drat.skip=true" in cmd
|
||||
assert "-Dcheckstyle.skip=true" in cmd
|
||||
# cwd should be maven_root
|
||||
assert mock_run.call_args[1]["cwd"] == root
|
||||
|
||||
|
||||
@patch("codeflash.languages.java.test_runner.find_maven_executable", return_value="mvn")
|
||||
@patch.object(MavenStrategy, "find_executable", return_value="mvn")
|
||||
@patch("codeflash.languages.java.test_runner._run_cmd_kill_pg_on_timeout")
|
||||
def test_caches_and_does_not_rerun(mock_run, mock_mvn):
|
||||
def test_caches_and_does_not_rerun(mock_run, mock_mvn, strategy):
|
||||
"""Second call with same (root, module) should be cached — no Maven invocation."""
|
||||
mock_run.return_value = subprocess.CompletedProcess(args=["mvn"], returncode=0, stdout="", stderr="")
|
||||
|
||||
root = Path("/project")
|
||||
ensure_multi_module_deps_installed(root, "guava-tests", {})
|
||||
strategy.install_multi_module_deps(root, "guava-tests", {})
|
||||
assert mock_run.call_count == 1
|
||||
|
||||
# Second call — should be cached
|
||||
result = ensure_multi_module_deps_installed(root, "guava-tests", {})
|
||||
result = strategy.install_multi_module_deps(root, "guava-tests", {})
|
||||
assert result is True
|
||||
assert mock_run.call_count == 1 # NOT called again
|
||||
assert mock_run.call_count == 1
|
||||
|
||||
|
||||
@patch("codeflash.languages.java.test_runner.find_maven_executable", return_value="mvn")
|
||||
@patch.object(MavenStrategy, "find_executable", return_value="mvn")
|
||||
@patch("codeflash.languages.java.test_runner._run_cmd_kill_pg_on_timeout")
|
||||
def test_different_modules_not_cached(mock_run, mock_mvn):
|
||||
def test_different_modules_not_cached(mock_run, mock_mvn, strategy):
|
||||
"""Different test modules should each trigger their own install."""
|
||||
mock_run.return_value = subprocess.CompletedProcess(args=["mvn"], returncode=0, stdout="", stderr="")
|
||||
|
||||
root = Path("/project")
|
||||
ensure_multi_module_deps_installed(root, "module-a", {})
|
||||
ensure_multi_module_deps_installed(root, "module-b", {})
|
||||
strategy.install_multi_module_deps(root, "module-a", {})
|
||||
strategy.install_multi_module_deps(root, "module-b", {})
|
||||
assert mock_run.call_count == 2
|
||||
|
||||
|
||||
@patch("codeflash.languages.java.test_runner.find_maven_executable", return_value="mvn")
|
||||
@patch.object(MavenStrategy, "find_executable", return_value="mvn")
|
||||
@patch("codeflash.languages.java.test_runner._run_cmd_kill_pg_on_timeout")
|
||||
def test_returns_false_on_maven_failure(mock_run, mock_mvn):
|
||||
def test_returns_false_on_maven_failure(mock_run, mock_mvn, strategy):
|
||||
"""Non-zero exit code should return False and NOT cache."""
|
||||
mock_run.return_value = subprocess.CompletedProcess(args=["mvn"], returncode=1, stdout="", stderr="BUILD FAILURE")
|
||||
|
||||
root = Path("/project")
|
||||
result = ensure_multi_module_deps_installed(root, "guava-tests", {})
|
||||
result = strategy.install_multi_module_deps(root, "guava-tests", {})
|
||||
assert result is False
|
||||
assert len(_multimodule_deps_installed) == 0
|
||||
|
||||
|
||||
@patch("codeflash.languages.java.test_runner.find_maven_executable", return_value=None)
|
||||
def test_returns_false_when_maven_not_found(mock_mvn):
|
||||
@patch.object(MavenStrategy, "find_executable", return_value=None)
|
||||
def test_returns_false_when_maven_not_found(mock_mvn, strategy):
|
||||
"""Should return False if Maven executable is not found."""
|
||||
result = ensure_multi_module_deps_installed(Path("/fake"), "module", {})
|
||||
result = strategy.install_multi_module_deps(Path("/fake"), "module", {})
|
||||
assert result is False
|
||||
|
|
|
|||
|
|
@ -6,38 +6,36 @@ from unittest.mock import patch
|
|||
|
||||
import pytest
|
||||
|
||||
from codeflash.languages.java.test_runner import _build_test_filter, _run_maven_tests
|
||||
from codeflash.languages.java.maven_strategy import MavenStrategy
|
||||
from codeflash.languages.java.test_runner import _build_test_filter
|
||||
from codeflash.models.models import TestFile, TestFiles, TestType
|
||||
|
||||
|
||||
def test_build_test_filter_with_none_benchmarking_paths():
|
||||
"""Test that _build_test_filter handles None benchmarking paths correctly."""
|
||||
# Create TestFiles with None benchmarking_file_path
|
||||
test_files = TestFiles(
|
||||
test_files=[
|
||||
TestFile(
|
||||
instrumented_behavior_file_path=Path("/tmp/test1__perfinstrumented.java"),
|
||||
benchmarking_file_path=None, # None path!
|
||||
benchmarking_file_path=None,
|
||||
original_file_path=Path("/tmp/test1.java"),
|
||||
test_type=TestType.EXISTING_UNIT_TEST,
|
||||
),
|
||||
TestFile(
|
||||
instrumented_behavior_file_path=Path("/tmp/test2__perfinstrumented.java"),
|
||||
benchmarking_file_path=None, # None path!
|
||||
benchmarking_file_path=None,
|
||||
original_file_path=Path("/tmp/test2.java"),
|
||||
test_type=TestType.EXISTING_UNIT_TEST,
|
||||
),
|
||||
]
|
||||
)
|
||||
|
||||
# In performance mode with None paths, filter should be empty
|
||||
result = _build_test_filter(test_files, mode="performance")
|
||||
assert result == "", f"Expected empty filter but got: {result}"
|
||||
|
||||
|
||||
def test_build_test_filter_with_valid_paths():
|
||||
"""Test that _build_test_filter works correctly with valid paths."""
|
||||
# Create TestFiles with valid paths
|
||||
test_files = TestFiles(
|
||||
test_files=[
|
||||
TestFile(
|
||||
|
|
@ -49,50 +47,46 @@ def test_build_test_filter_with_valid_paths():
|
|||
]
|
||||
)
|
||||
|
||||
# Should produce valid filter
|
||||
result = _build_test_filter(test_files, mode="performance")
|
||||
assert result != "", "Expected non-empty filter"
|
||||
assert "Test1__perfonlyinstrumented" in result
|
||||
|
||||
|
||||
def test_run_maven_tests_raises_on_empty_filter():
|
||||
"""Test that _run_maven_tests raises ValueError when filter is empty."""
|
||||
def test_run_tests_via_build_tool_raises_on_empty_filter():
|
||||
"""Test that MavenStrategy.run_tests_via_build_tool raises ValueError when filter is empty."""
|
||||
strategy = MavenStrategy()
|
||||
project_root = Path("/tmp/test_project")
|
||||
env = {}
|
||||
|
||||
# Create TestFiles with None paths (will produce empty filter)
|
||||
test_files = TestFiles(
|
||||
test_files=[
|
||||
TestFile(
|
||||
instrumented_behavior_file_path=Path("/tmp/test__perfinstrumented.java"),
|
||||
benchmarking_file_path=None, # Will cause empty filter in performance mode
|
||||
benchmarking_file_path=None,
|
||||
original_file_path=Path("/tmp/test.java"),
|
||||
test_type=TestType.EXISTING_UNIT_TEST,
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
# Mock Maven executable
|
||||
with patch("codeflash.languages.java.test_runner.find_maven_executable") as mock_maven:
|
||||
mock_maven.return_value = "mvn"
|
||||
|
||||
# Should raise ValueError due to empty filter
|
||||
with patch.object(MavenStrategy, "find_executable", return_value="mvn"):
|
||||
with pytest.raises(ValueError, match="Test filter is EMPTY"):
|
||||
_run_maven_tests(
|
||||
strategy.run_tests_via_build_tool(
|
||||
project_root,
|
||||
test_files,
|
||||
env,
|
||||
timeout=60,
|
||||
mode="performance", # Performance mode with None benchmarking_file_path
|
||||
mode="performance",
|
||||
test_module=None,
|
||||
)
|
||||
|
||||
|
||||
def test_run_maven_tests_succeeds_with_valid_filter():
|
||||
"""Test that _run_maven_tests works correctly when filter is not empty."""
|
||||
def test_run_tests_via_build_tool_succeeds_with_valid_filter():
|
||||
"""Test that MavenStrategy.run_tests_via_build_tool works correctly when filter is not empty."""
|
||||
strategy = MavenStrategy()
|
||||
project_root = Path("/tmp/test_project")
|
||||
env = {}
|
||||
|
||||
# Create TestFiles with valid paths
|
||||
test_files = TestFiles(
|
||||
test_files=[
|
||||
TestFile(
|
||||
|
|
@ -104,20 +98,18 @@ def test_run_maven_tests_succeeds_with_valid_filter():
|
|||
]
|
||||
)
|
||||
|
||||
# Mock Maven executable and _run_cmd_kill_pg_on_timeout (which replaced subprocess.run)
|
||||
with (
|
||||
patch("codeflash.languages.java.test_runner.find_maven_executable") as mock_maven,
|
||||
patch.object(MavenStrategy, "find_executable", return_value="mvn"),
|
||||
patch("codeflash.languages.java.test_runner._run_cmd_kill_pg_on_timeout") as mock_run,
|
||||
):
|
||||
mock_maven.return_value = "mvn"
|
||||
mock_run.return_value = subprocess.CompletedProcess(
|
||||
args=[], returncode=0, stdout="Tests run: 1, Failures: 0, Errors: 0, Skipped: 0", stderr=""
|
||||
)
|
||||
|
||||
# Should not raise - filter is valid
|
||||
result = _run_maven_tests(project_root, test_files, env, timeout=60, mode="performance")
|
||||
result = strategy.run_tests_via_build_tool(
|
||||
project_root, test_files, env, timeout=60, mode="performance", test_module=None
|
||||
)
|
||||
|
||||
# Verify Maven was called with -Dtest parameter
|
||||
assert mock_run.called
|
||||
cmd = mock_run.call_args[0][0]
|
||||
assert any("-Dtest=" in arg for arg in cmd), f"Expected -Dtest parameter in command: {cmd}"
|
||||
|
|
|
|||
|
|
@ -2,16 +2,17 @@
|
|||
|
||||
import os
|
||||
from pathlib import Path
|
||||
from unittest.mock import patch
|
||||
|
||||
from codeflash.languages.java.build_tools import (
|
||||
BuildTool,
|
||||
add_codeflash_dependency_to_pom,
|
||||
detect_build_tool,
|
||||
find_maven_executable,
|
||||
find_source_root,
|
||||
find_test_root,
|
||||
get_project_info,
|
||||
)
|
||||
from codeflash.languages.java.gradle_strategy import GradleStrategy
|
||||
from codeflash.languages.java.maven_strategy import MavenStrategy, add_codeflash_dependency
|
||||
from codeflash.languages.java.test_runner import _extract_modules_from_pom_content
|
||||
|
||||
|
||||
|
|
@ -175,8 +176,8 @@ class TestMavenExecutable:
|
|||
|
||||
def test_find_maven_executable_system(self):
|
||||
"""Test finding system Maven."""
|
||||
# This test may pass or fail depending on whether Maven is installed
|
||||
mvn = find_maven_executable()
|
||||
strategy = MavenStrategy()
|
||||
mvn = strategy.find_executable(Path())
|
||||
# We can't assert it exists, just that the function doesn't crash
|
||||
if mvn:
|
||||
assert "mvn" in mvn.lower() or "maven" in mvn.lower()
|
||||
|
|
@ -188,10 +189,8 @@ class TestMavenExecutable:
|
|||
mvnw_path.write_text("#!/bin/bash\necho 'Maven Wrapper'")
|
||||
mvnw_path.chmod(0o755)
|
||||
|
||||
# Change to tmp_path
|
||||
monkeypatch.chdir(tmp_path)
|
||||
|
||||
mvn = find_maven_executable()
|
||||
strategy = MavenStrategy()
|
||||
mvn = strategy.find_executable(tmp_path)
|
||||
# Should find the wrapper
|
||||
assert mvn is not None
|
||||
|
||||
|
|
@ -377,24 +376,27 @@ class TestMavenProfiles:
|
|||
|
||||
|
||||
class TestMavenExecutableWithProjectRoot:
|
||||
"""Tests for find_maven_executable with project_root parameter."""
|
||||
"""Tests for MavenStrategy.find_executable with project_root parameter."""
|
||||
|
||||
def test_find_wrapper_in_project_root(self, tmp_path):
|
||||
mvnw_path = tmp_path / "mvnw"
|
||||
mvnw_path.write_text("#!/bin/bash\necho Maven Wrapper")
|
||||
mvnw_path.chmod(0o755)
|
||||
|
||||
result = find_maven_executable(project_root=tmp_path)
|
||||
strategy = MavenStrategy()
|
||||
result = strategy.find_executable(tmp_path)
|
||||
assert result is not None
|
||||
assert str(tmp_path / "mvnw") in result
|
||||
|
||||
def test_fallback_to_cwd_when_no_project_root(self):
|
||||
result = find_maven_executable()
|
||||
# Should not crash even without project_root
|
||||
def test_fallback_to_cwd(self, tmp_path):
|
||||
strategy = MavenStrategy()
|
||||
result = strategy.find_executable(tmp_path)
|
||||
# Should not crash even with a dir that has no wrapper
|
||||
|
||||
def test_project_root_none_uses_cwd(self):
|
||||
result = find_maven_executable(project_root=None)
|
||||
# Should not crash
|
||||
def test_with_nonexistent_wrapper(self, tmp_path):
|
||||
strategy = MavenStrategy()
|
||||
result = strategy.find_executable(tmp_path)
|
||||
# Should not crash, may return system mvn or None
|
||||
|
||||
|
||||
class TestCustomSourceDirectoryDetection:
|
||||
|
|
@ -460,7 +462,7 @@ class TestCustomSourceDirectoryDetection:
|
|||
|
||||
|
||||
class TestAddCodeflashDependencyToPom:
|
||||
"""Tests for add_codeflash_dependency_to_pom, including stale system-scope replacement."""
|
||||
"""Tests for add_codeflash_dependency, including stale system-scope replacement."""
|
||||
|
||||
def test_adds_dependency_to_clean_pom(self, tmp_path):
|
||||
pom = tmp_path / "pom.xml"
|
||||
|
|
@ -477,7 +479,7 @@ class TestAddCodeflashDependencyToPom:
|
|||
"</project>\n",
|
||||
encoding="utf-8",
|
||||
)
|
||||
assert add_codeflash_dependency_to_pom(pom) is True
|
||||
assert add_codeflash_dependency(pom) is True
|
||||
content = pom.read_text(encoding="utf-8")
|
||||
assert "codeflash-runtime" in content
|
||||
assert "<scope>test</scope>" in content
|
||||
|
|
@ -499,7 +501,7 @@ class TestAddCodeflashDependencyToPom:
|
|||
"</project>\n",
|
||||
encoding="utf-8",
|
||||
)
|
||||
assert add_codeflash_dependency_to_pom(pom) is True
|
||||
assert add_codeflash_dependency(pom) is True
|
||||
content = pom.read_text(encoding="utf-8")
|
||||
assert "<scope>test</scope>" in content
|
||||
assert "<scope>system</scope>" not in content
|
||||
|
|
@ -523,7 +525,7 @@ class TestAddCodeflashDependencyToPom:
|
|||
"</project>\n",
|
||||
encoding="utf-8",
|
||||
)
|
||||
assert add_codeflash_dependency_to_pom(pom) is True
|
||||
assert add_codeflash_dependency(pom) is True
|
||||
content = pom.read_text(encoding="utf-8")
|
||||
assert "<scope>test</scope>" in content
|
||||
assert "<scope>system</scope>" not in content
|
||||
|
|
@ -545,17 +547,97 @@ class TestAddCodeflashDependencyToPom:
|
|||
"</project>\n",
|
||||
encoding="utf-8",
|
||||
)
|
||||
assert add_codeflash_dependency_to_pom(pom) is True
|
||||
assert add_codeflash_dependency(pom) is True
|
||||
content = pom.read_text(encoding="utf-8")
|
||||
assert content.count("codeflash-runtime") == 1
|
||||
|
||||
def test_returns_false_for_missing_pom(self, tmp_path):
|
||||
pom = tmp_path / "pom.xml"
|
||||
assert add_codeflash_dependency_to_pom(pom) is False
|
||||
assert add_codeflash_dependency(pom) is False
|
||||
|
||||
def test_returns_false_when_no_dependencies_tag(self, tmp_path):
|
||||
pom = tmp_path / "pom.xml"
|
||||
pom.write_text(
|
||||
'<?xml version="1.0"?>\n<project><modelVersion>4.0.0</modelVersion></project>\n', encoding="utf-8"
|
||||
)
|
||||
assert add_codeflash_dependency_to_pom(pom) is False
|
||||
assert add_codeflash_dependency(pom) is False
|
||||
|
||||
|
||||
class TestGradleEnsureRuntimeMultiModule:
|
||||
"""Tests that ensure_runtime adds the dependency to the correct module build file."""
|
||||
|
||||
def _make_multi_module_project(self, tmp_path):
|
||||
"""Create a multi-module Gradle project with submodule build files."""
|
||||
# Root
|
||||
(tmp_path / "build.gradle.kts").write_text("// root build\n", encoding="utf-8")
|
||||
(tmp_path / "settings.gradle.kts").write_text('include("clients", "streams")', encoding="utf-8")
|
||||
(tmp_path / "gradlew").write_text("#!/bin/sh\necho gradle", encoding="utf-8")
|
||||
(tmp_path / "gradlew").chmod(0o755)
|
||||
# Submodule build files with a dependencies block
|
||||
for module in ["clients", "streams"]:
|
||||
module_dir = tmp_path / module
|
||||
module_dir.mkdir()
|
||||
(module_dir / "build.gradle.kts").write_text(
|
||||
'plugins {\n java\n}\n\ndependencies {\n testImplementation("junit:junit:4.13.2")\n}\n',
|
||||
encoding="utf-8",
|
||||
)
|
||||
return tmp_path
|
||||
|
||||
def test_adds_dependency_to_correct_module_build_file(self, tmp_path):
|
||||
"""When test_module='streams', the dependency must be added to streams/build.gradle.kts."""
|
||||
project = self._make_multi_module_project(tmp_path)
|
||||
|
||||
strategy = GradleStrategy()
|
||||
# Provide a fake runtime JAR
|
||||
fake_jar = tmp_path / "fake-runtime.jar"
|
||||
fake_jar.write_bytes(b"PK\x03\x04") # minimal zip header
|
||||
|
||||
with patch.object(strategy, "find_runtime_jar", return_value=fake_jar):
|
||||
result = strategy.ensure_runtime(project, test_module="streams")
|
||||
|
||||
assert result is True
|
||||
# Dependency should be in streams/build.gradle.kts
|
||||
streams_build = (project / "streams" / "build.gradle.kts").read_text(encoding="utf-8")
|
||||
assert "codeflash-runtime" in streams_build
|
||||
# And NOT in clients/build.gradle.kts or root build.gradle.kts
|
||||
clients_build = (project / "clients" / "build.gradle.kts").read_text(encoding="utf-8")
|
||||
assert "codeflash-runtime" not in clients_build
|
||||
root_build = (project / "build.gradle.kts").read_text(encoding="utf-8")
|
||||
assert "codeflash-runtime" not in root_build
|
||||
|
||||
def test_adds_dependency_to_root_when_no_module(self, tmp_path):
|
||||
"""When test_module=None, the dependency is added to the root build file."""
|
||||
project = self._make_multi_module_project(tmp_path)
|
||||
|
||||
strategy = GradleStrategy()
|
||||
fake_jar = tmp_path / "fake-runtime.jar"
|
||||
fake_jar.write_bytes(b"PK\x03\x04")
|
||||
|
||||
with patch.object(strategy, "find_runtime_jar", return_value=fake_jar):
|
||||
result = strategy.ensure_runtime(project, test_module=None)
|
||||
|
||||
assert result is True
|
||||
root_build = (project / "build.gradle.kts").read_text(encoding="utf-8")
|
||||
assert "codeflash-runtime" in root_build
|
||||
|
||||
def test_adds_dependency_to_nested_module(self, tmp_path):
|
||||
"""When test_module='connect:runtime', the dep goes to connect/runtime/build.gradle.kts."""
|
||||
project = self._make_multi_module_project(tmp_path)
|
||||
# Add nested module
|
||||
nested = tmp_path / "connect" / "runtime"
|
||||
nested.mkdir(parents=True)
|
||||
(nested / "build.gradle.kts").write_text(
|
||||
'plugins {\n java\n}\n\ndependencies {\n testImplementation("junit:junit:4.13.2")\n}\n',
|
||||
encoding="utf-8",
|
||||
)
|
||||
|
||||
strategy = GradleStrategy()
|
||||
fake_jar = tmp_path / "fake-runtime.jar"
|
||||
fake_jar.write_bytes(b"PK\x03\x04")
|
||||
|
||||
with patch.object(strategy, "find_runtime_jar", return_value=fake_jar):
|
||||
result = strategy.ensure_runtime(project, test_module="connect:runtime")
|
||||
|
||||
assert result is True
|
||||
nested_build = (nested / "build.gradle.kts").read_text(encoding="utf-8")
|
||||
assert "codeflash-runtime" in nested_build
|
||||
|
|
|
|||
|
|
@ -4,10 +4,10 @@ from __future__ import annotations
|
|||
|
||||
from pathlib import Path
|
||||
|
||||
from codeflash.languages.java.build_tools import (
|
||||
from codeflash.languages.java.maven_strategy import (
|
||||
JACOCO_PLUGIN_VERSION,
|
||||
add_jacoco_plugin_to_pom,
|
||||
get_jacoco_xml_path,
|
||||
add_jacoco_plugin,
|
||||
get_jacoco_report_path,
|
||||
is_jacoco_configured,
|
||||
)
|
||||
from codeflash.models.models import CodeOptimizationContext, CodeStringsMarkdown, CoverageStatus, FunctionSource
|
||||
|
|
@ -470,7 +470,7 @@ class TestJacocoPluginAddition:
|
|||
pom_path.write_text(POM_MINIMAL)
|
||||
|
||||
# Add JaCoCo plugin
|
||||
result = add_jacoco_plugin_to_pom(pom_path)
|
||||
result = add_jacoco_plugin(pom_path)
|
||||
assert result is True
|
||||
|
||||
# Verify it's now configured
|
||||
|
|
@ -483,13 +483,13 @@ class TestJacocoPluginAddition:
|
|||
assert "prepare-agent" in content
|
||||
assert "report" in content
|
||||
|
||||
def test_add_jacoco_plugin_to_pom_with_build(self, tmp_path: Path) -> None:
|
||||
def test_add_jacoco_plugin_with_build(self, tmp_path: Path) -> None:
|
||||
"""Test adding JaCoCo to pom.xml that has a build section."""
|
||||
pom_path = tmp_path / "pom.xml"
|
||||
pom_path.write_text(POM_WITHOUT_JACOCO)
|
||||
|
||||
# Add JaCoCo plugin
|
||||
result = add_jacoco_plugin_to_pom(pom_path)
|
||||
result = add_jacoco_plugin(pom_path)
|
||||
assert result is True
|
||||
|
||||
# Verify it's now configured
|
||||
|
|
@ -501,7 +501,7 @@ class TestJacocoPluginAddition:
|
|||
pom_path.write_text(POM_WITH_JACOCO)
|
||||
|
||||
# Try to add JaCoCo plugin
|
||||
result = add_jacoco_plugin_to_pom(pom_path)
|
||||
result = add_jacoco_plugin(pom_path)
|
||||
assert result is True # Should succeed (already present)
|
||||
|
||||
# Verify it's still configured
|
||||
|
|
@ -513,7 +513,7 @@ class TestJacocoPluginAddition:
|
|||
pom_path.write_text(POM_NO_NAMESPACE)
|
||||
|
||||
# Add JaCoCo plugin
|
||||
result = add_jacoco_plugin_to_pom(pom_path)
|
||||
result = add_jacoco_plugin(pom_path)
|
||||
assert result is True
|
||||
|
||||
# Verify it's now configured
|
||||
|
|
@ -523,7 +523,7 @@ class TestJacocoPluginAddition:
|
|||
"""Test adding JaCoCo when pom.xml doesn't exist."""
|
||||
pom_path = tmp_path / "pom.xml"
|
||||
|
||||
result = add_jacoco_plugin_to_pom(pom_path)
|
||||
result = add_jacoco_plugin(pom_path)
|
||||
assert result is False
|
||||
|
||||
def test_add_jacoco_plugin_invalid_xml(self, tmp_path: Path) -> None:
|
||||
|
|
@ -531,16 +531,16 @@ class TestJacocoPluginAddition:
|
|||
pom_path = tmp_path / "pom.xml"
|
||||
pom_path.write_text("this is not valid xml")
|
||||
|
||||
result = add_jacoco_plugin_to_pom(pom_path)
|
||||
result = add_jacoco_plugin(pom_path)
|
||||
assert result is False
|
||||
|
||||
|
||||
class TestJacocoXmlPath:
|
||||
"""Tests for JaCoCo XML path resolution."""
|
||||
|
||||
def test_get_jacoco_xml_path(self, tmp_path: Path) -> None:
|
||||
def test_get_jacoco_report_path(self, tmp_path: Path) -> None:
|
||||
"""Test getting the expected JaCoCo XML path."""
|
||||
path = get_jacoco_xml_path(tmp_path)
|
||||
path = get_jacoco_report_path(tmp_path)
|
||||
|
||||
assert path == tmp_path / "target" / "site" / "jacoco" / "jacoco.xml"
|
||||
|
||||
|
|
|
|||
|
|
@ -22,7 +22,7 @@ os.environ["CODEFLASH_API_KEY"] = "cf-test-key"
|
|||
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
|
||||
from codeflash.languages.base import Language
|
||||
from codeflash.languages.current import set_current_language
|
||||
from codeflash.languages.java.build_tools import find_maven_executable
|
||||
from codeflash.languages.java.maven_strategy import MavenStrategy
|
||||
from codeflash.languages.java.discovery import discover_functions_from_source
|
||||
from codeflash.languages.java.instrumentation import (
|
||||
_add_behavior_instrumentation,
|
||||
|
|
@ -1968,7 +1968,7 @@ public class AccentTest {
|
|||
|
||||
# Skip all E2E tests if Maven is not available
|
||||
requires_maven = pytest.mark.skipif(
|
||||
find_maven_executable() is None, reason="Maven not found - skipping execution tests"
|
||||
MavenStrategy().find_executable(Path(".")) is None, reason="Maven not found - skipping execution tests"
|
||||
)
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -3,7 +3,14 @@
|
|||
from pathlib import Path
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
from codeflash.languages.java.test_runner import _extract_source_dirs_from_pom, _path_to_class_name
|
||||
from codeflash.languages.java.build_tool_strategy import module_to_dir
|
||||
from codeflash.languages.java.test_runner import (
|
||||
_extract_custom_source_dirs,
|
||||
_extract_modules_from_settings_gradle,
|
||||
_find_multi_module_root,
|
||||
_match_module_from_rel_path,
|
||||
_path_to_class_name,
|
||||
)
|
||||
|
||||
|
||||
class TestGetJavaSourcesRoot:
|
||||
|
|
@ -13,12 +20,12 @@ class TestGetJavaSourcesRoot:
|
|||
"""Create a mock FunctionOptimizer with the given tests_root."""
|
||||
from codeflash.languages.java.function_optimizer import JavaFunctionOptimizer
|
||||
|
||||
# Create a minimal mock
|
||||
mock_optimizer = MagicMock(spec=JavaFunctionOptimizer)
|
||||
mock_optimizer.test_cfg = MagicMock()
|
||||
mock_optimizer.test_cfg.tests_root = Path(tests_root)
|
||||
mock_optimizer.function_to_optimize = MagicMock()
|
||||
mock_optimizer.function_to_optimize.file_path = Path("/nonexistent/Foo.java")
|
||||
|
||||
# Bind the actual method to the mock
|
||||
mock_optimizer._get_java_sources_root = lambda: JavaFunctionOptimizer._get_java_sources_root(mock_optimizer)
|
||||
|
||||
return mock_optimizer
|
||||
|
|
@ -87,6 +94,168 @@ class TestGetJavaSourcesRoot:
|
|||
assert result == Path("/Users/test/Work/aerospike-client-java/test/src")
|
||||
|
||||
|
||||
class TestGetJavaSourcesRootMultiModule:
|
||||
"""Tests for _get_java_sources_root with multi-module projects."""
|
||||
|
||||
def _create_mock_optimizer(self, tests_root: str, file_path: str):
|
||||
from codeflash.languages.java.function_optimizer import JavaFunctionOptimizer
|
||||
|
||||
mock_optimizer = MagicMock(spec=JavaFunctionOptimizer)
|
||||
mock_optimizer.test_cfg = MagicMock()
|
||||
mock_optimizer.test_cfg.tests_root = Path(tests_root)
|
||||
mock_optimizer.function_to_optimize = MagicMock()
|
||||
mock_optimizer.function_to_optimize.file_path = Path(file_path)
|
||||
mock_optimizer._get_java_sources_root = lambda: JavaFunctionOptimizer._get_java_sources_root(mock_optimizer)
|
||||
return mock_optimizer
|
||||
|
||||
def test_kafka_streams_module(self, tmp_path):
|
||||
"""Kafka: function in streams module should use streams test dir, not clients."""
|
||||
(tmp_path / "streams" / "src" / "main" / "java").mkdir(parents=True)
|
||||
(tmp_path / "streams" / "src" / "test" / "java").mkdir(parents=True)
|
||||
(tmp_path / "clients" / "src" / "test" / "java").mkdir(parents=True)
|
||||
|
||||
optimizer = self._create_mock_optimizer(
|
||||
tests_root=str(tmp_path / "clients" / "src" / "test" / "java"),
|
||||
file_path=str(
|
||||
tmp_path
|
||||
/ "streams"
|
||||
/ "src"
|
||||
/ "main"
|
||||
/ "java"
|
||||
/ "org"
|
||||
/ "apache"
|
||||
/ "kafka"
|
||||
/ "streams"
|
||||
/ "query"
|
||||
/ "QueryConfig.java"
|
||||
),
|
||||
)
|
||||
result = optimizer._get_java_sources_root()
|
||||
assert result == tmp_path / "streams" / "src" / "test" / "java"
|
||||
|
||||
def test_kafka_connect_module(self, tmp_path):
|
||||
"""Kafka: function in connect/runtime should use connect/runtime test dir."""
|
||||
(tmp_path / "connect" / "runtime" / "src" / "main" / "java").mkdir(parents=True)
|
||||
(tmp_path / "connect" / "runtime" / "src" / "test" / "java").mkdir(parents=True)
|
||||
(tmp_path / "clients" / "src" / "test" / "java").mkdir(parents=True)
|
||||
|
||||
optimizer = self._create_mock_optimizer(
|
||||
tests_root=str(tmp_path / "clients" / "src" / "test" / "java"),
|
||||
file_path=str(
|
||||
tmp_path
|
||||
/ "connect"
|
||||
/ "runtime"
|
||||
/ "src"
|
||||
/ "main"
|
||||
/ "java"
|
||||
/ "org"
|
||||
/ "apache"
|
||||
/ "kafka"
|
||||
/ "connect"
|
||||
/ "runtime"
|
||||
/ "Worker.java"
|
||||
),
|
||||
)
|
||||
result = optimizer._get_java_sources_root()
|
||||
assert result == tmp_path / "connect" / "runtime" / "src" / "test" / "java"
|
||||
|
||||
def test_kafka_clients_module_same_as_config(self, tmp_path):
|
||||
"""Kafka: function in clients module should still use clients test dir."""
|
||||
(tmp_path / "clients" / "src" / "main" / "java").mkdir(parents=True)
|
||||
(tmp_path / "clients" / "src" / "test" / "java").mkdir(parents=True)
|
||||
|
||||
optimizer = self._create_mock_optimizer(
|
||||
tests_root=str(tmp_path / "clients" / "src" / "test" / "java"),
|
||||
file_path=str(
|
||||
tmp_path
|
||||
/ "clients"
|
||||
/ "src"
|
||||
/ "main"
|
||||
/ "java"
|
||||
/ "org"
|
||||
/ "apache"
|
||||
/ "kafka"
|
||||
/ "common"
|
||||
/ "utils"
|
||||
/ "Bytes.java"
|
||||
),
|
||||
)
|
||||
result = optimizer._get_java_sources_root()
|
||||
assert result == tmp_path / "clients" / "src" / "test" / "java"
|
||||
|
||||
def test_opensearch_libs_module(self, tmp_path):
|
||||
"""OpenSearch: function in libs/core should use libs/core test dir."""
|
||||
(tmp_path / "libs" / "core" / "src" / "main" / "java").mkdir(parents=True)
|
||||
(tmp_path / "libs" / "core" / "src" / "test" / "java").mkdir(parents=True)
|
||||
(tmp_path / "server" / "src" / "test" / "java").mkdir(parents=True)
|
||||
|
||||
optimizer = self._create_mock_optimizer(
|
||||
tests_root=str(tmp_path / "server" / "src" / "test" / "java"),
|
||||
file_path=str(
|
||||
tmp_path
|
||||
/ "libs"
|
||||
/ "core"
|
||||
/ "src"
|
||||
/ "main"
|
||||
/ "java"
|
||||
/ "org"
|
||||
/ "opensearch"
|
||||
/ "core"
|
||||
/ "common"
|
||||
/ "Strings.java"
|
||||
),
|
||||
)
|
||||
result = optimizer._get_java_sources_root()
|
||||
assert result == tmp_path / "libs" / "core" / "src" / "test" / "java"
|
||||
|
||||
def test_spring_boot_subproject(self, tmp_path):
|
||||
"""Spring Boot: function in autoconfigure should use autoconfigure test dir."""
|
||||
(tmp_path / "spring-boot-autoconfigure" / "src" / "main" / "java").mkdir(parents=True)
|
||||
(tmp_path / "spring-boot-autoconfigure" / "src" / "test" / "java").mkdir(parents=True)
|
||||
(tmp_path / "spring-boot" / "src" / "test" / "java").mkdir(parents=True)
|
||||
|
||||
optimizer = self._create_mock_optimizer(
|
||||
tests_root=str(tmp_path / "spring-boot" / "src" / "test" / "java"),
|
||||
file_path=str(
|
||||
tmp_path
|
||||
/ "spring-boot-autoconfigure"
|
||||
/ "src"
|
||||
/ "main"
|
||||
/ "java"
|
||||
/ "org"
|
||||
/ "springframework"
|
||||
/ "boot"
|
||||
/ "autoconfigure"
|
||||
/ "web"
|
||||
/ "ServerProperties.java"
|
||||
),
|
||||
)
|
||||
result = optimizer._get_java_sources_root()
|
||||
assert result == tmp_path / "spring-boot-autoconfigure" / "src" / "test" / "java"
|
||||
|
||||
def test_fallback_when_derived_test_dir_missing(self, tmp_path):
|
||||
"""When derived test dir doesn't exist, fall back to tests_root logic."""
|
||||
(tmp_path / "module-a" / "src" / "main" / "java").mkdir(parents=True)
|
||||
# Deliberately NOT creating module-a/src/test/java
|
||||
tests_root = tmp_path / "src" / "test" / "java"
|
||||
tests_root.mkdir(parents=True)
|
||||
|
||||
optimizer = self._create_mock_optimizer(
|
||||
tests_root=str(tests_root),
|
||||
file_path=str(tmp_path / "module-a" / "src" / "main" / "java" / "com" / "example" / "Foo.java"),
|
||||
)
|
||||
result = optimizer._get_java_sources_root()
|
||||
assert result == tests_root
|
||||
|
||||
def test_non_standard_layout_falls_through(self, tmp_path):
|
||||
"""Non-standard layout (no src/main/java) falls through to existing logic."""
|
||||
optimizer = self._create_mock_optimizer(
|
||||
tests_root=str(tmp_path / "custom" / "tests"), file_path=str(tmp_path / "custom" / "src" / "Foo.java")
|
||||
)
|
||||
result = optimizer._get_java_sources_root()
|
||||
assert result == tmp_path / "custom" / "tests"
|
||||
|
||||
|
||||
class TestFixJavaTestPathsIntegration:
|
||||
"""Integration tests for _fix_java_test_paths with the path fix."""
|
||||
|
||||
|
|
@ -97,8 +266,9 @@ class TestFixJavaTestPathsIntegration:
|
|||
mock_optimizer = MagicMock(spec=JavaFunctionOptimizer)
|
||||
mock_optimizer.test_cfg = MagicMock()
|
||||
mock_optimizer.test_cfg.tests_root = Path(tests_root)
|
||||
mock_optimizer.function_to_optimize = MagicMock()
|
||||
mock_optimizer.function_to_optimize.file_path = Path("/nonexistent/Foo.java")
|
||||
|
||||
# Bind the actual methods
|
||||
mock_optimizer._get_java_sources_root = lambda: JavaFunctionOptimizer._get_java_sources_root(mock_optimizer)
|
||||
mock_optimizer._fix_java_test_paths = lambda behavior_source, perf_source, used_paths, display_source="": (
|
||||
JavaFunctionOptimizer._fix_java_test_paths(
|
||||
|
|
@ -245,7 +415,7 @@ class TestExtractSourceDirsFromPom:
|
|||
</project>
|
||||
"""
|
||||
(tmp_path / "pom.xml").write_text(pom_content)
|
||||
dirs = _extract_source_dirs_from_pom(tmp_path)
|
||||
dirs = _extract_custom_source_dirs(tmp_path)
|
||||
assert "src/main/custom" in dirs
|
||||
assert "src/test/custom" in dirs
|
||||
|
||||
|
|
@ -259,11 +429,11 @@ class TestExtractSourceDirsFromPom:
|
|||
</project>
|
||||
"""
|
||||
(tmp_path / "pom.xml").write_text(pom_content)
|
||||
dirs = _extract_source_dirs_from_pom(tmp_path)
|
||||
dirs = _extract_custom_source_dirs(tmp_path)
|
||||
assert dirs == []
|
||||
|
||||
def test_no_pom_returns_empty(self, tmp_path):
|
||||
dirs = _extract_source_dirs_from_pom(tmp_path)
|
||||
dirs = _extract_custom_source_dirs(tmp_path)
|
||||
assert dirs == []
|
||||
|
||||
def test_pom_without_build_section(self, tmp_path):
|
||||
|
|
@ -273,10 +443,191 @@ class TestExtractSourceDirsFromPom:
|
|||
</project>
|
||||
"""
|
||||
(tmp_path / "pom.xml").write_text(pom_content)
|
||||
dirs = _extract_source_dirs_from_pom(tmp_path)
|
||||
dirs = _extract_custom_source_dirs(tmp_path)
|
||||
assert dirs == []
|
||||
|
||||
def test_malformed_xml(self, tmp_path):
|
||||
(tmp_path / "pom.xml").write_text("this is not valid xml <<<<")
|
||||
dirs = _extract_source_dirs_from_pom(tmp_path)
|
||||
dirs = _extract_custom_source_dirs(tmp_path)
|
||||
assert dirs == []
|
||||
|
||||
|
||||
class TestMatchModuleFromRelPath:
|
||||
"""Tests for _match_module_from_rel_path."""
|
||||
|
||||
def test_simple_module(self):
|
||||
assert _match_module_from_rel_path(Path("streams/src/test/java/Test.java"), ["streams", "clients"]) == "streams"
|
||||
|
||||
def test_nested_module(self):
|
||||
result = _match_module_from_rel_path(
|
||||
Path("connect/runtime/src/test/java/Test.java"), ["connect:runtime", "streams"]
|
||||
)
|
||||
assert result == "connect:runtime"
|
||||
|
||||
def test_no_match(self):
|
||||
assert _match_module_from_rel_path(Path("unknown/src/Test.java"), ["streams", "clients"]) is None
|
||||
|
||||
def test_partial_name_no_false_match(self):
|
||||
"""'streams-ng' should not match module 'streams'."""
|
||||
assert _match_module_from_rel_path(Path("streams-ng/src/Test.java"), ["streams"]) is None
|
||||
|
||||
|
||||
class TestModuleToDir:
|
||||
"""Tests for module_to_dir."""
|
||||
|
||||
def test_simple(self):
|
||||
assert module_to_dir("streams") == "streams"
|
||||
|
||||
def test_nested(self):
|
||||
result = module_to_dir("connect:runtime")
|
||||
assert result == "connect" + "/" + "runtime" or result == "connect" + "\\" + "runtime"
|
||||
|
||||
|
||||
class TestExtractModulesFromSettingsGradle:
|
||||
"""Tests for _extract_modules_from_settings_gradle."""
|
||||
|
||||
def test_simple_top_level_modules(self):
|
||||
content = """include("streams", "clients", "tools")"""
|
||||
modules = _extract_modules_from_settings_gradle(content)
|
||||
assert "streams" in modules
|
||||
assert "clients" in modules
|
||||
assert "tools" in modules
|
||||
|
||||
def test_nested_gradle_modules(self):
|
||||
"""Nested modules like connect:runtime should be extracted."""
|
||||
content = """include("connect:runtime", "connect:api", "streams")"""
|
||||
modules = _extract_modules_from_settings_gradle(content)
|
||||
assert "connect:runtime" in modules
|
||||
assert "connect:api" in modules
|
||||
assert "streams" in modules
|
||||
|
||||
def test_leading_colon_stripped(self):
|
||||
content = """include(":streams", ":clients")"""
|
||||
modules = _extract_modules_from_settings_gradle(content)
|
||||
assert "streams" in modules
|
||||
assert "clients" in modules
|
||||
|
||||
|
||||
class TestFindMultiModuleRoot:
|
||||
"""Tests for _find_multi_module_root with Gradle multi-module projects."""
|
||||
|
||||
def _make_kafka_like_project(self, tmp_path):
|
||||
"""Create a Kafka-like multi-module Gradle project structure."""
|
||||
# Root build files
|
||||
(tmp_path / "build.gradle.kts").write_text("// root build", encoding="utf-8")
|
||||
(tmp_path / "settings.gradle.kts").write_text(
|
||||
'include("clients", "streams", "tools", "connect:runtime")', encoding="utf-8"
|
||||
)
|
||||
# Module build files and source/test dirs
|
||||
for module in ["clients", "streams", "tools"]:
|
||||
(tmp_path / module / "src" / "main" / "java").mkdir(parents=True)
|
||||
(tmp_path / module / "src" / "test" / "java").mkdir(parents=True)
|
||||
(tmp_path / module / "build.gradle.kts").write_text(f"// {module} build", encoding="utf-8")
|
||||
# Nested module
|
||||
(tmp_path / "connect" / "runtime" / "src" / "main" / "java").mkdir(parents=True)
|
||||
(tmp_path / "connect" / "runtime" / "src" / "test" / "java").mkdir(parents=True)
|
||||
(tmp_path / "connect" / "runtime" / "build.gradle.kts").write_text("// connect:runtime build", encoding="utf-8")
|
||||
|
||||
def _make_test_paths_mock(self, file_paths: list[Path]):
|
||||
"""Create a mock test_paths object with test_files."""
|
||||
mock = MagicMock()
|
||||
mock.test_files = []
|
||||
for fp in file_paths:
|
||||
tf = MagicMock()
|
||||
tf.benchmarking_file_path = None
|
||||
tf.instrumented_behavior_file_path = fp
|
||||
mock.test_files.append(tf)
|
||||
return mock
|
||||
|
||||
def test_streams_tests_return_streams_module(self, tmp_path):
|
||||
"""When ALL test files are in streams/, should return 'streams' module."""
|
||||
self._make_kafka_like_project(tmp_path)
|
||||
test_file = tmp_path / "streams" / "src" / "test" / "java" / "org" / "apache" / "kafka" / "StreamsTest.java"
|
||||
test_file.parent.mkdir(parents=True, exist_ok=True)
|
||||
test_file.touch()
|
||||
|
||||
test_paths = self._make_test_paths_mock([test_file])
|
||||
build_root, test_module = _find_multi_module_root(tmp_path, test_paths)
|
||||
|
||||
assert build_root == tmp_path
|
||||
assert test_module == "streams", f"Expected 'streams' but got '{test_module}'"
|
||||
|
||||
def test_tools_tests_return_tools_module(self, tmp_path):
|
||||
"""When test files are in tools/, should return 'tools' module."""
|
||||
self._make_kafka_like_project(tmp_path)
|
||||
test_file = tmp_path / "tools" / "src" / "test" / "java" / "org" / "apache" / "kafka" / "ToolsTest.java"
|
||||
test_file.parent.mkdir(parents=True, exist_ok=True)
|
||||
test_file.touch()
|
||||
|
||||
test_paths = self._make_test_paths_mock([test_file])
|
||||
build_root, test_module = _find_multi_module_root(tmp_path, test_paths)
|
||||
|
||||
assert build_root == tmp_path
|
||||
assert test_module == "tools", f"Expected 'tools' but got '{test_module}'"
|
||||
|
||||
def test_mixed_modules_majority_wins(self, tmp_path):
|
||||
"""When tests span multiple modules, the module with the most test files wins."""
|
||||
self._make_kafka_like_project(tmp_path)
|
||||
clients_test = tmp_path / "clients" / "src" / "test" / "java" / "com" / "ClientsTest.java"
|
||||
clients_test.parent.mkdir(parents=True, exist_ok=True)
|
||||
clients_test.touch()
|
||||
streams_test_1 = tmp_path / "streams" / "src" / "test" / "java" / "com" / "StreamsTest1.java"
|
||||
streams_test_1.parent.mkdir(parents=True, exist_ok=True)
|
||||
streams_test_1.touch()
|
||||
streams_test_2 = tmp_path / "streams" / "src" / "test" / "java" / "com" / "StreamsTest2.java"
|
||||
streams_test_2.touch()
|
||||
|
||||
# 1 clients test + 2 streams tests → streams wins by majority
|
||||
test_paths = self._make_test_paths_mock([clients_test, streams_test_1, streams_test_2])
|
||||
build_root, test_module = _find_multi_module_root(tmp_path, test_paths)
|
||||
|
||||
assert build_root == tmp_path
|
||||
assert test_module == "streams"
|
||||
|
||||
def test_mixed_modules_equal_count_deterministic(self, tmp_path):
|
||||
"""When modules are tied, a module is still selected (not None)."""
|
||||
self._make_kafka_like_project(tmp_path)
|
||||
clients_test = tmp_path / "clients" / "src" / "test" / "java" / "com" / "ClientsTest.java"
|
||||
clients_test.parent.mkdir(parents=True, exist_ok=True)
|
||||
clients_test.touch()
|
||||
streams_test = tmp_path / "streams" / "src" / "test" / "java" / "com" / "StreamsTest.java"
|
||||
streams_test.parent.mkdir(parents=True, exist_ok=True)
|
||||
streams_test.touch()
|
||||
|
||||
test_paths = self._make_test_paths_mock([clients_test, streams_test])
|
||||
build_root, test_module = _find_multi_module_root(tmp_path, test_paths)
|
||||
|
||||
assert build_root == tmp_path
|
||||
assert test_module in ("clients", "streams")
|
||||
|
||||
def test_nested_module_connect_runtime(self, tmp_path):
|
||||
"""Nested Gradle module 'connect:runtime' (dir connect/runtime/) is matched."""
|
||||
self._make_kafka_like_project(tmp_path)
|
||||
test_file = (
|
||||
tmp_path / "connect" / "runtime" / "src" / "test" / "java" / "org" / "kafka" / "ConnectRuntimeTest.java"
|
||||
)
|
||||
test_file.parent.mkdir(parents=True, exist_ok=True)
|
||||
test_file.touch()
|
||||
|
||||
test_paths = self._make_test_paths_mock([test_file])
|
||||
build_root, test_module = _find_multi_module_root(tmp_path, test_paths)
|
||||
|
||||
assert build_root == tmp_path
|
||||
assert test_module == "connect:runtime"
|
||||
|
||||
def test_project_root_is_submodule_test_outside(self, tmp_path):
|
||||
"""When project_root is a submodule (e.g., kafka/clients) and generated
|
||||
tests are placed in a sibling module (kafka/streams), the function should
|
||||
walk up to find the repo root and return the correct module.
|
||||
"""
|
||||
self._make_kafka_like_project(tmp_path)
|
||||
submodule_root = tmp_path / "clients"
|
||||
test_file = tmp_path / "streams" / "src" / "test" / "java" / "com" / "StreamsTest.java"
|
||||
test_file.parent.mkdir(parents=True, exist_ok=True)
|
||||
test_file.touch()
|
||||
|
||||
test_paths = self._make_test_paths_mock([test_file])
|
||||
build_root, test_module = _find_multi_module_root(submodule_root, test_paths)
|
||||
|
||||
assert build_root == tmp_path
|
||||
assert test_module == "streams"
|
||||
|
|
|
|||
|
|
@ -90,9 +90,10 @@ POM_CONTENT = """<?xml version="1.0" encoding="UTF-8"?>
|
|||
|
||||
|
||||
def skip_if_maven_not_available():
|
||||
from codeflash.languages.java.build_tools import find_maven_executable
|
||||
from codeflash.languages.java.maven_strategy import MavenStrategy
|
||||
|
||||
if not MavenStrategy().find_executable(Path(".")):
|
||||
|
||||
if not find_maven_executable():
|
||||
pytest.skip("Maven not available")
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -81,6 +81,7 @@ class TestInputValidation:
|
|||
|
||||
def test_get_test_run_command_validates_input(self, tmp_path: Path):
|
||||
"""Test that get_test_run_command validates test class names."""
|
||||
(tmp_path / "pom.xml").write_text("<project></project>", encoding="utf-8")
|
||||
# Valid class names should work
|
||||
cmd = get_test_run_command(tmp_path, ["MyTest", "OtherTest"])
|
||||
assert "-Dtest=MyTest,OtherTest" in " ".join(cmd)
|
||||
|
|
|
|||
|
|
@ -9,7 +9,6 @@ from codeflash.languages.java.line_profiler import JavaLineProfiler
|
|||
|
||||
def test_parse_line_profile_results_non_python_java_json():
|
||||
set_current_language(Language.JAVA)
|
||||
|
||||
with TemporaryDirectory() as tmpdir:
|
||||
tmp_path = Path(tmpdir)
|
||||
source_file = tmp_path / "Util.java"
|
||||
|
|
|
|||
Loading…
Reference in a new issue