chore: merge omni-java into feat/java-gradle-support

Resolved conflicts in:
- codeflash/version.py (kept gradle branch version)
- codeflash/languages/java/test_runner.py (kept gradle multi-module logic)
- codeflash/languages/java/support.py (kept java_test_module parameter)
- codeflash/discovery/functions_to_optimize.py (kept enhanced test filtering)

Co-Authored-By: Claude Sonnet 4.5 <noreply@anthropic.com>
This commit is contained in:
HeshamHM28 2026-02-04 19:45:23 +00:00
commit c3be740417
63 changed files with 3338 additions and 1944 deletions

View file

@ -1,59 +0,0 @@
name: Claude Code Review
on:
pull_request:
types: [opened, synchronize]
# Optional: Only run on specific file changes
# paths:
# - "src/**/*.ts"
# - "src/**/*.tsx"
# - "src/**/*.js"
# - "src/**/*.jsx"
jobs:
claude-review:
# Optional: Filter by PR author
# if: |
# github.event.pull_request.user.login == 'external-contributor' ||
# github.event.pull_request.user.login == 'new-developer' ||
# github.event.pull_request.author_association == 'FIRST_TIME_CONTRIBUTOR'
runs-on: ubuntu-latest
permissions:
contents: read
pull-requests: write
issues: read
id-token: write
steps:
- name: Checkout repository
uses: actions/checkout@v4
with:
fetch-depth: 1
- name: Run Claude Code Review
id: claude-review
uses: anthropics/claude-code-action@v1
with:
use_foundry: "true"
use_sticky_comment: true
prompt: |
REPO: ${{ github.repository }}
PR NUMBER: ${{ github.event.pull_request.number }}
Please review this pull request and provide feedback on:
- Code quality and best practices
- Potential bugs or issues
- Performance considerations
- Security concerns
- Test coverage
Use the repository's CLAUDE.md for guidance on style and conventions. Be constructive and helpful in your feedback.
# See https://github.com/anthropics/claude-code-action/blob/main/docs/usage.md
# or https://code.claude.com/docs/en/cli-reference for available options
claude_args: '--allowed-tools "Bash(gh issue view:*),Bash(gh search:*),Bash(gh issue list:*),Bash(gh pr diff:*),Bash(gh pr view:*),Bash(gh pr list:*)"'
env:
ANTHROPIC_FOUNDRY_API_KEY: ${{ secrets.AZURE_ANTHROPIC_API_KEY }}
ANTHROPIC_FOUNDRY_BASE_URL: ${{ secrets.AZURE_ANTHROPIC_ENDPOINT }}

View file

@ -1,6 +1,8 @@
name: Claude Code
on:
pull_request:
types: [opened, synchronize, ready_for_review, reopened]
issue_comment:
types: [created]
pull_request_review_comment:
@ -11,19 +13,154 @@ on:
types: [submitted]
jobs:
claude:
# Automatic PR review (can fix linting issues and push)
# Blocked for fork PRs to prevent malicious code execution
pr-review:
if: |
(github.event_name == 'issue_comment' && contains(github.event.comment.body, '@claude')) ||
(github.event_name == 'pull_request_review_comment' && contains(github.event.comment.body, '@claude')) ||
(github.event_name == 'pull_request_review' && contains(github.event.review.body, '@claude')) ||
(github.event_name == 'issues' && (contains(github.event.issue.body, '@claude') || contains(github.event.issue.title, '@claude')))
github.event_name == 'pull_request' &&
github.actor != 'claude[bot]' &&
github.event.pull_request.head.repo.full_name == github.repository
runs-on: ubuntu-latest
permissions:
contents: write
pull-requests: write
issues: read
id-token: write
actions: read # Required for Claude to read CI results on PRs
actions: read
steps:
- name: Checkout repository
uses: actions/checkout@v4
with:
fetch-depth: 0
ref: ${{ github.event.pull_request.head.ref }}
- name: Install uv
uses: astral-sh/setup-uv@v6
- name: Install dependencies
run: |
uv venv --seed
uv sync
- name: Run Claude Code
id: claude
uses: anthropics/claude-code-action@v1
with:
use_foundry: "true"
use_sticky_comment: true
allowed_bots: "claude[bot]"
prompt: |
REPO: ${{ github.repository }}
PR NUMBER: ${{ github.event.pull_request.number }}
EVENT: ${{ github.event.action }}
## STEP 1: Run pre-commit checks and fix issues
First, run `uv run prek run --from-ref origin/main` to check for linting/formatting issues on files changed in this PR.
If there are any issues:
- For SAFE auto-fixable issues (formatting, import sorting, trailing whitespace, etc.), run `uv run prek run --from-ref origin/main` again to auto-fix them
- Stage the fixed files with `git add`
- Commit with message "style: auto-fix linting issues"
- Push the changes with `git push`
Do NOT attempt to fix:
- Type errors that require logic changes
- Complex refactoring suggestions
- Anything that could change behavior
## STEP 2: Review the PR
${{ github.event.action == 'synchronize' && 'This is a RE-REVIEW after new commits. First, get the list of changed files in this latest push using `gh pr diff`. Review ONLY the changed files. Check ALL existing review comments and resolve ones that are now fixed.' || 'This is the INITIAL REVIEW.' }}
Review this PR focusing ONLY on:
1. Critical bugs or logic errors
2. Security vulnerabilities
3. Breaking API changes
4. Test failures (methods with typos that wont run)
IMPORTANT:
- First check existing review comments using `gh api repos/${{ github.repository }}/pulls/${{ github.event.pull_request.number }}/comments`. For each existing comment, check if the issue still exists in the current code.
- If an issue is fixed, use `gh api --method PATCH repos/${{ github.repository }}/pulls/comments/COMMENT_ID -f body="✅ Fixed in latest commit"` to resolve it.
- Only create NEW inline comments for HIGH-PRIORITY issues found in changed files.
- Limit to 5-7 NEW comments maximum per review.
- Use CLAUDE.md for project-specific guidance.
- Use `gh pr comment` for summary-level feedback.
- Use `mcp__github_inline_comment__create_inline_comment` sparingly for critical code issues only.
## STEP 3: Coverage analysis
Analyze test coverage for changed files:
1. Get the list of Python files changed in this PR (excluding tests):
`git diff --name-only origin/main...HEAD -- '*.py' | grep -v test`
2. Run tests with coverage on the PR branch:
`uv run coverage run -m pytest tests/ -q --tb=no`
`uv run coverage json -o coverage-pr.json`
3. Get coverage for changed files only:
`uv run coverage report --include="<changed_files_comma_separated>"`
4. Compare with main branch coverage:
- Checkout main: `git checkout origin/main`
- Run coverage: `uv run coverage run -m pytest tests/ -q --tb=no && uv run coverage json -o coverage-main.json`
- Checkout back: `git checkout -`
5. Analyze the diff to identify:
- NEW FILES: Files that don't exist on main (require good test coverage)
- MODIFIED FILES: Files with changes (changes must be covered by tests)
6. Report in PR comment with a markdown table:
- Coverage % for each changed file (PR vs main)
- Overall coverage change
- For NEW files: Flag if coverage is below 75%
- For MODIFIED files: Flag if the changed lines are not covered by tests
- Flag if overall coverage decreased
Coverage requirements:
- New implementations/files: Must have ≥75% test coverage
- Modified code: Changed lines should be exercised by existing or new tests
- No coverage regressions: Overall coverage should not decrease
claude_args: '--allowedTools "mcp__github_inline_comment__create_inline_comment,Bash(gh pr comment:*),Bash(gh pr diff:*),Bash(gh pr view:*),Bash(gh pr list:*),Bash(gh issue view:*),Bash(gh issue list:*),Bash(gh api:*),Bash(uv run prek *),Bash(uv run coverage *),Bash(uv run pytest *),Bash(git status*),Bash(git add *),Bash(git commit *),Bash(git push*),Bash(git diff *),Bash(git checkout *),Read,Glob,Grep"'
additional_permissions: |
actions: read
env:
ANTHROPIC_FOUNDRY_API_KEY: ${{ secrets.AZURE_ANTHROPIC_API_KEY }}
ANTHROPIC_FOUNDRY_BASE_URL: ${{ secrets.AZURE_ANTHROPIC_ENDPOINT }}
# @claude mentions (can edit and push) - restricted to maintainers only
claude-mention:
if: |
(
github.event_name == 'issue_comment' &&
contains(github.event.comment.body, '@claude') &&
(github.event.comment.author_association == 'OWNER' || github.event.comment.author_association == 'MEMBER' || github.event.comment.author_association == 'COLLABORATOR')
) ||
(
github.event_name == 'pull_request_review_comment' &&
contains(github.event.comment.body, '@claude') &&
(github.event.comment.author_association == 'OWNER' || github.event.comment.author_association == 'MEMBER' || github.event.comment.author_association == 'COLLABORATOR') &&
github.event.pull_request.head.repo.full_name == github.repository
) ||
(
github.event_name == 'pull_request_review' &&
contains(github.event.review.body, '@claude') &&
(github.event.review.author_association == 'OWNER' || github.event.review.author_association == 'MEMBER' || github.event.review.author_association == 'COLLABORATOR') &&
github.event.pull_request.head.repo.full_name == github.repository
) ||
(
github.event_name == 'issues' &&
(contains(github.event.issue.body, '@claude') || contains(github.event.issue.title, '@claude')) &&
(github.event.issue.author_association == 'OWNER' || github.event.issue.author_association == 'MEMBER' || github.event.issue.author_association == 'COLLABORATOR')
)
runs-on: ubuntu-latest
permissions:
contents: write
pull-requests: write
issues: read
id-token: write
actions: read
steps:
- name: Get PR head ref
id: pr-ref
@ -44,14 +181,22 @@ jobs:
fetch-depth: 0
ref: ${{ steps.pr-ref.outputs.ref }}
- name: Install uv
uses: astral-sh/setup-uv@v6
- name: Install dependencies
run: |
uv venv --seed
uv sync
- name: Run Claude Code
id: claude
uses: anthropics/claude-code-action@v1
with:
use_foundry: "true"
claude_args: '--allowedTools "Read,Edit,Write,Glob,Grep,Bash(git status*),Bash(git diff*),Bash(git add *),Bash(git commit *),Bash(git push*),Bash(git log*),Bash(git merge*),Bash(git fetch*),Bash(git checkout*),Bash(git branch*),Bash(uv run prek *),Bash(prek *),Bash(uv run ruff *),Bash(uv run pytest *),Bash(uv run mypy *),Bash(uv run coverage *),Bash(gh pr comment*),Bash(gh pr view*),Bash(gh pr diff*),Bash(gh pr merge*),Bash(gh pr close*)"'
additional_permissions: |
actions: read
env:
ANTHROPIC_FOUNDRY_API_KEY: ${{ secrets.AZURE_ANTHROPIC_API_KEY }}
ANTHROPIC_FOUNDRY_BASE_URL: ${{ secrets.AZURE_ANTHROPIC_ENDPOINT }}

View file

@ -1,19 +0,0 @@
name: Lint
on:
pull_request:
push:
branches:
- main
concurrency:
group: ${{ github.workflow }}-${{ github.ref_name }}
cancel-in-progress: true
jobs:
lint:
name: Run pre-commit hooks
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- uses: actions/setup-python@v5
- uses: pre-commit/action@v3.0.1

18
.github/workflows/prek.yaml vendored Normal file
View file

@ -0,0 +1,18 @@
name: Lint
on: [pull_request]
concurrency:
group: ${{ github.workflow }}-${{ github.ref_name }}
cancel-in-progress: true
jobs:
prek:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
with:
fetch-depth: 0
- uses: astral-sh/setup-uv@v6
- uses: j178/prek-action@v1
with:
extra-args: '--from-ref origin/${{ github.base_ref }} --to-ref ${{ github.sha }}'

View file

@ -1,6 +1,6 @@
repos:
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.12.7
rev: v0.15.0
hooks:
# Run the linter.
- id: ruff-check

View file

@ -62,6 +62,8 @@ codeflash/
- Use libcst, not ast - For Python, always use `libcst` for code parsing/modification to preserve formatting.
- Code context extraction and replacement tests must always assert for full string equality, no substring matching.
- Any new feature or bug fix that can be tested automatically must have test cases.
- If changes affect existing test expectations, update the tests accordingly. Tests must always pass after changes.
- NEVER use leading underscores for function names (e.g., `_helper`). Python has no true private functions. Always use public names.
## Code Style
@ -70,7 +72,7 @@ codeflash/
- **Tooling**: Ruff for linting/formatting, mypy strict mode, pre-commit hooks
- **Comments**: Minimal - only explain "why", not "what"
- **Docstrings**: Do not add unless explicitly requested
- **Naming**: Prefer public functions (no leading underscore) - Python doesn't have true private functions
- **Naming**: NEVER use leading underscores (`_function_name`) - Python has no true private functions, use public names
- **Paths**: Always use absolute paths, handle encoding explicitly (UTF-8)
## Git Commits & Pull Requests

View file

@ -20,7 +20,7 @@
}
},
"../../../packages/codeflash": {
"version": "0.3.1",
"version": "0.4.0",
"dev": true,
"hasInstallScript": true,
"license": "MIT",

View file

@ -373,11 +373,13 @@ def _handle_show_config() -> None:
detected = detect_project(project_root)
# Check if config exists or is auto-detected
config_exists, _ = has_existing_config(project_root)
config_exists, config_file = has_existing_config(project_root)
status = "Saved config" if config_exists else "Auto-detected (not saved)"
console.print()
console.print(f"[bold]Codeflash Configuration[/bold] ({status})")
if config_exists and config_file:
console.print(f"[dim]Config file: {project_root / config_file}[/dim]")
console.print()
table = Table(show_header=True, header_style="bold cyan")

View file

@ -27,6 +27,9 @@ from codeflash.cli_cmds.cli_common import apologize_and_exit
from codeflash.cli_cmds.console import console, logger
from codeflash.cli_cmds.extension import install_vscode_extension
# Import Java init module
from codeflash.cli_cmds.init_java import init_java_project
# Import JS/TS init module
from codeflash.cli_cmds.init_javascript import (
ProjectLanguage,
@ -35,9 +38,6 @@ from codeflash.cli_cmds.init_javascript import (
get_js_dependency_installation_commands,
init_js_project,
)
# Import Java init module
from codeflash.cli_cmds.init_java import init_java_project
from codeflash.code_utils.code_utils import validate_relative_directory_path
from codeflash.code_utils.compat import LF
from codeflash.code_utils.config_parser import parse_config_file
@ -1674,9 +1674,7 @@ def _customize_java_workflow_content(optimize_yml_content: str, git_root: Path,
# Install dependencies
install_deps_cmd = get_java_dependency_installation_commands(build_tool)
optimize_yml_content = optimize_yml_content.replace("{{ install_dependencies_command }}", install_deps_cmd)
return optimize_yml_content
return optimize_yml_content.replace("{{ install_dependencies_command }}", install_deps_cmd)
def get_formatter_cmds(formatter: str) -> list[str]:

View file

@ -165,9 +165,7 @@ def init_java_project() -> None:
lang_panel = Panel(
Text(
"Java project detected!\n\nI'll help you set up Codeflash for your project.",
style="cyan",
justify="center",
"Java project detected!\n\nI'll help you set up Codeflash for your project.", style="cyan", justify="center"
),
title="Java Setup",
border_style="bright_red",
@ -205,7 +203,9 @@ def init_java_project() -> None:
completion_message = "Codeflash is now set up for your Java project!\n\nYou can now run any of these commands:"
if did_add_new_key:
completion_message += "\n\nDon't forget to restart your shell to load the CODEFLASH_API_KEY environment variable!"
completion_message += (
"\n\nDon't forget to restart your shell to load the CODEFLASH_API_KEY environment variable!"
)
if os.name == "nt":
reload_cmd = f". {get_shell_rc_path()}" if is_powershell() else f"call {get_shell_rc_path()}"
else:
@ -234,9 +234,7 @@ def should_modify_java_config() -> tuple[bool, dict[str, Any] | None]:
codeflash_config_path = project_root / "codeflash.toml"
if codeflash_config_path.exists():
return Confirm.ask(
"A Codeflash config already exists. Do you want to re-configure it?",
default=False,
show_default=True,
"A Codeflash config already exists. Do you want to re-configure it?", default=False, show_default=True
), None
return True, None
@ -285,14 +283,10 @@ def collect_java_setup_info() -> JavaSetupInfo:
if Confirm.ask("Would you like to change any of these settings?", default=False):
# Source root override
module_root_override = _prompt_directory_override(
"source", detected_source_root, curdir
)
module_root_override = _prompt_directory_override("source", detected_source_root, curdir)
# Test root override
test_root_override = _prompt_directory_override(
"test", detected_test_root, curdir
)
test_root_override = _prompt_directory_override("test", detected_test_root, curdir)
# Formatter override
formatter_questions = [
@ -300,7 +294,7 @@ def collect_java_setup_info() -> JavaSetupInfo:
"formatter",
message="Which code formatter do you use?",
choices=[
(f"keep detected (google-java-format)", "keep"),
("keep detected (google-java-format)", "keep"),
("google-java-format", "google-java-format"),
("spotless", "spotless"),
("other", "other"),
@ -345,7 +339,7 @@ def _prompt_directory_override(dir_type: str, detected: str, curdir: Path) -> st
subdirs = [d.name for d in curdir.iterdir() if d.is_dir() and not d.name.startswith(".")]
subdirs = [d for d in subdirs if d not in ("target", "build", ".git", ".idea", detected)]
options = [keep_detected_option] + subdirs[:5] + [custom_dir_option]
options = [keep_detected_option, *subdirs[:5], custom_dir_option]
questions = [
inquirer.List(
@ -364,10 +358,9 @@ def _prompt_directory_override(dir_type: str, detected: str, curdir: Path) -> st
answer = answers[f"{dir_type}_root"]
if answer == keep_detected_option:
return None
elif answer == custom_dir_option:
if answer == custom_dir_option:
return _prompt_custom_directory(dir_type)
else:
return answer
return answer
def _prompt_custom_directory(dir_type: str) -> str:
@ -441,7 +434,7 @@ def get_java_formatter_cmd(formatter: str, build_tool: JavaBuildTool) -> list[st
if formatter == "spotless":
if build_tool == JavaBuildTool.MAVEN:
return ["mvn spotless:apply -DspotlessFiles=$file"]
elif build_tool == JavaBuildTool.GRADLE:
if build_tool == JavaBuildTool.GRADLE:
return ["./gradlew spotlessApply"]
return ["spotless $file"]
if formatter == "other":

View file

@ -1429,7 +1429,9 @@ def _collect_numerical_imports(tree: ast.Module) -> tuple[set[str], set[str]]:
numerical_names: set[str] = set()
modules_used: set[str] = set()
for node in ast.walk(tree):
stack: list[ast.AST] = [tree]
while stack:
node = stack.pop()
if isinstance(node, ast.Import):
for alias in node.names:
# import numpy or import numpy as np
@ -1451,6 +1453,8 @@ def _collect_numerical_imports(tree: ast.Module) -> tuple[set[str], set[str]]:
name = alias.asname if alias.asname else alias.name
numerical_names.add(name)
modules_used.add(module_root)
else:
stack.extend(ast.iter_child_nodes(node))
return numerical_names, modules_used

View file

@ -711,18 +711,12 @@ def _add_java_class_members(
if not new_fields and not new_methods:
return original_source
logger.debug(
f"Adding {len(new_fields)} new fields and {len(new_methods)} helper methods to class {class_name}"
)
logger.debug(f"Adding {len(new_fields)} new fields and {len(new_methods)} helper methods to class {class_name}")
# Import the insertion function from replacement module
from codeflash.languages.java.replacement import _insert_class_members
result = _insert_class_members(
original_source, class_name, new_fields, new_methods, analyzer
)
return result
return _insert_class_members(original_source, class_name, new_fields, new_methods, analyzer)
except Exception as e:
logger.debug(f"Error adding Java class members: {e}")
@ -959,12 +953,14 @@ def get_optimized_code_for_module(relative_path: Path, optimized_code: CodeStrin
for file_path_str, code in file_to_code_context.items():
if file_path_str:
# Extract filename without creating Path object repeatedly
if file_path_str.endswith(target_filename) and (len(file_path_str) == len(target_filename) or file_path_str[-len(target_filename)-1] in ('/', '\\')):
if file_path_str.endswith(target_filename) and (
len(file_path_str) == len(target_filename)
or file_path_str[-len(target_filename) - 1] in ("/", "\\")
):
module_optimized_code = code
logger.debug(f"Matched {file_path_str} to {relative_path} by filename")
break
if module_optimized_code is None:
# Also try matching if there's only one code file, but ONLY for non-Python
# languages where path matching is less strict. For Python, we require

View file

@ -6,7 +6,8 @@ from typing import Any, Union
MAX_TEST_RUN_ITERATIONS = 5
OPTIMIZATION_CONTEXT_TOKEN_LIMIT = 16000
TESTGEN_CONTEXT_TOKEN_LIMIT = 16000
INDIVIDUAL_TESTCASE_TIMEOUT = 15
INDIVIDUAL_TESTCASE_TIMEOUT = 15 # For Python pytest
JAVA_TESTCASE_TIMEOUT = 120 # Java Maven tests need more time due to startup overhead
MAX_FUNCTION_TEST_SECONDS = 60
MIN_IMPROVEMENT_THRESHOLD = 0.05
MIN_THROUGHPUT_IMPROVEMENT_THRESHOLD = 0.10 # 10% minimum improvement for async throughput

View file

@ -6,6 +6,8 @@ import json
from pathlib import Path
from typing import Any
from codeflash.setup.detector import is_build_output_dir
PACKAGE_JSON_CACHE: dict[Path, Path] = {}
PACKAGE_JSON_DATA_CACHE: dict[Path, dict[str, Any]] = {}
@ -50,12 +52,15 @@ def detect_module_root(project_root: Path, package_data: dict[str, Any]) -> str:
"""Detect module root from package.json fields or directory conventions.
Detection order:
1. package.json "exports" field (extract directory from main export)
2. package.json "module" field (ESM entry point)
3. package.json "main" field (CJS entry point)
4. "src/" directory if it exists
1. src/, lib/, source/ directories (common source directories)
2. package.json "exports" field (if not in build output directory)
3. package.json "module" field (ESM, if not in build output directory)
4. package.json "main" field (CJS, if not in build output directory)
5. Fall back to "." (project root)
Build output directories (build/, dist/, out/) are skipped since they contain
compiled code, not source files.
Args:
project_root: Root directory of the project.
package_data: Parsed package.json data.
@ -64,6 +69,11 @@ def detect_module_root(project_root: Path, package_data: dict[str, Any]) -> str:
Detected module root path (relative to project root).
"""
# Check for common source directories first - these are always preferred
for src_dir in ["src", "lib", "source"]:
if (project_root / src_dir).is_dir():
return src_dir
# Check exports field (modern Node.js)
exports = package_data.get("exports")
if exports:
@ -80,27 +90,38 @@ def detect_module_root(project_root: Path, package_data: dict[str, Any]) -> str:
if entry_path and isinstance(entry_path, str):
parent = Path(entry_path).parent
if parent != Path() and (project_root / parent).is_dir():
if (
parent != Path()
and parent.as_posix() != "."
and (project_root / parent).is_dir()
and not is_build_output_dir(parent)
):
return parent.as_posix()
# Check module field (ESM)
module_field = package_data.get("module")
if module_field and isinstance(module_field, str):
parent = Path(module_field).parent
if parent != Path() and (project_root / parent).is_dir():
if (
parent != Path()
and parent.as_posix() != "."
and (project_root / parent).is_dir()
and not is_build_output_dir(parent)
):
return parent.as_posix()
# Check main field (CJS)
main_field = package_data.get("main")
if main_field and isinstance(main_field, str):
parent = Path(main_field).parent
if parent != Path() and (project_root / parent).is_dir():
if (
parent != Path()
and parent.as_posix() != "."
and (project_root / parent).is_dir()
and not is_build_output_dir(parent)
):
return parent.as_posix()
# Check for src/ directory convention
if (project_root / "src").is_dir():
return "src"
# Default to project root
return "."

View file

@ -721,9 +721,7 @@ def inject_profiling_into_existing_test(
if is_java():
from codeflash.languages.java.instrumentation import instrument_existing_test
return instrument_existing_test(
test_path, call_positions, function_to_optimize, tests_project_root, mode.value
)
return instrument_existing_test(test_path, call_positions, function_to_optimize, tests_project_root, mode.value)
if function_to_optimize.is_async:
return inject_async_profiling_into_existing_test(

View file

@ -746,7 +746,11 @@ def get_external_base_class_inits(code_context: CodeStringsMarkdown, project_roo
return CodeStringsMarkdown(code_strings=[])
imported_names: dict[str, str] = {}
external_bases: list[tuple[str, str]] = []
# Use a set to deduplicate external base entries to avoid repeated expensive checks/imports.
external_bases_set: set[tuple[str, str]] = set()
# Local cache to avoid repeated _is_project_module calls for the same module_name.
is_project_cache: dict[str, bool] = {}
for node in ast.walk(tree):
if isinstance(node, ast.ImportFrom) and node.module:
for alias in node.names:
@ -763,21 +767,31 @@ def get_external_base_class_inits(code_context: CodeStringsMarkdown, project_roo
if base_name and base_name in imported_names:
module_name = imported_names[base_name]
if not _is_project_module(module_name, project_root_path):
external_bases.append((base_name, module_name))
# Check cache first to avoid repeated expensive checks.
cached = is_project_cache.get(module_name)
if cached is None:
is_project = _is_project_module(module_name, project_root_path)
is_project_cache[module_name] = is_project
else:
is_project = cached
if not external_bases:
if not is_project:
external_bases_set.add((base_name, module_name))
if not external_bases_set:
return CodeStringsMarkdown(code_strings=[])
code_strings: list[CodeString] = []
extracted: set[tuple[str, str]] = set()
for base_name, module_name in external_bases:
if (module_name, base_name) in extracted:
continue
# Cache imported modules to avoid repeated importlib.import_module calls.
imported_module_cache: dict[str, object] = {}
for base_name, module_name in external_bases_set:
try:
module = importlib.import_module(module_name)
module = imported_module_cache.get(module_name)
if module is None:
module = importlib.import_module(module_name)
imported_module_cache[module_name] = module
base_class = getattr(module, base_name, None)
if base_class is None:
continue
@ -799,7 +813,6 @@ def get_external_base_class_inits(code_context: CodeStringsMarkdown, project_roo
class_source = f"class {base_name}:\n" + textwrap.indent(init_source, " ")
code_strings.append(CodeString(code=class_source, file_path=class_file))
extracted.add((module_name, base_name))
except (ImportError, ModuleNotFoundError, AttributeError):
logger.debug(f"Failed to extract __init__ for {module_name}.{base_name}")
@ -854,12 +867,13 @@ def extract_imports_for_class(module_tree: ast.Module, class_node: ast.ClassDef,
needed_names.add(decorator.func.value.id)
# Get type annotation names from class body (for dataclass fields)
for item in ast.walk(class_node):
for item in class_node.body:
if isinstance(item, ast.AnnAssign) and item.annotation:
collect_names_from_annotation(item.annotation, needed_names)
# Also check for field() calls which are common in dataclasses
if isinstance(item, ast.Call) and isinstance(item.func, ast.Name):
needed_names.add(item.func.id)
elif isinstance(item, ast.Assign) and isinstance(item.value, ast.Call):
if isinstance(item.value.func, ast.Name):
needed_names.add(item.value.func.id)
# Find imports that provide these names
import_lines: list[str] = []

View file

@ -656,7 +656,7 @@ def discover_unit_tests(
# Existing Python logic
framework_strategies: dict[str, Callable] = {"pytest": discover_tests_pytest, "unittest": discover_tests_unittest}
strategy = framework_strategies.get(cfg.test_framework, None)
strategy = framework_strategies.get(cfg.test_framework)
if not strategy:
error_message = f"Unsupported test framework: {cfg.test_framework}"
raise ValueError(error_message)

View file

@ -25,11 +25,11 @@ class PrComment:
best_async_throughput: Optional[int] = None
def to_json(self) -> dict[str, Union[str, int, dict[str, dict[str, int]], list[BenchmarkDetail], None]]:
report_table = {
test_type.to_name(): result
for test_type, result in self.winning_behavior_test_results.get_test_pass_fail_report_by_type().items()
if test_type.to_name()
}
report_table: dict[str, dict[str, int]] = {}
for test_type, result in self.winning_behavior_test_results.get_test_pass_fail_report_by_type().items():
name = test_type.to_name()
if name:
report_table[name] = result
result: dict[str, Union[str, int, dict[str, dict[str, int]], list[BenchmarkDetail], None]] = {
"optimization_explanation": self.optimization_explanation,

View file

@ -36,15 +36,15 @@ from codeflash.languages.current import (
reset_current_language,
set_current_language,
)
# Java language support
# Importing the module triggers registration via @register_language decorator
from codeflash.languages.java.support import JavaSupport # noqa: F401
from codeflash.languages.javascript import JavaScriptSupport, TypeScriptSupport # noqa: F401
# Import language support modules to trigger auto-registration
# This ensures all supported languages are available when this package is imported
from codeflash.languages.python import PythonSupport # noqa: F401
# Java language support
# Importing the module triggers registration via @register_language decorator
from codeflash.languages.java.support import JavaSupport # noqa: F401
from codeflash.languages.registry import (
detect_project_language,
get_language_support,

View file

@ -58,9 +58,6 @@ def set_current_language(language: Language | str) -> None:
"""
global _current_language
if _current_language is not None:
return
_current_language = Language(language) if isinstance(language, str) else language

View file

@ -13,7 +13,10 @@ import subprocess
import xml.etree.ElementTree as ET
from dataclasses import dataclass
from enum import Enum
from pathlib import Path
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from pathlib import Path
logger = logging.getLogger(__name__)
@ -78,9 +81,7 @@ def _safe_parse_xml(file_path: Path) -> ET.ElementTree:
root = ET.fromstring(content)
# Create ElementTree from root
tree = ET.ElementTree(root)
return tree
return ET.ElementTree(root)
class BuildTool(Enum):
@ -462,13 +463,7 @@ def run_maven_tests(
try:
result = subprocess.run(
cmd,
check=False,
cwd=project_root,
env=run_env,
capture_output=True,
text=True,
timeout=timeout,
cmd, check=False, cwd=project_root, env=run_env, capture_output=True, text=True, timeout=timeout
)
# Parse test results from Surefire reports
@ -488,7 +483,7 @@ def run_maven_tests(
)
except subprocess.TimeoutExpired:
logger.error("Maven test execution timed out after %d seconds", timeout)
logger.exception("Maven test execution timed out after %d seconds", timeout)
return MavenTestResult(
success=False,
tests_run=0,
@ -568,10 +563,7 @@ def _parse_surefire_reports(surefire_dir: Path) -> tuple[int, int, int, int]:
def compile_maven_project(
project_root: Path,
include_tests: bool = True,
env: dict[str, str] | None = None,
timeout: int = 300,
project_root: Path, include_tests: bool = True, env: dict[str, str] | None = None, timeout: int = 300
) -> tuple[bool, str, str]:
"""Compile a Maven project.
@ -605,13 +597,7 @@ def compile_maven_project(
try:
result = subprocess.run(
cmd,
check=False,
cwd=project_root,
env=run_env,
capture_output=True,
text=True,
timeout=timeout,
cmd, check=False, cwd=project_root, env=run_env, capture_output=True, text=True, timeout=timeout
)
return result.returncode == 0, result.stdout, result.stderr
@ -1002,14 +988,7 @@ def install_codeflash_runtime(project_root: Path, runtime_jar_path: Path) -> boo
]
try:
result = subprocess.run(
cmd,
check=False,
cwd=project_root,
capture_output=True,
text=True,
timeout=60,
)
result = subprocess.run(cmd, check=False, cwd=project_root, capture_output=True, text=True, timeout=60)
if result.returncode == 0:
logger.info("Successfully installed codeflash-runtime to local Maven repository")
@ -1085,7 +1064,7 @@ def add_codeflash_dependency_to_pom(pom_path: Path) -> bool:
return True
except ET.ParseError as e:
logger.error("Failed to parse pom.xml: %s", e)
logger.exception("Failed to parse pom.xml: %s", e)
return False
except Exception as e:
logger.exception("Failed to add dependency to pom.xml: %s", e)
@ -1236,7 +1215,7 @@ def add_jacoco_plugin_to_pom(pom_path: Path) -> bool:
if build_start != -1 and build_end != -1:
# Found main build section, find plugins within it
build_section = content[build_start:build_end + len("</build>")]
build_section = content[build_start : build_end + len("</build>")]
plugins_start_in_build = build_section.find("<plugins>")
plugins_end_in_build = build_section.rfind("</plugins>")

View file

@ -47,7 +47,16 @@ def _find_comparator_jar(project_root: Path | None = None) -> Path | None:
return jar_path
# Check local Maven repository
m2_jar = Path.home() / ".m2" / "repository" / "com" / "codeflash" / "codeflash-runtime" / "1.0.0" / "codeflash-runtime-1.0.0.jar"
m2_jar = (
Path.home()
/ ".m2"
/ "repository"
/ "com"
/ "codeflash"
/ "codeflash-runtime"
/ "1.0.0"
/ "codeflash-runtime-1.0.0.jar"
)
if m2_jar.exists():
return m2_jar
@ -113,8 +122,7 @@ def compare_test_results(
jar_path = comparator_jar or _find_comparator_jar(project_root)
if not jar_path or not jar_path.exists():
logger.error(
"codeflash-runtime JAR not found. "
"Please ensure the codeflash-runtime is installed in your project."
"codeflash-runtime JAR not found. Please ensure the codeflash-runtime is installed in your project."
)
return False, []
@ -155,10 +163,10 @@ def compare_test_results(
comparison = json.loads(result.stdout)
except json.JSONDecodeError as e:
logger.error(f"Failed to parse Java comparator output: {e}")
logger.error(f"stdout: {result.stdout[:500] if result.stdout else '(empty)'}")
logger.exception(f"Failed to parse Java comparator output: {e}")
logger.exception(f"stdout: {result.stdout[:500] if result.stdout else '(empty)'}")
if result.stderr:
logger.error(f"stderr: {result.stderr[:500]}")
logger.exception(f"stderr: {result.stderr[:500]}")
return False, []
# Check for errors in the JSON response
@ -178,9 +186,7 @@ def compare_test_results(
for diff in comparison.get("diffs", []):
scope_str = diff.get("scope", "return_value")
scope = TestDiffScope.RETURN_VALUE
if scope_str == "exception":
scope = TestDiffScope.DID_PASS
elif scope_str == "missing":
if scope_str in {"exception", "missing"}:
scope = TestDiffScope.DID_PASS
# Build test identifier
@ -220,20 +226,17 @@ def compare_test_results(
return equivalent, test_diffs
except subprocess.TimeoutExpired:
logger.error("Java comparator timed out")
logger.exception("Java comparator timed out")
return False, []
except FileNotFoundError:
logger.error("Java not found. Please install Java to compare test results.")
logger.exception("Java not found. Please install Java to compare test results.")
return False, []
except Exception as e:
logger.error(f"Error running Java comparator: {e}")
logger.exception(f"Error running Java comparator: {e}")
return False, []
def compare_invocations_directly(
original_results: dict,
candidate_results: dict,
) -> tuple[bool, list]:
def compare_invocations_directly(original_results: dict, candidate_results: dict) -> tuple[bool, list]:
"""Compare test invocations directly from Python dictionaries.
This is a fallback when the Java comparator is not available.

View file

@ -10,7 +10,6 @@ from __future__ import annotations
import logging
import xml.etree.ElementTree as ET
from dataclasses import dataclass, field
from pathlib import Path
from typing import TYPE_CHECKING
from codeflash.languages.java.build_tools import (
@ -22,7 +21,7 @@ from codeflash.languages.java.build_tools import (
)
if TYPE_CHECKING:
pass
from pathlib import Path
logger = logging.getLogger(__name__)
@ -80,9 +79,7 @@ def detect_java_project(project_root: Path) -> JavaProjectConfig | None:
project_info = get_project_info(project_root)
# Detect test framework
test_framework, has_junit5, has_junit4, has_testng = _detect_test_framework(
project_root, build_tool
)
test_framework, has_junit5, has_junit4, has_testng = _detect_test_framework(project_root, build_tool)
# Detect other dependencies
has_mockito, has_assertj = _detect_test_dependencies(project_root, build_tool)
@ -120,9 +117,7 @@ def detect_java_project(project_root: Path) -> JavaProjectConfig | None:
)
def _detect_test_framework(
project_root: Path, build_tool: BuildTool
) -> tuple[str, bool, bool, bool]:
def _detect_test_framework(project_root: Path, build_tool: BuildTool) -> tuple[str, bool, bool, bool]:
"""Detect which test framework the project uses.
Args:
@ -210,9 +205,7 @@ def _detect_test_deps_from_pom(project_root: Path) -> tuple[bool, bool, bool]:
elif tag == "groupId":
group_id = child.text
if group_id == "org.junit.jupiter" or (
artifact_id and "junit-jupiter" in artifact_id
):
if group_id == "org.junit.jupiter" or (artifact_id and "junit-jupiter" in artifact_id):
has_junit5 = True
elif group_id == "junit" and artifact_id == "junit":
has_junit4 = True
@ -253,9 +246,7 @@ def _detect_test_deps_from_gradle(project_root: Path) -> tuple[bool, bool, bool]
return has_junit5, has_junit4, has_testng
def _detect_test_dependencies(
project_root: Path, build_tool: BuildTool
) -> tuple[bool, bool]:
def _detect_test_dependencies(project_root: Path, build_tool: BuildTool) -> tuple[bool, bool]:
"""Detect additional test dependencies (Mockito, AssertJ).
Returns:
@ -289,9 +280,7 @@ def _detect_test_dependencies(
return has_mockito, has_assertj
def _get_compiler_settings(
project_root: Path, build_tool: BuildTool
) -> tuple[str | None, str | None]:
def _get_compiler_settings(project_root: Path, build_tool: BuildTool) -> tuple[str | None, str | None]:
"""Get compiler source and target settings.
Returns:
@ -392,11 +381,7 @@ def is_java_project(project_root: Path) -> bool:
return True
# Check for Java source files
for pattern in ["src/**/*.java", "*.java"]:
if list(project_root.glob(pattern)):
return True
return False
return any(list(project_root.glob(pattern)) for pattern in ["src/**/*.java", "*.java"])
def get_test_file_pattern(config: JavaProjectConfig) -> str:

View file

@ -8,26 +8,27 @@ and other dependencies.
from __future__ import annotations
import logging
from pathlib import Path
from typing import TYPE_CHECKING
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
from codeflash.languages.base import CodeContext, HelperFunction, Language
from codeflash.languages.java.discovery import discover_functions_from_source
from codeflash.languages.java.import_resolver import JavaImportResolver, find_helper_files
from codeflash.languages.java.parser import JavaAnalyzer, JavaClassNode, get_java_analyzer
from codeflash.languages.java.import_resolver import find_helper_files
from codeflash.languages.java.parser import get_java_analyzer
if TYPE_CHECKING:
from pathlib import Path
from tree_sitter import Node
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
from codeflash.languages.java.parser import JavaAnalyzer
logger = logging.getLogger(__name__)
class InvalidJavaSyntaxError(Exception):
"""Raised when extracted Java code is not syntactically valid."""
pass
def extract_code_context(
function: FunctionToOptimize,
@ -67,12 +68,8 @@ def extract_code_context(
try:
source = function.file_path.read_text(encoding="utf-8")
except Exception as e:
logger.error("Failed to read %s: %s", function.file_path, e)
return CodeContext(
target_code="",
target_file=function.file_path,
language=Language.JAVA,
)
logger.exception("Failed to read %s: %s", function.file_path, e)
return CodeContext(target_code="", target_file=function.file_path, language=Language.JAVA)
# Extract target function code
target_code = extract_function_source(source, function)
@ -94,9 +91,7 @@ def extract_code_context(
import_statements = [_import_to_statement(imp) for imp in imports]
# Extract helper functions
helper_functions = find_helper_functions(
function, project_root, max_helper_depth, analyzer
)
helper_functions = find_helper_functions(function, project_root, max_helper_depth, analyzer)
# Extract read-only context only if fields are NOT already in the skeleton
# Avoid duplication between target_code and read_only_context
@ -107,9 +102,8 @@ def extract_code_context(
# Validate syntax - extracted code must always be valid Java
if validate_syntax and target_code:
if not analyzer.validate_syntax(target_code):
raise InvalidJavaSyntaxError(
f"Extracted code for {function.function_name} is not syntactically valid Java:\n{target_code}"
)
msg = f"Extracted code for {function.function_name} is not syntactically valid Java:\n{target_code}"
raise InvalidJavaSyntaxError(msg)
return CodeContext(
target_code=target_code,
@ -156,7 +150,7 @@ class TypeSkeleton:
enum_constants: str,
type_indent: str,
type_kind: str, # "class", "interface", or "enum"
outer_type_skeleton: "TypeSkeleton | None" = None,
outer_type_skeleton: TypeSkeleton | None = None,
) -> None:
self.type_declaration = type_declaration
self.type_javadoc = type_javadoc
@ -173,10 +167,7 @@ ClassSkeleton = TypeSkeleton
def _extract_type_skeleton(
source: str,
type_name: str,
target_method_name: str,
analyzer: JavaAnalyzer,
source: str, type_name: str, target_method_name: str, analyzer: JavaAnalyzer
) -> TypeSkeleton | None:
"""Extract the type skeleton (class, interface, or enum) for wrapping a method.
@ -254,11 +245,7 @@ def _find_type_node(node: Node, type_name: str, source_bytes: bytes) -> tuple[No
Tuple of (node, type_kind) where type_kind is "class", "interface", or "enum".
"""
type_declarations = {
"class_declaration": "class",
"interface_declaration": "interface",
"enum_declaration": "enum",
}
type_declarations = {"class_declaration": "class", "interface_declaration": "interface", "enum_declaration": "enum"}
if node.type in type_declarations:
name_node = node.child_by_field_name("name")
@ -283,11 +270,7 @@ def _find_class_node(node: Node, class_name: str, source_bytes: bytes) -> Node |
def _get_outer_type_skeleton(
inner_type_node: Node,
source_bytes: bytes,
lines: list[str],
target_method_name: str,
analyzer: JavaAnalyzer,
inner_type_node: Node, source_bytes: bytes, lines: list[str], target_method_name: str, analyzer: JavaAnalyzer
) -> TypeSkeleton | None:
"""Get the outer type skeleton if this is an inner type.
@ -356,11 +339,7 @@ def _extract_type_declaration(type_node: Node, source_bytes: bytes, type_kind: s
parts: list[str] = []
# Determine which body node type to look for
body_types = {
"class": "class_body",
"interface": "interface_body",
"enum": "enum_body",
}
body_types = {"class": "class_body", "interface": "interface_body", "enum": "enum_body"}
body_type = body_types.get(type_kind, "class_body")
for child in type_node.children:
@ -374,7 +353,8 @@ def _extract_type_declaration(type_node: Node, source_bytes: bytes, type_kind: s
# Keep old function name for backwards compatibility
_extract_class_declaration = lambda node, source_bytes: _extract_type_declaration(node, source_bytes, "class")
def _extract_class_declaration(node, source_bytes):
return _extract_type_declaration(node, source_bytes, "class")
def _find_javadoc(node: Node, source_bytes: bytes) -> str | None:
@ -390,11 +370,7 @@ def _find_javadoc(node: Node, source_bytes: bytes) -> str | None:
def _extract_type_body_context(
body_node: Node,
source_bytes: bytes,
lines: list[str],
target_method_name: str,
type_kind: str,
body_node: Node, source_bytes: bytes, lines: list[str], target_method_name: str, type_kind: str
) -> tuple[str, str, str]:
"""Extract fields, constructors, and enum constants from a type body.
@ -473,15 +449,10 @@ def _extract_type_body_context(
# Keep old function name for backwards compatibility
def _extract_class_body_context(
body_node: Node,
source_bytes: bytes,
lines: list[str],
target_method_name: str,
body_node: Node, source_bytes: bytes, lines: list[str], target_method_name: str
) -> tuple[str, str]:
"""Extract fields and constructors from a class body."""
fields, constructors, _ = _extract_type_body_context(
body_node, source_bytes, lines, target_method_name, "class"
)
fields, constructors, _ = _extract_type_body_context(body_node, source_bytes, lines, target_method_name, "class")
return (fields, constructors)
@ -584,10 +555,7 @@ def extract_function_source(source: str, function: FunctionToOptimize) -> str:
def find_helper_functions(
function: FunctionToOptimize,
project_root: Path,
max_depth: int = 2,
analyzer: JavaAnalyzer | None = None,
function: FunctionToOptimize, project_root: Path, max_depth: int = 2, analyzer: JavaAnalyzer | None = None
) -> list[HelperFunction]:
"""Find helper functions that the target function depends on.
@ -606,11 +574,9 @@ def find_helper_functions(
visited_functions: set[str] = set()
# Find helper files through imports
helper_files = find_helper_files(
function.file_path, project_root, max_depth, analyzer
)
helper_files = find_helper_files(function.file_path, project_root, max_depth, analyzer)
for file_path, class_names in helper_files.items():
for file_path in helper_files:
try:
source = file_path.read_text(encoding="utf-8")
file_functions = discover_functions_from_source(source, file_path, analyzer=analyzer)
@ -648,10 +614,7 @@ def find_helper_functions(
return helpers
def _find_same_class_helpers(
function: FunctionToOptimize,
analyzer: JavaAnalyzer,
) -> list[HelperFunction]:
def _find_same_class_helpers(function: FunctionToOptimize, analyzer: JavaAnalyzer) -> list[HelperFunction]:
"""Find helper methods in the same class as the target function.
Args:
@ -694,9 +657,7 @@ def _find_same_class_helpers(
and method.class_name == function.class_name
and method.name in called_methods
):
func_source = source_bytes[
method.node.start_byte : method.node.end_byte
].decode("utf8")
func_source = source_bytes[method.node.start_byte : method.node.end_byte].decode("utf8")
helpers.append(
HelperFunction(
@ -715,11 +676,7 @@ def _find_same_class_helpers(
return helpers
def extract_read_only_context(
source: str,
function: FunctionToOptimize,
analyzer: JavaAnalyzer,
) -> str:
def extract_read_only_context(source: str, function: FunctionToOptimize, analyzer: JavaAnalyzer) -> str:
"""Extract read-only context (fields, constants, inner classes).
This extracts class-level context that the function might depend on
@ -767,11 +724,7 @@ def _import_to_statement(import_info) -> str:
return f"{prefix}{import_info.import_path}{suffix};"
def extract_class_context(
file_path: Path,
class_name: str,
analyzer: JavaAnalyzer | None = None,
) -> str:
def extract_class_context(file_path: Path, class_name: str, analyzer: JavaAnalyzer | None = None) -> str:
"""Extract the full context of a class.
Args:
@ -813,5 +766,5 @@ def extract_class_context(
return package_stmt + "\n".join(import_statements) + "\n\n" + class_source
except Exception as e:
logger.error("Failed to extract class context: %s", e)
logger.exception("Failed to extract class context: %s", e)
return ""

View file

@ -12,19 +12,17 @@ from typing import TYPE_CHECKING
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
from codeflash.languages.base import FunctionFilterCriteria
from codeflash.languages.java.parser import JavaAnalyzer, JavaMethodNode, get_java_analyzer
from codeflash.languages.java.parser import get_java_analyzer
from codeflash.models.function_types import FunctionParent
if TYPE_CHECKING:
pass
from codeflash.languages.java.parser import JavaAnalyzer, JavaMethodNode
logger = logging.getLogger(__name__)
def discover_functions(
file_path: Path,
filter_criteria: FunctionFilterCriteria | None = None,
analyzer: JavaAnalyzer | None = None,
file_path: Path, filter_criteria: FunctionFilterCriteria | None = None, analyzer: JavaAnalyzer | None = None
) -> list[FunctionToOptimize]:
"""Find all optimizable functions/methods in a Java file.
@ -115,10 +113,7 @@ def discover_functions_from_source(
def _should_include_method(
method: JavaMethodNode,
criteria: FunctionFilterCriteria,
source: str,
analyzer: JavaAnalyzer,
method: JavaMethodNode, criteria: FunctionFilterCriteria, source: str, analyzer: JavaAnalyzer
) -> bool:
"""Check if a method should be included based on filter criteria.
@ -176,10 +171,7 @@ def _should_include_method(
return True
def discover_test_methods(
file_path: Path,
analyzer: JavaAnalyzer | None = None,
) -> list[FunctionToOptimize]:
def discover_test_methods(file_path: Path, analyzer: JavaAnalyzer | None = None) -> list[FunctionToOptimize]:
"""Find all JUnit test methods in a Java test file.
Looks for methods annotated with @Test, @ParameterizedTest, @RepeatedTest, etc.
@ -232,7 +224,7 @@ def _walk_tree_for_test_methods(
for child in node.children:
if child.type == "modifiers":
for mod_child in child.children:
if mod_child.type == "marker_annotation" or mod_child.type == "annotation":
if mod_child.type in {"marker_annotation", "annotation"}:
annotation_text = analyzer.get_node_text(mod_child, source_bytes)
# Check for JUnit 5 test annotations
if any(
@ -278,10 +270,7 @@ def _walk_tree_for_test_methods(
def get_method_by_name(
file_path: Path,
method_name: str,
class_name: str | None = None,
analyzer: JavaAnalyzer | None = None,
file_path: Path, method_name: str, class_name: str | None = None, analyzer: JavaAnalyzer | None = None
) -> FunctionToOptimize | None:
"""Find a specific method by name in a Java file.
@ -306,9 +295,7 @@ def get_method_by_name(
def get_class_methods(
file_path: Path,
class_name: str,
analyzer: JavaAnalyzer | None = None,
file_path: Path, class_name: str, analyzer: JavaAnalyzer | None = None
) -> list[FunctionToOptimize]:
"""Get all methods in a specific class.

View file

@ -6,16 +6,13 @@ google-java-format or other available formatters.
from __future__ import annotations
import contextlib
import logging
import os
import shutil
import subprocess
import tempfile
from pathlib import Path
from typing import TYPE_CHECKING
if TYPE_CHECKING:
pass
logger = logging.getLogger(__name__)
@ -29,7 +26,7 @@ class JavaFormatter:
# Version of google-java-format to use
GOOGLE_JAVA_FORMAT_VERSION = "1.19.2"
def __init__(self, project_root: Path | None = None):
def __init__(self, project_root: Path | None = None) -> None:
"""Initialize the Java formatter.
Args:
@ -107,21 +104,13 @@ class JavaFormatter:
try:
# Write source to temp file
with tempfile.NamedTemporaryFile(
mode="w", suffix=".java", delete=False, encoding="utf-8"
) as tmp:
with tempfile.NamedTemporaryFile(mode="w", suffix=".java", delete=False, encoding="utf-8") as tmp:
tmp.write(source)
tmp_path = tmp.name
try:
result = subprocess.run(
[
self._java_executable,
"-jar",
str(jar_path),
"--replace",
tmp_path,
],
[self._java_executable, "-jar", str(jar_path), "--replace", tmp_path],
check=False,
capture_output=True,
text=True,
@ -133,16 +122,12 @@ class JavaFormatter:
with open(tmp_path, encoding="utf-8") as f:
return f.read()
else:
logger.debug(
"google-java-format failed: %s", result.stderr or result.stdout
)
logger.debug("google-java-format failed: %s", result.stderr or result.stdout)
finally:
# Clean up temp file
try:
with contextlib.suppress(OSError):
os.unlink(tmp_path)
except OSError:
pass
except subprocess.TimeoutExpired:
logger.warning("google-java-format timed out")
@ -169,9 +154,7 @@ class JavaFormatter:
if self.project_root
else None,
# In user's home directory
Path.home()
/ ".codeflash"
/ f"google-java-format-{self.GOOGLE_JAVA_FORMAT_VERSION}-all-deps.jar",
Path.home() / ".codeflash" / f"google-java-format-{self.GOOGLE_JAVA_FORMAT_VERSION}-all-deps.jar",
# In system temp
Path(tempfile.gettempdir())
/ "codeflash"
@ -186,8 +169,7 @@ class JavaFormatter:
# Don't auto-download to avoid surprises
# Users can manually download the JAR
logger.debug(
"google-java-format JAR not found. "
"Download from https://github.com/google/google-java-format/releases"
"google-java-format JAR not found. Download from https://github.com/google/google-java-format/releases"
)
return None
@ -239,7 +221,7 @@ class JavaFormatter:
logger.info("Downloaded google-java-format to %s", jar_path)
return jar_path
except Exception as e:
logger.error("Failed to download google-java-format: %s", e)
logger.exception("Failed to download google-java-format: %s", e)
return None

View file

@ -8,14 +8,15 @@ from __future__ import annotations
import logging
from dataclasses import dataclass
from pathlib import Path
from typing import TYPE_CHECKING
from codeflash.languages.java.build_tools import find_source_root, find_test_root, get_project_info
from codeflash.languages.java.parser import JavaAnalyzer, JavaImportInfo, get_java_analyzer
from codeflash.languages.java.parser import get_java_analyzer
if TYPE_CHECKING:
pass
from pathlib import Path
from codeflash.languages.java.parser import JavaAnalyzer, JavaImportInfo
logger = logging.getLogger(__name__)
@ -35,18 +36,7 @@ class JavaImportResolver:
"""Resolves Java imports to file paths within a project."""
# Standard Java packages that are always external
STANDARD_PACKAGES = frozenset(
[
"java",
"javax",
"sun",
"com.sun",
"jdk",
"org.w3c",
"org.xml",
"org.ietf",
]
)
STANDARD_PACKAGES = frozenset(["java", "javax", "sun", "com.sun", "jdk", "org.w3c", "org.xml", "org.ietf"])
# Common third-party package prefixes
COMMON_EXTERNAL_PREFIXES = frozenset(
@ -66,7 +56,7 @@ class JavaImportResolver:
]
)
def __init__(self, project_root: Path):
def __init__(self, project_root: Path) -> None:
"""Initialize the import resolver.
Args:
@ -156,10 +146,7 @@ class JavaImportResolver:
def _is_standard_library(self, import_path: str) -> bool:
"""Check if an import is from the Java standard library."""
for prefix in self.STANDARD_PACKAGES:
if import_path.startswith(prefix + ".") or import_path == prefix:
return True
return False
return any(import_path.startswith(prefix + ".") or import_path == prefix for prefix in self.STANDARD_PACKAGES)
def _is_external_library(self, import_path: str) -> bool:
"""Check if an import is from a known external library."""
@ -249,9 +236,7 @@ class JavaImportResolver:
return None
def get_imports_from_file(
self, file_path: Path, analyzer: JavaAnalyzer | None = None
) -> list[ResolvedImport]:
def get_imports_from_file(self, file_path: Path, analyzer: JavaAnalyzer | None = None) -> list[ResolvedImport]:
"""Get and resolve all imports from a Java file.
Args:
@ -272,9 +257,7 @@ class JavaImportResolver:
logger.warning("Failed to get imports from %s: %s", file_path, e)
return []
def get_project_imports(
self, file_path: Path, analyzer: JavaAnalyzer | None = None
) -> list[ResolvedImport]:
def get_project_imports(self, file_path: Path, analyzer: JavaAnalyzer | None = None) -> list[ResolvedImport]:
"""Get only the imports that resolve to files within the project.
Args:
@ -308,10 +291,7 @@ def resolve_imports_for_file(
def find_helper_files(
file_path: Path,
project_root: Path,
max_depth: int = 2,
analyzer: JavaAnalyzer | None = None,
file_path: Path, project_root: Path, max_depth: int = 2, analyzer: JavaAnalyzer | None = None
) -> dict[Path, list[str]]:
"""Find helper files imported by a Java file, recursively.

View file

@ -17,16 +17,16 @@ from __future__ import annotations
import logging
import re
from functools import lru_cache
from pathlib import Path
from typing import TYPE_CHECKING
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
from codeflash.languages.java.parser import JavaAnalyzer
if TYPE_CHECKING:
from collections.abc import Sequence
from pathlib import Path
from typing import Any
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
from codeflash.languages.java.parser import JavaAnalyzer
logger = logging.getLogger(__name__)
@ -36,7 +36,8 @@ def _get_function_name(func: Any) -> str:
return func.function_name
if hasattr(func, "name"):
return func.name
raise AttributeError(f"Cannot get function name from {type(func)}")
msg = f"Cannot get function name from {type(func)}"
raise AttributeError(msg)
def _get_qualified_name(func: Any) -> str:
@ -135,7 +136,7 @@ def instrument_existing_test(
try:
source = test_path.read_text(encoding="utf-8")
except Exception as e:
logger.error("Failed to read test file %s: %s", test_path, e)
logger.exception("Failed to read test file %s: %s", test_path, e)
return False, f"Failed to read test file: {e}"
func_name = _get_function_name(function_to_optimize)
@ -227,7 +228,7 @@ def _add_behavior_instrumentation(source: str, class_name: str, func_name: str)
result.append(imp)
imports_added = True
continue
if stripped.startswith("public class") or stripped.startswith("class"):
if stripped.startswith(("public class", "class")):
# No imports found, add before class
for imp in import_statements:
result.append(imp)
@ -244,7 +245,6 @@ def _add_behavior_instrumentation(source: str, class_name: str, func_name: str)
i = 0
iteration_counter = 0
# Pre-compile the regex pattern once
method_call_pattern = _get_method_call_pattern(func_name)
@ -291,11 +291,10 @@ def _add_behavior_instrumentation(source: str, class_name: str, func_name: str)
while i < len(lines) and brace_depth > 0:
body_line = lines[i]
# Count braces more efficiently using string methods
open_count = body_line.count('{')
close_count = body_line.count('}')
open_count = body_line.count("{")
close_count = body_line.count("}")
brace_depth += open_count - close_count
if brace_depth > 0:
body_lines.append(body_line)
i += 1
@ -581,7 +580,7 @@ def create_benchmark_test(
method_id = _get_qualified_name(target_function)
class_name = getattr(target_function, "class_name", None) or "Target"
benchmark_code = f"""
return f"""
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.DisplayName;
@ -615,7 +614,6 @@ public class {class_name}Benchmark {{
}}
}}
"""
return benchmark_code
def remove_instrumentation(source: str) -> str:
@ -713,7 +711,7 @@ def _add_import(source: str, import_statement: str) -> str:
# Find the last import or package statement
for i, line in enumerate(lines):
stripped = line.strip()
if stripped.startswith("import ") or stripped.startswith("package "):
if stripped.startswith(("import ", "package ")):
insert_idx = i + 1
elif stripped and not stripped.startswith("//") and not stripped.startswith("/*"):
# First non-import, non-comment line
@ -725,13 +723,11 @@ def _add_import(source: str, import_statement: str) -> str:
return "".join(lines)
@lru_cache(maxsize=128)
def _get_method_call_pattern(func_name: str):
"""Cache compiled regex patterns for method call matching."""
return re.compile(
rf"((?:new\s+\w+\s*\([^)]*\)|[a-zA-Z_]\w*))\s*\.\s*({re.escape(func_name)})\s*\(([^)]*)\)",
re.MULTILINE
rf"((?:new\s+\w+\s*\([^)]*\)|[a-zA-Z_]\w*))\s*\.\s*({re.escape(func_name)})\s*\(([^)]*)\)", re.MULTILINE
)
@ -739,6 +735,5 @@ def _get_method_call_pattern(func_name: str):
def _get_method_call_pattern(func_name: str):
"""Cache compiled regex patterns for method call matching."""
return re.compile(
rf"((?:new\s+\w+\s*\([^)]*\)|[a-zA-Z_]\w*))\s*\.\s*({re.escape(func_name)})\s*\(([^)]*)\)",
re.MULTILINE
rf"((?:new\s+\w+\s*\([^)]*\)|[a-zA-Z_]\w*))\s*\.\s*({re.escape(func_name)})\s*\(([^)]*)\)", re.MULTILINE
)

View file

@ -13,8 +13,6 @@ from typing import TYPE_CHECKING
from tree_sitter import Language, Parser
if TYPE_CHECKING:
from pathlib import Path
from tree_sitter import Node, Tree
logger = logging.getLogger(__name__)
@ -222,9 +220,7 @@ class JavaAnalyzer:
current_class=new_class if node.type in type_declarations else current_class,
)
def _extract_method_info(
self, node: Node, source_bytes: bytes, current_class: str | None
) -> JavaMethodNode | None:
def _extract_method_info(self, node: Node, source_bytes: bytes, current_class: str | None) -> JavaMethodNode | None:
"""Extract method information from a method_declaration node."""
name = ""
is_static = False
@ -347,9 +343,7 @@ class JavaAnalyzer:
for child in node.children:
self._walk_tree_for_classes(child, source_bytes, classes, is_inner)
def _extract_class_info(
self, node: Node, source_bytes: bytes, is_inner: bool
) -> JavaClassNode | None:
def _extract_class_info(self, node: Node, source_bytes: bytes, is_inner: bool) -> JavaClassNode | None:
"""Extract class information from a class_declaration node."""
name = ""
is_public = False

View file

@ -18,10 +18,10 @@ from pathlib import Path
from typing import TYPE_CHECKING
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
from codeflash.languages.java.parser import JavaAnalyzer, JavaMethodNode, get_java_analyzer
from codeflash.languages.java.parser import get_java_analyzer
if TYPE_CHECKING:
pass
from codeflash.languages.java.parser import JavaAnalyzer
logger = logging.getLogger(__name__)
@ -35,11 +35,7 @@ class ParsedOptimization:
new_helper_methods: list[str] # Source text of new helper methods to add
def _parse_optimization_source(
new_source: str,
target_method_name: str,
analyzer: JavaAnalyzer,
) -> ParsedOptimization:
def _parse_optimization_source(new_source: str, target_method_name: str, analyzer: JavaAnalyzer) -> ParsedOptimization:
"""Parse optimization source to extract method and additional class members.
The new_source may contain:
@ -96,18 +92,12 @@ def _parse_optimization_source(
new_fields.append(field.source_text)
return ParsedOptimization(
target_method_source=target_method_source,
new_fields=new_fields,
new_helper_methods=new_helper_methods,
target_method_source=target_method_source, new_fields=new_fields, new_helper_methods=new_helper_methods
)
def _insert_class_members(
source: str,
class_name: str,
fields: list[str],
methods: list[str],
analyzer: JavaAnalyzer,
source: str, class_name: str, fields: list[str], methods: list[str], analyzer: JavaAnalyzer
) -> str:
"""Insert new class members (fields and methods) into a class.
@ -212,10 +202,7 @@ def _insert_class_members(
def replace_function(
source: str,
function: FunctionToOptimize,
new_source: str,
analyzer: JavaAnalyzer | None = None,
source: str, function: FunctionToOptimize, new_source: str, analyzer: JavaAnalyzer | None = None
) -> str:
"""Replace a function in source code with new implementation.
@ -257,9 +244,9 @@ def replace_function(
# Find all methods matching the name (there may be overloads)
matching_methods = [
m for m in methods
if m.name == func_name
and (function.class_name is None or m.class_name == function.class_name)
m
for m in methods
if m.name == func_name and (function.class_name is None or m.class_name == function.class_name)
]
if len(matching_methods) == 1:
@ -296,10 +283,7 @@ def replace_function(
break
if not target_method:
# Fallback: use the first match
logger.warning(
"Multiple overloads of %s found but no line match, using first match",
func_name,
)
logger.warning("Multiple overloads of %s found but no line match, using first match", func_name)
target_method = matching_methods[0]
target_overload_index = 0
@ -342,18 +326,16 @@ def replace_function(
len(new_helpers_to_add),
class_name,
)
source = _insert_class_members(
source, class_name, new_fields_to_add, new_helpers_to_add, analyzer
)
source = _insert_class_members(source, class_name, new_fields_to_add, new_helpers_to_add, analyzer)
# Re-find the target method after modifications
# Line numbers have shifted, but the relative order of overloads is preserved
# Use the target_overload_index we saved earlier
methods = analyzer.find_methods(source)
matching_methods = [
m for m in methods
if m.name == func_name
and (function.class_name is None or m.class_name == function.class_name)
m
for m in methods
if m.name == func_name and (function.class_name is None or m.class_name == function.class_name)
]
if matching_methods and target_overload_index < len(matching_methods):
@ -398,9 +380,7 @@ def replace_function(
before = lines[: start_line - 1] # Lines before the method
after = lines[end_line:] # Lines after the method
result = "".join(before) + indented_new_source + "".join(after)
return result
return "".join(before) + indented_new_source + "".join(after)
def _get_indentation(line: str) -> str:
@ -460,10 +440,7 @@ def _apply_indentation(lines: list[str], base_indent: str) -> str:
def replace_method_body(
source: str,
function: FunctionToOptimize,
new_body: str,
analyzer: JavaAnalyzer | None = None,
source: str, function: FunctionToOptimize, new_body: str, analyzer: JavaAnalyzer | None = None
) -> str:
"""Replace just the body of a method, preserving signature.
@ -600,11 +577,7 @@ def insert_method(
return (before + separator.encode("utf8") + indented_method.encode("utf8") + after).decode("utf8")
def remove_method(
source: str,
function: FunctionToOptimize,
analyzer: JavaAnalyzer | None = None,
) -> str:
def remove_method(source: str, function: FunctionToOptimize, analyzer: JavaAnalyzer | None = None) -> str:
"""Remove a method from source code.
Args:
@ -648,9 +621,7 @@ def remove_method(
def remove_test_functions(
test_source: str,
functions_to_remove: list[str],
analyzer: JavaAnalyzer | None = None,
test_source: str, functions_to_remove: list[str], analyzer: JavaAnalyzer | None = None
) -> str:
"""Remove specific test functions from test source code.
@ -669,9 +640,7 @@ def remove_test_functions(
methods = analyzer.find_methods(test_source)
# Sort by start line in reverse order (remove from end first)
methods_to_remove = [
m for m in methods if m.name in functions_to_remove
]
methods_to_remove = [m for m in methods if m.name in functions_to_remove]
methods_to_remove.sort(key=lambda m: m.start_line, reverse=True)
result = test_source
@ -728,9 +697,7 @@ def add_runtime_comments(
if original_ns > 0:
speedup = ((original_ns - optimized_ns) / original_ns) * 100
summary_lines.append(
f"// {inv_id}: {original_ms:.3f}ms -> {optimized_ms:.3f}ms ({speedup:.1f}% faster)"
)
summary_lines.append(f"// {inv_id}: {original_ms:.3f}ms -> {optimized_ms:.3f}ms ({speedup:.1f}% faster)")
# Insert after imports
lines = test_source.splitlines(keepends=True)

View file

@ -7,20 +7,9 @@ required methods for Java language support in codeflash.
from __future__ import annotations
import logging
from pathlib import Path
from typing import TYPE_CHECKING, Any
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
from codeflash.languages.base import (
CodeContext,
FunctionFilterCriteria,
HelperFunction,
Language,
LanguageSupport,
TestInfo,
TestResult,
)
from codeflash.languages.registry import register_language
from codeflash.languages.base import Language, LanguageSupport
from codeflash.languages.java.build_tools import find_test_root
from codeflash.languages.java.comparator import compare_test_results as _compare_test_results
from codeflash.languages.java.config import detect_java_project
@ -33,11 +22,7 @@ from codeflash.languages.java.instrumentation import (
instrument_for_benchmarking,
)
from codeflash.languages.java.parser import get_java_analyzer
from codeflash.languages.java.replacement import (
add_runtime_comments,
remove_test_functions,
replace_function,
)
from codeflash.languages.java.replacement import add_runtime_comments, remove_test_functions, replace_function
from codeflash.languages.java.test_discovery import discover_tests
from codeflash.languages.java.test_runner import (
parse_test_results,
@ -45,9 +30,14 @@ from codeflash.languages.java.test_runner import (
run_benchmarking_tests,
run_tests,
)
from codeflash.languages.registry import register_language
if TYPE_CHECKING:
from collections.abc import Sequence
from pathlib import Path
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
from codeflash.languages.base import CodeContext, FunctionFilterCriteria, HelperFunction, TestInfo, TestResult
logger = logging.getLogger(__name__)
@ -112,23 +102,17 @@ class JavaSupport(LanguageSupport):
# === Code Analysis ===
def extract_code_context(
self, function: FunctionToOptimize, project_root: Path, module_root: Path
) -> CodeContext:
def extract_code_context(self, function: FunctionToOptimize, project_root: Path, module_root: Path) -> CodeContext:
"""Extract function code and its dependencies."""
return extract_code_context(function, project_root, module_root, analyzer=self._analyzer)
def find_helper_functions(
self, function: FunctionToOptimize, project_root: Path
) -> list[HelperFunction]:
def find_helper_functions(self, function: FunctionToOptimize, project_root: Path) -> list[HelperFunction]:
"""Find helper functions called by the target function."""
return find_helper_functions(function, project_root, analyzer=self._analyzer)
# === Code Transformation ===
def replace_function(
self, source: str, function: FunctionToOptimize, new_source: str
) -> str:
def replace_function(self, source: str, function: FunctionToOptimize, new_source: str) -> str:
"""Replace a function in source code with new implementation."""
return replace_function(source, function, new_source, self._analyzer)
@ -140,11 +124,7 @@ class JavaSupport(LanguageSupport):
# === Test Execution ===
def run_tests(
self,
test_files: Sequence[Path],
cwd: Path,
env: dict[str, str],
timeout: int,
self, test_files: Sequence[Path], cwd: Path, env: dict[str, str], timeout: int
) -> tuple[list[TestResult], Path]:
"""Run tests and return results."""
return run_tests(list(test_files), cwd, env, timeout)
@ -155,15 +135,11 @@ class JavaSupport(LanguageSupport):
# === Instrumentation ===
def instrument_for_behavior(
self, source: str, functions: Sequence[FunctionToOptimize]
) -> str:
def instrument_for_behavior(self, source: str, functions: Sequence[FunctionToOptimize]) -> str:
"""Add behavior instrumentation to capture inputs/outputs."""
return instrument_for_behavior(source, functions, self._analyzer)
def instrument_for_benchmarking(
self, test_source: str, target_function: FunctionToOptimize
) -> str:
def instrument_for_benchmarking(self, test_source: str, target_function: FunctionToOptimize) -> str:
"""Add timing instrumentation to test code."""
return instrument_for_benchmarking(test_source, target_function, self._analyzer)
@ -180,32 +156,22 @@ class JavaSupport(LanguageSupport):
# === Test Editing ===
def add_runtime_comments(
self,
test_source: str,
original_runtimes: dict[str, int],
optimized_runtimes: dict[str, int],
self, test_source: str, original_runtimes: dict[str, int], optimized_runtimes: dict[str, int]
) -> str:
"""Add runtime performance comments to test source code."""
return add_runtime_comments(test_source, original_runtimes, optimized_runtimes, self._analyzer)
def remove_test_functions(
self, test_source: str, functions_to_remove: list[str]
) -> str:
def remove_test_functions(self, test_source: str, functions_to_remove: list[str]) -> str:
"""Remove specific test functions from test source code."""
return remove_test_functions(test_source, functions_to_remove, self._analyzer)
# === Test Result Comparison ===
def compare_test_results(
self,
original_results_path: Path,
candidate_results_path: Path,
project_root: Path | None = None,
self, original_results_path: Path, candidate_results_path: Path, project_root: Path | None = None
) -> tuple[bool, list]:
"""Compare test results between original and candidate code."""
return _compare_test_results(
original_results_path, candidate_results_path, project_root=project_root
)
return _compare_test_results(original_results_path, candidate_results_path, project_root=project_root)
# === Configuration ===
@ -308,12 +274,7 @@ class JavaSupport(LanguageSupport):
) -> tuple[bool, str | None]:
"""Inject profiling code into an existing test file."""
return instrument_existing_test(
test_path,
call_positions,
function_to_optimize,
tests_project_root,
mode,
self._analyzer,
test_path, call_positions, function_to_optimize, tests_project_root, mode, self._analyzer
)
def instrument_source_for_line_profiler(

View file

@ -7,27 +7,26 @@ specific functions, mapping source functions to their tests.
from __future__ import annotations
import logging
import re
from collections import defaultdict
from pathlib import Path
from typing import TYPE_CHECKING
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
from codeflash.languages.base import TestInfo
from codeflash.languages.java.config import detect_java_project
from codeflash.languages.java.discovery import discover_test_methods
from codeflash.languages.java.parser import JavaAnalyzer, get_java_analyzer
from codeflash.languages.java.parser import get_java_analyzer
if TYPE_CHECKING:
from collections.abc import Sequence
from pathlib import Path
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
from codeflash.languages.java.parser import JavaAnalyzer
logger = logging.getLogger(__name__)
def discover_tests(
test_root: Path,
source_functions: Sequence[FunctionToOptimize],
analyzer: JavaAnalyzer | None = None,
test_root: Path, source_functions: Sequence[FunctionToOptimize], analyzer: JavaAnalyzer | None = None
) -> dict[str, list[TestInfo]]:
"""Map source functions to their tests via static analysis.
@ -56,9 +55,7 @@ def discover_tests(
# Find all test files (various naming conventions)
test_files = (
list(test_root.rglob("*Test.java"))
+ list(test_root.rglob("*Tests.java"))
+ list(test_root.rglob("Test*.java"))
list(test_root.rglob("*Test.java")) + list(test_root.rglob("*Tests.java")) + list(test_root.rglob("Test*.java"))
)
# Result map
@ -71,16 +68,12 @@ def discover_tests(
for test_method in test_methods:
# Find which source functions this test might exercise
matched_functions = _match_test_to_functions(
test_method, source, function_map, analyzer
)
matched_functions = _match_test_to_functions(test_method, source, function_map, analyzer)
for func_name in matched_functions:
result[func_name].append(
TestInfo(
test_name=test_method.function_name,
test_file=test_file,
test_class=test_method.class_name,
test_name=test_method.function_name, test_file=test_file, test_class=test_method.class_name
)
)
@ -114,7 +107,7 @@ def _match_test_to_functions(
# e.g., testAdd -> add, testCalculatorAdd -> Calculator.add
test_name_lower = test_method.function_name.lower()
for func_name, func_info in function_map.items():
for func_info in function_map.values():
if func_info.function_name.lower() in test_name_lower:
matched.append(func_info.qualified_name)
@ -125,11 +118,7 @@ def _match_test_to_functions(
# Find method calls within the test method's line range
method_calls = _find_method_calls_in_range(
tree.root_node,
source_bytes,
test_method.starting_line,
test_method.ending_line,
analyzer,
tree.root_node, source_bytes, test_method.starting_line, test_method.ending_line, analyzer
)
for call_name in method_calls:
@ -151,7 +140,7 @@ def _match_test_to_functions(
source_class_name = source_class_name[4:]
# Look for functions in the matching class
for func_name, func_info in function_map.items():
for func_info in function_map.values():
if func_info.class_name == source_class_name:
if func_info.qualified_name not in matched:
matched.append(func_info.qualified_name)
@ -161,7 +150,7 @@ def _match_test_to_functions(
# This handles cases like TestQueryBlob importing Buffer and calling Buffer methods
imported_classes = _extract_imports(tree.root_node, source_bytes, analyzer)
for func_name, func_info in function_map.items():
for func_info in function_map.values():
if func_info.qualified_name in matched:
continue
@ -172,11 +161,7 @@ def _match_test_to_functions(
return matched
def _extract_imports(
node,
source_bytes: bytes,
analyzer: JavaAnalyzer,
) -> set[str]:
def _extract_imports(node, source_bytes: bytes, analyzer: JavaAnalyzer) -> set[str]:
"""Extract imported class names from a Java file.
Args:
@ -224,7 +209,7 @@ def _extract_imports(
# Regular import: extract class name from scoped_identifier
for child in n.children:
if child.type == "scoped_identifier" or child.type == "identifier":
if child.type in {"scoped_identifier", "identifier"}:
import_path = analyzer.get_node_text(child, source_bytes)
# Extract just the class name (last part)
# e.g., "com.example.Buffer" -> "Buffer"
@ -244,11 +229,7 @@ def _extract_imports(
def _find_method_calls_in_range(
node,
source_bytes: bytes,
start_line: int,
end_line: int,
analyzer: JavaAnalyzer,
node, source_bytes: bytes, start_line: int, end_line: int, analyzer: JavaAnalyzer
) -> list[str]:
"""Find method calls within a line range.
@ -278,17 +259,13 @@ def _find_method_calls_in_range(
calls.append(analyzer.get_node_text(name_node, source_bytes))
for child in node.children:
calls.extend(
_find_method_calls_in_range(child, source_bytes, start_line, end_line, analyzer)
)
calls.extend(_find_method_calls_in_range(child, source_bytes, start_line, end_line, analyzer))
return calls
def find_tests_for_function(
function: FunctionToOptimize,
test_root: Path,
analyzer: JavaAnalyzer | None = None,
function: FunctionToOptimize, test_root: Path, analyzer: JavaAnalyzer | None = None
) -> list[TestInfo]:
"""Find tests that exercise a specific function.
@ -305,10 +282,7 @@ def find_tests_for_function(
return result.get(function.qualified_name, [])
def get_test_class_for_source_class(
source_class_name: str,
test_root: Path,
) -> Path | None:
def get_test_class_for_source_class(source_class_name: str, test_root: Path) -> Path | None:
"""Find the test class file for a source class.
Args:
@ -320,11 +294,7 @@ def get_test_class_for_source_class(
"""
# Try common naming patterns
patterns = [
f"{source_class_name}Test.java",
f"Test{source_class_name}.java",
f"{source_class_name}Tests.java",
]
patterns = [f"{source_class_name}Test.java", f"Test{source_class_name}.java", f"{source_class_name}Tests.java"]
for pattern in patterns:
matches = list(test_root.rglob(pattern))
@ -334,10 +304,7 @@ def get_test_class_for_source_class(
return None
def discover_all_tests(
test_root: Path,
analyzer: JavaAnalyzer | None = None,
) -> list[FunctionToOptimize]:
def discover_all_tests(test_root: Path, analyzer: JavaAnalyzer | None = None) -> list[FunctionToOptimize]:
"""Discover all test methods in a test directory.
Args:
@ -353,9 +320,7 @@ def discover_all_tests(
# Find all test files (various naming conventions)
test_files = (
list(test_root.rglob("*Test.java"))
+ list(test_root.rglob("*Tests.java"))
+ list(test_root.rglob("Test*.java"))
list(test_root.rglob("*Test.java")) + list(test_root.rglob("*Tests.java")) + list(test_root.rglob("Test*.java"))
)
for test_file in test_files:
@ -391,24 +356,18 @@ def is_test_file(file_path: Path) -> bool:
name = file_path.name
# Check naming patterns
if name.endswith("Test.java") or name.endswith("Tests.java"):
if name.endswith(("Test.java", "Tests.java")):
return True
if name.startswith("Test") and name.endswith(".java"):
return True
# Check if it's in a test directory
path_parts = file_path.parts
for part in path_parts:
if part in ("test", "tests", "src/test"):
return True
return False
return any(part in ("test", "tests", "src/test") for part in path_parts)
def get_test_methods_for_class(
test_file: Path,
test_class_name: str | None = None,
analyzer: JavaAnalyzer | None = None,
test_file: Path, test_class_name: str | None = None, analyzer: JavaAnalyzer | None = None
) -> list[FunctionToOptimize]:
"""Get all test methods in a specific test class.
@ -430,8 +389,7 @@ def get_test_methods_for_class(
def build_test_mapping_for_project(
project_root: Path,
analyzer: JavaAnalyzer | None = None,
project_root: Path, analyzer: JavaAnalyzer | None = None
) -> dict[str, list[TestInfo]]:
"""Build a complete test mapping for a project.

View file

@ -79,10 +79,11 @@ def _validate_test_filter(test_filter: str) -> str:
name_to_validate = pattern.replace("*", "A") # Replace * with a valid char
if not _validate_java_class_name(name_to_validate):
raise ValueError(
msg = (
f"Invalid test class name or pattern: '{pattern}'. "
f"Test names must follow Java identifier rules (letters, digits, underscores, dots, dollar signs)."
)
raise ValueError(msg)
return test_filter
@ -387,10 +388,7 @@ def run_behavioral_tests(
def _compile_tests(
project_root: Path,
env: dict[str, str],
test_module: str | None = None,
timeout: int = 120,
project_root: Path, env: dict[str, str], test_module: str | None = None, timeout: int = 120
) -> subprocess.CompletedProcess:
"""Compile test code using Maven (without running tests).
@ -407,12 +405,7 @@ def _compile_tests(
mvn = find_maven_executable()
if not mvn:
logger.error("Maven not found")
return subprocess.CompletedProcess(
args=["mvn"],
returncode=-1,
stdout="",
stderr="Maven not found",
)
return subprocess.CompletedProcess(args=["mvn"], returncode=-1, stdout="", stderr="Maven not found")
cmd = [mvn, "test-compile", "-e"] # Show errors but not verbose output
@ -423,37 +416,20 @@ def _compile_tests(
try:
return subprocess.run(
cmd,
check=False,
cwd=project_root,
env=env,
capture_output=True,
text=True,
timeout=timeout,
cmd, check=False, cwd=project_root, env=env, capture_output=True, text=True, timeout=timeout
)
except subprocess.TimeoutExpired:
logger.error("Maven compilation timed out after %d seconds", timeout)
logger.exception("Maven compilation timed out after %d seconds", timeout)
return subprocess.CompletedProcess(
args=cmd,
returncode=-2,
stdout="",
stderr=f"Compilation timed out after {timeout} seconds",
args=cmd, returncode=-2, stdout="", stderr=f"Compilation timed out after {timeout} seconds"
)
except Exception as e:
logger.exception("Maven compilation failed: %s", e)
return subprocess.CompletedProcess(
args=cmd,
returncode=-1,
stdout="",
stderr=str(e),
)
return subprocess.CompletedProcess(args=cmd, returncode=-1, stdout="", stderr=str(e))
def _get_test_classpath(
project_root: Path,
env: dict[str, str],
test_module: str | None = None,
timeout: int = 60,
project_root: Path, env: dict[str, str], test_module: str | None = None, timeout: int = 60
) -> str | None:
"""Get the test classpath from Maven.
@ -474,13 +450,7 @@ def _get_test_classpath(
# Create temp file for classpath output
cp_file = project_root / ".codeflash_classpath.txt"
cmd = [
mvn,
"dependency:build-classpath",
"-DincludeScope=test",
f"-Dmdep.outputFile={cp_file}",
"-q",
]
cmd = [mvn, "dependency:build-classpath", "-DincludeScope=test", f"-Dmdep.outputFile={cp_file}", "-q"]
if test_module:
cmd.extend(["-pl", test_module])
@ -489,13 +459,7 @@ def _get_test_classpath(
try:
result = subprocess.run(
cmd,
check=False,
cwd=project_root,
env=env,
capture_output=True,
text=True,
timeout=timeout,
cmd, check=False, cwd=project_root, env=env, capture_output=True, text=True, timeout=timeout
)
if result.returncode != 0:
@ -527,7 +491,7 @@ def _get_test_classpath(
return os.pathsep.join(cp_parts)
except subprocess.TimeoutExpired:
logger.error("Getting classpath timed out")
logger.exception("Getting classpath timed out")
return None
except Exception as e:
logger.exception("Failed to get classpath: %s", e)
@ -602,30 +566,16 @@ def _run_tests_direct(
try:
return subprocess.run(
cmd,
check=False,
cwd=working_dir,
env=env,
capture_output=True,
text=True,
timeout=timeout,
cmd, check=False, cwd=working_dir, env=env, capture_output=True, text=True, timeout=timeout
)
except subprocess.TimeoutExpired:
logger.error("Direct test execution timed out after %d seconds", timeout)
logger.exception("Direct test execution timed out after %d seconds", timeout)
return subprocess.CompletedProcess(
args=cmd,
returncode=-2,
stdout="",
stderr=f"Test execution timed out after {timeout} seconds",
args=cmd, returncode=-2, stdout="", stderr=f"Test execution timed out after {timeout} seconds"
)
except Exception as e:
logger.exception("Direct test execution failed: %s", e)
return subprocess.CompletedProcess(
args=cmd,
returncode=-1,
stdout="",
stderr=str(e),
)
return subprocess.CompletedProcess(args=cmd, returncode=-1, stdout="", stderr=str(e))
def _get_test_class_names(test_paths: Any, mode: str = "performance") -> list[str]:
@ -680,10 +630,7 @@ def _get_empty_result(maven_root: Path, test_module: str | None) -> tuple[Path,
result_xml_path = _get_combined_junit_xml(surefire_dir, -1)
empty_result = subprocess.CompletedProcess(
args=["java", "-cp", "...", "ConsoleLauncher"],
returncode=-1,
stdout="",
stderr="No test classes found",
args=["java", "-cp", "...", "ConsoleLauncher"], returncode=-1, stdout="", stderr="No test classes found"
)
return result_xml_path, empty_result
@ -742,12 +689,7 @@ def _run_benchmarking_tests_maven(
run_env["CODEFLASH_INNER_ITERATIONS"] = str(inner_iterations)
result = _run_maven_tests(
maven_root,
test_paths,
run_env,
timeout=per_loop_timeout,
mode="performance",
test_module=test_module,
maven_root, test_paths, run_env, timeout=per_loop_timeout, mode="performance", test_module=test_module
)
last_result = result
@ -760,17 +702,14 @@ def _run_benchmarking_tests_maven(
elapsed = time.time() - total_start_time
if loop_idx >= min_loops and elapsed >= target_duration_seconds:
logger.debug(
"Stopping Maven benchmark after %d loops (%.2fs elapsed)",
loop_idx,
elapsed,
)
logger.debug("Stopping Maven benchmark after %d loops (%.2fs elapsed)", loop_idx, elapsed)
break
# Check if we have timing markers even if some tests failed
# We should continue looping if we're getting valid timing data
if result.returncode != 0:
import re
timing_pattern = re.compile(r"!######[^:]*:[^:]*:[^:]*:[^:]*:[^:]+:[^:]+######!")
has_timing_markers = bool(timing_pattern.search(result.stdout or ""))
if not has_timing_markers:
@ -885,8 +824,15 @@ def run_benchmarking_tests(
# Fall back to Maven-based execution
logger.warning("Falling back to Maven-based test execution")
return _run_benchmarking_tests_maven(
test_paths, test_env, cwd, timeout, project_root,
min_loops, max_loops, target_duration_seconds, inner_iterations
test_paths,
test_env,
cwd,
timeout,
project_root,
min_loops,
max_loops,
target_duration_seconds,
inner_iterations,
)
logger.debug("Compilation completed in %.2fs", compile_time)
@ -898,8 +844,15 @@ def run_benchmarking_tests(
if not classpath:
logger.warning("Failed to get classpath, falling back to Maven-based execution")
return _run_benchmarking_tests_maven(
test_paths, test_env, cwd, timeout, project_root,
min_loops, max_loops, target_duration_seconds, inner_iterations
test_paths,
test_env,
cwd,
timeout,
project_root,
min_loops,
max_loops,
target_duration_seconds,
inner_iterations,
)
# Step 3: Run tests multiple times directly via JVM
@ -937,12 +890,7 @@ def run_benchmarking_tests(
# Run tests directly with XML report generation
loop_start = time.time()
result = _run_tests_direct(
classpath,
test_classes,
run_env,
working_dir,
timeout=per_loop_timeout,
reports_dir=reports_dir,
classpath, test_classes, run_env, working_dir, timeout=per_loop_timeout, reports_dir=reports_dir
)
loop_time = time.time() - loop_start
@ -959,12 +907,7 @@ def run_benchmarking_tests(
# Check if JUnit Console Launcher is not available (JUnit 4 projects)
# Fall back to Maven-based execution in this case
if (
loop_idx == 1
and result.returncode != 0
and result.stderr
and "ConsoleLauncher" in result.stderr
):
if loop_idx == 1 and result.returncode != 0 and result.stderr and "ConsoleLauncher" in result.stderr:
logger.debug("JUnit Console Launcher not available, falling back to Maven-based execution")
return _run_benchmarking_tests_maven(
test_paths,
@ -993,6 +936,7 @@ def run_benchmarking_tests(
# Check if tests failed - continue looping if we have timing markers
if result.returncode != 0:
import re
timing_pattern = re.compile(r"!######[^:]*:[^:]*:[^:]*:[^:]*:[^:]+:[^:]+######!")
has_timing_markers = bool(timing_pattern.search(result.stdout or ""))
if not has_timing_markers:
@ -1319,12 +1263,7 @@ def _run_maven_tests_impl(
mvn = find_maven_executable()
if not mvn:
logger.error("Maven not found")
return subprocess.CompletedProcess(
args=["mvn"],
returncode=-1,
stdout="",
stderr="Maven not found",
)
return subprocess.CompletedProcess(args=["mvn"], returncode=-1, stdout="", stderr="Maven not found")
# Build test filter
test_filter = _build_test_filter(test_paths, mode=mode)
@ -1354,33 +1293,18 @@ def _run_maven_tests_impl(
logger.debug("Running Maven command: %s in %s", " ".join(cmd), project_root)
try:
result = subprocess.run(
cmd,
check=False,
cwd=project_root,
env=env,
capture_output=True,
text=True,
timeout=timeout,
return subprocess.run(
cmd, check=False, cwd=project_root, env=env, capture_output=True, text=True, timeout=timeout
)
return result
except subprocess.TimeoutExpired:
logger.error("Maven test execution timed out after %d seconds", timeout)
logger.exception("Maven test execution timed out after %d seconds", timeout)
return subprocess.CompletedProcess(
args=cmd,
returncode=-2,
stdout="",
stderr=f"Test execution timed out after {timeout} seconds",
args=cmd, returncode=-2, stdout="", stderr=f"Test execution timed out after {timeout} seconds"
)
except Exception as e:
logger.exception("Maven test execution failed: %s", e)
return subprocess.CompletedProcess(
args=cmd,
returncode=-1,
stdout="",
stderr=str(e),
)
return subprocess.CompletedProcess(args=cmd, returncode=-1, stdout="", stderr=str(e))
def _build_test_filter(test_paths: Any, mode: str = "behavior") -> str:
@ -1440,7 +1364,7 @@ def _path_to_class_name(path: Path) -> str | None:
Fully qualified class name, or None if unable to determine.
"""
if not path.suffix == ".java":
if path.suffix != ".java":
return None
# Try to extract package from path
@ -1463,7 +1387,7 @@ def _path_to_class_name(path: Path) -> str | None:
break
if java_idx is not None:
class_parts = parts[java_idx + 1:]
class_parts = parts[java_idx + 1 :]
# Remove .java extension from last part
class_parts[-1] = class_parts[-1].replace(".java", "")
return ".".join(class_parts)
@ -1472,12 +1396,7 @@ def _path_to_class_name(path: Path) -> str | None:
return path.stem
def run_tests(
test_files: list[Path],
cwd: Path,
env: dict[str, str],
timeout: int,
) -> tuple[list[TestResult], Path]:
def run_tests(test_files: list[Path], cwd: Path, env: dict[str, str], timeout: int) -> tuple[list[TestResult], Path]:
"""Run tests and return results.
Args:
@ -1610,10 +1529,7 @@ def _parse_surefire_xml(xml_file: Path) -> list[TestResult]:
return results
def get_test_run_command(
project_root: Path,
test_classes: list[str] | None = None,
) -> list[str]:
def get_test_run_command(project_root: Path, test_classes: list[str] | None = None) -> list[str]:
"""Get the command to run Java tests.
Args:
@ -1633,10 +1549,8 @@ def get_test_run_command(
validated_classes = []
for test_class in test_classes:
if not _validate_java_class_name(test_class):
raise ValueError(
f"Invalid test class name: '{test_class}'. "
f"Test names must follow Java identifier rules."
)
msg = f"Invalid test class name: '{test_class}'. Test names must follow Java identifier rules."
raise ValueError(msg)
validated_classes.append(test_class)
cmd.append(f"-Dtest={','.join(validated_classes)}")

View file

@ -210,10 +210,10 @@ class ReferenceFinder:
# Check if this file imports from the re-export file
import_info = self._find_matching_import(imports, reexport_file, file_path, reexported)
trigger_check = True
if import_info:
context.visited_files.add(file_path)
import_name, original_import = import_info
import_name, _original_import = import_info
file_refs = self._find_references_in_file(
file_path, file_code, reexport_name, import_name, file_analyzer, include_self=True
)
@ -651,15 +651,18 @@ class ReferenceFinder:
"""
references: list[Reference] = []
export_name = exported.export_name or exported.function_name
# Skip expensive parsing if export name not in source
if export_name not in source_code:
return references
exports = analyzer.find_exports(source_code)
lines = source_code.splitlines()
for exp in exports:
if not exp.is_reexport:
continue
# Check if this re-exports our function
export_name = exported.export_name or exported.function_name
for name, alias in exp.exported_names:
if name == export_name:
# This is a re-export of our function

View file

@ -11,6 +11,8 @@ import logging
import re
from typing import TYPE_CHECKING
from codeflash.languages.current import is_typescript
if TYPE_CHECKING:
from pathlib import Path
@ -44,9 +46,10 @@ def detect_module_system(project_root: Path, file_path: Path | None = None) -> s
"""Detect the module system used by a JavaScript/TypeScript project.
Detection strategy:
1. Check package.json for "type" field
2. If file_path provided, check file extension (.mjs = ESM, .cjs = CommonJS)
3. Analyze import statements in the file
1. Check file extension for explicit module type (.mjs, .cjs, .ts, .tsx, .mts)
- TypeScript files always use ESM syntax regardless of package.json
2. Check package.json for explicit "type" field (only if explicitly set)
3. Analyze import/export statements in the file content
4. Default to CommonJS if uncertain
Args:
@ -57,13 +60,29 @@ def detect_module_system(project_root: Path, file_path: Path | None = None) -> s
ModuleSystem constant (COMMONJS, ES_MODULE, or UNKNOWN).
"""
# Strategy 1: Check package.json
# Strategy 1: Check file extension first for explicit module type indicators
# TypeScript files always use ESM syntax (import/export)
if file_path:
suffix = file_path.suffix.lower()
if suffix == ".mjs":
logger.debug("Detected ES Module from .mjs extension")
return ModuleSystem.ES_MODULE
if suffix == ".cjs":
logger.debug("Detected CommonJS from .cjs extension")
return ModuleSystem.COMMONJS
if suffix in (".ts", ".tsx", ".mts"):
# TypeScript always uses ESM syntax (import/export)
# even if package.json doesn't have "type": "module"
logger.debug("Detected ES Module from TypeScript file extension")
return ModuleSystem.ES_MODULE
# Strategy 2: Check package.json for explicit type field
package_json = project_root / "package.json"
if package_json.exists():
try:
with package_json.open("r") as f:
pkg = json.load(f)
pkg_type = pkg.get("type", "commonjs")
pkg_type = pkg.get("type") # Don't default - only use if explicitly set
if pkg_type == "module":
logger.debug("Detected ES Module from package.json type field")
@ -71,44 +90,35 @@ def detect_module_system(project_root: Path, file_path: Path | None = None) -> s
if pkg_type == "commonjs":
logger.debug("Detected CommonJS from package.json type field")
return ModuleSystem.COMMONJS
# If type is not explicitly set, continue to file content analysis
except Exception as e:
logger.warning("Failed to parse package.json: %s", e)
# Strategy 2: Check file extension
if file_path:
suffix = file_path.suffix
if suffix == ".mjs":
logger.debug("Detected ES Module from .mjs extension")
return ModuleSystem.ES_MODULE
if suffix == ".cjs":
logger.debug("Detected CommonJS from .cjs extension")
return ModuleSystem.COMMONJS
# Strategy 3: Analyze file content for import/export patterns
if file_path and file_path.exists():
try:
content = file_path.read_text()
# Strategy 3: Analyze file content
if file_path.exists():
try:
content = file_path.read_text()
# Look for ES module syntax
has_import = "import " in content and "from " in content
has_export = "export " in content or "export default" in content or "export {" in content
# Look for ES module syntax
has_import = "import " in content and "from " in content
has_export = "export " in content or "export default" in content or "export {" in content
# Look for CommonJS syntax
has_require = "require(" in content
has_module_exports = "module.exports" in content or "exports." in content
# Look for CommonJS syntax
has_require = "require(" in content
has_module_exports = "module.exports" in content or "exports." in content
# Determine based on what we found
if (has_import or has_export) and not (has_require or has_module_exports):
logger.debug("Detected ES Module from import/export statements")
return ModuleSystem.ES_MODULE
# Determine based on what we found
if (has_import or has_export) and not (has_require or has_module_exports):
logger.debug("Detected ES Module from import/export statements")
return ModuleSystem.ES_MODULE
if (has_require or has_module_exports) and not (has_import or has_export):
logger.debug("Detected CommonJS from require/module.exports")
return ModuleSystem.COMMONJS
if (has_require or has_module_exports) and not (has_import or has_export):
logger.debug("Detected CommonJS from require/module.exports")
return ModuleSystem.COMMONJS
except Exception as e:
logger.warning("Failed to analyze file %s: %s", file_path, e)
except Exception as e:
logger.warning("Failed to analyze file %s: %s", file_path, e)
# Default to CommonJS (more common and backward compatible)
logger.debug("Defaulting to CommonJS")
@ -199,11 +209,42 @@ def add_js_extension(module_path: str) -> str:
return module_path
def _convert_destructuring_to_imports(names_str: str) -> str:
"""Convert destructuring aliases to import aliases.
Converts:
a, b -> a, b
a: aliasA -> a as aliasA
a, b: aliasB -> a, b as aliasB
Args:
names_str: The destructuring pattern string (e.g., "a, b: aliasB")
Returns:
Import names string with aliases using 'as' syntax
"""
# Split by commas and process each name
parts = []
for name in names_str.split(","):
name = name.strip()
if ":" in name:
# Convert destructuring alias to import alias
# "a: aliasA" -> "a as aliasA"
original, alias = name.split(":", 1)
parts.append(f"{original.strip()} as {alias.strip()}")
else:
parts.append(name)
return ", ".join(parts)
# Replace destructured requires with named imports
def replace_destructured(match: re.Match) -> str:
names = match.group(2).strip()
module_path = add_js_extension(match.group(3))
return f"import {{ {names} }} from '{module_path}';"
# Convert destructuring aliases (a: b) to import aliases (a as b)
converted_names = _convert_destructuring_to_imports(names)
return f"import {{ {converted_names} }} from '{module_path}';"
# Replace property access requires with named imports with alias
@ -234,12 +275,14 @@ def convert_commonjs_to_esm(code: str) -> str:
"""Convert CommonJS require statements to ES Module imports.
Converts:
const { foo, bar } = require('./module'); -> import { foo, bar } from './module';
const foo = require('./module'); -> import foo from './module';
const foo = require('./module').default; -> import foo from './module';
const foo = require('./module').bar; -> import { bar as foo } from './module';
const { foo, bar } = require('./module'); -> import { foo, bar } from './module';
const { foo: alias } = require('./module'); -> import { foo as alias } from './module';
const foo = require('./module'); -> import foo from './module';
const foo = require('./module').default; -> import foo from './module';
const foo = require('./module').bar; -> import { bar as foo } from './module';
Special handling:
- Destructuring aliases (a: b) are converted to import aliases (a as b)
- Local codeflash helper (./codeflash-jest-helper) is converted to npm package codeflash
because the local helper uses CommonJS exports which don't work in ESM projects
@ -299,36 +342,89 @@ def convert_esm_to_commonjs(code: str) -> str:
return default_import.sub(replace_default, code)
def ensure_module_system_compatibility(code: str, target_module_system: str) -> str:
def uses_ts_jest(project_root: Path) -> bool:
"""Check if the project uses ts-jest for TypeScript transformation.
ts-jest handles module interoperability internally, allowing mixed
CommonJS/ESM imports without explicit conversion.
Args:
project_root: The project root directory.
Returns:
True if ts-jest is being used, False otherwise.
"""
# Check for ts-jest in devDependencies or dependencies
package_json = project_root / "package.json"
if package_json.exists():
try:
with package_json.open("r") as f:
pkg = json.load(f)
dev_deps = pkg.get("devDependencies", {})
deps = pkg.get("dependencies", {})
if "ts-jest" in dev_deps or "ts-jest" in deps:
return True
except Exception as e:
logger.debug(f"Failed to read package.json for ts-jest detection: {e}") # noqa: G004
# Also check for jest.config with ts-jest preset
for config_file in ["jest.config.js", "jest.config.cjs", "jest.config.ts", "jest.config.mjs"]:
config_path = project_root / config_file
if config_path.exists():
try:
content = config_path.read_text()
if "ts-jest" in content:
return True
except Exception as e:
logger.debug(f"Failed to read {config_file}: {e}") # noqa: G004
return False
def ensure_module_system_compatibility(code: str, target_module_system: str, project_root: Path | None = None) -> str:
"""Ensure code uses the correct module system syntax.
Detects the current module system in the code and converts if needed.
Handles mixed-style code (e.g., ESM imports with CommonJS require for npm packages).
If the project uses ts-jest, no conversion is performed because ts-jest
handles module interoperability internally. Otherwise, converts between
CommonJS and ES Modules as needed.
Args:
code: JavaScript code to check and potentially convert.
target_module_system: Target ModuleSystem (COMMONJS or ES_MODULE).
project_root: Project root directory for ts-jest detection.
Returns:
Code with correct module system syntax.
Converted code, or unchanged if ts-jest handles interop.
"""
# If ts-jest is installed, skip conversion - it handles interop natively
if is_typescript() and project_root and uses_ts_jest(project_root):
logger.debug(
f"Skipping module system conversion (target was {target_module_system}). " # noqa: G004
"ts-jest handles interop natively."
)
return code
# Detect current module system in code
has_require = "require(" in code
has_module_exports = "module.exports" in code or "exports." in code
has_import = "import " in code and "from " in code
has_export = "export " in code
if target_module_system == ModuleSystem.ES_MODULE:
# Convert any require() statements to imports for ESM projects
# This handles mixed code (ESM imports + CommonJS requires for npm packages)
if has_require:
logger.debug("Converting CommonJS requires to ESM imports")
return convert_commonjs_to_esm(code)
elif target_module_system == ModuleSystem.COMMONJS:
# Convert any import statements to requires for CommonJS projects
if has_import:
logger.debug("Converting ESM imports to CommonJS requires")
return convert_esm_to_commonjs(code)
is_commonjs = has_require or has_module_exports
is_esm = has_import or has_export
# Convert if needed
if target_module_system == ModuleSystem.ES_MODULE and is_commonjs and not is_esm:
logger.debug("Converting CommonJS to ES Module syntax")
return convert_commonjs_to_esm(code)
if target_module_system == ModuleSystem.COMMONJS and is_esm and not is_commonjs:
logger.debug("Converting ES Module to CommonJS syntax")
return convert_esm_to_commonjs(code)
logger.debug("No module system conversion needed")
return code
@ -355,12 +451,8 @@ def ensure_vitest_imports(code: str, test_framework: str) -> str:
# Check if the code uses test functions that need to be imported
test_globals = ["describe", "test", "it", "expect", "vi", "beforeEach", "afterEach", "beforeAll", "afterAll"]
needs_import = any(f"{global_name}(" in code or f"{global_name} (" in code for global_name in test_globals)
if not needs_import:
return code
# Determine which globals are actually used in the code
# Combine detection and collection into a single pass
used_globals = [g for g in test_globals if f"{g}(" in code or f"{g} (" in code]
if not used_globals:
return code
@ -373,9 +465,14 @@ def ensure_vitest_imports(code: str, test_framework: str) -> str:
insert_index = 0
for i, line in enumerate(lines):
stripped = line.strip()
if stripped and not stripped.startswith("//") and not stripped.startswith("/*") and not stripped.startswith("*"):
if (
stripped
and not stripped.startswith("//")
and not stripped.startswith("/*")
and not stripped.startswith("*")
):
# Check if this line is an import/require - insert after imports
if stripped.startswith("import ") or stripped.startswith("const ") or stripped.startswith("let "):
if stripped.startswith(("import ", "const ", "let ")):
continue
insert_index = i
break

View file

@ -296,6 +296,33 @@ def _find_node_project_root(file_path: Path) -> Path | None:
return None
def _find_monorepo_root(start_path: Path) -> Path | None:
"""Find the monorepo workspace root by looking for workspace markers.
Traverses up from the given path to find a directory containing
monorepo workspace markers like yarn.lock, pnpm-workspace.yaml, etc.
Args:
start_path: A path within the monorepo.
Returns:
The monorepo root directory, or None if not found.
"""
monorepo_markers = ["yarn.lock", "pnpm-workspace.yaml", "lerna.json", "package-lock.json"]
current = start_path if start_path.is_dir() else start_path.parent
while current != current.parent:
# Check for monorepo markers
if any((current / marker).exists() for marker in monorepo_markers):
# Verify it has node_modules (it's the workspace root)
if (current / "node_modules").exists():
return current
current = current.parent
return None
def _find_jest_config(project_root: Path) -> Path | None:
"""Find Jest configuration file in the project.
@ -797,6 +824,12 @@ def run_jest_benchmarking_tests(
jest_env["JEST_JUNIT_SUITE_NAME"] = "{filepath}"
jest_env["JEST_JUNIT_ADD_FILE_ATTRIBUTE"] = "true"
jest_env["JEST_JUNIT_INCLUDE_CONSOLE_OUTPUT"] = "true"
# Pass monorepo root to loop-runner for jest-runner resolution
monorepo_root = _find_monorepo_root(effective_cwd)
if monorepo_root:
jest_env["CODEFLASH_MONOREPO_ROOT"] = str(monorepo_root)
logger.debug(f"Detected monorepo root: {monorepo_root}")
codeflash_sqlite_file = get_run_tmp_file(Path("test_return_values_0.sqlite"))
jest_env["CODEFLASH_OUTPUT_FILE"] = str(codeflash_sqlite_file)
jest_env["CODEFLASH_TEST_ITERATION"] = "0"

View file

@ -321,7 +321,7 @@ class PythonSupport:
function_pos = None
for name in names:
if name.type == "function" and name.name == function.name:
if name.type == "function" and name.name == function.function_name:
# Check for class parent if it's a method
if function.class_name:
parent = name.parent()

View file

@ -258,6 +258,11 @@ class TreeSitterAnalyzer:
if func_info.is_arrow and not include_arrow_functions:
should_include = False
# Skip arrow functions that are object properties (e.g., { foo: () => {} })
# These are not standalone functions - they're values in object literals
if func_info.is_arrow and node.parent and node.parent.type == "pair":
should_include = False
if should_include:
functions.append(func_info)

View file

@ -15,6 +15,7 @@ from codeflash.models.test_type import TestType
if TYPE_CHECKING:
from collections.abc import Iterator
import enum
import re
import sys
@ -325,9 +326,7 @@ class CodeStringsMarkdown(BaseModel):
"""
if "file_to_path" in self._cache:
return self._cache["file_to_path"]
result = {
str(code_string.file_path): code_string.code for code_string in self.code_strings
}
result = {str(code_string.file_path): code_string.code for code_string in self.code_strings}
self._cache["file_to_path"] = result
return result
@ -877,15 +876,14 @@ class TestResults(BaseModel): # noqa: PLW1641
return max(test_result.loop_index for test_result in self.test_results)
def get_test_pass_fail_report_by_type(self) -> dict[TestType, dict[str, int]]:
report = {}
for test_type in TestType:
report[test_type] = {"passed": 0, "failed": 0}
report: dict[TestType, dict[str, int]] = {tt: {"passed": 0, "failed": 0} for tt in TestType}
for test_result in self.test_results:
if test_result.loop_index == 1:
if test_result.did_pass:
report[test_result.test_type]["passed"] += 1
else:
report[test_result.test_type]["failed"] += 1
if test_result.loop_index != 1:
continue
if test_result.did_pass:
report[test_result.test_type]["passed"] += 1
else:
report[test_result.test_type]["failed"] += 1
return report
@staticmethod

View file

@ -12,11 +12,13 @@ class TestType(Enum):
def to_name(self) -> str:
if self is TestType.INIT_STATE_TEST:
return ""
names = {
TestType.EXISTING_UNIT_TEST: "⚙️ Existing Unit Tests",
TestType.INSPIRED_REGRESSION: "🎨 Inspired Regression Tests",
TestType.GENERATED_REGRESSION: "🌀 Generated Regression Tests",
TestType.REPLAY_TEST: "⏪ Replay Tests",
TestType.CONCOLIC_COVERAGE_TEST: "🔎 Concolic Coverage Tests",
}
return names[self]
return _TO_NAME_MAP[self]
_TO_NAME_MAP: dict[TestType, str] = {
TestType.EXISTING_UNIT_TEST: "⚙️ Existing Unit Tests",
TestType.INSPIRED_REGRESSION: "🎨 Inspired Regression Tests",
TestType.GENERATED_REGRESSION: "🌀 Generated Regression Tests",
TestType.REPLAY_TEST: "⏪ Replay Tests",
TestType.CONCOLIC_COVERAGE_TEST: "🔎 Concolic Coverage Tests",
}

View file

@ -2,7 +2,6 @@ from __future__ import annotations
import ast
import concurrent.futures
import logging
import os
import queue
import random
@ -23,7 +22,7 @@ from rich.tree import Tree
from codeflash.api.aiservice import AiServiceClient, AIServiceRefinerRequest, LocalAiServiceClient
from codeflash.api.cfapi import add_code_context_hash, create_staging, get_cfapi_base_urls, mark_optimization_success
from codeflash.benchmarking.utils import process_benchmark_data
from codeflash.cli_cmds.console import code_print, console, logger, lsp_log, progress_bar
from codeflash.cli_cmds.console import DEBUG_MODE, code_print, console, logger, lsp_log, progress_bar
from codeflash.code_utils import env_utils
from codeflash.code_utils.code_extractor import get_opt_review_metrics, is_numerical_code
from codeflash.code_utils.code_replacer import (
@ -146,9 +145,70 @@ if TYPE_CHECKING:
from codeflash.verification.verification_utils import TestConfig
def log_code_after_replacement(file_path: Path, candidate_index: int) -> None:
"""Log the full file content after code replacement in verbose mode."""
if not DEBUG_MODE:
return
try:
code = file_path.read_text(encoding="utf-8")
lang_map = {".java": "java", ".py": "python", ".js": "javascript", ".ts": "typescript"}
language = lang_map.get(file_path.suffix.lower(), "text")
console.print(
Panel(
Syntax(code, language, line_numbers=True, theme="monokai", word_wrap=True),
title=f"[bold blue]Code After Replacement (Candidate {candidate_index})[/] [dim]({file_path.name})[/]",
border_style="blue",
)
)
except Exception as e:
logger.debug(f"Failed to log code after replacement: {e}")
def log_instrumented_test(test_source: str, test_name: str, test_type: str, language: str) -> None:
"""Log instrumented test code in verbose mode."""
if not DEBUG_MODE:
return
display_source = test_source
if len(test_source) > 15000:
display_source = test_source[:15000] + "\n\n... [truncated] ..."
console.print(
Panel(
Syntax(display_source, language, line_numbers=True, theme="monokai", word_wrap=True),
title=f"[bold magenta]Instrumented Test: {test_name}[/] [dim]({test_type})[/]",
border_style="magenta",
)
)
def log_test_run_output(stdout: str, stderr: str, test_type: str, returncode: int = 0) -> None:
"""Log test run stdout/stderr in verbose mode."""
if not DEBUG_MODE:
return
max_len = 10000
if stdout and stdout.strip():
display_stdout = stdout[:max_len] + ("...[truncated]" if len(stdout) > max_len else "")
console.print(
Panel(
display_stdout,
title=f"[bold green]{test_type} - stdout[/] [dim](exit: {returncode})[/]",
border_style="green" if returncode == 0 else "red",
)
)
if stderr and stderr.strip():
display_stderr = stderr[:max_len] + ("...[truncated]" if len(stderr) > max_len else "")
console.print(Panel(display_stderr, title=f"[bold yellow]{test_type} - stderr[/]", border_style="yellow"))
def log_optimization_context(function_name: str, code_context: CodeOptimizationContext) -> None:
"""Log optimization context details when in verbose mode using Rich formatting."""
if logger.getEffectiveLevel() > logging.DEBUG:
if not DEBUG_MODE:
return
console.rule()
@ -496,6 +556,16 @@ class FunctionOptimizer:
should_run_experiment = self.experiment_id is not None
logger.info(f"!lsp|Function Trace ID: {self.function_trace_id}")
ph("cli-optimize-function-start", {"function_trace_id": self.function_trace_id})
# Early check: if --no-gen-tests is set, verify there are existing tests for this function
if self.args.no_gen_tests:
func_qualname = self.function_to_optimize.qualified_name_with_modules_from_root(self.project_root)
if not self.function_to_tests.get(func_qualname):
return Failure(
f"No existing tests found for '{self.function_to_optimize.function_name}'. "
f"Cannot optimize without tests when --no-gen-tests is set."
)
self.cleanup_leftover_test_return_values()
file_name_from_test_module_name.cache_clear()
ctx_result = self.get_code_optimization_context()
@ -566,7 +636,7 @@ class FunctionOptimizer:
# Normalize codeflash imports in JS/TS tests to use npm package
if not is_python():
module_system = detect_module_system(self.project_root)
module_system = detect_module_system(self.project_root, self.function_to_optimize.file_path)
if module_system == "esm":
generated_tests = inject_test_globals(generated_tests)
if is_typescript():
@ -594,18 +664,32 @@ class FunctionOptimizer:
generated_test.instrumented_perf_test_source = modified_perf_source
used_behavior_paths.add(behavior_path)
logger.debug(
f"[PIPELINE] Test {i + 1}: behavior_path={behavior_path}, perf_path={perf_path}"
)
logger.debug(f"[PIPELINE] Test {i + 1}: behavior_path={behavior_path}, perf_path={perf_path}")
with behavior_path.open("w", encoding="utf8") as f:
f.write(generated_test.instrumented_behavior_test_source)
logger.debug(f"[PIPELINE] Wrote behavioral test to {behavior_path}")
# Verbose: Log instrumented behavior test
log_instrumented_test(
generated_test.instrumented_behavior_test_source,
behavior_path.name,
"Behavioral Test",
language=self.function_to_optimize.language,
)
with perf_path.open("w", encoding="utf8") as f:
f.write(generated_test.instrumented_perf_test_source)
logger.debug(f"[PIPELINE] Wrote perf test to {perf_path}")
# Verbose: Log instrumented performance test
log_instrumented_test(
generated_test.instrumented_perf_test_source,
perf_path.name,
"Performance Test",
language=self.function_to_optimize.language,
)
# File paths are expected to be absolute - resolved at their source (CLI, TestConfig, etc.)
test_file_obj = TestFile(
instrumented_behavior_file_path=generated_test.behavior_file_path,
@ -675,22 +759,24 @@ class FunctionOptimizer:
parts = tests_root.parts
# Look for standard Java package prefixes that indicate the start of package structure
standard_package_prefixes = ('com', 'org', 'net', 'io', 'edu', 'gov')
standard_package_prefixes = ("com", "org", "net", "io", "edu", "gov")
for i, part in enumerate(parts):
if part in standard_package_prefixes:
# Found start of package path, return everything before it
if i > 0:
java_sources_root = Path(*parts[:i])
logger.debug(f"[JAVA] Detected Java sources root: {java_sources_root} (from tests_root: {tests_root})")
logger.debug(
f"[JAVA] Detected Java sources root: {java_sources_root} (from tests_root: {tests_root})"
)
return java_sources_root
# If no standard package prefix found, check if there's a 'java' directory
# (standard Maven structure: src/test/java)
for i, part in enumerate(parts):
if part == 'java' and i > 0:
if part == "java" and i > 0:
# Return up to and including 'java'
java_sources_root = Path(*parts[:i + 1])
java_sources_root = Path(*parts[: i + 1])
logger.debug(f"[JAVA] Detected Maven-style Java sources root: {java_sources_root}")
return java_sources_root
@ -721,16 +807,16 @@ class FunctionOptimizer:
import re
# Extract package from behavior source
package_match = re.search(r'^\s*package\s+([\w.]+)\s*;', behavior_source, re.MULTILINE)
package_match = re.search(r"^\s*package\s+([\w.]+)\s*;", behavior_source, re.MULTILINE)
package_name = package_match.group(1) if package_match else ""
# Extract class name from behavior source
# Use more specific pattern to avoid matching words like "command" or text in comments
class_match = re.search(r'^(?:public\s+)?class\s+(\w+)', behavior_source, re.MULTILINE)
class_match = re.search(r"^(?:public\s+)?class\s+(\w+)", behavior_source, re.MULTILINE)
behavior_class = class_match.group(1) if class_match else "GeneratedTest"
# Extract class name from perf source
perf_class_match = re.search(r'^(?:public\s+)?class\s+(\w+)', perf_source, re.MULTILINE)
perf_class_match = re.search(r"^(?:public\s+)?class\s+(\w+)", perf_source, re.MULTILINE)
perf_class = perf_class_match.group(1) if perf_class_match else "GeneratedPerfTest"
# Build paths with package structure
@ -770,22 +856,20 @@ class FunctionOptimizer:
perf_path = new_perf_path
# Rename class in source code - replace the class declaration
modified_behavior_source = re.sub(
rf'^((?:public\s+)?class\s+){re.escape(behavior_class)}(\b)',
rf'\g<1>{new_behavior_class}\g<2>',
rf"^((?:public\s+)?class\s+){re.escape(behavior_class)}(\b)",
rf"\g<1>{new_behavior_class}\g<2>",
behavior_source,
count=1,
flags=re.MULTILINE,
)
modified_perf_source = re.sub(
rf'^((?:public\s+)?class\s+){re.escape(perf_class)}(\b)',
rf'\g<1>{new_perf_class}\g<2>',
rf"^((?:public\s+)?class\s+){re.escape(perf_class)}(\b)",
rf"\g<1>{new_perf_class}\g<2>",
perf_source,
count=1,
flags=re.MULTILINE,
)
logger.debug(
f"[JAVA] Renamed duplicate test class from {behavior_class} to {new_behavior_class}"
)
logger.debug(f"[JAVA] Renamed duplicate test class from {behavior_class} to {new_behavior_class}")
break
index += 1
@ -1202,6 +1286,9 @@ class FunctionOptimizer:
logger.info("No functions were replaced in the optimized code. Skipping optimization candidate.")
console.rule()
return None
# Verbose: Log code after replacement
log_code_after_replacement(self.function_to_optimize.file_path, candidate_index)
except (ValueError, SyntaxError, cst.ParserSyntaxError, AttributeError) as e:
logger.error(e)
self.write_code_and_helpers(
@ -1767,6 +1854,14 @@ class FunctionOptimizer:
with new_behavioral_test_path.open("w", encoding="utf8") as _f:
_f.write(injected_behavior_test)
logger.debug(f"[PIPELINE] Wrote instrumented behavior test to {new_behavioral_test_path}")
# Verbose: Log instrumented existing behavior test
log_instrumented_test(
injected_behavior_test,
new_behavioral_test_path.name,
"Existing Behavioral Test",
language=self.function_to_optimize.language,
)
else:
msg = "injected_behavior_test is None"
raise ValueError(msg)
@ -1776,6 +1871,14 @@ class FunctionOptimizer:
_f.write(injected_perf_test)
logger.debug(f"[PIPELINE] Wrote instrumented perf test to {new_perf_test_path}")
# Verbose: Log instrumented existing performance test
log_instrumented_test(
injected_perf_test,
new_perf_test_path.name,
"Existing Performance Test",
language=self.function_to_optimize.language,
)
unique_instrumented_test_files.add(new_behavioral_test_path)
unique_instrumented_test_files.add(new_perf_test_path)
@ -2242,7 +2345,7 @@ class FunctionOptimizer:
formatted_generated_test = format_generated_code(concolic_test_str, self.args.formatter_cmds)
generated_tests_str += f"```{code_lang}\n{formatted_generated_test}\n```\n\n"
existing_tests, replay_tests, concolic_tests = existing_tests_source_for(
existing_tests, replay_tests, _concolic_tests = existing_tests_source_for(
self.function_to_optimize.qualified_name_with_modules_from_root(self.project_root),
function_to_all_tests,
test_cfg=self.test_cfg,
@ -2885,6 +2988,11 @@ class FunctionOptimizer:
else:
msg = f"Unexpected testing type: {testing_type}"
raise ValueError(msg)
# Verbose: Log test run output
log_test_run_output(
run_result.stdout, run_result.stderr, f"Test Run ({testing_type.name})", run_result.returncode
)
except subprocess.TimeoutExpired:
logger.exception(
f"Error running tests in {', '.join(str(f) for f in test_files.test_files)}.\nTimeout Error"

View file

@ -21,6 +21,8 @@ from typing import Any
import tomlkit
_BUILD_DIRS = frozenset({"build", "dist", "out", ".next", ".nuxt"})
@dataclass
class DetectedProject:
@ -310,14 +312,21 @@ def _detect_js_module_root(project_root: Path) -> tuple[Path, str]:
"""Detect JavaScript/TypeScript module root.
Priority:
1. package.json "exports" field
2. package.json "module" field (ESM)
3. package.json "main" field (CJS)
4. src/ directory
5. lib/ directory
6. Project root
1. src/, lib/, source/ directories (common source directories)
2. package.json "exports" field (if not in build output directory)
3. package.json "module" field (ESM, if not in build output directory)
4. package.json "main" field (CJS, if not in build output directory)
5. Project root
Build output directories (build/, dist/, out/) are skipped since they contain
compiled code, not source files.
"""
# Check for common source directories first - these are always preferred
for src_dir in ["src", "lib", "source"]:
if (project_root / src_dir).is_dir():
return project_root / src_dir, f"{src_dir}/ directory"
package_json_path = project_root / "package.json"
package_data: dict[str, Any] = {}
@ -334,32 +343,52 @@ def _detect_js_module_root(project_root: Path) -> tuple[Path, str]:
entry_path = _extract_entry_path(exports)
if entry_path:
parent = Path(entry_path).parent
if parent != Path() and parent.as_posix() != "." and (project_root / parent).is_dir():
if (
parent != Path()
and parent.as_posix() != "."
and (project_root / parent).is_dir()
and not is_build_output_dir(parent)
):
return project_root / parent, f'{parent.as_posix()}/ (from package.json "exports")'
# Check module field (ESM)
module_field = package_data.get("module")
if module_field and isinstance(module_field, str):
parent = Path(module_field).parent
if parent != Path() and parent.as_posix() != "." and (project_root / parent).is_dir():
if (
parent != Path()
and parent.as_posix() != "."
and (project_root / parent).is_dir()
and not is_build_output_dir(parent)
):
return project_root / parent, f'{parent.as_posix()}/ (from package.json "module")'
# Check main field (CJS)
main_field = package_data.get("main")
if main_field and isinstance(main_field, str):
parent = Path(main_field).parent
if parent != Path() and parent.as_posix() != "." and (project_root / parent).is_dir():
if (
parent != Path()
and parent.as_posix() != "."
and (project_root / parent).is_dir()
and not is_build_output_dir(parent)
):
return project_root / parent, f'{parent.as_posix()}/ (from package.json "main")'
# Check for common source directories
for src_dir in ["src", "lib", "source"]:
if (project_root / src_dir).is_dir():
return project_root / src_dir, f"{src_dir}/ directory"
# Default to project root
return project_root, "project root"
def is_build_output_dir(path: Path) -> bool:
"""Check if a path is within a common build output directory.
Build output directories contain compiled code and should be skipped
in favor of source directories.
"""
return not _BUILD_DIRS.isdisjoint(path.parts)
def _extract_entry_path(exports: Any) -> str | None:
"""Extract entry path from package.json exports field."""
if isinstance(exports, str):

View file

@ -180,19 +180,31 @@ def resolve_test_file_from_class_path(test_class_path: str, base_dir: Path) -> P
# Handle file paths (contain slashes and extensions like .js/.ts)
if "/" in test_class_path or "\\" in test_class_path:
# This is a file path, not a Python module path
# Try the path as-is if it's absolute
potential_path = Path(test_class_path)
if potential_path.is_absolute() and potential_path.exists():
return potential_path
# Try to resolve relative to base_dir's parent (project root)
project_root = base_dir.parent
potential_path = project_root / test_class_path
if potential_path.exists():
return potential_path
# Normalize to resolve .. and . components
try:
potential_path = potential_path.resolve()
if potential_path.exists():
return potential_path
except (OSError, RuntimeError):
pass
# Also try relative to base_dir itself
potential_path = base_dir / test_class_path
if potential_path.exists():
return potential_path
# Try the path as-is if it's absolute
potential_path = Path(test_class_path)
if potential_path.exists():
return potential_path
try:
potential_path = potential_path.resolve()
if potential_path.exists():
return potential_path
except (OSError, RuntimeError):
pass
return None
# First try the full path (Python module path)
@ -512,8 +524,10 @@ def parse_sqlite_test_results(sqlite_file_path: Path, test_files: TestFiles, tes
# Check if the file name matches the module path
file_stem = test_file.instrumented_behavior_file_path.stem
# The instrumented file has __perfinstrumented suffix
original_class = file_stem.replace("__perfinstrumented", "").replace("__perfonlyinstrumented", "")
if original_class == test_module_path or file_stem == test_module_path:
original_class = file_stem.replace("__perfinstrumented", "").replace(
"__perfonlyinstrumented", ""
)
if test_module_path in (original_class, file_stem):
test_file_path = test_file.instrumented_behavior_file_path
break
# Check original file path
@ -551,7 +565,9 @@ def parse_sqlite_test_results(sqlite_file_path: Path, test_files: TestFiles, tes
# Default to GENERATED_REGRESSION for Jest/Java tests when test type can't be determined
if test_type is None and (is_jest or is_java_test):
test_type = TestType.GENERATED_REGRESSION
logger.debug(f"[PARSE-DEBUG] defaulting to GENERATED_REGRESSION ({'Jest' if is_jest else 'Java'})")
logger.debug(
f"[PARSE-DEBUG] defaulting to GENERATED_REGRESSION ({'Jest' if is_jest else 'Java'})"
)
elif test_type is None:
# Skip results where test type cannot be determined
logger.debug(f"Skipping result for {test_function_name}: could not determine test type")
@ -791,16 +807,25 @@ def parse_jest_test_xml(
if not test_file_path.exists():
test_file_path = base_dir / test_file_name
if test_file_path is None or not test_file_path.exists():
# For Jest tests in monorepos, test files may not exist after cleanup
# but we can still parse results and infer test type from the path
if test_file_path is None:
logger.warning(f"Could not resolve test file for Jest test: {test_class_path}")
continue
# Get test type if not already set from lookup
if test_type is None:
if test_type is None and test_file_path.exists():
test_type = test_files.get_test_type_by_instrumented_file_path(test_file_path)
if test_type is None:
# Default to GENERATED_REGRESSION for Jest tests
test_type = TestType.GENERATED_REGRESSION
# Infer test type from filename pattern
filename = test_file_path.name
if "__perf_test_" in filename or "_perf_test_" in filename:
test_type = TestType.GENERATED_PERFORMANCE
elif "__unit_test_" in filename or "_unit_test_" in filename:
test_type = TestType.GENERATED_REGRESSION
else:
# Default to GENERATED_REGRESSION for Jest tests
test_type = TestType.GENERATED_REGRESSION
# For Jest tests, keep the relative file path with extension intact
# (Python uses module_name_from_file_path which strips extensions)

View file

@ -132,11 +132,25 @@ def run_behavioral_tests(
# Check if there's a language support for this test framework that implements run_behavioral_tests
language_support = get_language_support_by_framework(test_framework)
if language_support is not None and hasattr(language_support, "run_behavioral_tests"):
# Java tests need longer timeout due to Maven startup overhead
# Use Java-specific timeout if no explicit timeout provided
from codeflash.code_utils.config_consts import JAVA_TESTCASE_TIMEOUT
effective_timeout = pytest_timeout
if test_framework == "junit5" and pytest_timeout is not None:
# For Java, use a minimum timeout to account for Maven overhead
effective_timeout = max(pytest_timeout, JAVA_TESTCASE_TIMEOUT)
if effective_timeout != pytest_timeout:
logger.debug(
f"Increased Java test timeout from {pytest_timeout}s to {effective_timeout}s "
"to account for Maven startup overhead"
)
return language_support.run_behavioral_tests(
test_paths=test_paths,
test_env=test_env,
cwd=cwd,
timeout=pytest_timeout,
timeout=effective_timeout,
project_root=js_project_root,
enable_coverage=enable_coverage,
candidate_index=candidate_index,
@ -331,11 +345,25 @@ def run_benchmarking_tests(
# Check if there's a language support for this test framework that implements run_benchmarking_tests
language_support = get_language_support_by_framework(test_framework)
if language_support is not None and hasattr(language_support, "run_benchmarking_tests"):
# Java tests need longer timeout due to Maven startup overhead
# Use Java-specific timeout if no explicit timeout provided
from codeflash.code_utils.config_consts import JAVA_TESTCASE_TIMEOUT
effective_timeout = pytest_timeout
if test_framework == "junit5" and pytest_timeout is not None:
# For Java, use a minimum timeout to account for Maven overhead
effective_timeout = max(pytest_timeout, JAVA_TESTCASE_TIMEOUT)
if effective_timeout != pytest_timeout:
logger.debug(
f"Increased Java test timeout from {pytest_timeout}s to {effective_timeout}s "
"to account for Maven startup overhead"
)
return language_support.run_benchmarking_tests(
test_paths=test_paths,
test_env=test_env,
cwd=cwd,
timeout=pytest_timeout,
timeout=effective_timeout,
project_root=js_project_root,
min_loops=pytest_min_loops,
max_loops=pytest_max_loops,

View file

@ -149,7 +149,9 @@ class TestConfig:
pom_path = current / "pom.xml"
if pom_path.exists():
parent_config = detect_java_project(current)
if parent_config and (parent_config.has_junit4 or parent_config.has_junit5 or parent_config.has_testng):
if parent_config and (
parent_config.has_junit4 or parent_config.has_junit5 or parent_config.has_testng
):
return parent_config.test_framework
current = current.parent

View file

@ -82,7 +82,10 @@ def generate_tests(
)
# Convert module system if needed (e.g., CommonJS -> ESM for ESM projects)
generated_test_source = ensure_module_system_compatibility(generated_test_source, project_module_system)
# Skip conversion if ts-jest is installed (handles interop natively)
generated_test_source = ensure_module_system_compatibility(
generated_test_source, project_module_system, test_cfg.tests_project_rootdir
)
# Ensure vitest imports are present when using vitest framework
generated_test_source = ensure_vitest_imports(generated_test_source, test_cfg.test_framework)

View file

@ -1,60 +0,0 @@
# codeflash
AI-powered code performance optimization for JavaScript and TypeScript.
## Installation
```bash
npm install -g codeflash
# or
npx codeflash
```
## Quick Start
1. Get your API key from [codeflash.ai](https://codeflash.ai)
2. Set your API key:
```bash
export CODEFLASH_API_KEY=your-api-key
```
3. Optimize a function:
```bash
codeflash --file src/utils.ts --function slowFunction
```
## Usage
```bash
# Optimize a specific function
codeflash --file <path> --function <name>
# Optimize all functions in a directory
codeflash --all src/
# Initialize GitHub Actions workflow
codeflash init-actions
# Verify setup
codeflash --verify-setup
```
## Requirements
- Node.js >= 16.0.0
- A codeflash API key
## Supported Platforms
- Linux (x64, arm64)
- macOS (x64, arm64)
- Windows (x64)
## Documentation
See [codeflash.ai/docs](https://codeflash.ai/docs) for full documentation.
## License
BSL-1.1

View file

@ -1,47 +0,0 @@
#!/usr/bin/env node
/**
* Wrapper script for codeflash CLI.
* Invokes the downloaded binary with all passed arguments.
*/
const { spawn } = require('child_process');
const path = require('path');
const fs = require('fs');
function getBinaryPath() {
const binDir = __dirname;
const isWindows = process.platform === 'win32';
return path.join(binDir, isWindows ? 'codeflash.exe' : 'codeflash-binary');
}
function main() {
const binaryPath = getBinaryPath();
if (!fs.existsSync(binaryPath)) {
console.error('\x1b[31mError: codeflash binary not found.\x1b[0m');
console.error('Try reinstalling: npm install codeflash');
process.exit(1);
}
// Pass all arguments to the binary
const args = process.argv.slice(2);
const child = spawn(binaryPath, args, {
stdio: 'inherit',
env: process.env,
});
child.on('error', (error) => {
console.error(`\x1b[31mError running codeflash: ${error.message}\x1b[0m`);
process.exit(1);
});
child.on('exit', (code, signal) => {
if (signal) {
process.exit(1);
}
process.exit(code || 0);
});
}
main();

View file

@ -1,47 +0,0 @@
{
"name": "codeflash",
"version": "0.0.0",
"description": "AI-powered code performance optimization - automatically find and fix slow code",
"keywords": [
"codeflash",
"performance",
"optimization",
"ai",
"code",
"profiler",
"typescript",
"javascript"
],
"author": "CodeFlash Inc. <contact@codeflash.ai>",
"license": "BSL-1.1",
"homepage": "https://codeflash.ai",
"repository": {
"type": "git",
"url": "git+https://github.com/codeflash-ai/codeflash.git"
},
"bugs": {
"url": "https://github.com/codeflash-ai/codeflash/issues"
},
"bin": {
"codeflash": "./bin/codeflash"
},
"scripts": {
"postinstall": "node lib/install.js"
},
"engines": {
"node": ">=16.0.0"
},
"os": [
"darwin",
"linux",
"win32"
],
"cpu": [
"x64",
"arm64"
],
"files": [
"bin/",
"lib/"
]
}

View file

@ -1,12 +1,12 @@
{
"name": "codeflash",
"version": "0.5.0",
"version": "0.7.0",
"lockfileVersion": 3,
"requires": true,
"packages": {
"": {
"name": "codeflash",
"version": "0.5.0",
"version": "0.7.0",
"hasInstallScript": true,
"license": "MIT",
"dependencies": {

View file

@ -1,6 +1,6 @@
{
"name": "codeflash",
"version": "0.5.0",
"version": "0.7.0",
"description": "Codeflash - AI-powered code optimization for JavaScript and TypeScript",
"main": "runtime/index.js",
"types": "runtime/index.d.ts",

View file

@ -0,0 +1,73 @@
/**
* Test: Dynamic environment variable reading
*
* This test verifies that the performance configuration functions read
* environment variables at runtime rather than at module load time.
*
* This is critical for Vitest compatibility, where modules may be cached
* and loaded before environment variables are set.
*
* Run with: node __tests__/dynamic-env-vars.test.js
*/
const assert = require('assert');
// Clear any existing env vars before loading the module
delete process.env.CODEFLASH_PERF_LOOP_COUNT;
delete process.env.CODEFLASH_PERF_MIN_LOOPS;
delete process.env.CODEFLASH_PERF_TARGET_DURATION_MS;
delete process.env.CODEFLASH_PERF_BATCH_SIZE;
delete process.env.CODEFLASH_PERF_STABILITY_CHECK;
delete process.env.CODEFLASH_PERF_CURRENT_BATCH;
// Now load the module - at this point env vars are not set
const capture = require('../capture');
console.log('Testing dynamic environment variable reading...\n');
// Test 1: Default values when env vars are not set
console.log('Test 1: Default values');
assert.strictEqual(capture.getPerfLoopCount(), 1, 'getPerfLoopCount default should be 1');
assert.strictEqual(capture.getPerfMinLoops(), 5, 'getPerfMinLoops default should be 5');
assert.strictEqual(capture.getPerfTargetDurationMs(), 10000, 'getPerfTargetDurationMs default should be 10000');
assert.strictEqual(capture.getPerfBatchSize(), 10, 'getPerfBatchSize default should be 10');
assert.strictEqual(capture.getPerfStabilityCheck(), false, 'getPerfStabilityCheck default should be false');
assert.strictEqual(capture.getPerfCurrentBatch(), 0, 'getPerfCurrentBatch default should be 0');
console.log(' PASS: All defaults correct\n');
// Test 2: Values change when env vars are set AFTER module load
// This is the critical test - if these were constants, they would still return defaults
console.log('Test 2: Dynamic reading after module load');
process.env.CODEFLASH_PERF_LOOP_COUNT = '100';
process.env.CODEFLASH_PERF_MIN_LOOPS = '10';
process.env.CODEFLASH_PERF_TARGET_DURATION_MS = '5000';
process.env.CODEFLASH_PERF_BATCH_SIZE = '20';
process.env.CODEFLASH_PERF_STABILITY_CHECK = 'true';
process.env.CODEFLASH_PERF_CURRENT_BATCH = '5';
assert.strictEqual(capture.getPerfLoopCount(), 100, 'getPerfLoopCount should read 100 from env');
assert.strictEqual(capture.getPerfMinLoops(), 10, 'getPerfMinLoops should read 10 from env');
assert.strictEqual(capture.getPerfTargetDurationMs(), 5000, 'getPerfTargetDurationMs should read 5000 from env');
assert.strictEqual(capture.getPerfBatchSize(), 20, 'getPerfBatchSize should read 20 from env');
assert.strictEqual(capture.getPerfStabilityCheck(), true, 'getPerfStabilityCheck should read true from env');
assert.strictEqual(capture.getPerfCurrentBatch(), 5, 'getPerfCurrentBatch should read 5 from env');
console.log(' PASS: Dynamic reading works correctly\n');
// Test 3: Values change again when env vars are modified
console.log('Test 3: Values update when env vars change');
process.env.CODEFLASH_PERF_LOOP_COUNT = '500';
process.env.CODEFLASH_PERF_BATCH_SIZE = '50';
assert.strictEqual(capture.getPerfLoopCount(), 500, 'getPerfLoopCount should update to 500');
assert.strictEqual(capture.getPerfBatchSize(), 50, 'getPerfBatchSize should update to 50');
console.log(' PASS: Values update correctly\n');
// Cleanup
delete process.env.CODEFLASH_PERF_LOOP_COUNT;
delete process.env.CODEFLASH_PERF_MIN_LOOPS;
delete process.env.CODEFLASH_PERF_TARGET_DURATION_MS;
delete process.env.CODEFLASH_PERF_BATCH_SIZE;
delete process.env.CODEFLASH_PERF_STABILITY_CHECK;
delete process.env.CODEFLASH_PERF_CURRENT_BATCH;
console.log('All tests passed!');

View file

@ -47,14 +47,30 @@ const TEST_MODULE = process.env.CODEFLASH_TEST_MODULE;
// Batch 1: Test1(5 loops) → Test2(5 loops) → Test3(5 loops)
// Batch 2: Test1(5 loops) → Test2(5 loops) → Test3(5 loops)
// ...until time budget exhausted
const PERF_LOOP_COUNT = parseInt(process.env.CODEFLASH_PERF_LOOP_COUNT || '1', 10);
const PERF_MIN_LOOPS = parseInt(process.env.CODEFLASH_PERF_MIN_LOOPS || '5', 10);
const PERF_TARGET_DURATION_MS = parseInt(process.env.CODEFLASH_PERF_TARGET_DURATION_MS || '10000', 10);
const PERF_BATCH_SIZE = parseInt(process.env.CODEFLASH_PERF_BATCH_SIZE || '10', 10);
const PERF_STABILITY_CHECK = (process.env.CODEFLASH_PERF_STABILITY_CHECK || 'false').toLowerCase() === 'true';
//
// IMPORTANT: These are getter functions, NOT constants!
// Vitest caches modules and may load this file before env vars are set.
// Using getter functions ensures we read the env vars at runtime when they're actually needed.
function getPerfLoopCount() {
return parseInt(process.env.CODEFLASH_PERF_LOOP_COUNT || '1', 10);
}
function getPerfMinLoops() {
return parseInt(process.env.CODEFLASH_PERF_MIN_LOOPS || '5', 10);
}
function getPerfTargetDurationMs() {
return parseInt(process.env.CODEFLASH_PERF_TARGET_DURATION_MS || '10000', 10);
}
function getPerfBatchSize() {
return parseInt(process.env.CODEFLASH_PERF_BATCH_SIZE || '10', 10);
}
function getPerfStabilityCheck() {
return (process.env.CODEFLASH_PERF_STABILITY_CHECK || 'false').toLowerCase() === 'true';
}
// Current batch number - set by loop-runner before each batch
// This allows continuous loop indices even when Jest resets module state
const PERF_CURRENT_BATCH = parseInt(process.env.CODEFLASH_PERF_CURRENT_BATCH || '0', 10);
function getPerfCurrentBatch() {
return parseInt(process.env.CODEFLASH_PERF_CURRENT_BATCH || '0', 10);
}
// Stability constants (matching Python's config_consts.py)
const STABILITY_WINDOW_SIZE = 0.35;
@ -86,7 +102,7 @@ function checkSharedTimeLimit() {
return false;
}
const elapsed = Date.now() - sharedPerfState.startTime;
if (elapsed >= PERF_TARGET_DURATION_MS && sharedPerfState.totalLoopsCompleted >= PERF_MIN_LOOPS) {
if (elapsed >= getPerfTargetDurationMs() && sharedPerfState.totalLoopsCompleted >= getPerfMinLoops()) {
sharedPerfState.shouldStop = true;
return true;
}
@ -111,7 +127,7 @@ function getInvocationLoopIndex(invocationKey) {
// Calculate global loop index using batch number from environment
// PERF_CURRENT_BATCH is 1-based (set by loop-runner before each batch)
const currentBatch = parseInt(process.env.CODEFLASH_PERF_CURRENT_BATCH || '1', 10);
const globalIndex = (currentBatch - 1) * PERF_BATCH_SIZE + localIndex;
const globalIndex = (currentBatch - 1) * getPerfBatchSize() + localIndex;
return globalIndex;
}
@ -606,7 +622,7 @@ function capture(funcName, lineId, fn, ...args) {
*/
function capturePerf(funcName, lineId, fn, ...args) {
// Check if we should skip looping entirely (shared time budget exceeded)
const shouldLoop = PERF_LOOP_COUNT > 1 && !checkSharedTimeLimit();
const shouldLoop = getPerfLoopCount() > 1 && !checkSharedTimeLimit();
// Get test context (computed once, reused across batch)
let testModulePath;
@ -636,9 +652,9 @@ function capturePerf(funcName, lineId, fn, ...args) {
// If so, just execute the function once without timing (for test assertions)
const peekLoopIndex = (sharedPerfState.invocationLoopCounts[invocationKey] || 0);
const currentBatch = parseInt(process.env.CODEFLASH_PERF_CURRENT_BATCH || '1', 10);
const nextGlobalIndex = (currentBatch - 1) * PERF_BATCH_SIZE + peekLoopIndex + 1;
const nextGlobalIndex = (currentBatch - 1) * getPerfBatchSize() + peekLoopIndex + 1;
if (shouldLoop && nextGlobalIndex > PERF_LOOP_COUNT) {
if (shouldLoop && nextGlobalIndex > getPerfLoopCount()) {
// All loops completed, just execute once for test assertion
return fn(...args);
}
@ -654,7 +670,7 @@ function capturePerf(funcName, lineId, fn, ...args) {
// Batched looping: run BATCH_SIZE loops per capturePerf call when using loop-runner
// For Vitest (no loop-runner), do all loops internally in a single call
const batchSize = shouldLoop
? (hasExternalLoopRunner ? PERF_BATCH_SIZE : PERF_LOOP_COUNT)
? (hasExternalLoopRunner ? getPerfBatchSize() : getPerfLoopCount())
: 1;
for (let batchIndex = 0; batchIndex < batchSize; batchIndex++) {
@ -667,7 +683,7 @@ function capturePerf(funcName, lineId, fn, ...args) {
const loopIndex = getInvocationLoopIndex(invocationKey);
// Check if we've exceeded max loops for this invocation
if (loopIndex > PERF_LOOP_COUNT) {
if (loopIndex > getPerfLoopCount()) {
break;
}
@ -872,7 +888,11 @@ module.exports = {
LOOP_INDEX,
OUTPUT_FILE,
TEST_ITERATION,
// Batch configuration
PERF_BATCH_SIZE,
PERF_LOOP_COUNT,
// Batch configuration (getter functions for dynamic env var reading)
getPerfBatchSize,
getPerfLoopCount,
getPerfMinLoops,
getPerfTargetDurationMs,
getPerfStabilityCheck,
getPerfCurrentBatch,
};

View file

@ -30,13 +30,74 @@
const { createRequire } = require('module');
const path = require('path');
const fs = require('fs');
/**
* Resolve jest-runner with monorepo support.
* Uses CODEFLASH_MONOREPO_ROOT environment variable if available,
* otherwise walks up the directory tree looking for node_modules/jest-runner.
*/
function resolveJestRunner() {
// Try standard resolution first (works in simple projects)
try {
return require.resolve('jest-runner');
} catch (e) {
// Standard resolution failed - try monorepo-aware resolution
}
// If Python detected a monorepo root, check there first
const monorepoRoot = process.env.CODEFLASH_MONOREPO_ROOT;
if (monorepoRoot) {
const jestRunnerPath = path.join(monorepoRoot, 'node_modules', 'jest-runner');
if (fs.existsSync(jestRunnerPath)) {
const packageJsonPath = path.join(jestRunnerPath, 'package.json');
if (fs.existsSync(packageJsonPath)) {
return jestRunnerPath;
}
}
}
// Fallback: Walk up from cwd looking for node_modules/jest-runner
const monorepoMarkers = ['yarn.lock', 'pnpm-workspace.yaml', 'lerna.json', 'package-lock.json'];
let currentDir = process.cwd();
const visitedDirs = new Set();
while (currentDir !== path.dirname(currentDir)) {
// Avoid infinite loops
if (visitedDirs.has(currentDir)) break;
visitedDirs.add(currentDir);
// Try node_modules/jest-runner at this level
const jestRunnerPath = path.join(currentDir, 'node_modules', 'jest-runner');
if (fs.existsSync(jestRunnerPath)) {
const packageJsonPath = path.join(jestRunnerPath, 'package.json');
if (fs.existsSync(packageJsonPath)) {
return jestRunnerPath;
}
}
// Check if this is a workspace root (has monorepo markers)
const isWorkspaceRoot = monorepoMarkers.some(marker =>
fs.existsSync(path.join(currentDir, marker))
);
if (isWorkspaceRoot) {
// Found workspace root but no jest-runner - stop searching
break;
}
currentDir = path.dirname(currentDir);
}
throw new Error('jest-runner not found');
}
// Try to load jest-runner - it's a peer dependency that must be installed by the user
let runTest;
let jestRunnerAvailable = false;
try {
const jestRunnerPath = require.resolve('jest-runner');
const jestRunnerPath = resolveJestRunner();
const internalRequire = createRequire(jestRunnerPath);
runTest = internalRequire('./runTest').default;
jestRunnerAvailable = true;

View file

@ -79,8 +79,9 @@ dev = [
"types-greenlet>=3.1.0.20241221,<4",
"types-pexpect>=4.9.0.20241208,<5",
"types-unidiff>=0.7.0.20240505,<0.8",
"uv>=0.6.2",
"pre-commit>=4.2.0,<5",
"prek>=0.2.25",
"ty>=0.0.14",
"uv>=0.9.29",
]
tests = [
"black>=25.9.0",
@ -272,6 +273,7 @@ ignore = [
"ANN401", # typing.Any disallowed
"ARG001", # Unused function argument (common in abstract/interface methods)
"TRY300", # Consider moving to else block
"FURB110", # if-exp-instead-of-or-operator - we prefer explicit if-else over "or"
"TRY401", # Redundant exception in logging.exception
"PLR0911", # Too many return statements
"PLW0603", # Global statement

View file

@ -127,13 +127,14 @@ class TestDetectModuleRoot:
assert result == "lib"
def test_detects_from_exports_object_dot(self, tmp_path: Path) -> None:
"""Should detect module root from exports object with '.' key."""
"""Should skip build output dirs and return '.' when no src dir exists."""
(tmp_path / "dist").mkdir()
package_data = {"exports": {".": "./dist/index.js"}}
result = detect_module_root(tmp_path, package_data)
assert result == "dist"
# dist is a build output directory, so it's skipped
assert result == "."
def test_detects_from_exports_object_nested(self, tmp_path: Path) -> None:
"""Should detect module root from nested exports object."""
@ -227,13 +228,14 @@ class TestDetectModuleRoot:
assert result == "src"
def test_handles_deeply_nested_exports(self, tmp_path: Path) -> None:
"""Should handle deeply nested export paths."""
"""Should handle deeply nested export paths but skip build output dirs."""
(tmp_path / "packages" / "core" / "dist").mkdir(parents=True)
package_data = {"exports": {".": {"import": "./packages/core/dist/index.mjs"}}}
result = detect_module_root(tmp_path, package_data)
assert result == "packages/core/dist"
# dist is a build output directory, so it's skipped even when nested
assert result == "."
def test_handles_empty_exports(self, tmp_path: Path) -> None:
"""Should handle empty exports gracefully."""
@ -756,7 +758,7 @@ class TestRealWorldPackageJsonExamples:
assert config["formatter_cmds"] == ["npx eslint --fix $file"]
def test_library_with_exports(self, tmp_path: Path) -> None:
"""Should handle library with modern exports field."""
"""Should handle library with modern exports field, skipping build output dirs."""
(tmp_path / "dist").mkdir()
package_json = tmp_path / "package.json"
package_json.write_text(
@ -773,7 +775,8 @@ class TestRealWorldPackageJsonExamples:
assert result is not None
config, _ = result
assert config["module_root"] == str((tmp_path / "dist").resolve())
# dist is a build output directory, so it's skipped and falls back to project root
assert config["module_root"] == str(tmp_path.resolve())
def test_monorepo_package(self, tmp_path: Path) -> None:
"""Should handle monorepo package configuration."""

View file

@ -24,6 +24,8 @@ from pathlib import Path
import pytest
from codeflash.context.code_context_extractor import get_code_optimization_context_for_language
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
from codeflash.languages.base import Language
from codeflash.languages.javascript.support import JavaScriptSupport, TypeScriptSupport
@ -1694,6 +1696,240 @@ const FIELD_KEYS = {
};"""
assert context.read_only_context == expected_read_only
def test_with_tricky_helpers(self, ts_support, temp_project):
"""Test function returning object with computed property names."""
code = """import { WebClient, ChatPostMessageArguments } from "@slack/web-api"
// Dependencies interface for easier testing
export interface SendSlackMessageDependencies {
WebClient: typeof WebClient
getSlackToken: () => string | undefined
getSlackChannelId: () => string | undefined
console: typeof console
}
// Default dependencies
let dependencies: SendSlackMessageDependencies = {
WebClient,
getSlackToken: () => process.env.SLACK_TOKEN,
getSlackChannelId: () => process.env.SLACK_CHANNEL_ID,
console,
}
// For testing - allow dependency injection
export function setSendSlackMessageDependencies(deps: Partial<SendSlackMessageDependencies>) {
dependencies = { ...dependencies, ...deps }
}
export function resetSendSlackMessageDependencies() {
dependencies = {
WebClient,
getSlackToken: () => process.env.SLACK_TOKEN,
getSlackChannelId: () => process.env.SLACK_CHANNEL_ID,
console,
}
}
// Initialize web client
let web: WebClient | null = null
export function initializeWebClient() {
const SLACK_TOKEN = dependencies.getSlackToken()
const SLACK_CHANNEL_ID = dependencies.getSlackChannelId()
if (!SLACK_TOKEN) {
throw new Error("Missing SLACK_TOKEN")
}
if (!SLACK_CHANNEL_ID) {
throw new Error("Missing SLACK_CHANNEL_ID")
}
if (!web) {
web = new dependencies.WebClient(SLACK_TOKEN, {})
}
return web
}
// For testing - allow resetting the web client
export function resetWebClient() {
web = null
}
/**
* Send a message to Slack
*
* @param {string|object} message - Text message or Block Kit message object
* @param {string|null} channel - Channel ID, defaults to SLACK_CHANNEL_ID
* @param {boolean} returnData - Whether to return the full Slack API response
* @returns {Promise<boolean|object>} - True or API response
*/
export const sendSlackMessage = async (
message: any,
channel: string | null = null,
returnData: boolean = false,
): Promise<boolean | object> => {
return new Promise(async (resolve, reject) => {
try {
const webClient = initializeWebClient()
const SLACK_CHANNEL_ID = dependencies.getSlackChannelId()
const channelId = channel || SLACK_CHANNEL_ID
// Configure the message payload depending on the input type
let payload: ChatPostMessageArguments
if (typeof message === "string") {
payload = {
channel: channelId,
text: message,
}
} else if (message && typeof message === "object") {
if (message.blocks) {
payload = {
channel: channelId,
text: message.text || "Notification from CodeFlash",
blocks: message.blocks,
}
} else {
dependencies.console.warn("Object passed to sendSlackMessage without blocks property")
payload = {
channel: channelId,
text: JSON.stringify(message),
}
}
} else {
dependencies.console.error("Invalid message type", typeof message)
payload = {
channel: channelId,
text: "Invalid message",
}
}
// console.log("Sending payload to Slack:", JSON.stringify(payload, null, 2));
const resp = await webClient.chat.postMessage(payload)
return resolve(returnData ? resp : true)
} catch (error) {
dependencies.console.error("Error sending Slack message:", error)
return resolve(returnData ? { error } : true)
}
})
}
"""
file_path = temp_project / "slack_util.ts"
file_path.write_text(code, encoding="utf-8")
target_func = "sendSlackMessage"
functions = ts_support.discover_functions(file_path)
func_info = next(f for f in functions if f.function_name == target_func)
fto = FunctionToOptimize(
function_name=target_func,
file_path=file_path,
parents=func_info.parents,
starting_line=func_info.starting_line,
ending_line=func_info.ending_line,
starting_col=func_info.starting_col,
ending_col=func_info.ending_col,
is_async=func_info.is_async,
language="typescript",
)
ctx = get_code_optimization_context_for_language(
fto, temp_project
)
# The read_writable_code should contain the target function AND helper functions
expected_read_writable = """```typescript:slack_util.ts
import { WebClient, ChatPostMessageArguments } from "@slack/web-api"
export const sendSlackMessage = async (
message: any,
channel: string | null = null,
returnData: boolean = false,
): Promise<boolean | object> => {
return new Promise(async (resolve, reject) => {
try {
const webClient = initializeWebClient()
const SLACK_CHANNEL_ID = dependencies.getSlackChannelId()
const channelId = channel || SLACK_CHANNEL_ID
// Configure the message payload depending on the input type
let payload: ChatPostMessageArguments
if (typeof message === "string") {
payload = {
channel: channelId,
text: message,
}
} else if (message && typeof message === "object") {
if (message.blocks) {
payload = {
channel: channelId,
text: message.text || "Notification from CodeFlash",
blocks: message.blocks,
}
} else {
dependencies.console.warn("Object passed to sendSlackMessage without blocks property")
payload = {
channel: channelId,
text: JSON.stringify(message),
}
}
} else {
dependencies.console.error("Invalid message type", typeof message)
payload = {
channel: channelId,
text: "Invalid message",
}
}
// console.log("Sending payload to Slack:", JSON.stringify(payload, null, 2));
const resp = await webClient.chat.postMessage(payload)
return resolve(returnData ? resp : true)
} catch (error) {
dependencies.console.error("Error sending Slack message:", error)
return resolve(returnData ? { error } : true)
}
})
}
export function initializeWebClient() {
const SLACK_TOKEN = dependencies.getSlackToken()
const SLACK_CHANNEL_ID = dependencies.getSlackChannelId()
if (!SLACK_TOKEN) {
throw new Error("Missing SLACK_TOKEN")
}
if (!SLACK_CHANNEL_ID) {
throw new Error("Missing SLACK_CHANNEL_ID")
}
if (!web) {
web = new dependencies.WebClient(SLACK_TOKEN, {})
}
return web
}
```"""
# The read_only_context should contain global variables (dependencies object, web client)
# but NOT have invalid floating object properties
expected_read_only = """let dependencies: SendSlackMessageDependencies = {
WebClient,
getSlackToken: () => process.env.SLACK_TOKEN,
getSlackChannelId: () => process.env.SLACK_CHANNEL_ID,
console,
}
let web: WebClient | null = null"""
assert ctx.read_writable_code.markdown == expected_read_writable
assert ctx.read_only_context_code == expected_read_only
class TestContextProperties:
"""Tests for CodeContext object properties."""

View file

@ -5,7 +5,13 @@ import json
import tempfile
from pathlib import Path
from codeflash.languages.javascript.module_system import ModuleSystem, detect_module_system, get_import_statement
from codeflash.languages.javascript.module_system import (
ModuleSystem,
convert_commonjs_to_esm,
convert_esm_to_commonjs,
detect_module_system,
get_import_statement,
)
class TestModuleSystemDetection:
@ -51,6 +57,39 @@ class TestModuleSystemDetection:
result = detect_module_system(project_root, file_path)
assert result == ModuleSystem.COMMONJS
def test_detect_esm_from_typescript_extension(self):
"""Test detection of ES modules from TypeScript file extensions."""
with tempfile.TemporaryDirectory() as tmpdir:
project_root = Path(tmpdir)
# Test .ts files
ts_file = project_root / "module.ts"
ts_file.write_text("export const foo = 'bar';")
assert detect_module_system(project_root, ts_file) == ModuleSystem.ES_MODULE
# Test .tsx files
tsx_file = project_root / "component.tsx"
tsx_file.write_text("export const Component = () => <div />;")
assert detect_module_system(project_root, tsx_file) == ModuleSystem.ES_MODULE
# Test .mts files
mts_file = project_root / "module.mts"
mts_file.write_text("export const foo = 'bar';")
assert detect_module_system(project_root, mts_file) == ModuleSystem.ES_MODULE
def test_typescript_ignores_package_json_commonjs(self):
"""Test that TypeScript files are detected as ESM even with CommonJS package.json."""
with tempfile.TemporaryDirectory() as tmpdir:
project_root = Path(tmpdir)
# Create package.json with explicit commonjs type
package_json = project_root / "package.json"
package_json.write_text(json.dumps({"type": "commonjs"}))
# TypeScript file should still be detected as ESM
ts_file = project_root / "module.ts"
ts_file.write_text("export const foo = 'bar';")
assert detect_module_system(project_root, ts_file) == ModuleSystem.ES_MODULE
def test_detect_esm_from_import_syntax(self):
"""Test detection of ES modules from import syntax."""
with tempfile.TemporaryDirectory() as tmpdir:
@ -159,3 +198,90 @@ class TestImportStatementGeneration:
result = get_import_statement(ModuleSystem.COMMONJS, target, source, ["foo"])
assert result == "const { foo } = require('../../utils');"
class TestModuleSystemConversion:
"""Tests for CommonJS <-> ESM conversion."""
def test_convert_simple_destructured_require(self):
"""Test converting simple destructured require to import."""
code = "const { foo, bar } = require('./module');"
result = convert_commonjs_to_esm(code)
assert result == "import { foo, bar } from './module';"
def test_convert_destructured_require_with_alias(self):
"""Test converting destructured require with alias to import with 'as'."""
code = "const { foo: aliasedFoo } = require('./module');"
result = convert_commonjs_to_esm(code)
assert result == "import { foo as aliasedFoo } from './module';"
def test_convert_mixed_destructured_require(self):
"""Test converting mixed destructured require (some aliased, some not)."""
code = "const { foo, bar: aliasedBar, baz } = require('./module');"
result = convert_commonjs_to_esm(code)
assert result == "import { foo, bar as aliasedBar, baz } from './module';"
def test_convert_destructured_with_whitespace(self):
"""Test that whitespace is handled correctly in destructuring."""
code = "const { foo : aliasedFoo , bar } = require('./module');"
result = convert_commonjs_to_esm(code)
assert result == "import { foo as aliasedFoo, bar } from './module';"
def test_convert_simple_require(self):
"""Test converting simple require to default import."""
code = "const module = require('./module');"
result = convert_commonjs_to_esm(code)
assert result == "import module from './module';"
def test_convert_property_access_require(self):
"""Test converting require with property access to named import."""
code = "const foo = require('./module').bar;"
result = convert_commonjs_to_esm(code)
assert result == "import { bar as foo } from './module';"
def test_convert_property_access_default(self):
"""Test converting require().default to default import."""
code = "const foo = require('./module').default;"
result = convert_commonjs_to_esm(code)
assert result == "import foo from './module';"
def test_convert_multiple_requires(self):
"""Test converting multiple requires in one code block."""
code = """const { db: dbCore, cache } = require('@budibase/backend-core');
const utils = require('./utils');
const { process } = require('./processor');"""
result = convert_commonjs_to_esm(code)
expected = """import { db as dbCore, cache } from '@budibase/backend-core';
import utils from './utils';
import { process } from './processor';"""
assert result == expected
def test_convert_esm_to_commonjs_named(self):
"""Test converting named imports to destructured require."""
code = "import { foo, bar } from './module';"
result = convert_esm_to_commonjs(code)
assert result == "const { foo, bar } = require('./module');"
def test_convert_esm_to_commonjs_default(self):
"""Test converting default import to simple require."""
code = "import module from './module';"
result = convert_esm_to_commonjs(code)
assert result == "const module = require('./module');"
def test_convert_esm_to_commonjs_with_alias(self):
"""Test converting import with 'as' to destructured require.
Note: ESM uses 'as' but the regex keeps it as-is in the output.
This is acceptable since the test is primarily for CommonJS -> ESM conversion.
"""
code = "import { foo as aliasedFoo } from './module';"
result = convert_esm_to_commonjs(code)
# The current implementation preserves 'as' syntax which works for our use case
assert result == "const { foo as aliasedFoo } = require('./module');"
def test_real_world_budibase_import(self):
"""Test the real-world case from Budibase that was failing."""
code = "const { queue, context, db: dbCore, cache, events } = require('@budibase/backend-core');"
result = convert_commonjs_to_esm(code)
expected = "import { queue, context, db as dbCore, cache, events } from '@budibase/backend-core';"
assert result == expected

View file

@ -15,6 +15,8 @@ from pathlib import Path
import pytest
from codeflash.code_utils.code_replacer import replace_function_definitions_for_language
from codeflash.languages.base import Language
from codeflash.languages.current import set_current_language
from codeflash.languages.javascript.module_system import (
ModuleSystem,
convert_commonjs_to_esm,
@ -300,12 +302,144 @@ export function calculate(x, y) {
assert "return add(x, y);" in result
class TestModuleSystemCompatibility:
"""Tests for module system compatibility."""
class TestTsJestSkipsConversion:
"""Tests verifying that module system conversion is skipped when ts-jest is installed.
def test_convert_mixed_code_to_esm(self):
"""Test converting mixed CJS/ESM code to pure ESM - exact output."""
code = """\
When ts-jest is installed, it handles module interoperability internally,
so we skip conversion to avoid breaking valid imports.
"""
def __init__(self):
set_current_language(Language.TYPESCRIPT)
def test_commonjs_not_converted_when_ts_jest_installed(self, tmp_path):
"""Test that CommonJS is NOT converted to ESM when ts-jest is installed."""
# Create a project with ts-jest
package_json = tmp_path / "package.json"
package_json.write_text('{"devDependencies": {"ts-jest": "^29.0.0"}}')
commonjs_test = """\
const Logger = require('../utils/logger');
const { helper } = require('../utils/helpers');
describe('Logger', () => {
test('should work', () => {
const logger = new Logger();
expect(logger).toBeDefined();
});
});
"""
# With ts-jest, no conversion should happen
result = ensure_module_system_compatibility(commonjs_test, ModuleSystem.ES_MODULE, tmp_path)
assert result == commonjs_test, (
f"CommonJS should NOT be converted when ts-jest is installed.\n"
f"Expected (unchanged):\n{commonjs_test}\n\nGot:\n{result}"
)
def test_esm_not_converted_when_ts_jest_installed(self, tmp_path):
"""Test that ESM is NOT converted to CommonJS when ts-jest is installed."""
# Create a project with ts-jest
package_json = tmp_path / "package.json"
package_json.write_text('{"devDependencies": {"ts-jest": "^29.0.0"}}')
esm_test = """\
import Logger from '../utils/logger';
import { helper } from '../utils/helpers';
describe('Logger', () => {
test('should work', () => {
const logger = new Logger();
expect(logger).toBeDefined();
});
});
"""
# With ts-jest, no conversion should happen
result = ensure_module_system_compatibility(esm_test, ModuleSystem.COMMONJS, tmp_path)
assert result == esm_test, (
f"ESM should NOT be converted when ts-jest is installed.\n"
f"Expected (unchanged):\n{esm_test}\n\nGot:\n{result}"
)
def test_ts_jest_detected_in_jest_config(self, tmp_path):
"""Test that ts-jest is detected from jest.config.js content."""
# Create a project with ts-jest in jest.config.js (not package.json)
package_json = tmp_path / "package.json"
package_json.write_text('{"devDependencies": {}}')
jest_config = tmp_path / "jest.config.js"
jest_config.write_text("module.exports = { preset: 'ts-jest' };")
commonjs_test = "const x = require('./module');"
result = ensure_module_system_compatibility(commonjs_test, ModuleSystem.ES_MODULE, tmp_path)
assert result == commonjs_test, "Should skip conversion when ts-jest is in jest.config.js"
class TestModuleSystemConversion:
"""Tests for module system conversion when ts-jest is NOT installed.
Without ts-jest, we convert between CommonJS and ESM as needed.
"""
def test_commonjs_converted_to_esm_without_ts_jest(self, tmp_path):
"""Test that CommonJS is converted to ESM when ts-jest is NOT installed."""
# Create a project WITHOUT ts-jest
package_json = tmp_path / "package.json"
package_json.write_text('{"devDependencies": {"jest": "^29.0.0"}}')
commonjs_code = """\
const { helper } = require('./helpers');
const logger = require('./logger');
function process() {
return helper();
}
"""
result = ensure_module_system_compatibility(commonjs_code, ModuleSystem.ES_MODULE, tmp_path)
# Should be converted to ESM
assert "import { helper } from './helpers';" in result
assert "import logger from './logger';" in result
assert "require(" not in result
def test_esm_converted_to_commonjs_without_ts_jest(self, tmp_path):
"""Test that ESM is converted to CommonJS when ts-jest is NOT installed."""
# Create a project WITHOUT ts-jest
package_json = tmp_path / "package.json"
package_json.write_text('{"devDependencies": {"jest": "^29.0.0"}}')
esm_code = """\
import { helper } from './helpers';
import logger from './logger';
function process() {
return helper();
}
"""
result = ensure_module_system_compatibility(esm_code, ModuleSystem.COMMONJS, tmp_path)
# Should be converted to CommonJS
assert "const { helper } = require('./helpers');" in result
assert "const logger = require('./logger');" in result
assert "import " not in result
def test_no_conversion_when_project_root_is_none(self):
"""Test that conversion happens when project_root is None (can't detect ts-jest)."""
commonjs_code = "const x = require('./module');"
# Without project_root, we can't detect ts-jest, so conversion should happen
result = ensure_module_system_compatibility(commonjs_code, ModuleSystem.ES_MODULE, None)
# Should be converted to ESM
assert "import x from './module';" in result
def test_mixed_code_not_converted(self, tmp_path):
"""Test that mixed CJS/ESM code is NOT converted (already has both)."""
package_json = tmp_path / "package.json"
package_json.write_text('{"devDependencies": {"jest": "^29.0.0"}}')
mixed_code = """\
import { existing } from './module.js';
const { helper } = require('./helpers');
@ -313,32 +447,16 @@ function process() {
return existing() + helper();
}
"""
result = ensure_module_system_compatibility(code, ModuleSystem.ES_MODULE)
# Mixed code has both import and require, so no conversion
result = ensure_module_system_compatibility(mixed_code, ModuleSystem.ES_MODULE, tmp_path)
# Should convert require to import
assert "import { helper } from './helpers';" in result
assert "require" not in result, f"require should be converted to import. Got:\n{result}"
assert result == mixed_code, "Mixed code should not be converted"
def test_convert_mixed_code_to_commonjs(self):
"""Test converting mixed ESM/CJS code to pure CommonJS - exact output."""
code = """\
const { existing } = require('./module');
import { helper } from './helpers.js';
function process() {
return existing() + helper();
}
"""
result = ensure_module_system_compatibility(code, ModuleSystem.COMMONJS)
# Should convert import to require
assert "const { helper } = require('./helpers');" in result
assert "import " not in result.split("\n")[0] or "import " not in result, (
f"import should be converted to require. Got:\n{result}"
)
def test_pure_esm_unchanged(self):
def test_pure_esm_unchanged_for_esm_target(self, tmp_path):
"""Test that pure ESM code is unchanged when targeting ESM."""
package_json = tmp_path / "package.json"
package_json.write_text('{"devDependencies": {"jest": "^29.0.0"}}')
code = """\
import { add } from './math.js';
@ -346,11 +464,14 @@ export function sum(a, b) {
return add(a, b);
}
"""
result = ensure_module_system_compatibility(code, ModuleSystem.ES_MODULE)
assert result == code, f"Pure ESM code should be unchanged.\nExpected:\n{code}\n\nGot:\n{result}"
result = ensure_module_system_compatibility(code, ModuleSystem.ES_MODULE, tmp_path)
assert result == code, "Pure ESM code should be unchanged for ESM target"
def test_pure_commonjs_unchanged(self):
def test_pure_commonjs_unchanged_for_commonjs_target(self, tmp_path):
"""Test that pure CommonJS code is unchanged when targeting CommonJS."""
package_json = tmp_path / "package.json"
package_json.write_text('{"devDependencies": {"jest": "^29.0.0"}}')
code = """\
const { add } = require('./math');
@ -360,8 +481,8 @@ function sum(a, b) {
module.exports = { sum };
"""
result = ensure_module_system_compatibility(code, ModuleSystem.COMMONJS)
assert result == code, f"Pure CommonJS code should be unchanged.\nExpected:\n{code}\n\nGot:\n{result}"
result = ensure_module_system_compatibility(code, ModuleSystem.COMMONJS, tmp_path)
assert result == code, "Pure CommonJS code should be unchanged for CommonJS target"
class TestImportStatementGeneration:

View file

@ -555,3 +555,125 @@ def standalone():
standalone_funcs = [f for f in functions if f.class_name is None]
assert len(standalone_funcs) == 1
# === Tests for find_references method ===
# These tests verify that PythonSupport correctly finds references to functions
# using jedi, including the fix for using function.function_name instead of function.name.
def test_find_references_simple_function(python_support, tmp_path):
"""Test finding references to a simple function.
This test specifically exercises the code path that was fixed in the
regression where function.name was used instead of function.function_name.
"""
from codeflash.models.function_types import FunctionToOptimize
# Create source file with function definition
source_file = tmp_path / "utils.py"
source_file.write_text("""def helper_function(x):
return x * 2
""")
# Create a file that imports and uses the function
consumer_file = tmp_path / "consumer.py"
consumer_file.write_text("""from utils import helper_function
def process(value):
return helper_function(value) + 1
""")
func = FunctionToOptimize(
function_name="helper_function",
file_path=source_file,
starting_line=1,
ending_line=2,
)
refs = python_support.find_references(func, project_root=tmp_path)
assert len(refs) >= 1
ref_files = {str(r.file_path) for r in refs}
assert any("consumer.py" in f for f in ref_files)
def test_find_references_class_method(python_support, tmp_path):
"""Test finding references to a class method.
This verifies the class_name attribute is correctly used to disambiguate methods.
"""
from codeflash.models.function_types import FunctionParent, FunctionToOptimize
# Create source file with class and method
source_file = tmp_path / "calculator.py"
source_file.write_text("""class Calculator:
def add(self, a, b):
return a + b
""")
# Create a file that uses the class method
consumer_file = tmp_path / "main.py"
consumer_file.write_text("""from calculator import Calculator
def compute():
calc = Calculator()
return calc.add(1, 2)
""")
func = FunctionToOptimize(
function_name="add",
file_path=source_file,
parents=[FunctionParent(name="Calculator", type="ClassDef")],
starting_line=2,
ending_line=3,
is_method=True,
)
refs = python_support.find_references(func, project_root=tmp_path)
assert len(refs) >= 1
ref_files = {str(r.file_path) for r in refs}
assert any("main.py" in f for f in ref_files)
def test_find_references_no_references(python_support, tmp_path):
"""Test that find_references returns empty list when no references exist."""
from codeflash.models.function_types import FunctionToOptimize
source_file = tmp_path / "isolated.py"
source_file.write_text("""def isolated_function():
return 42
""")
func = FunctionToOptimize(
function_name="isolated_function",
file_path=source_file,
starting_line=1,
ending_line=2,
)
refs = python_support.find_references(func, project_root=tmp_path)
assert refs == []
def test_find_references_nonexistent_function(python_support, tmp_path):
"""Test that find_references handles nonexistent functions gracefully."""
from codeflash.models.function_types import FunctionToOptimize
source_file = tmp_path / "source.py"
source_file.write_text("""def existing_function():
return 1
""")
func = FunctionToOptimize(
function_name="nonexistent_function",
file_path=source_file,
starting_line=1,
ending_line=2,
)
refs = python_support.find_references(func, project_root=tmp_path)
assert refs == []

View file

@ -14,6 +14,7 @@ from codeflash.setup.detector import (
_find_project_root,
detect_project,
has_existing_config,
is_build_output_dir,
)
@ -139,15 +140,15 @@ class TestDetectModuleRoot:
assert "pyproject.toml" in detail
def test_js_detects_from_exports(self, tmp_path):
"""Should detect module root from package.json exports."""
"""Should detect module root from package.json exports when no common src dir exists."""
(tmp_path / "package.json").write_text(json.dumps({
"name": "test",
"exports": {".": "./src/index.js"}
"exports": {".": "./packages/core/index.js"}
}))
(tmp_path / "src").mkdir()
(tmp_path / "packages" / "core").mkdir(parents=True)
module_root, detail = _detect_js_module_root(tmp_path)
assert module_root == tmp_path / "src"
assert module_root == tmp_path / "packages" / "core"
assert "exports" in detail
def test_js_detects_src_convention(self, tmp_path):
@ -158,6 +159,214 @@ class TestDetectModuleRoot:
module_root, detail = _detect_js_module_root(tmp_path)
assert module_root == tmp_path / "src"
def test_js_prefers_src_over_build_src(self, tmp_path):
"""Should prefer src/ over build/src/ even when package.json points to build/."""
(tmp_path / "package.json").write_text(json.dumps({
"name": "test",
"main": "build/src/index.js",
"module": "build/src/index.js"
}))
(tmp_path / "src").mkdir()
(tmp_path / "build" / "src").mkdir(parents=True)
module_root, detail = _detect_js_module_root(tmp_path)
assert module_root == tmp_path / "src"
assert "src/ directory" in detail
def test_js_skips_build_dir_from_main(self, tmp_path):
"""Should skip build output directories from package.json main field."""
(tmp_path / "package.json").write_text(json.dumps({
"name": "test",
"main": "build/index.js"
}))
(tmp_path / "build").mkdir()
module_root, detail = _detect_js_module_root(tmp_path)
assert module_root == tmp_path
assert "project root" in detail
def test_js_skips_dist_dir_from_exports(self, tmp_path):
"""Should skip dist output directories from package.json exports field."""
(tmp_path / "package.json").write_text(json.dumps({
"name": "test",
"exports": {".": "./dist/index.js"}
}))
(tmp_path / "dist").mkdir()
module_root, detail = _detect_js_module_root(tmp_path)
assert module_root == tmp_path
assert "project root" in detail
def test_js_skips_out_dir_from_module(self, tmp_path):
"""Should skip out output directories from package.json module field."""
(tmp_path / "package.json").write_text(json.dumps({
"name": "test",
"module": "out/esm/index.js"
}))
(tmp_path / "out" / "esm").mkdir(parents=True)
module_root, detail = _detect_js_module_root(tmp_path)
assert module_root == tmp_path
assert "project root" in detail
def test_js_prefers_lib_over_build_dir(self, tmp_path):
"""Should prefer lib/ over build output directories."""
(tmp_path / "package.json").write_text(json.dumps({
"name": "test",
"main": "dist/index.js"
}))
(tmp_path / "lib").mkdir()
(tmp_path / "dist").mkdir()
module_root, detail = _detect_js_module_root(tmp_path)
assert module_root == tmp_path / "lib"
assert "lib/ directory" in detail
def test_js_prefers_source_over_build_dir(self, tmp_path):
"""Should prefer source/ over build output directories."""
(tmp_path / "package.json").write_text(json.dumps({
"name": "test",
"main": "build/index.js"
}))
(tmp_path / "source").mkdir()
(tmp_path / "build").mkdir()
module_root, detail = _detect_js_module_root(tmp_path)
assert module_root == tmp_path / "source"
assert "source/ directory" in detail
def test_js_falls_back_to_valid_exports_path(self, tmp_path):
"""Should use exports path when no common source dirs exist and path is not build output."""
(tmp_path / "package.json").write_text(json.dumps({
"name": "test",
"exports": {".": "./packages/core/index.js"}
}))
(tmp_path / "packages" / "core").mkdir(parents=True)
module_root, detail = _detect_js_module_root(tmp_path)
assert module_root == tmp_path / "packages" / "core"
assert "exports" in detail
def test_js_falls_back_to_valid_main_path(self, tmp_path):
"""Should use main path when no common source dirs exist and path is not build output."""
(tmp_path / "package.json").write_text(json.dumps({
"name": "test",
"main": "packages/main/index.js"
}))
(tmp_path / "packages" / "main").mkdir(parents=True)
module_root, detail = _detect_js_module_root(tmp_path)
assert module_root == tmp_path / "packages" / "main"
assert "main" in detail
def test_js_falls_back_to_valid_module_path(self, tmp_path):
"""Should use module path when no common source dirs exist and path is not build output."""
(tmp_path / "package.json").write_text(json.dumps({
"name": "test",
"module": "esm/index.js"
}))
(tmp_path / "esm").mkdir()
module_root, detail = _detect_js_module_root(tmp_path)
assert module_root == tmp_path / "esm"
assert "module" in detail
def test_js_returns_project_root_when_all_paths_are_build_output(self, tmp_path):
"""Should return project root when all package.json paths point to build outputs."""
(tmp_path / "package.json").write_text(json.dumps({
"name": "test",
"main": "dist/cjs/index.js",
"module": "dist/esm/index.js",
"exports": {".": "./build/index.js"}
}))
(tmp_path / "dist" / "cjs").mkdir(parents=True)
(tmp_path / "dist" / "esm").mkdir(parents=True)
(tmp_path / "build").mkdir()
module_root, detail = _detect_js_module_root(tmp_path)
assert module_root == tmp_path
assert "project root" in detail
def test_js_handles_malformed_package_json(self, tmp_path):
"""Should handle malformed package.json gracefully."""
(tmp_path / "package.json").write_text("{ invalid json }")
module_root, detail = _detect_js_module_root(tmp_path)
assert module_root == tmp_path
assert "project root" in detail
class TestIsBuildOutputDir:
"""Tests for is_build_output_dir function."""
def test_detects_build_dir(self):
"""Should detect build/ as build output."""
from pathlib import Path
assert is_build_output_dir(Path("build"))
assert is_build_output_dir(Path("build/src"))
assert is_build_output_dir(Path("build/src/index.js"))
def test_detects_dist_dir(self):
"""Should detect dist/ as build output."""
from pathlib import Path
assert is_build_output_dir(Path("dist"))
assert is_build_output_dir(Path("dist/esm"))
assert is_build_output_dir(Path("dist/cjs/index.js"))
def test_detects_out_dir(self):
"""Should detect out/ as build output."""
from pathlib import Path
assert is_build_output_dir(Path("out"))
assert is_build_output_dir(Path("out/src"))
def test_detects_next_dir(self):
"""Should detect .next/ as build output."""
from pathlib import Path
assert is_build_output_dir(Path(".next"))
assert is_build_output_dir(Path(".next/static"))
def test_detects_nuxt_dir(self):
"""Should detect .nuxt/ as build output."""
from pathlib import Path
assert is_build_output_dir(Path(".nuxt"))
assert is_build_output_dir(Path(".nuxt/dist"))
def test_detects_nested_build_dir(self):
"""Should detect build dir nested in path."""
from pathlib import Path
assert is_build_output_dir(Path("packages/build/index.js"))
assert is_build_output_dir(Path("foo/dist/bar"))
def test_does_not_detect_src(self):
"""Should not detect src/ as build output."""
from pathlib import Path
assert not is_build_output_dir(Path("src"))
assert not is_build_output_dir(Path("src/index.js"))
def test_does_not_detect_lib(self):
"""Should not detect lib/ as build output."""
from pathlib import Path
assert not is_build_output_dir(Path("lib"))
assert not is_build_output_dir(Path("lib/utils"))
def test_does_not_detect_source(self):
"""Should not detect source/ as build output."""
from pathlib import Path
assert not is_build_output_dir(Path("source"))
def test_does_not_detect_packages(self):
"""Should not detect packages/ as build output."""
from pathlib import Path
assert not is_build_output_dir(Path("packages"))
assert not is_build_output_dir(Path("packages/core"))
def test_does_not_detect_similar_names(self):
"""Should not detect directories with similar but different names."""
from pathlib import Path
assert not is_build_output_dir(Path("builder"))
assert not is_build_output_dir(Path("distribution"))
assert not is_build_output_dir(Path("output"))
class TestDetectTestsRoot:
"""Tests for tests root detection."""

View file

@ -857,6 +857,48 @@ class TestE2ECLIFlags:
# Should complete without error
_handle_show_config()
def test_show_config_displays_config_path_when_saved(self, project_with_existing_config, monkeypatch):
"""Should display config file path when saved config exists."""
monkeypatch.chdir(project_with_existing_config)
# Track what gets printed
printed_messages = []
def mock_print(msg="", *args, **kwargs):
printed_messages.append(str(msg))
from codeflash.cli_cmds import console
monkeypatch.setattr(console.console, "print", mock_print)
from codeflash.cli_cmds.cli import _handle_show_config
_handle_show_config()
# Verify config path is displayed
all_output = "\n".join(printed_messages)
assert "pyproject.toml" in all_output
assert "Config file:" in all_output
def test_show_config_no_path_when_auto_detected(self, python_src_layout, monkeypatch):
"""Should not display config file path when config is auto-detected."""
monkeypatch.chdir(python_src_layout)
# Track what gets printed
printed_messages = []
def mock_print(msg="", *args, **kwargs):
printed_messages.append(str(msg))
from codeflash.cli_cmds import console
monkeypatch.setattr(console.console, "print", mock_print)
from codeflash.cli_cmds.cli import _handle_show_config
_handle_show_config()
# Verify no config path line is displayed
all_output = "\n".join(printed_messages)
assert "Config file:" not in all_output
assert "Auto-detected" in all_output
def test_reset_config_removes_from_pyproject(self, project_with_existing_config, monkeypatch):
"""Should remove codeflash config from pyproject.toml."""
monkeypatch.chdir(project_with_existing_config)

2123
uv.lock

File diff suppressed because it is too large Load diff