chore: merge omni-java into feat/java-gradle-support
Resolved conflicts in: - codeflash/version.py (kept gradle branch version) - codeflash/languages/java/test_runner.py (kept gradle multi-module logic) - codeflash/languages/java/support.py (kept java_test_module parameter) - codeflash/discovery/functions_to_optimize.py (kept enhanced test filtering) Co-Authored-By: Claude Sonnet 4.5 <noreply@anthropic.com>
This commit is contained in:
commit
c3be740417
63 changed files with 3338 additions and 1944 deletions
59
.github/workflows/claude-code-review.yml
vendored
59
.github/workflows/claude-code-review.yml
vendored
|
|
@ -1,59 +0,0 @@
|
|||
name: Claude Code Review
|
||||
|
||||
on:
|
||||
pull_request:
|
||||
types: [opened, synchronize]
|
||||
# Optional: Only run on specific file changes
|
||||
# paths:
|
||||
# - "src/**/*.ts"
|
||||
# - "src/**/*.tsx"
|
||||
# - "src/**/*.js"
|
||||
# - "src/**/*.jsx"
|
||||
|
||||
jobs:
|
||||
claude-review:
|
||||
# Optional: Filter by PR author
|
||||
# if: |
|
||||
# github.event.pull_request.user.login == 'external-contributor' ||
|
||||
# github.event.pull_request.user.login == 'new-developer' ||
|
||||
# github.event.pull_request.author_association == 'FIRST_TIME_CONTRIBUTOR'
|
||||
|
||||
runs-on: ubuntu-latest
|
||||
permissions:
|
||||
contents: read
|
||||
pull-requests: write
|
||||
issues: read
|
||||
id-token: write
|
||||
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
fetch-depth: 1
|
||||
|
||||
- name: Run Claude Code Review
|
||||
id: claude-review
|
||||
uses: anthropics/claude-code-action@v1
|
||||
with:
|
||||
use_foundry: "true"
|
||||
use_sticky_comment: true
|
||||
prompt: |
|
||||
REPO: ${{ github.repository }}
|
||||
PR NUMBER: ${{ github.event.pull_request.number }}
|
||||
|
||||
Please review this pull request and provide feedback on:
|
||||
- Code quality and best practices
|
||||
- Potential bugs or issues
|
||||
- Performance considerations
|
||||
- Security concerns
|
||||
- Test coverage
|
||||
|
||||
Use the repository's CLAUDE.md for guidance on style and conventions. Be constructive and helpful in your feedback.
|
||||
|
||||
# See https://github.com/anthropics/claude-code-action/blob/main/docs/usage.md
|
||||
# or https://code.claude.com/docs/en/cli-reference for available options
|
||||
claude_args: '--allowed-tools "Bash(gh issue view:*),Bash(gh search:*),Bash(gh issue list:*),Bash(gh pr diff:*),Bash(gh pr view:*),Bash(gh pr list:*)"'
|
||||
env:
|
||||
ANTHROPIC_FOUNDRY_API_KEY: ${{ secrets.AZURE_ANTHROPIC_API_KEY }}
|
||||
ANTHROPIC_FOUNDRY_BASE_URL: ${{ secrets.AZURE_ANTHROPIC_ENDPOINT }}
|
||||
|
||||
159
.github/workflows/claude.yml
vendored
159
.github/workflows/claude.yml
vendored
|
|
@ -1,6 +1,8 @@
|
|||
name: Claude Code
|
||||
|
||||
on:
|
||||
pull_request:
|
||||
types: [opened, synchronize, ready_for_review, reopened]
|
||||
issue_comment:
|
||||
types: [created]
|
||||
pull_request_review_comment:
|
||||
|
|
@ -11,19 +13,154 @@ on:
|
|||
types: [submitted]
|
||||
|
||||
jobs:
|
||||
claude:
|
||||
# Automatic PR review (can fix linting issues and push)
|
||||
# Blocked for fork PRs to prevent malicious code execution
|
||||
pr-review:
|
||||
if: |
|
||||
(github.event_name == 'issue_comment' && contains(github.event.comment.body, '@claude')) ||
|
||||
(github.event_name == 'pull_request_review_comment' && contains(github.event.comment.body, '@claude')) ||
|
||||
(github.event_name == 'pull_request_review' && contains(github.event.review.body, '@claude')) ||
|
||||
(github.event_name == 'issues' && (contains(github.event.issue.body, '@claude') || contains(github.event.issue.title, '@claude')))
|
||||
github.event_name == 'pull_request' &&
|
||||
github.actor != 'claude[bot]' &&
|
||||
github.event.pull_request.head.repo.full_name == github.repository
|
||||
runs-on: ubuntu-latest
|
||||
permissions:
|
||||
contents: write
|
||||
pull-requests: write
|
||||
issues: read
|
||||
id-token: write
|
||||
actions: read # Required for Claude to read CI results on PRs
|
||||
actions: read
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
fetch-depth: 0
|
||||
ref: ${{ github.event.pull_request.head.ref }}
|
||||
|
||||
- name: Install uv
|
||||
uses: astral-sh/setup-uv@v6
|
||||
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
uv venv --seed
|
||||
uv sync
|
||||
|
||||
- name: Run Claude Code
|
||||
id: claude
|
||||
uses: anthropics/claude-code-action@v1
|
||||
with:
|
||||
use_foundry: "true"
|
||||
use_sticky_comment: true
|
||||
allowed_bots: "claude[bot]"
|
||||
prompt: |
|
||||
REPO: ${{ github.repository }}
|
||||
PR NUMBER: ${{ github.event.pull_request.number }}
|
||||
EVENT: ${{ github.event.action }}
|
||||
|
||||
## STEP 1: Run pre-commit checks and fix issues
|
||||
|
||||
First, run `uv run prek run --from-ref origin/main` to check for linting/formatting issues on files changed in this PR.
|
||||
|
||||
If there are any issues:
|
||||
- For SAFE auto-fixable issues (formatting, import sorting, trailing whitespace, etc.), run `uv run prek run --from-ref origin/main` again to auto-fix them
|
||||
- Stage the fixed files with `git add`
|
||||
- Commit with message "style: auto-fix linting issues"
|
||||
- Push the changes with `git push`
|
||||
|
||||
Do NOT attempt to fix:
|
||||
- Type errors that require logic changes
|
||||
- Complex refactoring suggestions
|
||||
- Anything that could change behavior
|
||||
|
||||
## STEP 2: Review the PR
|
||||
|
||||
${{ github.event.action == 'synchronize' && 'This is a RE-REVIEW after new commits. First, get the list of changed files in this latest push using `gh pr diff`. Review ONLY the changed files. Check ALL existing review comments and resolve ones that are now fixed.' || 'This is the INITIAL REVIEW.' }}
|
||||
|
||||
Review this PR focusing ONLY on:
|
||||
1. Critical bugs or logic errors
|
||||
2. Security vulnerabilities
|
||||
3. Breaking API changes
|
||||
4. Test failures (methods with typos that wont run)
|
||||
|
||||
IMPORTANT:
|
||||
- First check existing review comments using `gh api repos/${{ github.repository }}/pulls/${{ github.event.pull_request.number }}/comments`. For each existing comment, check if the issue still exists in the current code.
|
||||
- If an issue is fixed, use `gh api --method PATCH repos/${{ github.repository }}/pulls/comments/COMMENT_ID -f body="✅ Fixed in latest commit"` to resolve it.
|
||||
- Only create NEW inline comments for HIGH-PRIORITY issues found in changed files.
|
||||
- Limit to 5-7 NEW comments maximum per review.
|
||||
- Use CLAUDE.md for project-specific guidance.
|
||||
- Use `gh pr comment` for summary-level feedback.
|
||||
- Use `mcp__github_inline_comment__create_inline_comment` sparingly for critical code issues only.
|
||||
|
||||
## STEP 3: Coverage analysis
|
||||
|
||||
Analyze test coverage for changed files:
|
||||
|
||||
1. Get the list of Python files changed in this PR (excluding tests):
|
||||
`git diff --name-only origin/main...HEAD -- '*.py' | grep -v test`
|
||||
|
||||
2. Run tests with coverage on the PR branch:
|
||||
`uv run coverage run -m pytest tests/ -q --tb=no`
|
||||
`uv run coverage json -o coverage-pr.json`
|
||||
|
||||
3. Get coverage for changed files only:
|
||||
`uv run coverage report --include="<changed_files_comma_separated>"`
|
||||
|
||||
4. Compare with main branch coverage:
|
||||
- Checkout main: `git checkout origin/main`
|
||||
- Run coverage: `uv run coverage run -m pytest tests/ -q --tb=no && uv run coverage json -o coverage-main.json`
|
||||
- Checkout back: `git checkout -`
|
||||
|
||||
5. Analyze the diff to identify:
|
||||
- NEW FILES: Files that don't exist on main (require good test coverage)
|
||||
- MODIFIED FILES: Files with changes (changes must be covered by tests)
|
||||
|
||||
6. Report in PR comment with a markdown table:
|
||||
- Coverage % for each changed file (PR vs main)
|
||||
- Overall coverage change
|
||||
- For NEW files: Flag if coverage is below 75%
|
||||
- For MODIFIED files: Flag if the changed lines are not covered by tests
|
||||
- Flag if overall coverage decreased
|
||||
|
||||
Coverage requirements:
|
||||
- New implementations/files: Must have ≥75% test coverage
|
||||
- Modified code: Changed lines should be exercised by existing or new tests
|
||||
- No coverage regressions: Overall coverage should not decrease
|
||||
claude_args: '--allowedTools "mcp__github_inline_comment__create_inline_comment,Bash(gh pr comment:*),Bash(gh pr diff:*),Bash(gh pr view:*),Bash(gh pr list:*),Bash(gh issue view:*),Bash(gh issue list:*),Bash(gh api:*),Bash(uv run prek *),Bash(uv run coverage *),Bash(uv run pytest *),Bash(git status*),Bash(git add *),Bash(git commit *),Bash(git push*),Bash(git diff *),Bash(git checkout *),Read,Glob,Grep"'
|
||||
additional_permissions: |
|
||||
actions: read
|
||||
env:
|
||||
ANTHROPIC_FOUNDRY_API_KEY: ${{ secrets.AZURE_ANTHROPIC_API_KEY }}
|
||||
ANTHROPIC_FOUNDRY_BASE_URL: ${{ secrets.AZURE_ANTHROPIC_ENDPOINT }}
|
||||
|
||||
# @claude mentions (can edit and push) - restricted to maintainers only
|
||||
claude-mention:
|
||||
if: |
|
||||
(
|
||||
github.event_name == 'issue_comment' &&
|
||||
contains(github.event.comment.body, '@claude') &&
|
||||
(github.event.comment.author_association == 'OWNER' || github.event.comment.author_association == 'MEMBER' || github.event.comment.author_association == 'COLLABORATOR')
|
||||
) ||
|
||||
(
|
||||
github.event_name == 'pull_request_review_comment' &&
|
||||
contains(github.event.comment.body, '@claude') &&
|
||||
(github.event.comment.author_association == 'OWNER' || github.event.comment.author_association == 'MEMBER' || github.event.comment.author_association == 'COLLABORATOR') &&
|
||||
github.event.pull_request.head.repo.full_name == github.repository
|
||||
) ||
|
||||
(
|
||||
github.event_name == 'pull_request_review' &&
|
||||
contains(github.event.review.body, '@claude') &&
|
||||
(github.event.review.author_association == 'OWNER' || github.event.review.author_association == 'MEMBER' || github.event.review.author_association == 'COLLABORATOR') &&
|
||||
github.event.pull_request.head.repo.full_name == github.repository
|
||||
) ||
|
||||
(
|
||||
github.event_name == 'issues' &&
|
||||
(contains(github.event.issue.body, '@claude') || contains(github.event.issue.title, '@claude')) &&
|
||||
(github.event.issue.author_association == 'OWNER' || github.event.issue.author_association == 'MEMBER' || github.event.issue.author_association == 'COLLABORATOR')
|
||||
)
|
||||
runs-on: ubuntu-latest
|
||||
permissions:
|
||||
contents: write
|
||||
pull-requests: write
|
||||
issues: read
|
||||
id-token: write
|
||||
actions: read
|
||||
steps:
|
||||
- name: Get PR head ref
|
||||
id: pr-ref
|
||||
|
|
@ -44,14 +181,22 @@ jobs:
|
|||
fetch-depth: 0
|
||||
ref: ${{ steps.pr-ref.outputs.ref }}
|
||||
|
||||
- name: Install uv
|
||||
uses: astral-sh/setup-uv@v6
|
||||
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
uv venv --seed
|
||||
uv sync
|
||||
|
||||
- name: Run Claude Code
|
||||
id: claude
|
||||
uses: anthropics/claude-code-action@v1
|
||||
with:
|
||||
use_foundry: "true"
|
||||
claude_args: '--allowedTools "Read,Edit,Write,Glob,Grep,Bash(git status*),Bash(git diff*),Bash(git add *),Bash(git commit *),Bash(git push*),Bash(git log*),Bash(git merge*),Bash(git fetch*),Bash(git checkout*),Bash(git branch*),Bash(uv run prek *),Bash(prek *),Bash(uv run ruff *),Bash(uv run pytest *),Bash(uv run mypy *),Bash(uv run coverage *),Bash(gh pr comment*),Bash(gh pr view*),Bash(gh pr diff*),Bash(gh pr merge*),Bash(gh pr close*)"'
|
||||
additional_permissions: |
|
||||
actions: read
|
||||
env:
|
||||
ANTHROPIC_FOUNDRY_API_KEY: ${{ secrets.AZURE_ANTHROPIC_API_KEY }}
|
||||
ANTHROPIC_FOUNDRY_BASE_URL: ${{ secrets.AZURE_ANTHROPIC_ENDPOINT }}
|
||||
|
||||
|
|
|
|||
19
.github/workflows/pre-commit.yaml
vendored
19
.github/workflows/pre-commit.yaml
vendored
|
|
@ -1,19 +0,0 @@
|
|||
name: Lint
|
||||
on:
|
||||
pull_request:
|
||||
push:
|
||||
branches:
|
||||
- main
|
||||
|
||||
concurrency:
|
||||
group: ${{ github.workflow }}-${{ github.ref_name }}
|
||||
cancel-in-progress: true
|
||||
|
||||
jobs:
|
||||
lint:
|
||||
name: Run pre-commit hooks
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- uses: actions/setup-python@v5
|
||||
- uses: pre-commit/action@v3.0.1
|
||||
18
.github/workflows/prek.yaml
vendored
Normal file
18
.github/workflows/prek.yaml
vendored
Normal file
|
|
@ -0,0 +1,18 @@
|
|||
name: Lint
|
||||
on: [pull_request]
|
||||
|
||||
concurrency:
|
||||
group: ${{ github.workflow }}-${{ github.ref_name }}
|
||||
cancel-in-progress: true
|
||||
|
||||
jobs:
|
||||
prek:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
with:
|
||||
fetch-depth: 0
|
||||
- uses: astral-sh/setup-uv@v6
|
||||
- uses: j178/prek-action@v1
|
||||
with:
|
||||
extra-args: '--from-ref origin/${{ github.base_ref }} --to-ref ${{ github.sha }}'
|
||||
|
|
@ -1,6 +1,6 @@
|
|||
repos:
|
||||
- repo: https://github.com/astral-sh/ruff-pre-commit
|
||||
rev: v0.12.7
|
||||
rev: v0.15.0
|
||||
hooks:
|
||||
# Run the linter.
|
||||
- id: ruff-check
|
||||
|
|
|
|||
|
|
@ -62,6 +62,8 @@ codeflash/
|
|||
- Use libcst, not ast - For Python, always use `libcst` for code parsing/modification to preserve formatting.
|
||||
- Code context extraction and replacement tests must always assert for full string equality, no substring matching.
|
||||
- Any new feature or bug fix that can be tested automatically must have test cases.
|
||||
- If changes affect existing test expectations, update the tests accordingly. Tests must always pass after changes.
|
||||
- NEVER use leading underscores for function names (e.g., `_helper`). Python has no true private functions. Always use public names.
|
||||
|
||||
## Code Style
|
||||
|
||||
|
|
@ -70,7 +72,7 @@ codeflash/
|
|||
- **Tooling**: Ruff for linting/formatting, mypy strict mode, pre-commit hooks
|
||||
- **Comments**: Minimal - only explain "why", not "what"
|
||||
- **Docstrings**: Do not add unless explicitly requested
|
||||
- **Naming**: Prefer public functions (no leading underscore) - Python doesn't have true private functions
|
||||
- **Naming**: NEVER use leading underscores (`_function_name`) - Python has no true private functions, use public names
|
||||
- **Paths**: Always use absolute paths, handle encoding explicitly (UTF-8)
|
||||
|
||||
## Git Commits & Pull Requests
|
||||
|
|
|
|||
|
|
@ -20,7 +20,7 @@
|
|||
}
|
||||
},
|
||||
"../../../packages/codeflash": {
|
||||
"version": "0.3.1",
|
||||
"version": "0.4.0",
|
||||
"dev": true,
|
||||
"hasInstallScript": true,
|
||||
"license": "MIT",
|
||||
|
|
|
|||
|
|
@ -373,11 +373,13 @@ def _handle_show_config() -> None:
|
|||
detected = detect_project(project_root)
|
||||
|
||||
# Check if config exists or is auto-detected
|
||||
config_exists, _ = has_existing_config(project_root)
|
||||
config_exists, config_file = has_existing_config(project_root)
|
||||
status = "Saved config" if config_exists else "Auto-detected (not saved)"
|
||||
|
||||
console.print()
|
||||
console.print(f"[bold]Codeflash Configuration[/bold] ({status})")
|
||||
if config_exists and config_file:
|
||||
console.print(f"[dim]Config file: {project_root / config_file}[/dim]")
|
||||
console.print()
|
||||
|
||||
table = Table(show_header=True, header_style="bold cyan")
|
||||
|
|
|
|||
|
|
@ -27,6 +27,9 @@ from codeflash.cli_cmds.cli_common import apologize_and_exit
|
|||
from codeflash.cli_cmds.console import console, logger
|
||||
from codeflash.cli_cmds.extension import install_vscode_extension
|
||||
|
||||
# Import Java init module
|
||||
from codeflash.cli_cmds.init_java import init_java_project
|
||||
|
||||
# Import JS/TS init module
|
||||
from codeflash.cli_cmds.init_javascript import (
|
||||
ProjectLanguage,
|
||||
|
|
@ -35,9 +38,6 @@ from codeflash.cli_cmds.init_javascript import (
|
|||
get_js_dependency_installation_commands,
|
||||
init_js_project,
|
||||
)
|
||||
|
||||
# Import Java init module
|
||||
from codeflash.cli_cmds.init_java import init_java_project
|
||||
from codeflash.code_utils.code_utils import validate_relative_directory_path
|
||||
from codeflash.code_utils.compat import LF
|
||||
from codeflash.code_utils.config_parser import parse_config_file
|
||||
|
|
@ -1674,9 +1674,7 @@ def _customize_java_workflow_content(optimize_yml_content: str, git_root: Path,
|
|||
|
||||
# Install dependencies
|
||||
install_deps_cmd = get_java_dependency_installation_commands(build_tool)
|
||||
optimize_yml_content = optimize_yml_content.replace("{{ install_dependencies_command }}", install_deps_cmd)
|
||||
|
||||
return optimize_yml_content
|
||||
return optimize_yml_content.replace("{{ install_dependencies_command }}", install_deps_cmd)
|
||||
|
||||
|
||||
def get_formatter_cmds(formatter: str) -> list[str]:
|
||||
|
|
|
|||
|
|
@ -165,9 +165,7 @@ def init_java_project() -> None:
|
|||
|
||||
lang_panel = Panel(
|
||||
Text(
|
||||
"Java project detected!\n\nI'll help you set up Codeflash for your project.",
|
||||
style="cyan",
|
||||
justify="center",
|
||||
"Java project detected!\n\nI'll help you set up Codeflash for your project.", style="cyan", justify="center"
|
||||
),
|
||||
title="Java Setup",
|
||||
border_style="bright_red",
|
||||
|
|
@ -205,7 +203,9 @@ def init_java_project() -> None:
|
|||
completion_message = "Codeflash is now set up for your Java project!\n\nYou can now run any of these commands:"
|
||||
|
||||
if did_add_new_key:
|
||||
completion_message += "\n\nDon't forget to restart your shell to load the CODEFLASH_API_KEY environment variable!"
|
||||
completion_message += (
|
||||
"\n\nDon't forget to restart your shell to load the CODEFLASH_API_KEY environment variable!"
|
||||
)
|
||||
if os.name == "nt":
|
||||
reload_cmd = f". {get_shell_rc_path()}" if is_powershell() else f"call {get_shell_rc_path()}"
|
||||
else:
|
||||
|
|
@ -234,9 +234,7 @@ def should_modify_java_config() -> tuple[bool, dict[str, Any] | None]:
|
|||
codeflash_config_path = project_root / "codeflash.toml"
|
||||
if codeflash_config_path.exists():
|
||||
return Confirm.ask(
|
||||
"A Codeflash config already exists. Do you want to re-configure it?",
|
||||
default=False,
|
||||
show_default=True,
|
||||
"A Codeflash config already exists. Do you want to re-configure it?", default=False, show_default=True
|
||||
), None
|
||||
|
||||
return True, None
|
||||
|
|
@ -285,14 +283,10 @@ def collect_java_setup_info() -> JavaSetupInfo:
|
|||
|
||||
if Confirm.ask("Would you like to change any of these settings?", default=False):
|
||||
# Source root override
|
||||
module_root_override = _prompt_directory_override(
|
||||
"source", detected_source_root, curdir
|
||||
)
|
||||
module_root_override = _prompt_directory_override("source", detected_source_root, curdir)
|
||||
|
||||
# Test root override
|
||||
test_root_override = _prompt_directory_override(
|
||||
"test", detected_test_root, curdir
|
||||
)
|
||||
test_root_override = _prompt_directory_override("test", detected_test_root, curdir)
|
||||
|
||||
# Formatter override
|
||||
formatter_questions = [
|
||||
|
|
@ -300,7 +294,7 @@ def collect_java_setup_info() -> JavaSetupInfo:
|
|||
"formatter",
|
||||
message="Which code formatter do you use?",
|
||||
choices=[
|
||||
(f"keep detected (google-java-format)", "keep"),
|
||||
("keep detected (google-java-format)", "keep"),
|
||||
("google-java-format", "google-java-format"),
|
||||
("spotless", "spotless"),
|
||||
("other", "other"),
|
||||
|
|
@ -345,7 +339,7 @@ def _prompt_directory_override(dir_type: str, detected: str, curdir: Path) -> st
|
|||
subdirs = [d.name for d in curdir.iterdir() if d.is_dir() and not d.name.startswith(".")]
|
||||
subdirs = [d for d in subdirs if d not in ("target", "build", ".git", ".idea", detected)]
|
||||
|
||||
options = [keep_detected_option] + subdirs[:5] + [custom_dir_option]
|
||||
options = [keep_detected_option, *subdirs[:5], custom_dir_option]
|
||||
|
||||
questions = [
|
||||
inquirer.List(
|
||||
|
|
@ -364,10 +358,9 @@ def _prompt_directory_override(dir_type: str, detected: str, curdir: Path) -> st
|
|||
answer = answers[f"{dir_type}_root"]
|
||||
if answer == keep_detected_option:
|
||||
return None
|
||||
elif answer == custom_dir_option:
|
||||
if answer == custom_dir_option:
|
||||
return _prompt_custom_directory(dir_type)
|
||||
else:
|
||||
return answer
|
||||
return answer
|
||||
|
||||
|
||||
def _prompt_custom_directory(dir_type: str) -> str:
|
||||
|
|
@ -441,7 +434,7 @@ def get_java_formatter_cmd(formatter: str, build_tool: JavaBuildTool) -> list[st
|
|||
if formatter == "spotless":
|
||||
if build_tool == JavaBuildTool.MAVEN:
|
||||
return ["mvn spotless:apply -DspotlessFiles=$file"]
|
||||
elif build_tool == JavaBuildTool.GRADLE:
|
||||
if build_tool == JavaBuildTool.GRADLE:
|
||||
return ["./gradlew spotlessApply"]
|
||||
return ["spotless $file"]
|
||||
if formatter == "other":
|
||||
|
|
|
|||
|
|
@ -1429,7 +1429,9 @@ def _collect_numerical_imports(tree: ast.Module) -> tuple[set[str], set[str]]:
|
|||
numerical_names: set[str] = set()
|
||||
modules_used: set[str] = set()
|
||||
|
||||
for node in ast.walk(tree):
|
||||
stack: list[ast.AST] = [tree]
|
||||
while stack:
|
||||
node = stack.pop()
|
||||
if isinstance(node, ast.Import):
|
||||
for alias in node.names:
|
||||
# import numpy or import numpy as np
|
||||
|
|
@ -1451,6 +1453,8 @@ def _collect_numerical_imports(tree: ast.Module) -> tuple[set[str], set[str]]:
|
|||
name = alias.asname if alias.asname else alias.name
|
||||
numerical_names.add(name)
|
||||
modules_used.add(module_root)
|
||||
else:
|
||||
stack.extend(ast.iter_child_nodes(node))
|
||||
|
||||
return numerical_names, modules_used
|
||||
|
||||
|
|
|
|||
|
|
@ -711,18 +711,12 @@ def _add_java_class_members(
|
|||
if not new_fields and not new_methods:
|
||||
return original_source
|
||||
|
||||
logger.debug(
|
||||
f"Adding {len(new_fields)} new fields and {len(new_methods)} helper methods to class {class_name}"
|
||||
)
|
||||
logger.debug(f"Adding {len(new_fields)} new fields and {len(new_methods)} helper methods to class {class_name}")
|
||||
|
||||
# Import the insertion function from replacement module
|
||||
from codeflash.languages.java.replacement import _insert_class_members
|
||||
|
||||
result = _insert_class_members(
|
||||
original_source, class_name, new_fields, new_methods, analyzer
|
||||
)
|
||||
|
||||
return result
|
||||
return _insert_class_members(original_source, class_name, new_fields, new_methods, analyzer)
|
||||
|
||||
except Exception as e:
|
||||
logger.debug(f"Error adding Java class members: {e}")
|
||||
|
|
@ -959,12 +953,14 @@ def get_optimized_code_for_module(relative_path: Path, optimized_code: CodeStrin
|
|||
for file_path_str, code in file_to_code_context.items():
|
||||
if file_path_str:
|
||||
# Extract filename without creating Path object repeatedly
|
||||
if file_path_str.endswith(target_filename) and (len(file_path_str) == len(target_filename) or file_path_str[-len(target_filename)-1] in ('/', '\\')):
|
||||
if file_path_str.endswith(target_filename) and (
|
||||
len(file_path_str) == len(target_filename)
|
||||
or file_path_str[-len(target_filename) - 1] in ("/", "\\")
|
||||
):
|
||||
module_optimized_code = code
|
||||
logger.debug(f"Matched {file_path_str} to {relative_path} by filename")
|
||||
break
|
||||
|
||||
|
||||
if module_optimized_code is None:
|
||||
# Also try matching if there's only one code file, but ONLY for non-Python
|
||||
# languages where path matching is less strict. For Python, we require
|
||||
|
|
|
|||
|
|
@ -6,7 +6,8 @@ from typing import Any, Union
|
|||
MAX_TEST_RUN_ITERATIONS = 5
|
||||
OPTIMIZATION_CONTEXT_TOKEN_LIMIT = 16000
|
||||
TESTGEN_CONTEXT_TOKEN_LIMIT = 16000
|
||||
INDIVIDUAL_TESTCASE_TIMEOUT = 15
|
||||
INDIVIDUAL_TESTCASE_TIMEOUT = 15 # For Python pytest
|
||||
JAVA_TESTCASE_TIMEOUT = 120 # Java Maven tests need more time due to startup overhead
|
||||
MAX_FUNCTION_TEST_SECONDS = 60
|
||||
MIN_IMPROVEMENT_THRESHOLD = 0.05
|
||||
MIN_THROUGHPUT_IMPROVEMENT_THRESHOLD = 0.10 # 10% minimum improvement for async throughput
|
||||
|
|
|
|||
|
|
@ -6,6 +6,8 @@ import json
|
|||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from codeflash.setup.detector import is_build_output_dir
|
||||
|
||||
PACKAGE_JSON_CACHE: dict[Path, Path] = {}
|
||||
PACKAGE_JSON_DATA_CACHE: dict[Path, dict[str, Any]] = {}
|
||||
|
||||
|
|
@ -50,12 +52,15 @@ def detect_module_root(project_root: Path, package_data: dict[str, Any]) -> str:
|
|||
"""Detect module root from package.json fields or directory conventions.
|
||||
|
||||
Detection order:
|
||||
1. package.json "exports" field (extract directory from main export)
|
||||
2. package.json "module" field (ESM entry point)
|
||||
3. package.json "main" field (CJS entry point)
|
||||
4. "src/" directory if it exists
|
||||
1. src/, lib/, source/ directories (common source directories)
|
||||
2. package.json "exports" field (if not in build output directory)
|
||||
3. package.json "module" field (ESM, if not in build output directory)
|
||||
4. package.json "main" field (CJS, if not in build output directory)
|
||||
5. Fall back to "." (project root)
|
||||
|
||||
Build output directories (build/, dist/, out/) are skipped since they contain
|
||||
compiled code, not source files.
|
||||
|
||||
Args:
|
||||
project_root: Root directory of the project.
|
||||
package_data: Parsed package.json data.
|
||||
|
|
@ -64,6 +69,11 @@ def detect_module_root(project_root: Path, package_data: dict[str, Any]) -> str:
|
|||
Detected module root path (relative to project root).
|
||||
|
||||
"""
|
||||
# Check for common source directories first - these are always preferred
|
||||
for src_dir in ["src", "lib", "source"]:
|
||||
if (project_root / src_dir).is_dir():
|
||||
return src_dir
|
||||
|
||||
# Check exports field (modern Node.js)
|
||||
exports = package_data.get("exports")
|
||||
if exports:
|
||||
|
|
@ -80,27 +90,38 @@ def detect_module_root(project_root: Path, package_data: dict[str, Any]) -> str:
|
|||
|
||||
if entry_path and isinstance(entry_path, str):
|
||||
parent = Path(entry_path).parent
|
||||
if parent != Path() and (project_root / parent).is_dir():
|
||||
if (
|
||||
parent != Path()
|
||||
and parent.as_posix() != "."
|
||||
and (project_root / parent).is_dir()
|
||||
and not is_build_output_dir(parent)
|
||||
):
|
||||
return parent.as_posix()
|
||||
|
||||
# Check module field (ESM)
|
||||
module_field = package_data.get("module")
|
||||
if module_field and isinstance(module_field, str):
|
||||
parent = Path(module_field).parent
|
||||
if parent != Path() and (project_root / parent).is_dir():
|
||||
if (
|
||||
parent != Path()
|
||||
and parent.as_posix() != "."
|
||||
and (project_root / parent).is_dir()
|
||||
and not is_build_output_dir(parent)
|
||||
):
|
||||
return parent.as_posix()
|
||||
|
||||
# Check main field (CJS)
|
||||
main_field = package_data.get("main")
|
||||
if main_field and isinstance(main_field, str):
|
||||
parent = Path(main_field).parent
|
||||
if parent != Path() and (project_root / parent).is_dir():
|
||||
if (
|
||||
parent != Path()
|
||||
and parent.as_posix() != "."
|
||||
and (project_root / parent).is_dir()
|
||||
and not is_build_output_dir(parent)
|
||||
):
|
||||
return parent.as_posix()
|
||||
|
||||
# Check for src/ directory convention
|
||||
if (project_root / "src").is_dir():
|
||||
return "src"
|
||||
|
||||
# Default to project root
|
||||
return "."
|
||||
|
||||
|
|
|
|||
|
|
@ -721,9 +721,7 @@ def inject_profiling_into_existing_test(
|
|||
if is_java():
|
||||
from codeflash.languages.java.instrumentation import instrument_existing_test
|
||||
|
||||
return instrument_existing_test(
|
||||
test_path, call_positions, function_to_optimize, tests_project_root, mode.value
|
||||
)
|
||||
return instrument_existing_test(test_path, call_positions, function_to_optimize, tests_project_root, mode.value)
|
||||
|
||||
if function_to_optimize.is_async:
|
||||
return inject_async_profiling_into_existing_test(
|
||||
|
|
|
|||
|
|
@ -746,7 +746,11 @@ def get_external_base_class_inits(code_context: CodeStringsMarkdown, project_roo
|
|||
return CodeStringsMarkdown(code_strings=[])
|
||||
|
||||
imported_names: dict[str, str] = {}
|
||||
external_bases: list[tuple[str, str]] = []
|
||||
# Use a set to deduplicate external base entries to avoid repeated expensive checks/imports.
|
||||
external_bases_set: set[tuple[str, str]] = set()
|
||||
# Local cache to avoid repeated _is_project_module calls for the same module_name.
|
||||
is_project_cache: dict[str, bool] = {}
|
||||
|
||||
for node in ast.walk(tree):
|
||||
if isinstance(node, ast.ImportFrom) and node.module:
|
||||
for alias in node.names:
|
||||
|
|
@ -763,21 +767,31 @@ def get_external_base_class_inits(code_context: CodeStringsMarkdown, project_roo
|
|||
|
||||
if base_name and base_name in imported_names:
|
||||
module_name = imported_names[base_name]
|
||||
if not _is_project_module(module_name, project_root_path):
|
||||
external_bases.append((base_name, module_name))
|
||||
# Check cache first to avoid repeated expensive checks.
|
||||
cached = is_project_cache.get(module_name)
|
||||
if cached is None:
|
||||
is_project = _is_project_module(module_name, project_root_path)
|
||||
is_project_cache[module_name] = is_project
|
||||
else:
|
||||
is_project = cached
|
||||
|
||||
if not external_bases:
|
||||
if not is_project:
|
||||
external_bases_set.add((base_name, module_name))
|
||||
|
||||
if not external_bases_set:
|
||||
return CodeStringsMarkdown(code_strings=[])
|
||||
|
||||
code_strings: list[CodeString] = []
|
||||
extracted: set[tuple[str, str]] = set()
|
||||
|
||||
for base_name, module_name in external_bases:
|
||||
if (module_name, base_name) in extracted:
|
||||
continue
|
||||
# Cache imported modules to avoid repeated importlib.import_module calls.
|
||||
imported_module_cache: dict[str, object] = {}
|
||||
|
||||
for base_name, module_name in external_bases_set:
|
||||
try:
|
||||
module = importlib.import_module(module_name)
|
||||
module = imported_module_cache.get(module_name)
|
||||
if module is None:
|
||||
module = importlib.import_module(module_name)
|
||||
imported_module_cache[module_name] = module
|
||||
|
||||
base_class = getattr(module, base_name, None)
|
||||
if base_class is None:
|
||||
continue
|
||||
|
|
@ -799,7 +813,6 @@ def get_external_base_class_inits(code_context: CodeStringsMarkdown, project_roo
|
|||
|
||||
class_source = f"class {base_name}:\n" + textwrap.indent(init_source, " ")
|
||||
code_strings.append(CodeString(code=class_source, file_path=class_file))
|
||||
extracted.add((module_name, base_name))
|
||||
|
||||
except (ImportError, ModuleNotFoundError, AttributeError):
|
||||
logger.debug(f"Failed to extract __init__ for {module_name}.{base_name}")
|
||||
|
|
@ -854,12 +867,13 @@ def extract_imports_for_class(module_tree: ast.Module, class_node: ast.ClassDef,
|
|||
needed_names.add(decorator.func.value.id)
|
||||
|
||||
# Get type annotation names from class body (for dataclass fields)
|
||||
for item in ast.walk(class_node):
|
||||
for item in class_node.body:
|
||||
if isinstance(item, ast.AnnAssign) and item.annotation:
|
||||
collect_names_from_annotation(item.annotation, needed_names)
|
||||
# Also check for field() calls which are common in dataclasses
|
||||
if isinstance(item, ast.Call) and isinstance(item.func, ast.Name):
|
||||
needed_names.add(item.func.id)
|
||||
elif isinstance(item, ast.Assign) and isinstance(item.value, ast.Call):
|
||||
if isinstance(item.value.func, ast.Name):
|
||||
needed_names.add(item.value.func.id)
|
||||
|
||||
# Find imports that provide these names
|
||||
import_lines: list[str] = []
|
||||
|
|
|
|||
|
|
@ -656,7 +656,7 @@ def discover_unit_tests(
|
|||
|
||||
# Existing Python logic
|
||||
framework_strategies: dict[str, Callable] = {"pytest": discover_tests_pytest, "unittest": discover_tests_unittest}
|
||||
strategy = framework_strategies.get(cfg.test_framework, None)
|
||||
strategy = framework_strategies.get(cfg.test_framework)
|
||||
if not strategy:
|
||||
error_message = f"Unsupported test framework: {cfg.test_framework}"
|
||||
raise ValueError(error_message)
|
||||
|
|
|
|||
|
|
@ -25,11 +25,11 @@ class PrComment:
|
|||
best_async_throughput: Optional[int] = None
|
||||
|
||||
def to_json(self) -> dict[str, Union[str, int, dict[str, dict[str, int]], list[BenchmarkDetail], None]]:
|
||||
report_table = {
|
||||
test_type.to_name(): result
|
||||
for test_type, result in self.winning_behavior_test_results.get_test_pass_fail_report_by_type().items()
|
||||
if test_type.to_name()
|
||||
}
|
||||
report_table: dict[str, dict[str, int]] = {}
|
||||
for test_type, result in self.winning_behavior_test_results.get_test_pass_fail_report_by_type().items():
|
||||
name = test_type.to_name()
|
||||
if name:
|
||||
report_table[name] = result
|
||||
|
||||
result: dict[str, Union[str, int, dict[str, dict[str, int]], list[BenchmarkDetail], None]] = {
|
||||
"optimization_explanation": self.optimization_explanation,
|
||||
|
|
|
|||
|
|
@ -36,15 +36,15 @@ from codeflash.languages.current import (
|
|||
reset_current_language,
|
||||
set_current_language,
|
||||
)
|
||||
|
||||
# Java language support
|
||||
# Importing the module triggers registration via @register_language decorator
|
||||
from codeflash.languages.java.support import JavaSupport # noqa: F401
|
||||
from codeflash.languages.javascript import JavaScriptSupport, TypeScriptSupport # noqa: F401
|
||||
|
||||
# Import language support modules to trigger auto-registration
|
||||
# This ensures all supported languages are available when this package is imported
|
||||
from codeflash.languages.python import PythonSupport # noqa: F401
|
||||
|
||||
# Java language support
|
||||
# Importing the module triggers registration via @register_language decorator
|
||||
from codeflash.languages.java.support import JavaSupport # noqa: F401
|
||||
from codeflash.languages.registry import (
|
||||
detect_project_language,
|
||||
get_language_support,
|
||||
|
|
|
|||
|
|
@ -58,9 +58,6 @@ def set_current_language(language: Language | str) -> None:
|
|||
|
||||
"""
|
||||
global _current_language
|
||||
|
||||
if _current_language is not None:
|
||||
return
|
||||
_current_language = Language(language) if isinstance(language, str) else language
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -13,7 +13,10 @@ import subprocess
|
|||
import xml.etree.ElementTree as ET
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from pathlib import Path
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
|
@ -78,9 +81,7 @@ def _safe_parse_xml(file_path: Path) -> ET.ElementTree:
|
|||
root = ET.fromstring(content)
|
||||
|
||||
# Create ElementTree from root
|
||||
tree = ET.ElementTree(root)
|
||||
|
||||
return tree
|
||||
return ET.ElementTree(root)
|
||||
|
||||
|
||||
class BuildTool(Enum):
|
||||
|
|
@ -462,13 +463,7 @@ def run_maven_tests(
|
|||
|
||||
try:
|
||||
result = subprocess.run(
|
||||
cmd,
|
||||
check=False,
|
||||
cwd=project_root,
|
||||
env=run_env,
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=timeout,
|
||||
cmd, check=False, cwd=project_root, env=run_env, capture_output=True, text=True, timeout=timeout
|
||||
)
|
||||
|
||||
# Parse test results from Surefire reports
|
||||
|
|
@ -488,7 +483,7 @@ def run_maven_tests(
|
|||
)
|
||||
|
||||
except subprocess.TimeoutExpired:
|
||||
logger.error("Maven test execution timed out after %d seconds", timeout)
|
||||
logger.exception("Maven test execution timed out after %d seconds", timeout)
|
||||
return MavenTestResult(
|
||||
success=False,
|
||||
tests_run=0,
|
||||
|
|
@ -568,10 +563,7 @@ def _parse_surefire_reports(surefire_dir: Path) -> tuple[int, int, int, int]:
|
|||
|
||||
|
||||
def compile_maven_project(
|
||||
project_root: Path,
|
||||
include_tests: bool = True,
|
||||
env: dict[str, str] | None = None,
|
||||
timeout: int = 300,
|
||||
project_root: Path, include_tests: bool = True, env: dict[str, str] | None = None, timeout: int = 300
|
||||
) -> tuple[bool, str, str]:
|
||||
"""Compile a Maven project.
|
||||
|
||||
|
|
@ -605,13 +597,7 @@ def compile_maven_project(
|
|||
|
||||
try:
|
||||
result = subprocess.run(
|
||||
cmd,
|
||||
check=False,
|
||||
cwd=project_root,
|
||||
env=run_env,
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=timeout,
|
||||
cmd, check=False, cwd=project_root, env=run_env, capture_output=True, text=True, timeout=timeout
|
||||
)
|
||||
|
||||
return result.returncode == 0, result.stdout, result.stderr
|
||||
|
|
@ -1002,14 +988,7 @@ def install_codeflash_runtime(project_root: Path, runtime_jar_path: Path) -> boo
|
|||
]
|
||||
|
||||
try:
|
||||
result = subprocess.run(
|
||||
cmd,
|
||||
check=False,
|
||||
cwd=project_root,
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=60,
|
||||
)
|
||||
result = subprocess.run(cmd, check=False, cwd=project_root, capture_output=True, text=True, timeout=60)
|
||||
|
||||
if result.returncode == 0:
|
||||
logger.info("Successfully installed codeflash-runtime to local Maven repository")
|
||||
|
|
@ -1085,7 +1064,7 @@ def add_codeflash_dependency_to_pom(pom_path: Path) -> bool:
|
|||
return True
|
||||
|
||||
except ET.ParseError as e:
|
||||
logger.error("Failed to parse pom.xml: %s", e)
|
||||
logger.exception("Failed to parse pom.xml: %s", e)
|
||||
return False
|
||||
except Exception as e:
|
||||
logger.exception("Failed to add dependency to pom.xml: %s", e)
|
||||
|
|
@ -1236,7 +1215,7 @@ def add_jacoco_plugin_to_pom(pom_path: Path) -> bool:
|
|||
|
||||
if build_start != -1 and build_end != -1:
|
||||
# Found main build section, find plugins within it
|
||||
build_section = content[build_start:build_end + len("</build>")]
|
||||
build_section = content[build_start : build_end + len("</build>")]
|
||||
plugins_start_in_build = build_section.find("<plugins>")
|
||||
plugins_end_in_build = build_section.rfind("</plugins>")
|
||||
|
||||
|
|
|
|||
|
|
@ -47,7 +47,16 @@ def _find_comparator_jar(project_root: Path | None = None) -> Path | None:
|
|||
return jar_path
|
||||
|
||||
# Check local Maven repository
|
||||
m2_jar = Path.home() / ".m2" / "repository" / "com" / "codeflash" / "codeflash-runtime" / "1.0.0" / "codeflash-runtime-1.0.0.jar"
|
||||
m2_jar = (
|
||||
Path.home()
|
||||
/ ".m2"
|
||||
/ "repository"
|
||||
/ "com"
|
||||
/ "codeflash"
|
||||
/ "codeflash-runtime"
|
||||
/ "1.0.0"
|
||||
/ "codeflash-runtime-1.0.0.jar"
|
||||
)
|
||||
if m2_jar.exists():
|
||||
return m2_jar
|
||||
|
||||
|
|
@ -113,8 +122,7 @@ def compare_test_results(
|
|||
jar_path = comparator_jar or _find_comparator_jar(project_root)
|
||||
if not jar_path or not jar_path.exists():
|
||||
logger.error(
|
||||
"codeflash-runtime JAR not found. "
|
||||
"Please ensure the codeflash-runtime is installed in your project."
|
||||
"codeflash-runtime JAR not found. Please ensure the codeflash-runtime is installed in your project."
|
||||
)
|
||||
return False, []
|
||||
|
||||
|
|
@ -155,10 +163,10 @@ def compare_test_results(
|
|||
|
||||
comparison = json.loads(result.stdout)
|
||||
except json.JSONDecodeError as e:
|
||||
logger.error(f"Failed to parse Java comparator output: {e}")
|
||||
logger.error(f"stdout: {result.stdout[:500] if result.stdout else '(empty)'}")
|
||||
logger.exception(f"Failed to parse Java comparator output: {e}")
|
||||
logger.exception(f"stdout: {result.stdout[:500] if result.stdout else '(empty)'}")
|
||||
if result.stderr:
|
||||
logger.error(f"stderr: {result.stderr[:500]}")
|
||||
logger.exception(f"stderr: {result.stderr[:500]}")
|
||||
return False, []
|
||||
|
||||
# Check for errors in the JSON response
|
||||
|
|
@ -178,9 +186,7 @@ def compare_test_results(
|
|||
for diff in comparison.get("diffs", []):
|
||||
scope_str = diff.get("scope", "return_value")
|
||||
scope = TestDiffScope.RETURN_VALUE
|
||||
if scope_str == "exception":
|
||||
scope = TestDiffScope.DID_PASS
|
||||
elif scope_str == "missing":
|
||||
if scope_str in {"exception", "missing"}:
|
||||
scope = TestDiffScope.DID_PASS
|
||||
|
||||
# Build test identifier
|
||||
|
|
@ -220,20 +226,17 @@ def compare_test_results(
|
|||
return equivalent, test_diffs
|
||||
|
||||
except subprocess.TimeoutExpired:
|
||||
logger.error("Java comparator timed out")
|
||||
logger.exception("Java comparator timed out")
|
||||
return False, []
|
||||
except FileNotFoundError:
|
||||
logger.error("Java not found. Please install Java to compare test results.")
|
||||
logger.exception("Java not found. Please install Java to compare test results.")
|
||||
return False, []
|
||||
except Exception as e:
|
||||
logger.error(f"Error running Java comparator: {e}")
|
||||
logger.exception(f"Error running Java comparator: {e}")
|
||||
return False, []
|
||||
|
||||
|
||||
def compare_invocations_directly(
|
||||
original_results: dict,
|
||||
candidate_results: dict,
|
||||
) -> tuple[bool, list]:
|
||||
def compare_invocations_directly(original_results: dict, candidate_results: dict) -> tuple[bool, list]:
|
||||
"""Compare test invocations directly from Python dictionaries.
|
||||
|
||||
This is a fallback when the Java comparator is not available.
|
||||
|
|
|
|||
|
|
@ -10,7 +10,6 @@ from __future__ import annotations
|
|||
import logging
|
||||
import xml.etree.ElementTree as ET
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from codeflash.languages.java.build_tools import (
|
||||
|
|
@ -22,7 +21,7 @@ from codeflash.languages.java.build_tools import (
|
|||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
pass
|
||||
from pathlib import Path
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
|
@ -80,9 +79,7 @@ def detect_java_project(project_root: Path) -> JavaProjectConfig | None:
|
|||
project_info = get_project_info(project_root)
|
||||
|
||||
# Detect test framework
|
||||
test_framework, has_junit5, has_junit4, has_testng = _detect_test_framework(
|
||||
project_root, build_tool
|
||||
)
|
||||
test_framework, has_junit5, has_junit4, has_testng = _detect_test_framework(project_root, build_tool)
|
||||
|
||||
# Detect other dependencies
|
||||
has_mockito, has_assertj = _detect_test_dependencies(project_root, build_tool)
|
||||
|
|
@ -120,9 +117,7 @@ def detect_java_project(project_root: Path) -> JavaProjectConfig | None:
|
|||
)
|
||||
|
||||
|
||||
def _detect_test_framework(
|
||||
project_root: Path, build_tool: BuildTool
|
||||
) -> tuple[str, bool, bool, bool]:
|
||||
def _detect_test_framework(project_root: Path, build_tool: BuildTool) -> tuple[str, bool, bool, bool]:
|
||||
"""Detect which test framework the project uses.
|
||||
|
||||
Args:
|
||||
|
|
@ -210,9 +205,7 @@ def _detect_test_deps_from_pom(project_root: Path) -> tuple[bool, bool, bool]:
|
|||
elif tag == "groupId":
|
||||
group_id = child.text
|
||||
|
||||
if group_id == "org.junit.jupiter" or (
|
||||
artifact_id and "junit-jupiter" in artifact_id
|
||||
):
|
||||
if group_id == "org.junit.jupiter" or (artifact_id and "junit-jupiter" in artifact_id):
|
||||
has_junit5 = True
|
||||
elif group_id == "junit" and artifact_id == "junit":
|
||||
has_junit4 = True
|
||||
|
|
@ -253,9 +246,7 @@ def _detect_test_deps_from_gradle(project_root: Path) -> tuple[bool, bool, bool]
|
|||
return has_junit5, has_junit4, has_testng
|
||||
|
||||
|
||||
def _detect_test_dependencies(
|
||||
project_root: Path, build_tool: BuildTool
|
||||
) -> tuple[bool, bool]:
|
||||
def _detect_test_dependencies(project_root: Path, build_tool: BuildTool) -> tuple[bool, bool]:
|
||||
"""Detect additional test dependencies (Mockito, AssertJ).
|
||||
|
||||
Returns:
|
||||
|
|
@ -289,9 +280,7 @@ def _detect_test_dependencies(
|
|||
return has_mockito, has_assertj
|
||||
|
||||
|
||||
def _get_compiler_settings(
|
||||
project_root: Path, build_tool: BuildTool
|
||||
) -> tuple[str | None, str | None]:
|
||||
def _get_compiler_settings(project_root: Path, build_tool: BuildTool) -> tuple[str | None, str | None]:
|
||||
"""Get compiler source and target settings.
|
||||
|
||||
Returns:
|
||||
|
|
@ -392,11 +381,7 @@ def is_java_project(project_root: Path) -> bool:
|
|||
return True
|
||||
|
||||
# Check for Java source files
|
||||
for pattern in ["src/**/*.java", "*.java"]:
|
||||
if list(project_root.glob(pattern)):
|
||||
return True
|
||||
|
||||
return False
|
||||
return any(list(project_root.glob(pattern)) for pattern in ["src/**/*.java", "*.java"])
|
||||
|
||||
|
||||
def get_test_file_pattern(config: JavaProjectConfig) -> str:
|
||||
|
|
|
|||
|
|
@ -8,26 +8,27 @@ and other dependencies.
|
|||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
|
||||
from codeflash.languages.base import CodeContext, HelperFunction, Language
|
||||
from codeflash.languages.java.discovery import discover_functions_from_source
|
||||
from codeflash.languages.java.import_resolver import JavaImportResolver, find_helper_files
|
||||
from codeflash.languages.java.parser import JavaAnalyzer, JavaClassNode, get_java_analyzer
|
||||
from codeflash.languages.java.import_resolver import find_helper_files
|
||||
from codeflash.languages.java.parser import get_java_analyzer
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from pathlib import Path
|
||||
|
||||
from tree_sitter import Node
|
||||
|
||||
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
|
||||
from codeflash.languages.java.parser import JavaAnalyzer
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class InvalidJavaSyntaxError(Exception):
|
||||
"""Raised when extracted Java code is not syntactically valid."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
def extract_code_context(
|
||||
function: FunctionToOptimize,
|
||||
|
|
@ -67,12 +68,8 @@ def extract_code_context(
|
|||
try:
|
||||
source = function.file_path.read_text(encoding="utf-8")
|
||||
except Exception as e:
|
||||
logger.error("Failed to read %s: %s", function.file_path, e)
|
||||
return CodeContext(
|
||||
target_code="",
|
||||
target_file=function.file_path,
|
||||
language=Language.JAVA,
|
||||
)
|
||||
logger.exception("Failed to read %s: %s", function.file_path, e)
|
||||
return CodeContext(target_code="", target_file=function.file_path, language=Language.JAVA)
|
||||
|
||||
# Extract target function code
|
||||
target_code = extract_function_source(source, function)
|
||||
|
|
@ -94,9 +91,7 @@ def extract_code_context(
|
|||
import_statements = [_import_to_statement(imp) for imp in imports]
|
||||
|
||||
# Extract helper functions
|
||||
helper_functions = find_helper_functions(
|
||||
function, project_root, max_helper_depth, analyzer
|
||||
)
|
||||
helper_functions = find_helper_functions(function, project_root, max_helper_depth, analyzer)
|
||||
|
||||
# Extract read-only context only if fields are NOT already in the skeleton
|
||||
# Avoid duplication between target_code and read_only_context
|
||||
|
|
@ -107,9 +102,8 @@ def extract_code_context(
|
|||
# Validate syntax - extracted code must always be valid Java
|
||||
if validate_syntax and target_code:
|
||||
if not analyzer.validate_syntax(target_code):
|
||||
raise InvalidJavaSyntaxError(
|
||||
f"Extracted code for {function.function_name} is not syntactically valid Java:\n{target_code}"
|
||||
)
|
||||
msg = f"Extracted code for {function.function_name} is not syntactically valid Java:\n{target_code}"
|
||||
raise InvalidJavaSyntaxError(msg)
|
||||
|
||||
return CodeContext(
|
||||
target_code=target_code,
|
||||
|
|
@ -156,7 +150,7 @@ class TypeSkeleton:
|
|||
enum_constants: str,
|
||||
type_indent: str,
|
||||
type_kind: str, # "class", "interface", or "enum"
|
||||
outer_type_skeleton: "TypeSkeleton | None" = None,
|
||||
outer_type_skeleton: TypeSkeleton | None = None,
|
||||
) -> None:
|
||||
self.type_declaration = type_declaration
|
||||
self.type_javadoc = type_javadoc
|
||||
|
|
@ -173,10 +167,7 @@ ClassSkeleton = TypeSkeleton
|
|||
|
||||
|
||||
def _extract_type_skeleton(
|
||||
source: str,
|
||||
type_name: str,
|
||||
target_method_name: str,
|
||||
analyzer: JavaAnalyzer,
|
||||
source: str, type_name: str, target_method_name: str, analyzer: JavaAnalyzer
|
||||
) -> TypeSkeleton | None:
|
||||
"""Extract the type skeleton (class, interface, or enum) for wrapping a method.
|
||||
|
||||
|
|
@ -254,11 +245,7 @@ def _find_type_node(node: Node, type_name: str, source_bytes: bytes) -> tuple[No
|
|||
Tuple of (node, type_kind) where type_kind is "class", "interface", or "enum".
|
||||
|
||||
"""
|
||||
type_declarations = {
|
||||
"class_declaration": "class",
|
||||
"interface_declaration": "interface",
|
||||
"enum_declaration": "enum",
|
||||
}
|
||||
type_declarations = {"class_declaration": "class", "interface_declaration": "interface", "enum_declaration": "enum"}
|
||||
|
||||
if node.type in type_declarations:
|
||||
name_node = node.child_by_field_name("name")
|
||||
|
|
@ -283,11 +270,7 @@ def _find_class_node(node: Node, class_name: str, source_bytes: bytes) -> Node |
|
|||
|
||||
|
||||
def _get_outer_type_skeleton(
|
||||
inner_type_node: Node,
|
||||
source_bytes: bytes,
|
||||
lines: list[str],
|
||||
target_method_name: str,
|
||||
analyzer: JavaAnalyzer,
|
||||
inner_type_node: Node, source_bytes: bytes, lines: list[str], target_method_name: str, analyzer: JavaAnalyzer
|
||||
) -> TypeSkeleton | None:
|
||||
"""Get the outer type skeleton if this is an inner type.
|
||||
|
||||
|
|
@ -356,11 +339,7 @@ def _extract_type_declaration(type_node: Node, source_bytes: bytes, type_kind: s
|
|||
parts: list[str] = []
|
||||
|
||||
# Determine which body node type to look for
|
||||
body_types = {
|
||||
"class": "class_body",
|
||||
"interface": "interface_body",
|
||||
"enum": "enum_body",
|
||||
}
|
||||
body_types = {"class": "class_body", "interface": "interface_body", "enum": "enum_body"}
|
||||
body_type = body_types.get(type_kind, "class_body")
|
||||
|
||||
for child in type_node.children:
|
||||
|
|
@ -374,7 +353,8 @@ def _extract_type_declaration(type_node: Node, source_bytes: bytes, type_kind: s
|
|||
|
||||
|
||||
# Keep old function name for backwards compatibility
|
||||
_extract_class_declaration = lambda node, source_bytes: _extract_type_declaration(node, source_bytes, "class")
|
||||
def _extract_class_declaration(node, source_bytes):
|
||||
return _extract_type_declaration(node, source_bytes, "class")
|
||||
|
||||
|
||||
def _find_javadoc(node: Node, source_bytes: bytes) -> str | None:
|
||||
|
|
@ -390,11 +370,7 @@ def _find_javadoc(node: Node, source_bytes: bytes) -> str | None:
|
|||
|
||||
|
||||
def _extract_type_body_context(
|
||||
body_node: Node,
|
||||
source_bytes: bytes,
|
||||
lines: list[str],
|
||||
target_method_name: str,
|
||||
type_kind: str,
|
||||
body_node: Node, source_bytes: bytes, lines: list[str], target_method_name: str, type_kind: str
|
||||
) -> tuple[str, str, str]:
|
||||
"""Extract fields, constructors, and enum constants from a type body.
|
||||
|
||||
|
|
@ -473,15 +449,10 @@ def _extract_type_body_context(
|
|||
|
||||
# Keep old function name for backwards compatibility
|
||||
def _extract_class_body_context(
|
||||
body_node: Node,
|
||||
source_bytes: bytes,
|
||||
lines: list[str],
|
||||
target_method_name: str,
|
||||
body_node: Node, source_bytes: bytes, lines: list[str], target_method_name: str
|
||||
) -> tuple[str, str]:
|
||||
"""Extract fields and constructors from a class body."""
|
||||
fields, constructors, _ = _extract_type_body_context(
|
||||
body_node, source_bytes, lines, target_method_name, "class"
|
||||
)
|
||||
fields, constructors, _ = _extract_type_body_context(body_node, source_bytes, lines, target_method_name, "class")
|
||||
return (fields, constructors)
|
||||
|
||||
|
||||
|
|
@ -584,10 +555,7 @@ def extract_function_source(source: str, function: FunctionToOptimize) -> str:
|
|||
|
||||
|
||||
def find_helper_functions(
|
||||
function: FunctionToOptimize,
|
||||
project_root: Path,
|
||||
max_depth: int = 2,
|
||||
analyzer: JavaAnalyzer | None = None,
|
||||
function: FunctionToOptimize, project_root: Path, max_depth: int = 2, analyzer: JavaAnalyzer | None = None
|
||||
) -> list[HelperFunction]:
|
||||
"""Find helper functions that the target function depends on.
|
||||
|
||||
|
|
@ -606,11 +574,9 @@ def find_helper_functions(
|
|||
visited_functions: set[str] = set()
|
||||
|
||||
# Find helper files through imports
|
||||
helper_files = find_helper_files(
|
||||
function.file_path, project_root, max_depth, analyzer
|
||||
)
|
||||
helper_files = find_helper_files(function.file_path, project_root, max_depth, analyzer)
|
||||
|
||||
for file_path, class_names in helper_files.items():
|
||||
for file_path in helper_files:
|
||||
try:
|
||||
source = file_path.read_text(encoding="utf-8")
|
||||
file_functions = discover_functions_from_source(source, file_path, analyzer=analyzer)
|
||||
|
|
@ -648,10 +614,7 @@ def find_helper_functions(
|
|||
return helpers
|
||||
|
||||
|
||||
def _find_same_class_helpers(
|
||||
function: FunctionToOptimize,
|
||||
analyzer: JavaAnalyzer,
|
||||
) -> list[HelperFunction]:
|
||||
def _find_same_class_helpers(function: FunctionToOptimize, analyzer: JavaAnalyzer) -> list[HelperFunction]:
|
||||
"""Find helper methods in the same class as the target function.
|
||||
|
||||
Args:
|
||||
|
|
@ -694,9 +657,7 @@ def _find_same_class_helpers(
|
|||
and method.class_name == function.class_name
|
||||
and method.name in called_methods
|
||||
):
|
||||
func_source = source_bytes[
|
||||
method.node.start_byte : method.node.end_byte
|
||||
].decode("utf8")
|
||||
func_source = source_bytes[method.node.start_byte : method.node.end_byte].decode("utf8")
|
||||
|
||||
helpers.append(
|
||||
HelperFunction(
|
||||
|
|
@ -715,11 +676,7 @@ def _find_same_class_helpers(
|
|||
return helpers
|
||||
|
||||
|
||||
def extract_read_only_context(
|
||||
source: str,
|
||||
function: FunctionToOptimize,
|
||||
analyzer: JavaAnalyzer,
|
||||
) -> str:
|
||||
def extract_read_only_context(source: str, function: FunctionToOptimize, analyzer: JavaAnalyzer) -> str:
|
||||
"""Extract read-only context (fields, constants, inner classes).
|
||||
|
||||
This extracts class-level context that the function might depend on
|
||||
|
|
@ -767,11 +724,7 @@ def _import_to_statement(import_info) -> str:
|
|||
return f"{prefix}{import_info.import_path}{suffix};"
|
||||
|
||||
|
||||
def extract_class_context(
|
||||
file_path: Path,
|
||||
class_name: str,
|
||||
analyzer: JavaAnalyzer | None = None,
|
||||
) -> str:
|
||||
def extract_class_context(file_path: Path, class_name: str, analyzer: JavaAnalyzer | None = None) -> str:
|
||||
"""Extract the full context of a class.
|
||||
|
||||
Args:
|
||||
|
|
@ -813,5 +766,5 @@ def extract_class_context(
|
|||
return package_stmt + "\n".join(import_statements) + "\n\n" + class_source
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to extract class context: %s", e)
|
||||
logger.exception("Failed to extract class context: %s", e)
|
||||
return ""
|
||||
|
|
|
|||
|
|
@ -12,19 +12,17 @@ from typing import TYPE_CHECKING
|
|||
|
||||
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
|
||||
from codeflash.languages.base import FunctionFilterCriteria
|
||||
from codeflash.languages.java.parser import JavaAnalyzer, JavaMethodNode, get_java_analyzer
|
||||
from codeflash.languages.java.parser import get_java_analyzer
|
||||
from codeflash.models.function_types import FunctionParent
|
||||
|
||||
if TYPE_CHECKING:
|
||||
pass
|
||||
from codeflash.languages.java.parser import JavaAnalyzer, JavaMethodNode
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def discover_functions(
|
||||
file_path: Path,
|
||||
filter_criteria: FunctionFilterCriteria | None = None,
|
||||
analyzer: JavaAnalyzer | None = None,
|
||||
file_path: Path, filter_criteria: FunctionFilterCriteria | None = None, analyzer: JavaAnalyzer | None = None
|
||||
) -> list[FunctionToOptimize]:
|
||||
"""Find all optimizable functions/methods in a Java file.
|
||||
|
||||
|
|
@ -115,10 +113,7 @@ def discover_functions_from_source(
|
|||
|
||||
|
||||
def _should_include_method(
|
||||
method: JavaMethodNode,
|
||||
criteria: FunctionFilterCriteria,
|
||||
source: str,
|
||||
analyzer: JavaAnalyzer,
|
||||
method: JavaMethodNode, criteria: FunctionFilterCriteria, source: str, analyzer: JavaAnalyzer
|
||||
) -> bool:
|
||||
"""Check if a method should be included based on filter criteria.
|
||||
|
||||
|
|
@ -176,10 +171,7 @@ def _should_include_method(
|
|||
return True
|
||||
|
||||
|
||||
def discover_test_methods(
|
||||
file_path: Path,
|
||||
analyzer: JavaAnalyzer | None = None,
|
||||
) -> list[FunctionToOptimize]:
|
||||
def discover_test_methods(file_path: Path, analyzer: JavaAnalyzer | None = None) -> list[FunctionToOptimize]:
|
||||
"""Find all JUnit test methods in a Java test file.
|
||||
|
||||
Looks for methods annotated with @Test, @ParameterizedTest, @RepeatedTest, etc.
|
||||
|
|
@ -232,7 +224,7 @@ def _walk_tree_for_test_methods(
|
|||
for child in node.children:
|
||||
if child.type == "modifiers":
|
||||
for mod_child in child.children:
|
||||
if mod_child.type == "marker_annotation" or mod_child.type == "annotation":
|
||||
if mod_child.type in {"marker_annotation", "annotation"}:
|
||||
annotation_text = analyzer.get_node_text(mod_child, source_bytes)
|
||||
# Check for JUnit 5 test annotations
|
||||
if any(
|
||||
|
|
@ -278,10 +270,7 @@ def _walk_tree_for_test_methods(
|
|||
|
||||
|
||||
def get_method_by_name(
|
||||
file_path: Path,
|
||||
method_name: str,
|
||||
class_name: str | None = None,
|
||||
analyzer: JavaAnalyzer | None = None,
|
||||
file_path: Path, method_name: str, class_name: str | None = None, analyzer: JavaAnalyzer | None = None
|
||||
) -> FunctionToOptimize | None:
|
||||
"""Find a specific method by name in a Java file.
|
||||
|
||||
|
|
@ -306,9 +295,7 @@ def get_method_by_name(
|
|||
|
||||
|
||||
def get_class_methods(
|
||||
file_path: Path,
|
||||
class_name: str,
|
||||
analyzer: JavaAnalyzer | None = None,
|
||||
file_path: Path, class_name: str, analyzer: JavaAnalyzer | None = None
|
||||
) -> list[FunctionToOptimize]:
|
||||
"""Get all methods in a specific class.
|
||||
|
||||
|
|
|
|||
|
|
@ -6,16 +6,13 @@ google-java-format or other available formatters.
|
|||
|
||||
from __future__ import annotations
|
||||
|
||||
import contextlib
|
||||
import logging
|
||||
import os
|
||||
import shutil
|
||||
import subprocess
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
if TYPE_CHECKING:
|
||||
pass
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
|
@ -29,7 +26,7 @@ class JavaFormatter:
|
|||
# Version of google-java-format to use
|
||||
GOOGLE_JAVA_FORMAT_VERSION = "1.19.2"
|
||||
|
||||
def __init__(self, project_root: Path | None = None):
|
||||
def __init__(self, project_root: Path | None = None) -> None:
|
||||
"""Initialize the Java formatter.
|
||||
|
||||
Args:
|
||||
|
|
@ -107,21 +104,13 @@ class JavaFormatter:
|
|||
|
||||
try:
|
||||
# Write source to temp file
|
||||
with tempfile.NamedTemporaryFile(
|
||||
mode="w", suffix=".java", delete=False, encoding="utf-8"
|
||||
) as tmp:
|
||||
with tempfile.NamedTemporaryFile(mode="w", suffix=".java", delete=False, encoding="utf-8") as tmp:
|
||||
tmp.write(source)
|
||||
tmp_path = tmp.name
|
||||
|
||||
try:
|
||||
result = subprocess.run(
|
||||
[
|
||||
self._java_executable,
|
||||
"-jar",
|
||||
str(jar_path),
|
||||
"--replace",
|
||||
tmp_path,
|
||||
],
|
||||
[self._java_executable, "-jar", str(jar_path), "--replace", tmp_path],
|
||||
check=False,
|
||||
capture_output=True,
|
||||
text=True,
|
||||
|
|
@ -133,16 +122,12 @@ class JavaFormatter:
|
|||
with open(tmp_path, encoding="utf-8") as f:
|
||||
return f.read()
|
||||
else:
|
||||
logger.debug(
|
||||
"google-java-format failed: %s", result.stderr or result.stdout
|
||||
)
|
||||
logger.debug("google-java-format failed: %s", result.stderr or result.stdout)
|
||||
|
||||
finally:
|
||||
# Clean up temp file
|
||||
try:
|
||||
with contextlib.suppress(OSError):
|
||||
os.unlink(tmp_path)
|
||||
except OSError:
|
||||
pass
|
||||
|
||||
except subprocess.TimeoutExpired:
|
||||
logger.warning("google-java-format timed out")
|
||||
|
|
@ -169,9 +154,7 @@ class JavaFormatter:
|
|||
if self.project_root
|
||||
else None,
|
||||
# In user's home directory
|
||||
Path.home()
|
||||
/ ".codeflash"
|
||||
/ f"google-java-format-{self.GOOGLE_JAVA_FORMAT_VERSION}-all-deps.jar",
|
||||
Path.home() / ".codeflash" / f"google-java-format-{self.GOOGLE_JAVA_FORMAT_VERSION}-all-deps.jar",
|
||||
# In system temp
|
||||
Path(tempfile.gettempdir())
|
||||
/ "codeflash"
|
||||
|
|
@ -186,8 +169,7 @@ class JavaFormatter:
|
|||
# Don't auto-download to avoid surprises
|
||||
# Users can manually download the JAR
|
||||
logger.debug(
|
||||
"google-java-format JAR not found. "
|
||||
"Download from https://github.com/google/google-java-format/releases"
|
||||
"google-java-format JAR not found. Download from https://github.com/google/google-java-format/releases"
|
||||
)
|
||||
return None
|
||||
|
||||
|
|
@ -239,7 +221,7 @@ class JavaFormatter:
|
|||
logger.info("Downloaded google-java-format to %s", jar_path)
|
||||
return jar_path
|
||||
except Exception as e:
|
||||
logger.error("Failed to download google-java-format: %s", e)
|
||||
logger.exception("Failed to download google-java-format: %s", e)
|
||||
return None
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -8,14 +8,15 @@ from __future__ import annotations
|
|||
|
||||
import logging
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from codeflash.languages.java.build_tools import find_source_root, find_test_root, get_project_info
|
||||
from codeflash.languages.java.parser import JavaAnalyzer, JavaImportInfo, get_java_analyzer
|
||||
from codeflash.languages.java.parser import get_java_analyzer
|
||||
|
||||
if TYPE_CHECKING:
|
||||
pass
|
||||
from pathlib import Path
|
||||
|
||||
from codeflash.languages.java.parser import JavaAnalyzer, JavaImportInfo
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
|
@ -35,18 +36,7 @@ class JavaImportResolver:
|
|||
"""Resolves Java imports to file paths within a project."""
|
||||
|
||||
# Standard Java packages that are always external
|
||||
STANDARD_PACKAGES = frozenset(
|
||||
[
|
||||
"java",
|
||||
"javax",
|
||||
"sun",
|
||||
"com.sun",
|
||||
"jdk",
|
||||
"org.w3c",
|
||||
"org.xml",
|
||||
"org.ietf",
|
||||
]
|
||||
)
|
||||
STANDARD_PACKAGES = frozenset(["java", "javax", "sun", "com.sun", "jdk", "org.w3c", "org.xml", "org.ietf"])
|
||||
|
||||
# Common third-party package prefixes
|
||||
COMMON_EXTERNAL_PREFIXES = frozenset(
|
||||
|
|
@ -66,7 +56,7 @@ class JavaImportResolver:
|
|||
]
|
||||
)
|
||||
|
||||
def __init__(self, project_root: Path):
|
||||
def __init__(self, project_root: Path) -> None:
|
||||
"""Initialize the import resolver.
|
||||
|
||||
Args:
|
||||
|
|
@ -156,10 +146,7 @@ class JavaImportResolver:
|
|||
|
||||
def _is_standard_library(self, import_path: str) -> bool:
|
||||
"""Check if an import is from the Java standard library."""
|
||||
for prefix in self.STANDARD_PACKAGES:
|
||||
if import_path.startswith(prefix + ".") or import_path == prefix:
|
||||
return True
|
||||
return False
|
||||
return any(import_path.startswith(prefix + ".") or import_path == prefix for prefix in self.STANDARD_PACKAGES)
|
||||
|
||||
def _is_external_library(self, import_path: str) -> bool:
|
||||
"""Check if an import is from a known external library."""
|
||||
|
|
@ -249,9 +236,7 @@ class JavaImportResolver:
|
|||
|
||||
return None
|
||||
|
||||
def get_imports_from_file(
|
||||
self, file_path: Path, analyzer: JavaAnalyzer | None = None
|
||||
) -> list[ResolvedImport]:
|
||||
def get_imports_from_file(self, file_path: Path, analyzer: JavaAnalyzer | None = None) -> list[ResolvedImport]:
|
||||
"""Get and resolve all imports from a Java file.
|
||||
|
||||
Args:
|
||||
|
|
@ -272,9 +257,7 @@ class JavaImportResolver:
|
|||
logger.warning("Failed to get imports from %s: %s", file_path, e)
|
||||
return []
|
||||
|
||||
def get_project_imports(
|
||||
self, file_path: Path, analyzer: JavaAnalyzer | None = None
|
||||
) -> list[ResolvedImport]:
|
||||
def get_project_imports(self, file_path: Path, analyzer: JavaAnalyzer | None = None) -> list[ResolvedImport]:
|
||||
"""Get only the imports that resolve to files within the project.
|
||||
|
||||
Args:
|
||||
|
|
@ -308,10 +291,7 @@ def resolve_imports_for_file(
|
|||
|
||||
|
||||
def find_helper_files(
|
||||
file_path: Path,
|
||||
project_root: Path,
|
||||
max_depth: int = 2,
|
||||
analyzer: JavaAnalyzer | None = None,
|
||||
file_path: Path, project_root: Path, max_depth: int = 2, analyzer: JavaAnalyzer | None = None
|
||||
) -> dict[Path, list[str]]:
|
||||
"""Find helper files imported by a Java file, recursively.
|
||||
|
||||
|
|
|
|||
|
|
@ -17,16 +17,16 @@ from __future__ import annotations
|
|||
import logging
|
||||
import re
|
||||
from functools import lru_cache
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
|
||||
from codeflash.languages.java.parser import JavaAnalyzer
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Sequence
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
|
||||
from codeflash.languages.java.parser import JavaAnalyzer
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
|
|
@ -36,7 +36,8 @@ def _get_function_name(func: Any) -> str:
|
|||
return func.function_name
|
||||
if hasattr(func, "name"):
|
||||
return func.name
|
||||
raise AttributeError(f"Cannot get function name from {type(func)}")
|
||||
msg = f"Cannot get function name from {type(func)}"
|
||||
raise AttributeError(msg)
|
||||
|
||||
|
||||
def _get_qualified_name(func: Any) -> str:
|
||||
|
|
@ -135,7 +136,7 @@ def instrument_existing_test(
|
|||
try:
|
||||
source = test_path.read_text(encoding="utf-8")
|
||||
except Exception as e:
|
||||
logger.error("Failed to read test file %s: %s", test_path, e)
|
||||
logger.exception("Failed to read test file %s: %s", test_path, e)
|
||||
return False, f"Failed to read test file: {e}"
|
||||
|
||||
func_name = _get_function_name(function_to_optimize)
|
||||
|
|
@ -227,7 +228,7 @@ def _add_behavior_instrumentation(source: str, class_name: str, func_name: str)
|
|||
result.append(imp)
|
||||
imports_added = True
|
||||
continue
|
||||
if stripped.startswith("public class") or stripped.startswith("class"):
|
||||
if stripped.startswith(("public class", "class")):
|
||||
# No imports found, add before class
|
||||
for imp in import_statements:
|
||||
result.append(imp)
|
||||
|
|
@ -244,7 +245,6 @@ def _add_behavior_instrumentation(source: str, class_name: str, func_name: str)
|
|||
i = 0
|
||||
iteration_counter = 0
|
||||
|
||||
|
||||
# Pre-compile the regex pattern once
|
||||
method_call_pattern = _get_method_call_pattern(func_name)
|
||||
|
||||
|
|
@ -291,11 +291,10 @@ def _add_behavior_instrumentation(source: str, class_name: str, func_name: str)
|
|||
while i < len(lines) and brace_depth > 0:
|
||||
body_line = lines[i]
|
||||
# Count braces more efficiently using string methods
|
||||
open_count = body_line.count('{')
|
||||
close_count = body_line.count('}')
|
||||
open_count = body_line.count("{")
|
||||
close_count = body_line.count("}")
|
||||
brace_depth += open_count - close_count
|
||||
|
||||
|
||||
if brace_depth > 0:
|
||||
body_lines.append(body_line)
|
||||
i += 1
|
||||
|
|
@ -581,7 +580,7 @@ def create_benchmark_test(
|
|||
method_id = _get_qualified_name(target_function)
|
||||
class_name = getattr(target_function, "class_name", None) or "Target"
|
||||
|
||||
benchmark_code = f"""
|
||||
return f"""
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.junit.jupiter.api.DisplayName;
|
||||
|
||||
|
|
@ -615,7 +614,6 @@ public class {class_name}Benchmark {{
|
|||
}}
|
||||
}}
|
||||
"""
|
||||
return benchmark_code
|
||||
|
||||
|
||||
def remove_instrumentation(source: str) -> str:
|
||||
|
|
@ -713,7 +711,7 @@ def _add_import(source: str, import_statement: str) -> str:
|
|||
# Find the last import or package statement
|
||||
for i, line in enumerate(lines):
|
||||
stripped = line.strip()
|
||||
if stripped.startswith("import ") or stripped.startswith("package "):
|
||||
if stripped.startswith(("import ", "package ")):
|
||||
insert_idx = i + 1
|
||||
elif stripped and not stripped.startswith("//") and not stripped.startswith("/*"):
|
||||
# First non-import, non-comment line
|
||||
|
|
@ -725,13 +723,11 @@ def _add_import(source: str, import_statement: str) -> str:
|
|||
return "".join(lines)
|
||||
|
||||
|
||||
|
||||
@lru_cache(maxsize=128)
|
||||
def _get_method_call_pattern(func_name: str):
|
||||
"""Cache compiled regex patterns for method call matching."""
|
||||
return re.compile(
|
||||
rf"((?:new\s+\w+\s*\([^)]*\)|[a-zA-Z_]\w*))\s*\.\s*({re.escape(func_name)})\s*\(([^)]*)\)",
|
||||
re.MULTILINE
|
||||
rf"((?:new\s+\w+\s*\([^)]*\)|[a-zA-Z_]\w*))\s*\.\s*({re.escape(func_name)})\s*\(([^)]*)\)", re.MULTILINE
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -739,6 +735,5 @@ def _get_method_call_pattern(func_name: str):
|
|||
def _get_method_call_pattern(func_name: str):
|
||||
"""Cache compiled regex patterns for method call matching."""
|
||||
return re.compile(
|
||||
rf"((?:new\s+\w+\s*\([^)]*\)|[a-zA-Z_]\w*))\s*\.\s*({re.escape(func_name)})\s*\(([^)]*)\)",
|
||||
re.MULTILINE
|
||||
rf"((?:new\s+\w+\s*\([^)]*\)|[a-zA-Z_]\w*))\s*\.\s*({re.escape(func_name)})\s*\(([^)]*)\)", re.MULTILINE
|
||||
)
|
||||
|
|
|
|||
|
|
@ -13,8 +13,6 @@ from typing import TYPE_CHECKING
|
|||
from tree_sitter import Language, Parser
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from pathlib import Path
|
||||
|
||||
from tree_sitter import Node, Tree
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
|
@ -222,9 +220,7 @@ class JavaAnalyzer:
|
|||
current_class=new_class if node.type in type_declarations else current_class,
|
||||
)
|
||||
|
||||
def _extract_method_info(
|
||||
self, node: Node, source_bytes: bytes, current_class: str | None
|
||||
) -> JavaMethodNode | None:
|
||||
def _extract_method_info(self, node: Node, source_bytes: bytes, current_class: str | None) -> JavaMethodNode | None:
|
||||
"""Extract method information from a method_declaration node."""
|
||||
name = ""
|
||||
is_static = False
|
||||
|
|
@ -347,9 +343,7 @@ class JavaAnalyzer:
|
|||
for child in node.children:
|
||||
self._walk_tree_for_classes(child, source_bytes, classes, is_inner)
|
||||
|
||||
def _extract_class_info(
|
||||
self, node: Node, source_bytes: bytes, is_inner: bool
|
||||
) -> JavaClassNode | None:
|
||||
def _extract_class_info(self, node: Node, source_bytes: bytes, is_inner: bool) -> JavaClassNode | None:
|
||||
"""Extract class information from a class_declaration node."""
|
||||
name = ""
|
||||
is_public = False
|
||||
|
|
|
|||
|
|
@ -18,10 +18,10 @@ from pathlib import Path
|
|||
from typing import TYPE_CHECKING
|
||||
|
||||
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
|
||||
from codeflash.languages.java.parser import JavaAnalyzer, JavaMethodNode, get_java_analyzer
|
||||
from codeflash.languages.java.parser import get_java_analyzer
|
||||
|
||||
if TYPE_CHECKING:
|
||||
pass
|
||||
from codeflash.languages.java.parser import JavaAnalyzer
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
|
@ -35,11 +35,7 @@ class ParsedOptimization:
|
|||
new_helper_methods: list[str] # Source text of new helper methods to add
|
||||
|
||||
|
||||
def _parse_optimization_source(
|
||||
new_source: str,
|
||||
target_method_name: str,
|
||||
analyzer: JavaAnalyzer,
|
||||
) -> ParsedOptimization:
|
||||
def _parse_optimization_source(new_source: str, target_method_name: str, analyzer: JavaAnalyzer) -> ParsedOptimization:
|
||||
"""Parse optimization source to extract method and additional class members.
|
||||
|
||||
The new_source may contain:
|
||||
|
|
@ -96,18 +92,12 @@ def _parse_optimization_source(
|
|||
new_fields.append(field.source_text)
|
||||
|
||||
return ParsedOptimization(
|
||||
target_method_source=target_method_source,
|
||||
new_fields=new_fields,
|
||||
new_helper_methods=new_helper_methods,
|
||||
target_method_source=target_method_source, new_fields=new_fields, new_helper_methods=new_helper_methods
|
||||
)
|
||||
|
||||
|
||||
def _insert_class_members(
|
||||
source: str,
|
||||
class_name: str,
|
||||
fields: list[str],
|
||||
methods: list[str],
|
||||
analyzer: JavaAnalyzer,
|
||||
source: str, class_name: str, fields: list[str], methods: list[str], analyzer: JavaAnalyzer
|
||||
) -> str:
|
||||
"""Insert new class members (fields and methods) into a class.
|
||||
|
||||
|
|
@ -212,10 +202,7 @@ def _insert_class_members(
|
|||
|
||||
|
||||
def replace_function(
|
||||
source: str,
|
||||
function: FunctionToOptimize,
|
||||
new_source: str,
|
||||
analyzer: JavaAnalyzer | None = None,
|
||||
source: str, function: FunctionToOptimize, new_source: str, analyzer: JavaAnalyzer | None = None
|
||||
) -> str:
|
||||
"""Replace a function in source code with new implementation.
|
||||
|
||||
|
|
@ -257,9 +244,9 @@ def replace_function(
|
|||
|
||||
# Find all methods matching the name (there may be overloads)
|
||||
matching_methods = [
|
||||
m for m in methods
|
||||
if m.name == func_name
|
||||
and (function.class_name is None or m.class_name == function.class_name)
|
||||
m
|
||||
for m in methods
|
||||
if m.name == func_name and (function.class_name is None or m.class_name == function.class_name)
|
||||
]
|
||||
|
||||
if len(matching_methods) == 1:
|
||||
|
|
@ -296,10 +283,7 @@ def replace_function(
|
|||
break
|
||||
if not target_method:
|
||||
# Fallback: use the first match
|
||||
logger.warning(
|
||||
"Multiple overloads of %s found but no line match, using first match",
|
||||
func_name,
|
||||
)
|
||||
logger.warning("Multiple overloads of %s found but no line match, using first match", func_name)
|
||||
target_method = matching_methods[0]
|
||||
target_overload_index = 0
|
||||
|
||||
|
|
@ -342,18 +326,16 @@ def replace_function(
|
|||
len(new_helpers_to_add),
|
||||
class_name,
|
||||
)
|
||||
source = _insert_class_members(
|
||||
source, class_name, new_fields_to_add, new_helpers_to_add, analyzer
|
||||
)
|
||||
source = _insert_class_members(source, class_name, new_fields_to_add, new_helpers_to_add, analyzer)
|
||||
|
||||
# Re-find the target method after modifications
|
||||
# Line numbers have shifted, but the relative order of overloads is preserved
|
||||
# Use the target_overload_index we saved earlier
|
||||
methods = analyzer.find_methods(source)
|
||||
matching_methods = [
|
||||
m for m in methods
|
||||
if m.name == func_name
|
||||
and (function.class_name is None or m.class_name == function.class_name)
|
||||
m
|
||||
for m in methods
|
||||
if m.name == func_name and (function.class_name is None or m.class_name == function.class_name)
|
||||
]
|
||||
|
||||
if matching_methods and target_overload_index < len(matching_methods):
|
||||
|
|
@ -398,9 +380,7 @@ def replace_function(
|
|||
before = lines[: start_line - 1] # Lines before the method
|
||||
after = lines[end_line:] # Lines after the method
|
||||
|
||||
result = "".join(before) + indented_new_source + "".join(after)
|
||||
|
||||
return result
|
||||
return "".join(before) + indented_new_source + "".join(after)
|
||||
|
||||
|
||||
def _get_indentation(line: str) -> str:
|
||||
|
|
@ -460,10 +440,7 @@ def _apply_indentation(lines: list[str], base_indent: str) -> str:
|
|||
|
||||
|
||||
def replace_method_body(
|
||||
source: str,
|
||||
function: FunctionToOptimize,
|
||||
new_body: str,
|
||||
analyzer: JavaAnalyzer | None = None,
|
||||
source: str, function: FunctionToOptimize, new_body: str, analyzer: JavaAnalyzer | None = None
|
||||
) -> str:
|
||||
"""Replace just the body of a method, preserving signature.
|
||||
|
||||
|
|
@ -600,11 +577,7 @@ def insert_method(
|
|||
return (before + separator.encode("utf8") + indented_method.encode("utf8") + after).decode("utf8")
|
||||
|
||||
|
||||
def remove_method(
|
||||
source: str,
|
||||
function: FunctionToOptimize,
|
||||
analyzer: JavaAnalyzer | None = None,
|
||||
) -> str:
|
||||
def remove_method(source: str, function: FunctionToOptimize, analyzer: JavaAnalyzer | None = None) -> str:
|
||||
"""Remove a method from source code.
|
||||
|
||||
Args:
|
||||
|
|
@ -648,9 +621,7 @@ def remove_method(
|
|||
|
||||
|
||||
def remove_test_functions(
|
||||
test_source: str,
|
||||
functions_to_remove: list[str],
|
||||
analyzer: JavaAnalyzer | None = None,
|
||||
test_source: str, functions_to_remove: list[str], analyzer: JavaAnalyzer | None = None
|
||||
) -> str:
|
||||
"""Remove specific test functions from test source code.
|
||||
|
||||
|
|
@ -669,9 +640,7 @@ def remove_test_functions(
|
|||
methods = analyzer.find_methods(test_source)
|
||||
|
||||
# Sort by start line in reverse order (remove from end first)
|
||||
methods_to_remove = [
|
||||
m for m in methods if m.name in functions_to_remove
|
||||
]
|
||||
methods_to_remove = [m for m in methods if m.name in functions_to_remove]
|
||||
methods_to_remove.sort(key=lambda m: m.start_line, reverse=True)
|
||||
|
||||
result = test_source
|
||||
|
|
@ -728,9 +697,7 @@ def add_runtime_comments(
|
|||
|
||||
if original_ns > 0:
|
||||
speedup = ((original_ns - optimized_ns) / original_ns) * 100
|
||||
summary_lines.append(
|
||||
f"// {inv_id}: {original_ms:.3f}ms -> {optimized_ms:.3f}ms ({speedup:.1f}% faster)"
|
||||
)
|
||||
summary_lines.append(f"// {inv_id}: {original_ms:.3f}ms -> {optimized_ms:.3f}ms ({speedup:.1f}% faster)")
|
||||
|
||||
# Insert after imports
|
||||
lines = test_source.splitlines(keepends=True)
|
||||
|
|
|
|||
|
|
@ -7,20 +7,9 @@ required methods for Java language support in codeflash.
|
|||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
|
||||
from codeflash.languages.base import (
|
||||
CodeContext,
|
||||
FunctionFilterCriteria,
|
||||
HelperFunction,
|
||||
Language,
|
||||
LanguageSupport,
|
||||
TestInfo,
|
||||
TestResult,
|
||||
)
|
||||
from codeflash.languages.registry import register_language
|
||||
from codeflash.languages.base import Language, LanguageSupport
|
||||
from codeflash.languages.java.build_tools import find_test_root
|
||||
from codeflash.languages.java.comparator import compare_test_results as _compare_test_results
|
||||
from codeflash.languages.java.config import detect_java_project
|
||||
|
|
@ -33,11 +22,7 @@ from codeflash.languages.java.instrumentation import (
|
|||
instrument_for_benchmarking,
|
||||
)
|
||||
from codeflash.languages.java.parser import get_java_analyzer
|
||||
from codeflash.languages.java.replacement import (
|
||||
add_runtime_comments,
|
||||
remove_test_functions,
|
||||
replace_function,
|
||||
)
|
||||
from codeflash.languages.java.replacement import add_runtime_comments, remove_test_functions, replace_function
|
||||
from codeflash.languages.java.test_discovery import discover_tests
|
||||
from codeflash.languages.java.test_runner import (
|
||||
parse_test_results,
|
||||
|
|
@ -45,9 +30,14 @@ from codeflash.languages.java.test_runner import (
|
|||
run_benchmarking_tests,
|
||||
run_tests,
|
||||
)
|
||||
from codeflash.languages.registry import register_language
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Sequence
|
||||
from pathlib import Path
|
||||
|
||||
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
|
||||
from codeflash.languages.base import CodeContext, FunctionFilterCriteria, HelperFunction, TestInfo, TestResult
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
|
@ -112,23 +102,17 @@ class JavaSupport(LanguageSupport):
|
|||
|
||||
# === Code Analysis ===
|
||||
|
||||
def extract_code_context(
|
||||
self, function: FunctionToOptimize, project_root: Path, module_root: Path
|
||||
) -> CodeContext:
|
||||
def extract_code_context(self, function: FunctionToOptimize, project_root: Path, module_root: Path) -> CodeContext:
|
||||
"""Extract function code and its dependencies."""
|
||||
return extract_code_context(function, project_root, module_root, analyzer=self._analyzer)
|
||||
|
||||
def find_helper_functions(
|
||||
self, function: FunctionToOptimize, project_root: Path
|
||||
) -> list[HelperFunction]:
|
||||
def find_helper_functions(self, function: FunctionToOptimize, project_root: Path) -> list[HelperFunction]:
|
||||
"""Find helper functions called by the target function."""
|
||||
return find_helper_functions(function, project_root, analyzer=self._analyzer)
|
||||
|
||||
# === Code Transformation ===
|
||||
|
||||
def replace_function(
|
||||
self, source: str, function: FunctionToOptimize, new_source: str
|
||||
) -> str:
|
||||
def replace_function(self, source: str, function: FunctionToOptimize, new_source: str) -> str:
|
||||
"""Replace a function in source code with new implementation."""
|
||||
return replace_function(source, function, new_source, self._analyzer)
|
||||
|
||||
|
|
@ -140,11 +124,7 @@ class JavaSupport(LanguageSupport):
|
|||
# === Test Execution ===
|
||||
|
||||
def run_tests(
|
||||
self,
|
||||
test_files: Sequence[Path],
|
||||
cwd: Path,
|
||||
env: dict[str, str],
|
||||
timeout: int,
|
||||
self, test_files: Sequence[Path], cwd: Path, env: dict[str, str], timeout: int
|
||||
) -> tuple[list[TestResult], Path]:
|
||||
"""Run tests and return results."""
|
||||
return run_tests(list(test_files), cwd, env, timeout)
|
||||
|
|
@ -155,15 +135,11 @@ class JavaSupport(LanguageSupport):
|
|||
|
||||
# === Instrumentation ===
|
||||
|
||||
def instrument_for_behavior(
|
||||
self, source: str, functions: Sequence[FunctionToOptimize]
|
||||
) -> str:
|
||||
def instrument_for_behavior(self, source: str, functions: Sequence[FunctionToOptimize]) -> str:
|
||||
"""Add behavior instrumentation to capture inputs/outputs."""
|
||||
return instrument_for_behavior(source, functions, self._analyzer)
|
||||
|
||||
def instrument_for_benchmarking(
|
||||
self, test_source: str, target_function: FunctionToOptimize
|
||||
) -> str:
|
||||
def instrument_for_benchmarking(self, test_source: str, target_function: FunctionToOptimize) -> str:
|
||||
"""Add timing instrumentation to test code."""
|
||||
return instrument_for_benchmarking(test_source, target_function, self._analyzer)
|
||||
|
||||
|
|
@ -180,32 +156,22 @@ class JavaSupport(LanguageSupport):
|
|||
# === Test Editing ===
|
||||
|
||||
def add_runtime_comments(
|
||||
self,
|
||||
test_source: str,
|
||||
original_runtimes: dict[str, int],
|
||||
optimized_runtimes: dict[str, int],
|
||||
self, test_source: str, original_runtimes: dict[str, int], optimized_runtimes: dict[str, int]
|
||||
) -> str:
|
||||
"""Add runtime performance comments to test source code."""
|
||||
return add_runtime_comments(test_source, original_runtimes, optimized_runtimes, self._analyzer)
|
||||
|
||||
def remove_test_functions(
|
||||
self, test_source: str, functions_to_remove: list[str]
|
||||
) -> str:
|
||||
def remove_test_functions(self, test_source: str, functions_to_remove: list[str]) -> str:
|
||||
"""Remove specific test functions from test source code."""
|
||||
return remove_test_functions(test_source, functions_to_remove, self._analyzer)
|
||||
|
||||
# === Test Result Comparison ===
|
||||
|
||||
def compare_test_results(
|
||||
self,
|
||||
original_results_path: Path,
|
||||
candidate_results_path: Path,
|
||||
project_root: Path | None = None,
|
||||
self, original_results_path: Path, candidate_results_path: Path, project_root: Path | None = None
|
||||
) -> tuple[bool, list]:
|
||||
"""Compare test results between original and candidate code."""
|
||||
return _compare_test_results(
|
||||
original_results_path, candidate_results_path, project_root=project_root
|
||||
)
|
||||
return _compare_test_results(original_results_path, candidate_results_path, project_root=project_root)
|
||||
|
||||
# === Configuration ===
|
||||
|
||||
|
|
@ -308,12 +274,7 @@ class JavaSupport(LanguageSupport):
|
|||
) -> tuple[bool, str | None]:
|
||||
"""Inject profiling code into an existing test file."""
|
||||
return instrument_existing_test(
|
||||
test_path,
|
||||
call_positions,
|
||||
function_to_optimize,
|
||||
tests_project_root,
|
||||
mode,
|
||||
self._analyzer,
|
||||
test_path, call_positions, function_to_optimize, tests_project_root, mode, self._analyzer
|
||||
)
|
||||
|
||||
def instrument_source_for_line_profiler(
|
||||
|
|
|
|||
|
|
@ -7,27 +7,26 @@ specific functions, mapping source functions to their tests.
|
|||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import re
|
||||
from collections import defaultdict
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
|
||||
from codeflash.languages.base import TestInfo
|
||||
from codeflash.languages.java.config import detect_java_project
|
||||
from codeflash.languages.java.discovery import discover_test_methods
|
||||
from codeflash.languages.java.parser import JavaAnalyzer, get_java_analyzer
|
||||
from codeflash.languages.java.parser import get_java_analyzer
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Sequence
|
||||
from pathlib import Path
|
||||
|
||||
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
|
||||
from codeflash.languages.java.parser import JavaAnalyzer
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def discover_tests(
|
||||
test_root: Path,
|
||||
source_functions: Sequence[FunctionToOptimize],
|
||||
analyzer: JavaAnalyzer | None = None,
|
||||
test_root: Path, source_functions: Sequence[FunctionToOptimize], analyzer: JavaAnalyzer | None = None
|
||||
) -> dict[str, list[TestInfo]]:
|
||||
"""Map source functions to their tests via static analysis.
|
||||
|
||||
|
|
@ -56,9 +55,7 @@ def discover_tests(
|
|||
|
||||
# Find all test files (various naming conventions)
|
||||
test_files = (
|
||||
list(test_root.rglob("*Test.java"))
|
||||
+ list(test_root.rglob("*Tests.java"))
|
||||
+ list(test_root.rglob("Test*.java"))
|
||||
list(test_root.rglob("*Test.java")) + list(test_root.rglob("*Tests.java")) + list(test_root.rglob("Test*.java"))
|
||||
)
|
||||
|
||||
# Result map
|
||||
|
|
@ -71,16 +68,12 @@ def discover_tests(
|
|||
|
||||
for test_method in test_methods:
|
||||
# Find which source functions this test might exercise
|
||||
matched_functions = _match_test_to_functions(
|
||||
test_method, source, function_map, analyzer
|
||||
)
|
||||
matched_functions = _match_test_to_functions(test_method, source, function_map, analyzer)
|
||||
|
||||
for func_name in matched_functions:
|
||||
result[func_name].append(
|
||||
TestInfo(
|
||||
test_name=test_method.function_name,
|
||||
test_file=test_file,
|
||||
test_class=test_method.class_name,
|
||||
test_name=test_method.function_name, test_file=test_file, test_class=test_method.class_name
|
||||
)
|
||||
)
|
||||
|
||||
|
|
@ -114,7 +107,7 @@ def _match_test_to_functions(
|
|||
# e.g., testAdd -> add, testCalculatorAdd -> Calculator.add
|
||||
test_name_lower = test_method.function_name.lower()
|
||||
|
||||
for func_name, func_info in function_map.items():
|
||||
for func_info in function_map.values():
|
||||
if func_info.function_name.lower() in test_name_lower:
|
||||
matched.append(func_info.qualified_name)
|
||||
|
||||
|
|
@ -125,11 +118,7 @@ def _match_test_to_functions(
|
|||
|
||||
# Find method calls within the test method's line range
|
||||
method_calls = _find_method_calls_in_range(
|
||||
tree.root_node,
|
||||
source_bytes,
|
||||
test_method.starting_line,
|
||||
test_method.ending_line,
|
||||
analyzer,
|
||||
tree.root_node, source_bytes, test_method.starting_line, test_method.ending_line, analyzer
|
||||
)
|
||||
|
||||
for call_name in method_calls:
|
||||
|
|
@ -151,7 +140,7 @@ def _match_test_to_functions(
|
|||
source_class_name = source_class_name[4:]
|
||||
|
||||
# Look for functions in the matching class
|
||||
for func_name, func_info in function_map.items():
|
||||
for func_info in function_map.values():
|
||||
if func_info.class_name == source_class_name:
|
||||
if func_info.qualified_name not in matched:
|
||||
matched.append(func_info.qualified_name)
|
||||
|
|
@ -161,7 +150,7 @@ def _match_test_to_functions(
|
|||
# This handles cases like TestQueryBlob importing Buffer and calling Buffer methods
|
||||
imported_classes = _extract_imports(tree.root_node, source_bytes, analyzer)
|
||||
|
||||
for func_name, func_info in function_map.items():
|
||||
for func_info in function_map.values():
|
||||
if func_info.qualified_name in matched:
|
||||
continue
|
||||
|
||||
|
|
@ -172,11 +161,7 @@ def _match_test_to_functions(
|
|||
return matched
|
||||
|
||||
|
||||
def _extract_imports(
|
||||
node,
|
||||
source_bytes: bytes,
|
||||
analyzer: JavaAnalyzer,
|
||||
) -> set[str]:
|
||||
def _extract_imports(node, source_bytes: bytes, analyzer: JavaAnalyzer) -> set[str]:
|
||||
"""Extract imported class names from a Java file.
|
||||
|
||||
Args:
|
||||
|
|
@ -224,7 +209,7 @@ def _extract_imports(
|
|||
|
||||
# Regular import: extract class name from scoped_identifier
|
||||
for child in n.children:
|
||||
if child.type == "scoped_identifier" or child.type == "identifier":
|
||||
if child.type in {"scoped_identifier", "identifier"}:
|
||||
import_path = analyzer.get_node_text(child, source_bytes)
|
||||
# Extract just the class name (last part)
|
||||
# e.g., "com.example.Buffer" -> "Buffer"
|
||||
|
|
@ -244,11 +229,7 @@ def _extract_imports(
|
|||
|
||||
|
||||
def _find_method_calls_in_range(
|
||||
node,
|
||||
source_bytes: bytes,
|
||||
start_line: int,
|
||||
end_line: int,
|
||||
analyzer: JavaAnalyzer,
|
||||
node, source_bytes: bytes, start_line: int, end_line: int, analyzer: JavaAnalyzer
|
||||
) -> list[str]:
|
||||
"""Find method calls within a line range.
|
||||
|
||||
|
|
@ -278,17 +259,13 @@ def _find_method_calls_in_range(
|
|||
calls.append(analyzer.get_node_text(name_node, source_bytes))
|
||||
|
||||
for child in node.children:
|
||||
calls.extend(
|
||||
_find_method_calls_in_range(child, source_bytes, start_line, end_line, analyzer)
|
||||
)
|
||||
calls.extend(_find_method_calls_in_range(child, source_bytes, start_line, end_line, analyzer))
|
||||
|
||||
return calls
|
||||
|
||||
|
||||
def find_tests_for_function(
|
||||
function: FunctionToOptimize,
|
||||
test_root: Path,
|
||||
analyzer: JavaAnalyzer | None = None,
|
||||
function: FunctionToOptimize, test_root: Path, analyzer: JavaAnalyzer | None = None
|
||||
) -> list[TestInfo]:
|
||||
"""Find tests that exercise a specific function.
|
||||
|
||||
|
|
@ -305,10 +282,7 @@ def find_tests_for_function(
|
|||
return result.get(function.qualified_name, [])
|
||||
|
||||
|
||||
def get_test_class_for_source_class(
|
||||
source_class_name: str,
|
||||
test_root: Path,
|
||||
) -> Path | None:
|
||||
def get_test_class_for_source_class(source_class_name: str, test_root: Path) -> Path | None:
|
||||
"""Find the test class file for a source class.
|
||||
|
||||
Args:
|
||||
|
|
@ -320,11 +294,7 @@ def get_test_class_for_source_class(
|
|||
|
||||
"""
|
||||
# Try common naming patterns
|
||||
patterns = [
|
||||
f"{source_class_name}Test.java",
|
||||
f"Test{source_class_name}.java",
|
||||
f"{source_class_name}Tests.java",
|
||||
]
|
||||
patterns = [f"{source_class_name}Test.java", f"Test{source_class_name}.java", f"{source_class_name}Tests.java"]
|
||||
|
||||
for pattern in patterns:
|
||||
matches = list(test_root.rglob(pattern))
|
||||
|
|
@ -334,10 +304,7 @@ def get_test_class_for_source_class(
|
|||
return None
|
||||
|
||||
|
||||
def discover_all_tests(
|
||||
test_root: Path,
|
||||
analyzer: JavaAnalyzer | None = None,
|
||||
) -> list[FunctionToOptimize]:
|
||||
def discover_all_tests(test_root: Path, analyzer: JavaAnalyzer | None = None) -> list[FunctionToOptimize]:
|
||||
"""Discover all test methods in a test directory.
|
||||
|
||||
Args:
|
||||
|
|
@ -353,9 +320,7 @@ def discover_all_tests(
|
|||
|
||||
# Find all test files (various naming conventions)
|
||||
test_files = (
|
||||
list(test_root.rglob("*Test.java"))
|
||||
+ list(test_root.rglob("*Tests.java"))
|
||||
+ list(test_root.rglob("Test*.java"))
|
||||
list(test_root.rglob("*Test.java")) + list(test_root.rglob("*Tests.java")) + list(test_root.rglob("Test*.java"))
|
||||
)
|
||||
|
||||
for test_file in test_files:
|
||||
|
|
@ -391,24 +356,18 @@ def is_test_file(file_path: Path) -> bool:
|
|||
name = file_path.name
|
||||
|
||||
# Check naming patterns
|
||||
if name.endswith("Test.java") or name.endswith("Tests.java"):
|
||||
if name.endswith(("Test.java", "Tests.java")):
|
||||
return True
|
||||
if name.startswith("Test") and name.endswith(".java"):
|
||||
return True
|
||||
|
||||
# Check if it's in a test directory
|
||||
path_parts = file_path.parts
|
||||
for part in path_parts:
|
||||
if part in ("test", "tests", "src/test"):
|
||||
return True
|
||||
|
||||
return False
|
||||
return any(part in ("test", "tests", "src/test") for part in path_parts)
|
||||
|
||||
|
||||
def get_test_methods_for_class(
|
||||
test_file: Path,
|
||||
test_class_name: str | None = None,
|
||||
analyzer: JavaAnalyzer | None = None,
|
||||
test_file: Path, test_class_name: str | None = None, analyzer: JavaAnalyzer | None = None
|
||||
) -> list[FunctionToOptimize]:
|
||||
"""Get all test methods in a specific test class.
|
||||
|
||||
|
|
@ -430,8 +389,7 @@ def get_test_methods_for_class(
|
|||
|
||||
|
||||
def build_test_mapping_for_project(
|
||||
project_root: Path,
|
||||
analyzer: JavaAnalyzer | None = None,
|
||||
project_root: Path, analyzer: JavaAnalyzer | None = None
|
||||
) -> dict[str, list[TestInfo]]:
|
||||
"""Build a complete test mapping for a project.
|
||||
|
||||
|
|
|
|||
|
|
@ -79,10 +79,11 @@ def _validate_test_filter(test_filter: str) -> str:
|
|||
name_to_validate = pattern.replace("*", "A") # Replace * with a valid char
|
||||
|
||||
if not _validate_java_class_name(name_to_validate):
|
||||
raise ValueError(
|
||||
msg = (
|
||||
f"Invalid test class name or pattern: '{pattern}'. "
|
||||
f"Test names must follow Java identifier rules (letters, digits, underscores, dots, dollar signs)."
|
||||
)
|
||||
raise ValueError(msg)
|
||||
|
||||
return test_filter
|
||||
|
||||
|
|
@ -387,10 +388,7 @@ def run_behavioral_tests(
|
|||
|
||||
|
||||
def _compile_tests(
|
||||
project_root: Path,
|
||||
env: dict[str, str],
|
||||
test_module: str | None = None,
|
||||
timeout: int = 120,
|
||||
project_root: Path, env: dict[str, str], test_module: str | None = None, timeout: int = 120
|
||||
) -> subprocess.CompletedProcess:
|
||||
"""Compile test code using Maven (without running tests).
|
||||
|
||||
|
|
@ -407,12 +405,7 @@ def _compile_tests(
|
|||
mvn = find_maven_executable()
|
||||
if not mvn:
|
||||
logger.error("Maven not found")
|
||||
return subprocess.CompletedProcess(
|
||||
args=["mvn"],
|
||||
returncode=-1,
|
||||
stdout="",
|
||||
stderr="Maven not found",
|
||||
)
|
||||
return subprocess.CompletedProcess(args=["mvn"], returncode=-1, stdout="", stderr="Maven not found")
|
||||
|
||||
cmd = [mvn, "test-compile", "-e"] # Show errors but not verbose output
|
||||
|
||||
|
|
@ -423,37 +416,20 @@ def _compile_tests(
|
|||
|
||||
try:
|
||||
return subprocess.run(
|
||||
cmd,
|
||||
check=False,
|
||||
cwd=project_root,
|
||||
env=env,
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=timeout,
|
||||
cmd, check=False, cwd=project_root, env=env, capture_output=True, text=True, timeout=timeout
|
||||
)
|
||||
except subprocess.TimeoutExpired:
|
||||
logger.error("Maven compilation timed out after %d seconds", timeout)
|
||||
logger.exception("Maven compilation timed out after %d seconds", timeout)
|
||||
return subprocess.CompletedProcess(
|
||||
args=cmd,
|
||||
returncode=-2,
|
||||
stdout="",
|
||||
stderr=f"Compilation timed out after {timeout} seconds",
|
||||
args=cmd, returncode=-2, stdout="", stderr=f"Compilation timed out after {timeout} seconds"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.exception("Maven compilation failed: %s", e)
|
||||
return subprocess.CompletedProcess(
|
||||
args=cmd,
|
||||
returncode=-1,
|
||||
stdout="",
|
||||
stderr=str(e),
|
||||
)
|
||||
return subprocess.CompletedProcess(args=cmd, returncode=-1, stdout="", stderr=str(e))
|
||||
|
||||
|
||||
def _get_test_classpath(
|
||||
project_root: Path,
|
||||
env: dict[str, str],
|
||||
test_module: str | None = None,
|
||||
timeout: int = 60,
|
||||
project_root: Path, env: dict[str, str], test_module: str | None = None, timeout: int = 60
|
||||
) -> str | None:
|
||||
"""Get the test classpath from Maven.
|
||||
|
||||
|
|
@ -474,13 +450,7 @@ def _get_test_classpath(
|
|||
# Create temp file for classpath output
|
||||
cp_file = project_root / ".codeflash_classpath.txt"
|
||||
|
||||
cmd = [
|
||||
mvn,
|
||||
"dependency:build-classpath",
|
||||
"-DincludeScope=test",
|
||||
f"-Dmdep.outputFile={cp_file}",
|
||||
"-q",
|
||||
]
|
||||
cmd = [mvn, "dependency:build-classpath", "-DincludeScope=test", f"-Dmdep.outputFile={cp_file}", "-q"]
|
||||
|
||||
if test_module:
|
||||
cmd.extend(["-pl", test_module])
|
||||
|
|
@ -489,13 +459,7 @@ def _get_test_classpath(
|
|||
|
||||
try:
|
||||
result = subprocess.run(
|
||||
cmd,
|
||||
check=False,
|
||||
cwd=project_root,
|
||||
env=env,
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=timeout,
|
||||
cmd, check=False, cwd=project_root, env=env, capture_output=True, text=True, timeout=timeout
|
||||
)
|
||||
|
||||
if result.returncode != 0:
|
||||
|
|
@ -527,7 +491,7 @@ def _get_test_classpath(
|
|||
return os.pathsep.join(cp_parts)
|
||||
|
||||
except subprocess.TimeoutExpired:
|
||||
logger.error("Getting classpath timed out")
|
||||
logger.exception("Getting classpath timed out")
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.exception("Failed to get classpath: %s", e)
|
||||
|
|
@ -602,30 +566,16 @@ def _run_tests_direct(
|
|||
|
||||
try:
|
||||
return subprocess.run(
|
||||
cmd,
|
||||
check=False,
|
||||
cwd=working_dir,
|
||||
env=env,
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=timeout,
|
||||
cmd, check=False, cwd=working_dir, env=env, capture_output=True, text=True, timeout=timeout
|
||||
)
|
||||
except subprocess.TimeoutExpired:
|
||||
logger.error("Direct test execution timed out after %d seconds", timeout)
|
||||
logger.exception("Direct test execution timed out after %d seconds", timeout)
|
||||
return subprocess.CompletedProcess(
|
||||
args=cmd,
|
||||
returncode=-2,
|
||||
stdout="",
|
||||
stderr=f"Test execution timed out after {timeout} seconds",
|
||||
args=cmd, returncode=-2, stdout="", stderr=f"Test execution timed out after {timeout} seconds"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.exception("Direct test execution failed: %s", e)
|
||||
return subprocess.CompletedProcess(
|
||||
args=cmd,
|
||||
returncode=-1,
|
||||
stdout="",
|
||||
stderr=str(e),
|
||||
)
|
||||
return subprocess.CompletedProcess(args=cmd, returncode=-1, stdout="", stderr=str(e))
|
||||
|
||||
|
||||
def _get_test_class_names(test_paths: Any, mode: str = "performance") -> list[str]:
|
||||
|
|
@ -680,10 +630,7 @@ def _get_empty_result(maven_root: Path, test_module: str | None) -> tuple[Path,
|
|||
result_xml_path = _get_combined_junit_xml(surefire_dir, -1)
|
||||
|
||||
empty_result = subprocess.CompletedProcess(
|
||||
args=["java", "-cp", "...", "ConsoleLauncher"],
|
||||
returncode=-1,
|
||||
stdout="",
|
||||
stderr="No test classes found",
|
||||
args=["java", "-cp", "...", "ConsoleLauncher"], returncode=-1, stdout="", stderr="No test classes found"
|
||||
)
|
||||
return result_xml_path, empty_result
|
||||
|
||||
|
|
@ -742,12 +689,7 @@ def _run_benchmarking_tests_maven(
|
|||
run_env["CODEFLASH_INNER_ITERATIONS"] = str(inner_iterations)
|
||||
|
||||
result = _run_maven_tests(
|
||||
maven_root,
|
||||
test_paths,
|
||||
run_env,
|
||||
timeout=per_loop_timeout,
|
||||
mode="performance",
|
||||
test_module=test_module,
|
||||
maven_root, test_paths, run_env, timeout=per_loop_timeout, mode="performance", test_module=test_module
|
||||
)
|
||||
|
||||
last_result = result
|
||||
|
|
@ -760,17 +702,14 @@ def _run_benchmarking_tests_maven(
|
|||
|
||||
elapsed = time.time() - total_start_time
|
||||
if loop_idx >= min_loops and elapsed >= target_duration_seconds:
|
||||
logger.debug(
|
||||
"Stopping Maven benchmark after %d loops (%.2fs elapsed)",
|
||||
loop_idx,
|
||||
elapsed,
|
||||
)
|
||||
logger.debug("Stopping Maven benchmark after %d loops (%.2fs elapsed)", loop_idx, elapsed)
|
||||
break
|
||||
|
||||
# Check if we have timing markers even if some tests failed
|
||||
# We should continue looping if we're getting valid timing data
|
||||
if result.returncode != 0:
|
||||
import re
|
||||
|
||||
timing_pattern = re.compile(r"!######[^:]*:[^:]*:[^:]*:[^:]*:[^:]+:[^:]+######!")
|
||||
has_timing_markers = bool(timing_pattern.search(result.stdout or ""))
|
||||
if not has_timing_markers:
|
||||
|
|
@ -885,8 +824,15 @@ def run_benchmarking_tests(
|
|||
# Fall back to Maven-based execution
|
||||
logger.warning("Falling back to Maven-based test execution")
|
||||
return _run_benchmarking_tests_maven(
|
||||
test_paths, test_env, cwd, timeout, project_root,
|
||||
min_loops, max_loops, target_duration_seconds, inner_iterations
|
||||
test_paths,
|
||||
test_env,
|
||||
cwd,
|
||||
timeout,
|
||||
project_root,
|
||||
min_loops,
|
||||
max_loops,
|
||||
target_duration_seconds,
|
||||
inner_iterations,
|
||||
)
|
||||
|
||||
logger.debug("Compilation completed in %.2fs", compile_time)
|
||||
|
|
@ -898,8 +844,15 @@ def run_benchmarking_tests(
|
|||
if not classpath:
|
||||
logger.warning("Failed to get classpath, falling back to Maven-based execution")
|
||||
return _run_benchmarking_tests_maven(
|
||||
test_paths, test_env, cwd, timeout, project_root,
|
||||
min_loops, max_loops, target_duration_seconds, inner_iterations
|
||||
test_paths,
|
||||
test_env,
|
||||
cwd,
|
||||
timeout,
|
||||
project_root,
|
||||
min_loops,
|
||||
max_loops,
|
||||
target_duration_seconds,
|
||||
inner_iterations,
|
||||
)
|
||||
|
||||
# Step 3: Run tests multiple times directly via JVM
|
||||
|
|
@ -937,12 +890,7 @@ def run_benchmarking_tests(
|
|||
# Run tests directly with XML report generation
|
||||
loop_start = time.time()
|
||||
result = _run_tests_direct(
|
||||
classpath,
|
||||
test_classes,
|
||||
run_env,
|
||||
working_dir,
|
||||
timeout=per_loop_timeout,
|
||||
reports_dir=reports_dir,
|
||||
classpath, test_classes, run_env, working_dir, timeout=per_loop_timeout, reports_dir=reports_dir
|
||||
)
|
||||
loop_time = time.time() - loop_start
|
||||
|
||||
|
|
@ -959,12 +907,7 @@ def run_benchmarking_tests(
|
|||
|
||||
# Check if JUnit Console Launcher is not available (JUnit 4 projects)
|
||||
# Fall back to Maven-based execution in this case
|
||||
if (
|
||||
loop_idx == 1
|
||||
and result.returncode != 0
|
||||
and result.stderr
|
||||
and "ConsoleLauncher" in result.stderr
|
||||
):
|
||||
if loop_idx == 1 and result.returncode != 0 and result.stderr and "ConsoleLauncher" in result.stderr:
|
||||
logger.debug("JUnit Console Launcher not available, falling back to Maven-based execution")
|
||||
return _run_benchmarking_tests_maven(
|
||||
test_paths,
|
||||
|
|
@ -993,6 +936,7 @@ def run_benchmarking_tests(
|
|||
# Check if tests failed - continue looping if we have timing markers
|
||||
if result.returncode != 0:
|
||||
import re
|
||||
|
||||
timing_pattern = re.compile(r"!######[^:]*:[^:]*:[^:]*:[^:]*:[^:]+:[^:]+######!")
|
||||
has_timing_markers = bool(timing_pattern.search(result.stdout or ""))
|
||||
if not has_timing_markers:
|
||||
|
|
@ -1319,12 +1263,7 @@ def _run_maven_tests_impl(
|
|||
mvn = find_maven_executable()
|
||||
if not mvn:
|
||||
logger.error("Maven not found")
|
||||
return subprocess.CompletedProcess(
|
||||
args=["mvn"],
|
||||
returncode=-1,
|
||||
stdout="",
|
||||
stderr="Maven not found",
|
||||
)
|
||||
return subprocess.CompletedProcess(args=["mvn"], returncode=-1, stdout="", stderr="Maven not found")
|
||||
|
||||
# Build test filter
|
||||
test_filter = _build_test_filter(test_paths, mode=mode)
|
||||
|
|
@ -1354,33 +1293,18 @@ def _run_maven_tests_impl(
|
|||
logger.debug("Running Maven command: %s in %s", " ".join(cmd), project_root)
|
||||
|
||||
try:
|
||||
result = subprocess.run(
|
||||
cmd,
|
||||
check=False,
|
||||
cwd=project_root,
|
||||
env=env,
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=timeout,
|
||||
return subprocess.run(
|
||||
cmd, check=False, cwd=project_root, env=env, capture_output=True, text=True, timeout=timeout
|
||||
)
|
||||
return result
|
||||
|
||||
except subprocess.TimeoutExpired:
|
||||
logger.error("Maven test execution timed out after %d seconds", timeout)
|
||||
logger.exception("Maven test execution timed out after %d seconds", timeout)
|
||||
return subprocess.CompletedProcess(
|
||||
args=cmd,
|
||||
returncode=-2,
|
||||
stdout="",
|
||||
stderr=f"Test execution timed out after {timeout} seconds",
|
||||
args=cmd, returncode=-2, stdout="", stderr=f"Test execution timed out after {timeout} seconds"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.exception("Maven test execution failed: %s", e)
|
||||
return subprocess.CompletedProcess(
|
||||
args=cmd,
|
||||
returncode=-1,
|
||||
stdout="",
|
||||
stderr=str(e),
|
||||
)
|
||||
return subprocess.CompletedProcess(args=cmd, returncode=-1, stdout="", stderr=str(e))
|
||||
|
||||
|
||||
def _build_test_filter(test_paths: Any, mode: str = "behavior") -> str:
|
||||
|
|
@ -1440,7 +1364,7 @@ def _path_to_class_name(path: Path) -> str | None:
|
|||
Fully qualified class name, or None if unable to determine.
|
||||
|
||||
"""
|
||||
if not path.suffix == ".java":
|
||||
if path.suffix != ".java":
|
||||
return None
|
||||
|
||||
# Try to extract package from path
|
||||
|
|
@ -1463,7 +1387,7 @@ def _path_to_class_name(path: Path) -> str | None:
|
|||
break
|
||||
|
||||
if java_idx is not None:
|
||||
class_parts = parts[java_idx + 1:]
|
||||
class_parts = parts[java_idx + 1 :]
|
||||
# Remove .java extension from last part
|
||||
class_parts[-1] = class_parts[-1].replace(".java", "")
|
||||
return ".".join(class_parts)
|
||||
|
|
@ -1472,12 +1396,7 @@ def _path_to_class_name(path: Path) -> str | None:
|
|||
return path.stem
|
||||
|
||||
|
||||
def run_tests(
|
||||
test_files: list[Path],
|
||||
cwd: Path,
|
||||
env: dict[str, str],
|
||||
timeout: int,
|
||||
) -> tuple[list[TestResult], Path]:
|
||||
def run_tests(test_files: list[Path], cwd: Path, env: dict[str, str], timeout: int) -> tuple[list[TestResult], Path]:
|
||||
"""Run tests and return results.
|
||||
|
||||
Args:
|
||||
|
|
@ -1610,10 +1529,7 @@ def _parse_surefire_xml(xml_file: Path) -> list[TestResult]:
|
|||
return results
|
||||
|
||||
|
||||
def get_test_run_command(
|
||||
project_root: Path,
|
||||
test_classes: list[str] | None = None,
|
||||
) -> list[str]:
|
||||
def get_test_run_command(project_root: Path, test_classes: list[str] | None = None) -> list[str]:
|
||||
"""Get the command to run Java tests.
|
||||
|
||||
Args:
|
||||
|
|
@ -1633,10 +1549,8 @@ def get_test_run_command(
|
|||
validated_classes = []
|
||||
for test_class in test_classes:
|
||||
if not _validate_java_class_name(test_class):
|
||||
raise ValueError(
|
||||
f"Invalid test class name: '{test_class}'. "
|
||||
f"Test names must follow Java identifier rules."
|
||||
)
|
||||
msg = f"Invalid test class name: '{test_class}'. Test names must follow Java identifier rules."
|
||||
raise ValueError(msg)
|
||||
validated_classes.append(test_class)
|
||||
|
||||
cmd.append(f"-Dtest={','.join(validated_classes)}")
|
||||
|
|
|
|||
|
|
@ -210,10 +210,10 @@ class ReferenceFinder:
|
|||
|
||||
# Check if this file imports from the re-export file
|
||||
import_info = self._find_matching_import(imports, reexport_file, file_path, reexported)
|
||||
|
||||
trigger_check = True
|
||||
if import_info:
|
||||
context.visited_files.add(file_path)
|
||||
import_name, original_import = import_info
|
||||
import_name, _original_import = import_info
|
||||
file_refs = self._find_references_in_file(
|
||||
file_path, file_code, reexport_name, import_name, file_analyzer, include_self=True
|
||||
)
|
||||
|
|
@ -651,15 +651,18 @@ class ReferenceFinder:
|
|||
|
||||
"""
|
||||
references: list[Reference] = []
|
||||
export_name = exported.export_name or exported.function_name
|
||||
|
||||
# Skip expensive parsing if export name not in source
|
||||
if export_name not in source_code:
|
||||
return references
|
||||
|
||||
exports = analyzer.find_exports(source_code)
|
||||
lines = source_code.splitlines()
|
||||
|
||||
for exp in exports:
|
||||
if not exp.is_reexport:
|
||||
continue
|
||||
|
||||
# Check if this re-exports our function
|
||||
export_name = exported.export_name or exported.function_name
|
||||
for name, alias in exp.exported_names:
|
||||
if name == export_name:
|
||||
# This is a re-export of our function
|
||||
|
|
|
|||
|
|
@ -11,6 +11,8 @@ import logging
|
|||
import re
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from codeflash.languages.current import is_typescript
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from pathlib import Path
|
||||
|
||||
|
|
@ -44,9 +46,10 @@ def detect_module_system(project_root: Path, file_path: Path | None = None) -> s
|
|||
"""Detect the module system used by a JavaScript/TypeScript project.
|
||||
|
||||
Detection strategy:
|
||||
1. Check package.json for "type" field
|
||||
2. If file_path provided, check file extension (.mjs = ESM, .cjs = CommonJS)
|
||||
3. Analyze import statements in the file
|
||||
1. Check file extension for explicit module type (.mjs, .cjs, .ts, .tsx, .mts)
|
||||
- TypeScript files always use ESM syntax regardless of package.json
|
||||
2. Check package.json for explicit "type" field (only if explicitly set)
|
||||
3. Analyze import/export statements in the file content
|
||||
4. Default to CommonJS if uncertain
|
||||
|
||||
Args:
|
||||
|
|
@ -57,13 +60,29 @@ def detect_module_system(project_root: Path, file_path: Path | None = None) -> s
|
|||
ModuleSystem constant (COMMONJS, ES_MODULE, or UNKNOWN).
|
||||
|
||||
"""
|
||||
# Strategy 1: Check package.json
|
||||
# Strategy 1: Check file extension first for explicit module type indicators
|
||||
# TypeScript files always use ESM syntax (import/export)
|
||||
if file_path:
|
||||
suffix = file_path.suffix.lower()
|
||||
if suffix == ".mjs":
|
||||
logger.debug("Detected ES Module from .mjs extension")
|
||||
return ModuleSystem.ES_MODULE
|
||||
if suffix == ".cjs":
|
||||
logger.debug("Detected CommonJS from .cjs extension")
|
||||
return ModuleSystem.COMMONJS
|
||||
if suffix in (".ts", ".tsx", ".mts"):
|
||||
# TypeScript always uses ESM syntax (import/export)
|
||||
# even if package.json doesn't have "type": "module"
|
||||
logger.debug("Detected ES Module from TypeScript file extension")
|
||||
return ModuleSystem.ES_MODULE
|
||||
|
||||
# Strategy 2: Check package.json for explicit type field
|
||||
package_json = project_root / "package.json"
|
||||
if package_json.exists():
|
||||
try:
|
||||
with package_json.open("r") as f:
|
||||
pkg = json.load(f)
|
||||
pkg_type = pkg.get("type", "commonjs")
|
||||
pkg_type = pkg.get("type") # Don't default - only use if explicitly set
|
||||
|
||||
if pkg_type == "module":
|
||||
logger.debug("Detected ES Module from package.json type field")
|
||||
|
|
@ -71,44 +90,35 @@ def detect_module_system(project_root: Path, file_path: Path | None = None) -> s
|
|||
if pkg_type == "commonjs":
|
||||
logger.debug("Detected CommonJS from package.json type field")
|
||||
return ModuleSystem.COMMONJS
|
||||
# If type is not explicitly set, continue to file content analysis
|
||||
|
||||
except Exception as e:
|
||||
logger.warning("Failed to parse package.json: %s", e)
|
||||
|
||||
# Strategy 2: Check file extension
|
||||
if file_path:
|
||||
suffix = file_path.suffix
|
||||
if suffix == ".mjs":
|
||||
logger.debug("Detected ES Module from .mjs extension")
|
||||
return ModuleSystem.ES_MODULE
|
||||
if suffix == ".cjs":
|
||||
logger.debug("Detected CommonJS from .cjs extension")
|
||||
return ModuleSystem.COMMONJS
|
||||
# Strategy 3: Analyze file content for import/export patterns
|
||||
if file_path and file_path.exists():
|
||||
try:
|
||||
content = file_path.read_text()
|
||||
|
||||
# Strategy 3: Analyze file content
|
||||
if file_path.exists():
|
||||
try:
|
||||
content = file_path.read_text()
|
||||
# Look for ES module syntax
|
||||
has_import = "import " in content and "from " in content
|
||||
has_export = "export " in content or "export default" in content or "export {" in content
|
||||
|
||||
# Look for ES module syntax
|
||||
has_import = "import " in content and "from " in content
|
||||
has_export = "export " in content or "export default" in content or "export {" in content
|
||||
# Look for CommonJS syntax
|
||||
has_require = "require(" in content
|
||||
has_module_exports = "module.exports" in content or "exports." in content
|
||||
|
||||
# Look for CommonJS syntax
|
||||
has_require = "require(" in content
|
||||
has_module_exports = "module.exports" in content or "exports." in content
|
||||
# Determine based on what we found
|
||||
if (has_import or has_export) and not (has_require or has_module_exports):
|
||||
logger.debug("Detected ES Module from import/export statements")
|
||||
return ModuleSystem.ES_MODULE
|
||||
|
||||
# Determine based on what we found
|
||||
if (has_import or has_export) and not (has_require or has_module_exports):
|
||||
logger.debug("Detected ES Module from import/export statements")
|
||||
return ModuleSystem.ES_MODULE
|
||||
if (has_require or has_module_exports) and not (has_import or has_export):
|
||||
logger.debug("Detected CommonJS from require/module.exports")
|
||||
return ModuleSystem.COMMONJS
|
||||
|
||||
if (has_require or has_module_exports) and not (has_import or has_export):
|
||||
logger.debug("Detected CommonJS from require/module.exports")
|
||||
return ModuleSystem.COMMONJS
|
||||
|
||||
except Exception as e:
|
||||
logger.warning("Failed to analyze file %s: %s", file_path, e)
|
||||
except Exception as e:
|
||||
logger.warning("Failed to analyze file %s: %s", file_path, e)
|
||||
|
||||
# Default to CommonJS (more common and backward compatible)
|
||||
logger.debug("Defaulting to CommonJS")
|
||||
|
|
@ -199,11 +209,42 @@ def add_js_extension(module_path: str) -> str:
|
|||
return module_path
|
||||
|
||||
|
||||
def _convert_destructuring_to_imports(names_str: str) -> str:
|
||||
"""Convert destructuring aliases to import aliases.
|
||||
|
||||
Converts:
|
||||
a, b -> a, b
|
||||
a: aliasA -> a as aliasA
|
||||
a, b: aliasB -> a, b as aliasB
|
||||
|
||||
Args:
|
||||
names_str: The destructuring pattern string (e.g., "a, b: aliasB")
|
||||
|
||||
Returns:
|
||||
Import names string with aliases using 'as' syntax
|
||||
|
||||
"""
|
||||
# Split by commas and process each name
|
||||
parts = []
|
||||
for name in names_str.split(","):
|
||||
name = name.strip()
|
||||
if ":" in name:
|
||||
# Convert destructuring alias to import alias
|
||||
# "a: aliasA" -> "a as aliasA"
|
||||
original, alias = name.split(":", 1)
|
||||
parts.append(f"{original.strip()} as {alias.strip()}")
|
||||
else:
|
||||
parts.append(name)
|
||||
return ", ".join(parts)
|
||||
|
||||
|
||||
# Replace destructured requires with named imports
|
||||
def replace_destructured(match: re.Match) -> str:
|
||||
names = match.group(2).strip()
|
||||
module_path = add_js_extension(match.group(3))
|
||||
return f"import {{ {names} }} from '{module_path}';"
|
||||
# Convert destructuring aliases (a: b) to import aliases (a as b)
|
||||
converted_names = _convert_destructuring_to_imports(names)
|
||||
return f"import {{ {converted_names} }} from '{module_path}';"
|
||||
|
||||
|
||||
# Replace property access requires with named imports with alias
|
||||
|
|
@ -234,12 +275,14 @@ def convert_commonjs_to_esm(code: str) -> str:
|
|||
"""Convert CommonJS require statements to ES Module imports.
|
||||
|
||||
Converts:
|
||||
const { foo, bar } = require('./module'); -> import { foo, bar } from './module';
|
||||
const foo = require('./module'); -> import foo from './module';
|
||||
const foo = require('./module').default; -> import foo from './module';
|
||||
const foo = require('./module').bar; -> import { bar as foo } from './module';
|
||||
const { foo, bar } = require('./module'); -> import { foo, bar } from './module';
|
||||
const { foo: alias } = require('./module'); -> import { foo as alias } from './module';
|
||||
const foo = require('./module'); -> import foo from './module';
|
||||
const foo = require('./module').default; -> import foo from './module';
|
||||
const foo = require('./module').bar; -> import { bar as foo } from './module';
|
||||
|
||||
Special handling:
|
||||
- Destructuring aliases (a: b) are converted to import aliases (a as b)
|
||||
- Local codeflash helper (./codeflash-jest-helper) is converted to npm package codeflash
|
||||
because the local helper uses CommonJS exports which don't work in ESM projects
|
||||
|
||||
|
|
@ -299,36 +342,89 @@ def convert_esm_to_commonjs(code: str) -> str:
|
|||
return default_import.sub(replace_default, code)
|
||||
|
||||
|
||||
def ensure_module_system_compatibility(code: str, target_module_system: str) -> str:
|
||||
def uses_ts_jest(project_root: Path) -> bool:
|
||||
"""Check if the project uses ts-jest for TypeScript transformation.
|
||||
|
||||
ts-jest handles module interoperability internally, allowing mixed
|
||||
CommonJS/ESM imports without explicit conversion.
|
||||
|
||||
Args:
|
||||
project_root: The project root directory.
|
||||
|
||||
Returns:
|
||||
True if ts-jest is being used, False otherwise.
|
||||
|
||||
"""
|
||||
# Check for ts-jest in devDependencies or dependencies
|
||||
package_json = project_root / "package.json"
|
||||
if package_json.exists():
|
||||
try:
|
||||
with package_json.open("r") as f:
|
||||
pkg = json.load(f)
|
||||
dev_deps = pkg.get("devDependencies", {})
|
||||
deps = pkg.get("dependencies", {})
|
||||
if "ts-jest" in dev_deps or "ts-jest" in deps:
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.debug(f"Failed to read package.json for ts-jest detection: {e}") # noqa: G004
|
||||
|
||||
# Also check for jest.config with ts-jest preset
|
||||
for config_file in ["jest.config.js", "jest.config.cjs", "jest.config.ts", "jest.config.mjs"]:
|
||||
config_path = project_root / config_file
|
||||
if config_path.exists():
|
||||
try:
|
||||
content = config_path.read_text()
|
||||
if "ts-jest" in content:
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.debug(f"Failed to read {config_file}: {e}") # noqa: G004
|
||||
|
||||
return False
|
||||
|
||||
|
||||
def ensure_module_system_compatibility(code: str, target_module_system: str, project_root: Path | None = None) -> str:
|
||||
"""Ensure code uses the correct module system syntax.
|
||||
|
||||
Detects the current module system in the code and converts if needed.
|
||||
Handles mixed-style code (e.g., ESM imports with CommonJS require for npm packages).
|
||||
If the project uses ts-jest, no conversion is performed because ts-jest
|
||||
handles module interoperability internally. Otherwise, converts between
|
||||
CommonJS and ES Modules as needed.
|
||||
|
||||
Args:
|
||||
code: JavaScript code to check and potentially convert.
|
||||
target_module_system: Target ModuleSystem (COMMONJS or ES_MODULE).
|
||||
project_root: Project root directory for ts-jest detection.
|
||||
|
||||
Returns:
|
||||
Code with correct module system syntax.
|
||||
Converted code, or unchanged if ts-jest handles interop.
|
||||
|
||||
"""
|
||||
# If ts-jest is installed, skip conversion - it handles interop natively
|
||||
if is_typescript() and project_root and uses_ts_jest(project_root):
|
||||
logger.debug(
|
||||
f"Skipping module system conversion (target was {target_module_system}). " # noqa: G004
|
||||
"ts-jest handles interop natively."
|
||||
)
|
||||
return code
|
||||
|
||||
# Detect current module system in code
|
||||
has_require = "require(" in code
|
||||
has_module_exports = "module.exports" in code or "exports." in code
|
||||
has_import = "import " in code and "from " in code
|
||||
has_export = "export " in code
|
||||
|
||||
if target_module_system == ModuleSystem.ES_MODULE:
|
||||
# Convert any require() statements to imports for ESM projects
|
||||
# This handles mixed code (ESM imports + CommonJS requires for npm packages)
|
||||
if has_require:
|
||||
logger.debug("Converting CommonJS requires to ESM imports")
|
||||
return convert_commonjs_to_esm(code)
|
||||
elif target_module_system == ModuleSystem.COMMONJS:
|
||||
# Convert any import statements to requires for CommonJS projects
|
||||
if has_import:
|
||||
logger.debug("Converting ESM imports to CommonJS requires")
|
||||
return convert_esm_to_commonjs(code)
|
||||
is_commonjs = has_require or has_module_exports
|
||||
is_esm = has_import or has_export
|
||||
|
||||
# Convert if needed
|
||||
if target_module_system == ModuleSystem.ES_MODULE and is_commonjs and not is_esm:
|
||||
logger.debug("Converting CommonJS to ES Module syntax")
|
||||
return convert_commonjs_to_esm(code)
|
||||
|
||||
if target_module_system == ModuleSystem.COMMONJS and is_esm and not is_commonjs:
|
||||
logger.debug("Converting ES Module to CommonJS syntax")
|
||||
return convert_esm_to_commonjs(code)
|
||||
|
||||
logger.debug("No module system conversion needed")
|
||||
return code
|
||||
|
||||
|
||||
|
|
@ -355,12 +451,8 @@ def ensure_vitest_imports(code: str, test_framework: str) -> str:
|
|||
|
||||
# Check if the code uses test functions that need to be imported
|
||||
test_globals = ["describe", "test", "it", "expect", "vi", "beforeEach", "afterEach", "beforeAll", "afterAll"]
|
||||
needs_import = any(f"{global_name}(" in code or f"{global_name} (" in code for global_name in test_globals)
|
||||
|
||||
if not needs_import:
|
||||
return code
|
||||
|
||||
# Determine which globals are actually used in the code
|
||||
# Combine detection and collection into a single pass
|
||||
used_globals = [g for g in test_globals if f"{g}(" in code or f"{g} (" in code]
|
||||
if not used_globals:
|
||||
return code
|
||||
|
|
@ -373,9 +465,14 @@ def ensure_vitest_imports(code: str, test_framework: str) -> str:
|
|||
insert_index = 0
|
||||
for i, line in enumerate(lines):
|
||||
stripped = line.strip()
|
||||
if stripped and not stripped.startswith("//") and not stripped.startswith("/*") and not stripped.startswith("*"):
|
||||
if (
|
||||
stripped
|
||||
and not stripped.startswith("//")
|
||||
and not stripped.startswith("/*")
|
||||
and not stripped.startswith("*")
|
||||
):
|
||||
# Check if this line is an import/require - insert after imports
|
||||
if stripped.startswith("import ") or stripped.startswith("const ") or stripped.startswith("let "):
|
||||
if stripped.startswith(("import ", "const ", "let ")):
|
||||
continue
|
||||
insert_index = i
|
||||
break
|
||||
|
|
|
|||
|
|
@ -296,6 +296,33 @@ def _find_node_project_root(file_path: Path) -> Path | None:
|
|||
return None
|
||||
|
||||
|
||||
def _find_monorepo_root(start_path: Path) -> Path | None:
|
||||
"""Find the monorepo workspace root by looking for workspace markers.
|
||||
|
||||
Traverses up from the given path to find a directory containing
|
||||
monorepo workspace markers like yarn.lock, pnpm-workspace.yaml, etc.
|
||||
|
||||
Args:
|
||||
start_path: A path within the monorepo.
|
||||
|
||||
Returns:
|
||||
The monorepo root directory, or None if not found.
|
||||
|
||||
"""
|
||||
monorepo_markers = ["yarn.lock", "pnpm-workspace.yaml", "lerna.json", "package-lock.json"]
|
||||
current = start_path if start_path.is_dir() else start_path.parent
|
||||
|
||||
while current != current.parent:
|
||||
# Check for monorepo markers
|
||||
if any((current / marker).exists() for marker in monorepo_markers):
|
||||
# Verify it has node_modules (it's the workspace root)
|
||||
if (current / "node_modules").exists():
|
||||
return current
|
||||
current = current.parent
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def _find_jest_config(project_root: Path) -> Path | None:
|
||||
"""Find Jest configuration file in the project.
|
||||
|
||||
|
|
@ -797,6 +824,12 @@ def run_jest_benchmarking_tests(
|
|||
jest_env["JEST_JUNIT_SUITE_NAME"] = "{filepath}"
|
||||
jest_env["JEST_JUNIT_ADD_FILE_ATTRIBUTE"] = "true"
|
||||
jest_env["JEST_JUNIT_INCLUDE_CONSOLE_OUTPUT"] = "true"
|
||||
|
||||
# Pass monorepo root to loop-runner for jest-runner resolution
|
||||
monorepo_root = _find_monorepo_root(effective_cwd)
|
||||
if monorepo_root:
|
||||
jest_env["CODEFLASH_MONOREPO_ROOT"] = str(monorepo_root)
|
||||
logger.debug(f"Detected monorepo root: {monorepo_root}")
|
||||
codeflash_sqlite_file = get_run_tmp_file(Path("test_return_values_0.sqlite"))
|
||||
jest_env["CODEFLASH_OUTPUT_FILE"] = str(codeflash_sqlite_file)
|
||||
jest_env["CODEFLASH_TEST_ITERATION"] = "0"
|
||||
|
|
|
|||
|
|
@ -321,7 +321,7 @@ class PythonSupport:
|
|||
|
||||
function_pos = None
|
||||
for name in names:
|
||||
if name.type == "function" and name.name == function.name:
|
||||
if name.type == "function" and name.name == function.function_name:
|
||||
# Check for class parent if it's a method
|
||||
if function.class_name:
|
||||
parent = name.parent()
|
||||
|
|
|
|||
|
|
@ -258,6 +258,11 @@ class TreeSitterAnalyzer:
|
|||
if func_info.is_arrow and not include_arrow_functions:
|
||||
should_include = False
|
||||
|
||||
# Skip arrow functions that are object properties (e.g., { foo: () => {} })
|
||||
# These are not standalone functions - they're values in object literals
|
||||
if func_info.is_arrow and node.parent and node.parent.type == "pair":
|
||||
should_include = False
|
||||
|
||||
if should_include:
|
||||
functions.append(func_info)
|
||||
|
||||
|
|
|
|||
|
|
@ -15,6 +15,7 @@ from codeflash.models.test_type import TestType
|
|||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Iterator
|
||||
|
||||
import enum
|
||||
import re
|
||||
import sys
|
||||
|
|
@ -325,9 +326,7 @@ class CodeStringsMarkdown(BaseModel):
|
|||
"""
|
||||
if "file_to_path" in self._cache:
|
||||
return self._cache["file_to_path"]
|
||||
result = {
|
||||
str(code_string.file_path): code_string.code for code_string in self.code_strings
|
||||
}
|
||||
result = {str(code_string.file_path): code_string.code for code_string in self.code_strings}
|
||||
self._cache["file_to_path"] = result
|
||||
return result
|
||||
|
||||
|
|
@ -877,15 +876,14 @@ class TestResults(BaseModel): # noqa: PLW1641
|
|||
return max(test_result.loop_index for test_result in self.test_results)
|
||||
|
||||
def get_test_pass_fail_report_by_type(self) -> dict[TestType, dict[str, int]]:
|
||||
report = {}
|
||||
for test_type in TestType:
|
||||
report[test_type] = {"passed": 0, "failed": 0}
|
||||
report: dict[TestType, dict[str, int]] = {tt: {"passed": 0, "failed": 0} for tt in TestType}
|
||||
for test_result in self.test_results:
|
||||
if test_result.loop_index == 1:
|
||||
if test_result.did_pass:
|
||||
report[test_result.test_type]["passed"] += 1
|
||||
else:
|
||||
report[test_result.test_type]["failed"] += 1
|
||||
if test_result.loop_index != 1:
|
||||
continue
|
||||
if test_result.did_pass:
|
||||
report[test_result.test_type]["passed"] += 1
|
||||
else:
|
||||
report[test_result.test_type]["failed"] += 1
|
||||
return report
|
||||
|
||||
@staticmethod
|
||||
|
|
|
|||
|
|
@ -12,11 +12,13 @@ class TestType(Enum):
|
|||
def to_name(self) -> str:
|
||||
if self is TestType.INIT_STATE_TEST:
|
||||
return ""
|
||||
names = {
|
||||
TestType.EXISTING_UNIT_TEST: "⚙️ Existing Unit Tests",
|
||||
TestType.INSPIRED_REGRESSION: "🎨 Inspired Regression Tests",
|
||||
TestType.GENERATED_REGRESSION: "🌀 Generated Regression Tests",
|
||||
TestType.REPLAY_TEST: "⏪ Replay Tests",
|
||||
TestType.CONCOLIC_COVERAGE_TEST: "🔎 Concolic Coverage Tests",
|
||||
}
|
||||
return names[self]
|
||||
return _TO_NAME_MAP[self]
|
||||
|
||||
|
||||
_TO_NAME_MAP: dict[TestType, str] = {
|
||||
TestType.EXISTING_UNIT_TEST: "⚙️ Existing Unit Tests",
|
||||
TestType.INSPIRED_REGRESSION: "🎨 Inspired Regression Tests",
|
||||
TestType.GENERATED_REGRESSION: "🌀 Generated Regression Tests",
|
||||
TestType.REPLAY_TEST: "⏪ Replay Tests",
|
||||
TestType.CONCOLIC_COVERAGE_TEST: "🔎 Concolic Coverage Tests",
|
||||
}
|
||||
|
|
|
|||
|
|
@ -2,7 +2,6 @@ from __future__ import annotations
|
|||
|
||||
import ast
|
||||
import concurrent.futures
|
||||
import logging
|
||||
import os
|
||||
import queue
|
||||
import random
|
||||
|
|
@ -23,7 +22,7 @@ from rich.tree import Tree
|
|||
from codeflash.api.aiservice import AiServiceClient, AIServiceRefinerRequest, LocalAiServiceClient
|
||||
from codeflash.api.cfapi import add_code_context_hash, create_staging, get_cfapi_base_urls, mark_optimization_success
|
||||
from codeflash.benchmarking.utils import process_benchmark_data
|
||||
from codeflash.cli_cmds.console import code_print, console, logger, lsp_log, progress_bar
|
||||
from codeflash.cli_cmds.console import DEBUG_MODE, code_print, console, logger, lsp_log, progress_bar
|
||||
from codeflash.code_utils import env_utils
|
||||
from codeflash.code_utils.code_extractor import get_opt_review_metrics, is_numerical_code
|
||||
from codeflash.code_utils.code_replacer import (
|
||||
|
|
@ -146,9 +145,70 @@ if TYPE_CHECKING:
|
|||
from codeflash.verification.verification_utils import TestConfig
|
||||
|
||||
|
||||
def log_code_after_replacement(file_path: Path, candidate_index: int) -> None:
|
||||
"""Log the full file content after code replacement in verbose mode."""
|
||||
if not DEBUG_MODE:
|
||||
return
|
||||
|
||||
try:
|
||||
code = file_path.read_text(encoding="utf-8")
|
||||
lang_map = {".java": "java", ".py": "python", ".js": "javascript", ".ts": "typescript"}
|
||||
language = lang_map.get(file_path.suffix.lower(), "text")
|
||||
|
||||
console.print(
|
||||
Panel(
|
||||
Syntax(code, language, line_numbers=True, theme="monokai", word_wrap=True),
|
||||
title=f"[bold blue]Code After Replacement (Candidate {candidate_index})[/] [dim]({file_path.name})[/]",
|
||||
border_style="blue",
|
||||
)
|
||||
)
|
||||
except Exception as e:
|
||||
logger.debug(f"Failed to log code after replacement: {e}")
|
||||
|
||||
|
||||
def log_instrumented_test(test_source: str, test_name: str, test_type: str, language: str) -> None:
|
||||
"""Log instrumented test code in verbose mode."""
|
||||
if not DEBUG_MODE:
|
||||
return
|
||||
|
||||
display_source = test_source
|
||||
if len(test_source) > 15000:
|
||||
display_source = test_source[:15000] + "\n\n... [truncated] ..."
|
||||
|
||||
console.print(
|
||||
Panel(
|
||||
Syntax(display_source, language, line_numbers=True, theme="monokai", word_wrap=True),
|
||||
title=f"[bold magenta]Instrumented Test: {test_name}[/] [dim]({test_type})[/]",
|
||||
border_style="magenta",
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def log_test_run_output(stdout: str, stderr: str, test_type: str, returncode: int = 0) -> None:
|
||||
"""Log test run stdout/stderr in verbose mode."""
|
||||
if not DEBUG_MODE:
|
||||
return
|
||||
|
||||
max_len = 10000
|
||||
|
||||
if stdout and stdout.strip():
|
||||
display_stdout = stdout[:max_len] + ("...[truncated]" if len(stdout) > max_len else "")
|
||||
console.print(
|
||||
Panel(
|
||||
display_stdout,
|
||||
title=f"[bold green]{test_type} - stdout[/] [dim](exit: {returncode})[/]",
|
||||
border_style="green" if returncode == 0 else "red",
|
||||
)
|
||||
)
|
||||
|
||||
if stderr and stderr.strip():
|
||||
display_stderr = stderr[:max_len] + ("...[truncated]" if len(stderr) > max_len else "")
|
||||
console.print(Panel(display_stderr, title=f"[bold yellow]{test_type} - stderr[/]", border_style="yellow"))
|
||||
|
||||
|
||||
def log_optimization_context(function_name: str, code_context: CodeOptimizationContext) -> None:
|
||||
"""Log optimization context details when in verbose mode using Rich formatting."""
|
||||
if logger.getEffectiveLevel() > logging.DEBUG:
|
||||
if not DEBUG_MODE:
|
||||
return
|
||||
|
||||
console.rule()
|
||||
|
|
@ -496,6 +556,16 @@ class FunctionOptimizer:
|
|||
should_run_experiment = self.experiment_id is not None
|
||||
logger.info(f"!lsp|Function Trace ID: {self.function_trace_id}")
|
||||
ph("cli-optimize-function-start", {"function_trace_id": self.function_trace_id})
|
||||
|
||||
# Early check: if --no-gen-tests is set, verify there are existing tests for this function
|
||||
if self.args.no_gen_tests:
|
||||
func_qualname = self.function_to_optimize.qualified_name_with_modules_from_root(self.project_root)
|
||||
if not self.function_to_tests.get(func_qualname):
|
||||
return Failure(
|
||||
f"No existing tests found for '{self.function_to_optimize.function_name}'. "
|
||||
f"Cannot optimize without tests when --no-gen-tests is set."
|
||||
)
|
||||
|
||||
self.cleanup_leftover_test_return_values()
|
||||
file_name_from_test_module_name.cache_clear()
|
||||
ctx_result = self.get_code_optimization_context()
|
||||
|
|
@ -566,7 +636,7 @@ class FunctionOptimizer:
|
|||
|
||||
# Normalize codeflash imports in JS/TS tests to use npm package
|
||||
if not is_python():
|
||||
module_system = detect_module_system(self.project_root)
|
||||
module_system = detect_module_system(self.project_root, self.function_to_optimize.file_path)
|
||||
if module_system == "esm":
|
||||
generated_tests = inject_test_globals(generated_tests)
|
||||
if is_typescript():
|
||||
|
|
@ -594,18 +664,32 @@ class FunctionOptimizer:
|
|||
generated_test.instrumented_perf_test_source = modified_perf_source
|
||||
used_behavior_paths.add(behavior_path)
|
||||
|
||||
logger.debug(
|
||||
f"[PIPELINE] Test {i + 1}: behavior_path={behavior_path}, perf_path={perf_path}"
|
||||
)
|
||||
logger.debug(f"[PIPELINE] Test {i + 1}: behavior_path={behavior_path}, perf_path={perf_path}")
|
||||
|
||||
with behavior_path.open("w", encoding="utf8") as f:
|
||||
f.write(generated_test.instrumented_behavior_test_source)
|
||||
logger.debug(f"[PIPELINE] Wrote behavioral test to {behavior_path}")
|
||||
|
||||
# Verbose: Log instrumented behavior test
|
||||
log_instrumented_test(
|
||||
generated_test.instrumented_behavior_test_source,
|
||||
behavior_path.name,
|
||||
"Behavioral Test",
|
||||
language=self.function_to_optimize.language,
|
||||
)
|
||||
|
||||
with perf_path.open("w", encoding="utf8") as f:
|
||||
f.write(generated_test.instrumented_perf_test_source)
|
||||
logger.debug(f"[PIPELINE] Wrote perf test to {perf_path}")
|
||||
|
||||
# Verbose: Log instrumented performance test
|
||||
log_instrumented_test(
|
||||
generated_test.instrumented_perf_test_source,
|
||||
perf_path.name,
|
||||
"Performance Test",
|
||||
language=self.function_to_optimize.language,
|
||||
)
|
||||
|
||||
# File paths are expected to be absolute - resolved at their source (CLI, TestConfig, etc.)
|
||||
test_file_obj = TestFile(
|
||||
instrumented_behavior_file_path=generated_test.behavior_file_path,
|
||||
|
|
@ -675,22 +759,24 @@ class FunctionOptimizer:
|
|||
parts = tests_root.parts
|
||||
|
||||
# Look for standard Java package prefixes that indicate the start of package structure
|
||||
standard_package_prefixes = ('com', 'org', 'net', 'io', 'edu', 'gov')
|
||||
standard_package_prefixes = ("com", "org", "net", "io", "edu", "gov")
|
||||
|
||||
for i, part in enumerate(parts):
|
||||
if part in standard_package_prefixes:
|
||||
# Found start of package path, return everything before it
|
||||
if i > 0:
|
||||
java_sources_root = Path(*parts[:i])
|
||||
logger.debug(f"[JAVA] Detected Java sources root: {java_sources_root} (from tests_root: {tests_root})")
|
||||
logger.debug(
|
||||
f"[JAVA] Detected Java sources root: {java_sources_root} (from tests_root: {tests_root})"
|
||||
)
|
||||
return java_sources_root
|
||||
|
||||
# If no standard package prefix found, check if there's a 'java' directory
|
||||
# (standard Maven structure: src/test/java)
|
||||
for i, part in enumerate(parts):
|
||||
if part == 'java' and i > 0:
|
||||
if part == "java" and i > 0:
|
||||
# Return up to and including 'java'
|
||||
java_sources_root = Path(*parts[:i + 1])
|
||||
java_sources_root = Path(*parts[: i + 1])
|
||||
logger.debug(f"[JAVA] Detected Maven-style Java sources root: {java_sources_root}")
|
||||
return java_sources_root
|
||||
|
||||
|
|
@ -721,16 +807,16 @@ class FunctionOptimizer:
|
|||
import re
|
||||
|
||||
# Extract package from behavior source
|
||||
package_match = re.search(r'^\s*package\s+([\w.]+)\s*;', behavior_source, re.MULTILINE)
|
||||
package_match = re.search(r"^\s*package\s+([\w.]+)\s*;", behavior_source, re.MULTILINE)
|
||||
package_name = package_match.group(1) if package_match else ""
|
||||
|
||||
# Extract class name from behavior source
|
||||
# Use more specific pattern to avoid matching words like "command" or text in comments
|
||||
class_match = re.search(r'^(?:public\s+)?class\s+(\w+)', behavior_source, re.MULTILINE)
|
||||
class_match = re.search(r"^(?:public\s+)?class\s+(\w+)", behavior_source, re.MULTILINE)
|
||||
behavior_class = class_match.group(1) if class_match else "GeneratedTest"
|
||||
|
||||
# Extract class name from perf source
|
||||
perf_class_match = re.search(r'^(?:public\s+)?class\s+(\w+)', perf_source, re.MULTILINE)
|
||||
perf_class_match = re.search(r"^(?:public\s+)?class\s+(\w+)", perf_source, re.MULTILINE)
|
||||
perf_class = perf_class_match.group(1) if perf_class_match else "GeneratedPerfTest"
|
||||
|
||||
# Build paths with package structure
|
||||
|
|
@ -770,22 +856,20 @@ class FunctionOptimizer:
|
|||
perf_path = new_perf_path
|
||||
# Rename class in source code - replace the class declaration
|
||||
modified_behavior_source = re.sub(
|
||||
rf'^((?:public\s+)?class\s+){re.escape(behavior_class)}(\b)',
|
||||
rf'\g<1>{new_behavior_class}\g<2>',
|
||||
rf"^((?:public\s+)?class\s+){re.escape(behavior_class)}(\b)",
|
||||
rf"\g<1>{new_behavior_class}\g<2>",
|
||||
behavior_source,
|
||||
count=1,
|
||||
flags=re.MULTILINE,
|
||||
)
|
||||
modified_perf_source = re.sub(
|
||||
rf'^((?:public\s+)?class\s+){re.escape(perf_class)}(\b)',
|
||||
rf'\g<1>{new_perf_class}\g<2>',
|
||||
rf"^((?:public\s+)?class\s+){re.escape(perf_class)}(\b)",
|
||||
rf"\g<1>{new_perf_class}\g<2>",
|
||||
perf_source,
|
||||
count=1,
|
||||
flags=re.MULTILINE,
|
||||
)
|
||||
logger.debug(
|
||||
f"[JAVA] Renamed duplicate test class from {behavior_class} to {new_behavior_class}"
|
||||
)
|
||||
logger.debug(f"[JAVA] Renamed duplicate test class from {behavior_class} to {new_behavior_class}")
|
||||
break
|
||||
index += 1
|
||||
|
||||
|
|
@ -1202,6 +1286,9 @@ class FunctionOptimizer:
|
|||
logger.info("No functions were replaced in the optimized code. Skipping optimization candidate.")
|
||||
console.rule()
|
||||
return None
|
||||
|
||||
# Verbose: Log code after replacement
|
||||
log_code_after_replacement(self.function_to_optimize.file_path, candidate_index)
|
||||
except (ValueError, SyntaxError, cst.ParserSyntaxError, AttributeError) as e:
|
||||
logger.error(e)
|
||||
self.write_code_and_helpers(
|
||||
|
|
@ -1767,6 +1854,14 @@ class FunctionOptimizer:
|
|||
with new_behavioral_test_path.open("w", encoding="utf8") as _f:
|
||||
_f.write(injected_behavior_test)
|
||||
logger.debug(f"[PIPELINE] Wrote instrumented behavior test to {new_behavioral_test_path}")
|
||||
|
||||
# Verbose: Log instrumented existing behavior test
|
||||
log_instrumented_test(
|
||||
injected_behavior_test,
|
||||
new_behavioral_test_path.name,
|
||||
"Existing Behavioral Test",
|
||||
language=self.function_to_optimize.language,
|
||||
)
|
||||
else:
|
||||
msg = "injected_behavior_test is None"
|
||||
raise ValueError(msg)
|
||||
|
|
@ -1776,6 +1871,14 @@ class FunctionOptimizer:
|
|||
_f.write(injected_perf_test)
|
||||
logger.debug(f"[PIPELINE] Wrote instrumented perf test to {new_perf_test_path}")
|
||||
|
||||
# Verbose: Log instrumented existing performance test
|
||||
log_instrumented_test(
|
||||
injected_perf_test,
|
||||
new_perf_test_path.name,
|
||||
"Existing Performance Test",
|
||||
language=self.function_to_optimize.language,
|
||||
)
|
||||
|
||||
unique_instrumented_test_files.add(new_behavioral_test_path)
|
||||
unique_instrumented_test_files.add(new_perf_test_path)
|
||||
|
||||
|
|
@ -2242,7 +2345,7 @@ class FunctionOptimizer:
|
|||
formatted_generated_test = format_generated_code(concolic_test_str, self.args.formatter_cmds)
|
||||
generated_tests_str += f"```{code_lang}\n{formatted_generated_test}\n```\n\n"
|
||||
|
||||
existing_tests, replay_tests, concolic_tests = existing_tests_source_for(
|
||||
existing_tests, replay_tests, _concolic_tests = existing_tests_source_for(
|
||||
self.function_to_optimize.qualified_name_with_modules_from_root(self.project_root),
|
||||
function_to_all_tests,
|
||||
test_cfg=self.test_cfg,
|
||||
|
|
@ -2885,6 +2988,11 @@ class FunctionOptimizer:
|
|||
else:
|
||||
msg = f"Unexpected testing type: {testing_type}"
|
||||
raise ValueError(msg)
|
||||
|
||||
# Verbose: Log test run output
|
||||
log_test_run_output(
|
||||
run_result.stdout, run_result.stderr, f"Test Run ({testing_type.name})", run_result.returncode
|
||||
)
|
||||
except subprocess.TimeoutExpired:
|
||||
logger.exception(
|
||||
f"Error running tests in {', '.join(str(f) for f in test_files.test_files)}.\nTimeout Error"
|
||||
|
|
|
|||
|
|
@ -21,6 +21,8 @@ from typing import Any
|
|||
|
||||
import tomlkit
|
||||
|
||||
_BUILD_DIRS = frozenset({"build", "dist", "out", ".next", ".nuxt"})
|
||||
|
||||
|
||||
@dataclass
|
||||
class DetectedProject:
|
||||
|
|
@ -310,14 +312,21 @@ def _detect_js_module_root(project_root: Path) -> tuple[Path, str]:
|
|||
"""Detect JavaScript/TypeScript module root.
|
||||
|
||||
Priority:
|
||||
1. package.json "exports" field
|
||||
2. package.json "module" field (ESM)
|
||||
3. package.json "main" field (CJS)
|
||||
4. src/ directory
|
||||
5. lib/ directory
|
||||
6. Project root
|
||||
1. src/, lib/, source/ directories (common source directories)
|
||||
2. package.json "exports" field (if not in build output directory)
|
||||
3. package.json "module" field (ESM, if not in build output directory)
|
||||
4. package.json "main" field (CJS, if not in build output directory)
|
||||
5. Project root
|
||||
|
||||
Build output directories (build/, dist/, out/) are skipped since they contain
|
||||
compiled code, not source files.
|
||||
|
||||
"""
|
||||
# Check for common source directories first - these are always preferred
|
||||
for src_dir in ["src", "lib", "source"]:
|
||||
if (project_root / src_dir).is_dir():
|
||||
return project_root / src_dir, f"{src_dir}/ directory"
|
||||
|
||||
package_json_path = project_root / "package.json"
|
||||
package_data: dict[str, Any] = {}
|
||||
|
||||
|
|
@ -334,32 +343,52 @@ def _detect_js_module_root(project_root: Path) -> tuple[Path, str]:
|
|||
entry_path = _extract_entry_path(exports)
|
||||
if entry_path:
|
||||
parent = Path(entry_path).parent
|
||||
if parent != Path() and parent.as_posix() != "." and (project_root / parent).is_dir():
|
||||
if (
|
||||
parent != Path()
|
||||
and parent.as_posix() != "."
|
||||
and (project_root / parent).is_dir()
|
||||
and not is_build_output_dir(parent)
|
||||
):
|
||||
return project_root / parent, f'{parent.as_posix()}/ (from package.json "exports")'
|
||||
|
||||
# Check module field (ESM)
|
||||
module_field = package_data.get("module")
|
||||
if module_field and isinstance(module_field, str):
|
||||
parent = Path(module_field).parent
|
||||
if parent != Path() and parent.as_posix() != "." and (project_root / parent).is_dir():
|
||||
if (
|
||||
parent != Path()
|
||||
and parent.as_posix() != "."
|
||||
and (project_root / parent).is_dir()
|
||||
and not is_build_output_dir(parent)
|
||||
):
|
||||
return project_root / parent, f'{parent.as_posix()}/ (from package.json "module")'
|
||||
|
||||
# Check main field (CJS)
|
||||
main_field = package_data.get("main")
|
||||
if main_field and isinstance(main_field, str):
|
||||
parent = Path(main_field).parent
|
||||
if parent != Path() and parent.as_posix() != "." and (project_root / parent).is_dir():
|
||||
if (
|
||||
parent != Path()
|
||||
and parent.as_posix() != "."
|
||||
and (project_root / parent).is_dir()
|
||||
and not is_build_output_dir(parent)
|
||||
):
|
||||
return project_root / parent, f'{parent.as_posix()}/ (from package.json "main")'
|
||||
|
||||
# Check for common source directories
|
||||
for src_dir in ["src", "lib", "source"]:
|
||||
if (project_root / src_dir).is_dir():
|
||||
return project_root / src_dir, f"{src_dir}/ directory"
|
||||
|
||||
# Default to project root
|
||||
return project_root, "project root"
|
||||
|
||||
|
||||
def is_build_output_dir(path: Path) -> bool:
|
||||
"""Check if a path is within a common build output directory.
|
||||
|
||||
Build output directories contain compiled code and should be skipped
|
||||
in favor of source directories.
|
||||
|
||||
"""
|
||||
return not _BUILD_DIRS.isdisjoint(path.parts)
|
||||
|
||||
|
||||
def _extract_entry_path(exports: Any) -> str | None:
|
||||
"""Extract entry path from package.json exports field."""
|
||||
if isinstance(exports, str):
|
||||
|
|
|
|||
|
|
@ -180,19 +180,31 @@ def resolve_test_file_from_class_path(test_class_path: str, base_dir: Path) -> P
|
|||
# Handle file paths (contain slashes and extensions like .js/.ts)
|
||||
if "/" in test_class_path or "\\" in test_class_path:
|
||||
# This is a file path, not a Python module path
|
||||
# Try the path as-is if it's absolute
|
||||
potential_path = Path(test_class_path)
|
||||
if potential_path.is_absolute() and potential_path.exists():
|
||||
return potential_path
|
||||
|
||||
# Try to resolve relative to base_dir's parent (project root)
|
||||
project_root = base_dir.parent
|
||||
potential_path = project_root / test_class_path
|
||||
if potential_path.exists():
|
||||
return potential_path
|
||||
# Normalize to resolve .. and . components
|
||||
try:
|
||||
potential_path = potential_path.resolve()
|
||||
if potential_path.exists():
|
||||
return potential_path
|
||||
except (OSError, RuntimeError):
|
||||
pass
|
||||
|
||||
# Also try relative to base_dir itself
|
||||
potential_path = base_dir / test_class_path
|
||||
if potential_path.exists():
|
||||
return potential_path
|
||||
# Try the path as-is if it's absolute
|
||||
potential_path = Path(test_class_path)
|
||||
if potential_path.exists():
|
||||
return potential_path
|
||||
try:
|
||||
potential_path = potential_path.resolve()
|
||||
if potential_path.exists():
|
||||
return potential_path
|
||||
except (OSError, RuntimeError):
|
||||
pass
|
||||
|
||||
return None
|
||||
|
||||
# First try the full path (Python module path)
|
||||
|
|
@ -512,8 +524,10 @@ def parse_sqlite_test_results(sqlite_file_path: Path, test_files: TestFiles, tes
|
|||
# Check if the file name matches the module path
|
||||
file_stem = test_file.instrumented_behavior_file_path.stem
|
||||
# The instrumented file has __perfinstrumented suffix
|
||||
original_class = file_stem.replace("__perfinstrumented", "").replace("__perfonlyinstrumented", "")
|
||||
if original_class == test_module_path or file_stem == test_module_path:
|
||||
original_class = file_stem.replace("__perfinstrumented", "").replace(
|
||||
"__perfonlyinstrumented", ""
|
||||
)
|
||||
if test_module_path in (original_class, file_stem):
|
||||
test_file_path = test_file.instrumented_behavior_file_path
|
||||
break
|
||||
# Check original file path
|
||||
|
|
@ -551,7 +565,9 @@ def parse_sqlite_test_results(sqlite_file_path: Path, test_files: TestFiles, tes
|
|||
# Default to GENERATED_REGRESSION for Jest/Java tests when test type can't be determined
|
||||
if test_type is None and (is_jest or is_java_test):
|
||||
test_type = TestType.GENERATED_REGRESSION
|
||||
logger.debug(f"[PARSE-DEBUG] defaulting to GENERATED_REGRESSION ({'Jest' if is_jest else 'Java'})")
|
||||
logger.debug(
|
||||
f"[PARSE-DEBUG] defaulting to GENERATED_REGRESSION ({'Jest' if is_jest else 'Java'})"
|
||||
)
|
||||
elif test_type is None:
|
||||
# Skip results where test type cannot be determined
|
||||
logger.debug(f"Skipping result for {test_function_name}: could not determine test type")
|
||||
|
|
@ -791,16 +807,25 @@ def parse_jest_test_xml(
|
|||
if not test_file_path.exists():
|
||||
test_file_path = base_dir / test_file_name
|
||||
|
||||
if test_file_path is None or not test_file_path.exists():
|
||||
# For Jest tests in monorepos, test files may not exist after cleanup
|
||||
# but we can still parse results and infer test type from the path
|
||||
if test_file_path is None:
|
||||
logger.warning(f"Could not resolve test file for Jest test: {test_class_path}")
|
||||
continue
|
||||
|
||||
# Get test type if not already set from lookup
|
||||
if test_type is None:
|
||||
if test_type is None and test_file_path.exists():
|
||||
test_type = test_files.get_test_type_by_instrumented_file_path(test_file_path)
|
||||
if test_type is None:
|
||||
# Default to GENERATED_REGRESSION for Jest tests
|
||||
test_type = TestType.GENERATED_REGRESSION
|
||||
# Infer test type from filename pattern
|
||||
filename = test_file_path.name
|
||||
if "__perf_test_" in filename or "_perf_test_" in filename:
|
||||
test_type = TestType.GENERATED_PERFORMANCE
|
||||
elif "__unit_test_" in filename or "_unit_test_" in filename:
|
||||
test_type = TestType.GENERATED_REGRESSION
|
||||
else:
|
||||
# Default to GENERATED_REGRESSION for Jest tests
|
||||
test_type = TestType.GENERATED_REGRESSION
|
||||
|
||||
# For Jest tests, keep the relative file path with extension intact
|
||||
# (Python uses module_name_from_file_path which strips extensions)
|
||||
|
|
|
|||
|
|
@ -132,11 +132,25 @@ def run_behavioral_tests(
|
|||
# Check if there's a language support for this test framework that implements run_behavioral_tests
|
||||
language_support = get_language_support_by_framework(test_framework)
|
||||
if language_support is not None and hasattr(language_support, "run_behavioral_tests"):
|
||||
# Java tests need longer timeout due to Maven startup overhead
|
||||
# Use Java-specific timeout if no explicit timeout provided
|
||||
from codeflash.code_utils.config_consts import JAVA_TESTCASE_TIMEOUT
|
||||
|
||||
effective_timeout = pytest_timeout
|
||||
if test_framework == "junit5" and pytest_timeout is not None:
|
||||
# For Java, use a minimum timeout to account for Maven overhead
|
||||
effective_timeout = max(pytest_timeout, JAVA_TESTCASE_TIMEOUT)
|
||||
if effective_timeout != pytest_timeout:
|
||||
logger.debug(
|
||||
f"Increased Java test timeout from {pytest_timeout}s to {effective_timeout}s "
|
||||
"to account for Maven startup overhead"
|
||||
)
|
||||
|
||||
return language_support.run_behavioral_tests(
|
||||
test_paths=test_paths,
|
||||
test_env=test_env,
|
||||
cwd=cwd,
|
||||
timeout=pytest_timeout,
|
||||
timeout=effective_timeout,
|
||||
project_root=js_project_root,
|
||||
enable_coverage=enable_coverage,
|
||||
candidate_index=candidate_index,
|
||||
|
|
@ -331,11 +345,25 @@ def run_benchmarking_tests(
|
|||
# Check if there's a language support for this test framework that implements run_benchmarking_tests
|
||||
language_support = get_language_support_by_framework(test_framework)
|
||||
if language_support is not None and hasattr(language_support, "run_benchmarking_tests"):
|
||||
# Java tests need longer timeout due to Maven startup overhead
|
||||
# Use Java-specific timeout if no explicit timeout provided
|
||||
from codeflash.code_utils.config_consts import JAVA_TESTCASE_TIMEOUT
|
||||
|
||||
effective_timeout = pytest_timeout
|
||||
if test_framework == "junit5" and pytest_timeout is not None:
|
||||
# For Java, use a minimum timeout to account for Maven overhead
|
||||
effective_timeout = max(pytest_timeout, JAVA_TESTCASE_TIMEOUT)
|
||||
if effective_timeout != pytest_timeout:
|
||||
logger.debug(
|
||||
f"Increased Java test timeout from {pytest_timeout}s to {effective_timeout}s "
|
||||
"to account for Maven startup overhead"
|
||||
)
|
||||
|
||||
return language_support.run_benchmarking_tests(
|
||||
test_paths=test_paths,
|
||||
test_env=test_env,
|
||||
cwd=cwd,
|
||||
timeout=pytest_timeout,
|
||||
timeout=effective_timeout,
|
||||
project_root=js_project_root,
|
||||
min_loops=pytest_min_loops,
|
||||
max_loops=pytest_max_loops,
|
||||
|
|
|
|||
|
|
@ -149,7 +149,9 @@ class TestConfig:
|
|||
pom_path = current / "pom.xml"
|
||||
if pom_path.exists():
|
||||
parent_config = detect_java_project(current)
|
||||
if parent_config and (parent_config.has_junit4 or parent_config.has_junit5 or parent_config.has_testng):
|
||||
if parent_config and (
|
||||
parent_config.has_junit4 or parent_config.has_junit5 or parent_config.has_testng
|
||||
):
|
||||
return parent_config.test_framework
|
||||
current = current.parent
|
||||
|
||||
|
|
|
|||
|
|
@ -82,7 +82,10 @@ def generate_tests(
|
|||
)
|
||||
|
||||
# Convert module system if needed (e.g., CommonJS -> ESM for ESM projects)
|
||||
generated_test_source = ensure_module_system_compatibility(generated_test_source, project_module_system)
|
||||
# Skip conversion if ts-jest is installed (handles interop natively)
|
||||
generated_test_source = ensure_module_system_compatibility(
|
||||
generated_test_source, project_module_system, test_cfg.tests_project_rootdir
|
||||
)
|
||||
|
||||
# Ensure vitest imports are present when using vitest framework
|
||||
generated_test_source = ensure_vitest_imports(generated_test_source, test_cfg.test_framework)
|
||||
|
|
|
|||
|
|
@ -1,60 +0,0 @@
|
|||
# codeflash
|
||||
|
||||
AI-powered code performance optimization for JavaScript and TypeScript.
|
||||
|
||||
## Installation
|
||||
|
||||
```bash
|
||||
npm install -g codeflash
|
||||
# or
|
||||
npx codeflash
|
||||
```
|
||||
|
||||
## Quick Start
|
||||
|
||||
1. Get your API key from [codeflash.ai](https://codeflash.ai)
|
||||
|
||||
2. Set your API key:
|
||||
```bash
|
||||
export CODEFLASH_API_KEY=your-api-key
|
||||
```
|
||||
|
||||
3. Optimize a function:
|
||||
```bash
|
||||
codeflash --file src/utils.ts --function slowFunction
|
||||
```
|
||||
|
||||
## Usage
|
||||
|
||||
```bash
|
||||
# Optimize a specific function
|
||||
codeflash --file <path> --function <name>
|
||||
|
||||
# Optimize all functions in a directory
|
||||
codeflash --all src/
|
||||
|
||||
# Initialize GitHub Actions workflow
|
||||
codeflash init-actions
|
||||
|
||||
# Verify setup
|
||||
codeflash --verify-setup
|
||||
```
|
||||
|
||||
## Requirements
|
||||
|
||||
- Node.js >= 16.0.0
|
||||
- A codeflash API key
|
||||
|
||||
## Supported Platforms
|
||||
|
||||
- Linux (x64, arm64)
|
||||
- macOS (x64, arm64)
|
||||
- Windows (x64)
|
||||
|
||||
## Documentation
|
||||
|
||||
See [codeflash.ai/docs](https://codeflash.ai/docs) for full documentation.
|
||||
|
||||
## License
|
||||
|
||||
BSL-1.1
|
||||
|
|
@ -1,47 +0,0 @@
|
|||
#!/usr/bin/env node
|
||||
/**
|
||||
* Wrapper script for codeflash CLI.
|
||||
* Invokes the downloaded binary with all passed arguments.
|
||||
*/
|
||||
|
||||
const { spawn } = require('child_process');
|
||||
const path = require('path');
|
||||
const fs = require('fs');
|
||||
|
||||
function getBinaryPath() {
|
||||
const binDir = __dirname;
|
||||
const isWindows = process.platform === 'win32';
|
||||
return path.join(binDir, isWindows ? 'codeflash.exe' : 'codeflash-binary');
|
||||
}
|
||||
|
||||
function main() {
|
||||
const binaryPath = getBinaryPath();
|
||||
|
||||
if (!fs.existsSync(binaryPath)) {
|
||||
console.error('\x1b[31mError: codeflash binary not found.\x1b[0m');
|
||||
console.error('Try reinstalling: npm install codeflash');
|
||||
process.exit(1);
|
||||
}
|
||||
|
||||
// Pass all arguments to the binary
|
||||
const args = process.argv.slice(2);
|
||||
|
||||
const child = spawn(binaryPath, args, {
|
||||
stdio: 'inherit',
|
||||
env: process.env,
|
||||
});
|
||||
|
||||
child.on('error', (error) => {
|
||||
console.error(`\x1b[31mError running codeflash: ${error.message}\x1b[0m`);
|
||||
process.exit(1);
|
||||
});
|
||||
|
||||
child.on('exit', (code, signal) => {
|
||||
if (signal) {
|
||||
process.exit(1);
|
||||
}
|
||||
process.exit(code || 0);
|
||||
});
|
||||
}
|
||||
|
||||
main();
|
||||
|
|
@ -1,47 +0,0 @@
|
|||
{
|
||||
"name": "codeflash",
|
||||
"version": "0.0.0",
|
||||
"description": "AI-powered code performance optimization - automatically find and fix slow code",
|
||||
"keywords": [
|
||||
"codeflash",
|
||||
"performance",
|
||||
"optimization",
|
||||
"ai",
|
||||
"code",
|
||||
"profiler",
|
||||
"typescript",
|
||||
"javascript"
|
||||
],
|
||||
"author": "CodeFlash Inc. <contact@codeflash.ai>",
|
||||
"license": "BSL-1.1",
|
||||
"homepage": "https://codeflash.ai",
|
||||
"repository": {
|
||||
"type": "git",
|
||||
"url": "git+https://github.com/codeflash-ai/codeflash.git"
|
||||
},
|
||||
"bugs": {
|
||||
"url": "https://github.com/codeflash-ai/codeflash/issues"
|
||||
},
|
||||
"bin": {
|
||||
"codeflash": "./bin/codeflash"
|
||||
},
|
||||
"scripts": {
|
||||
"postinstall": "node lib/install.js"
|
||||
},
|
||||
"engines": {
|
||||
"node": ">=16.0.0"
|
||||
},
|
||||
"os": [
|
||||
"darwin",
|
||||
"linux",
|
||||
"win32"
|
||||
],
|
||||
"cpu": [
|
||||
"x64",
|
||||
"arm64"
|
||||
],
|
||||
"files": [
|
||||
"bin/",
|
||||
"lib/"
|
||||
]
|
||||
}
|
||||
4
packages/codeflash/package-lock.json
generated
4
packages/codeflash/package-lock.json
generated
|
|
@ -1,12 +1,12 @@
|
|||
{
|
||||
"name": "codeflash",
|
||||
"version": "0.5.0",
|
||||
"version": "0.7.0",
|
||||
"lockfileVersion": 3,
|
||||
"requires": true,
|
||||
"packages": {
|
||||
"": {
|
||||
"name": "codeflash",
|
||||
"version": "0.5.0",
|
||||
"version": "0.7.0",
|
||||
"hasInstallScript": true,
|
||||
"license": "MIT",
|
||||
"dependencies": {
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
{
|
||||
"name": "codeflash",
|
||||
"version": "0.5.0",
|
||||
"version": "0.7.0",
|
||||
"description": "Codeflash - AI-powered code optimization for JavaScript and TypeScript",
|
||||
"main": "runtime/index.js",
|
||||
"types": "runtime/index.d.ts",
|
||||
|
|
|
|||
|
|
@ -0,0 +1,73 @@
|
|||
/**
|
||||
* Test: Dynamic environment variable reading
|
||||
*
|
||||
* This test verifies that the performance configuration functions read
|
||||
* environment variables at runtime rather than at module load time.
|
||||
*
|
||||
* This is critical for Vitest compatibility, where modules may be cached
|
||||
* and loaded before environment variables are set.
|
||||
*
|
||||
* Run with: node __tests__/dynamic-env-vars.test.js
|
||||
*/
|
||||
|
||||
const assert = require('assert');
|
||||
|
||||
// Clear any existing env vars before loading the module
|
||||
delete process.env.CODEFLASH_PERF_LOOP_COUNT;
|
||||
delete process.env.CODEFLASH_PERF_MIN_LOOPS;
|
||||
delete process.env.CODEFLASH_PERF_TARGET_DURATION_MS;
|
||||
delete process.env.CODEFLASH_PERF_BATCH_SIZE;
|
||||
delete process.env.CODEFLASH_PERF_STABILITY_CHECK;
|
||||
delete process.env.CODEFLASH_PERF_CURRENT_BATCH;
|
||||
|
||||
// Now load the module - at this point env vars are not set
|
||||
const capture = require('../capture');
|
||||
|
||||
console.log('Testing dynamic environment variable reading...\n');
|
||||
|
||||
// Test 1: Default values when env vars are not set
|
||||
console.log('Test 1: Default values');
|
||||
assert.strictEqual(capture.getPerfLoopCount(), 1, 'getPerfLoopCount default should be 1');
|
||||
assert.strictEqual(capture.getPerfMinLoops(), 5, 'getPerfMinLoops default should be 5');
|
||||
assert.strictEqual(capture.getPerfTargetDurationMs(), 10000, 'getPerfTargetDurationMs default should be 10000');
|
||||
assert.strictEqual(capture.getPerfBatchSize(), 10, 'getPerfBatchSize default should be 10');
|
||||
assert.strictEqual(capture.getPerfStabilityCheck(), false, 'getPerfStabilityCheck default should be false');
|
||||
assert.strictEqual(capture.getPerfCurrentBatch(), 0, 'getPerfCurrentBatch default should be 0');
|
||||
console.log(' PASS: All defaults correct\n');
|
||||
|
||||
// Test 2: Values change when env vars are set AFTER module load
|
||||
// This is the critical test - if these were constants, they would still return defaults
|
||||
console.log('Test 2: Dynamic reading after module load');
|
||||
process.env.CODEFLASH_PERF_LOOP_COUNT = '100';
|
||||
process.env.CODEFLASH_PERF_MIN_LOOPS = '10';
|
||||
process.env.CODEFLASH_PERF_TARGET_DURATION_MS = '5000';
|
||||
process.env.CODEFLASH_PERF_BATCH_SIZE = '20';
|
||||
process.env.CODEFLASH_PERF_STABILITY_CHECK = 'true';
|
||||
process.env.CODEFLASH_PERF_CURRENT_BATCH = '5';
|
||||
|
||||
assert.strictEqual(capture.getPerfLoopCount(), 100, 'getPerfLoopCount should read 100 from env');
|
||||
assert.strictEqual(capture.getPerfMinLoops(), 10, 'getPerfMinLoops should read 10 from env');
|
||||
assert.strictEqual(capture.getPerfTargetDurationMs(), 5000, 'getPerfTargetDurationMs should read 5000 from env');
|
||||
assert.strictEqual(capture.getPerfBatchSize(), 20, 'getPerfBatchSize should read 20 from env');
|
||||
assert.strictEqual(capture.getPerfStabilityCheck(), true, 'getPerfStabilityCheck should read true from env');
|
||||
assert.strictEqual(capture.getPerfCurrentBatch(), 5, 'getPerfCurrentBatch should read 5 from env');
|
||||
console.log(' PASS: Dynamic reading works correctly\n');
|
||||
|
||||
// Test 3: Values change again when env vars are modified
|
||||
console.log('Test 3: Values update when env vars change');
|
||||
process.env.CODEFLASH_PERF_LOOP_COUNT = '500';
|
||||
process.env.CODEFLASH_PERF_BATCH_SIZE = '50';
|
||||
|
||||
assert.strictEqual(capture.getPerfLoopCount(), 500, 'getPerfLoopCount should update to 500');
|
||||
assert.strictEqual(capture.getPerfBatchSize(), 50, 'getPerfBatchSize should update to 50');
|
||||
console.log(' PASS: Values update correctly\n');
|
||||
|
||||
// Cleanup
|
||||
delete process.env.CODEFLASH_PERF_LOOP_COUNT;
|
||||
delete process.env.CODEFLASH_PERF_MIN_LOOPS;
|
||||
delete process.env.CODEFLASH_PERF_TARGET_DURATION_MS;
|
||||
delete process.env.CODEFLASH_PERF_BATCH_SIZE;
|
||||
delete process.env.CODEFLASH_PERF_STABILITY_CHECK;
|
||||
delete process.env.CODEFLASH_PERF_CURRENT_BATCH;
|
||||
|
||||
console.log('All tests passed!');
|
||||
|
|
@ -47,14 +47,30 @@ const TEST_MODULE = process.env.CODEFLASH_TEST_MODULE;
|
|||
// Batch 1: Test1(5 loops) → Test2(5 loops) → Test3(5 loops)
|
||||
// Batch 2: Test1(5 loops) → Test2(5 loops) → Test3(5 loops)
|
||||
// ...until time budget exhausted
|
||||
const PERF_LOOP_COUNT = parseInt(process.env.CODEFLASH_PERF_LOOP_COUNT || '1', 10);
|
||||
const PERF_MIN_LOOPS = parseInt(process.env.CODEFLASH_PERF_MIN_LOOPS || '5', 10);
|
||||
const PERF_TARGET_DURATION_MS = parseInt(process.env.CODEFLASH_PERF_TARGET_DURATION_MS || '10000', 10);
|
||||
const PERF_BATCH_SIZE = parseInt(process.env.CODEFLASH_PERF_BATCH_SIZE || '10', 10);
|
||||
const PERF_STABILITY_CHECK = (process.env.CODEFLASH_PERF_STABILITY_CHECK || 'false').toLowerCase() === 'true';
|
||||
//
|
||||
// IMPORTANT: These are getter functions, NOT constants!
|
||||
// Vitest caches modules and may load this file before env vars are set.
|
||||
// Using getter functions ensures we read the env vars at runtime when they're actually needed.
|
||||
function getPerfLoopCount() {
|
||||
return parseInt(process.env.CODEFLASH_PERF_LOOP_COUNT || '1', 10);
|
||||
}
|
||||
function getPerfMinLoops() {
|
||||
return parseInt(process.env.CODEFLASH_PERF_MIN_LOOPS || '5', 10);
|
||||
}
|
||||
function getPerfTargetDurationMs() {
|
||||
return parseInt(process.env.CODEFLASH_PERF_TARGET_DURATION_MS || '10000', 10);
|
||||
}
|
||||
function getPerfBatchSize() {
|
||||
return parseInt(process.env.CODEFLASH_PERF_BATCH_SIZE || '10', 10);
|
||||
}
|
||||
function getPerfStabilityCheck() {
|
||||
return (process.env.CODEFLASH_PERF_STABILITY_CHECK || 'false').toLowerCase() === 'true';
|
||||
}
|
||||
// Current batch number - set by loop-runner before each batch
|
||||
// This allows continuous loop indices even when Jest resets module state
|
||||
const PERF_CURRENT_BATCH = parseInt(process.env.CODEFLASH_PERF_CURRENT_BATCH || '0', 10);
|
||||
function getPerfCurrentBatch() {
|
||||
return parseInt(process.env.CODEFLASH_PERF_CURRENT_BATCH || '0', 10);
|
||||
}
|
||||
|
||||
// Stability constants (matching Python's config_consts.py)
|
||||
const STABILITY_WINDOW_SIZE = 0.35;
|
||||
|
|
@ -86,7 +102,7 @@ function checkSharedTimeLimit() {
|
|||
return false;
|
||||
}
|
||||
const elapsed = Date.now() - sharedPerfState.startTime;
|
||||
if (elapsed >= PERF_TARGET_DURATION_MS && sharedPerfState.totalLoopsCompleted >= PERF_MIN_LOOPS) {
|
||||
if (elapsed >= getPerfTargetDurationMs() && sharedPerfState.totalLoopsCompleted >= getPerfMinLoops()) {
|
||||
sharedPerfState.shouldStop = true;
|
||||
return true;
|
||||
}
|
||||
|
|
@ -111,7 +127,7 @@ function getInvocationLoopIndex(invocationKey) {
|
|||
// Calculate global loop index using batch number from environment
|
||||
// PERF_CURRENT_BATCH is 1-based (set by loop-runner before each batch)
|
||||
const currentBatch = parseInt(process.env.CODEFLASH_PERF_CURRENT_BATCH || '1', 10);
|
||||
const globalIndex = (currentBatch - 1) * PERF_BATCH_SIZE + localIndex;
|
||||
const globalIndex = (currentBatch - 1) * getPerfBatchSize() + localIndex;
|
||||
|
||||
return globalIndex;
|
||||
}
|
||||
|
|
@ -606,7 +622,7 @@ function capture(funcName, lineId, fn, ...args) {
|
|||
*/
|
||||
function capturePerf(funcName, lineId, fn, ...args) {
|
||||
// Check if we should skip looping entirely (shared time budget exceeded)
|
||||
const shouldLoop = PERF_LOOP_COUNT > 1 && !checkSharedTimeLimit();
|
||||
const shouldLoop = getPerfLoopCount() > 1 && !checkSharedTimeLimit();
|
||||
|
||||
// Get test context (computed once, reused across batch)
|
||||
let testModulePath;
|
||||
|
|
@ -636,9 +652,9 @@ function capturePerf(funcName, lineId, fn, ...args) {
|
|||
// If so, just execute the function once without timing (for test assertions)
|
||||
const peekLoopIndex = (sharedPerfState.invocationLoopCounts[invocationKey] || 0);
|
||||
const currentBatch = parseInt(process.env.CODEFLASH_PERF_CURRENT_BATCH || '1', 10);
|
||||
const nextGlobalIndex = (currentBatch - 1) * PERF_BATCH_SIZE + peekLoopIndex + 1;
|
||||
const nextGlobalIndex = (currentBatch - 1) * getPerfBatchSize() + peekLoopIndex + 1;
|
||||
|
||||
if (shouldLoop && nextGlobalIndex > PERF_LOOP_COUNT) {
|
||||
if (shouldLoop && nextGlobalIndex > getPerfLoopCount()) {
|
||||
// All loops completed, just execute once for test assertion
|
||||
return fn(...args);
|
||||
}
|
||||
|
|
@ -654,7 +670,7 @@ function capturePerf(funcName, lineId, fn, ...args) {
|
|||
// Batched looping: run BATCH_SIZE loops per capturePerf call when using loop-runner
|
||||
// For Vitest (no loop-runner), do all loops internally in a single call
|
||||
const batchSize = shouldLoop
|
||||
? (hasExternalLoopRunner ? PERF_BATCH_SIZE : PERF_LOOP_COUNT)
|
||||
? (hasExternalLoopRunner ? getPerfBatchSize() : getPerfLoopCount())
|
||||
: 1;
|
||||
|
||||
for (let batchIndex = 0; batchIndex < batchSize; batchIndex++) {
|
||||
|
|
@ -667,7 +683,7 @@ function capturePerf(funcName, lineId, fn, ...args) {
|
|||
const loopIndex = getInvocationLoopIndex(invocationKey);
|
||||
|
||||
// Check if we've exceeded max loops for this invocation
|
||||
if (loopIndex > PERF_LOOP_COUNT) {
|
||||
if (loopIndex > getPerfLoopCount()) {
|
||||
break;
|
||||
}
|
||||
|
||||
|
|
@ -872,7 +888,11 @@ module.exports = {
|
|||
LOOP_INDEX,
|
||||
OUTPUT_FILE,
|
||||
TEST_ITERATION,
|
||||
// Batch configuration
|
||||
PERF_BATCH_SIZE,
|
||||
PERF_LOOP_COUNT,
|
||||
// Batch configuration (getter functions for dynamic env var reading)
|
||||
getPerfBatchSize,
|
||||
getPerfLoopCount,
|
||||
getPerfMinLoops,
|
||||
getPerfTargetDurationMs,
|
||||
getPerfStabilityCheck,
|
||||
getPerfCurrentBatch,
|
||||
};
|
||||
|
|
|
|||
|
|
@ -30,13 +30,74 @@
|
|||
|
||||
const { createRequire } = require('module');
|
||||
const path = require('path');
|
||||
const fs = require('fs');
|
||||
|
||||
/**
|
||||
* Resolve jest-runner with monorepo support.
|
||||
* Uses CODEFLASH_MONOREPO_ROOT environment variable if available,
|
||||
* otherwise walks up the directory tree looking for node_modules/jest-runner.
|
||||
*/
|
||||
function resolveJestRunner() {
|
||||
// Try standard resolution first (works in simple projects)
|
||||
try {
|
||||
return require.resolve('jest-runner');
|
||||
} catch (e) {
|
||||
// Standard resolution failed - try monorepo-aware resolution
|
||||
}
|
||||
|
||||
// If Python detected a monorepo root, check there first
|
||||
const monorepoRoot = process.env.CODEFLASH_MONOREPO_ROOT;
|
||||
if (monorepoRoot) {
|
||||
const jestRunnerPath = path.join(monorepoRoot, 'node_modules', 'jest-runner');
|
||||
if (fs.existsSync(jestRunnerPath)) {
|
||||
const packageJsonPath = path.join(jestRunnerPath, 'package.json');
|
||||
if (fs.existsSync(packageJsonPath)) {
|
||||
return jestRunnerPath;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Fallback: Walk up from cwd looking for node_modules/jest-runner
|
||||
const monorepoMarkers = ['yarn.lock', 'pnpm-workspace.yaml', 'lerna.json', 'package-lock.json'];
|
||||
let currentDir = process.cwd();
|
||||
const visitedDirs = new Set();
|
||||
|
||||
while (currentDir !== path.dirname(currentDir)) {
|
||||
// Avoid infinite loops
|
||||
if (visitedDirs.has(currentDir)) break;
|
||||
visitedDirs.add(currentDir);
|
||||
|
||||
// Try node_modules/jest-runner at this level
|
||||
const jestRunnerPath = path.join(currentDir, 'node_modules', 'jest-runner');
|
||||
if (fs.existsSync(jestRunnerPath)) {
|
||||
const packageJsonPath = path.join(jestRunnerPath, 'package.json');
|
||||
if (fs.existsSync(packageJsonPath)) {
|
||||
return jestRunnerPath;
|
||||
}
|
||||
}
|
||||
|
||||
// Check if this is a workspace root (has monorepo markers)
|
||||
const isWorkspaceRoot = monorepoMarkers.some(marker =>
|
||||
fs.existsSync(path.join(currentDir, marker))
|
||||
);
|
||||
|
||||
if (isWorkspaceRoot) {
|
||||
// Found workspace root but no jest-runner - stop searching
|
||||
break;
|
||||
}
|
||||
|
||||
currentDir = path.dirname(currentDir);
|
||||
}
|
||||
|
||||
throw new Error('jest-runner not found');
|
||||
}
|
||||
|
||||
// Try to load jest-runner - it's a peer dependency that must be installed by the user
|
||||
let runTest;
|
||||
let jestRunnerAvailable = false;
|
||||
|
||||
try {
|
||||
const jestRunnerPath = require.resolve('jest-runner');
|
||||
const jestRunnerPath = resolveJestRunner();
|
||||
const internalRequire = createRequire(jestRunnerPath);
|
||||
runTest = internalRequire('./runTest').default;
|
||||
jestRunnerAvailable = true;
|
||||
|
|
|
|||
|
|
@ -79,8 +79,9 @@ dev = [
|
|||
"types-greenlet>=3.1.0.20241221,<4",
|
||||
"types-pexpect>=4.9.0.20241208,<5",
|
||||
"types-unidiff>=0.7.0.20240505,<0.8",
|
||||
"uv>=0.6.2",
|
||||
"pre-commit>=4.2.0,<5",
|
||||
"prek>=0.2.25",
|
||||
"ty>=0.0.14",
|
||||
"uv>=0.9.29",
|
||||
]
|
||||
tests = [
|
||||
"black>=25.9.0",
|
||||
|
|
@ -272,6 +273,7 @@ ignore = [
|
|||
"ANN401", # typing.Any disallowed
|
||||
"ARG001", # Unused function argument (common in abstract/interface methods)
|
||||
"TRY300", # Consider moving to else block
|
||||
"FURB110", # if-exp-instead-of-or-operator - we prefer explicit if-else over "or"
|
||||
"TRY401", # Redundant exception in logging.exception
|
||||
"PLR0911", # Too many return statements
|
||||
"PLW0603", # Global statement
|
||||
|
|
|
|||
|
|
@ -127,13 +127,14 @@ class TestDetectModuleRoot:
|
|||
assert result == "lib"
|
||||
|
||||
def test_detects_from_exports_object_dot(self, tmp_path: Path) -> None:
|
||||
"""Should detect module root from exports object with '.' key."""
|
||||
"""Should skip build output dirs and return '.' when no src dir exists."""
|
||||
(tmp_path / "dist").mkdir()
|
||||
package_data = {"exports": {".": "./dist/index.js"}}
|
||||
|
||||
result = detect_module_root(tmp_path, package_data)
|
||||
|
||||
assert result == "dist"
|
||||
# dist is a build output directory, so it's skipped
|
||||
assert result == "."
|
||||
|
||||
def test_detects_from_exports_object_nested(self, tmp_path: Path) -> None:
|
||||
"""Should detect module root from nested exports object."""
|
||||
|
|
@ -227,13 +228,14 @@ class TestDetectModuleRoot:
|
|||
assert result == "src"
|
||||
|
||||
def test_handles_deeply_nested_exports(self, tmp_path: Path) -> None:
|
||||
"""Should handle deeply nested export paths."""
|
||||
"""Should handle deeply nested export paths but skip build output dirs."""
|
||||
(tmp_path / "packages" / "core" / "dist").mkdir(parents=True)
|
||||
package_data = {"exports": {".": {"import": "./packages/core/dist/index.mjs"}}}
|
||||
|
||||
result = detect_module_root(tmp_path, package_data)
|
||||
|
||||
assert result == "packages/core/dist"
|
||||
# dist is a build output directory, so it's skipped even when nested
|
||||
assert result == "."
|
||||
|
||||
def test_handles_empty_exports(self, tmp_path: Path) -> None:
|
||||
"""Should handle empty exports gracefully."""
|
||||
|
|
@ -756,7 +758,7 @@ class TestRealWorldPackageJsonExamples:
|
|||
assert config["formatter_cmds"] == ["npx eslint --fix $file"]
|
||||
|
||||
def test_library_with_exports(self, tmp_path: Path) -> None:
|
||||
"""Should handle library with modern exports field."""
|
||||
"""Should handle library with modern exports field, skipping build output dirs."""
|
||||
(tmp_path / "dist").mkdir()
|
||||
package_json = tmp_path / "package.json"
|
||||
package_json.write_text(
|
||||
|
|
@ -773,7 +775,8 @@ class TestRealWorldPackageJsonExamples:
|
|||
|
||||
assert result is not None
|
||||
config, _ = result
|
||||
assert config["module_root"] == str((tmp_path / "dist").resolve())
|
||||
# dist is a build output directory, so it's skipped and falls back to project root
|
||||
assert config["module_root"] == str(tmp_path.resolve())
|
||||
|
||||
def test_monorepo_package(self, tmp_path: Path) -> None:
|
||||
"""Should handle monorepo package configuration."""
|
||||
|
|
|
|||
|
|
@ -24,6 +24,8 @@ from pathlib import Path
|
|||
|
||||
import pytest
|
||||
|
||||
from codeflash.context.code_context_extractor import get_code_optimization_context_for_language
|
||||
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
|
||||
from codeflash.languages.base import Language
|
||||
from codeflash.languages.javascript.support import JavaScriptSupport, TypeScriptSupport
|
||||
|
||||
|
|
@ -1694,6 +1696,240 @@ const FIELD_KEYS = {
|
|||
};"""
|
||||
assert context.read_only_context == expected_read_only
|
||||
|
||||
def test_with_tricky_helpers(self, ts_support, temp_project):
|
||||
"""Test function returning object with computed property names."""
|
||||
code = """import { WebClient, ChatPostMessageArguments } from "@slack/web-api"
|
||||
|
||||
// Dependencies interface for easier testing
|
||||
export interface SendSlackMessageDependencies {
|
||||
WebClient: typeof WebClient
|
||||
getSlackToken: () => string | undefined
|
||||
getSlackChannelId: () => string | undefined
|
||||
console: typeof console
|
||||
}
|
||||
|
||||
// Default dependencies
|
||||
let dependencies: SendSlackMessageDependencies = {
|
||||
WebClient,
|
||||
getSlackToken: () => process.env.SLACK_TOKEN,
|
||||
getSlackChannelId: () => process.env.SLACK_CHANNEL_ID,
|
||||
console,
|
||||
}
|
||||
|
||||
// For testing - allow dependency injection
|
||||
export function setSendSlackMessageDependencies(deps: Partial<SendSlackMessageDependencies>) {
|
||||
dependencies = { ...dependencies, ...deps }
|
||||
}
|
||||
|
||||
export function resetSendSlackMessageDependencies() {
|
||||
dependencies = {
|
||||
WebClient,
|
||||
getSlackToken: () => process.env.SLACK_TOKEN,
|
||||
getSlackChannelId: () => process.env.SLACK_CHANNEL_ID,
|
||||
console,
|
||||
}
|
||||
}
|
||||
|
||||
// Initialize web client
|
||||
let web: WebClient | null = null
|
||||
|
||||
export function initializeWebClient() {
|
||||
const SLACK_TOKEN = dependencies.getSlackToken()
|
||||
const SLACK_CHANNEL_ID = dependencies.getSlackChannelId()
|
||||
|
||||
if (!SLACK_TOKEN) {
|
||||
throw new Error("Missing SLACK_TOKEN")
|
||||
}
|
||||
|
||||
if (!SLACK_CHANNEL_ID) {
|
||||
throw new Error("Missing SLACK_CHANNEL_ID")
|
||||
}
|
||||
|
||||
if (!web) {
|
||||
web = new dependencies.WebClient(SLACK_TOKEN, {})
|
||||
}
|
||||
|
||||
return web
|
||||
}
|
||||
|
||||
// For testing - allow resetting the web client
|
||||
export function resetWebClient() {
|
||||
web = null
|
||||
}
|
||||
|
||||
/**
|
||||
* Send a message to Slack
|
||||
*
|
||||
* @param {string|object} message - Text message or Block Kit message object
|
||||
* @param {string|null} channel - Channel ID, defaults to SLACK_CHANNEL_ID
|
||||
* @param {boolean} returnData - Whether to return the full Slack API response
|
||||
* @returns {Promise<boolean|object>} - True or API response
|
||||
*/
|
||||
export const sendSlackMessage = async (
|
||||
message: any,
|
||||
channel: string | null = null,
|
||||
returnData: boolean = false,
|
||||
): Promise<boolean | object> => {
|
||||
return new Promise(async (resolve, reject) => {
|
||||
try {
|
||||
const webClient = initializeWebClient()
|
||||
const SLACK_CHANNEL_ID = dependencies.getSlackChannelId()
|
||||
const channelId = channel || SLACK_CHANNEL_ID
|
||||
|
||||
// Configure the message payload depending on the input type
|
||||
let payload: ChatPostMessageArguments
|
||||
|
||||
if (typeof message === "string") {
|
||||
payload = {
|
||||
channel: channelId,
|
||||
text: message,
|
||||
}
|
||||
} else if (message && typeof message === "object") {
|
||||
if (message.blocks) {
|
||||
payload = {
|
||||
channel: channelId,
|
||||
text: message.text || "Notification from CodeFlash",
|
||||
blocks: message.blocks,
|
||||
}
|
||||
} else {
|
||||
dependencies.console.warn("Object passed to sendSlackMessage without blocks property")
|
||||
payload = {
|
||||
channel: channelId,
|
||||
text: JSON.stringify(message),
|
||||
}
|
||||
}
|
||||
} else {
|
||||
dependencies.console.error("Invalid message type", typeof message)
|
||||
payload = {
|
||||
channel: channelId,
|
||||
text: "Invalid message",
|
||||
}
|
||||
}
|
||||
|
||||
// console.log("Sending payload to Slack:", JSON.stringify(payload, null, 2));
|
||||
|
||||
const resp = await webClient.chat.postMessage(payload)
|
||||
return resolve(returnData ? resp : true)
|
||||
} catch (error) {
|
||||
dependencies.console.error("Error sending Slack message:", error)
|
||||
return resolve(returnData ? { error } : true)
|
||||
}
|
||||
})
|
||||
}
|
||||
"""
|
||||
file_path = temp_project / "slack_util.ts"
|
||||
file_path.write_text(code, encoding="utf-8")
|
||||
target_func = "sendSlackMessage"
|
||||
|
||||
functions = ts_support.discover_functions(file_path)
|
||||
func_info = next(f for f in functions if f.function_name == target_func)
|
||||
fto = FunctionToOptimize(
|
||||
function_name=target_func,
|
||||
file_path=file_path,
|
||||
parents=func_info.parents,
|
||||
starting_line=func_info.starting_line,
|
||||
ending_line=func_info.ending_line,
|
||||
starting_col=func_info.starting_col,
|
||||
ending_col=func_info.ending_col,
|
||||
is_async=func_info.is_async,
|
||||
language="typescript",
|
||||
)
|
||||
|
||||
ctx = get_code_optimization_context_for_language(
|
||||
fto, temp_project
|
||||
)
|
||||
|
||||
# The read_writable_code should contain the target function AND helper functions
|
||||
expected_read_writable = """```typescript:slack_util.ts
|
||||
import { WebClient, ChatPostMessageArguments } from "@slack/web-api"
|
||||
|
||||
export const sendSlackMessage = async (
|
||||
message: any,
|
||||
channel: string | null = null,
|
||||
returnData: boolean = false,
|
||||
): Promise<boolean | object> => {
|
||||
return new Promise(async (resolve, reject) => {
|
||||
try {
|
||||
const webClient = initializeWebClient()
|
||||
const SLACK_CHANNEL_ID = dependencies.getSlackChannelId()
|
||||
const channelId = channel || SLACK_CHANNEL_ID
|
||||
|
||||
// Configure the message payload depending on the input type
|
||||
let payload: ChatPostMessageArguments
|
||||
|
||||
if (typeof message === "string") {
|
||||
payload = {
|
||||
channel: channelId,
|
||||
text: message,
|
||||
}
|
||||
} else if (message && typeof message === "object") {
|
||||
if (message.blocks) {
|
||||
payload = {
|
||||
channel: channelId,
|
||||
text: message.text || "Notification from CodeFlash",
|
||||
blocks: message.blocks,
|
||||
}
|
||||
} else {
|
||||
dependencies.console.warn("Object passed to sendSlackMessage without blocks property")
|
||||
payload = {
|
||||
channel: channelId,
|
||||
text: JSON.stringify(message),
|
||||
}
|
||||
}
|
||||
} else {
|
||||
dependencies.console.error("Invalid message type", typeof message)
|
||||
payload = {
|
||||
channel: channelId,
|
||||
text: "Invalid message",
|
||||
}
|
||||
}
|
||||
|
||||
// console.log("Sending payload to Slack:", JSON.stringify(payload, null, 2));
|
||||
|
||||
const resp = await webClient.chat.postMessage(payload)
|
||||
return resolve(returnData ? resp : true)
|
||||
} catch (error) {
|
||||
dependencies.console.error("Error sending Slack message:", error)
|
||||
return resolve(returnData ? { error } : true)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
|
||||
export function initializeWebClient() {
|
||||
const SLACK_TOKEN = dependencies.getSlackToken()
|
||||
const SLACK_CHANNEL_ID = dependencies.getSlackChannelId()
|
||||
|
||||
if (!SLACK_TOKEN) {
|
||||
throw new Error("Missing SLACK_TOKEN")
|
||||
}
|
||||
|
||||
if (!SLACK_CHANNEL_ID) {
|
||||
throw new Error("Missing SLACK_CHANNEL_ID")
|
||||
}
|
||||
|
||||
if (!web) {
|
||||
web = new dependencies.WebClient(SLACK_TOKEN, {})
|
||||
}
|
||||
|
||||
return web
|
||||
}
|
||||
```"""
|
||||
|
||||
# The read_only_context should contain global variables (dependencies object, web client)
|
||||
# but NOT have invalid floating object properties
|
||||
expected_read_only = """let dependencies: SendSlackMessageDependencies = {
|
||||
WebClient,
|
||||
getSlackToken: () => process.env.SLACK_TOKEN,
|
||||
getSlackChannelId: () => process.env.SLACK_CHANNEL_ID,
|
||||
console,
|
||||
}
|
||||
let web: WebClient | null = null"""
|
||||
|
||||
assert ctx.read_writable_code.markdown == expected_read_writable
|
||||
assert ctx.read_only_context_code == expected_read_only
|
||||
|
||||
|
||||
|
||||
class TestContextProperties:
|
||||
"""Tests for CodeContext object properties."""
|
||||
|
|
|
|||
|
|
@ -5,7 +5,13 @@ import json
|
|||
import tempfile
|
||||
from pathlib import Path
|
||||
|
||||
from codeflash.languages.javascript.module_system import ModuleSystem, detect_module_system, get_import_statement
|
||||
from codeflash.languages.javascript.module_system import (
|
||||
ModuleSystem,
|
||||
convert_commonjs_to_esm,
|
||||
convert_esm_to_commonjs,
|
||||
detect_module_system,
|
||||
get_import_statement,
|
||||
)
|
||||
|
||||
|
||||
class TestModuleSystemDetection:
|
||||
|
|
@ -51,6 +57,39 @@ class TestModuleSystemDetection:
|
|||
result = detect_module_system(project_root, file_path)
|
||||
assert result == ModuleSystem.COMMONJS
|
||||
|
||||
def test_detect_esm_from_typescript_extension(self):
|
||||
"""Test detection of ES modules from TypeScript file extensions."""
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
project_root = Path(tmpdir)
|
||||
|
||||
# Test .ts files
|
||||
ts_file = project_root / "module.ts"
|
||||
ts_file.write_text("export const foo = 'bar';")
|
||||
assert detect_module_system(project_root, ts_file) == ModuleSystem.ES_MODULE
|
||||
|
||||
# Test .tsx files
|
||||
tsx_file = project_root / "component.tsx"
|
||||
tsx_file.write_text("export const Component = () => <div />;")
|
||||
assert detect_module_system(project_root, tsx_file) == ModuleSystem.ES_MODULE
|
||||
|
||||
# Test .mts files
|
||||
mts_file = project_root / "module.mts"
|
||||
mts_file.write_text("export const foo = 'bar';")
|
||||
assert detect_module_system(project_root, mts_file) == ModuleSystem.ES_MODULE
|
||||
|
||||
def test_typescript_ignores_package_json_commonjs(self):
|
||||
"""Test that TypeScript files are detected as ESM even with CommonJS package.json."""
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
project_root = Path(tmpdir)
|
||||
# Create package.json with explicit commonjs type
|
||||
package_json = project_root / "package.json"
|
||||
package_json.write_text(json.dumps({"type": "commonjs"}))
|
||||
|
||||
# TypeScript file should still be detected as ESM
|
||||
ts_file = project_root / "module.ts"
|
||||
ts_file.write_text("export const foo = 'bar';")
|
||||
assert detect_module_system(project_root, ts_file) == ModuleSystem.ES_MODULE
|
||||
|
||||
def test_detect_esm_from_import_syntax(self):
|
||||
"""Test detection of ES modules from import syntax."""
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
|
|
@ -159,3 +198,90 @@ class TestImportStatementGeneration:
|
|||
result = get_import_statement(ModuleSystem.COMMONJS, target, source, ["foo"])
|
||||
|
||||
assert result == "const { foo } = require('../../utils');"
|
||||
|
||||
|
||||
class TestModuleSystemConversion:
|
||||
"""Tests for CommonJS <-> ESM conversion."""
|
||||
|
||||
def test_convert_simple_destructured_require(self):
|
||||
"""Test converting simple destructured require to import."""
|
||||
code = "const { foo, bar } = require('./module');"
|
||||
result = convert_commonjs_to_esm(code)
|
||||
assert result == "import { foo, bar } from './module';"
|
||||
|
||||
def test_convert_destructured_require_with_alias(self):
|
||||
"""Test converting destructured require with alias to import with 'as'."""
|
||||
code = "const { foo: aliasedFoo } = require('./module');"
|
||||
result = convert_commonjs_to_esm(code)
|
||||
assert result == "import { foo as aliasedFoo } from './module';"
|
||||
|
||||
def test_convert_mixed_destructured_require(self):
|
||||
"""Test converting mixed destructured require (some aliased, some not)."""
|
||||
code = "const { foo, bar: aliasedBar, baz } = require('./module');"
|
||||
result = convert_commonjs_to_esm(code)
|
||||
assert result == "import { foo, bar as aliasedBar, baz } from './module';"
|
||||
|
||||
def test_convert_destructured_with_whitespace(self):
|
||||
"""Test that whitespace is handled correctly in destructuring."""
|
||||
code = "const { foo : aliasedFoo , bar } = require('./module');"
|
||||
result = convert_commonjs_to_esm(code)
|
||||
assert result == "import { foo as aliasedFoo, bar } from './module';"
|
||||
|
||||
def test_convert_simple_require(self):
|
||||
"""Test converting simple require to default import."""
|
||||
code = "const module = require('./module');"
|
||||
result = convert_commonjs_to_esm(code)
|
||||
assert result == "import module from './module';"
|
||||
|
||||
def test_convert_property_access_require(self):
|
||||
"""Test converting require with property access to named import."""
|
||||
code = "const foo = require('./module').bar;"
|
||||
result = convert_commonjs_to_esm(code)
|
||||
assert result == "import { bar as foo } from './module';"
|
||||
|
||||
def test_convert_property_access_default(self):
|
||||
"""Test converting require().default to default import."""
|
||||
code = "const foo = require('./module').default;"
|
||||
result = convert_commonjs_to_esm(code)
|
||||
assert result == "import foo from './module';"
|
||||
|
||||
def test_convert_multiple_requires(self):
|
||||
"""Test converting multiple requires in one code block."""
|
||||
code = """const { db: dbCore, cache } = require('@budibase/backend-core');
|
||||
const utils = require('./utils');
|
||||
const { process } = require('./processor');"""
|
||||
result = convert_commonjs_to_esm(code)
|
||||
expected = """import { db as dbCore, cache } from '@budibase/backend-core';
|
||||
import utils from './utils';
|
||||
import { process } from './processor';"""
|
||||
assert result == expected
|
||||
|
||||
def test_convert_esm_to_commonjs_named(self):
|
||||
"""Test converting named imports to destructured require."""
|
||||
code = "import { foo, bar } from './module';"
|
||||
result = convert_esm_to_commonjs(code)
|
||||
assert result == "const { foo, bar } = require('./module');"
|
||||
|
||||
def test_convert_esm_to_commonjs_default(self):
|
||||
"""Test converting default import to simple require."""
|
||||
code = "import module from './module';"
|
||||
result = convert_esm_to_commonjs(code)
|
||||
assert result == "const module = require('./module');"
|
||||
|
||||
def test_convert_esm_to_commonjs_with_alias(self):
|
||||
"""Test converting import with 'as' to destructured require.
|
||||
|
||||
Note: ESM uses 'as' but the regex keeps it as-is in the output.
|
||||
This is acceptable since the test is primarily for CommonJS -> ESM conversion.
|
||||
"""
|
||||
code = "import { foo as aliasedFoo } from './module';"
|
||||
result = convert_esm_to_commonjs(code)
|
||||
# The current implementation preserves 'as' syntax which works for our use case
|
||||
assert result == "const { foo as aliasedFoo } = require('./module');"
|
||||
|
||||
def test_real_world_budibase_import(self):
|
||||
"""Test the real-world case from Budibase that was failing."""
|
||||
code = "const { queue, context, db: dbCore, cache, events } = require('@budibase/backend-core');"
|
||||
result = convert_commonjs_to_esm(code)
|
||||
expected = "import { queue, context, db as dbCore, cache, events } from '@budibase/backend-core';"
|
||||
assert result == expected
|
||||
|
|
|
|||
|
|
@ -15,6 +15,8 @@ from pathlib import Path
|
|||
import pytest
|
||||
|
||||
from codeflash.code_utils.code_replacer import replace_function_definitions_for_language
|
||||
from codeflash.languages.base import Language
|
||||
from codeflash.languages.current import set_current_language
|
||||
from codeflash.languages.javascript.module_system import (
|
||||
ModuleSystem,
|
||||
convert_commonjs_to_esm,
|
||||
|
|
@ -300,12 +302,144 @@ export function calculate(x, y) {
|
|||
assert "return add(x, y);" in result
|
||||
|
||||
|
||||
class TestModuleSystemCompatibility:
|
||||
"""Tests for module system compatibility."""
|
||||
class TestTsJestSkipsConversion:
|
||||
"""Tests verifying that module system conversion is skipped when ts-jest is installed.
|
||||
|
||||
def test_convert_mixed_code_to_esm(self):
|
||||
"""Test converting mixed CJS/ESM code to pure ESM - exact output."""
|
||||
code = """\
|
||||
When ts-jest is installed, it handles module interoperability internally,
|
||||
so we skip conversion to avoid breaking valid imports.
|
||||
"""
|
||||
def __init__(self):
|
||||
set_current_language(Language.TYPESCRIPT)
|
||||
|
||||
def test_commonjs_not_converted_when_ts_jest_installed(self, tmp_path):
|
||||
"""Test that CommonJS is NOT converted to ESM when ts-jest is installed."""
|
||||
# Create a project with ts-jest
|
||||
package_json = tmp_path / "package.json"
|
||||
package_json.write_text('{"devDependencies": {"ts-jest": "^29.0.0"}}')
|
||||
|
||||
commonjs_test = """\
|
||||
const Logger = require('../utils/logger');
|
||||
const { helper } = require('../utils/helpers');
|
||||
|
||||
describe('Logger', () => {
|
||||
test('should work', () => {
|
||||
const logger = new Logger();
|
||||
expect(logger).toBeDefined();
|
||||
});
|
||||
});
|
||||
"""
|
||||
# With ts-jest, no conversion should happen
|
||||
result = ensure_module_system_compatibility(commonjs_test, ModuleSystem.ES_MODULE, tmp_path)
|
||||
|
||||
assert result == commonjs_test, (
|
||||
f"CommonJS should NOT be converted when ts-jest is installed.\n"
|
||||
f"Expected (unchanged):\n{commonjs_test}\n\nGot:\n{result}"
|
||||
)
|
||||
|
||||
def test_esm_not_converted_when_ts_jest_installed(self, tmp_path):
|
||||
"""Test that ESM is NOT converted to CommonJS when ts-jest is installed."""
|
||||
# Create a project with ts-jest
|
||||
package_json = tmp_path / "package.json"
|
||||
package_json.write_text('{"devDependencies": {"ts-jest": "^29.0.0"}}')
|
||||
|
||||
esm_test = """\
|
||||
import Logger from '../utils/logger';
|
||||
import { helper } from '../utils/helpers';
|
||||
|
||||
describe('Logger', () => {
|
||||
test('should work', () => {
|
||||
const logger = new Logger();
|
||||
expect(logger).toBeDefined();
|
||||
});
|
||||
});
|
||||
"""
|
||||
# With ts-jest, no conversion should happen
|
||||
result = ensure_module_system_compatibility(esm_test, ModuleSystem.COMMONJS, tmp_path)
|
||||
|
||||
assert result == esm_test, (
|
||||
f"ESM should NOT be converted when ts-jest is installed.\n"
|
||||
f"Expected (unchanged):\n{esm_test}\n\nGot:\n{result}"
|
||||
)
|
||||
|
||||
def test_ts_jest_detected_in_jest_config(self, tmp_path):
|
||||
"""Test that ts-jest is detected from jest.config.js content."""
|
||||
# Create a project with ts-jest in jest.config.js (not package.json)
|
||||
package_json = tmp_path / "package.json"
|
||||
package_json.write_text('{"devDependencies": {}}')
|
||||
jest_config = tmp_path / "jest.config.js"
|
||||
jest_config.write_text("module.exports = { preset: 'ts-jest' };")
|
||||
|
||||
commonjs_test = "const x = require('./module');"
|
||||
|
||||
result = ensure_module_system_compatibility(commonjs_test, ModuleSystem.ES_MODULE, tmp_path)
|
||||
|
||||
assert result == commonjs_test, "Should skip conversion when ts-jest is in jest.config.js"
|
||||
|
||||
|
||||
class TestModuleSystemConversion:
|
||||
"""Tests for module system conversion when ts-jest is NOT installed.
|
||||
|
||||
Without ts-jest, we convert between CommonJS and ESM as needed.
|
||||
"""
|
||||
|
||||
def test_commonjs_converted_to_esm_without_ts_jest(self, tmp_path):
|
||||
"""Test that CommonJS is converted to ESM when ts-jest is NOT installed."""
|
||||
# Create a project WITHOUT ts-jest
|
||||
package_json = tmp_path / "package.json"
|
||||
package_json.write_text('{"devDependencies": {"jest": "^29.0.0"}}')
|
||||
|
||||
commonjs_code = """\
|
||||
const { helper } = require('./helpers');
|
||||
const logger = require('./logger');
|
||||
|
||||
function process() {
|
||||
return helper();
|
||||
}
|
||||
"""
|
||||
result = ensure_module_system_compatibility(commonjs_code, ModuleSystem.ES_MODULE, tmp_path)
|
||||
|
||||
# Should be converted to ESM
|
||||
assert "import { helper } from './helpers';" in result
|
||||
assert "import logger from './logger';" in result
|
||||
assert "require(" not in result
|
||||
|
||||
def test_esm_converted_to_commonjs_without_ts_jest(self, tmp_path):
|
||||
"""Test that ESM is converted to CommonJS when ts-jest is NOT installed."""
|
||||
# Create a project WITHOUT ts-jest
|
||||
package_json = tmp_path / "package.json"
|
||||
package_json.write_text('{"devDependencies": {"jest": "^29.0.0"}}')
|
||||
|
||||
esm_code = """\
|
||||
import { helper } from './helpers';
|
||||
import logger from './logger';
|
||||
|
||||
function process() {
|
||||
return helper();
|
||||
}
|
||||
"""
|
||||
result = ensure_module_system_compatibility(esm_code, ModuleSystem.COMMONJS, tmp_path)
|
||||
|
||||
# Should be converted to CommonJS
|
||||
assert "const { helper } = require('./helpers');" in result
|
||||
assert "const logger = require('./logger');" in result
|
||||
assert "import " not in result
|
||||
|
||||
def test_no_conversion_when_project_root_is_none(self):
|
||||
"""Test that conversion happens when project_root is None (can't detect ts-jest)."""
|
||||
commonjs_code = "const x = require('./module');"
|
||||
|
||||
# Without project_root, we can't detect ts-jest, so conversion should happen
|
||||
result = ensure_module_system_compatibility(commonjs_code, ModuleSystem.ES_MODULE, None)
|
||||
|
||||
# Should be converted to ESM
|
||||
assert "import x from './module';" in result
|
||||
|
||||
def test_mixed_code_not_converted(self, tmp_path):
|
||||
"""Test that mixed CJS/ESM code is NOT converted (already has both)."""
|
||||
package_json = tmp_path / "package.json"
|
||||
package_json.write_text('{"devDependencies": {"jest": "^29.0.0"}}')
|
||||
|
||||
mixed_code = """\
|
||||
import { existing } from './module.js';
|
||||
const { helper } = require('./helpers');
|
||||
|
||||
|
|
@ -313,32 +447,16 @@ function process() {
|
|||
return existing() + helper();
|
||||
}
|
||||
"""
|
||||
result = ensure_module_system_compatibility(code, ModuleSystem.ES_MODULE)
|
||||
# Mixed code has both import and require, so no conversion
|
||||
result = ensure_module_system_compatibility(mixed_code, ModuleSystem.ES_MODULE, tmp_path)
|
||||
|
||||
# Should convert require to import
|
||||
assert "import { helper } from './helpers';" in result
|
||||
assert "require" not in result, f"require should be converted to import. Got:\n{result}"
|
||||
assert result == mixed_code, "Mixed code should not be converted"
|
||||
|
||||
def test_convert_mixed_code_to_commonjs(self):
|
||||
"""Test converting mixed ESM/CJS code to pure CommonJS - exact output."""
|
||||
code = """\
|
||||
const { existing } = require('./module');
|
||||
import { helper } from './helpers.js';
|
||||
|
||||
function process() {
|
||||
return existing() + helper();
|
||||
}
|
||||
"""
|
||||
result = ensure_module_system_compatibility(code, ModuleSystem.COMMONJS)
|
||||
|
||||
# Should convert import to require
|
||||
assert "const { helper } = require('./helpers');" in result
|
||||
assert "import " not in result.split("\n")[0] or "import " not in result, (
|
||||
f"import should be converted to require. Got:\n{result}"
|
||||
)
|
||||
|
||||
def test_pure_esm_unchanged(self):
|
||||
def test_pure_esm_unchanged_for_esm_target(self, tmp_path):
|
||||
"""Test that pure ESM code is unchanged when targeting ESM."""
|
||||
package_json = tmp_path / "package.json"
|
||||
package_json.write_text('{"devDependencies": {"jest": "^29.0.0"}}')
|
||||
|
||||
code = """\
|
||||
import { add } from './math.js';
|
||||
|
||||
|
|
@ -346,11 +464,14 @@ export function sum(a, b) {
|
|||
return add(a, b);
|
||||
}
|
||||
"""
|
||||
result = ensure_module_system_compatibility(code, ModuleSystem.ES_MODULE)
|
||||
assert result == code, f"Pure ESM code should be unchanged.\nExpected:\n{code}\n\nGot:\n{result}"
|
||||
result = ensure_module_system_compatibility(code, ModuleSystem.ES_MODULE, tmp_path)
|
||||
assert result == code, "Pure ESM code should be unchanged for ESM target"
|
||||
|
||||
def test_pure_commonjs_unchanged(self):
|
||||
def test_pure_commonjs_unchanged_for_commonjs_target(self, tmp_path):
|
||||
"""Test that pure CommonJS code is unchanged when targeting CommonJS."""
|
||||
package_json = tmp_path / "package.json"
|
||||
package_json.write_text('{"devDependencies": {"jest": "^29.0.0"}}')
|
||||
|
||||
code = """\
|
||||
const { add } = require('./math');
|
||||
|
||||
|
|
@ -360,8 +481,8 @@ function sum(a, b) {
|
|||
|
||||
module.exports = { sum };
|
||||
"""
|
||||
result = ensure_module_system_compatibility(code, ModuleSystem.COMMONJS)
|
||||
assert result == code, f"Pure CommonJS code should be unchanged.\nExpected:\n{code}\n\nGot:\n{result}"
|
||||
result = ensure_module_system_compatibility(code, ModuleSystem.COMMONJS, tmp_path)
|
||||
assert result == code, "Pure CommonJS code should be unchanged for CommonJS target"
|
||||
|
||||
|
||||
class TestImportStatementGeneration:
|
||||
|
|
|
|||
|
|
@ -555,3 +555,125 @@ def standalone():
|
|||
|
||||
standalone_funcs = [f for f in functions if f.class_name is None]
|
||||
assert len(standalone_funcs) == 1
|
||||
|
||||
|
||||
# === Tests for find_references method ===
|
||||
# These tests verify that PythonSupport correctly finds references to functions
|
||||
# using jedi, including the fix for using function.function_name instead of function.name.
|
||||
|
||||
|
||||
def test_find_references_simple_function(python_support, tmp_path):
|
||||
"""Test finding references to a simple function.
|
||||
|
||||
This test specifically exercises the code path that was fixed in the
|
||||
regression where function.name was used instead of function.function_name.
|
||||
"""
|
||||
from codeflash.models.function_types import FunctionToOptimize
|
||||
|
||||
# Create source file with function definition
|
||||
source_file = tmp_path / "utils.py"
|
||||
source_file.write_text("""def helper_function(x):
|
||||
return x * 2
|
||||
""")
|
||||
|
||||
# Create a file that imports and uses the function
|
||||
consumer_file = tmp_path / "consumer.py"
|
||||
consumer_file.write_text("""from utils import helper_function
|
||||
|
||||
def process(value):
|
||||
return helper_function(value) + 1
|
||||
""")
|
||||
|
||||
func = FunctionToOptimize(
|
||||
function_name="helper_function",
|
||||
file_path=source_file,
|
||||
starting_line=1,
|
||||
ending_line=2,
|
||||
)
|
||||
|
||||
refs = python_support.find_references(func, project_root=tmp_path)
|
||||
|
||||
assert len(refs) >= 1
|
||||
ref_files = {str(r.file_path) for r in refs}
|
||||
assert any("consumer.py" in f for f in ref_files)
|
||||
|
||||
|
||||
def test_find_references_class_method(python_support, tmp_path):
|
||||
"""Test finding references to a class method.
|
||||
|
||||
This verifies the class_name attribute is correctly used to disambiguate methods.
|
||||
"""
|
||||
from codeflash.models.function_types import FunctionParent, FunctionToOptimize
|
||||
|
||||
# Create source file with class and method
|
||||
source_file = tmp_path / "calculator.py"
|
||||
source_file.write_text("""class Calculator:
|
||||
def add(self, a, b):
|
||||
return a + b
|
||||
""")
|
||||
|
||||
# Create a file that uses the class method
|
||||
consumer_file = tmp_path / "main.py"
|
||||
consumer_file.write_text("""from calculator import Calculator
|
||||
|
||||
def compute():
|
||||
calc = Calculator()
|
||||
return calc.add(1, 2)
|
||||
""")
|
||||
|
||||
func = FunctionToOptimize(
|
||||
function_name="add",
|
||||
file_path=source_file,
|
||||
parents=[FunctionParent(name="Calculator", type="ClassDef")],
|
||||
starting_line=2,
|
||||
ending_line=3,
|
||||
is_method=True,
|
||||
)
|
||||
|
||||
refs = python_support.find_references(func, project_root=tmp_path)
|
||||
|
||||
assert len(refs) >= 1
|
||||
ref_files = {str(r.file_path) for r in refs}
|
||||
assert any("main.py" in f for f in ref_files)
|
||||
|
||||
|
||||
def test_find_references_no_references(python_support, tmp_path):
|
||||
"""Test that find_references returns empty list when no references exist."""
|
||||
from codeflash.models.function_types import FunctionToOptimize
|
||||
|
||||
source_file = tmp_path / "isolated.py"
|
||||
source_file.write_text("""def isolated_function():
|
||||
return 42
|
||||
""")
|
||||
|
||||
func = FunctionToOptimize(
|
||||
function_name="isolated_function",
|
||||
file_path=source_file,
|
||||
starting_line=1,
|
||||
ending_line=2,
|
||||
)
|
||||
|
||||
refs = python_support.find_references(func, project_root=tmp_path)
|
||||
|
||||
assert refs == []
|
||||
|
||||
|
||||
def test_find_references_nonexistent_function(python_support, tmp_path):
|
||||
"""Test that find_references handles nonexistent functions gracefully."""
|
||||
from codeflash.models.function_types import FunctionToOptimize
|
||||
|
||||
source_file = tmp_path / "source.py"
|
||||
source_file.write_text("""def existing_function():
|
||||
return 1
|
||||
""")
|
||||
|
||||
func = FunctionToOptimize(
|
||||
function_name="nonexistent_function",
|
||||
file_path=source_file,
|
||||
starting_line=1,
|
||||
ending_line=2,
|
||||
)
|
||||
|
||||
refs = python_support.find_references(func, project_root=tmp_path)
|
||||
|
||||
assert refs == []
|
||||
|
|
|
|||
|
|
@ -14,6 +14,7 @@ from codeflash.setup.detector import (
|
|||
_find_project_root,
|
||||
detect_project,
|
||||
has_existing_config,
|
||||
is_build_output_dir,
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -139,15 +140,15 @@ class TestDetectModuleRoot:
|
|||
assert "pyproject.toml" in detail
|
||||
|
||||
def test_js_detects_from_exports(self, tmp_path):
|
||||
"""Should detect module root from package.json exports."""
|
||||
"""Should detect module root from package.json exports when no common src dir exists."""
|
||||
(tmp_path / "package.json").write_text(json.dumps({
|
||||
"name": "test",
|
||||
"exports": {".": "./src/index.js"}
|
||||
"exports": {".": "./packages/core/index.js"}
|
||||
}))
|
||||
(tmp_path / "src").mkdir()
|
||||
(tmp_path / "packages" / "core").mkdir(parents=True)
|
||||
|
||||
module_root, detail = _detect_js_module_root(tmp_path)
|
||||
assert module_root == tmp_path / "src"
|
||||
assert module_root == tmp_path / "packages" / "core"
|
||||
assert "exports" in detail
|
||||
|
||||
def test_js_detects_src_convention(self, tmp_path):
|
||||
|
|
@ -158,6 +159,214 @@ class TestDetectModuleRoot:
|
|||
module_root, detail = _detect_js_module_root(tmp_path)
|
||||
assert module_root == tmp_path / "src"
|
||||
|
||||
def test_js_prefers_src_over_build_src(self, tmp_path):
|
||||
"""Should prefer src/ over build/src/ even when package.json points to build/."""
|
||||
(tmp_path / "package.json").write_text(json.dumps({
|
||||
"name": "test",
|
||||
"main": "build/src/index.js",
|
||||
"module": "build/src/index.js"
|
||||
}))
|
||||
(tmp_path / "src").mkdir()
|
||||
(tmp_path / "build" / "src").mkdir(parents=True)
|
||||
|
||||
module_root, detail = _detect_js_module_root(tmp_path)
|
||||
assert module_root == tmp_path / "src"
|
||||
assert "src/ directory" in detail
|
||||
|
||||
def test_js_skips_build_dir_from_main(self, tmp_path):
|
||||
"""Should skip build output directories from package.json main field."""
|
||||
(tmp_path / "package.json").write_text(json.dumps({
|
||||
"name": "test",
|
||||
"main": "build/index.js"
|
||||
}))
|
||||
(tmp_path / "build").mkdir()
|
||||
|
||||
module_root, detail = _detect_js_module_root(tmp_path)
|
||||
assert module_root == tmp_path
|
||||
assert "project root" in detail
|
||||
|
||||
def test_js_skips_dist_dir_from_exports(self, tmp_path):
|
||||
"""Should skip dist output directories from package.json exports field."""
|
||||
(tmp_path / "package.json").write_text(json.dumps({
|
||||
"name": "test",
|
||||
"exports": {".": "./dist/index.js"}
|
||||
}))
|
||||
(tmp_path / "dist").mkdir()
|
||||
|
||||
module_root, detail = _detect_js_module_root(tmp_path)
|
||||
assert module_root == tmp_path
|
||||
assert "project root" in detail
|
||||
|
||||
def test_js_skips_out_dir_from_module(self, tmp_path):
|
||||
"""Should skip out output directories from package.json module field."""
|
||||
(tmp_path / "package.json").write_text(json.dumps({
|
||||
"name": "test",
|
||||
"module": "out/esm/index.js"
|
||||
}))
|
||||
(tmp_path / "out" / "esm").mkdir(parents=True)
|
||||
|
||||
module_root, detail = _detect_js_module_root(tmp_path)
|
||||
assert module_root == tmp_path
|
||||
assert "project root" in detail
|
||||
|
||||
def test_js_prefers_lib_over_build_dir(self, tmp_path):
|
||||
"""Should prefer lib/ over build output directories."""
|
||||
(tmp_path / "package.json").write_text(json.dumps({
|
||||
"name": "test",
|
||||
"main": "dist/index.js"
|
||||
}))
|
||||
(tmp_path / "lib").mkdir()
|
||||
(tmp_path / "dist").mkdir()
|
||||
|
||||
module_root, detail = _detect_js_module_root(tmp_path)
|
||||
assert module_root == tmp_path / "lib"
|
||||
assert "lib/ directory" in detail
|
||||
|
||||
def test_js_prefers_source_over_build_dir(self, tmp_path):
|
||||
"""Should prefer source/ over build output directories."""
|
||||
(tmp_path / "package.json").write_text(json.dumps({
|
||||
"name": "test",
|
||||
"main": "build/index.js"
|
||||
}))
|
||||
(tmp_path / "source").mkdir()
|
||||
(tmp_path / "build").mkdir()
|
||||
|
||||
module_root, detail = _detect_js_module_root(tmp_path)
|
||||
assert module_root == tmp_path / "source"
|
||||
assert "source/ directory" in detail
|
||||
|
||||
def test_js_falls_back_to_valid_exports_path(self, tmp_path):
|
||||
"""Should use exports path when no common source dirs exist and path is not build output."""
|
||||
(tmp_path / "package.json").write_text(json.dumps({
|
||||
"name": "test",
|
||||
"exports": {".": "./packages/core/index.js"}
|
||||
}))
|
||||
(tmp_path / "packages" / "core").mkdir(parents=True)
|
||||
|
||||
module_root, detail = _detect_js_module_root(tmp_path)
|
||||
assert module_root == tmp_path / "packages" / "core"
|
||||
assert "exports" in detail
|
||||
|
||||
def test_js_falls_back_to_valid_main_path(self, tmp_path):
|
||||
"""Should use main path when no common source dirs exist and path is not build output."""
|
||||
(tmp_path / "package.json").write_text(json.dumps({
|
||||
"name": "test",
|
||||
"main": "packages/main/index.js"
|
||||
}))
|
||||
(tmp_path / "packages" / "main").mkdir(parents=True)
|
||||
|
||||
module_root, detail = _detect_js_module_root(tmp_path)
|
||||
assert module_root == tmp_path / "packages" / "main"
|
||||
assert "main" in detail
|
||||
|
||||
def test_js_falls_back_to_valid_module_path(self, tmp_path):
|
||||
"""Should use module path when no common source dirs exist and path is not build output."""
|
||||
(tmp_path / "package.json").write_text(json.dumps({
|
||||
"name": "test",
|
||||
"module": "esm/index.js"
|
||||
}))
|
||||
(tmp_path / "esm").mkdir()
|
||||
|
||||
module_root, detail = _detect_js_module_root(tmp_path)
|
||||
assert module_root == tmp_path / "esm"
|
||||
assert "module" in detail
|
||||
|
||||
def test_js_returns_project_root_when_all_paths_are_build_output(self, tmp_path):
|
||||
"""Should return project root when all package.json paths point to build outputs."""
|
||||
(tmp_path / "package.json").write_text(json.dumps({
|
||||
"name": "test",
|
||||
"main": "dist/cjs/index.js",
|
||||
"module": "dist/esm/index.js",
|
||||
"exports": {".": "./build/index.js"}
|
||||
}))
|
||||
(tmp_path / "dist" / "cjs").mkdir(parents=True)
|
||||
(tmp_path / "dist" / "esm").mkdir(parents=True)
|
||||
(tmp_path / "build").mkdir()
|
||||
|
||||
module_root, detail = _detect_js_module_root(tmp_path)
|
||||
assert module_root == tmp_path
|
||||
assert "project root" in detail
|
||||
|
||||
def test_js_handles_malformed_package_json(self, tmp_path):
|
||||
"""Should handle malformed package.json gracefully."""
|
||||
(tmp_path / "package.json").write_text("{ invalid json }")
|
||||
|
||||
module_root, detail = _detect_js_module_root(tmp_path)
|
||||
assert module_root == tmp_path
|
||||
assert "project root" in detail
|
||||
|
||||
|
||||
class TestIsBuildOutputDir:
|
||||
"""Tests for is_build_output_dir function."""
|
||||
|
||||
def test_detects_build_dir(self):
|
||||
"""Should detect build/ as build output."""
|
||||
from pathlib import Path
|
||||
assert is_build_output_dir(Path("build"))
|
||||
assert is_build_output_dir(Path("build/src"))
|
||||
assert is_build_output_dir(Path("build/src/index.js"))
|
||||
|
||||
def test_detects_dist_dir(self):
|
||||
"""Should detect dist/ as build output."""
|
||||
from pathlib import Path
|
||||
assert is_build_output_dir(Path("dist"))
|
||||
assert is_build_output_dir(Path("dist/esm"))
|
||||
assert is_build_output_dir(Path("dist/cjs/index.js"))
|
||||
|
||||
def test_detects_out_dir(self):
|
||||
"""Should detect out/ as build output."""
|
||||
from pathlib import Path
|
||||
assert is_build_output_dir(Path("out"))
|
||||
assert is_build_output_dir(Path("out/src"))
|
||||
|
||||
def test_detects_next_dir(self):
|
||||
"""Should detect .next/ as build output."""
|
||||
from pathlib import Path
|
||||
assert is_build_output_dir(Path(".next"))
|
||||
assert is_build_output_dir(Path(".next/static"))
|
||||
|
||||
def test_detects_nuxt_dir(self):
|
||||
"""Should detect .nuxt/ as build output."""
|
||||
from pathlib import Path
|
||||
assert is_build_output_dir(Path(".nuxt"))
|
||||
assert is_build_output_dir(Path(".nuxt/dist"))
|
||||
|
||||
def test_detects_nested_build_dir(self):
|
||||
"""Should detect build dir nested in path."""
|
||||
from pathlib import Path
|
||||
assert is_build_output_dir(Path("packages/build/index.js"))
|
||||
assert is_build_output_dir(Path("foo/dist/bar"))
|
||||
|
||||
def test_does_not_detect_src(self):
|
||||
"""Should not detect src/ as build output."""
|
||||
from pathlib import Path
|
||||
assert not is_build_output_dir(Path("src"))
|
||||
assert not is_build_output_dir(Path("src/index.js"))
|
||||
|
||||
def test_does_not_detect_lib(self):
|
||||
"""Should not detect lib/ as build output."""
|
||||
from pathlib import Path
|
||||
assert not is_build_output_dir(Path("lib"))
|
||||
assert not is_build_output_dir(Path("lib/utils"))
|
||||
|
||||
def test_does_not_detect_source(self):
|
||||
"""Should not detect source/ as build output."""
|
||||
from pathlib import Path
|
||||
assert not is_build_output_dir(Path("source"))
|
||||
|
||||
def test_does_not_detect_packages(self):
|
||||
"""Should not detect packages/ as build output."""
|
||||
from pathlib import Path
|
||||
assert not is_build_output_dir(Path("packages"))
|
||||
assert not is_build_output_dir(Path("packages/core"))
|
||||
|
||||
def test_does_not_detect_similar_names(self):
|
||||
"""Should not detect directories with similar but different names."""
|
||||
from pathlib import Path
|
||||
assert not is_build_output_dir(Path("builder"))
|
||||
assert not is_build_output_dir(Path("distribution"))
|
||||
assert not is_build_output_dir(Path("output"))
|
||||
|
||||
|
||||
class TestDetectTestsRoot:
|
||||
"""Tests for tests root detection."""
|
||||
|
|
|
|||
|
|
@ -857,6 +857,48 @@ class TestE2ECLIFlags:
|
|||
# Should complete without error
|
||||
_handle_show_config()
|
||||
|
||||
def test_show_config_displays_config_path_when_saved(self, project_with_existing_config, monkeypatch):
|
||||
"""Should display config file path when saved config exists."""
|
||||
monkeypatch.chdir(project_with_existing_config)
|
||||
|
||||
# Track what gets printed
|
||||
printed_messages = []
|
||||
|
||||
def mock_print(msg="", *args, **kwargs):
|
||||
printed_messages.append(str(msg))
|
||||
|
||||
from codeflash.cli_cmds import console
|
||||
monkeypatch.setattr(console.console, "print", mock_print)
|
||||
|
||||
from codeflash.cli_cmds.cli import _handle_show_config
|
||||
_handle_show_config()
|
||||
|
||||
# Verify config path is displayed
|
||||
all_output = "\n".join(printed_messages)
|
||||
assert "pyproject.toml" in all_output
|
||||
assert "Config file:" in all_output
|
||||
|
||||
def test_show_config_no_path_when_auto_detected(self, python_src_layout, monkeypatch):
|
||||
"""Should not display config file path when config is auto-detected."""
|
||||
monkeypatch.chdir(python_src_layout)
|
||||
|
||||
# Track what gets printed
|
||||
printed_messages = []
|
||||
|
||||
def mock_print(msg="", *args, **kwargs):
|
||||
printed_messages.append(str(msg))
|
||||
|
||||
from codeflash.cli_cmds import console
|
||||
monkeypatch.setattr(console.console, "print", mock_print)
|
||||
|
||||
from codeflash.cli_cmds.cli import _handle_show_config
|
||||
_handle_show_config()
|
||||
|
||||
# Verify no config path line is displayed
|
||||
all_output = "\n".join(printed_messages)
|
||||
assert "Config file:" not in all_output
|
||||
assert "Auto-detected" in all_output
|
||||
|
||||
def test_reset_config_removes_from_pyproject(self, project_with_existing_config, monkeypatch):
|
||||
"""Should remove codeflash config from pyproject.toml."""
|
||||
monkeypatch.chdir(project_with_existing_config)
|
||||
|
|
|
|||
Loading…
Reference in a new issue