Merge branch 'main' into omni-java

This commit is contained in:
Kevin Turcios 2026-02-04 03:22:37 -05:00 committed by GitHub
commit 95cc60397d
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
43 changed files with 2918 additions and 1301 deletions

View file

@ -1,59 +0,0 @@
name: Claude Code Review
on:
pull_request:
types: [opened, synchronize]
# Optional: Only run on specific file changes
# paths:
# - "src/**/*.ts"
# - "src/**/*.tsx"
# - "src/**/*.js"
# - "src/**/*.jsx"
jobs:
claude-review:
# Optional: Filter by PR author
# if: |
# github.event.pull_request.user.login == 'external-contributor' ||
# github.event.pull_request.user.login == 'new-developer' ||
# github.event.pull_request.author_association == 'FIRST_TIME_CONTRIBUTOR'
runs-on: ubuntu-latest
permissions:
contents: read
pull-requests: write
issues: read
id-token: write
steps:
- name: Checkout repository
uses: actions/checkout@v4
with:
fetch-depth: 1
- name: Run Claude Code Review
id: claude-review
uses: anthropics/claude-code-action@v1
with:
use_foundry: "true"
use_sticky_comment: true
prompt: |
REPO: ${{ github.repository }}
PR NUMBER: ${{ github.event.pull_request.number }}
Please review this pull request and provide feedback on:
- Code quality and best practices
- Potential bugs or issues
- Performance considerations
- Security concerns
- Test coverage
Use the repository's CLAUDE.md for guidance on style and conventions. Be constructive and helpful in your feedback.
# See https://github.com/anthropics/claude-code-action/blob/main/docs/usage.md
# or https://code.claude.com/docs/en/cli-reference for available options
claude_args: '--allowed-tools "Bash(gh issue view:*),Bash(gh search:*),Bash(gh issue list:*),Bash(gh pr diff:*),Bash(gh pr view:*),Bash(gh pr list:*)"'
env:
ANTHROPIC_FOUNDRY_API_KEY: ${{ secrets.AZURE_ANTHROPIC_API_KEY }}
ANTHROPIC_FOUNDRY_BASE_URL: ${{ secrets.AZURE_ANTHROPIC_ENDPOINT }}

View file

@ -1,6 +1,8 @@
name: Claude Code
on:
pull_request:
types: [opened, synchronize, ready_for_review, reopened]
issue_comment:
types: [created]
pull_request_review_comment:
@ -11,19 +13,154 @@ on:
types: [submitted]
jobs:
claude:
# Automatic PR review (can fix linting issues and push)
# Blocked for fork PRs to prevent malicious code execution
pr-review:
if: |
(github.event_name == 'issue_comment' && contains(github.event.comment.body, '@claude')) ||
(github.event_name == 'pull_request_review_comment' && contains(github.event.comment.body, '@claude')) ||
(github.event_name == 'pull_request_review' && contains(github.event.review.body, '@claude')) ||
(github.event_name == 'issues' && (contains(github.event.issue.body, '@claude') || contains(github.event.issue.title, '@claude')))
github.event_name == 'pull_request' &&
github.actor != 'claude[bot]' &&
github.event.pull_request.head.repo.full_name == github.repository
runs-on: ubuntu-latest
permissions:
contents: write
pull-requests: write
issues: read
id-token: write
actions: read # Required for Claude to read CI results on PRs
actions: read
steps:
- name: Checkout repository
uses: actions/checkout@v4
with:
fetch-depth: 0
ref: ${{ github.event.pull_request.head.ref }}
- name: Install uv
uses: astral-sh/setup-uv@v6
- name: Install dependencies
run: |
uv venv --seed
uv sync
- name: Run Claude Code
id: claude
uses: anthropics/claude-code-action@v1
with:
use_foundry: "true"
use_sticky_comment: true
allowed_bots: "claude[bot]"
prompt: |
REPO: ${{ github.repository }}
PR NUMBER: ${{ github.event.pull_request.number }}
EVENT: ${{ github.event.action }}
## STEP 1: Run pre-commit checks and fix issues
First, run `uv run prek run --from-ref origin/main` to check for linting/formatting issues on files changed in this PR.
If there are any issues:
- For SAFE auto-fixable issues (formatting, import sorting, trailing whitespace, etc.), run `uv run prek run --from-ref origin/main` again to auto-fix them
- Stage the fixed files with `git add`
- Commit with message "style: auto-fix linting issues"
- Push the changes with `git push`
Do NOT attempt to fix:
- Type errors that require logic changes
- Complex refactoring suggestions
- Anything that could change behavior
## STEP 2: Review the PR
${{ github.event.action == 'synchronize' && 'This is a RE-REVIEW after new commits. First, get the list of changed files in this latest push using `gh pr diff`. Review ONLY the changed files. Check ALL existing review comments and resolve ones that are now fixed.' || 'This is the INITIAL REVIEW.' }}
Review this PR focusing ONLY on:
1. Critical bugs or logic errors
2. Security vulnerabilities
3. Breaking API changes
4. Test failures (methods with typos that wont run)
IMPORTANT:
- First check existing review comments using `gh api repos/${{ github.repository }}/pulls/${{ github.event.pull_request.number }}/comments`. For each existing comment, check if the issue still exists in the current code.
- If an issue is fixed, use `gh api --method PATCH repos/${{ github.repository }}/pulls/comments/COMMENT_ID -f body="✅ Fixed in latest commit"` to resolve it.
- Only create NEW inline comments for HIGH-PRIORITY issues found in changed files.
- Limit to 5-7 NEW comments maximum per review.
- Use CLAUDE.md for project-specific guidance.
- Use `gh pr comment` for summary-level feedback.
- Use `mcp__github_inline_comment__create_inline_comment` sparingly for critical code issues only.
## STEP 3: Coverage analysis
Analyze test coverage for changed files:
1. Get the list of Python files changed in this PR (excluding tests):
`git diff --name-only origin/main...HEAD -- '*.py' | grep -v test`
2. Run tests with coverage on the PR branch:
`uv run coverage run -m pytest tests/ -q --tb=no`
`uv run coverage json -o coverage-pr.json`
3. Get coverage for changed files only:
`uv run coverage report --include="<changed_files_comma_separated>"`
4. Compare with main branch coverage:
- Checkout main: `git checkout origin/main`
- Run coverage: `uv run coverage run -m pytest tests/ -q --tb=no && uv run coverage json -o coverage-main.json`
- Checkout back: `git checkout -`
5. Analyze the diff to identify:
- NEW FILES: Files that don't exist on main (require good test coverage)
- MODIFIED FILES: Files with changes (changes must be covered by tests)
6. Report in PR comment with a markdown table:
- Coverage % for each changed file (PR vs main)
- Overall coverage change
- For NEW files: Flag if coverage is below 75%
- For MODIFIED files: Flag if the changed lines are not covered by tests
- Flag if overall coverage decreased
Coverage requirements:
- New implementations/files: Must have ≥75% test coverage
- Modified code: Changed lines should be exercised by existing or new tests
- No coverage regressions: Overall coverage should not decrease
claude_args: '--allowedTools "mcp__github_inline_comment__create_inline_comment,Bash(gh pr comment:*),Bash(gh pr diff:*),Bash(gh pr view:*),Bash(gh pr list:*),Bash(gh issue view:*),Bash(gh issue list:*),Bash(gh api:*),Bash(uv run prek *),Bash(uv run coverage *),Bash(uv run pytest *),Bash(git status*),Bash(git add *),Bash(git commit *),Bash(git push*),Bash(git diff *),Bash(git checkout *),Read,Glob,Grep"'
additional_permissions: |
actions: read
env:
ANTHROPIC_FOUNDRY_API_KEY: ${{ secrets.AZURE_ANTHROPIC_API_KEY }}
ANTHROPIC_FOUNDRY_BASE_URL: ${{ secrets.AZURE_ANTHROPIC_ENDPOINT }}
# @claude mentions (can edit and push) - restricted to maintainers only
claude-mention:
if: |
(
github.event_name == 'issue_comment' &&
contains(github.event.comment.body, '@claude') &&
(github.event.comment.author_association == 'OWNER' || github.event.comment.author_association == 'MEMBER' || github.event.comment.author_association == 'COLLABORATOR')
) ||
(
github.event_name == 'pull_request_review_comment' &&
contains(github.event.comment.body, '@claude') &&
(github.event.comment.author_association == 'OWNER' || github.event.comment.author_association == 'MEMBER' || github.event.comment.author_association == 'COLLABORATOR') &&
github.event.pull_request.head.repo.full_name == github.repository
) ||
(
github.event_name == 'pull_request_review' &&
contains(github.event.review.body, '@claude') &&
(github.event.review.author_association == 'OWNER' || github.event.review.author_association == 'MEMBER' || github.event.review.author_association == 'COLLABORATOR') &&
github.event.pull_request.head.repo.full_name == github.repository
) ||
(
github.event_name == 'issues' &&
(contains(github.event.issue.body, '@claude') || contains(github.event.issue.title, '@claude')) &&
(github.event.issue.author_association == 'OWNER' || github.event.issue.author_association == 'MEMBER' || github.event.issue.author_association == 'COLLABORATOR')
)
runs-on: ubuntu-latest
permissions:
contents: write
pull-requests: write
issues: read
id-token: write
actions: read
steps:
- name: Get PR head ref
id: pr-ref
@ -44,14 +181,22 @@ jobs:
fetch-depth: 0
ref: ${{ steps.pr-ref.outputs.ref }}
- name: Install uv
uses: astral-sh/setup-uv@v6
- name: Install dependencies
run: |
uv venv --seed
uv sync
- name: Run Claude Code
id: claude
uses: anthropics/claude-code-action@v1
with:
use_foundry: "true"
claude_args: '--allowedTools "Read,Edit,Write,Glob,Grep,Bash(git status*),Bash(git diff*),Bash(git add *),Bash(git commit *),Bash(git push*),Bash(git log*),Bash(git merge*),Bash(git fetch*),Bash(git checkout*),Bash(git branch*),Bash(uv run prek *),Bash(prek *),Bash(uv run ruff *),Bash(uv run pytest *),Bash(uv run mypy *),Bash(uv run coverage *),Bash(gh pr comment*),Bash(gh pr view*),Bash(gh pr diff*),Bash(gh pr merge*),Bash(gh pr close*)"'
additional_permissions: |
actions: read
env:
ANTHROPIC_FOUNDRY_API_KEY: ${{ secrets.AZURE_ANTHROPIC_API_KEY }}
ANTHROPIC_FOUNDRY_BASE_URL: ${{ secrets.AZURE_ANTHROPIC_ENDPOINT }}

View file

@ -1,19 +0,0 @@
name: Lint
on:
pull_request:
push:
branches:
- main
concurrency:
group: ${{ github.workflow }}-${{ github.ref_name }}
cancel-in-progress: true
jobs:
lint:
name: Run pre-commit hooks
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- uses: actions/setup-python@v5
- uses: pre-commit/action@v3.0.1

18
.github/workflows/prek.yaml vendored Normal file
View file

@ -0,0 +1,18 @@
name: Lint
on: [pull_request]
concurrency:
group: ${{ github.workflow }}-${{ github.ref_name }}
cancel-in-progress: true
jobs:
prek:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
with:
fetch-depth: 0
- uses: astral-sh/setup-uv@v6
- uses: j178/prek-action@v1
with:
extra-args: '--from-ref origin/${{ github.base_ref }} --to-ref ${{ github.sha }}'

View file

@ -1,6 +1,6 @@
repos:
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.12.7
rev: v0.15.0
hooks:
# Run the linter.
- id: ruff-check

View file

@ -62,6 +62,8 @@ codeflash/
- Use libcst, not ast - For Python, always use `libcst` for code parsing/modification to preserve formatting.
- Code context extraction and replacement tests must always assert for full string equality, no substring matching.
- Any new feature or bug fix that can be tested automatically must have test cases.
- If changes affect existing test expectations, update the tests accordingly. Tests must always pass after changes.
- NEVER use leading underscores for function names (e.g., `_helper`). Python has no true private functions. Always use public names.
## Code Style
@ -70,7 +72,7 @@ codeflash/
- **Tooling**: Ruff for linting/formatting, mypy strict mode, pre-commit hooks
- **Comments**: Minimal - only explain "why", not "what"
- **Docstrings**: Do not add unless explicitly requested
- **Naming**: Prefer public functions (no leading underscore) - Python doesn't have true private functions
- **Naming**: NEVER use leading underscores (`_function_name`) - Python has no true private functions, use public names
- **Paths**: Always use absolute paths, handle encoding explicitly (UTF-8)
## Git Commits & Pull Requests

View file

@ -20,7 +20,7 @@
}
},
"../../../packages/codeflash": {
"version": "0.3.1",
"version": "0.4.0",
"dev": true,
"hasInstallScript": true,
"license": "MIT",

View file

@ -373,11 +373,13 @@ def _handle_show_config() -> None:
detected = detect_project(project_root)
# Check if config exists or is auto-detected
config_exists, _ = has_existing_config(project_root)
config_exists, config_file = has_existing_config(project_root)
status = "Saved config" if config_exists else "Auto-detected (not saved)"
console.print()
console.print(f"[bold]Codeflash Configuration[/bold] ({status})")
if config_exists and config_file:
console.print(f"[dim]Config file: {project_root / config_file}[/dim]")
console.print()
table = Table(show_header=True, header_style="bold cyan")

View file

@ -1429,7 +1429,9 @@ def _collect_numerical_imports(tree: ast.Module) -> tuple[set[str], set[str]]:
numerical_names: set[str] = set()
modules_used: set[str] = set()
for node in ast.walk(tree):
stack: list[ast.AST] = [tree]
while stack:
node = stack.pop()
if isinstance(node, ast.Import):
for alias in node.names:
# import numpy or import numpy as np
@ -1451,6 +1453,8 @@ def _collect_numerical_imports(tree: ast.Module) -> tuple[set[str], set[str]]:
name = alias.asname if alias.asname else alias.name
numerical_names.add(name)
modules_used.add(module_root)
else:
stack.extend(ast.iter_child_nodes(node))
return numerical_names, modules_used

View file

@ -6,6 +6,8 @@ import json
from pathlib import Path
from typing import Any
from codeflash.setup.detector import is_build_output_dir
PACKAGE_JSON_CACHE: dict[Path, Path] = {}
PACKAGE_JSON_DATA_CACHE: dict[Path, dict[str, Any]] = {}
@ -50,12 +52,15 @@ def detect_module_root(project_root: Path, package_data: dict[str, Any]) -> str:
"""Detect module root from package.json fields or directory conventions.
Detection order:
1. package.json "exports" field (extract directory from main export)
2. package.json "module" field (ESM entry point)
3. package.json "main" field (CJS entry point)
4. "src/" directory if it exists
1. src/, lib/, source/ directories (common source directories)
2. package.json "exports" field (if not in build output directory)
3. package.json "module" field (ESM, if not in build output directory)
4. package.json "main" field (CJS, if not in build output directory)
5. Fall back to "." (project root)
Build output directories (build/, dist/, out/) are skipped since they contain
compiled code, not source files.
Args:
project_root: Root directory of the project.
package_data: Parsed package.json data.
@ -64,6 +69,11 @@ def detect_module_root(project_root: Path, package_data: dict[str, Any]) -> str:
Detected module root path (relative to project root).
"""
# Check for common source directories first - these are always preferred
for src_dir in ["src", "lib", "source"]:
if (project_root / src_dir).is_dir():
return src_dir
# Check exports field (modern Node.js)
exports = package_data.get("exports")
if exports:
@ -80,27 +90,38 @@ def detect_module_root(project_root: Path, package_data: dict[str, Any]) -> str:
if entry_path and isinstance(entry_path, str):
parent = Path(entry_path).parent
if parent != Path() and (project_root / parent).is_dir():
if (
parent != Path()
and parent.as_posix() != "."
and (project_root / parent).is_dir()
and not is_build_output_dir(parent)
):
return parent.as_posix()
# Check module field (ESM)
module_field = package_data.get("module")
if module_field and isinstance(module_field, str):
parent = Path(module_field).parent
if parent != Path() and (project_root / parent).is_dir():
if (
parent != Path()
and parent.as_posix() != "."
and (project_root / parent).is_dir()
and not is_build_output_dir(parent)
):
return parent.as_posix()
# Check main field (CJS)
main_field = package_data.get("main")
if main_field and isinstance(main_field, str):
parent = Path(main_field).parent
if parent != Path() and (project_root / parent).is_dir():
if (
parent != Path()
and parent.as_posix() != "."
and (project_root / parent).is_dir()
and not is_build_output_dir(parent)
):
return parent.as_posix()
# Check for src/ directory convention
if (project_root / "src").is_dir():
return "src"
# Default to project root
return "."

View file

@ -746,7 +746,11 @@ def get_external_base_class_inits(code_context: CodeStringsMarkdown, project_roo
return CodeStringsMarkdown(code_strings=[])
imported_names: dict[str, str] = {}
external_bases: list[tuple[str, str]] = []
# Use a set to deduplicate external base entries to avoid repeated expensive checks/imports.
external_bases_set: set[tuple[str, str]] = set()
# Local cache to avoid repeated _is_project_module calls for the same module_name.
is_project_cache: dict[str, bool] = {}
for node in ast.walk(tree):
if isinstance(node, ast.ImportFrom) and node.module:
for alias in node.names:
@ -763,21 +767,31 @@ def get_external_base_class_inits(code_context: CodeStringsMarkdown, project_roo
if base_name and base_name in imported_names:
module_name = imported_names[base_name]
if not _is_project_module(module_name, project_root_path):
external_bases.append((base_name, module_name))
# Check cache first to avoid repeated expensive checks.
cached = is_project_cache.get(module_name)
if cached is None:
is_project = _is_project_module(module_name, project_root_path)
is_project_cache[module_name] = is_project
else:
is_project = cached
if not external_bases:
if not is_project:
external_bases_set.add((base_name, module_name))
if not external_bases_set:
return CodeStringsMarkdown(code_strings=[])
code_strings: list[CodeString] = []
extracted: set[tuple[str, str]] = set()
for base_name, module_name in external_bases:
if (module_name, base_name) in extracted:
continue
# Cache imported modules to avoid repeated importlib.import_module calls.
imported_module_cache: dict[str, object] = {}
for base_name, module_name in external_bases_set:
try:
module = importlib.import_module(module_name)
module = imported_module_cache.get(module_name)
if module is None:
module = importlib.import_module(module_name)
imported_module_cache[module_name] = module
base_class = getattr(module, base_name, None)
if base_class is None:
continue
@ -799,7 +813,6 @@ def get_external_base_class_inits(code_context: CodeStringsMarkdown, project_roo
class_source = f"class {base_name}:\n" + textwrap.indent(init_source, " ")
code_strings.append(CodeString(code=class_source, file_path=class_file))
extracted.add((module_name, base_name))
except (ImportError, ModuleNotFoundError, AttributeError):
logger.debug(f"Failed to extract __init__ for {module_name}.{base_name}")
@ -854,12 +867,13 @@ def extract_imports_for_class(module_tree: ast.Module, class_node: ast.ClassDef,
needed_names.add(decorator.func.value.id)
# Get type annotation names from class body (for dataclass fields)
for item in ast.walk(class_node):
for item in class_node.body:
if isinstance(item, ast.AnnAssign) and item.annotation:
collect_names_from_annotation(item.annotation, needed_names)
# Also check for field() calls which are common in dataclasses
if isinstance(item, ast.Call) and isinstance(item.func, ast.Name):
needed_names.add(item.func.id)
elif isinstance(item, ast.Assign) and isinstance(item.value, ast.Call):
if isinstance(item.value.func, ast.Name):
needed_names.add(item.value.func.id)
# Find imports that provide these names
import_lines: list[str] = []

View file

@ -656,7 +656,7 @@ def discover_unit_tests(
# Existing Python logic
framework_strategies: dict[str, Callable] = {"pytest": discover_tests_pytest, "unittest": discover_tests_unittest}
strategy = framework_strategies.get(cfg.test_framework, None)
strategy = framework_strategies.get(cfg.test_framework)
if not strategy:
error_message = f"Unsupported test framework: {cfg.test_framework}"
raise ValueError(error_message)

View file

@ -839,11 +839,13 @@ def filter_functions(
if any(pattern in file_lower for pattern in test_file_name_patterns):
return True
# Check directory patterns, but only within the project root
# to avoid false positives from parent directories
relative_path = file_lower
# to avoid false positives from parent directories (e.g., project at /home/user/tests/myproject)
if project_root_str and file_lower.startswith(project_root_str.lower()):
relative_path = file_lower[len(project_root_str) :]
return any(pattern in relative_path for pattern in test_dir_patterns)
return any(pattern in relative_path for pattern in test_dir_patterns)
# If we can't compute relative path from project root, don't check directory patterns
# This avoids false positives when project is inside a folder named "tests"
return False
# Use directory-based filtering when tests are in a separate directory
return file_path_normalized.startswith(tests_root_str + os.sep)

View file

@ -25,11 +25,11 @@ class PrComment:
best_async_throughput: Optional[int] = None
def to_json(self) -> dict[str, Union[str, int, dict[str, dict[str, int]], list[BenchmarkDetail], None]]:
report_table = {
test_type.to_name(): result
for test_type, result in self.winning_behavior_test_results.get_test_pass_fail_report_by_type().items()
if test_type.to_name()
}
report_table: dict[str, dict[str, int]] = {}
for test_type, result in self.winning_behavior_test_results.get_test_pass_fail_report_by_type().items():
name = test_type.to_name()
if name:
report_table[name] = result
result: dict[str, Union[str, int, dict[str, dict[str, int]], list[BenchmarkDetail], None]] = {
"optimization_explanation": self.optimization_explanation,

View file

@ -58,9 +58,6 @@ def set_current_language(language: Language | str) -> None:
"""
global _current_language
if _current_language is not None:
return
_current_language = Language(language) if isinstance(language, str) else language

View file

@ -210,7 +210,7 @@ class ReferenceFinder:
# Check if this file imports from the re-export file
import_info = self._find_matching_import(imports, reexport_file, file_path, reexported)
trigger_check = True
if import_info:
context.visited_files.add(file_path)
import_name, _original_import = import_info
@ -651,15 +651,18 @@ class ReferenceFinder:
"""
references: list[Reference] = []
export_name = exported.export_name or exported.function_name
# Skip expensive parsing if export name not in source
if export_name not in source_code:
return references
exports = analyzer.find_exports(source_code)
lines = source_code.splitlines()
for exp in exports:
if not exp.is_reexport:
continue
# Check if this re-exports our function
export_name = exported.export_name or exported.function_name
for name, alias in exp.exported_names:
if name == export_name:
# This is a re-export of our function

View file

@ -11,6 +11,8 @@ import logging
import re
from typing import TYPE_CHECKING
from codeflash.languages.current import is_typescript
if TYPE_CHECKING:
from pathlib import Path
@ -44,9 +46,10 @@ def detect_module_system(project_root: Path, file_path: Path | None = None) -> s
"""Detect the module system used by a JavaScript/TypeScript project.
Detection strategy:
1. Check package.json for "type" field
2. If file_path provided, check file extension (.mjs = ESM, .cjs = CommonJS)
3. Analyze import statements in the file
1. Check file extension for explicit module type (.mjs, .cjs, .ts, .tsx, .mts)
- TypeScript files always use ESM syntax regardless of package.json
2. Check package.json for explicit "type" field (only if explicitly set)
3. Analyze import/export statements in the file content
4. Default to CommonJS if uncertain
Args:
@ -57,13 +60,29 @@ def detect_module_system(project_root: Path, file_path: Path | None = None) -> s
ModuleSystem constant (COMMONJS, ES_MODULE, or UNKNOWN).
"""
# Strategy 1: Check package.json
# Strategy 1: Check file extension first for explicit module type indicators
# TypeScript files always use ESM syntax (import/export)
if file_path:
suffix = file_path.suffix.lower()
if suffix == ".mjs":
logger.debug("Detected ES Module from .mjs extension")
return ModuleSystem.ES_MODULE
if suffix == ".cjs":
logger.debug("Detected CommonJS from .cjs extension")
return ModuleSystem.COMMONJS
if suffix in (".ts", ".tsx", ".mts"):
# TypeScript always uses ESM syntax (import/export)
# even if package.json doesn't have "type": "module"
logger.debug("Detected ES Module from TypeScript file extension")
return ModuleSystem.ES_MODULE
# Strategy 2: Check package.json for explicit type field
package_json = project_root / "package.json"
if package_json.exists():
try:
with package_json.open("r") as f:
pkg = json.load(f)
pkg_type = pkg.get("type", "commonjs")
pkg_type = pkg.get("type") # Don't default - only use if explicitly set
if pkg_type == "module":
logger.debug("Detected ES Module from package.json type field")
@ -71,44 +90,35 @@ def detect_module_system(project_root: Path, file_path: Path | None = None) -> s
if pkg_type == "commonjs":
logger.debug("Detected CommonJS from package.json type field")
return ModuleSystem.COMMONJS
# If type is not explicitly set, continue to file content analysis
except Exception as e:
logger.warning("Failed to parse package.json: %s", e)
# Strategy 2: Check file extension
if file_path:
suffix = file_path.suffix
if suffix == ".mjs":
logger.debug("Detected ES Module from .mjs extension")
return ModuleSystem.ES_MODULE
if suffix == ".cjs":
logger.debug("Detected CommonJS from .cjs extension")
return ModuleSystem.COMMONJS
# Strategy 3: Analyze file content for import/export patterns
if file_path and file_path.exists():
try:
content = file_path.read_text()
# Strategy 3: Analyze file content
if file_path.exists():
try:
content = file_path.read_text()
# Look for ES module syntax
has_import = "import " in content and "from " in content
has_export = "export " in content or "export default" in content or "export {" in content
# Look for ES module syntax
has_import = "import " in content and "from " in content
has_export = "export " in content or "export default" in content or "export {" in content
# Look for CommonJS syntax
has_require = "require(" in content
has_module_exports = "module.exports" in content or "exports." in content
# Look for CommonJS syntax
has_require = "require(" in content
has_module_exports = "module.exports" in content or "exports." in content
# Determine based on what we found
if (has_import or has_export) and not (has_require or has_module_exports):
logger.debug("Detected ES Module from import/export statements")
return ModuleSystem.ES_MODULE
# Determine based on what we found
if (has_import or has_export) and not (has_require or has_module_exports):
logger.debug("Detected ES Module from import/export statements")
return ModuleSystem.ES_MODULE
if (has_require or has_module_exports) and not (has_import or has_export):
logger.debug("Detected CommonJS from require/module.exports")
return ModuleSystem.COMMONJS
if (has_require or has_module_exports) and not (has_import or has_export):
logger.debug("Detected CommonJS from require/module.exports")
return ModuleSystem.COMMONJS
except Exception as e:
logger.warning("Failed to analyze file %s: %s", file_path, e)
except Exception as e:
logger.warning("Failed to analyze file %s: %s", file_path, e)
# Default to CommonJS (more common and backward compatible)
logger.debug("Defaulting to CommonJS")
@ -199,11 +209,42 @@ def add_js_extension(module_path: str) -> str:
return module_path
def _convert_destructuring_to_imports(names_str: str) -> str:
"""Convert destructuring aliases to import aliases.
Converts:
a, b -> a, b
a: aliasA -> a as aliasA
a, b: aliasB -> a, b as aliasB
Args:
names_str: The destructuring pattern string (e.g., "a, b: aliasB")
Returns:
Import names string with aliases using 'as' syntax
"""
# Split by commas and process each name
parts = []
for name in names_str.split(","):
name = name.strip()
if ":" in name:
# Convert destructuring alias to import alias
# "a: aliasA" -> "a as aliasA"
original, alias = name.split(":", 1)
parts.append(f"{original.strip()} as {alias.strip()}")
else:
parts.append(name)
return ", ".join(parts)
# Replace destructured requires with named imports
def replace_destructured(match: re.Match) -> str:
names = match.group(2).strip()
module_path = add_js_extension(match.group(3))
return f"import {{ {names} }} from '{module_path}';"
# Convert destructuring aliases (a: b) to import aliases (a as b)
converted_names = _convert_destructuring_to_imports(names)
return f"import {{ {converted_names} }} from '{module_path}';"
# Replace property access requires with named imports with alias
@ -234,12 +275,14 @@ def convert_commonjs_to_esm(code: str) -> str:
"""Convert CommonJS require statements to ES Module imports.
Converts:
const { foo, bar } = require('./module'); -> import { foo, bar } from './module';
const foo = require('./module'); -> import foo from './module';
const foo = require('./module').default; -> import foo from './module';
const foo = require('./module').bar; -> import { bar as foo } from './module';
const { foo, bar } = require('./module'); -> import { foo, bar } from './module';
const { foo: alias } = require('./module'); -> import { foo as alias } from './module';
const foo = require('./module'); -> import foo from './module';
const foo = require('./module').default; -> import foo from './module';
const foo = require('./module').bar; -> import { bar as foo } from './module';
Special handling:
- Destructuring aliases (a: b) are converted to import aliases (a as b)
- Local codeflash helper (./codeflash-jest-helper) is converted to npm package codeflash
because the local helper uses CommonJS exports which don't work in ESM projects
@ -299,36 +342,89 @@ def convert_esm_to_commonjs(code: str) -> str:
return default_import.sub(replace_default, code)
def ensure_module_system_compatibility(code: str, target_module_system: str) -> str:
def uses_ts_jest(project_root: Path) -> bool:
"""Check if the project uses ts-jest for TypeScript transformation.
ts-jest handles module interoperability internally, allowing mixed
CommonJS/ESM imports without explicit conversion.
Args:
project_root: The project root directory.
Returns:
True if ts-jest is being used, False otherwise.
"""
# Check for ts-jest in devDependencies or dependencies
package_json = project_root / "package.json"
if package_json.exists():
try:
with package_json.open("r") as f:
pkg = json.load(f)
dev_deps = pkg.get("devDependencies", {})
deps = pkg.get("dependencies", {})
if "ts-jest" in dev_deps or "ts-jest" in deps:
return True
except Exception as e:
logger.debug(f"Failed to read package.json for ts-jest detection: {e}") # noqa: G004
# Also check for jest.config with ts-jest preset
for config_file in ["jest.config.js", "jest.config.cjs", "jest.config.ts", "jest.config.mjs"]:
config_path = project_root / config_file
if config_path.exists():
try:
content = config_path.read_text()
if "ts-jest" in content:
return True
except Exception as e:
logger.debug(f"Failed to read {config_file}: {e}") # noqa: G004
return False
def ensure_module_system_compatibility(code: str, target_module_system: str, project_root: Path | None = None) -> str:
"""Ensure code uses the correct module system syntax.
Detects the current module system in the code and converts if needed.
Handles mixed-style code (e.g., ESM imports with CommonJS require for npm packages).
If the project uses ts-jest, no conversion is performed because ts-jest
handles module interoperability internally. Otherwise, converts between
CommonJS and ES Modules as needed.
Args:
code: JavaScript code to check and potentially convert.
target_module_system: Target ModuleSystem (COMMONJS or ES_MODULE).
project_root: Project root directory for ts-jest detection.
Returns:
Code with correct module system syntax.
Converted code, or unchanged if ts-jest handles interop.
"""
# If ts-jest is installed, skip conversion - it handles interop natively
if is_typescript() and project_root and uses_ts_jest(project_root):
logger.debug(
f"Skipping module system conversion (target was {target_module_system}). " # noqa: G004
"ts-jest handles interop natively."
)
return code
# Detect current module system in code
has_require = "require(" in code
has_module_exports = "module.exports" in code or "exports." in code
has_import = "import " in code and "from " in code
has_export = "export " in code
if target_module_system == ModuleSystem.ES_MODULE:
# Convert any require() statements to imports for ESM projects
# This handles mixed code (ESM imports + CommonJS requires for npm packages)
if has_require:
logger.debug("Converting CommonJS requires to ESM imports")
return convert_commonjs_to_esm(code)
elif target_module_system == ModuleSystem.COMMONJS:
# Convert any import statements to requires for CommonJS projects
if has_import:
logger.debug("Converting ESM imports to CommonJS requires")
return convert_esm_to_commonjs(code)
is_commonjs = has_require or has_module_exports
is_esm = has_import or has_export
# Convert if needed
if target_module_system == ModuleSystem.ES_MODULE and is_commonjs and not is_esm:
logger.debug("Converting CommonJS to ES Module syntax")
return convert_commonjs_to_esm(code)
if target_module_system == ModuleSystem.COMMONJS and is_esm and not is_commonjs:
logger.debug("Converting ES Module to CommonJS syntax")
return convert_esm_to_commonjs(code)
logger.debug("No module system conversion needed")
return code
@ -355,12 +451,8 @@ def ensure_vitest_imports(code: str, test_framework: str) -> str:
# Check if the code uses test functions that need to be imported
test_globals = ["describe", "test", "it", "expect", "vi", "beforeEach", "afterEach", "beforeAll", "afterAll"]
needs_import = any(f"{global_name}(" in code or f"{global_name} (" in code for global_name in test_globals)
if not needs_import:
return code
# Determine which globals are actually used in the code
# Combine detection and collection into a single pass
used_globals = [g for g in test_globals if f"{g}(" in code or f"{g} (" in code]
if not used_globals:
return code

View file

@ -296,6 +296,33 @@ def _find_node_project_root(file_path: Path) -> Path | None:
return None
def _find_monorepo_root(start_path: Path) -> Path | None:
"""Find the monorepo workspace root by looking for workspace markers.
Traverses up from the given path to find a directory containing
monorepo workspace markers like yarn.lock, pnpm-workspace.yaml, etc.
Args:
start_path: A path within the monorepo.
Returns:
The monorepo root directory, or None if not found.
"""
monorepo_markers = ["yarn.lock", "pnpm-workspace.yaml", "lerna.json", "package-lock.json"]
current = start_path if start_path.is_dir() else start_path.parent
while current != current.parent:
# Check for monorepo markers
if any((current / marker).exists() for marker in monorepo_markers):
# Verify it has node_modules (it's the workspace root)
if (current / "node_modules").exists():
return current
current = current.parent
return None
def _find_jest_config(project_root: Path) -> Path | None:
"""Find Jest configuration file in the project.
@ -797,6 +824,12 @@ def run_jest_benchmarking_tests(
jest_env["JEST_JUNIT_SUITE_NAME"] = "{filepath}"
jest_env["JEST_JUNIT_ADD_FILE_ATTRIBUTE"] = "true"
jest_env["JEST_JUNIT_INCLUDE_CONSOLE_OUTPUT"] = "true"
# Pass monorepo root to loop-runner for jest-runner resolution
monorepo_root = _find_monorepo_root(effective_cwd)
if monorepo_root:
jest_env["CODEFLASH_MONOREPO_ROOT"] = str(monorepo_root)
logger.debug(f"Detected monorepo root: {monorepo_root}")
codeflash_sqlite_file = get_run_tmp_file(Path("test_return_values_0.sqlite"))
jest_env["CODEFLASH_OUTPUT_FILE"] = str(codeflash_sqlite_file)
jest_env["CODEFLASH_TEST_ITERATION"] = "0"

View file

@ -321,7 +321,7 @@ class PythonSupport:
function_pos = None
for name in names:
if name.type == "function" and name.name == function.name:
if name.type == "function" and name.name == function.function_name:
# Check for class parent if it's a method
if function.class_name:
parent = name.parent()

View file

@ -258,6 +258,11 @@ class TreeSitterAnalyzer:
if func_info.is_arrow and not include_arrow_functions:
should_include = False
# Skip arrow functions that are object properties (e.g., { foo: () => {} })
# These are not standalone functions - they're values in object literals
if func_info.is_arrow and node.parent and node.parent.type == "pair":
should_include = False
if should_include:
functions.append(func_info)

View file

@ -15,6 +15,7 @@ from codeflash.models.test_type import TestType
if TYPE_CHECKING:
from collections.abc import Iterator
import enum
import re
import sys
@ -875,15 +876,14 @@ class TestResults(BaseModel): # noqa: PLW1641
return max(test_result.loop_index for test_result in self.test_results)
def get_test_pass_fail_report_by_type(self) -> dict[TestType, dict[str, int]]:
report = {}
for test_type in TestType:
report[test_type] = {"passed": 0, "failed": 0}
report: dict[TestType, dict[str, int]] = {tt: {"passed": 0, "failed": 0} for tt in TestType}
for test_result in self.test_results:
if test_result.loop_index == 1:
if test_result.did_pass:
report[test_result.test_type]["passed"] += 1
else:
report[test_result.test_type]["failed"] += 1
if test_result.loop_index != 1:
continue
if test_result.did_pass:
report[test_result.test_type]["passed"] += 1
else:
report[test_result.test_type]["failed"] += 1
return report
@staticmethod

View file

@ -12,11 +12,13 @@ class TestType(Enum):
def to_name(self) -> str:
if self is TestType.INIT_STATE_TEST:
return ""
names = {
TestType.EXISTING_UNIT_TEST: "⚙️ Existing Unit Tests",
TestType.INSPIRED_REGRESSION: "🎨 Inspired Regression Tests",
TestType.GENERATED_REGRESSION: "🌀 Generated Regression Tests",
TestType.REPLAY_TEST: "⏪ Replay Tests",
TestType.CONCOLIC_COVERAGE_TEST: "🔎 Concolic Coverage Tests",
}
return names[self]
return _TO_NAME_MAP[self]
_TO_NAME_MAP: dict[TestType, str] = {
TestType.EXISTING_UNIT_TEST: "⚙️ Existing Unit Tests",
TestType.INSPIRED_REGRESSION: "🎨 Inspired Regression Tests",
TestType.GENERATED_REGRESSION: "🌀 Generated Regression Tests",
TestType.REPLAY_TEST: "⏪ Replay Tests",
TestType.CONCOLIC_COVERAGE_TEST: "🔎 Concolic Coverage Tests",
}

View file

@ -556,6 +556,16 @@ class FunctionOptimizer:
should_run_experiment = self.experiment_id is not None
logger.info(f"!lsp|Function Trace ID: {self.function_trace_id}")
ph("cli-optimize-function-start", {"function_trace_id": self.function_trace_id})
# Early check: if --no-gen-tests is set, verify there are existing tests for this function
if self.args.no_gen_tests:
func_qualname = self.function_to_optimize.qualified_name_with_modules_from_root(self.project_root)
if not self.function_to_tests.get(func_qualname):
return Failure(
f"No existing tests found for '{self.function_to_optimize.function_name}'. "
f"Cannot optimize without tests when --no-gen-tests is set."
)
self.cleanup_leftover_test_return_values()
file_name_from_test_module_name.cache_clear()
ctx_result = self.get_code_optimization_context()
@ -626,7 +636,7 @@ class FunctionOptimizer:
# Normalize codeflash imports in JS/TS tests to use npm package
if not is_python():
module_system = detect_module_system(self.project_root)
module_system = detect_module_system(self.project_root, self.function_to_optimize.file_path)
if module_system == "esm":
generated_tests = inject_test_globals(generated_tests)
if is_typescript():

View file

@ -21,6 +21,8 @@ from typing import Any
import tomlkit
_BUILD_DIRS = frozenset({"build", "dist", "out", ".next", ".nuxt"})
@dataclass
class DetectedProject:
@ -310,14 +312,21 @@ def _detect_js_module_root(project_root: Path) -> tuple[Path, str]:
"""Detect JavaScript/TypeScript module root.
Priority:
1. package.json "exports" field
2. package.json "module" field (ESM)
3. package.json "main" field (CJS)
4. src/ directory
5. lib/ directory
6. Project root
1. src/, lib/, source/ directories (common source directories)
2. package.json "exports" field (if not in build output directory)
3. package.json "module" field (ESM, if not in build output directory)
4. package.json "main" field (CJS, if not in build output directory)
5. Project root
Build output directories (build/, dist/, out/) are skipped since they contain
compiled code, not source files.
"""
# Check for common source directories first - these are always preferred
for src_dir in ["src", "lib", "source"]:
if (project_root / src_dir).is_dir():
return project_root / src_dir, f"{src_dir}/ directory"
package_json_path = project_root / "package.json"
package_data: dict[str, Any] = {}
@ -334,32 +343,52 @@ def _detect_js_module_root(project_root: Path) -> tuple[Path, str]:
entry_path = _extract_entry_path(exports)
if entry_path:
parent = Path(entry_path).parent
if parent != Path() and parent.as_posix() != "." and (project_root / parent).is_dir():
if (
parent != Path()
and parent.as_posix() != "."
and (project_root / parent).is_dir()
and not is_build_output_dir(parent)
):
return project_root / parent, f'{parent.as_posix()}/ (from package.json "exports")'
# Check module field (ESM)
module_field = package_data.get("module")
if module_field and isinstance(module_field, str):
parent = Path(module_field).parent
if parent != Path() and parent.as_posix() != "." and (project_root / parent).is_dir():
if (
parent != Path()
and parent.as_posix() != "."
and (project_root / parent).is_dir()
and not is_build_output_dir(parent)
):
return project_root / parent, f'{parent.as_posix()}/ (from package.json "module")'
# Check main field (CJS)
main_field = package_data.get("main")
if main_field and isinstance(main_field, str):
parent = Path(main_field).parent
if parent != Path() and parent.as_posix() != "." and (project_root / parent).is_dir():
if (
parent != Path()
and parent.as_posix() != "."
and (project_root / parent).is_dir()
and not is_build_output_dir(parent)
):
return project_root / parent, f'{parent.as_posix()}/ (from package.json "main")'
# Check for common source directories
for src_dir in ["src", "lib", "source"]:
if (project_root / src_dir).is_dir():
return project_root / src_dir, f"{src_dir}/ directory"
# Default to project root
return project_root, "project root"
def is_build_output_dir(path: Path) -> bool:
"""Check if a path is within a common build output directory.
Build output directories contain compiled code and should be skipped
in favor of source directories.
"""
return not _BUILD_DIRS.isdisjoint(path.parts)
def _extract_entry_path(exports: Any) -> str | None:
"""Extract entry path from package.json exports field."""
if isinstance(exports, str):

View file

@ -180,19 +180,31 @@ def resolve_test_file_from_class_path(test_class_path: str, base_dir: Path) -> P
# Handle file paths (contain slashes and extensions like .js/.ts)
if "/" in test_class_path or "\\" in test_class_path:
# This is a file path, not a Python module path
# Try the path as-is if it's absolute
potential_path = Path(test_class_path)
if potential_path.is_absolute() and potential_path.exists():
return potential_path
# Try to resolve relative to base_dir's parent (project root)
project_root = base_dir.parent
potential_path = project_root / test_class_path
if potential_path.exists():
return potential_path
# Normalize to resolve .. and . components
try:
potential_path = potential_path.resolve()
if potential_path.exists():
return potential_path
except (OSError, RuntimeError):
pass
# Also try relative to base_dir itself
potential_path = base_dir / test_class_path
if potential_path.exists():
return potential_path
# Try the path as-is if it's absolute
potential_path = Path(test_class_path)
if potential_path.exists():
return potential_path
try:
potential_path = potential_path.resolve()
if potential_path.exists():
return potential_path
except (OSError, RuntimeError):
pass
return None
# First try the full path (Python module path)
@ -795,16 +807,25 @@ def parse_jest_test_xml(
if not test_file_path.exists():
test_file_path = base_dir / test_file_name
if test_file_path is None or not test_file_path.exists():
# For Jest tests in monorepos, test files may not exist after cleanup
# but we can still parse results and infer test type from the path
if test_file_path is None:
logger.warning(f"Could not resolve test file for Jest test: {test_class_path}")
continue
# Get test type if not already set from lookup
if test_type is None:
if test_type is None and test_file_path.exists():
test_type = test_files.get_test_type_by_instrumented_file_path(test_file_path)
if test_type is None:
# Default to GENERATED_REGRESSION for Jest tests
test_type = TestType.GENERATED_REGRESSION
# Infer test type from filename pattern
filename = test_file_path.name
if "__perf_test_" in filename or "_perf_test_" in filename:
test_type = TestType.GENERATED_PERFORMANCE
elif "__unit_test_" in filename or "_unit_test_" in filename:
test_type = TestType.GENERATED_REGRESSION
else:
# Default to GENERATED_REGRESSION for Jest tests
test_type = TestType.GENERATED_REGRESSION
# For Jest tests, keep the relative file path with extension intact
# (Python uses module_name_from_file_path which strips extensions)

View file

@ -82,7 +82,10 @@ def generate_tests(
)
# Convert module system if needed (e.g., CommonJS -> ESM for ESM projects)
generated_test_source = ensure_module_system_compatibility(generated_test_source, project_module_system)
# Skip conversion if ts-jest is installed (handles interop natively)
generated_test_source = ensure_module_system_compatibility(
generated_test_source, project_module_system, test_cfg.tests_project_rootdir
)
# Ensure vitest imports are present when using vitest framework
generated_test_source = ensure_vitest_imports(generated_test_source, test_cfg.test_framework)

View file

@ -1,60 +0,0 @@
# codeflash
AI-powered code performance optimization for JavaScript and TypeScript.
## Installation
```bash
npm install -g codeflash
# or
npx codeflash
```
## Quick Start
1. Get your API key from [codeflash.ai](https://codeflash.ai)
2. Set your API key:
```bash
export CODEFLASH_API_KEY=your-api-key
```
3. Optimize a function:
```bash
codeflash --file src/utils.ts --function slowFunction
```
## Usage
```bash
# Optimize a specific function
codeflash --file <path> --function <name>
# Optimize all functions in a directory
codeflash --all src/
# Initialize GitHub Actions workflow
codeflash init-actions
# Verify setup
codeflash --verify-setup
```
## Requirements
- Node.js >= 16.0.0
- A codeflash API key
## Supported Platforms
- Linux (x64, arm64)
- macOS (x64, arm64)
- Windows (x64)
## Documentation
See [codeflash.ai/docs](https://codeflash.ai/docs) for full documentation.
## License
BSL-1.1

View file

@ -1,47 +0,0 @@
#!/usr/bin/env node
/**
* Wrapper script for codeflash CLI.
* Invokes the downloaded binary with all passed arguments.
*/
const { spawn } = require('child_process');
const path = require('path');
const fs = require('fs');
function getBinaryPath() {
const binDir = __dirname;
const isWindows = process.platform === 'win32';
return path.join(binDir, isWindows ? 'codeflash.exe' : 'codeflash-binary');
}
function main() {
const binaryPath = getBinaryPath();
if (!fs.existsSync(binaryPath)) {
console.error('\x1b[31mError: codeflash binary not found.\x1b[0m');
console.error('Try reinstalling: npm install codeflash');
process.exit(1);
}
// Pass all arguments to the binary
const args = process.argv.slice(2);
const child = spawn(binaryPath, args, {
stdio: 'inherit',
env: process.env,
});
child.on('error', (error) => {
console.error(`\x1b[31mError running codeflash: ${error.message}\x1b[0m`);
process.exit(1);
});
child.on('exit', (code, signal) => {
if (signal) {
process.exit(1);
}
process.exit(code || 0);
});
}
main();

View file

@ -1,47 +0,0 @@
{
"name": "codeflash",
"version": "0.0.0",
"description": "AI-powered code performance optimization - automatically find and fix slow code",
"keywords": [
"codeflash",
"performance",
"optimization",
"ai",
"code",
"profiler",
"typescript",
"javascript"
],
"author": "CodeFlash Inc. <contact@codeflash.ai>",
"license": "BSL-1.1",
"homepage": "https://codeflash.ai",
"repository": {
"type": "git",
"url": "git+https://github.com/codeflash-ai/codeflash.git"
},
"bugs": {
"url": "https://github.com/codeflash-ai/codeflash/issues"
},
"bin": {
"codeflash": "./bin/codeflash"
},
"scripts": {
"postinstall": "node lib/install.js"
},
"engines": {
"node": ">=16.0.0"
},
"os": [
"darwin",
"linux",
"win32"
],
"cpu": [
"x64",
"arm64"
],
"files": [
"bin/",
"lib/"
]
}

View file

@ -1,12 +1,12 @@
{
"name": "codeflash",
"version": "0.5.0",
"version": "0.7.0",
"lockfileVersion": 3,
"requires": true,
"packages": {
"": {
"name": "codeflash",
"version": "0.5.0",
"version": "0.7.0",
"hasInstallScript": true,
"license": "MIT",
"dependencies": {

View file

@ -1,6 +1,6 @@
{
"name": "codeflash",
"version": "0.5.0",
"version": "0.7.0",
"description": "Codeflash - AI-powered code optimization for JavaScript and TypeScript",
"main": "runtime/index.js",
"types": "runtime/index.d.ts",

View file

@ -0,0 +1,73 @@
/**
* Test: Dynamic environment variable reading
*
* This test verifies that the performance configuration functions read
* environment variables at runtime rather than at module load time.
*
* This is critical for Vitest compatibility, where modules may be cached
* and loaded before environment variables are set.
*
* Run with: node __tests__/dynamic-env-vars.test.js
*/
const assert = require('assert');
// Clear any existing env vars before loading the module
delete process.env.CODEFLASH_PERF_LOOP_COUNT;
delete process.env.CODEFLASH_PERF_MIN_LOOPS;
delete process.env.CODEFLASH_PERF_TARGET_DURATION_MS;
delete process.env.CODEFLASH_PERF_BATCH_SIZE;
delete process.env.CODEFLASH_PERF_STABILITY_CHECK;
delete process.env.CODEFLASH_PERF_CURRENT_BATCH;
// Now load the module - at this point env vars are not set
const capture = require('../capture');
console.log('Testing dynamic environment variable reading...\n');
// Test 1: Default values when env vars are not set
console.log('Test 1: Default values');
assert.strictEqual(capture.getPerfLoopCount(), 1, 'getPerfLoopCount default should be 1');
assert.strictEqual(capture.getPerfMinLoops(), 5, 'getPerfMinLoops default should be 5');
assert.strictEqual(capture.getPerfTargetDurationMs(), 10000, 'getPerfTargetDurationMs default should be 10000');
assert.strictEqual(capture.getPerfBatchSize(), 10, 'getPerfBatchSize default should be 10');
assert.strictEqual(capture.getPerfStabilityCheck(), false, 'getPerfStabilityCheck default should be false');
assert.strictEqual(capture.getPerfCurrentBatch(), 0, 'getPerfCurrentBatch default should be 0');
console.log(' PASS: All defaults correct\n');
// Test 2: Values change when env vars are set AFTER module load
// This is the critical test - if these were constants, they would still return defaults
console.log('Test 2: Dynamic reading after module load');
process.env.CODEFLASH_PERF_LOOP_COUNT = '100';
process.env.CODEFLASH_PERF_MIN_LOOPS = '10';
process.env.CODEFLASH_PERF_TARGET_DURATION_MS = '5000';
process.env.CODEFLASH_PERF_BATCH_SIZE = '20';
process.env.CODEFLASH_PERF_STABILITY_CHECK = 'true';
process.env.CODEFLASH_PERF_CURRENT_BATCH = '5';
assert.strictEqual(capture.getPerfLoopCount(), 100, 'getPerfLoopCount should read 100 from env');
assert.strictEqual(capture.getPerfMinLoops(), 10, 'getPerfMinLoops should read 10 from env');
assert.strictEqual(capture.getPerfTargetDurationMs(), 5000, 'getPerfTargetDurationMs should read 5000 from env');
assert.strictEqual(capture.getPerfBatchSize(), 20, 'getPerfBatchSize should read 20 from env');
assert.strictEqual(capture.getPerfStabilityCheck(), true, 'getPerfStabilityCheck should read true from env');
assert.strictEqual(capture.getPerfCurrentBatch(), 5, 'getPerfCurrentBatch should read 5 from env');
console.log(' PASS: Dynamic reading works correctly\n');
// Test 3: Values change again when env vars are modified
console.log('Test 3: Values update when env vars change');
process.env.CODEFLASH_PERF_LOOP_COUNT = '500';
process.env.CODEFLASH_PERF_BATCH_SIZE = '50';
assert.strictEqual(capture.getPerfLoopCount(), 500, 'getPerfLoopCount should update to 500');
assert.strictEqual(capture.getPerfBatchSize(), 50, 'getPerfBatchSize should update to 50');
console.log(' PASS: Values update correctly\n');
// Cleanup
delete process.env.CODEFLASH_PERF_LOOP_COUNT;
delete process.env.CODEFLASH_PERF_MIN_LOOPS;
delete process.env.CODEFLASH_PERF_TARGET_DURATION_MS;
delete process.env.CODEFLASH_PERF_BATCH_SIZE;
delete process.env.CODEFLASH_PERF_STABILITY_CHECK;
delete process.env.CODEFLASH_PERF_CURRENT_BATCH;
console.log('All tests passed!');

View file

@ -47,14 +47,30 @@ const TEST_MODULE = process.env.CODEFLASH_TEST_MODULE;
// Batch 1: Test1(5 loops) → Test2(5 loops) → Test3(5 loops)
// Batch 2: Test1(5 loops) → Test2(5 loops) → Test3(5 loops)
// ...until time budget exhausted
const PERF_LOOP_COUNT = parseInt(process.env.CODEFLASH_PERF_LOOP_COUNT || '1', 10);
const PERF_MIN_LOOPS = parseInt(process.env.CODEFLASH_PERF_MIN_LOOPS || '5', 10);
const PERF_TARGET_DURATION_MS = parseInt(process.env.CODEFLASH_PERF_TARGET_DURATION_MS || '10000', 10);
const PERF_BATCH_SIZE = parseInt(process.env.CODEFLASH_PERF_BATCH_SIZE || '10', 10);
const PERF_STABILITY_CHECK = (process.env.CODEFLASH_PERF_STABILITY_CHECK || 'false').toLowerCase() === 'true';
//
// IMPORTANT: These are getter functions, NOT constants!
// Vitest caches modules and may load this file before env vars are set.
// Using getter functions ensures we read the env vars at runtime when they're actually needed.
function getPerfLoopCount() {
return parseInt(process.env.CODEFLASH_PERF_LOOP_COUNT || '1', 10);
}
function getPerfMinLoops() {
return parseInt(process.env.CODEFLASH_PERF_MIN_LOOPS || '5', 10);
}
function getPerfTargetDurationMs() {
return parseInt(process.env.CODEFLASH_PERF_TARGET_DURATION_MS || '10000', 10);
}
function getPerfBatchSize() {
return parseInt(process.env.CODEFLASH_PERF_BATCH_SIZE || '10', 10);
}
function getPerfStabilityCheck() {
return (process.env.CODEFLASH_PERF_STABILITY_CHECK || 'false').toLowerCase() === 'true';
}
// Current batch number - set by loop-runner before each batch
// This allows continuous loop indices even when Jest resets module state
const PERF_CURRENT_BATCH = parseInt(process.env.CODEFLASH_PERF_CURRENT_BATCH || '0', 10);
function getPerfCurrentBatch() {
return parseInt(process.env.CODEFLASH_PERF_CURRENT_BATCH || '0', 10);
}
// Stability constants (matching Python's config_consts.py)
const STABILITY_WINDOW_SIZE = 0.35;
@ -86,7 +102,7 @@ function checkSharedTimeLimit() {
return false;
}
const elapsed = Date.now() - sharedPerfState.startTime;
if (elapsed >= PERF_TARGET_DURATION_MS && sharedPerfState.totalLoopsCompleted >= PERF_MIN_LOOPS) {
if (elapsed >= getPerfTargetDurationMs() && sharedPerfState.totalLoopsCompleted >= getPerfMinLoops()) {
sharedPerfState.shouldStop = true;
return true;
}
@ -111,7 +127,7 @@ function getInvocationLoopIndex(invocationKey) {
// Calculate global loop index using batch number from environment
// PERF_CURRENT_BATCH is 1-based (set by loop-runner before each batch)
const currentBatch = parseInt(process.env.CODEFLASH_PERF_CURRENT_BATCH || '1', 10);
const globalIndex = (currentBatch - 1) * PERF_BATCH_SIZE + localIndex;
const globalIndex = (currentBatch - 1) * getPerfBatchSize() + localIndex;
return globalIndex;
}
@ -606,7 +622,7 @@ function capture(funcName, lineId, fn, ...args) {
*/
function capturePerf(funcName, lineId, fn, ...args) {
// Check if we should skip looping entirely (shared time budget exceeded)
const shouldLoop = PERF_LOOP_COUNT > 1 && !checkSharedTimeLimit();
const shouldLoop = getPerfLoopCount() > 1 && !checkSharedTimeLimit();
// Get test context (computed once, reused across batch)
let testModulePath;
@ -636,9 +652,9 @@ function capturePerf(funcName, lineId, fn, ...args) {
// If so, just execute the function once without timing (for test assertions)
const peekLoopIndex = (sharedPerfState.invocationLoopCounts[invocationKey] || 0);
const currentBatch = parseInt(process.env.CODEFLASH_PERF_CURRENT_BATCH || '1', 10);
const nextGlobalIndex = (currentBatch - 1) * PERF_BATCH_SIZE + peekLoopIndex + 1;
const nextGlobalIndex = (currentBatch - 1) * getPerfBatchSize() + peekLoopIndex + 1;
if (shouldLoop && nextGlobalIndex > PERF_LOOP_COUNT) {
if (shouldLoop && nextGlobalIndex > getPerfLoopCount()) {
// All loops completed, just execute once for test assertion
return fn(...args);
}
@ -654,7 +670,7 @@ function capturePerf(funcName, lineId, fn, ...args) {
// Batched looping: run BATCH_SIZE loops per capturePerf call when using loop-runner
// For Vitest (no loop-runner), do all loops internally in a single call
const batchSize = shouldLoop
? (hasExternalLoopRunner ? PERF_BATCH_SIZE : PERF_LOOP_COUNT)
? (hasExternalLoopRunner ? getPerfBatchSize() : getPerfLoopCount())
: 1;
for (let batchIndex = 0; batchIndex < batchSize; batchIndex++) {
@ -667,7 +683,7 @@ function capturePerf(funcName, lineId, fn, ...args) {
const loopIndex = getInvocationLoopIndex(invocationKey);
// Check if we've exceeded max loops for this invocation
if (loopIndex > PERF_LOOP_COUNT) {
if (loopIndex > getPerfLoopCount()) {
break;
}
@ -872,7 +888,11 @@ module.exports = {
LOOP_INDEX,
OUTPUT_FILE,
TEST_ITERATION,
// Batch configuration
PERF_BATCH_SIZE,
PERF_LOOP_COUNT,
// Batch configuration (getter functions for dynamic env var reading)
getPerfBatchSize,
getPerfLoopCount,
getPerfMinLoops,
getPerfTargetDurationMs,
getPerfStabilityCheck,
getPerfCurrentBatch,
};

View file

@ -30,13 +30,74 @@
const { createRequire } = require('module');
const path = require('path');
const fs = require('fs');
/**
* Resolve jest-runner with monorepo support.
* Uses CODEFLASH_MONOREPO_ROOT environment variable if available,
* otherwise walks up the directory tree looking for node_modules/jest-runner.
*/
function resolveJestRunner() {
// Try standard resolution first (works in simple projects)
try {
return require.resolve('jest-runner');
} catch (e) {
// Standard resolution failed - try monorepo-aware resolution
}
// If Python detected a monorepo root, check there first
const monorepoRoot = process.env.CODEFLASH_MONOREPO_ROOT;
if (monorepoRoot) {
const jestRunnerPath = path.join(monorepoRoot, 'node_modules', 'jest-runner');
if (fs.existsSync(jestRunnerPath)) {
const packageJsonPath = path.join(jestRunnerPath, 'package.json');
if (fs.existsSync(packageJsonPath)) {
return jestRunnerPath;
}
}
}
// Fallback: Walk up from cwd looking for node_modules/jest-runner
const monorepoMarkers = ['yarn.lock', 'pnpm-workspace.yaml', 'lerna.json', 'package-lock.json'];
let currentDir = process.cwd();
const visitedDirs = new Set();
while (currentDir !== path.dirname(currentDir)) {
// Avoid infinite loops
if (visitedDirs.has(currentDir)) break;
visitedDirs.add(currentDir);
// Try node_modules/jest-runner at this level
const jestRunnerPath = path.join(currentDir, 'node_modules', 'jest-runner');
if (fs.existsSync(jestRunnerPath)) {
const packageJsonPath = path.join(jestRunnerPath, 'package.json');
if (fs.existsSync(packageJsonPath)) {
return jestRunnerPath;
}
}
// Check if this is a workspace root (has monorepo markers)
const isWorkspaceRoot = monorepoMarkers.some(marker =>
fs.existsSync(path.join(currentDir, marker))
);
if (isWorkspaceRoot) {
// Found workspace root but no jest-runner - stop searching
break;
}
currentDir = path.dirname(currentDir);
}
throw new Error('jest-runner not found');
}
// Try to load jest-runner - it's a peer dependency that must be installed by the user
let runTest;
let jestRunnerAvailable = false;
try {
const jestRunnerPath = require.resolve('jest-runner');
const jestRunnerPath = resolveJestRunner();
const internalRequire = createRequire(jestRunnerPath);
runTest = internalRequire('./runTest').default;
jestRunnerAvailable = true;

View file

@ -79,8 +79,9 @@ dev = [
"types-greenlet>=3.1.0.20241221,<4",
"types-pexpect>=4.9.0.20241208,<5",
"types-unidiff>=0.7.0.20240505,<0.8",
"uv>=0.6.2",
"pre-commit>=4.2.0,<5",
"prek>=0.2.25",
"ty>=0.0.14",
"uv>=0.9.29",
]
tests = [
"black>=25.9.0",
@ -272,6 +273,7 @@ ignore = [
"ANN401", # typing.Any disallowed
"ARG001", # Unused function argument (common in abstract/interface methods)
"TRY300", # Consider moving to else block
"FURB110", # if-exp-instead-of-or-operator - we prefer explicit if-else over "or"
"TRY401", # Redundant exception in logging.exception
"PLR0911", # Too many return statements
"PLW0603", # Global statement

View file

@ -127,13 +127,14 @@ class TestDetectModuleRoot:
assert result == "lib"
def test_detects_from_exports_object_dot(self, tmp_path: Path) -> None:
"""Should detect module root from exports object with '.' key."""
"""Should skip build output dirs and return '.' when no src dir exists."""
(tmp_path / "dist").mkdir()
package_data = {"exports": {".": "./dist/index.js"}}
result = detect_module_root(tmp_path, package_data)
assert result == "dist"
# dist is a build output directory, so it's skipped
assert result == "."
def test_detects_from_exports_object_nested(self, tmp_path: Path) -> None:
"""Should detect module root from nested exports object."""
@ -227,13 +228,14 @@ class TestDetectModuleRoot:
assert result == "src"
def test_handles_deeply_nested_exports(self, tmp_path: Path) -> None:
"""Should handle deeply nested export paths."""
"""Should handle deeply nested export paths but skip build output dirs."""
(tmp_path / "packages" / "core" / "dist").mkdir(parents=True)
package_data = {"exports": {".": {"import": "./packages/core/dist/index.mjs"}}}
result = detect_module_root(tmp_path, package_data)
assert result == "packages/core/dist"
# dist is a build output directory, so it's skipped even when nested
assert result == "."
def test_handles_empty_exports(self, tmp_path: Path) -> None:
"""Should handle empty exports gracefully."""
@ -756,7 +758,7 @@ class TestRealWorldPackageJsonExamples:
assert config["formatter_cmds"] == ["npx eslint --fix $file"]
def test_library_with_exports(self, tmp_path: Path) -> None:
"""Should handle library with modern exports field."""
"""Should handle library with modern exports field, skipping build output dirs."""
(tmp_path / "dist").mkdir()
package_json = tmp_path / "package.json"
package_json.write_text(
@ -773,7 +775,8 @@ class TestRealWorldPackageJsonExamples:
assert result is not None
config, _ = result
assert config["module_root"] == str((tmp_path / "dist").resolve())
# dist is a build output directory, so it's skipped and falls back to project root
assert config["module_root"] == str(tmp_path.resolve())
def test_monorepo_package(self, tmp_path: Path) -> None:
"""Should handle monorepo package configuration."""

View file

@ -24,6 +24,8 @@ from pathlib import Path
import pytest
from codeflash.context.code_context_extractor import get_code_optimization_context_for_language
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
from codeflash.languages.base import Language
from codeflash.languages.javascript.support import JavaScriptSupport, TypeScriptSupport
@ -1694,6 +1696,240 @@ const FIELD_KEYS = {
};"""
assert context.read_only_context == expected_read_only
def test_with_tricky_helpers(self, ts_support, temp_project):
"""Test function returning object with computed property names."""
code = """import { WebClient, ChatPostMessageArguments } from "@slack/web-api"
// Dependencies interface for easier testing
export interface SendSlackMessageDependencies {
WebClient: typeof WebClient
getSlackToken: () => string | undefined
getSlackChannelId: () => string | undefined
console: typeof console
}
// Default dependencies
let dependencies: SendSlackMessageDependencies = {
WebClient,
getSlackToken: () => process.env.SLACK_TOKEN,
getSlackChannelId: () => process.env.SLACK_CHANNEL_ID,
console,
}
// For testing - allow dependency injection
export function setSendSlackMessageDependencies(deps: Partial<SendSlackMessageDependencies>) {
dependencies = { ...dependencies, ...deps }
}
export function resetSendSlackMessageDependencies() {
dependencies = {
WebClient,
getSlackToken: () => process.env.SLACK_TOKEN,
getSlackChannelId: () => process.env.SLACK_CHANNEL_ID,
console,
}
}
// Initialize web client
let web: WebClient | null = null
export function initializeWebClient() {
const SLACK_TOKEN = dependencies.getSlackToken()
const SLACK_CHANNEL_ID = dependencies.getSlackChannelId()
if (!SLACK_TOKEN) {
throw new Error("Missing SLACK_TOKEN")
}
if (!SLACK_CHANNEL_ID) {
throw new Error("Missing SLACK_CHANNEL_ID")
}
if (!web) {
web = new dependencies.WebClient(SLACK_TOKEN, {})
}
return web
}
// For testing - allow resetting the web client
export function resetWebClient() {
web = null
}
/**
* Send a message to Slack
*
* @param {string|object} message - Text message or Block Kit message object
* @param {string|null} channel - Channel ID, defaults to SLACK_CHANNEL_ID
* @param {boolean} returnData - Whether to return the full Slack API response
* @returns {Promise<boolean|object>} - True or API response
*/
export const sendSlackMessage = async (
message: any,
channel: string | null = null,
returnData: boolean = false,
): Promise<boolean | object> => {
return new Promise(async (resolve, reject) => {
try {
const webClient = initializeWebClient()
const SLACK_CHANNEL_ID = dependencies.getSlackChannelId()
const channelId = channel || SLACK_CHANNEL_ID
// Configure the message payload depending on the input type
let payload: ChatPostMessageArguments
if (typeof message === "string") {
payload = {
channel: channelId,
text: message,
}
} else if (message && typeof message === "object") {
if (message.blocks) {
payload = {
channel: channelId,
text: message.text || "Notification from CodeFlash",
blocks: message.blocks,
}
} else {
dependencies.console.warn("Object passed to sendSlackMessage without blocks property")
payload = {
channel: channelId,
text: JSON.stringify(message),
}
}
} else {
dependencies.console.error("Invalid message type", typeof message)
payload = {
channel: channelId,
text: "Invalid message",
}
}
// console.log("Sending payload to Slack:", JSON.stringify(payload, null, 2));
const resp = await webClient.chat.postMessage(payload)
return resolve(returnData ? resp : true)
} catch (error) {
dependencies.console.error("Error sending Slack message:", error)
return resolve(returnData ? { error } : true)
}
})
}
"""
file_path = temp_project / "slack_util.ts"
file_path.write_text(code, encoding="utf-8")
target_func = "sendSlackMessage"
functions = ts_support.discover_functions(file_path)
func_info = next(f for f in functions if f.function_name == target_func)
fto = FunctionToOptimize(
function_name=target_func,
file_path=file_path,
parents=func_info.parents,
starting_line=func_info.starting_line,
ending_line=func_info.ending_line,
starting_col=func_info.starting_col,
ending_col=func_info.ending_col,
is_async=func_info.is_async,
language="typescript",
)
ctx = get_code_optimization_context_for_language(
fto, temp_project
)
# The read_writable_code should contain the target function AND helper functions
expected_read_writable = """```typescript:slack_util.ts
import { WebClient, ChatPostMessageArguments } from "@slack/web-api"
export const sendSlackMessage = async (
message: any,
channel: string | null = null,
returnData: boolean = false,
): Promise<boolean | object> => {
return new Promise(async (resolve, reject) => {
try {
const webClient = initializeWebClient()
const SLACK_CHANNEL_ID = dependencies.getSlackChannelId()
const channelId = channel || SLACK_CHANNEL_ID
// Configure the message payload depending on the input type
let payload: ChatPostMessageArguments
if (typeof message === "string") {
payload = {
channel: channelId,
text: message,
}
} else if (message && typeof message === "object") {
if (message.blocks) {
payload = {
channel: channelId,
text: message.text || "Notification from CodeFlash",
blocks: message.blocks,
}
} else {
dependencies.console.warn("Object passed to sendSlackMessage without blocks property")
payload = {
channel: channelId,
text: JSON.stringify(message),
}
}
} else {
dependencies.console.error("Invalid message type", typeof message)
payload = {
channel: channelId,
text: "Invalid message",
}
}
// console.log("Sending payload to Slack:", JSON.stringify(payload, null, 2));
const resp = await webClient.chat.postMessage(payload)
return resolve(returnData ? resp : true)
} catch (error) {
dependencies.console.error("Error sending Slack message:", error)
return resolve(returnData ? { error } : true)
}
})
}
export function initializeWebClient() {
const SLACK_TOKEN = dependencies.getSlackToken()
const SLACK_CHANNEL_ID = dependencies.getSlackChannelId()
if (!SLACK_TOKEN) {
throw new Error("Missing SLACK_TOKEN")
}
if (!SLACK_CHANNEL_ID) {
throw new Error("Missing SLACK_CHANNEL_ID")
}
if (!web) {
web = new dependencies.WebClient(SLACK_TOKEN, {})
}
return web
}
```"""
# The read_only_context should contain global variables (dependencies object, web client)
# but NOT have invalid floating object properties
expected_read_only = """let dependencies: SendSlackMessageDependencies = {
WebClient,
getSlackToken: () => process.env.SLACK_TOKEN,
getSlackChannelId: () => process.env.SLACK_CHANNEL_ID,
console,
}
let web: WebClient | null = null"""
assert ctx.read_writable_code.markdown == expected_read_writable
assert ctx.read_only_context_code == expected_read_only
class TestContextProperties:
"""Tests for CodeContext object properties."""

View file

@ -5,7 +5,13 @@ import json
import tempfile
from pathlib import Path
from codeflash.languages.javascript.module_system import ModuleSystem, detect_module_system, get_import_statement
from codeflash.languages.javascript.module_system import (
ModuleSystem,
convert_commonjs_to_esm,
convert_esm_to_commonjs,
detect_module_system,
get_import_statement,
)
class TestModuleSystemDetection:
@ -51,6 +57,39 @@ class TestModuleSystemDetection:
result = detect_module_system(project_root, file_path)
assert result == ModuleSystem.COMMONJS
def test_detect_esm_from_typescript_extension(self):
"""Test detection of ES modules from TypeScript file extensions."""
with tempfile.TemporaryDirectory() as tmpdir:
project_root = Path(tmpdir)
# Test .ts files
ts_file = project_root / "module.ts"
ts_file.write_text("export const foo = 'bar';")
assert detect_module_system(project_root, ts_file) == ModuleSystem.ES_MODULE
# Test .tsx files
tsx_file = project_root / "component.tsx"
tsx_file.write_text("export const Component = () => <div />;")
assert detect_module_system(project_root, tsx_file) == ModuleSystem.ES_MODULE
# Test .mts files
mts_file = project_root / "module.mts"
mts_file.write_text("export const foo = 'bar';")
assert detect_module_system(project_root, mts_file) == ModuleSystem.ES_MODULE
def test_typescript_ignores_package_json_commonjs(self):
"""Test that TypeScript files are detected as ESM even with CommonJS package.json."""
with tempfile.TemporaryDirectory() as tmpdir:
project_root = Path(tmpdir)
# Create package.json with explicit commonjs type
package_json = project_root / "package.json"
package_json.write_text(json.dumps({"type": "commonjs"}))
# TypeScript file should still be detected as ESM
ts_file = project_root / "module.ts"
ts_file.write_text("export const foo = 'bar';")
assert detect_module_system(project_root, ts_file) == ModuleSystem.ES_MODULE
def test_detect_esm_from_import_syntax(self):
"""Test detection of ES modules from import syntax."""
with tempfile.TemporaryDirectory() as tmpdir:
@ -159,3 +198,90 @@ class TestImportStatementGeneration:
result = get_import_statement(ModuleSystem.COMMONJS, target, source, ["foo"])
assert result == "const { foo } = require('../../utils');"
class TestModuleSystemConversion:
"""Tests for CommonJS <-> ESM conversion."""
def test_convert_simple_destructured_require(self):
"""Test converting simple destructured require to import."""
code = "const { foo, bar } = require('./module');"
result = convert_commonjs_to_esm(code)
assert result == "import { foo, bar } from './module';"
def test_convert_destructured_require_with_alias(self):
"""Test converting destructured require with alias to import with 'as'."""
code = "const { foo: aliasedFoo } = require('./module');"
result = convert_commonjs_to_esm(code)
assert result == "import { foo as aliasedFoo } from './module';"
def test_convert_mixed_destructured_require(self):
"""Test converting mixed destructured require (some aliased, some not)."""
code = "const { foo, bar: aliasedBar, baz } = require('./module');"
result = convert_commonjs_to_esm(code)
assert result == "import { foo, bar as aliasedBar, baz } from './module';"
def test_convert_destructured_with_whitespace(self):
"""Test that whitespace is handled correctly in destructuring."""
code = "const { foo : aliasedFoo , bar } = require('./module');"
result = convert_commonjs_to_esm(code)
assert result == "import { foo as aliasedFoo, bar } from './module';"
def test_convert_simple_require(self):
"""Test converting simple require to default import."""
code = "const module = require('./module');"
result = convert_commonjs_to_esm(code)
assert result == "import module from './module';"
def test_convert_property_access_require(self):
"""Test converting require with property access to named import."""
code = "const foo = require('./module').bar;"
result = convert_commonjs_to_esm(code)
assert result == "import { bar as foo } from './module';"
def test_convert_property_access_default(self):
"""Test converting require().default to default import."""
code = "const foo = require('./module').default;"
result = convert_commonjs_to_esm(code)
assert result == "import foo from './module';"
def test_convert_multiple_requires(self):
"""Test converting multiple requires in one code block."""
code = """const { db: dbCore, cache } = require('@budibase/backend-core');
const utils = require('./utils');
const { process } = require('./processor');"""
result = convert_commonjs_to_esm(code)
expected = """import { db as dbCore, cache } from '@budibase/backend-core';
import utils from './utils';
import { process } from './processor';"""
assert result == expected
def test_convert_esm_to_commonjs_named(self):
"""Test converting named imports to destructured require."""
code = "import { foo, bar } from './module';"
result = convert_esm_to_commonjs(code)
assert result == "const { foo, bar } = require('./module');"
def test_convert_esm_to_commonjs_default(self):
"""Test converting default import to simple require."""
code = "import module from './module';"
result = convert_esm_to_commonjs(code)
assert result == "const module = require('./module');"
def test_convert_esm_to_commonjs_with_alias(self):
"""Test converting import with 'as' to destructured require.
Note: ESM uses 'as' but the regex keeps it as-is in the output.
This is acceptable since the test is primarily for CommonJS -> ESM conversion.
"""
code = "import { foo as aliasedFoo } from './module';"
result = convert_esm_to_commonjs(code)
# The current implementation preserves 'as' syntax which works for our use case
assert result == "const { foo as aliasedFoo } = require('./module');"
def test_real_world_budibase_import(self):
"""Test the real-world case from Budibase that was failing."""
code = "const { queue, context, db: dbCore, cache, events } = require('@budibase/backend-core');"
result = convert_commonjs_to_esm(code)
expected = "import { queue, context, db as dbCore, cache, events } from '@budibase/backend-core';"
assert result == expected

View file

@ -15,6 +15,8 @@ from pathlib import Path
import pytest
from codeflash.code_utils.code_replacer import replace_function_definitions_for_language
from codeflash.languages.base import Language
from codeflash.languages.current import set_current_language
from codeflash.languages.javascript.module_system import (
ModuleSystem,
convert_commonjs_to_esm,
@ -300,12 +302,144 @@ export function calculate(x, y) {
assert "return add(x, y);" in result
class TestModuleSystemCompatibility:
"""Tests for module system compatibility."""
class TestTsJestSkipsConversion:
"""Tests verifying that module system conversion is skipped when ts-jest is installed.
def test_convert_mixed_code_to_esm(self):
"""Test converting mixed CJS/ESM code to pure ESM - exact output."""
code = """\
When ts-jest is installed, it handles module interoperability internally,
so we skip conversion to avoid breaking valid imports.
"""
def __init__(self):
set_current_language(Language.TYPESCRIPT)
def test_commonjs_not_converted_when_ts_jest_installed(self, tmp_path):
"""Test that CommonJS is NOT converted to ESM when ts-jest is installed."""
# Create a project with ts-jest
package_json = tmp_path / "package.json"
package_json.write_text('{"devDependencies": {"ts-jest": "^29.0.0"}}')
commonjs_test = """\
const Logger = require('../utils/logger');
const { helper } = require('../utils/helpers');
describe('Logger', () => {
test('should work', () => {
const logger = new Logger();
expect(logger).toBeDefined();
});
});
"""
# With ts-jest, no conversion should happen
result = ensure_module_system_compatibility(commonjs_test, ModuleSystem.ES_MODULE, tmp_path)
assert result == commonjs_test, (
f"CommonJS should NOT be converted when ts-jest is installed.\n"
f"Expected (unchanged):\n{commonjs_test}\n\nGot:\n{result}"
)
def test_esm_not_converted_when_ts_jest_installed(self, tmp_path):
"""Test that ESM is NOT converted to CommonJS when ts-jest is installed."""
# Create a project with ts-jest
package_json = tmp_path / "package.json"
package_json.write_text('{"devDependencies": {"ts-jest": "^29.0.0"}}')
esm_test = """\
import Logger from '../utils/logger';
import { helper } from '../utils/helpers';
describe('Logger', () => {
test('should work', () => {
const logger = new Logger();
expect(logger).toBeDefined();
});
});
"""
# With ts-jest, no conversion should happen
result = ensure_module_system_compatibility(esm_test, ModuleSystem.COMMONJS, tmp_path)
assert result == esm_test, (
f"ESM should NOT be converted when ts-jest is installed.\n"
f"Expected (unchanged):\n{esm_test}\n\nGot:\n{result}"
)
def test_ts_jest_detected_in_jest_config(self, tmp_path):
"""Test that ts-jest is detected from jest.config.js content."""
# Create a project with ts-jest in jest.config.js (not package.json)
package_json = tmp_path / "package.json"
package_json.write_text('{"devDependencies": {}}')
jest_config = tmp_path / "jest.config.js"
jest_config.write_text("module.exports = { preset: 'ts-jest' };")
commonjs_test = "const x = require('./module');"
result = ensure_module_system_compatibility(commonjs_test, ModuleSystem.ES_MODULE, tmp_path)
assert result == commonjs_test, "Should skip conversion when ts-jest is in jest.config.js"
class TestModuleSystemConversion:
"""Tests for module system conversion when ts-jest is NOT installed.
Without ts-jest, we convert between CommonJS and ESM as needed.
"""
def test_commonjs_converted_to_esm_without_ts_jest(self, tmp_path):
"""Test that CommonJS is converted to ESM when ts-jest is NOT installed."""
# Create a project WITHOUT ts-jest
package_json = tmp_path / "package.json"
package_json.write_text('{"devDependencies": {"jest": "^29.0.0"}}')
commonjs_code = """\
const { helper } = require('./helpers');
const logger = require('./logger');
function process() {
return helper();
}
"""
result = ensure_module_system_compatibility(commonjs_code, ModuleSystem.ES_MODULE, tmp_path)
# Should be converted to ESM
assert "import { helper } from './helpers';" in result
assert "import logger from './logger';" in result
assert "require(" not in result
def test_esm_converted_to_commonjs_without_ts_jest(self, tmp_path):
"""Test that ESM is converted to CommonJS when ts-jest is NOT installed."""
# Create a project WITHOUT ts-jest
package_json = tmp_path / "package.json"
package_json.write_text('{"devDependencies": {"jest": "^29.0.0"}}')
esm_code = """\
import { helper } from './helpers';
import logger from './logger';
function process() {
return helper();
}
"""
result = ensure_module_system_compatibility(esm_code, ModuleSystem.COMMONJS, tmp_path)
# Should be converted to CommonJS
assert "const { helper } = require('./helpers');" in result
assert "const logger = require('./logger');" in result
assert "import " not in result
def test_no_conversion_when_project_root_is_none(self):
"""Test that conversion happens when project_root is None (can't detect ts-jest)."""
commonjs_code = "const x = require('./module');"
# Without project_root, we can't detect ts-jest, so conversion should happen
result = ensure_module_system_compatibility(commonjs_code, ModuleSystem.ES_MODULE, None)
# Should be converted to ESM
assert "import x from './module';" in result
def test_mixed_code_not_converted(self, tmp_path):
"""Test that mixed CJS/ESM code is NOT converted (already has both)."""
package_json = tmp_path / "package.json"
package_json.write_text('{"devDependencies": {"jest": "^29.0.0"}}')
mixed_code = """\
import { existing } from './module.js';
const { helper } = require('./helpers');
@ -313,32 +447,16 @@ function process() {
return existing() + helper();
}
"""
result = ensure_module_system_compatibility(code, ModuleSystem.ES_MODULE)
# Mixed code has both import and require, so no conversion
result = ensure_module_system_compatibility(mixed_code, ModuleSystem.ES_MODULE, tmp_path)
# Should convert require to import
assert "import { helper } from './helpers';" in result
assert "require" not in result, f"require should be converted to import. Got:\n{result}"
assert result == mixed_code, "Mixed code should not be converted"
def test_convert_mixed_code_to_commonjs(self):
"""Test converting mixed ESM/CJS code to pure CommonJS - exact output."""
code = """\
const { existing } = require('./module');
import { helper } from './helpers.js';
function process() {
return existing() + helper();
}
"""
result = ensure_module_system_compatibility(code, ModuleSystem.COMMONJS)
# Should convert import to require
assert "const { helper } = require('./helpers');" in result
assert "import " not in result.split("\n")[0] or "import " not in result, (
f"import should be converted to require. Got:\n{result}"
)
def test_pure_esm_unchanged(self):
def test_pure_esm_unchanged_for_esm_target(self, tmp_path):
"""Test that pure ESM code is unchanged when targeting ESM."""
package_json = tmp_path / "package.json"
package_json.write_text('{"devDependencies": {"jest": "^29.0.0"}}')
code = """\
import { add } from './math.js';
@ -346,11 +464,14 @@ export function sum(a, b) {
return add(a, b);
}
"""
result = ensure_module_system_compatibility(code, ModuleSystem.ES_MODULE)
assert result == code, f"Pure ESM code should be unchanged.\nExpected:\n{code}\n\nGot:\n{result}"
result = ensure_module_system_compatibility(code, ModuleSystem.ES_MODULE, tmp_path)
assert result == code, "Pure ESM code should be unchanged for ESM target"
def test_pure_commonjs_unchanged(self):
def test_pure_commonjs_unchanged_for_commonjs_target(self, tmp_path):
"""Test that pure CommonJS code is unchanged when targeting CommonJS."""
package_json = tmp_path / "package.json"
package_json.write_text('{"devDependencies": {"jest": "^29.0.0"}}')
code = """\
const { add } = require('./math');
@ -360,8 +481,8 @@ function sum(a, b) {
module.exports = { sum };
"""
result = ensure_module_system_compatibility(code, ModuleSystem.COMMONJS)
assert result == code, f"Pure CommonJS code should be unchanged.\nExpected:\n{code}\n\nGot:\n{result}"
result = ensure_module_system_compatibility(code, ModuleSystem.COMMONJS, tmp_path)
assert result == code, "Pure CommonJS code should be unchanged for CommonJS target"
class TestImportStatementGeneration:

View file

@ -555,3 +555,125 @@ def standalone():
standalone_funcs = [f for f in functions if f.class_name is None]
assert len(standalone_funcs) == 1
# === Tests for find_references method ===
# These tests verify that PythonSupport correctly finds references to functions
# using jedi, including the fix for using function.function_name instead of function.name.
def test_find_references_simple_function(python_support, tmp_path):
"""Test finding references to a simple function.
This test specifically exercises the code path that was fixed in the
regression where function.name was used instead of function.function_name.
"""
from codeflash.models.function_types import FunctionToOptimize
# Create source file with function definition
source_file = tmp_path / "utils.py"
source_file.write_text("""def helper_function(x):
return x * 2
""")
# Create a file that imports and uses the function
consumer_file = tmp_path / "consumer.py"
consumer_file.write_text("""from utils import helper_function
def process(value):
return helper_function(value) + 1
""")
func = FunctionToOptimize(
function_name="helper_function",
file_path=source_file,
starting_line=1,
ending_line=2,
)
refs = python_support.find_references(func, project_root=tmp_path)
assert len(refs) >= 1
ref_files = {str(r.file_path) for r in refs}
assert any("consumer.py" in f for f in ref_files)
def test_find_references_class_method(python_support, tmp_path):
"""Test finding references to a class method.
This verifies the class_name attribute is correctly used to disambiguate methods.
"""
from codeflash.models.function_types import FunctionParent, FunctionToOptimize
# Create source file with class and method
source_file = tmp_path / "calculator.py"
source_file.write_text("""class Calculator:
def add(self, a, b):
return a + b
""")
# Create a file that uses the class method
consumer_file = tmp_path / "main.py"
consumer_file.write_text("""from calculator import Calculator
def compute():
calc = Calculator()
return calc.add(1, 2)
""")
func = FunctionToOptimize(
function_name="add",
file_path=source_file,
parents=[FunctionParent(name="Calculator", type="ClassDef")],
starting_line=2,
ending_line=3,
is_method=True,
)
refs = python_support.find_references(func, project_root=tmp_path)
assert len(refs) >= 1
ref_files = {str(r.file_path) for r in refs}
assert any("main.py" in f for f in ref_files)
def test_find_references_no_references(python_support, tmp_path):
"""Test that find_references returns empty list when no references exist."""
from codeflash.models.function_types import FunctionToOptimize
source_file = tmp_path / "isolated.py"
source_file.write_text("""def isolated_function():
return 42
""")
func = FunctionToOptimize(
function_name="isolated_function",
file_path=source_file,
starting_line=1,
ending_line=2,
)
refs = python_support.find_references(func, project_root=tmp_path)
assert refs == []
def test_find_references_nonexistent_function(python_support, tmp_path):
"""Test that find_references handles nonexistent functions gracefully."""
from codeflash.models.function_types import FunctionToOptimize
source_file = tmp_path / "source.py"
source_file.write_text("""def existing_function():
return 1
""")
func = FunctionToOptimize(
function_name="nonexistent_function",
file_path=source_file,
starting_line=1,
ending_line=2,
)
refs = python_support.find_references(func, project_root=tmp_path)
assert refs == []

View file

@ -14,6 +14,7 @@ from codeflash.setup.detector import (
_find_project_root,
detect_project,
has_existing_config,
is_build_output_dir,
)
@ -139,15 +140,15 @@ class TestDetectModuleRoot:
assert "pyproject.toml" in detail
def test_js_detects_from_exports(self, tmp_path):
"""Should detect module root from package.json exports."""
"""Should detect module root from package.json exports when no common src dir exists."""
(tmp_path / "package.json").write_text(json.dumps({
"name": "test",
"exports": {".": "./src/index.js"}
"exports": {".": "./packages/core/index.js"}
}))
(tmp_path / "src").mkdir()
(tmp_path / "packages" / "core").mkdir(parents=True)
module_root, detail = _detect_js_module_root(tmp_path)
assert module_root == tmp_path / "src"
assert module_root == tmp_path / "packages" / "core"
assert "exports" in detail
def test_js_detects_src_convention(self, tmp_path):
@ -158,6 +159,214 @@ class TestDetectModuleRoot:
module_root, detail = _detect_js_module_root(tmp_path)
assert module_root == tmp_path / "src"
def test_js_prefers_src_over_build_src(self, tmp_path):
"""Should prefer src/ over build/src/ even when package.json points to build/."""
(tmp_path / "package.json").write_text(json.dumps({
"name": "test",
"main": "build/src/index.js",
"module": "build/src/index.js"
}))
(tmp_path / "src").mkdir()
(tmp_path / "build" / "src").mkdir(parents=True)
module_root, detail = _detect_js_module_root(tmp_path)
assert module_root == tmp_path / "src"
assert "src/ directory" in detail
def test_js_skips_build_dir_from_main(self, tmp_path):
"""Should skip build output directories from package.json main field."""
(tmp_path / "package.json").write_text(json.dumps({
"name": "test",
"main": "build/index.js"
}))
(tmp_path / "build").mkdir()
module_root, detail = _detect_js_module_root(tmp_path)
assert module_root == tmp_path
assert "project root" in detail
def test_js_skips_dist_dir_from_exports(self, tmp_path):
"""Should skip dist output directories from package.json exports field."""
(tmp_path / "package.json").write_text(json.dumps({
"name": "test",
"exports": {".": "./dist/index.js"}
}))
(tmp_path / "dist").mkdir()
module_root, detail = _detect_js_module_root(tmp_path)
assert module_root == tmp_path
assert "project root" in detail
def test_js_skips_out_dir_from_module(self, tmp_path):
"""Should skip out output directories from package.json module field."""
(tmp_path / "package.json").write_text(json.dumps({
"name": "test",
"module": "out/esm/index.js"
}))
(tmp_path / "out" / "esm").mkdir(parents=True)
module_root, detail = _detect_js_module_root(tmp_path)
assert module_root == tmp_path
assert "project root" in detail
def test_js_prefers_lib_over_build_dir(self, tmp_path):
"""Should prefer lib/ over build output directories."""
(tmp_path / "package.json").write_text(json.dumps({
"name": "test",
"main": "dist/index.js"
}))
(tmp_path / "lib").mkdir()
(tmp_path / "dist").mkdir()
module_root, detail = _detect_js_module_root(tmp_path)
assert module_root == tmp_path / "lib"
assert "lib/ directory" in detail
def test_js_prefers_source_over_build_dir(self, tmp_path):
"""Should prefer source/ over build output directories."""
(tmp_path / "package.json").write_text(json.dumps({
"name": "test",
"main": "build/index.js"
}))
(tmp_path / "source").mkdir()
(tmp_path / "build").mkdir()
module_root, detail = _detect_js_module_root(tmp_path)
assert module_root == tmp_path / "source"
assert "source/ directory" in detail
def test_js_falls_back_to_valid_exports_path(self, tmp_path):
"""Should use exports path when no common source dirs exist and path is not build output."""
(tmp_path / "package.json").write_text(json.dumps({
"name": "test",
"exports": {".": "./packages/core/index.js"}
}))
(tmp_path / "packages" / "core").mkdir(parents=True)
module_root, detail = _detect_js_module_root(tmp_path)
assert module_root == tmp_path / "packages" / "core"
assert "exports" in detail
def test_js_falls_back_to_valid_main_path(self, tmp_path):
"""Should use main path when no common source dirs exist and path is not build output."""
(tmp_path / "package.json").write_text(json.dumps({
"name": "test",
"main": "packages/main/index.js"
}))
(tmp_path / "packages" / "main").mkdir(parents=True)
module_root, detail = _detect_js_module_root(tmp_path)
assert module_root == tmp_path / "packages" / "main"
assert "main" in detail
def test_js_falls_back_to_valid_module_path(self, tmp_path):
"""Should use module path when no common source dirs exist and path is not build output."""
(tmp_path / "package.json").write_text(json.dumps({
"name": "test",
"module": "esm/index.js"
}))
(tmp_path / "esm").mkdir()
module_root, detail = _detect_js_module_root(tmp_path)
assert module_root == tmp_path / "esm"
assert "module" in detail
def test_js_returns_project_root_when_all_paths_are_build_output(self, tmp_path):
"""Should return project root when all package.json paths point to build outputs."""
(tmp_path / "package.json").write_text(json.dumps({
"name": "test",
"main": "dist/cjs/index.js",
"module": "dist/esm/index.js",
"exports": {".": "./build/index.js"}
}))
(tmp_path / "dist" / "cjs").mkdir(parents=True)
(tmp_path / "dist" / "esm").mkdir(parents=True)
(tmp_path / "build").mkdir()
module_root, detail = _detect_js_module_root(tmp_path)
assert module_root == tmp_path
assert "project root" in detail
def test_js_handles_malformed_package_json(self, tmp_path):
"""Should handle malformed package.json gracefully."""
(tmp_path / "package.json").write_text("{ invalid json }")
module_root, detail = _detect_js_module_root(tmp_path)
assert module_root == tmp_path
assert "project root" in detail
class TestIsBuildOutputDir:
"""Tests for is_build_output_dir function."""
def test_detects_build_dir(self):
"""Should detect build/ as build output."""
from pathlib import Path
assert is_build_output_dir(Path("build"))
assert is_build_output_dir(Path("build/src"))
assert is_build_output_dir(Path("build/src/index.js"))
def test_detects_dist_dir(self):
"""Should detect dist/ as build output."""
from pathlib import Path
assert is_build_output_dir(Path("dist"))
assert is_build_output_dir(Path("dist/esm"))
assert is_build_output_dir(Path("dist/cjs/index.js"))
def test_detects_out_dir(self):
"""Should detect out/ as build output."""
from pathlib import Path
assert is_build_output_dir(Path("out"))
assert is_build_output_dir(Path("out/src"))
def test_detects_next_dir(self):
"""Should detect .next/ as build output."""
from pathlib import Path
assert is_build_output_dir(Path(".next"))
assert is_build_output_dir(Path(".next/static"))
def test_detects_nuxt_dir(self):
"""Should detect .nuxt/ as build output."""
from pathlib import Path
assert is_build_output_dir(Path(".nuxt"))
assert is_build_output_dir(Path(".nuxt/dist"))
def test_detects_nested_build_dir(self):
"""Should detect build dir nested in path."""
from pathlib import Path
assert is_build_output_dir(Path("packages/build/index.js"))
assert is_build_output_dir(Path("foo/dist/bar"))
def test_does_not_detect_src(self):
"""Should not detect src/ as build output."""
from pathlib import Path
assert not is_build_output_dir(Path("src"))
assert not is_build_output_dir(Path("src/index.js"))
def test_does_not_detect_lib(self):
"""Should not detect lib/ as build output."""
from pathlib import Path
assert not is_build_output_dir(Path("lib"))
assert not is_build_output_dir(Path("lib/utils"))
def test_does_not_detect_source(self):
"""Should not detect source/ as build output."""
from pathlib import Path
assert not is_build_output_dir(Path("source"))
def test_does_not_detect_packages(self):
"""Should not detect packages/ as build output."""
from pathlib import Path
assert not is_build_output_dir(Path("packages"))
assert not is_build_output_dir(Path("packages/core"))
def test_does_not_detect_similar_names(self):
"""Should not detect directories with similar but different names."""
from pathlib import Path
assert not is_build_output_dir(Path("builder"))
assert not is_build_output_dir(Path("distribution"))
assert not is_build_output_dir(Path("output"))
class TestDetectTestsRoot:
"""Tests for tests root detection."""

View file

@ -857,6 +857,48 @@ class TestE2ECLIFlags:
# Should complete without error
_handle_show_config()
def test_show_config_displays_config_path_when_saved(self, project_with_existing_config, monkeypatch):
"""Should display config file path when saved config exists."""
monkeypatch.chdir(project_with_existing_config)
# Track what gets printed
printed_messages = []
def mock_print(msg="", *args, **kwargs):
printed_messages.append(str(msg))
from codeflash.cli_cmds import console
monkeypatch.setattr(console.console, "print", mock_print)
from codeflash.cli_cmds.cli import _handle_show_config
_handle_show_config()
# Verify config path is displayed
all_output = "\n".join(printed_messages)
assert "pyproject.toml" in all_output
assert "Config file:" in all_output
def test_show_config_no_path_when_auto_detected(self, python_src_layout, monkeypatch):
"""Should not display config file path when config is auto-detected."""
monkeypatch.chdir(python_src_layout)
# Track what gets printed
printed_messages = []
def mock_print(msg="", *args, **kwargs):
printed_messages.append(str(msg))
from codeflash.cli_cmds import console
monkeypatch.setattr(console.console, "print", mock_print)
from codeflash.cli_cmds.cli import _handle_show_config
_handle_show_config()
# Verify no config path line is displayed
all_output = "\n".join(printed_messages)
assert "Config file:" not in all_output
assert "Auto-detected" in all_output
def test_reset_config_removes_from_pyproject(self, project_with_existing_config, monkeypatch):
"""Should remove codeflash config from pyproject.toml."""
monkeypatch.chdir(project_with_existing_config)

2123
uv.lock

File diff suppressed because it is too large Load diff