fix: improve multi-module Gradle detection for dynamic settings.gradle.kts

- Parse listOf(...) patterns in settings.gradle.kts for projects that
  build include lists dynamically (e.g. OpenRewrite)
- Use word boundary in include regex to avoid matching variable names
  like 'includedProjects'
- Break module voting ties using codeflash.toml module-root config,
  so the function's own module is preferred over cross-module tests

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
HeshamHM28 2026-04-07 11:08:16 +00:00
parent 94e1b02597
commit 1fde200bc4
4 changed files with 196 additions and 47 deletions

View file

@ -9,7 +9,6 @@ from __future__ import annotations
import logging
import os
import re
import shutil
import subprocess
import tempfile
import xml.etree.ElementTree as ET
@ -17,7 +16,7 @@ from pathlib import Path
from typing import Any
from codeflash.languages.java.build_tool_strategy import BuildToolStrategy, module_to_dir
from codeflash.languages.java.build_tools import BuildTool, JavaProjectInfo
from codeflash.languages.java.build_tools import CODEFLASH_RUNTIME_VERSION, BuildTool, JavaProjectInfo
_RE_INCLUDE = re.compile(r"""include\s*\(?([^)\n]+)\)?""")
@ -205,8 +204,32 @@ def _is_multimodule_project(build_root: Path) -> bool:
return False
def add_codeflash_dependency_multimodule(build_file: Path, runtime_jar_path: Path) -> bool:
"""Add codeflash-runtime dependency wrapped in a subprojects block for multi-module projects.
_CODEFLASH_MAVEN_COORD = f"com.codeflash:codeflash-runtime:{CODEFLASH_RUNTIME_VERSION}"
def _ensure_maven_central_repo(build_file: Path, content: str) -> str:
"""Ensure mavenCentral() is present in the repositories block. Returns updated content."""
if "mavenCentral()" in content:
return content
is_kts = build_file.name.endswith(".kts")
# Try to find existing repositories block and add mavenCentral() inside it
repo_match = re.search(r"repositories\s*\{", content)
if repo_match:
insert_pos = repo_match.end()
return content[:insert_pos] + "\n mavenCentral()" + content[insert_pos:]
# No repositories block — append one
if is_kts:
content += "\nrepositories {\n mavenCentral()\n}\n"
else:
content += "\nrepositories {\n mavenCentral()\n}\n"
return content
def add_codeflash_dependency_multimodule(build_file: Path) -> bool:
"""Add codeflash-runtime dependency from Maven Central in a subprojects block for multi-module projects.
This avoids adding testImplementation to the root build file directly, which would fail
if the root project doesn't apply the java plugin.
@ -222,14 +245,16 @@ def add_codeflash_dependency_multimodule(build_file: Path, runtime_jar_path: Pat
return True
is_kts = build_file.name.endswith(".kts")
jar_str = str(runtime_jar_path).replace("\\", "/")
if is_kts:
block = (
f"\nsubprojects {{\n"
f' plugins.withId("java") {{\n'
f" repositories {{\n"
f" mavenCentral()\n"
f" }}\n"
f" dependencies {{\n"
f' testImplementation(files("{jar_str}")) // codeflash-runtime\n'
f' testImplementation("{_CODEFLASH_MAVEN_COORD}") // codeflash-runtime\n'
f" }}\n"
f" }}\n"
f"}}\n"
@ -238,8 +263,11 @@ def add_codeflash_dependency_multimodule(build_file: Path, runtime_jar_path: Pat
block = (
f"\nsubprojects {{\n"
f" plugins.withId('java') {{\n"
f" repositories {{\n"
f" mavenCentral()\n"
f" }}\n"
f" dependencies {{\n"
f" testImplementation files('{jar_str}') // codeflash-runtime\n"
f" testImplementation '{_CODEFLASH_MAVEN_COORD}' // codeflash-runtime\n"
f" }}\n"
f" }}\n"
f"}}\n"
@ -255,7 +283,7 @@ def add_codeflash_dependency_multimodule(build_file: Path, runtime_jar_path: Pat
return False
def add_codeflash_dependency(build_file: Path, runtime_jar_path: Path) -> bool:
def add_codeflash_dependency(build_file: Path) -> bool:
if not build_file.exists():
return False
@ -266,13 +294,14 @@ def add_codeflash_dependency(build_file: Path, runtime_jar_path: Path) -> bool:
logger.info("codeflash-runtime dependency already present in %s", build_file.name)
return True
content = _ensure_maven_central_repo(build_file, content)
is_kts = build_file.name.endswith(".kts")
jar_str = str(runtime_jar_path).replace("\\", "/")
if is_kts:
dep_line = f' testImplementation(files("{jar_str}")) // codeflash-runtime\n'
dep_line = f' testImplementation("{_CODEFLASH_MAVEN_COORD}") // codeflash-runtime\n'
else:
dep_line = f" testImplementation files('{jar_str}') // codeflash-runtime\n"
dep_line = f" testImplementation '{_CODEFLASH_MAVEN_COORD}' // codeflash-runtime\n"
# Use tree-sitter to find the top-level dependencies block
insert_pos = _find_top_level_dependencies_block(build_file, content)
@ -284,9 +313,13 @@ def add_codeflash_dependency(build_file: Path, runtime_jar_path: Path) -> bool:
# No existing dependencies block — append one
if is_kts:
content += f'\ndependencies {{\n testImplementation(files("{jar_str}")) // codeflash-runtime\n}}\n'
content += (
f'\ndependencies {{\n testImplementation("{_CODEFLASH_MAVEN_COORD}") // codeflash-runtime\n}}\n'
)
else:
content += f"\ndependencies {{\n testImplementation files('{jar_str}') // codeflash-runtime\n}}\n"
content += (
f"\ndependencies {{\n testImplementation '{_CODEFLASH_MAVEN_COORD}' // codeflash-runtime\n}}\n"
)
build_file.write_text(content, encoding="utf-8")
logger.info("Added codeflash-runtime dependency to %s (new block)", build_file.name)
return True
@ -420,34 +453,21 @@ class GradleStrategy(BuildToolStrategy):
return self.find_wrapper_executable(build_root, ("gradlew", "gradlew.bat"), "gradle")
def ensure_runtime(self, build_root: Path, test_module: str | None) -> bool:
runtime_jar = self.find_runtime_jar()
if runtime_jar is None:
logger.error("codeflash-runtime JAR not found. Generated tests will fail to compile.")
return False
if test_module:
module_root = build_root / module_to_dir(test_module)
else:
module_root = build_root
libs_dir = module_root / "libs"
libs_dir.mkdir(parents=True, exist_ok=True)
dest_jar = libs_dir / "codeflash-runtime-1.0.1.jar"
if not dest_jar.exists():
logger.info("Copying codeflash-runtime JAR to %s", dest_jar)
shutil.copy2(runtime_jar, dest_jar)
build_file = find_gradle_build_file(module_root)
if build_file is None:
logger.warning("No build.gradle(.kts) found at %s, cannot add codeflash-runtime dependency", module_root)
return False
if not test_module and _is_multimodule_project(build_root):
if not add_codeflash_dependency_multimodule(build_file, dest_jar):
if not add_codeflash_dependency_multimodule(build_file):
logger.error("Failed to add codeflash-runtime dependency to %s", build_file)
return False
elif not add_codeflash_dependency(build_file, dest_jar):
elif not add_codeflash_dependency(build_file):
logger.error("Failed to add codeflash-runtime dependency to %s", build_file)
return False

View file

@ -205,12 +205,25 @@ def _extract_modules_from_settings_gradle(content: str) -> list[str]:
Looks for include directives like:
include("module-a", "module-b") // Kotlin DSL
include 'module-a', 'module-b' // Groovy DSL
Also handles dynamic Kotlin DSL patterns like:
val allProjects = listOf("module-a", "module-b")
include(*(allProjects + ...).toTypedArray())
Module names may be prefixed with ':' which is stripped.
"""
modules: list[str] = []
for match in re.findall(r"""include\s*\(?[^)\n]*\)?""", content):
# Standard include(...) directives — word boundary avoids matching variable names
# like 'includedProjects'
for match in re.findall(r"""(?:^|(?<=\s))include\s*\(?[^)\n]*\)?""", content, re.MULTILINE):
for name in re.findall(r"""['"]([^'"]+)['"]""", match):
modules.append(name.lstrip(":"))
# Kotlin DSL: val ... = listOf("module-a", "module-b", ...) spanning multiple lines.
# Used when settings.gradle.kts builds the include list dynamically.
if not modules or not any("/" not in m and "." not in m for m in modules):
for match in re.findall(r"""listOf\s*\(([^)]*)\)""", content, re.DOTALL):
for name in re.findall(r"""['"]([^'"]+)['"]""", match):
stripped = name.lstrip(":")
if stripped not in modules:
modules.append(stripped)
return modules
@ -269,6 +282,50 @@ def _match_module_from_rel_path(rel_path: Path, modules: list[str]) -> str | Non
return None
def _read_config_module_root(project_root: Path) -> str | None:
"""Read module-root from codeflash.toml or pyproject.toml."""
for cfg_name in ("codeflash.toml", "pyproject.toml"):
cfg_path = project_root / cfg_name
if cfg_path.exists():
try:
cfg_text = cfg_path.read_text(encoding="utf-8")
m = re.search(r'module-root\s*=\s*["\']([^"\']+)["\']', cfg_text)
if m:
return m.group(1).strip().strip("/")
except Exception:
pass
return None
def _infer_module_from_config(project_root: Path) -> str | None:
"""Infer the target Gradle module from codeflash config in gradle.properties.
Reads codeflash.moduleRoot or codeflash.testsRoot and extracts the first
path component as the module name. Verifies the module directory has a
build.gradle(.kts) file.
"""
props_file = project_root / "gradle.properties"
if not props_file.exists():
return None
try:
content = props_file.read_text(encoding="utf-8")
except Exception:
return None
for key in ("codeflash.moduleRoot", "codeflash.testsRoot"):
for line in content.splitlines():
line = line.strip()
if line.startswith(key + "="):
value = line.split("=", 1)[1].strip()
# Extract first path component (e.g. "rewrite-core/src/main/java" → "rewrite-core")
candidate = Path(value).parts[0] if Path(value).parts else None
if candidate:
module_dir = project_root / candidate
if (module_dir / "build.gradle.kts").exists() or (module_dir / "build.gradle").exists():
return candidate
return None
def _find_multi_module_root(project_root: Path, test_paths: Any) -> tuple[Path, str | None]:
"""Find the multi-module parent root if tests are in a different module.
@ -287,10 +344,18 @@ def _find_multi_module_root(project_root: Path, test_paths: Any) -> tuple[Path,
test_file_paths.append(test_file.benchmarking_file_path)
elif hasattr(test_file, "instrumented_behavior_file_path") and test_file.instrumented_behavior_file_path:
test_file_paths.append(test_file.instrumented_behavior_file_path)
elif hasattr(test_file, "original_file_path") and test_file.original_file_path:
test_file_paths.append(test_file.original_file_path)
elif isinstance(test_paths, (list, tuple)):
test_file_paths = [Path(p) if isinstance(p, str) else p for p in test_paths]
if not test_file_paths:
# No test file paths available — try to infer the module from codeflash config
# in gradle.properties (e.g. codeflash.moduleRoot=rewrite-core/src/main/java).
module = _infer_module_from_config(project_root)
if module:
logger.info("Inferred module '%s' from codeflash config (no test file paths)", module)
return project_root, module
return project_root, None
test_outside_project = False
@ -320,7 +385,14 @@ def _find_multi_module_root(project_root: Path, test_paths: Any) -> tuple[Path,
module_counts[matched] = module_counts.get(matched, 0) + 1
if module_counts:
best_module = max(module_counts, key=lambda m: module_counts[m])
# On ties, prefer the module matching codeflash.toml module-root
config_module = _read_config_module_root(project_root)
max_count = max(module_counts.values())
tied = [m for m, c in module_counts.items() if c == max_count]
if config_module and config_module in tied:
best_module = config_module
else:
best_module = max(module_counts, key=lambda m: module_counts[m])
logger.debug(
"Detected multi-module project. Root: %s, Module votes: %s, Selected: %s",
project_root,
@ -328,6 +400,31 @@ def _find_multi_module_root(project_root: Path, test_paths: Any) -> tuple[Path,
best_module,
)
return project_root, best_module
# project_root has no sub-modules — check if it is itself a sub-module
# of a parent multi-module project (e.g. rewrite-core/ inside rewrite/).
parent = project_root.parent
while parent != parent.parent:
if _is_build_root(parent):
parent_modules = _detect_modules(parent)
if parent_modules:
try:
rel_path = project_root.relative_to(parent)
matched = _match_module_from_rel_path(rel_path, parent_modules)
if matched:
logger.debug("Detected project_root as sub-module. Root: %s, Module: %s", parent, matched)
return parent, matched
except ValueError:
pass
parent = parent.parent
# Last resort: settings.gradle may use dynamic includes that _detect_modules
# can't parse. Fall back to codeflash config in gradle.properties.
module = _infer_module_from_config(project_root)
if module:
logger.info("Inferred module '%s' from codeflash config (dynamic settings.gradle)", module)
return project_root, module
return project_root, None
current = project_root.parent

View file

@ -588,17 +588,14 @@ class TestGradleEnsureRuntimeMultiModule:
project = self._make_multi_module_project(tmp_path)
strategy = GradleStrategy()
# Provide a fake runtime JAR
fake_jar = tmp_path / "fake-runtime.jar"
fake_jar.write_bytes(b"PK\x03\x04") # minimal zip header
with patch.object(strategy, "find_runtime_jar", return_value=fake_jar):
result = strategy.ensure_runtime(project, test_module="streams")
result = strategy.ensure_runtime(project, test_module="streams")
assert result is True
# Dependency should be in streams/build.gradle.kts
# Dependency should be in streams/build.gradle.kts with Maven Central coordinate
streams_build = (project / "streams" / "build.gradle.kts").read_text(encoding="utf-8")
assert "codeflash-runtime" in streams_build
assert "com.codeflash:codeflash-runtime:" in streams_build
assert "mavenCentral()" in streams_build
# And NOT in clients/build.gradle.kts or root build.gradle.kts
clients_build = (project / "clients" / "build.gradle.kts").read_text(encoding="utf-8")
assert "codeflash-runtime" not in clients_build
@ -610,15 +607,13 @@ class TestGradleEnsureRuntimeMultiModule:
project = self._make_multi_module_project(tmp_path)
strategy = GradleStrategy()
fake_jar = tmp_path / "fake-runtime.jar"
fake_jar.write_bytes(b"PK\x03\x04")
with patch.object(strategy, "find_runtime_jar", return_value=fake_jar):
result = strategy.ensure_runtime(project, test_module=None)
result = strategy.ensure_runtime(project, test_module=None)
assert result is True
root_build = (project / "build.gradle.kts").read_text(encoding="utf-8")
assert "codeflash-runtime" in root_build
assert "com.codeflash:codeflash-runtime:" in root_build
assert "mavenCentral()" in root_build
def test_adds_dependency_to_nested_module(self, tmp_path):
"""When test_module='connect:runtime', the dep goes to connect/runtime/build.gradle.kts."""
@ -632,12 +627,20 @@ class TestGradleEnsureRuntimeMultiModule:
)
strategy = GradleStrategy()
fake_jar = tmp_path / "fake-runtime.jar"
fake_jar.write_bytes(b"PK\x03\x04")
with patch.object(strategy, "find_runtime_jar", return_value=fake_jar):
result = strategy.ensure_runtime(project, test_module="connect:runtime")
result = strategy.ensure_runtime(project, test_module="connect:runtime")
assert result is True
nested_build = (nested / "build.gradle.kts").read_text(encoding="utf-8")
assert "codeflash-runtime" in nested_build
assert "com.codeflash:codeflash-runtime:" in nested_build
assert "mavenCentral()" in nested_build
def test_does_not_copy_jar_to_libs(self, tmp_path):
"""ensure_runtime should NOT copy JARs locally — Gradle resolves from Maven Central."""
project = self._make_multi_module_project(tmp_path)
strategy = GradleStrategy()
strategy.ensure_runtime(project, test_module="streams")
libs_dir = project / "streams" / "libs"
assert not libs_dir.exists()

View file

@ -631,3 +631,32 @@ class TestFindMultiModuleRoot:
assert build_root == tmp_path
assert test_module == "streams"
def test_submodule_as_project_root_with_tests_inside(self, tmp_path):
"""When project_root is a sub-module (e.g. rewrite-core/) and generated tests
are inside it, should walk up to find the real root and detect the module."""
self._make_kafka_like_project(tmp_path)
submodule_root = tmp_path / "clients"
test_file = submodule_root / "src" / "test" / "java" / "com" / "ClientsTest.java"
test_file.parent.mkdir(parents=True, exist_ok=True)
test_file.touch()
test_paths = self._make_test_paths_mock([test_file])
build_root, test_module = _find_multi_module_root(submodule_root, test_paths)
assert build_root == tmp_path
assert test_module == "clients"
def test_submodule_as_project_root_nested_module(self, tmp_path):
"""When project_root is a nested sub-module (connect/runtime), should detect it."""
self._make_kafka_like_project(tmp_path)
submodule_root = tmp_path / "connect" / "runtime"
test_file = submodule_root / "src" / "test" / "java" / "com" / "RuntimeTest.java"
test_file.parent.mkdir(parents=True, exist_ok=True)
test_file.touch()
test_paths = self._make_test_paths_mock([test_file])
build_root, test_module = _find_multi_module_root(submodule_root, test_paths)
assert build_root == tmp_path
assert test_module == "connect:runtime"