fix: improve multi-module Gradle detection for dynamic settings.gradle.kts

- Parse listOf(...) patterns in settings.gradle.kts for projects that
  build include lists dynamically (e.g. OpenRewrite)
- Use word boundary in include regex to avoid matching variable names
  like 'includedProjects'
- Break module voting ties using codeflash.toml module-root config,
  so the function's own module is preferred over cross-module tests

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
HeshamHM28 2026-04-07 11:08:16 +00:00
parent 94e1b02597
commit 1fde200bc4
4 changed files with 196 additions and 47 deletions

View file

@ -9,7 +9,6 @@ from __future__ import annotations
import logging import logging
import os import os
import re import re
import shutil
import subprocess import subprocess
import tempfile import tempfile
import xml.etree.ElementTree as ET import xml.etree.ElementTree as ET
@ -17,7 +16,7 @@ from pathlib import Path
from typing import Any from typing import Any
from codeflash.languages.java.build_tool_strategy import BuildToolStrategy, module_to_dir from codeflash.languages.java.build_tool_strategy import BuildToolStrategy, module_to_dir
from codeflash.languages.java.build_tools import BuildTool, JavaProjectInfo from codeflash.languages.java.build_tools import CODEFLASH_RUNTIME_VERSION, BuildTool, JavaProjectInfo
_RE_INCLUDE = re.compile(r"""include\s*\(?([^)\n]+)\)?""") _RE_INCLUDE = re.compile(r"""include\s*\(?([^)\n]+)\)?""")
@ -205,8 +204,32 @@ def _is_multimodule_project(build_root: Path) -> bool:
return False return False
def add_codeflash_dependency_multimodule(build_file: Path, runtime_jar_path: Path) -> bool: _CODEFLASH_MAVEN_COORD = f"com.codeflash:codeflash-runtime:{CODEFLASH_RUNTIME_VERSION}"
"""Add codeflash-runtime dependency wrapped in a subprojects block for multi-module projects.
def _ensure_maven_central_repo(build_file: Path, content: str) -> str:
"""Ensure mavenCentral() is present in the repositories block. Returns updated content."""
if "mavenCentral()" in content:
return content
is_kts = build_file.name.endswith(".kts")
# Try to find existing repositories block and add mavenCentral() inside it
repo_match = re.search(r"repositories\s*\{", content)
if repo_match:
insert_pos = repo_match.end()
return content[:insert_pos] + "\n mavenCentral()" + content[insert_pos:]
# No repositories block — append one
if is_kts:
content += "\nrepositories {\n mavenCentral()\n}\n"
else:
content += "\nrepositories {\n mavenCentral()\n}\n"
return content
def add_codeflash_dependency_multimodule(build_file: Path) -> bool:
"""Add codeflash-runtime dependency from Maven Central in a subprojects block for multi-module projects.
This avoids adding testImplementation to the root build file directly, which would fail This avoids adding testImplementation to the root build file directly, which would fail
if the root project doesn't apply the java plugin. if the root project doesn't apply the java plugin.
@ -222,14 +245,16 @@ def add_codeflash_dependency_multimodule(build_file: Path, runtime_jar_path: Pat
return True return True
is_kts = build_file.name.endswith(".kts") is_kts = build_file.name.endswith(".kts")
jar_str = str(runtime_jar_path).replace("\\", "/")
if is_kts: if is_kts:
block = ( block = (
f"\nsubprojects {{\n" f"\nsubprojects {{\n"
f' plugins.withId("java") {{\n' f' plugins.withId("java") {{\n'
f" repositories {{\n"
f" mavenCentral()\n"
f" }}\n"
f" dependencies {{\n" f" dependencies {{\n"
f' testImplementation(files("{jar_str}")) // codeflash-runtime\n' f' testImplementation("{_CODEFLASH_MAVEN_COORD}") // codeflash-runtime\n'
f" }}\n" f" }}\n"
f" }}\n" f" }}\n"
f"}}\n" f"}}\n"
@ -238,8 +263,11 @@ def add_codeflash_dependency_multimodule(build_file: Path, runtime_jar_path: Pat
block = ( block = (
f"\nsubprojects {{\n" f"\nsubprojects {{\n"
f" plugins.withId('java') {{\n" f" plugins.withId('java') {{\n"
f" repositories {{\n"
f" mavenCentral()\n"
f" }}\n"
f" dependencies {{\n" f" dependencies {{\n"
f" testImplementation files('{jar_str}') // codeflash-runtime\n" f" testImplementation '{_CODEFLASH_MAVEN_COORD}' // codeflash-runtime\n"
f" }}\n" f" }}\n"
f" }}\n" f" }}\n"
f"}}\n" f"}}\n"
@ -255,7 +283,7 @@ def add_codeflash_dependency_multimodule(build_file: Path, runtime_jar_path: Pat
return False return False
def add_codeflash_dependency(build_file: Path, runtime_jar_path: Path) -> bool: def add_codeflash_dependency(build_file: Path) -> bool:
if not build_file.exists(): if not build_file.exists():
return False return False
@ -266,13 +294,14 @@ def add_codeflash_dependency(build_file: Path, runtime_jar_path: Path) -> bool:
logger.info("codeflash-runtime dependency already present in %s", build_file.name) logger.info("codeflash-runtime dependency already present in %s", build_file.name)
return True return True
content = _ensure_maven_central_repo(build_file, content)
is_kts = build_file.name.endswith(".kts") is_kts = build_file.name.endswith(".kts")
jar_str = str(runtime_jar_path).replace("\\", "/")
if is_kts: if is_kts:
dep_line = f' testImplementation(files("{jar_str}")) // codeflash-runtime\n' dep_line = f' testImplementation("{_CODEFLASH_MAVEN_COORD}") // codeflash-runtime\n'
else: else:
dep_line = f" testImplementation files('{jar_str}') // codeflash-runtime\n" dep_line = f" testImplementation '{_CODEFLASH_MAVEN_COORD}' // codeflash-runtime\n"
# Use tree-sitter to find the top-level dependencies block # Use tree-sitter to find the top-level dependencies block
insert_pos = _find_top_level_dependencies_block(build_file, content) insert_pos = _find_top_level_dependencies_block(build_file, content)
@ -284,9 +313,13 @@ def add_codeflash_dependency(build_file: Path, runtime_jar_path: Path) -> bool:
# No existing dependencies block — append one # No existing dependencies block — append one
if is_kts: if is_kts:
content += f'\ndependencies {{\n testImplementation(files("{jar_str}")) // codeflash-runtime\n}}\n' content += (
f'\ndependencies {{\n testImplementation("{_CODEFLASH_MAVEN_COORD}") // codeflash-runtime\n}}\n'
)
else: else:
content += f"\ndependencies {{\n testImplementation files('{jar_str}') // codeflash-runtime\n}}\n" content += (
f"\ndependencies {{\n testImplementation '{_CODEFLASH_MAVEN_COORD}' // codeflash-runtime\n}}\n"
)
build_file.write_text(content, encoding="utf-8") build_file.write_text(content, encoding="utf-8")
logger.info("Added codeflash-runtime dependency to %s (new block)", build_file.name) logger.info("Added codeflash-runtime dependency to %s (new block)", build_file.name)
return True return True
@ -420,34 +453,21 @@ class GradleStrategy(BuildToolStrategy):
return self.find_wrapper_executable(build_root, ("gradlew", "gradlew.bat"), "gradle") return self.find_wrapper_executable(build_root, ("gradlew", "gradlew.bat"), "gradle")
def ensure_runtime(self, build_root: Path, test_module: str | None) -> bool: def ensure_runtime(self, build_root: Path, test_module: str | None) -> bool:
runtime_jar = self.find_runtime_jar()
if runtime_jar is None:
logger.error("codeflash-runtime JAR not found. Generated tests will fail to compile.")
return False
if test_module: if test_module:
module_root = build_root / module_to_dir(test_module) module_root = build_root / module_to_dir(test_module)
else: else:
module_root = build_root module_root = build_root
libs_dir = module_root / "libs"
libs_dir.mkdir(parents=True, exist_ok=True)
dest_jar = libs_dir / "codeflash-runtime-1.0.1.jar"
if not dest_jar.exists():
logger.info("Copying codeflash-runtime JAR to %s", dest_jar)
shutil.copy2(runtime_jar, dest_jar)
build_file = find_gradle_build_file(module_root) build_file = find_gradle_build_file(module_root)
if build_file is None: if build_file is None:
logger.warning("No build.gradle(.kts) found at %s, cannot add codeflash-runtime dependency", module_root) logger.warning("No build.gradle(.kts) found at %s, cannot add codeflash-runtime dependency", module_root)
return False return False
if not test_module and _is_multimodule_project(build_root): if not test_module and _is_multimodule_project(build_root):
if not add_codeflash_dependency_multimodule(build_file, dest_jar): if not add_codeflash_dependency_multimodule(build_file):
logger.error("Failed to add codeflash-runtime dependency to %s", build_file) logger.error("Failed to add codeflash-runtime dependency to %s", build_file)
return False return False
elif not add_codeflash_dependency(build_file, dest_jar): elif not add_codeflash_dependency(build_file):
logger.error("Failed to add codeflash-runtime dependency to %s", build_file) logger.error("Failed to add codeflash-runtime dependency to %s", build_file)
return False return False

View file

@ -205,12 +205,25 @@ def _extract_modules_from_settings_gradle(content: str) -> list[str]:
Looks for include directives like: Looks for include directives like:
include("module-a", "module-b") // Kotlin DSL include("module-a", "module-b") // Kotlin DSL
include 'module-a', 'module-b' // Groovy DSL include 'module-a', 'module-b' // Groovy DSL
Also handles dynamic Kotlin DSL patterns like:
val allProjects = listOf("module-a", "module-b")
include(*(allProjects + ...).toTypedArray())
Module names may be prefixed with ':' which is stripped. Module names may be prefixed with ':' which is stripped.
""" """
modules: list[str] = [] modules: list[str] = []
for match in re.findall(r"""include\s*\(?[^)\n]*\)?""", content): # Standard include(...) directives — word boundary avoids matching variable names
# like 'includedProjects'
for match in re.findall(r"""(?:^|(?<=\s))include\s*\(?[^)\n]*\)?""", content, re.MULTILINE):
for name in re.findall(r"""['"]([^'"]+)['"]""", match): for name in re.findall(r"""['"]([^'"]+)['"]""", match):
modules.append(name.lstrip(":")) modules.append(name.lstrip(":"))
# Kotlin DSL: val ... = listOf("module-a", "module-b", ...) spanning multiple lines.
# Used when settings.gradle.kts builds the include list dynamically.
if not modules or not any("/" not in m and "." not in m for m in modules):
for match in re.findall(r"""listOf\s*\(([^)]*)\)""", content, re.DOTALL):
for name in re.findall(r"""['"]([^'"]+)['"]""", match):
stripped = name.lstrip(":")
if stripped not in modules:
modules.append(stripped)
return modules return modules
@ -269,6 +282,50 @@ def _match_module_from_rel_path(rel_path: Path, modules: list[str]) -> str | Non
return None return None
def _read_config_module_root(project_root: Path) -> str | None:
"""Read module-root from codeflash.toml or pyproject.toml."""
for cfg_name in ("codeflash.toml", "pyproject.toml"):
cfg_path = project_root / cfg_name
if cfg_path.exists():
try:
cfg_text = cfg_path.read_text(encoding="utf-8")
m = re.search(r'module-root\s*=\s*["\']([^"\']+)["\']', cfg_text)
if m:
return m.group(1).strip().strip("/")
except Exception:
pass
return None
def _infer_module_from_config(project_root: Path) -> str | None:
"""Infer the target Gradle module from codeflash config in gradle.properties.
Reads codeflash.moduleRoot or codeflash.testsRoot and extracts the first
path component as the module name. Verifies the module directory has a
build.gradle(.kts) file.
"""
props_file = project_root / "gradle.properties"
if not props_file.exists():
return None
try:
content = props_file.read_text(encoding="utf-8")
except Exception:
return None
for key in ("codeflash.moduleRoot", "codeflash.testsRoot"):
for line in content.splitlines():
line = line.strip()
if line.startswith(key + "="):
value = line.split("=", 1)[1].strip()
# Extract first path component (e.g. "rewrite-core/src/main/java" → "rewrite-core")
candidate = Path(value).parts[0] if Path(value).parts else None
if candidate:
module_dir = project_root / candidate
if (module_dir / "build.gradle.kts").exists() or (module_dir / "build.gradle").exists():
return candidate
return None
def _find_multi_module_root(project_root: Path, test_paths: Any) -> tuple[Path, str | None]: def _find_multi_module_root(project_root: Path, test_paths: Any) -> tuple[Path, str | None]:
"""Find the multi-module parent root if tests are in a different module. """Find the multi-module parent root if tests are in a different module.
@ -287,10 +344,18 @@ def _find_multi_module_root(project_root: Path, test_paths: Any) -> tuple[Path,
test_file_paths.append(test_file.benchmarking_file_path) test_file_paths.append(test_file.benchmarking_file_path)
elif hasattr(test_file, "instrumented_behavior_file_path") and test_file.instrumented_behavior_file_path: elif hasattr(test_file, "instrumented_behavior_file_path") and test_file.instrumented_behavior_file_path:
test_file_paths.append(test_file.instrumented_behavior_file_path) test_file_paths.append(test_file.instrumented_behavior_file_path)
elif hasattr(test_file, "original_file_path") and test_file.original_file_path:
test_file_paths.append(test_file.original_file_path)
elif isinstance(test_paths, (list, tuple)): elif isinstance(test_paths, (list, tuple)):
test_file_paths = [Path(p) if isinstance(p, str) else p for p in test_paths] test_file_paths = [Path(p) if isinstance(p, str) else p for p in test_paths]
if not test_file_paths: if not test_file_paths:
# No test file paths available — try to infer the module from codeflash config
# in gradle.properties (e.g. codeflash.moduleRoot=rewrite-core/src/main/java).
module = _infer_module_from_config(project_root)
if module:
logger.info("Inferred module '%s' from codeflash config (no test file paths)", module)
return project_root, module
return project_root, None return project_root, None
test_outside_project = False test_outside_project = False
@ -320,7 +385,14 @@ def _find_multi_module_root(project_root: Path, test_paths: Any) -> tuple[Path,
module_counts[matched] = module_counts.get(matched, 0) + 1 module_counts[matched] = module_counts.get(matched, 0) + 1
if module_counts: if module_counts:
best_module = max(module_counts, key=lambda m: module_counts[m]) # On ties, prefer the module matching codeflash.toml module-root
config_module = _read_config_module_root(project_root)
max_count = max(module_counts.values())
tied = [m for m, c in module_counts.items() if c == max_count]
if config_module and config_module in tied:
best_module = config_module
else:
best_module = max(module_counts, key=lambda m: module_counts[m])
logger.debug( logger.debug(
"Detected multi-module project. Root: %s, Module votes: %s, Selected: %s", "Detected multi-module project. Root: %s, Module votes: %s, Selected: %s",
project_root, project_root,
@ -328,6 +400,31 @@ def _find_multi_module_root(project_root: Path, test_paths: Any) -> tuple[Path,
best_module, best_module,
) )
return project_root, best_module return project_root, best_module
# project_root has no sub-modules — check if it is itself a sub-module
# of a parent multi-module project (e.g. rewrite-core/ inside rewrite/).
parent = project_root.parent
while parent != parent.parent:
if _is_build_root(parent):
parent_modules = _detect_modules(parent)
if parent_modules:
try:
rel_path = project_root.relative_to(parent)
matched = _match_module_from_rel_path(rel_path, parent_modules)
if matched:
logger.debug("Detected project_root as sub-module. Root: %s, Module: %s", parent, matched)
return parent, matched
except ValueError:
pass
parent = parent.parent
# Last resort: settings.gradle may use dynamic includes that _detect_modules
# can't parse. Fall back to codeflash config in gradle.properties.
module = _infer_module_from_config(project_root)
if module:
logger.info("Inferred module '%s' from codeflash config (dynamic settings.gradle)", module)
return project_root, module
return project_root, None return project_root, None
current = project_root.parent current = project_root.parent

View file

@ -588,17 +588,14 @@ class TestGradleEnsureRuntimeMultiModule:
project = self._make_multi_module_project(tmp_path) project = self._make_multi_module_project(tmp_path)
strategy = GradleStrategy() strategy = GradleStrategy()
# Provide a fake runtime JAR result = strategy.ensure_runtime(project, test_module="streams")
fake_jar = tmp_path / "fake-runtime.jar"
fake_jar.write_bytes(b"PK\x03\x04") # minimal zip header
with patch.object(strategy, "find_runtime_jar", return_value=fake_jar):
result = strategy.ensure_runtime(project, test_module="streams")
assert result is True assert result is True
# Dependency should be in streams/build.gradle.kts # Dependency should be in streams/build.gradle.kts with Maven Central coordinate
streams_build = (project / "streams" / "build.gradle.kts").read_text(encoding="utf-8") streams_build = (project / "streams" / "build.gradle.kts").read_text(encoding="utf-8")
assert "codeflash-runtime" in streams_build assert "codeflash-runtime" in streams_build
assert "com.codeflash:codeflash-runtime:" in streams_build
assert "mavenCentral()" in streams_build
# And NOT in clients/build.gradle.kts or root build.gradle.kts # And NOT in clients/build.gradle.kts or root build.gradle.kts
clients_build = (project / "clients" / "build.gradle.kts").read_text(encoding="utf-8") clients_build = (project / "clients" / "build.gradle.kts").read_text(encoding="utf-8")
assert "codeflash-runtime" not in clients_build assert "codeflash-runtime" not in clients_build
@ -610,15 +607,13 @@ class TestGradleEnsureRuntimeMultiModule:
project = self._make_multi_module_project(tmp_path) project = self._make_multi_module_project(tmp_path)
strategy = GradleStrategy() strategy = GradleStrategy()
fake_jar = tmp_path / "fake-runtime.jar" result = strategy.ensure_runtime(project, test_module=None)
fake_jar.write_bytes(b"PK\x03\x04")
with patch.object(strategy, "find_runtime_jar", return_value=fake_jar):
result = strategy.ensure_runtime(project, test_module=None)
assert result is True assert result is True
root_build = (project / "build.gradle.kts").read_text(encoding="utf-8") root_build = (project / "build.gradle.kts").read_text(encoding="utf-8")
assert "codeflash-runtime" in root_build assert "codeflash-runtime" in root_build
assert "com.codeflash:codeflash-runtime:" in root_build
assert "mavenCentral()" in root_build
def test_adds_dependency_to_nested_module(self, tmp_path): def test_adds_dependency_to_nested_module(self, tmp_path):
"""When test_module='connect:runtime', the dep goes to connect/runtime/build.gradle.kts.""" """When test_module='connect:runtime', the dep goes to connect/runtime/build.gradle.kts."""
@ -632,12 +627,20 @@ class TestGradleEnsureRuntimeMultiModule:
) )
strategy = GradleStrategy() strategy = GradleStrategy()
fake_jar = tmp_path / "fake-runtime.jar" result = strategy.ensure_runtime(project, test_module="connect:runtime")
fake_jar.write_bytes(b"PK\x03\x04")
with patch.object(strategy, "find_runtime_jar", return_value=fake_jar):
result = strategy.ensure_runtime(project, test_module="connect:runtime")
assert result is True assert result is True
nested_build = (nested / "build.gradle.kts").read_text(encoding="utf-8") nested_build = (nested / "build.gradle.kts").read_text(encoding="utf-8")
assert "codeflash-runtime" in nested_build assert "codeflash-runtime" in nested_build
assert "com.codeflash:codeflash-runtime:" in nested_build
assert "mavenCentral()" in nested_build
def test_does_not_copy_jar_to_libs(self, tmp_path):
"""ensure_runtime should NOT copy JARs locally — Gradle resolves from Maven Central."""
project = self._make_multi_module_project(tmp_path)
strategy = GradleStrategy()
strategy.ensure_runtime(project, test_module="streams")
libs_dir = project / "streams" / "libs"
assert not libs_dir.exists()

View file

@ -631,3 +631,32 @@ class TestFindMultiModuleRoot:
assert build_root == tmp_path assert build_root == tmp_path
assert test_module == "streams" assert test_module == "streams"
def test_submodule_as_project_root_with_tests_inside(self, tmp_path):
"""When project_root is a sub-module (e.g. rewrite-core/) and generated tests
are inside it, should walk up to find the real root and detect the module."""
self._make_kafka_like_project(tmp_path)
submodule_root = tmp_path / "clients"
test_file = submodule_root / "src" / "test" / "java" / "com" / "ClientsTest.java"
test_file.parent.mkdir(parents=True, exist_ok=True)
test_file.touch()
test_paths = self._make_test_paths_mock([test_file])
build_root, test_module = _find_multi_module_root(submodule_root, test_paths)
assert build_root == tmp_path
assert test_module == "clients"
def test_submodule_as_project_root_nested_module(self, tmp_path):
"""When project_root is a nested sub-module (connect/runtime), should detect it."""
self._make_kafka_like_project(tmp_path)
submodule_root = tmp_path / "connect" / "runtime"
test_file = submodule_root / "src" / "test" / "java" / "com" / "RuntimeTest.java"
test_file.parent.mkdir(parents=True, exist_ok=True)
test_file.touch()
test_paths = self._make_test_paths_mock([test_file])
build_root, test_module = _find_multi_module_root(submodule_root, test_paths)
assert build_root == tmp_path
assert test_module == "connect:runtime"