Merge branch 'main' into fix/jest-junit-and-misc

This commit is contained in:
Sarthak Agarwal 2026-03-02 22:46:13 +05:30 committed by GitHub
commit c53740df2e
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
82 changed files with 3335 additions and 3441 deletions

View file

@ -1,5 +1,7 @@
# Architecture
When adding, moving, or deleting source files, update this doc to match.
```
codeflash/
├── main.py # CLI entry point
@ -14,7 +16,21 @@ codeflash/
├── api/ # AI service communication
├── code_utils/ # Code parsing, git utilities
├── models/ # Pydantic models and types
├── languages/ # Multi-language support (Python, JavaScript/TypeScript)
├── languages/ # Multi-language support (Python, JavaScript/TypeScript, Java planned)
│ ├── base.py # LanguageSupport protocol and shared data types
│ ├── registry.py # Language registration and lookup by extension/enum
│ ├── current.py # Current language singleton (set_current_language / current_language_support)
│ ├── code_replacer.py # Language-agnostic code replacement
│ ├── python/
│ │ ├── support.py # PythonSupport (LanguageSupport implementation)
│ │ ├── function_optimizer.py # PythonFunctionOptimizer subclass
│ │ ├── optimizer.py # Python module preparation & AST resolution
│ │ └── normalizer.py # Python code normalization for deduplication
│ └── javascript/
│ ├── support.py # JavaScriptSupport (LanguageSupport implementation)
│ ├── function_optimizer.py # JavaScriptFunctionOptimizer subclass
│ ├── optimizer.py # JS project root finding & module preparation
│ └── normalizer.py # JS/TS code normalization for deduplication
├── setup/ # Config schema, auto-detection, first-run experience
├── picklepatch/ # Serialization/deserialization utilities
├── tracing/ # Function call tracing
@ -32,10 +48,36 @@ codeflash/
|------|------------|
| CLI arguments & commands | `cli_cmds/cli.py` |
| Optimization orchestration | `optimization/optimizer.py``run()` |
| Per-function optimization | `optimization/function_optimizer.py` |
| Per-function optimization | `optimization/function_optimizer.py` (base), `languages/python/function_optimizer.py`, `languages/javascript/function_optimizer.py` |
| Function discovery | `discovery/functions_to_optimize.py` |
| Context extraction | `languages/<lang>/context/code_context_extractor.py` |
| Test execution | `verification/test_runner.py`, `verification/pytest_plugin.py` |
| Test execution | `languages/<lang>/support.py` (`run_behavioral_tests`, etc.), `verification/pytest_plugin.py` |
| Performance ranking | `benchmarking/function_ranker.py` |
| Domain types | `models/models.py`, `models/function_types.py` |
| Result handling | `either.py` (`Result`, `Success`, `Failure`, `is_successful`) |
## LanguageSupport Protocol Methods
Core protocol in `languages/base.py`. Each language (`PythonSupport`, `JavaScriptSupport`) implements these.
| Category | Method/Property | Purpose |
|----------|----------------|---------|
| Identity | `language`, `file_extensions`, `default_file_extension` | Language identification |
| Identity | `comment_prefix`, `dir_excludes` | Language conventions |
| AI service | `default_language_version` | Language version for API payloads (`None` for Python, `"ES2022"` for JS) |
| AI service | `valid_test_frameworks` | Allowed test frameworks for validation |
| Discovery | `discover_functions`, `discover_tests` | Find optimizable functions and their tests |
| Discovery | `adjust_test_config_for_discovery` | Pre-discovery config adjustment (no-op default) |
| Context | `extract_code_context`, `find_helper_functions`, `find_references` | Code dependency extraction |
| Transform | `replace_function`, `format_code`, `normalize_code` | Code modification |
| Validation | `validate_syntax` | Syntax checking |
| Test execution | `run_behavioral_tests`, `run_benchmarking_tests`, `run_line_profile_tests` | Test runners |
| Test results | `test_result_serialization_format` | `"pickle"` (Python) or `"json"` (JS) |
| Test results | `load_coverage` | Load coverage from language-specific format |
| Test results | `compare_test_results` | Equivalence checking between original and candidate |
| Test gen | `postprocess_generated_tests` | Post-process `GeneratedTestsList` objects |
| Test gen | `process_generated_test_strings` | Instrument/transform raw generated test strings |
| Module | `detect_module_system` | Detect project module system (`None` for Python, `"esm"`/`"commonjs"` for JS) |
| Module | `prepare_module` | Parse/validate module before optimization |
| Setup | `setup_test_config` | One-time project setup after language detection |
| Optimizer | `function_optimizer_class` | Return `FunctionOptimizer` subclass for this language |

View file

@ -7,4 +7,5 @@
- **Comments**: Minimal - only explain "why", not "what"
- **Docstrings**: Do not add unless explicitly requested
- **Naming**: NEVER use leading underscores (`_function_name`) - Python has no true private functions, use public names
- **Paths**: Always use absolute paths, handle encoding explicitly (UTF-8)
- **Paths**: Always use absolute paths
- **Encoding**: Always pass `encoding="utf-8"` to `open()`, `read_text()`, `write_text()`, etc. in new or changed code — Windows defaults to `cp1252` which breaks on non-ASCII content. Don't flag pre-existing code that lacks it unless you're already modifying that line.

View file

@ -9,4 +9,5 @@ paths:
- Use `get_language_support(identifier)` from `languages/registry.py` to get a `LanguageSupport` instance — never import language classes directly
- New language support classes must use the `@register_language` decorator to register with the extension and language registries
- `languages/__init__.py` uses `__getattr__` for lazy imports to avoid circular dependencies — follow this pattern when adding new exports
- `is_javascript()` returns `True` for both JavaScript and TypeScript
- Prefer `LanguageSupport` protocol dispatch over `is_python()`/`is_javascript()` guards — remaining guards are being migrated to protocol methods
- `is_javascript()` returns `True` for both JavaScript and TypeScript (still used in ~15 call sites pending migration)

View file

@ -3,7 +3,7 @@ paths:
- "codeflash/optimization/**/*.py"
- "codeflash/verification/**/*.py"
- "codeflash/benchmarking/**/*.py"
- "codeflash/context/**/*.py"
- "codeflash/languages/*/context/**/*.py"
---
# Optimization Pipeline Patterns

View file

@ -56,132 +56,116 @@ jobs:
use_sticky_comment: true
allowed_bots: "claude[bot],codeflash-ai[bot]"
prompt: |
REPO: ${{ github.repository }}
PR NUMBER: ${{ github.event.pull_request.number }}
EVENT: ${{ github.event.action }}
<context>
repo: ${{ github.repository }}
pr_number: ${{ github.event.pull_request.number }}
event: ${{ github.event.action }}
is_re_review: ${{ github.event.action == 'synchronize' }}
</context>
## STEP 1: Run prek and mypy checks, fix issues
<commitment>
Execute these steps in order. If a step has no work, state that and continue to the next step.
Post all review findings in a single summary comment only — never as inline PR review comments.
</commitment>
First, run these checks on files changed in this PR:
1. `uv run prek run --from-ref origin/main` - linting/formatting issues
2. `uv run mypy <changed_files>` - type checking issues
<step name="lint_and_typecheck">
Run checks on files changed in this PR and auto-fix what you can.
If there are prek issues:
- For SAFE auto-fixable issues (formatting, import sorting, trailing whitespace, etc.), run `uv run prek run --from-ref origin/main` again to auto-fix them
- For issues that prek cannot auto-fix, do NOT attempt to fix them manually — report them as remaining issues in your summary
1. Run `uv run prek run --from-ref origin/main` to check linting/formatting.
If there are auto-fixable issues, run it again to fix them.
Report any issues prek cannot auto-fix in your summary.
If there are mypy issues:
- Fix type annotation issues (missing return types, Optional/None unions, import errors for type hints, incorrect types)
- Do NOT add `type: ignore` comments - always fix the root cause
2. Run `uv run mypy <changed_files>` to check types.
Fix type annotation issues (missing return types, Optional unions, import errors).
Always fix the root cause instead of adding `type: ignore` comments.
Leave alone: type errors requiring logic changes, complex generics, anything changing runtime behavior.
After fixing issues:
- Stage the fixed files with `git add`
- Commit with message "style: auto-fix linting issues" or "fix: resolve mypy type errors" as appropriate
- Push the changes with `git push`
3. After fixes: stage with `git add`, commit ("style: auto-fix linting issues" or "fix: resolve mypy type errors"), push.
IMPORTANT - Verification after fixing:
- After committing fixes, run `uv run prek run --from-ref origin/main` ONE MORE TIME to verify all issues are resolved
- If errors remain, either fix them or report them honestly as unfixed in your summary
- NEVER claim issues are fixed without verifying. If you cannot fix an issue, say so
4. Verify by running `uv run prek run --from-ref origin/main` one more time. Report honestly if issues remain.
</step>
Do NOT attempt to fix:
- Type errors that require logic changes or refactoring
- Complex generic type issues
- Anything that could change runtime behavior
<step name="resolve_stale_threads">
Before reviewing, resolve any stale review threads from previous runs.
## STEP 2: Review the PR
1. Fetch unresolved threads you created:
`gh api graphql -f query='{ repository(owner: "${{ github.repository_owner }}", name: "${{ github.event.repository.name }}") { pullRequest(number: ${{ github.event.pull_request.number }}) { reviewThreads(first: 100) { nodes { id isResolved path comments(first: 1) { nodes { body author { login } } } } } } } }' --jq '.data.repository.pullRequest.reviewThreads.nodes[] | select(.isResolved == false) | select(.comments.nodes[0].author.login == "claude") | {id: .id, path: .path, body: .comments.nodes[0].body}'`
${{ github.event.action == 'synchronize' && 'This is a RE-REVIEW after new commits. First, get the list of changed files in this latest push using `gh pr diff`. Review ONLY the changed files. Check ALL existing review comments and resolve ones that are now fixed.' || 'This is the INITIAL REVIEW.' }}
2. For each unresolved thread:
a. Read the file at that path to check if the issue still exists
b. If fixed → resolve it: `gh api graphql -f query='mutation { resolveReviewThread(input: {threadId: "<THREAD_ID>"}) { thread { isResolved } } }'`
c. If still present → leave it
Review this PR focusing ONLY on:
1. Critical bugs or logic errors
Read the actual code before deciding. If there are no unresolved threads, skip to the next step.
</step>
<step name="review">
Review the diff (`gh pr diff ${{ github.event.pull_request.number }}`) for:
1. Bugs that will crash at runtime
2. Security vulnerabilities
3. Breaking API changes
4. Test failures (methods with typos that wont run)
5. Stale documentation — if files or directories were moved, renamed, or deleted, check that `.claude/rules/`, `CLAUDE.md`, and `AGENTS.md` don't reference paths that no longer exist
6. New language support — if new language modules are added under `languages/`, check that `.github/workflows/duplicate-code-detector.yml` includes the new language in its file filters, search patterns, and cross-module checks
IMPORTANT:
- First check existing review comments using `gh api repos/${{ github.repository }}/pulls/${{ github.event.pull_request.number }}/comments`. For each existing comment, check if the issue still exists in the current code.
- If an issue is fixed, use `gh api --method PATCH repos/${{ github.repository }}/pulls/comments/COMMENT_ID -f body="✅ Fixed in latest commit"` to resolve it.
- Only create NEW inline comments for HIGH-PRIORITY issues found in changed files.
- Limit to 5-7 NEW comments maximum per review.
- Use CLAUDE.md for project-specific guidance.
- Use `mcp__github_inline_comment__create_inline_comment` sparingly for critical code issues only.
## STEP 3: Coverage analysis
Ignore style issues, type hints, and log message wording.
Record findings for the summary comment. Refer to CLAUDE.md for project conventions.
</step>
<step name="coverage">
Analyze test coverage for changed files:
1. Get the list of Python files changed in this PR (excluding tests):
`git diff --name-only origin/main...HEAD -- '*.py' | grep -v test`
1. Get changed Python files (excluding tests): `git diff --name-only origin/main...HEAD -- '*.py' | grep -v test`
2. Run coverage on PR branch: `uv run coverage run -m pytest tests/ -q --tb=no` then `uv run coverage json -o coverage-pr.json`
3. Get per-file coverage: `uv run coverage report --include="<changed_files>"`
4. Compare with main: checkout main, run coverage, checkout back
5. Flag: new files below 75%, decreased coverage, untested changed lines
</step>
2. Run tests with coverage on the PR branch:
`uv run coverage run -m pytest tests/ -q --tb=no`
`uv run coverage json -o coverage-pr.json`
<step name="summary_comment">
Post exactly one summary comment containing all results from previous steps.
3. Get coverage for changed files only:
`uv run coverage report --include="<changed_files_comma_separated>"`
4. Compare with main branch coverage:
- Checkout main: `git checkout origin/main`
- Run coverage: `uv run coverage run -m pytest tests/ -q --tb=no && uv run coverage json -o coverage-main.json`
- Checkout back: `git checkout -`
5. Analyze the diff to identify:
- NEW FILES: Files that don't exist on main (require good test coverage)
- MODIFIED FILES: Files with changes (changes must be covered by tests)
6. Report in PR comment with a markdown table:
- Coverage % for each changed file (PR vs main)
- Overall coverage change
- For NEW files: Flag if coverage is below 75%
- For MODIFIED files: Flag if the changed lines are not covered by tests
- Flag if overall coverage decreased
Coverage requirements:
- New implementations/files: Must have ≥75% test coverage
- Modified code: Changed lines should be exercised by existing or new tests
- No coverage regressions: Overall coverage should not decrease
## STEP 4: Post ONE consolidated summary comment
CRITICAL: You must post exactly ONE summary comment containing ALL results (pre-commit, review, coverage).
DO NOT post multiple separate comments. Use this format:
To ensure one comment: find an existing claude[bot] comment and update it, or create one if none exists.
Delete any duplicate claude[bot] comments.
```
gh api repos/${{ github.repository }}/issues/${{ github.event.pull_request.number }}/comments --jq '.[] | select(.user.login == "claude[bot]") | .id' | head -1
```
Format:
## PR Review Summary
### Prek Checks
[status and any fixes made]
### Code Review
[critical issues found, if any]
### Test Coverage
[coverage table and analysis]
---
*Last updated: <timestamp>*
```
</step>
To ensure only ONE comment exists:
1. Find existing claude[bot] comment: `gh api repos/${{ github.repository }}/issues/${{ github.event.pull_request.number }}/comments --jq '.[] | select(.user.login == "claude[bot]") | .id' | head -1`
2. If found, UPDATE it: `gh api --method PATCH repos/${{ github.repository }}/issues/comments/<ID> -f body="<content>"`
3. If not found, CREATE: `gh pr comment ${{ github.event.pull_request.number }} --body "<content>"`
4. Delete any OTHER claude[bot] comments to clean up duplicates: `gh api repos/${{ github.repository }}/issues/${{ github.event.pull_request.number }}/comments --jq '.[] | select(.user.login == "claude[bot]") | .id' | tail -n +2 | xargs -I {} gh api --method DELETE repos/${{ github.repository }}/issues/comments/{}`
<step name="simplify">
Run /simplify to review recently changed code for reuse, quality, and efficiency opportunities.
If improvements are found, commit with "refactor: simplify <description>" and push.
Only make behavior-preserving changes.
</step>
## STEP 5: Merge pending codeflash optimization PRs
<step name="merge_optimization_prs">
Check for open PRs from codeflash-ai[bot]:
`gh pr list --author "codeflash-ai[bot]" --state open --json number,title,headRefName,createdAt,mergeable`
Check for open optimization PRs from codeflash and merge if CI passes:
For each PR:
- If CI passes and the PR is mergeable → merge with `--squash --delete-branch`
- Close the PR as stale if ANY of these apply:
- Older than 7 days
- Has merge conflicts (mergeable state is "CONFLICTING")
- CI is failing
- The optimized function no longer exists in the target file (check the diff)
Close with: `gh pr close <number> --comment "Closing stale optimization PR." --delete-branch`
</step>
1. List open PRs from codeflash bot:
`gh pr list --author "codeflash-ai[bot]" --state open --json number,title,headRefName`
2. For each optimization PR:
- Check if CI is passing: `gh pr checks <number>`
- If all checks pass, merge it: `gh pr merge <number> --squash --delete-branch`
claude_args: '--model us.anthropic.claude-opus-4-6-v1 --allowedTools "mcp__github_inline_comment__create_inline_comment,Bash(gh pr comment:*),Bash(gh pr diff:*),Bash(gh pr view:*),Bash(gh pr list:*),Bash(gh pr checks:*),Bash(gh pr merge:*),Bash(gh issue view:*),Bash(gh issue list:*),Bash(gh api:*),Bash(uv run prek *),Bash(uv run mypy *),Bash(uv run coverage *),Bash(uv run pytest *),Bash(git status*),Bash(git add *),Bash(git commit *),Bash(git push*),Bash(git diff *),Bash(git checkout *),Read,Glob,Grep,Edit"'
<verification>
Before finishing, confirm:
- All steps were attempted (even if some had no work)
- Stale review threads were checked and resolved where appropriate
- All findings are in a single summary comment (no inline review comments were created)
- If fixes were made, they were verified with prek
</verification>
claude_args: '--model us.anthropic.claude-opus-4-6-v1 --allowedTools "Bash(gh pr comment:*),Bash(gh pr diff:*),Bash(gh pr view:*),Bash(gh pr list:*),Bash(gh pr checks:*),Bash(gh pr merge:*),Bash(gh issue view:*),Bash(gh issue list:*),Bash(gh api:*),Bash(uv run prek *),Bash(uv run mypy *),Bash(uv run coverage *),Bash(uv run pytest *),Bash(git status*),Bash(git add *),Bash(git commit *),Bash(git push*),Bash(git diff *),Bash(git checkout *),Read,Glob,Grep,Edit,Skill"'
additional_permissions: |
actions: read

View file

@ -14,7 +14,8 @@ from codeflash.cli_cmds.console import console, logger
from codeflash.code_utils.env_utils import get_codeflash_api_key
from codeflash.code_utils.git_utils import get_last_commit_author_if_pr_exists, get_repo_owner_and_name
from codeflash.code_utils.time_utils import humanize_runtime
from codeflash.languages import is_javascript, is_python
from codeflash.languages import Language, current_language
from codeflash.languages.current import current_language_support
from codeflash.models.ExperimentMetadata import ExperimentMetadata
from codeflash.models.models import (
AIServiceRefinerRequest,
@ -51,6 +52,18 @@ class AiServiceClient:
"""Get the next LLM call sequence number."""
return next(self.llm_call_counter)
@staticmethod
def add_language_metadata(
payload: dict[str, Any], language_version: str | None = None, module_system: str | None = None
) -> None:
"""Add language version and module system metadata to an API payload."""
payload["python_version"] = platform.python_version()
default_lang_version = current_language_support().default_language_version
if default_lang_version is not None:
payload["language_version"] = language_version or default_lang_version
if module_system:
payload["module_system"] = module_system
def get_aiservice_base_url(self) -> str:
if os.environ.get("CODEFLASH_AIS_SERVER", default="prod").lower() == "local":
logger.info("Using local AI Service at http://localhost:8000")
@ -177,16 +190,7 @@ class AiServiceClient:
"is_numerical_code": is_numerical_code,
}
# Add language-specific version fields
# Always include python_version for backward compatibility with older backend
payload["python_version"] = platform.python_version()
if is_python():
pass # python_version already set
else:
payload["language_version"] = language_version or "ES2022"
# Add module system for JavaScript/TypeScript (esm or commonjs)
if module_system:
payload["module_system"] = module_system
self.add_language_metadata(payload, language_version, module_system)
# DEBUG: Print payload language field
logger.debug(
@ -432,14 +436,7 @@ class AiServiceClient:
"language": opt.language,
}
# Add language version - always include python_version for backward compatibility
item["python_version"] = platform.python_version()
if is_python():
pass # python_version already set
elif opt.language_version:
item["language_version"] = opt.language_version
else:
item["language_version"] = "ES2022" # Default for JS/TS
self.add_language_metadata(item, opt.language_version)
# Add multi-file context if provided
if opt.additional_context_files:
@ -752,16 +749,11 @@ class AiServiceClient:
"""
# Validate test framework based on language
python_frameworks = ["pytest", "unittest"]
javascript_frameworks = ["jest", "mocha", "vitest"]
if is_python():
assert test_framework in python_frameworks, (
f"Invalid test framework for Python, got {test_framework} but expected one of {python_frameworks}"
)
elif is_javascript():
assert test_framework in javascript_frameworks, (
f"Invalid test framework for JavaScript, got {test_framework} but expected one of {javascript_frameworks}"
)
lang_support = current_language_support()
valid_frameworks = lang_support.valid_test_frameworks
assert test_framework in valid_frameworks, (
f"Invalid test framework for {current_language()}, got {test_framework} but expected one of {list(valid_frameworks)}"
)
payload: dict[str, Any] = {
"source_code_being_tested": source_code_being_tested,
@ -780,16 +772,7 @@ class AiServiceClient:
"is_numerical_code": is_numerical_code,
}
# Add language-specific version fields
# Always include python_version for backward compatibility with older backend
payload["python_version"] = platform.python_version()
if is_python():
pass # python_version already set
else:
payload["language_version"] = language_version or "ES2022"
# Add module system for JavaScript/TypeScript (esm or commonjs)
if module_system:
payload["module_system"] = module_system
self.add_language_metadata(payload, language_version, module_system)
# DEBUG: Print payload language field
logger.debug(f"Sending testgen request with language='{payload['language']}', framework='{test_framework}'")
@ -875,7 +858,7 @@ class AiServiceClient:
"codeflash_version": codeflash_version,
"calling_fn_details": calling_fn_details,
"language": language,
"python_version": platform.python_version() if is_python() else None,
"python_version": platform.python_version() if current_language() == Language.PYTHON else None,
"call_sequence": self.get_next_sequence(),
}
console.rule()

View file

@ -408,9 +408,10 @@ def get_all_function_names(code: str) -> tuple[bool, list[str]]:
def get_run_tmp_file(file_path: Path | str) -> Path:
if isinstance(file_path, str):
file_path = Path(file_path)
if not hasattr(get_run_tmp_file, "tmpdir"):
if not hasattr(get_run_tmp_file, "tmpdir_path"):
get_run_tmp_file.tmpdir = TemporaryDirectory(prefix="codeflash_")
return Path(get_run_tmp_file.tmpdir.name) / file_path
get_run_tmp_file.tmpdir_path = Path(get_run_tmp_file.tmpdir.name)
return get_run_tmp_file.tmpdir_path / file_path
def path_belongs_to_site_packages(file_path: Path) -> bool:

View file

@ -6,7 +6,10 @@ from typing import Any, Union
MAX_TEST_RUN_ITERATIONS = 5
OPTIMIZATION_CONTEXT_TOKEN_LIMIT = 64000
TESTGEN_CONTEXT_TOKEN_LIMIT = 64000
READ_WRITABLE_LIMIT_ERROR = "Read-writable code has exceeded token limit, cannot proceed"
TESTGEN_LIMIT_ERROR = "Testgen code context has exceeded token limit, cannot proceed"
INDIVIDUAL_TESTCASE_TIMEOUT = 15
JAVA_TESTCASE_TIMEOUT = 120
MAX_FUNCTION_TEST_SECONDS = 60
MIN_IMPROVEMENT_THRESHOLD = 0.05
MIN_THROUGHPUT_IMPROVEMENT_THRESHOLD = 0.10 # 10% minimum improvement for async throughput

View file

@ -1,126 +0,0 @@
"""Code deduplication utilities using language-specific normalizers.
This module provides functions to normalize code, generate fingerprints,
and detect duplicate code segments across different programming languages.
"""
from __future__ import annotations
import hashlib
import re
from codeflash.code_utils.normalizers import get_normalizer
from codeflash.languages import current_language, is_python
def normalize_code(
code: str, remove_docstrings: bool = True, return_ast_dump: bool = False, language: str | None = None
) -> str:
"""Normalize code by parsing, cleaning, and normalizing variable names.
Function names, class names, and parameters are preserved.
Args:
code: Source code as string
remove_docstrings: Whether to remove docstrings (Python only)
return_ast_dump: Return AST dump instead of unparsed code (Python only)
language: Language of the code. If None, uses the current session language.
Returns:
Normalized code as string
"""
if language is None:
language = current_language().value
try:
normalizer = get_normalizer(language)
# Python has additional options
if is_python():
if return_ast_dump:
return normalizer.normalize_for_hash(code)
return normalizer.normalize(code, remove_docstrings=remove_docstrings)
# For other languages, use standard normalization
return normalizer.normalize(code)
except ValueError:
# Unknown language - fall back to basic normalization
return _basic_normalize(code)
except Exception:
# Parsing error - try other languages or fall back
if is_python():
# Try JavaScript as fallback
try:
js_normalizer = get_normalizer("javascript")
js_result = js_normalizer.normalize(code)
if js_result != _basic_normalize(code):
return js_result
except Exception:
pass
return _basic_normalize(code)
def _basic_normalize(code: str) -> str:
"""Basic normalization: remove comments and normalize whitespace."""
# Remove single-line comments (// and #)
code = re.sub(r"//.*$", "", code, flags=re.MULTILINE)
code = re.sub(r"#.*$", "", code, flags=re.MULTILINE)
# Remove multi-line comments
code = re.sub(r"/\*.*?\*/", "", code, flags=re.DOTALL)
code = re.sub(r'""".*?"""', "", code, flags=re.DOTALL)
code = re.sub(r"'''.*?'''", "", code, flags=re.DOTALL)
# Normalize whitespace
return " ".join(code.split())
def get_code_fingerprint(code: str, language: str | None = None) -> str:
"""Generate a fingerprint for normalized code.
Args:
code: Source code
language: Language of the code. If None, uses the current session language.
Returns:
SHA-256 hash of normalized code
"""
if language is None:
language = current_language().value
try:
normalizer = get_normalizer(language)
return normalizer.get_fingerprint(code)
except ValueError:
# Unknown language - use basic normalization
normalized = _basic_normalize(code)
return hashlib.sha256(normalized.encode()).hexdigest()
def are_codes_duplicate(code1: str, code2: str, language: str | None = None) -> bool:
"""Check if two code segments are duplicates after normalization.
Args:
code1: First code segment
code2: Second code segment
language: Language of the code. If None, uses the current session language.
Returns:
True if codes are structurally identical (ignoring local variable names)
"""
if language is None:
language = current_language().value
try:
normalizer = get_normalizer(language)
return normalizer.are_duplicates(code1, code2)
except ValueError:
# Unknown language - use basic comparison
return _basic_normalize(code1) == _basic_normalize(code2)
except Exception:
return False
# Re-export for backward compatibility
__all__ = ["are_codes_duplicate", "get_code_fingerprint", "normalize_code"]

View file

@ -1,106 +0,0 @@
"""Code normalizers for different programming languages.
This module provides language-specific code normalizers that transform source code
into canonical forms for duplicate detection. The normalizers:
- Replace local variable names with canonical forms (var_0, var_1, etc.)
- Preserve function names, class names, parameters, and imports
- Remove or normalize comments and docstrings
- Produce consistent output for structurally identical code
Usage:
>>> normalizer = get_normalizer("python")
>>> normalized = normalizer.normalize(code)
>>> fingerprint = normalizer.get_fingerprint(code)
>>> are_same = normalizer.are_duplicates(code1, code2)
"""
from __future__ import annotations
from codeflash.code_utils.normalizers.base import CodeNormalizer
from codeflash.code_utils.normalizers.javascript import JavaScriptNormalizer, TypeScriptNormalizer
from codeflash.code_utils.normalizers.python import PythonNormalizer
__all__ = [
"CodeNormalizer",
"JavaScriptNormalizer",
"PythonNormalizer",
"TypeScriptNormalizer",
"get_normalizer",
"get_normalizer_for_extension",
]
# Registry of normalizers by language
_NORMALIZERS: dict[str, type[CodeNormalizer]] = {
"python": PythonNormalizer,
"javascript": JavaScriptNormalizer,
"typescript": TypeScriptNormalizer,
}
# Singleton cache for normalizer instances
_normalizer_instances: dict[str, CodeNormalizer] = {}
def get_normalizer(language: str) -> CodeNormalizer:
"""Get a code normalizer for the specified language.
Args:
language: Language name ('python', 'javascript', 'typescript')
Returns:
CodeNormalizer instance for the language
Raises:
ValueError: If no normalizer exists for the language
"""
language = language.lower()
# Check cache first
if language in _normalizer_instances:
return _normalizer_instances[language]
# Get normalizer class
if language not in _NORMALIZERS:
supported = ", ".join(sorted(_NORMALIZERS.keys()))
msg = f"No normalizer available for language '{language}'. Supported: {supported}"
raise ValueError(msg)
# Create and cache instance
normalizer = _NORMALIZERS[language]()
_normalizer_instances[language] = normalizer
return normalizer
def get_normalizer_for_extension(extension: str) -> CodeNormalizer | None:
"""Get a code normalizer based on file extension.
Args:
extension: File extension including dot (e.g., '.py', '.js')
Returns:
CodeNormalizer instance if found, None otherwise
"""
extension = extension.lower()
if not extension.startswith("."):
extension = f".{extension}"
for language in _NORMALIZERS:
normalizer = get_normalizer(language)
if extension in normalizer.supported_extensions:
return normalizer
return None
def register_normalizer(language: str, normalizer_class: type[CodeNormalizer]) -> None:
"""Register a new normalizer for a language.
Args:
language: Language name
normalizer_class: CodeNormalizer subclass
"""
_NORMALIZERS[language.lower()] = normalizer_class
# Clear cached instance if it exists
_normalizer_instances.pop(language.lower(), None)

View file

@ -1,104 +0,0 @@
"""Abstract base class for code normalizers.
Code normalizers transform source code into a canonical form for duplicate detection.
They normalize variable names, remove comments/docstrings, and produce consistent output
that can be compared across different implementations of the same algorithm.
"""
# TODO:{claude} move to base.py in language folder
from __future__ import annotations
from abc import ABC, abstractmethod
class CodeNormalizer(ABC):
"""Abstract base class for language-specific code normalizers.
Subclasses must implement the normalize() method for their specific language.
The normalization should:
- Normalize local variable names to canonical forms (var_0, var_1, etc.)
- Preserve function names, class names, parameters, and imports
- Remove or normalize comments and docstrings
- Produce consistent output for structurally identical code
Example:
>>> normalizer = PythonNormalizer()
>>> code1 = "def foo(x): y = x + 1; return y"
>>> code2 = "def foo(x): z = x + 1; return z"
>>> normalizer.normalize(code1) == normalizer.normalize(code2)
True
"""
@property
@abstractmethod
def language(self) -> str:
"""Return the language this normalizer handles."""
...
@property
def supported_extensions(self) -> tuple[str, ...]:
"""Return file extensions this normalizer can handle."""
return ()
@abstractmethod
def normalize(self, code: str) -> str:
"""Normalize code to a canonical form for comparison.
Args:
code: Source code to normalize
Returns:
Normalized representation of the code
"""
...
@abstractmethod
def normalize_for_hash(self, code: str) -> str:
"""Normalize code optimized for hashing/fingerprinting.
This may return a more compact representation than normalize().
Args:
code: Source code to normalize
Returns:
Normalized representation suitable for hashing
"""
...
def are_duplicates(self, code1: str, code2: str) -> bool:
"""Check if two code segments are duplicates after normalization.
Args:
code1: First code segment
code2: Second code segment
Returns:
True if codes are structurally identical
"""
try:
normalized1 = self.normalize_for_hash(code1)
normalized2 = self.normalize_for_hash(code2)
except Exception:
return False
else:
return normalized1 == normalized2
def get_fingerprint(self, code: str) -> str:
"""Generate a fingerprint hash for normalized code.
Args:
code: Source code to fingerprint
Returns:
SHA-256 hash of normalized code
"""
import hashlib
normalized = self.normalize_for_hash(code)
return hashlib.sha256(normalized.encode()).hexdigest()

View file

@ -641,17 +641,15 @@ def discover_unit_tests(
discover_only_these_tests: list[Path] | None = None,
file_to_funcs_to_optimize: dict[Path, list[FunctionToOptimize]] | None = None,
) -> tuple[dict[str, set[FunctionCalledInTest]], int, int]:
from codeflash.languages import is_javascript, is_python
from codeflash.languages import is_python
from codeflash.languages.current import current_language_support
# Detect language from functions being optimized
language = _detect_language_from_functions(file_to_funcs_to_optimize)
# Route to language-specific test discovery for non-Python languages
if not is_python():
# For JavaScript/TypeScript, tests_project_rootdir should be tests_root itself
# The Jest helper will be configured to NOT include "tests." prefix to match
if is_javascript():
cfg.tests_project_rootdir = cfg.tests_root
current_language_support().adjust_test_config_for_discovery(cfg)
return discover_tests_for_language(cfg, language, file_to_funcs_to_optimize)
# Existing Python logic

View file

@ -1,6 +1,7 @@
from __future__ import annotations
import ast
import contextlib
import os
import random
import warnings
@ -10,8 +11,8 @@ from pathlib import Path
from typing import TYPE_CHECKING, Any, Optional
import git
import libcst as cst
from pydantic.dataclasses import dataclass
from rich.text import Text
from rich.tree import Tree
from codeflash.api.cfapi import get_blocklisted_functions, is_function_being_optimized_again
@ -37,14 +38,8 @@ __all__ = ["FunctionParent", "FunctionToOptimize"]
if TYPE_CHECKING:
from argparse import Namespace
from libcst import CSTNode
from libcst.metadata import CodeRange
from codeflash.models.models import CodeOptimizationContext
from codeflash.verification.verification_utils import TestConfig
import contextlib
from rich.text import Text
@dataclass(frozen=True)
@ -56,70 +51,6 @@ class FunctionProperties:
staticmethod_class_name: Optional[str]
class ReturnStatementVisitor(cst.CSTVisitor):
def __init__(self) -> None:
super().__init__()
self.has_return_statement: bool = False
def visit_Return(self, node: cst.Return) -> None:
self.has_return_statement = True
class FunctionVisitor(cst.CSTVisitor):
METADATA_DEPENDENCIES = (cst.metadata.PositionProvider, cst.metadata.ParentNodeProvider)
def __init__(self, file_path: Path) -> None:
super().__init__()
self.file_path: Path = file_path
self.functions: list[FunctionToOptimize] = []
@staticmethod
def is_pytest_fixture(node: cst.FunctionDef) -> bool:
for decorator in node.decorators:
dec = decorator.decorator
if isinstance(dec, cst.Call):
dec = dec.func
if isinstance(dec, cst.Attribute) and dec.attr.value == "fixture":
if isinstance(dec.value, cst.Name) and dec.value.value == "pytest":
return True
if isinstance(dec, cst.Name) and dec.value == "fixture":
return True
return False
@staticmethod
def is_property(node: cst.FunctionDef) -> bool:
for decorator in node.decorators:
dec = decorator.decorator
if isinstance(dec, cst.Name) and dec.value in ("property", "cached_property"):
return True
return False
def visit_FunctionDef(self, node: cst.FunctionDef) -> None:
return_visitor: ReturnStatementVisitor = ReturnStatementVisitor()
node.visit(return_visitor)
if return_visitor.has_return_statement and not self.is_pytest_fixture(node) and not self.is_property(node):
pos: CodeRange = self.get_metadata(cst.metadata.PositionProvider, node)
parents: CSTNode | None = self.get_metadata(cst.metadata.ParentNodeProvider, node)
ast_parents: list[FunctionParent] = []
while parents is not None:
if isinstance(parents, cst.FunctionDef):
# Skip nested functions — only discover top-level and class-level functions
return
if isinstance(parents, cst.ClassDef):
ast_parents.append(FunctionParent(parents.name.value, parents.__class__.__name__))
parents = self.get_metadata(cst.metadata.ParentNodeProvider, parents, default=None)
self.functions.append(
FunctionToOptimize(
function_name=node.name.value,
file_path=self.file_path,
parents=list(reversed(ast_parents)),
starting_line=pos.start.line,
ending_line=pos.end.line,
is_async=bool(node.asynchronous),
)
)
# =============================================================================
# Multi-language support helpers
# =============================================================================
@ -480,22 +411,14 @@ def get_functions_within_lines(modified_lines: dict[str, list[int]]) -> dict[Pat
path = Path(path_str)
if not path.exists():
continue
with path.open(encoding="utf8") as f:
file_content = f.read()
try:
wrapper = cst.metadata.MetadataWrapper(cst.parse_module(file_content))
except Exception as e:
logger.exception(e)
continue
function_lines = FunctionVisitor(file_path=path)
wrapper.visit(function_lines)
functions[path] = [
function_to_optimize
for function_to_optimize in function_lines.functions
if (start_line := function_to_optimize.starting_line) is not None
and (end_line := function_to_optimize.ending_line) is not None
and any(start_line <= line <= end_line for line in lines_in_file)
]
all_functions = find_all_functions_in_file(path)
functions[path] = [
func
for func in all_functions.get(path, [])
if func.starting_line is not None
and func.ending_line is not None
and any(func.starting_line <= line <= func.ending_line for line in lines_in_file)
]
return functions
@ -524,24 +447,19 @@ def get_all_files_and_functions(
def find_all_functions_in_file(file_path: Path) -> dict[Path, list[FunctionToOptimize]]:
"""Find all optimizable functions in a file, routing to the appropriate language handler.
This function checks if the file extension is supported and routes to either
the Python-specific implementation (for backward compatibility) or the
language support abstraction for other languages.
Args:
file_path: Path to the source file.
Returns:
Dictionary mapping file path to list of FunctionToOptimize.
"""
# Check if the file extension is supported
"""Find all optimizable functions in a file using the language support abstraction."""
if not is_language_supported(file_path):
return {}
try:
from codeflash.languages.base import FunctionFilterCriteria
return _find_all_functions_via_language_support(file_path)
lang_support = get_language_support(file_path)
criteria = FunctionFilterCriteria(require_return=True)
source = file_path.read_text(encoding="utf-8")
return {file_path: lang_support.discover_functions(source, file_path, criteria)}
except Exception as e:
logger.debug(f"Failed to discover functions in {file_path}: {e}")
return {}
def get_all_replay_test_functions(

View file

@ -11,7 +11,7 @@ Usage:
lang = get_language_support(Path("example.py"))
# Discover functions
functions = lang.discover_functions(file_path)
functions = lang.discover_functions(source, file_path)
# Replace a function
new_source = lang.replace_function(file_path, function, new_code)
@ -31,6 +31,7 @@ from codeflash.languages.base import (
from codeflash.languages.current import (
current_language,
current_language_support,
is_java,
is_javascript,
is_python,
is_typescript,
@ -78,6 +79,10 @@ def __getattr__(name: str):
from codeflash.languages.python.support import PythonSupport
return PythonSupport
if name == "JavaSupport":
from codeflash.languages.java.support import JavaSupport
return JavaSupport
msg = f"module {__name__!r} has no attribute {name!r}"
raise AttributeError(msg)
@ -101,6 +106,7 @@ __all__ = [
"get_language_support",
"get_supported_extensions",
"get_supported_languages",
"is_java",
"is_javascript",
"is_jest",
"is_mocha",

View file

@ -11,11 +11,13 @@ from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Any, Protocol, runtime_checkable
if TYPE_CHECKING:
import ast
from collections.abc import Callable, Iterable, Sequence
from pathlib import Path
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
from codeflash.models.models import FunctionSource, GeneratedTestsList, InvocationId
from codeflash.models.models import FunctionSource, GeneratedTestsList, InvocationId, ValidCode
from codeflash.verification.verification_utils import TestConfig
from codeflash.languages.language_enum import Language
from codeflash.models.function_types import FunctionParent
@ -91,6 +93,7 @@ class CodeContext:
target_file: Path
helper_functions: list[HelperFunction] = field(default_factory=list)
read_only_context: str = ""
imported_type_skeletons: str = ""
imports: list[str] = field(default_factory=list)
language: Language = Language.PYTHON
@ -166,6 +169,7 @@ class FunctionFilterCriteria:
include_patterns: list[str] = field(default_factory=list)
exclude_patterns: list[str] = field(default_factory=list)
require_return: bool = True
require_export: bool = True
include_async: bool = True
include_methods: bool = True
min_lines: int | None = None
@ -251,7 +255,7 @@ class LanguageSupport(Protocol):
def language(self) -> Language:
return Language.PYTHON
def discover_functions(self, file_path: Path, ...) -> list[FunctionInfo]:
def discover_functions(self, source: str, file_path: Path, ...) -> list[FunctionInfo]:
# Python-specific implementation using LibCST
...
@ -302,15 +306,49 @@ class LanguageSupport(Protocol):
"""
...
@property
def default_language_version(self) -> str | None:
"""Default language version string sent to AI service.
Returns None for languages where the runtime version is auto-detected (e.g. Python).
Returns a version string (e.g. "ES2022") for languages that need an explicit default.
"""
return None
@property
def valid_test_frameworks(self) -> tuple[str, ...]:
"""Valid test frameworks for this language."""
...
@property
def test_result_serialization_format(self) -> str:
"""How test return values are serialized: "pickle" or "json"."""
return "pickle"
def load_coverage(
self,
coverage_database_file: Path,
function_name: str,
code_context: Any,
source_file: Path,
coverage_config_file: Path | None = None,
) -> Any:
"""Load coverage data from language-specific format.
Returns a CoverageData instance.
"""
...
# === Discovery ===
def discover_functions(
self, file_path: Path, filter_criteria: FunctionFilterCriteria | None = None
self, source: str, file_path: Path, filter_criteria: FunctionFilterCriteria | None = None
) -> list[FunctionToOptimize]:
"""Find all optimizable functions in a file.
"""Find all optimizable functions in source code.
Args:
file_path: Path to the source file to analyze.
source: Source code to analyze.
file_path: Path to the source file (used for context and language detection).
filter_criteria: Optional criteria to filter functions.
Returns:
@ -637,6 +675,45 @@ class LanguageSupport(Protocol):
"""
...
@property
def function_optimizer_class(self) -> type:
"""Return the FunctionOptimizer subclass for this language."""
from codeflash.optimization.function_optimizer import FunctionOptimizer
return FunctionOptimizer
def prepare_module(
self, module_code: str, module_path: Path, project_root: Path
) -> tuple[dict[Path, ValidCode], ast.Module | None] | None:
"""Parse/validate a module before optimization."""
...
def setup_test_config(self, test_cfg: TestConfig, file_path: Path) -> None:
"""One-time project setup after language detection. Default: no-op."""
def adjust_test_config_for_discovery(self, test_cfg: TestConfig) -> None:
"""Adjust test config before test discovery. Default: no-op."""
def detect_module_system(self, project_root: Path, source_file: Path) -> str | None:
"""Detect the module system used by the project. Default: None (not applicable)."""
return None
def process_generated_test_strings(
self,
generated_test_source: str,
instrumented_behavior_test_source: str,
instrumented_perf_test_source: str,
function_to_optimize: FunctionToOptimize,
test_path: Path,
test_cfg: Any,
project_module_system: str | None,
) -> tuple[str, str, str]:
"""Process raw generated test strings (instrumentation, placeholder replacement, etc.).
Returns (generated_test_source, instrumented_behavior_source, instrumented_perf_source).
"""
...
# === Configuration ===
def get_test_file_suffix(self) -> str:
@ -732,6 +809,20 @@ class LanguageSupport(Protocol):
# === Test Execution ===
def generate_concolic_tests(
self,
test_cfg: TestConfig,
project_root: Path,
function_to_optimize: FunctionToOptimize,
function_to_optimize_ast: Any,
) -> tuple[dict, str]:
"""Generate concolic tests for a function.
Default implementation returns empty results. Override for languages
that support concolic testing (e.g. Python via CrossHair).
"""
return {}, ""
def run_behavioral_tests(
self,
test_paths: Any,
@ -788,6 +879,31 @@ class LanguageSupport(Protocol):
"""
...
def run_line_profile_tests(
self,
test_paths: Any,
test_env: dict[str, str],
cwd: Path,
timeout: int | None = None,
project_root: Path | None = None,
line_profile_output_file: Path | None = None,
) -> tuple[Path, Any]:
"""Run tests for line profiling.
Args:
test_paths: TestFiles object containing test file information.
test_env: Environment variables for the test run.
cwd: Working directory for running tests.
timeout: Optional timeout in seconds.
project_root: Project root directory.
line_profile_output_file: Path where line profile results will be written.
Returns:
Tuple of (result_file_path, subprocess_result).
"""
...
def convert_parents_to_tuple(parents: list | tuple) -> tuple[FunctionParent, ...]:
"""Convert a list of parent objects to a tuple of FunctionParent.

View file

@ -0,0 +1,135 @@
"""Language-agnostic code replacement utilities.
Used by non-Python language optimizers to replace function definitions
via the LanguageSupport protocol.
"""
from __future__ import annotations
from typing import TYPE_CHECKING
from codeflash.cli_cmds.console import logger
from codeflash.languages.base import FunctionFilterCriteria
if TYPE_CHECKING:
from pathlib import Path
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
from codeflash.languages.base import LanguageSupport
from codeflash.models.models import CodeStringsMarkdown
# Permissive criteria for discovering functions in code snippets (no export/return filtering)
_SOURCE_CRITERIA = FunctionFilterCriteria(require_return=False, require_export=False)
def get_optimized_code_for_module(relative_path: Path, optimized_code: CodeStringsMarkdown) -> str:
file_to_code_context = optimized_code.file_to_path()
module_optimized_code = file_to_code_context.get(str(relative_path))
if module_optimized_code is None:
# Fallback: if there's only one code block with None file path,
# use it regardless of the expected path (the AI server doesn't always include file paths)
if "None" in file_to_code_context and len(file_to_code_context) == 1:
module_optimized_code = file_to_code_context["None"]
logger.debug(f"Using code block with None file_path for {relative_path}")
else:
logger.warning(
f"Optimized code not found for {relative_path} In the context\n-------\n{optimized_code}\n-------\n"
"re-check your 'markdown code structure'"
f"existing files are {file_to_code_context.keys()}"
)
module_optimized_code = ""
return module_optimized_code
def replace_function_definitions_for_language(
function_names: list[str],
optimized_code: CodeStringsMarkdown,
module_abspath: Path,
project_root_path: Path,
lang_support: LanguageSupport,
function_to_optimize: FunctionToOptimize | None = None,
) -> bool:
"""Replace function definitions using the LanguageSupport protocol.
Works for any language that implements LanguageSupport.replace_function
and LanguageSupport.discover_functions.
"""
original_source_code: str = module_abspath.read_text(encoding="utf8")
code_to_apply = get_optimized_code_for_module(module_abspath.relative_to(project_root_path), optimized_code)
if not code_to_apply.strip():
return False
original_source_code = lang_support.add_global_declarations(
optimized_code=code_to_apply, original_source=original_source_code, module_abspath=module_abspath
)
if (
function_to_optimize
and function_to_optimize.starting_line
and function_to_optimize.ending_line
and function_to_optimize.file_path == module_abspath
):
optimized_func = _extract_function_from_code(
lang_support, code_to_apply, function_to_optimize.function_name, module_abspath
)
if optimized_func:
new_code = lang_support.replace_function(original_source_code, function_to_optimize, optimized_func)
else:
new_code = lang_support.replace_function(original_source_code, function_to_optimize, code_to_apply)
else:
new_code = original_source_code
modified = False
functions_to_replace = list(function_names)
for func_name in functions_to_replace:
current_functions = lang_support.discover_functions(new_code, module_abspath, _SOURCE_CRITERIA)
func = None
for f in current_functions:
if func_name in (f.qualified_name, f.function_name):
func = f
break
if func is None:
continue
optimized_func = _extract_function_from_code(
lang_support, code_to_apply, func.function_name, module_abspath
)
if optimized_func:
new_code = lang_support.replace_function(new_code, func, optimized_func)
modified = True
if not modified:
logger.warning(f"Could not find function {function_names} in {module_abspath}")
return False
if original_source_code.strip() == new_code.strip():
return False
module_abspath.write_text(new_code, encoding="utf8")
return True
def _extract_function_from_code(
lang_support: LanguageSupport, source_code: str, function_name: str, file_path: Path
) -> str | None:
"""Extract a specific function's source code from a code string.
Includes JSDoc/docstring comments if present.
"""
try:
functions = lang_support.discover_functions(source_code, file_path, _SOURCE_CRITERIA)
for func in functions:
if func.function_name == function_name:
lines = source_code.splitlines(keepends=True)
effective_start = func.doc_start_line or func.starting_line
if effective_start and func.ending_line and effective_start <= len(lines):
func_lines = lines[effective_start - 1 : func.ending_line]
return "".join(func_lines)
except Exception as e:
logger.debug(f"Error extracting function {function_name}: {e}")
return None

View file

@ -103,6 +103,16 @@ def is_typescript() -> bool:
return _current_language == Language.TYPESCRIPT
def is_java() -> bool:
"""Check if the current language is Java.
Returns:
True if the current language is Java.
"""
return _current_language == Language.JAVA
def current_language_support() -> LanguageSupport:
"""Get the LanguageSupport instance for the current language.

View file

@ -0,0 +1,228 @@
from __future__ import annotations
import hashlib
from collections import defaultdict
from pathlib import Path
from typing import TYPE_CHECKING, Any
from codeflash.cli_cmds.console import logger
from codeflash.code_utils.code_utils import encoded_tokens_len, get_run_tmp_file
from codeflash.code_utils.config_consts import (
OPTIMIZATION_CONTEXT_TOKEN_LIMIT,
READ_WRITABLE_LIMIT_ERROR,
TESTGEN_CONTEXT_TOKEN_LIMIT,
TESTGEN_LIMIT_ERROR,
TOTAL_LOOPING_TIME_EFFECTIVE,
)
from codeflash.either import Failure, Success
from codeflash.models.models import (
CodeOptimizationContext,
CodeString,
CodeStringsMarkdown,
FunctionSource,
TestingMode,
TestResults,
)
from codeflash.optimization.function_optimizer import FunctionOptimizer
from codeflash.verification.equivalence import compare_test_results
if TYPE_CHECKING:
from codeflash.either import Result
from codeflash.languages.base import CodeContext, HelperFunction
from codeflash.models.models import CoverageData, OriginalCodeBaseline, TestDiff
class JavaScriptFunctionOptimizer(FunctionOptimizer):
def get_code_optimization_context(self) -> Result[CodeOptimizationContext, str]:
from codeflash.languages import get_language_support
from codeflash.languages.base import Language
language = Language(self.function_to_optimize.language)
lang_support = get_language_support(language)
try:
code_context = lang_support.extract_code_context(
self.function_to_optimize, self.project_root, self.project_root
)
return Success(
self._build_optimization_context(
code_context,
self.function_to_optimize.file_path,
self.function_to_optimize.language,
self.project_root,
)
)
except ValueError as e:
return Failure(str(e))
@staticmethod
def _build_optimization_context(
code_context: CodeContext,
file_path: Path,
language: str,
project_root: Path,
optim_token_limit: int = OPTIMIZATION_CONTEXT_TOKEN_LIMIT,
testgen_token_limit: int = TESTGEN_CONTEXT_TOKEN_LIMIT,
) -> CodeOptimizationContext:
imports_code = "\n".join(code_context.imports) if code_context.imports else ""
try:
target_relative_path = file_path.resolve().relative_to(project_root.resolve())
except ValueError:
target_relative_path = file_path
helpers_by_file: dict[Path, list[HelperFunction]] = defaultdict(list)
helper_function_sources = []
for helper in code_context.helper_functions:
helpers_by_file[helper.file_path].append(helper)
helper_function_sources.append(
FunctionSource(
file_path=helper.file_path,
qualified_name=helper.qualified_name,
fully_qualified_name=helper.qualified_name,
only_function_name=helper.name,
source_code=helper.source_code,
)
)
target_file_code = code_context.target_code
same_file_helpers = helpers_by_file.get(file_path, [])
if same_file_helpers:
helper_code = "\n\n".join(h.source_code for h in same_file_helpers)
target_file_code = target_file_code + "\n\n" + helper_code
if imports_code:
target_file_code = imports_code + "\n\n" + target_file_code
read_writable_code_strings = [
CodeString(code=target_file_code, file_path=target_relative_path, language=language)
]
for helper_file_path, file_helpers in helpers_by_file.items():
if helper_file_path == file_path:
continue
try:
helper_relative_path = helper_file_path.resolve().relative_to(project_root.resolve())
except ValueError:
helper_relative_path = helper_file_path
combined_helper_code = "\n\n".join(h.source_code for h in file_helpers)
read_writable_code_strings.append(
CodeString(code=combined_helper_code, file_path=helper_relative_path, language=language)
)
read_writable_code = CodeStringsMarkdown(code_strings=read_writable_code_strings, language=language)
testgen_context = CodeStringsMarkdown(code_strings=read_writable_code_strings.copy(), language=language)
read_writable_tokens = encoded_tokens_len(read_writable_code.markdown)
if read_writable_tokens > optim_token_limit:
raise ValueError(READ_WRITABLE_LIMIT_ERROR)
testgen_tokens = encoded_tokens_len(testgen_context.markdown)
if testgen_tokens > testgen_token_limit:
raise ValueError(TESTGEN_LIMIT_ERROR)
code_hash = hashlib.sha256(read_writable_code.flat.encode("utf-8")).hexdigest()
return CodeOptimizationContext(
testgen_context=testgen_context,
read_writable_code=read_writable_code,
read_only_context_code=code_context.read_only_context,
hashing_code_context=read_writable_code.flat,
hashing_code_context_hash=code_hash,
helper_functions=helper_function_sources,
testgen_helper_fqns=[fs.fully_qualified_name for fs in helper_function_sources],
preexisting_objects=set(),
)
def compare_candidate_results(
self,
baseline_results: OriginalCodeBaseline,
candidate_behavior_results: TestResults,
optimization_candidate_index: int,
) -> tuple[bool, list[TestDiff]]:
original_sqlite = get_run_tmp_file(Path("test_return_values_0.sqlite"))
candidate_sqlite = get_run_tmp_file(Path(f"test_return_values_{optimization_candidate_index}.sqlite"))
if original_sqlite.exists() and candidate_sqlite.exists():
js_root = self.test_cfg.js_project_root or self.project_root
match, diffs = self.language_support.compare_test_results(
original_sqlite, candidate_sqlite, project_root=js_root
)
candidate_sqlite.unlink(missing_ok=True)
else:
match, diffs = compare_test_results(
baseline_results.behavior_test_results, candidate_behavior_results, pass_fail_only=True
)
return match, diffs
def should_skip_sqlite_cleanup(self, testing_type: TestingMode, optimization_iteration: int) -> bool:
return testing_type == TestingMode.BEHAVIOR or optimization_iteration == 0
def parse_line_profile_test_results(
self, line_profiler_output_file: Path | None
) -> tuple[TestResults | dict[str, Any], CoverageData | None]:
if line_profiler_output_file is None or not line_profiler_output_file.exists():
return TestResults(test_results=[]), None
if hasattr(self.language_support, "parse_line_profile_results"):
return self.language_support.parse_line_profile_results(line_profiler_output_file), None
return TestResults(test_results=[]), None
def line_profiler_step(
self, code_context: CodeOptimizationContext, original_helper_code: dict[Path, str], candidate_index: int
) -> dict[str, Any]:
if not hasattr(self.language_support, "instrument_source_for_line_profiler"):
logger.warning(f"Language support for {self.language_support.language} doesn't support line profiling")
return {"timings": {}, "unit": 0, "str_out": ""}
original_source = self.function_to_optimize.file_path.read_text(encoding="utf-8")
try:
line_profiler_output_path = get_run_tmp_file(Path("line_profiler_output.json"))
success = self.language_support.instrument_source_for_line_profiler(
func_info=self.function_to_optimize, line_profiler_output_file=line_profiler_output_path
)
if not success:
return {"timings": {}, "unit": 0, "str_out": ""}
test_env = self.get_test_env(
codeflash_loop_index=0, codeflash_test_iteration=candidate_index, codeflash_tracer_disable=1
)
_test_results, _ = self.run_and_parse_tests(
testing_type=TestingMode.LINE_PROFILE,
test_env=test_env,
test_files=self.test_files,
optimization_iteration=0,
testing_time=TOTAL_LOOPING_TIME_EFFECTIVE,
enable_coverage=False,
code_context=code_context,
line_profiler_output_file=line_profiler_output_path,
)
return self.language_support.parse_line_profile_results(line_profiler_output_path)
except Exception as e:
logger.warning(f"Failed to run line profiling: {e}")
return {"timings": {}, "unit": 0, "str_out": ""}
finally:
self.function_to_optimize.file_path.write_text(original_source, encoding="utf-8")
def replace_function_and_helpers_with_optimized_code(
self,
code_context: CodeOptimizationContext,
optimized_code: CodeStringsMarkdown,
original_helper_code: dict[Path, str],
) -> bool:
from codeflash.languages.code_replacer import replace_function_definitions_for_language
did_update = False
for module_abspath, qualified_names in self.group_functions_by_file(code_context).items():
did_update |= replace_function_definitions_for_language(
function_names=list(qualified_names),
optimized_code=optimized_code,
module_abspath=module_abspath,
project_root_path=self.project_root,
lang_support=self.language_support,
function_to_optimize=self.function_to_optimize,
)
return did_update

View file

@ -499,6 +499,18 @@ class MultiFileHelperFinder:
# Split source into lines for JSDoc extraction
lines = source.splitlines(keepends=True)
def helper_from_func(func):
effective_start = func.doc_start_line or func.start_line
helper_source = "".join(lines[effective_start - 1 : func.end_line])
return HelperFunction(
name=func.name,
qualified_name=func.name,
file_path=file_path,
source_code=helper_source,
start_line=effective_start,
end_line=func.end_line,
)
# Handle "default" export - look for default exported function
if function_name == "default":
# Find the default export
@ -506,38 +518,14 @@ class MultiFileHelperFinder:
# For now, return first function if looking for default
# TODO: Implement proper default export detection
for func in functions:
# Extract source including JSDoc if present
effective_start = func.doc_start_line or func.start_line
helper_lines = lines[effective_start - 1 : func.end_line]
helper_source = "".join(helper_lines)
return HelperFunction(
name=func.name,
qualified_name=func.name,
file_path=file_path,
source_code=helper_source,
start_line=effective_start,
end_line=func.end_line,
)
return helper_from_func(func)
return None
# Find the function by name
functions = file_analyzer.find_functions(source, include_methods=True)
for func in functions:
if func.name == function_name:
# Extract source including JSDoc if present
effective_start = func.doc_start_line or func.start_line
helper_lines = lines[effective_start - 1 : func.end_line]
helper_source = "".join(helper_lines)
return HelperFunction(
name=func.name,
qualified_name=func.name,
file_path=file_path,
source_code=helper_source,
start_line=effective_start,
end_line=func.end_line,
)
return helper_from_func(func)
logger.debug("Function %s not found in %s", function_name, file_path)
return None

View file

@ -1,17 +1,52 @@
"""JavaScript/TypeScript code normalizer using tree-sitter."""
"""JavaScript/TypeScript code normalizer using tree-sitter.
Not currently wired into JavaScriptSupport.normalize_code kept as a
ready-to-use upgrade path when AST-based JS deduplication is needed.
The old CodeNormalizer ABC (deleted from base.py) is preserved below for reference.
"""
from __future__ import annotations
import re
from abc import ABC, abstractmethod
from typing import TYPE_CHECKING
from codeflash.code_utils.normalizers.base import CodeNormalizer
if TYPE_CHECKING:
from tree_sitter import Node
# TODO:{claude} move to language support directory to keep the directory structure clean
# ---------------------------------------------------------------------------
# Reference: the old CodeNormalizer ABC that was deleted from base.py.
# Kept here so the interface contract is visible if we re-introduce a
# normalizer hierarchy later.
# ---------------------------------------------------------------------------
class CodeNormalizer(ABC):
@property
@abstractmethod
def language(self) -> str: ...
@abstractmethod
def normalize(self, code: str) -> str: ...
@abstractmethod
def normalize_for_hash(self, code: str) -> str: ...
def are_duplicates(self, code1: str, code2: str) -> bool:
try:
return self.normalize_for_hash(code1) == self.normalize_for_hash(code2)
except Exception:
return False
def get_fingerprint(self, code: str) -> str:
import hashlib
return hashlib.sha256(self.normalize_for_hash(code).encode()).hexdigest()
# ---------------------------------------------------------------------------
class JavaScriptVariableNormalizer:
"""Normalizes JavaScript/TypeScript code for duplicate detection using tree-sitter.
@ -188,103 +223,35 @@ class JavaScriptVariableNormalizer:
parts.append(")")
def _basic_normalize(code: str) -> str:
def _basic_normalize_js(code: str) -> str:
"""Basic normalization: remove comments and normalize whitespace."""
# Remove single-line comments
code = re.sub(r"//.*$", "", code, flags=re.MULTILINE)
# Remove multi-line comments
code = re.sub(r"/\*.*?\*/", "", code, flags=re.DOTALL)
# Normalize whitespace
return " ".join(code.split())
class JavaScriptNormalizer(CodeNormalizer):
"""JavaScript code normalizer using tree-sitter.
def normalize_js_code(code: str, typescript: bool = False) -> str:
"""Normalize JavaScript/TypeScript code to a canonical form for comparison.
Normalizes JavaScript code by:
- Replacing local variable names with canonical forms (var_0, var_1, etc.)
- Preserving function names, class names, parameters, and imports
- Removing comments
- Normalizing string and number literals
Uses tree-sitter to parse and normalize variable names. Falls back to
basic comment/whitespace stripping if tree-sitter is unavailable or parsing fails.
Not currently wired into JavaScriptSupport.normalize_code kept as a
ready-to-use upgrade path when AST-based JS deduplication is needed.
"""
try:
from codeflash.languages.javascript.treesitter import TreeSitterAnalyzer, TreeSitterLanguage
@property
def language(self) -> str:
"""Return the language this normalizer handles."""
return "javascript"
lang = TreeSitterLanguage.TYPESCRIPT if typescript else TreeSitterLanguage.JAVASCRIPT
analyzer = TreeSitterAnalyzer(lang)
tree = analyzer.parse(code)
@property
def supported_extensions(self) -> tuple[str, ...]:
"""Return file extensions this normalizer can handle."""
return (".js", ".jsx", ".mjs", ".cjs")
if tree.root_node.has_error:
return _basic_normalize_js(code)
def _get_tree_sitter_language(self) -> str:
"""Get the tree-sitter language identifier."""
return "javascript"
def normalize(self, code: str) -> str:
"""Normalize JavaScript code to a canonical form.
Args:
code: JavaScript source code to normalize
Returns:
Normalized representation of the code
"""
try:
from codeflash.languages.javascript.treesitter import TreeSitterAnalyzer, TreeSitterLanguage
lang_map = {"javascript": TreeSitterLanguage.JAVASCRIPT, "typescript": TreeSitterLanguage.TYPESCRIPT}
lang = lang_map.get(self._get_tree_sitter_language(), TreeSitterLanguage.JAVASCRIPT)
analyzer = TreeSitterAnalyzer(lang)
tree = analyzer.parse(code)
if tree.root_node.has_error:
return _basic_normalize(code)
normalizer = JavaScriptVariableNormalizer()
source_bytes = code.encode("utf-8")
# First pass: collect preserved names
normalizer.collect_preserved_names(tree.root_node, source_bytes)
# Second pass: normalize and build representation
return normalizer.normalize_tree(tree.root_node, source_bytes)
except Exception:
return _basic_normalize(code)
def normalize_for_hash(self, code: str) -> str:
"""Normalize JavaScript code optimized for hashing.
For JavaScript, this is the same as normalize().
Args:
code: JavaScript source code to normalize
Returns:
Normalized representation suitable for hashing
"""
return self.normalize(code)
class TypeScriptNormalizer(JavaScriptNormalizer):
"""TypeScript code normalizer using tree-sitter.
Inherits from JavaScriptNormalizer and overrides language-specific settings.
"""
@property
def language(self) -> str:
"""Return the language this normalizer handles."""
return "typescript"
@property
def supported_extensions(self) -> tuple[str, ...]:
"""Return file extensions this normalizer can handle."""
return (".ts", ".tsx", ".mts", ".cts")
def _get_tree_sitter_language(self) -> str:
"""Get the tree-sitter language identifier."""
return "typescript"
normalizer = JavaScriptVariableNormalizer()
source_bytes = code.encode("utf-8")
normalizer.collect_preserved_names(tree.root_node, source_bytes)
return normalizer.normalize_tree(tree.root_node, source_bytes)
except Exception:
return _basic_normalize_js(code)

View file

@ -0,0 +1,52 @@
from __future__ import annotations
from typing import TYPE_CHECKING
from codeflash.cli_cmds.console import logger
from codeflash.models.models import ValidCode
if TYPE_CHECKING:
from pathlib import Path
from codeflash.verification.verification_utils import TestConfig
def prepare_javascript_module(
original_module_code: str, original_module_path: Path
) -> tuple[dict[Path, ValidCode], None]:
"""Prepare a JavaScript/TypeScript module for optimization.
Unlike Python, JS/TS doesn't need AST parsing or import analysis at this stage.
Returns a mapping of the file path to ValidCode with the source as-is.
"""
validated_original_code: dict[Path, ValidCode] = {
original_module_path: ValidCode(source_code=original_module_code, normalized_code=original_module_code)
}
return validated_original_code, None
def verify_js_requirements(test_cfg: TestConfig) -> None:
"""Verify JavaScript/TypeScript requirements before optimization.
Checks that Node.js, npm, and the test framework are available.
Logs warnings if requirements are not met but does not abort.
"""
from codeflash.languages import get_language_support
from codeflash.languages.base import Language
from codeflash.languages.test_framework import get_js_test_framework_or_default
js_project_root = test_cfg.js_project_root
if not js_project_root:
return
try:
js_support = get_language_support(Language.JAVASCRIPT)
test_framework = get_js_test_framework_or_default()
success, errors = js_support.verify_requirements(js_project_root, test_framework)
if not success:
logger.warning("JavaScript requirements check found issues:")
for error in errors:
logger.warning(f" - {error}")
except Exception as e:
logger.debug(f"Failed to verify JS requirements: {e}")

View file

@ -23,7 +23,8 @@ if TYPE_CHECKING:
from codeflash.languages.base import ReferenceInfo
from codeflash.languages.javascript.treesitter import TypeDefinition
from codeflash.models.models import GeneratedTestsList, InvocationId
from codeflash.models.models import GeneratedTestsList, InvocationId, ValidCode
from codeflash.verification.verification_utils import TestConfig
logger = logging.getLogger(__name__)
@ -50,8 +51,7 @@ class JavaScriptSupport:
@property
def default_file_extension(self) -> str:
"""Default file extension for JavaScript."""
return ".js"
return self.file_extensions[0]
@property
def test_framework(self) -> str:
@ -68,17 +68,47 @@ class JavaScriptSupport:
def dir_excludes(self) -> frozenset[str]:
return frozenset({"node_modules", "dist", "build", ".next", ".nuxt", "coverage", ".cache", ".turbo", ".vercel"})
@property
def default_language_version(self) -> str | None:
return "ES2022"
@property
def valid_test_frameworks(self) -> tuple[str, ...]:
return ("jest", "mocha", "vitest")
@property
def test_result_serialization_format(self) -> str:
return "json"
def load_coverage(
self,
coverage_database_file: Path,
function_name: str,
code_context: Any,
source_file: Path,
coverage_config_file: Path | None = None,
) -> Any:
from codeflash.verification.coverage_utils import JestCoverageUtils
return JestCoverageUtils.load_from_jest_json(
coverage_json_path=coverage_database_file,
function_name=function_name,
code_context=code_context,
source_code_path=source_file,
)
# === Discovery ===
def discover_functions(
self, file_path: Path, filter_criteria: FunctionFilterCriteria | None = None
self, source: str, file_path: Path, filter_criteria: FunctionFilterCriteria | None = None
) -> list[FunctionToOptimize]:
"""Find all optimizable functions in a JavaScript file.
"""Find all optimizable functions in JavaScript/TypeScript source code.
Uses tree-sitter to parse the file and find functions.
Uses tree-sitter to parse the source and find functions.
Args:
file_path: Path to the JavaScript file to analyze.
source: Source code to analyze.
file_path: Path to the source file (used for language detection).
filter_criteria: Optional criteria to filter functions.
Returns:
@ -87,12 +117,6 @@ class JavaScriptSupport:
"""
criteria = filter_criteria or FunctionFilterCriteria()
try:
source = file_path.read_text(encoding="utf-8")
except Exception as e:
logger.warning("Failed to read %s: %s", file_path, e)
return []
try:
analyzer = get_analyzer_for_file(file_path)
tree_functions = analyzer.find_functions(
@ -111,7 +135,7 @@ class JavaScriptSupport:
# Skip non-exported functions (can't be imported in tests)
# Exception: nested functions and methods are allowed if their parent is exported
if not func.is_exported and not func.parent_function:
if criteria.require_export and not func.is_exported and not func.parent_function:
logger.debug(f"Skipping non-exported function: {func.name}") # noqa: G004
continue
@ -144,61 +168,6 @@ class JavaScriptSupport:
logger.warning("Failed to parse %s: %s", file_path, e)
return []
def discover_functions_from_source(self, source: str, file_path: Path | None = None) -> list[FunctionToOptimize]:
"""Find all functions in source code string.
Uses tree-sitter to parse the source and find functions.
Args:
source: The source code to analyze.
file_path: Optional file path for context (used for language detection).
Returns:
List of FunctionToOptimize objects for discovered functions.
"""
try:
# Use JavaScript analyzer by default, or detect from file path
if file_path:
analyzer = get_analyzer_for_file(file_path)
else:
analyzer = TreeSitterAnalyzer(TreeSitterLanguage.JAVASCRIPT)
tree_functions = analyzer.find_functions(
source, include_methods=True, include_arrow_functions=True, require_name=True
)
functions: list[FunctionToOptimize] = []
for func in tree_functions:
# Build parents list
parents: list[FunctionParent] = []
if func.class_name:
parents.append(FunctionParent(name=func.class_name, type="ClassDef"))
if func.parent_function:
parents.append(FunctionParent(name=func.parent_function, type="FunctionDef"))
functions.append(
FunctionToOptimize(
function_name=func.name,
file_path=file_path or Path("unknown"),
parents=parents,
starting_line=func.start_line,
ending_line=func.end_line,
starting_col=func.start_col,
ending_col=func.end_col,
is_async=func.is_async,
is_method=func.is_method,
language=str(self.language),
doc_start_line=func.doc_start_line,
)
)
return functions
except Exception as e:
logger.warning("Failed to parse source: %s", e)
return []
def _get_test_patterns(self) -> list[str]:
"""Get test file patterns for this language.
@ -1508,7 +1477,7 @@ class JavaScriptSupport:
return "".join(result_lines)
def format_code(self, source: str, file_path: Path | None = None) -> str:
"""Format JavaScript code using prettier (if available).
"""Format JavaScript/TypeScript code using prettier (if available).
Args:
source: Source code to format.
@ -1519,9 +1488,10 @@ class JavaScriptSupport:
"""
try:
# Try to use prettier via npx
stdin_filepath = str(file_path.name) if file_path else f"file{self.default_file_extension}"
result = subprocess.run(
["npx", "prettier", "--stdin-filepath", "file.js"],
["npx", "prettier", "--stdin-filepath", stdin_filepath],
check=False,
input=source,
capture_output=True,
@ -1702,22 +1672,15 @@ class JavaScriptSupport:
# === Validation ===
@property
def treesitter_language(self) -> TreeSitterLanguage:
return TreeSitterLanguage.JAVASCRIPT
def validate_syntax(self, source: str) -> bool:
"""Check if JavaScript source code is syntactically valid.
Uses tree-sitter to parse and check for errors.
Args:
source: Source code to validate.
Returns:
True if valid, False otherwise.
"""
"""Check if source code is syntactically valid using tree-sitter."""
try:
analyzer = TreeSitterAnalyzer(TreeSitterLanguage.JAVASCRIPT)
analyzer = TreeSitterAnalyzer(self.treesitter_language)
tree = analyzer.parse(source)
# Check if tree has errors
return not tree.root_node.has_error
except Exception:
return False
@ -1744,6 +1707,11 @@ class JavaScriptSupport:
normalized_lines.append(stripped)
return "\n".join(normalized_lines)
def generate_concolic_tests(
self, test_cfg: Any, project_root: Any, function_to_optimize: Any, function_to_optimize_ast: Any
) -> tuple[dict, str]:
return {}, ""
# === Test Editing ===
def add_runtime_comments(
@ -1909,6 +1877,92 @@ class JavaScriptSupport:
return compare_test_results(original_results_path, candidate_results_path, project_root=project_root)
@property
def function_optimizer_class(self) -> type:
from codeflash.languages.javascript.function_optimizer import JavaScriptFunctionOptimizer
return JavaScriptFunctionOptimizer
def prepare_module(
self, module_code: str, module_path: Path, project_root: Path
) -> tuple[dict[Path, ValidCode], None]:
from codeflash.languages.javascript.optimizer import prepare_javascript_module
return prepare_javascript_module(module_code, module_path)
def setup_test_config(self, test_cfg: TestConfig, file_path: Path) -> None:
from codeflash.languages.javascript.optimizer import verify_js_requirements
from codeflash.languages.javascript.test_runner import find_node_project_root
test_cfg.js_project_root = find_node_project_root(file_path)
verify_js_requirements(test_cfg)
def adjust_test_config_for_discovery(self, test_cfg: TestConfig) -> None:
test_cfg.tests_project_rootdir = test_cfg.tests_root
def detect_module_system(self, project_root: Path, source_file: Path) -> str | None:
from codeflash.languages.javascript.module_system import detect_module_system
return detect_module_system(project_root, source_file)
def process_generated_test_strings(
self,
generated_test_source: str,
instrumented_behavior_test_source: str,
instrumented_perf_test_source: str,
function_to_optimize: Any,
test_path: Path,
test_cfg: Any,
project_module_system: str | None,
) -> tuple[str, str, str]:
from codeflash.languages.javascript.instrument import (
TestingMode,
fix_imports_inside_test_blocks,
fix_jest_mock_paths,
instrument_generated_js_test,
validate_and_fix_import_style,
)
from codeflash.languages.javascript.module_system import (
ensure_module_system_compatibility,
ensure_vitest_imports,
)
source_file = Path(function_to_optimize.file_path)
# Fix import statements that appear inside test blocks (invalid JS syntax)
generated_test_source = fix_imports_inside_test_blocks(generated_test_source)
# Fix relative paths in jest.mock() calls
generated_test_source = fix_jest_mock_paths(
generated_test_source, test_path, source_file, test_cfg.tests_project_rootdir
)
# Validate and fix import styles (default vs named exports)
generated_test_source = validate_and_fix_import_style(
generated_test_source, source_file, function_to_optimize.function_name
)
# Convert module system if needed (e.g., CommonJS -> ESM for ESM projects)
generated_test_source = ensure_module_system_compatibility(
generated_test_source, project_module_system, test_cfg.tests_project_rootdir
)
# Ensure vitest imports are present when using vitest framework
generated_test_source = ensure_vitest_imports(generated_test_source, test_cfg.test_framework)
# Instrument for behavior verification (writes to SQLite)
instrumented_behavior_test_source = instrument_generated_js_test(
test_code=generated_test_source, function_to_optimize=function_to_optimize, mode=TestingMode.BEHAVIOR
)
# Instrument for performance measurement (prints to stdout)
instrumented_perf_test_source = instrument_generated_js_test(
test_code=generated_test_source, function_to_optimize=function_to_optimize, mode=TestingMode.PERFORMANCE
)
logger.debug("Instrumented JS/TS tests locally for %s", function_to_optimize.function_name)
return generated_test_source, instrumented_behavior_test_source, instrumented_perf_test_source
# === Configuration ===
def get_test_file_suffix(self) -> str:
@ -2488,62 +2542,9 @@ class TypeScriptSupport(JavaScriptSupport):
]
def get_test_file_suffix(self) -> str:
"""Get the test file suffix for TypeScript.
Returns:
Jest test file suffix for TypeScript.
"""
"""Get the test file suffix for TypeScript."""
return ".test.ts"
def validate_syntax(self, source: str) -> bool:
"""Check if TypeScript source code is syntactically valid.
Uses tree-sitter TypeScript parser to parse and check for errors.
Args:
source: Source code to validate.
Returns:
True if valid, False otherwise.
"""
try:
analyzer = TreeSitterAnalyzer(TreeSitterLanguage.TYPESCRIPT)
tree = analyzer.parse(source)
return not tree.root_node.has_error
except Exception:
return False
def format_code(self, source: str, file_path: Path | None = None) -> str:
"""Format TypeScript code using prettier (if available).
Args:
source: Source code to format.
file_path: Optional file path for context.
Returns:
Formatted source code.
"""
try:
# Determine file extension for prettier
stdin_filepath = str(file_path.name) if file_path else "file.ts"
# Try to use prettier via npx
result = subprocess.run(
["npx", "prettier", "--stdin-filepath", stdin_filepath],
check=False,
input=source,
capture_output=True,
text=True,
timeout=30,
)
if result.returncode == 0:
return result.stdout
except (subprocess.TimeoutExpired, FileNotFoundError):
pass
except Exception as e:
logger.debug("Prettier formatting failed: %s", e)
return source
@property
def treesitter_language(self) -> TreeSitterLanguage:
return TreeSitterLanguage.TYPESCRIPT

View file

@ -369,7 +369,7 @@ def _get_jest_config_for_project(project_root: Path) -> Path | None:
return original_jest_config
def _find_node_project_root(file_path: Path) -> Path | None:
def find_node_project_root(file_path: Path) -> Path | None:
"""Find the Node.js project root by looking for package.json.
Traverses up from the given file path to find the nearest directory
@ -686,7 +686,7 @@ def run_jest_behavioral_tests(
# Use provided project_root, or detect it as fallback
if project_root is None and test_files:
first_test_file = Path(test_files[0])
project_root = _find_node_project_root(first_test_file)
project_root = find_node_project_root(first_test_file)
# Use the project root, or fall back to provided cwd
effective_cwd = project_root if project_root else cwd
@ -936,7 +936,7 @@ def run_jest_benchmarking_tests(
# Use provided project_root, or detect it as fallback
if project_root is None and test_files:
first_test_file = Path(test_files[0])
project_root = _find_node_project_root(first_test_file)
project_root = find_node_project_root(first_test_file)
effective_cwd = project_root if project_root else cwd
@ -1106,7 +1106,7 @@ def run_jest_line_profile_tests(
# Use provided project_root, or detect it as fallback
if project_root is None and test_files:
first_test_file = Path(test_files[0])
project_root = _find_node_project_root(first_test_file)
project_root = find_node_project_root(first_test_file)
effective_cwd = project_root if project_root else cwd
logger.debug(f"Jest line profiling working directory: {effective_cwd}")

View file

@ -12,6 +12,7 @@ class Language(str, Enum):
PYTHON = "python"
JAVASCRIPT = "javascript"
TYPESCRIPT = "typescript"
JAVA = "java"
def __str__(self) -> str:
return self.value

View file

@ -12,11 +12,13 @@ import libcst as cst
from codeflash.cli_cmds.console import logger
from codeflash.code_utils.code_utils import encoded_tokens_len, get_qualified_name, path_belongs_to_site_packages
from codeflash.code_utils.config_consts import OPTIMIZATION_CONTEXT_TOKEN_LIMIT, TESTGEN_CONTEXT_TOKEN_LIMIT
from codeflash.code_utils.config_consts import (
OPTIMIZATION_CONTEXT_TOKEN_LIMIT,
READ_WRITABLE_LIMIT_ERROR,
TESTGEN_CONTEXT_TOKEN_LIMIT,
TESTGEN_LIMIT_ERROR,
)
from codeflash.discovery.functions_to_optimize import FunctionToOptimize # noqa: TC001
# Language support imports for multi-language code context extraction
from codeflash.languages import Language, is_python
from codeflash.languages.python.context.unused_definition_remover import (
collect_top_level_defs_with_usages,
get_section_names,
@ -40,13 +42,9 @@ from codeflash.optimization.function_context import belongs_to_function_qualifie
if TYPE_CHECKING:
from jedi.api.classes import Name
from codeflash.languages.base import DependencyResolver, HelperFunction
from codeflash.languages.base import DependencyResolver
from codeflash.languages.python.context.unused_definition_remover import UsageInfo
# Error message constants
READ_WRITABLE_LIMIT_ERROR = "Read-writable code has exceeded token limit, cannot proceed"
TESTGEN_LIMIT_ERROR = "Testgen code context has exceeded token limit, cannot proceed"
def build_testgen_context(
helpers_of_fto_dict: dict[Path, set[FunctionSource]],
@ -91,12 +89,6 @@ def get_code_optimization_context(
testgen_token_limit: int = TESTGEN_CONTEXT_TOKEN_LIMIT,
call_graph: DependencyResolver | None = None,
) -> CodeOptimizationContext:
# Route to language-specific implementation for non-Python languages
if not is_python():
return get_code_optimization_context_for_language(
function_to_optimize, project_root_path, optim_token_limit, testgen_token_limit
)
# Get FunctionSource representation of helpers of FTO
fto_input = {function_to_optimize.file_path: {function_to_optimize.qualified_name}}
if call_graph is not None:
@ -216,140 +208,6 @@ def get_code_optimization_context(
)
def get_code_optimization_context_for_language(
function_to_optimize: FunctionToOptimize,
project_root_path: Path,
optim_token_limit: int = OPTIMIZATION_CONTEXT_TOKEN_LIMIT,
testgen_token_limit: int = TESTGEN_CONTEXT_TOKEN_LIMIT,
) -> CodeOptimizationContext:
"""Extract code optimization context for non-Python languages.
Uses the language support abstraction to extract code context and converts
it to the CodeOptimizationContext format expected by the pipeline.
This function supports multi-file context extraction, grouping helpers by file
and creating proper CodeStringsMarkdown with file paths for multi-file replacement.
Args:
function_to_optimize: The function to extract context for.
project_root_path: Root of the project.
optim_token_limit: Token limit for optimization context.
testgen_token_limit: Token limit for testgen context.
Returns:
CodeOptimizationContext with target code and dependencies.
"""
from codeflash.languages import get_language_support
# Get language support for this function
language = Language(function_to_optimize.language)
lang_support = get_language_support(language)
# Extract code context using language support
code_context = lang_support.extract_code_context(function_to_optimize, project_root_path, project_root_path)
# Build imports string if available
imports_code = "\n".join(code_context.imports) if code_context.imports else ""
# Get relative path for target file
try:
target_relative_path = function_to_optimize.file_path.resolve().relative_to(project_root_path.resolve())
except ValueError:
target_relative_path = function_to_optimize.file_path
# Group helpers by file path
helpers_by_file: dict[Path, list[HelperFunction]] = defaultdict(list)
helper_function_sources = []
for helper in code_context.helper_functions:
helpers_by_file[helper.file_path].append(helper)
# Convert to FunctionSource for pipeline compatibility
helper_function_sources.append(
FunctionSource(
file_path=helper.file_path,
qualified_name=helper.qualified_name,
fully_qualified_name=helper.qualified_name,
only_function_name=helper.name,
source_code=helper.source_code,
)
)
# Build read-writable code (target file + same-file helpers + global variables)
read_writable_code_strings = []
# Combine target code with same-file helpers
target_file_code = code_context.target_code
same_file_helpers = helpers_by_file.get(function_to_optimize.file_path, [])
if same_file_helpers:
helper_code = "\n\n".join(h.source_code for h in same_file_helpers)
target_file_code = target_file_code + "\n\n" + helper_code
# Note: code_context.read_only_context contains type definitions and global variables
# These should be passed as read-only context to the AI, not prepended to the target code
# If prepended to target code, the AI treats them as code to optimize and includes them in output
# Add imports to target file code
if imports_code:
target_file_code = imports_code + "\n\n" + target_file_code
read_writable_code_strings.append(
CodeString(code=target_file_code, file_path=target_relative_path, language=function_to_optimize.language)
)
# Add helper files (cross-file helpers)
for file_path, file_helpers in helpers_by_file.items():
if file_path == function_to_optimize.file_path:
continue # Already included in target file
try:
helper_relative_path = file_path.resolve().relative_to(project_root_path.resolve())
except ValueError:
helper_relative_path = file_path
# Combine all helpers from this file
combined_helper_code = "\n\n".join(h.source_code for h in file_helpers)
read_writable_code_strings.append(
CodeString(
code=combined_helper_code, file_path=helper_relative_path, language=function_to_optimize.language
)
)
read_writable_code = CodeStringsMarkdown(
code_strings=read_writable_code_strings, language=function_to_optimize.language
)
# Build testgen context (same as read_writable for non-Python)
testgen_context = CodeStringsMarkdown(
code_strings=read_writable_code_strings.copy(), language=function_to_optimize.language
)
# Check token limits
read_writable_tokens = encoded_tokens_len(read_writable_code.markdown)
if read_writable_tokens > optim_token_limit:
raise ValueError(READ_WRITABLE_LIMIT_ERROR)
testgen_tokens = encoded_tokens_len(testgen_context.markdown)
if testgen_tokens > testgen_token_limit:
raise ValueError(TESTGEN_LIMIT_ERROR)
# Generate code hash from all read-writable code
code_hash = hashlib.sha256(read_writable_code.flat.encode("utf-8")).hexdigest()
return CodeOptimizationContext(
testgen_context=testgen_context,
read_writable_code=read_writable_code,
read_only_context_code=code_context.read_only_context,
hashing_code_context=read_writable_code.flat,
hashing_code_context_hash=code_hash,
helper_functions=helper_function_sources,
testgen_helper_fqns=[fs.fully_qualified_name for fs in helper_function_sources],
preexisting_objects=set(),
)
def process_file_context(
file_path: Path,
primary_qualified_names: set[str],

View file

@ -10,7 +10,8 @@ from typing import TYPE_CHECKING, Optional, Union
import libcst as cst
from codeflash.cli_cmds.console import logger
from codeflash.languages import is_python
from codeflash.languages import current_language
from codeflash.languages.base import Language
from codeflash.languages.python.static_analysis.code_replacer import replace_function_definitions_in_module
from codeflash.models.models import CodeString, CodeStringsMarkdown
@ -747,7 +748,7 @@ def detect_unused_helper_functions(
"""
# Skip this analysis for non-Python languages since we use Python's ast module
if not is_python():
if current_language() != Language.PYTHON:
logger.debug("Skipping unused helper function detection for non-Python languages")
return []

View file

@ -0,0 +1,215 @@
from __future__ import annotations
import ast
from pathlib import Path
from typing import TYPE_CHECKING
from codeflash.cli_cmds.console import console, logger
from codeflash.code_utils.config_consts import TOTAL_LOOPING_TIME_EFFECTIVE
from codeflash.either import Failure, Success
from codeflash.languages.python.context.unused_definition_remover import (
detect_unused_helper_functions,
revert_unused_helper_functions,
)
from codeflash.languages.python.optimizer import resolve_python_function_ast
from codeflash.languages.python.static_analysis.code_extractor import get_opt_review_metrics, is_numerical_code
from codeflash.languages.python.static_analysis.code_replacer import (
add_custom_marker_to_all_tests,
modify_autouse_fixture,
)
from codeflash.languages.python.static_analysis.line_profile_utils import add_decorator_imports, contains_jit_decorator
from codeflash.models.models import TestingMode, TestResults
from codeflash.optimization.function_optimizer import FunctionOptimizer
from codeflash.verification.parse_test_output import calculate_function_throughput_from_test_results
if TYPE_CHECKING:
from typing import Any
from codeflash.either import Result
from codeflash.languages.base import Language
from codeflash.models.function_types import FunctionParent
from codeflash.models.models import (
CodeOptimizationContext,
CodeStringsMarkdown,
ConcurrencyMetrics,
CoverageData,
OriginalCodeBaseline,
TestDiff,
)
class PythonFunctionOptimizer(FunctionOptimizer):
def get_code_optimization_context(self) -> Result[CodeOptimizationContext, str]:
from codeflash.languages.python.context import code_context_extractor
try:
return Success(
code_context_extractor.get_code_optimization_context(
self.function_to_optimize, self.project_root, call_graph=self.call_graph
)
)
except ValueError as e:
return Failure(str(e))
def _resolve_function_ast(
self, source_code: str, function_name: str, parents: list[FunctionParent]
) -> ast.FunctionDef | ast.AsyncFunctionDef | None:
original_module_ast = ast.parse(source_code)
return resolve_python_function_ast(function_name, parents, original_module_ast)
def requires_function_ast(self) -> bool:
return True
def analyze_code_characteristics(self, code_context: CodeOptimizationContext) -> None:
self.is_numerical_code = is_numerical_code(code_string=code_context.read_writable_code.flat)
def get_optimization_review_metrics(
self,
source_code: str,
file_path: Path,
qualified_name: str,
project_root: Path,
tests_root: Path,
language: Language,
) -> str:
return get_opt_review_metrics(source_code, file_path, qualified_name, project_root, tests_root, language)
def instrument_test_fixtures(self, test_paths: list[Path]) -> dict[Path, list[str]] | None:
logger.info("Disabling all autouse fixtures associated with the generated test files")
original_conftest_content = modify_autouse_fixture(test_paths)
logger.info("Add custom marker to generated test files")
add_custom_marker_to_all_tests(test_paths)
return original_conftest_content
def instrument_capture(self, file_path_to_helper_classes: dict[Path, set[str]]) -> None:
from codeflash.verification.instrument_codeflash_capture import instrument_codeflash_capture
instrument_codeflash_capture(self.function_to_optimize, file_path_to_helper_classes, self.test_cfg.tests_root)
def should_check_coverage(self) -> bool:
return True
def collect_async_metrics(
self,
benchmarking_results: TestResults,
code_context: CodeOptimizationContext,
helper_code: dict[Path, str],
test_env: dict[str, str],
) -> tuple[int | None, ConcurrencyMetrics | None]:
if not self.function_to_optimize.is_async:
return None, None
async_throughput = calculate_function_throughput_from_test_results(
benchmarking_results, self.function_to_optimize.function_name
)
logger.debug(f"Async function throughput: {async_throughput} calls/second")
concurrency_metrics = self.run_concurrency_benchmark(
code_context=code_context, original_helper_code=helper_code, test_env=test_env
)
if concurrency_metrics:
logger.debug(
f"Concurrency metrics: ratio={concurrency_metrics.concurrency_ratio:.2f}, "
f"seq={concurrency_metrics.sequential_time_ns}ns, conc={concurrency_metrics.concurrent_time_ns}ns"
)
return async_throughput, concurrency_metrics
def instrument_async_for_mode(self, mode: TestingMode) -> None:
from codeflash.code_utils.instrument_existing_tests import add_async_decorator_to_function
add_async_decorator_to_function(
self.function_to_optimize.file_path, self.function_to_optimize, mode, project_root=self.project_root
)
def should_skip_sqlite_cleanup(self, testing_type: TestingMode, optimization_iteration: int) -> bool:
return False
def parse_line_profile_test_results(
self, line_profiler_output_file: Path | None
) -> tuple[TestResults | dict, CoverageData | None]:
from codeflash.verification.parse_line_profile_test_output import parse_line_profile_results
return parse_line_profile_results(line_profiler_output_file=line_profiler_output_file)
def compare_candidate_results(
self,
baseline_results: OriginalCodeBaseline,
candidate_behavior_results: TestResults,
optimization_candidate_index: int,
) -> tuple[bool, list[TestDiff]]:
from codeflash.verification.equivalence import compare_test_results
return compare_test_results(baseline_results.behavior_test_results, candidate_behavior_results)
def replace_function_and_helpers_with_optimized_code(
self,
code_context: CodeOptimizationContext,
optimized_code: CodeStringsMarkdown,
original_helper_code: dict[Path, str],
) -> bool:
from codeflash.languages.python.static_analysis.code_replacer import replace_function_definitions_in_module
did_update = False
for module_abspath, qualified_names in self.group_functions_by_file(code_context).items():
did_update |= replace_function_definitions_in_module(
function_names=list(qualified_names),
optimized_code=optimized_code,
module_abspath=module_abspath,
preexisting_objects=code_context.preexisting_objects,
project_root_path=self.project_root,
)
unused_helpers = detect_unused_helper_functions(self.function_to_optimize, code_context, optimized_code)
if unused_helpers:
revert_unused_helper_functions(self.project_root, unused_helpers, original_helper_code)
return did_update
def line_profiler_step(
self, code_context: CodeOptimizationContext, original_helper_code: dict[Path, str], candidate_index: int
) -> dict[str, Any]:
candidate_fto_code = Path(self.function_to_optimize.file_path).read_text("utf-8")
if contains_jit_decorator(candidate_fto_code):
logger.info(
f"Skipping line profiler for {self.function_to_optimize.function_name} - code contains JIT decorator"
)
return {"timings": {}, "unit": 0, "str_out": ""}
for module_abspath in original_helper_code:
candidate_helper_code = Path(module_abspath).read_text("utf-8")
if contains_jit_decorator(candidate_helper_code):
logger.info(
f"Skipping line profiler for {self.function_to_optimize.function_name} - helper code contains JIT decorator"
)
return {"timings": {}, "unit": 0, "str_out": ""}
try:
console.rule()
test_env = self.get_test_env(
codeflash_loop_index=0, codeflash_test_iteration=candidate_index, codeflash_tracer_disable=1
)
line_profiler_output_file = add_decorator_imports(self.function_to_optimize, code_context)
line_profile_results, _ = self.run_and_parse_tests(
testing_type=TestingMode.LINE_PROFILE,
test_env=test_env,
test_files=self.test_files,
optimization_iteration=0,
testing_time=TOTAL_LOOPING_TIME_EFFECTIVE,
enable_coverage=False,
code_context=code_context,
line_profiler_output_file=line_profiler_output_file,
)
finally:
self.write_code_and_helpers(
self.function_to_optimize_source_code, original_helper_code, self.function_to_optimize.file_path
)
if isinstance(line_profile_results, TestResults) and not line_profile_results.test_results:
logger.warning(
f"Timeout occurred while running line profiler for original function {self.function_to_optimize.function_name}"
)
return {"timings": {}, "unit": 0, "str_out": ""}
if line_profile_results["str_out"] == "":
logger.warning(
f"Couldn't run line profiler for original function {self.function_to_optimize.function_name}"
)
return line_profile_results

View file

@ -4,8 +4,6 @@ from __future__ import annotations
import ast
from codeflash.code_utils.normalizers.base import CodeNormalizer
class VariableNormalizer(ast.NodeTransformer):
"""Normalizes only local variable names in AST to canonical forms like var_0, var_1, etc.
@ -164,63 +162,19 @@ def _remove_docstrings_from_ast(node: ast.AST) -> None:
stack.extend([child for child in body if isinstance(child, node_types)])
class PythonNormalizer(CodeNormalizer):
"""Python code normalizer using AST transformation.
def normalize_python_code(code: str, remove_docstrings: bool = True) -> str:
"""Normalize Python code to a canonical form for comparison.
Normalizes Python code by:
- Replacing local variable names with canonical forms (var_0, var_1, etc.)
- Preserving function names, class names, parameters, and imports
- Optionally removing docstrings
Replaces local variable names with canonical forms (var_0, var_1, etc.)
while preserving function names, class names, parameters, and imports.
"""
tree = ast.parse(code)
@property
def language(self) -> str:
"""Return the language this normalizer handles."""
return "python"
@property
def supported_extensions(self) -> tuple[str, ...]:
"""Return file extensions this normalizer can handle."""
return (".py", ".pyw", ".pyi")
def normalize(self, code: str, remove_docstrings: bool = True) -> str:
"""Normalize Python code to a canonical form.
Args:
code: Python source code to normalize
remove_docstrings: Whether to remove docstrings
Returns:
Normalized Python code as a string
"""
tree = ast.parse(code)
if remove_docstrings:
_remove_docstrings_from_ast(tree)
normalizer = VariableNormalizer()
normalized_tree = normalizer.visit(tree)
ast.fix_missing_locations(normalized_tree)
return ast.unparse(normalized_tree)
def normalize_for_hash(self, code: str) -> str:
"""Normalize Python code optimized for hashing.
Returns AST dump which is faster than unparsing.
Args:
code: Python source code to normalize
Returns:
AST dump string suitable for hashing
"""
tree = ast.parse(code)
if remove_docstrings:
_remove_docstrings_from_ast(tree)
normalizer = VariableNormalizer()
normalized_tree = normalizer.visit(tree)
normalizer = VariableNormalizer()
normalized_tree = normalizer.visit(tree)
ast.fix_missing_locations(normalized_tree)
return ast.dump(normalized_tree, annotate_fields=False, include_attributes=False)
return ast.unparse(normalized_tree)

View file

@ -0,0 +1,63 @@
from __future__ import annotations
import ast
from typing import TYPE_CHECKING
from codeflash.cli_cmds.console import logger
from codeflash.models.models import ValidCode
if TYPE_CHECKING:
from pathlib import Path
from codeflash.models.function_types import FunctionParent
def prepare_python_module(
original_module_code: str, original_module_path: Path, project_root: Path
) -> tuple[dict[Path, ValidCode], ast.Module] | None:
"""Parse a Python module, normalize its code, and validate imported callee modules.
Returns a mapping of file paths to ValidCode (for the module and its imported callees)
plus the parsed AST, or None on syntax error.
"""
from codeflash.languages.python.static_analysis.code_replacer import normalize_code, normalize_node
from codeflash.languages.python.static_analysis.static_analysis import analyze_imported_modules
try:
original_module_ast = ast.parse(original_module_code)
except SyntaxError as e:
logger.warning(f"Syntax error parsing code in {original_module_path}: {e}")
logger.info("Skipping optimization due to file error.")
return None
normalized_original_module_code = ast.unparse(normalize_node(original_module_ast))
validated_original_code: dict[Path, ValidCode] = {
original_module_path: ValidCode(
source_code=original_module_code, normalized_code=normalized_original_module_code
)
}
imported_module_analyses = analyze_imported_modules(original_module_code, original_module_path, project_root)
for analysis in imported_module_analyses:
callee_original_code = analysis.file_path.read_text(encoding="utf8")
try:
normalized_callee_original_code = normalize_code(callee_original_code)
except SyntaxError as e:
logger.warning(f"Syntax error parsing code in callee module {analysis.file_path}: {e}")
logger.info("Skipping optimization due to helper file error.")
return None
validated_original_code[analysis.file_path] = ValidCode(
source_code=callee_original_code, normalized_code=normalized_callee_original_code
)
return validated_original_code, original_module_ast
def resolve_python_function_ast(
function_name: str, parents: list[FunctionParent], module_ast: ast.Module
) -> ast.FunctionDef | ast.AsyncFunctionDef | None:
"""Look up a function/method AST node in a parsed Python module."""
from codeflash.languages.python.static_analysis.static_analysis import get_first_top_level_function_or_method_ast
return get_first_top_level_function_or_method_ast(function_name, parents, module_ast)

View file

@ -4,7 +4,7 @@ import ast
from collections import defaultdict
from functools import lru_cache
from itertools import chain
from typing import TYPE_CHECKING, Optional, TypeVar
from typing import TYPE_CHECKING, TypeVar
import libcst as cst
from libcst.metadata import PositionProvider
@ -12,7 +12,7 @@ from libcst.metadata import PositionProvider
from codeflash.cli_cmds.console import logger
from codeflash.code_utils.config_parser import find_conftest_files
from codeflash.code_utils.formatter import sort_imports
from codeflash.languages import is_python
from codeflash.languages.code_replacer import get_optimized_code_for_module
from codeflash.languages.python.static_analysis.code_extractor import (
add_global_assignments,
add_needed_imports_from_module,
@ -24,9 +24,7 @@ from codeflash.models.models import FunctionParent
if TYPE_CHECKING:
from pathlib import Path
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
from codeflash.languages.base import LanguageSupport
from codeflash.models.models import CodeOptimizationContext, CodeStringsMarkdown, OptimizedCandidate, ValidCode
from codeflash.models.models import CodeStringsMarkdown
ASTNodeT = TypeVar("ASTNodeT", bound=ast.AST)
@ -391,14 +389,7 @@ def replace_function_definitions_in_module(
preexisting_objects: set[tuple[str, tuple[FunctionParent, ...]]],
project_root_path: Path,
should_add_global_assignments: bool = True,
function_to_optimize: Optional[FunctionToOptimize] = None,
) -> bool:
# Route to language-specific implementation for non-Python languages
if not is_python():
return replace_function_definitions_for_language(
function_names, optimized_code, module_abspath, project_root_path, function_to_optimize
)
source_code: str = module_abspath.read_text(encoding="utf8")
code_to_apply = get_optimized_code_for_module(module_abspath.relative_to(project_root_path), optimized_code)
@ -420,270 +411,5 @@ def replace_function_definitions_in_module(
return True
def replace_function_definitions_for_language(
function_names: list[str],
optimized_code: CodeStringsMarkdown,
module_abspath: Path,
project_root_path: Path,
function_to_optimize: Optional[FunctionToOptimize] = None,
) -> bool:
"""Replace function definitions for non-Python languages.
Uses the language support abstraction to perform code replacement.
Args:
function_names: List of qualified function names to replace.
optimized_code: The optimized code to apply.
module_abspath: Path to the module file.
project_root_path: Root of the project.
function_to_optimize: The function being optimized (needed for line info).
Returns:
True if the code was modified, False if no changes.
"""
from codeflash.languages import get_language_support
from codeflash.languages.base import Language
original_source_code: str = module_abspath.read_text(encoding="utf8")
code_to_apply = get_optimized_code_for_module(module_abspath.relative_to(project_root_path), optimized_code)
if not code_to_apply.strip():
return False
# Get language support
language = Language(optimized_code.language)
lang_support = get_language_support(language)
# Add any new global declarations from the optimized code to the original source
original_source_code = lang_support.add_global_declarations(
optimized_code=code_to_apply, original_source=original_source_code, module_abspath=module_abspath
)
# If we have function_to_optimize with line info and this is the main file, use it for precise replacement
if (
function_to_optimize
and function_to_optimize.starting_line
and function_to_optimize.ending_line
and function_to_optimize.file_path == module_abspath
):
# Extract just the target function from the optimized code
optimized_func = _extract_function_from_code(
lang_support, code_to_apply, function_to_optimize.function_name, module_abspath
)
if optimized_func:
new_code = lang_support.replace_function(original_source_code, function_to_optimize, optimized_func)
else:
# Fallback: use the entire optimized code (for simple single-function files)
new_code = lang_support.replace_function(original_source_code, function_to_optimize, code_to_apply)
else:
# For helper files or when we don't have precise line info:
# Find each function by name in both original and optimized code
# Then replace with the corresponding optimized version
new_code = original_source_code
modified = False
# Get the list of function names to replace
functions_to_replace = list(function_names)
for func_name in functions_to_replace:
# Re-discover functions from current code state to get correct line numbers
current_functions = lang_support.discover_functions_from_source(new_code, module_abspath)
# Find the function in current code
func = None
for f in current_functions:
if func_name in (f.qualified_name, f.function_name):
func = f
break
if func is None:
continue
# Extract just this function from the optimized code
optimized_func = _extract_function_from_code(
lang_support, code_to_apply, func.function_name, module_abspath
)
if optimized_func:
new_code = lang_support.replace_function(new_code, func, optimized_func)
modified = True
if not modified:
logger.warning(f"Could not find function {function_names} in {module_abspath}")
return False
# Check if there was actually a change
if original_source_code.strip() == new_code.strip():
return False
module_abspath.write_text(new_code, encoding="utf8")
return True
def _extract_function_from_code(
lang_support: LanguageSupport, source_code: str, function_name: str, file_path: Path | None = None
) -> str | None:
"""Extract a specific function's source code from a code string.
Includes JSDoc/docstring comments if present.
Args:
lang_support: Language support instance.
source_code: The full source code containing the function.
function_name: Name of the function to extract.
file_path: Path to the file (used to determine correct analyzer for JS/TS).
Returns:
The function's source code (including doc comments), or None if not found.
"""
try:
# Use the language support to find functions in the source
# file_path is needed for JS/TS to determine correct analyzer (TypeScript vs JavaScript)
functions = lang_support.discover_functions_from_source(source_code, file_path)
for func in functions:
if func.function_name == function_name:
# Extract the function's source using line numbers
# Use doc_start_line if available to include JSDoc/docstring
lines = source_code.splitlines(keepends=True)
effective_start = func.doc_start_line or func.starting_line
if effective_start and func.ending_line and effective_start <= len(lines):
func_lines = lines[effective_start - 1 : func.ending_line]
return "".join(func_lines)
except Exception as e:
logger.debug(f"Error extracting function {function_name}: {e}")
return None
def get_optimized_code_for_module(relative_path: Path, optimized_code: CodeStringsMarkdown) -> str:
file_to_code_context = optimized_code.file_to_path()
module_optimized_code = file_to_code_context.get(str(relative_path))
if module_optimized_code is None:
# Fallback: if there's only one code block with None file path,
# use it regardless of the expected path (the AI server doesn't always include file paths)
if "None" in file_to_code_context and len(file_to_code_context) == 1:
module_optimized_code = file_to_code_context["None"]
logger.debug(f"Using code block with None file_path for {relative_path}")
else:
logger.warning(
f"Optimized code not found for {relative_path} In the context\n-------\n{optimized_code}\n-------\n"
"re-check your 'markdown code structure'"
f"existing files are {file_to_code_context.keys()}"
)
module_optimized_code = ""
return module_optimized_code
def is_zero_diff(original_code: str, new_code: str) -> bool:
return normalize_code(original_code) == normalize_code(new_code)
def replace_optimized_code(
callee_module_paths: set[Path],
candidates: list[OptimizedCandidate],
code_context: CodeOptimizationContext,
function_to_optimize: FunctionToOptimize,
validated_original_code: dict[Path, ValidCode],
project_root: Path,
) -> tuple[set[Path], dict[str, dict[Path, str]]]:
initial_optimized_code = {
candidate.optimization_id: replace_functions_and_add_imports(
validated_original_code[function_to_optimize.file_path].source_code,
[function_to_optimize.qualified_name],
candidate.source_code,
function_to_optimize.file_path,
function_to_optimize.file_path,
code_context.preexisting_objects,
project_root,
)
for candidate in candidates
}
callee_original_code = {
module_path: validated_original_code[module_path].source_code for module_path in callee_module_paths
}
intermediate_original_code: dict[str, dict[Path, str]] = {
candidate.optimization_id: (
callee_original_code | {function_to_optimize.file_path: initial_optimized_code[candidate.optimization_id]}
)
for candidate in candidates
}
module_paths = callee_module_paths | {function_to_optimize.file_path}
optimized_code = {
candidate.optimization_id: {
module_path: replace_functions_and_add_imports(
intermediate_original_code[candidate.optimization_id][module_path],
(
[
callee.qualified_name
for callee in code_context.helper_functions
if callee.file_path == module_path and callee.definition_type != "class"
]
),
candidate.source_code,
function_to_optimize.file_path,
module_path,
[],
project_root,
)
for module_path in module_paths
}
for candidate in candidates
}
return module_paths, optimized_code
def is_optimized_module_code_zero_diff(
candidates: list[OptimizedCandidate],
validated_original_code: dict[Path, ValidCode],
optimized_code: dict[str, dict[Path, str]],
module_paths: set[Path],
) -> dict[str, dict[Path, bool]]:
return {
candidate.optimization_id: {
callee_module_path: normalize_code(optimized_code[candidate.optimization_id][callee_module_path])
== validated_original_code[callee_module_path].normalized_code
for callee_module_path in module_paths
}
for candidate in candidates
}
def candidates_with_diffs(
candidates: list[OptimizedCandidate],
validated_original_code: ValidCode,
optimized_code: dict[str, dict[Path, str]],
module_paths: set[Path],
) -> list[OptimizedCandidate]:
return [
candidate
for candidate in candidates
if not all(
is_optimized_module_code_zero_diff(candidates, validated_original_code, optimized_code, module_paths)[
candidate.optimization_id
].values()
)
]
def replace_optimized_code_in_worktrees(
optimized_code: dict[str, dict[Path, str]],
candidates: list[OptimizedCandidate], # Should be candidates_with_diffs
worktrees: list[Path],
git_root: Path, # Handle None case
) -> None:
for candidate, worktree in zip(candidates, worktrees[1:]):
for module_path in optimized_code[candidate.optimization_id]:
(worktree / module_path.relative_to(git_root)).write_text(
optimized_code[candidate.optimization_id][module_path], encoding="utf8"
) # Check with is_optimized_module_code_zero_diff
def function_to_optimize_original_worktree_fqn(
function_to_optimize: FunctionToOptimize, worktrees: list[Path], git_root: Path
) -> str:
return (
str(worktrees[0].name / function_to_optimize.file_path.relative_to(git_root).with_suffix("")).replace("/", ".")
+ "."
+ function_to_optimize.qualified_name
)

View file

@ -6,6 +6,8 @@ import logging
from pathlib import Path
from typing import TYPE_CHECKING, Any
import libcst as cst
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
from codeflash.languages.base import (
CodeContext,
@ -17,12 +19,18 @@ from codeflash.languages.base import (
TestResult,
)
from codeflash.languages.registry import register_language
from codeflash.models.function_types import FunctionParent
if TYPE_CHECKING:
import ast
from collections.abc import Sequence
from libcst import CSTNode
from libcst.metadata import CodeRange
from codeflash.languages.base import DependencyResolver
from codeflash.models.models import FunctionSource, GeneratedTestsList, InvocationId
from codeflash.models.models import FunctionSource, GeneratedTestsList, InvocationId, ValidCode
from codeflash.verification.verification_utils import TestConfig
logger = logging.getLogger(__name__)
@ -41,6 +49,70 @@ def function_sources_to_helpers(sources: list[FunctionSource]) -> list[HelperFun
]
class ReturnStatementVisitor(cst.CSTVisitor):
def __init__(self) -> None:
super().__init__()
self.has_return_statement: bool = False
def visit_Return(self, node: cst.Return) -> None:
self.has_return_statement = True
class FunctionVisitor(cst.CSTVisitor):
METADATA_DEPENDENCIES = (cst.metadata.PositionProvider, cst.metadata.ParentNodeProvider)
def __init__(self, file_path: Path) -> None:
super().__init__()
self.file_path: Path = file_path
self.functions: list[FunctionToOptimize] = []
@staticmethod
def is_pytest_fixture(node: cst.FunctionDef) -> bool:
for decorator in node.decorators:
dec = decorator.decorator
if isinstance(dec, cst.Call):
dec = dec.func
if isinstance(dec, cst.Attribute) and dec.attr.value == "fixture":
if isinstance(dec.value, cst.Name) and dec.value.value == "pytest":
return True
if isinstance(dec, cst.Name) and dec.value == "fixture":
return True
return False
@staticmethod
def is_property(node: cst.FunctionDef) -> bool:
for decorator in node.decorators:
dec = decorator.decorator
if isinstance(dec, cst.Name) and dec.value in ("property", "cached_property"):
return True
return False
def visit_FunctionDef(self, node: cst.FunctionDef) -> None:
return_visitor: ReturnStatementVisitor = ReturnStatementVisitor()
node.visit(return_visitor)
if return_visitor.has_return_statement and not self.is_pytest_fixture(node) and not self.is_property(node):
pos: CodeRange = self.get_metadata(cst.metadata.PositionProvider, node)
parents: CSTNode | None = self.get_metadata(cst.metadata.ParentNodeProvider, node)
ast_parents: list[FunctionParent] = []
while parents is not None:
if isinstance(parents, cst.FunctionDef):
# Skip nested functions — only discover top-level and class-level functions
return
if isinstance(parents, cst.ClassDef):
ast_parents.append(FunctionParent(parents.name.value, parents.__class__.__name__))
parents = self.get_metadata(cst.metadata.ParentNodeProvider, parents, default=None)
self.functions.append(
FunctionToOptimize(
function_name=node.name.value,
file_path=self.file_path,
parents=list(reversed(ast_parents)),
starting_line=pos.start.line,
ending_line=pos.end.line,
is_async=bool(node.asynchronous),
)
)
@register_language
class PythonSupport:
"""Python language support implementation.
@ -107,30 +179,70 @@ class PythonSupport:
}
)
@property
def default_language_version(self) -> str | None:
return None
@property
def valid_test_frameworks(self) -> tuple[str, ...]:
return ("pytest", "unittest")
@property
def test_result_serialization_format(self) -> str:
return "pickle"
def load_coverage(
self,
coverage_database_file: Path,
function_name: str,
code_context: Any,
source_file: Path,
coverage_config_file: Path | None = None,
) -> Any:
from codeflash.verification.coverage_utils import CoverageUtils
return CoverageUtils.load_from_sqlite_database(
database_path=coverage_database_file,
config_path=coverage_config_file,
source_code_path=source_file,
code_context=code_context,
function_name=function_name,
)
def process_generated_test_strings(
self,
generated_test_source: str,
instrumented_behavior_test_source: str,
instrumented_perf_test_source: str,
function_to_optimize: Any,
test_path: Path,
test_cfg: Any,
project_module_system: str | None,
) -> tuple[str, str, str]:
from codeflash.code_utils.code_utils import get_run_tmp_file
temp_run_dir = get_run_tmp_file(Path()).as_posix()
instrumented_behavior_test_source = instrumented_behavior_test_source.replace(
"{codeflash_run_tmp_dir_client_side}", temp_run_dir
)
instrumented_perf_test_source = instrumented_perf_test_source.replace(
"{codeflash_run_tmp_dir_client_side}", temp_run_dir
)
return generated_test_source, instrumented_behavior_test_source, instrumented_perf_test_source
def adjust_test_config_for_discovery(self, test_cfg: Any) -> None:
pass
def detect_module_system(self, project_root: Path, source_file: Path) -> str | None:
return None
# === Discovery ===
def discover_functions(
self, file_path: Path, filter_criteria: FunctionFilterCriteria | None = None
self, source: str, file_path: Path, filter_criteria: FunctionFilterCriteria | None = None
) -> list[FunctionToOptimize]:
"""Find all optimizable functions in a Python file.
Uses libcst to parse the file and find functions with return statements.
Args:
file_path: Path to the Python file to analyze.
filter_criteria: Optional criteria to filter functions.
Returns:
List of FunctionToOptimize objects for discovered functions.
"""
import libcst as cst
from codeflash.discovery.functions_to_optimize import FunctionVisitor
criteria = filter_criteria or FunctionFilterCriteria()
source = file_path.read_text(encoding="utf-8")
tree = cst.parse_module(source)
wrapper = cst.metadata.MetadataWrapper(tree)
@ -572,21 +684,10 @@ class PythonSupport:
return False
def normalize_code(self, source: str) -> str:
"""Normalize Python code for deduplication.
Removes comments, normalizes whitespace, and replaces variable names.
Args:
source: Source code to normalize.
Returns:
Normalized source code.
"""
from codeflash.code_utils.deduplicate_code import normalize_code
from codeflash.languages.python.normalizer import normalize_python_code
try:
return normalize_code(source, remove_docstrings=True, language=Language.PYTHON)
return normalize_python_code(source, remove_docstrings=True)
except Exception:
return source
@ -861,8 +962,335 @@ class PythonSupport:
# Python uses line_profiler which has its own output format
return {"timings": {}, "unit": 0, "str_out": ""}
@property
def function_optimizer_class(self) -> type:
from codeflash.languages.python.function_optimizer import PythonFunctionOptimizer
return PythonFunctionOptimizer
def prepare_module(
self, module_code: str, module_path: Path, project_root: Path
) -> tuple[dict[Path, ValidCode], ast.Module] | None:
from codeflash.languages.python.optimizer import prepare_python_module
return prepare_python_module(module_code, module_path, project_root)
pytest_cmd: str = "pytest"
def setup_test_config(self, test_cfg: TestConfig, file_path: Path) -> None:
self.pytest_cmd = test_cfg.pytest_cmd or "pytest"
def pytest_cmd_tokens(self, is_posix: bool) -> list[str]:
import shlex
return shlex.split(self.pytest_cmd, posix=is_posix)
def build_pytest_cmd(self, safe_sys_executable: str, is_posix: bool) -> list[str]:
return [safe_sys_executable, "-m", *self.pytest_cmd_tokens(is_posix)]
# === Test Execution (Full Protocol) ===
# Note: For Python, test execution is handled by the main test_runner.py
# which has special Python-specific logic. These methods are not called
# for Python as the test_runner checks is_python() and uses the existing path.
# They are defined here only for protocol compliance.
def run_behavioral_tests(
self,
test_paths: Any,
test_env: dict[str, str],
cwd: Path,
timeout: int | None = None,
project_root: Path | None = None,
enable_coverage: bool = False,
candidate_index: int = 0,
) -> tuple[Path, Any, Path | None, Path | None]:
import contextlib
import shlex
import sys
from codeflash.code_utils.code_utils import get_run_tmp_file
from codeflash.code_utils.compat import IS_POSIX, SAFE_SYS_EXECUTABLE
from codeflash.code_utils.config_consts import TOTAL_LOOPING_TIME_EFFECTIVE
from codeflash.languages.python.static_analysis.coverage_utils import prepare_coverage_files
from codeflash.models.models import TestType
from codeflash.verification.test_runner import execute_test_subprocess
blocklisted_plugins = ["benchmark", "codspeed", "xdist", "sugar"]
test_files: list[str] = []
for file in test_paths.test_files:
if file.test_type == TestType.REPLAY_TEST:
if file.tests_in_file:
test_files.extend(
[
str(file.instrumented_behavior_file_path) + "::" + test.test_function
for test in file.tests_in_file
]
)
else:
test_files.append(str(file.instrumented_behavior_file_path))
pytest_cmd_list = self.build_pytest_cmd(SAFE_SYS_EXECUTABLE, IS_POSIX)
test_files = list(set(test_files))
common_pytest_args = [
"--capture=tee-sys",
"-q",
"--codeflash_loops_scope=session",
"--codeflash_min_loops=1",
"--codeflash_max_loops=1",
f"--codeflash_seconds={TOTAL_LOOPING_TIME_EFFECTIVE}",
]
if timeout is not None:
common_pytest_args.append(f"--timeout={timeout}")
result_file_path = get_run_tmp_file(Path("pytest_results.xml"))
result_args = [f"--junitxml={result_file_path.as_posix()}", "-o", "junit_logging=all"]
pytest_test_env = test_env.copy()
pytest_test_env["PYTEST_PLUGINS"] = "codeflash.verification.pytest_plugin"
coverage_database_file: Path | None = None
coverage_config_file: Path | None = None
if enable_coverage:
coverage_database_file, coverage_config_file = prepare_coverage_files()
pytest_test_env["NUMBA_DISABLE_JIT"] = str(1)
pytest_test_env["TORCHDYNAMO_DISABLE"] = str(1)
pytest_test_env["PYTORCH_JIT"] = str(0)
pytest_test_env["TF_XLA_FLAGS"] = "--tf_xla_auto_jit=0"
pytest_test_env["TF_ENABLE_ONEDNN_OPTS"] = str(0)
pytest_test_env["JAX_DISABLE_JIT"] = str(0)
is_windows = sys.platform == "win32"
if is_windows:
if coverage_database_file.exists():
with contextlib.suppress(PermissionError, OSError):
coverage_database_file.unlink()
else:
cov_erase = execute_test_subprocess(
shlex.split(f"{SAFE_SYS_EXECUTABLE} -m coverage erase"), cwd=cwd, env=pytest_test_env, timeout=30
)
logger.debug(cov_erase)
coverage_cmd = [
SAFE_SYS_EXECUTABLE,
"-m",
"coverage",
"run",
f"--rcfile={coverage_config_file.as_posix()}",
"-m",
]
coverage_cmd.extend(self.pytest_cmd_tokens(IS_POSIX))
blocklist_args = [f"-p no:{plugin}" for plugin in blocklisted_plugins if plugin != "cov"]
results = execute_test_subprocess(
coverage_cmd + common_pytest_args + blocklist_args + result_args + test_files,
cwd=cwd,
env=pytest_test_env,
timeout=600,
)
logger.debug("Result return code: %s, %s", results.returncode, results.stderr or "")
else:
blocklist_args = [f"-p no:{plugin}" for plugin in blocklisted_plugins]
results = execute_test_subprocess(
pytest_cmd_list + common_pytest_args + blocklist_args + result_args + test_files,
cwd=cwd,
env=pytest_test_env,
timeout=600,
)
logger.debug("Result return code: %s, %s", results.returncode, results.stderr or "")
return result_file_path, results, coverage_database_file, coverage_config_file
def run_benchmarking_tests(
self,
test_paths: Any,
test_env: dict[str, str],
cwd: Path,
timeout: int | None = None,
project_root: Path | None = None,
min_loops: int = 5,
max_loops: int = 100_000,
target_duration_seconds: float = 10.0,
) -> tuple[Path, Any]:
from codeflash.code_utils.code_utils import get_run_tmp_file
from codeflash.code_utils.compat import IS_POSIX, SAFE_SYS_EXECUTABLE
from codeflash.verification.test_runner import execute_test_subprocess
blocklisted_plugins = ["codspeed", "cov", "benchmark", "profiling", "xdist", "sugar"]
pytest_cmd_list = self.build_pytest_cmd(SAFE_SYS_EXECUTABLE, IS_POSIX)
test_files: list[str] = list({str(file.benchmarking_file_path) for file in test_paths.test_files})
pytest_args = [
"--capture=tee-sys",
"-q",
"--codeflash_loops_scope=session",
f"--codeflash_min_loops={min_loops}",
f"--codeflash_max_loops={max_loops}",
f"--codeflash_seconds={target_duration_seconds}",
"--codeflash_stability_check=true",
]
if timeout is not None:
pytest_args.append(f"--timeout={timeout}")
result_file_path = get_run_tmp_file(Path("pytest_results.xml"))
result_args = [f"--junitxml={result_file_path.as_posix()}", "-o", "junit_logging=all"]
pytest_test_env = test_env.copy()
pytest_test_env["PYTEST_PLUGINS"] = "codeflash.verification.pytest_plugin"
blocklist_args = [f"-p no:{plugin}" for plugin in blocklisted_plugins]
results = execute_test_subprocess(
pytest_cmd_list + pytest_args + blocklist_args + result_args + test_files,
cwd=cwd,
env=pytest_test_env,
timeout=600,
)
return result_file_path, results
def run_line_profile_tests(
self,
test_paths: Any,
test_env: dict[str, str],
cwd: Path,
timeout: int | None = None,
project_root: Path | None = None,
line_profile_output_file: Path | None = None,
) -> tuple[Path, Any]:
from codeflash.code_utils.code_utils import get_run_tmp_file
from codeflash.code_utils.compat import IS_POSIX, SAFE_SYS_EXECUTABLE
from codeflash.code_utils.config_consts import TOTAL_LOOPING_TIME_EFFECTIVE
from codeflash.verification.test_runner import execute_test_subprocess
blocklisted_plugins = ["codspeed", "cov", "benchmark", "profiling", "xdist", "sugar"]
pytest_cmd_list = self.build_pytest_cmd(SAFE_SYS_EXECUTABLE, IS_POSIX)
test_files: list[str] = list({str(file.benchmarking_file_path) for file in test_paths.test_files})
pytest_args = [
"--capture=tee-sys",
"-q",
"--codeflash_loops_scope=session",
"--codeflash_min_loops=1",
"--codeflash_max_loops=1",
f"--codeflash_seconds={TOTAL_LOOPING_TIME_EFFECTIVE}",
]
if timeout is not None:
pytest_args.append(f"--timeout={timeout}")
result_file_path = get_run_tmp_file(Path("pytest_results.xml"))
result_args = [f"--junitxml={result_file_path.as_posix()}", "-o", "junit_logging=all"]
pytest_test_env = test_env.copy()
pytest_test_env["PYTEST_PLUGINS"] = "codeflash.verification.pytest_plugin"
blocklist_args = [f"-p no:{plugin}" for plugin in blocklisted_plugins]
pytest_test_env["LINE_PROFILE"] = "1"
results = execute_test_subprocess(
pytest_cmd_list + pytest_args + blocklist_args + result_args + test_files,
cwd=cwd,
env=pytest_test_env,
timeout=600,
)
return result_file_path, results
def generate_concolic_tests(
self, test_cfg: Any, project_root: Path, function_to_optimize: FunctionToOptimize, function_to_optimize_ast: Any
) -> tuple[dict, str]:
import ast
import importlib.util
import subprocess
import tempfile
import time
from codeflash.cli_cmds.console import console
from codeflash.code_utils.compat import SAFE_SYS_EXECUTABLE
from codeflash.code_utils.shell_utils import make_env_with_project_root
from codeflash.discovery.discover_unit_tests import discover_unit_tests
from codeflash.languages.python.static_analysis.concolic_utils import (
clean_concolic_tests,
is_valid_concolic_test,
)
from codeflash.languages.python.static_analysis.static_analysis import has_typed_parameters
from codeflash.lsp.helpers import is_LSP_enabled
from codeflash.telemetry.posthog_cf import ph
from codeflash.verification.verification_utils import TestConfig
crosshair_available = importlib.util.find_spec("crosshair") is not None
start_time = time.perf_counter()
function_to_concolic_tests: dict = {}
concolic_test_suite_code = ""
if not crosshair_available:
logger.debug("Skipping concolic test generation (crosshair-tool is not installed)")
return function_to_concolic_tests, concolic_test_suite_code
if is_LSP_enabled():
logger.debug("Skipping concolic test generation in LSP mode")
return function_to_concolic_tests, concolic_test_suite_code
if (
test_cfg.concolic_test_root_dir
and isinstance(function_to_optimize_ast, ast.FunctionDef)
and has_typed_parameters(function_to_optimize_ast, function_to_optimize.parents)
):
logger.info("Generating concolic opcode coverage tests for the original code…")
console.rule()
try:
env = make_env_with_project_root(project_root)
cover_result = subprocess.run(
[
SAFE_SYS_EXECUTABLE,
"-m",
"crosshair",
"cover",
"--example_output_format=pytest",
"--per_condition_timeout=20",
".".join(
[
function_to_optimize.file_path.relative_to(project_root)
.with_suffix("")
.as_posix()
.replace("/", "."),
function_to_optimize.qualified_name,
]
),
],
capture_output=True,
text=True,
cwd=project_root,
check=False,
timeout=600,
env=env,
)
except subprocess.TimeoutExpired:
logger.debug("CrossHair Cover test generation timed out")
return function_to_concolic_tests, concolic_test_suite_code
if cover_result.returncode == 0:
generated_concolic_test: str = cover_result.stdout
if not is_valid_concolic_test(generated_concolic_test, project_root=str(project_root)):
logger.debug("CrossHair generated invalid test, skipping")
console.rule()
return function_to_concolic_tests, concolic_test_suite_code
concolic_test_suite_code = clean_concolic_tests(generated_concolic_test)
concolic_test_suite_dir = Path(tempfile.mkdtemp(dir=test_cfg.concolic_test_root_dir))
concolic_test_suite_path = concolic_test_suite_dir / "test_concolic_coverage.py"
concolic_test_suite_path.write_text(concolic_test_suite_code, encoding="utf8")
concolic_test_cfg = TestConfig(
tests_root=concolic_test_suite_dir,
tests_project_rootdir=test_cfg.concolic_test_root_dir,
project_root_path=project_root,
)
function_to_concolic_tests, num_discovered_concolic_tests, _ = discover_unit_tests(concolic_test_cfg)
logger.info(
"Created %d concolic unit test case%s ",
num_discovered_concolic_tests,
"s" if num_discovered_concolic_tests != 1 else "",
)
console.rule()
ph("cli-optimize-concolic-tests", {"num_tests": num_discovered_concolic_tests})
else:
logger.debug(
"Error running CrossHair Cover%s", ": " + cover_result.stderr if cover_result.stderr else "."
)
console.rule()
end_time = time.perf_counter()
logger.debug("Generated concolic tests in %.2f seconds", end_time - start_time)
return function_to_concolic_tests, concolic_test_suite_code

View file

@ -53,7 +53,10 @@ def _ensure_languages_registered() -> None:
from codeflash.languages.python import support as _
with contextlib.suppress(ImportError):
from codeflash.languages.javascript import support as _ # noqa: F401
from codeflash.languages.javascript import support as _
with contextlib.suppress(ImportError):
from codeflash.languages.java import support as _ # noqa: F401
_languages_registered = True

View file

@ -30,7 +30,7 @@ from __future__ import annotations
from typing import Literal
TestFramework = Literal["jest", "vitest", "mocha", "pytest", "unittest"]
TestFramework = Literal["jest", "vitest", "mocha", "pytest", "unittest", "junit5", "junit4", "testng"]
# Module-level singleton for the current test framework
_current_test_framework: TestFramework | None = None
@ -63,11 +63,11 @@ def set_current_test_framework(framework: TestFramework | str | None) -> None:
if framework is not None:
framework = framework.lower()
if framework not in ("jest", "vitest", "mocha", "pytest", "unittest"):
if framework not in ("jest", "vitest", "mocha", "pytest", "unittest", "junit5", "junit4", "testng"):
# Default to jest for unknown JS frameworks, pytest for unknown Python
from codeflash.languages.current import is_javascript
from codeflash.languages.current import current_language_support
framework = "jest" if is_javascript() else "pytest"
framework = current_language_support().test_framework
_current_test_framework = framework

View file

@ -463,14 +463,10 @@ def _initialize_current_function_optimizer() -> Union[dict[str, str], WrappedIni
"message": "Failed to prepare module for optimization",
}
validated_original_code, original_module_ast = module_prep_result
validated_original_code, _original_module_ast = module_prep_result
function_optimizer = server.optimizer.create_function_optimizer(
fto,
function_to_optimize_source_code=validated_original_code[fto.file_path].source_code,
original_module_ast=original_module_ast,
original_module_path=fto.file_path,
function_to_tests={},
fto, function_to_optimize_source_code=validated_original_code[fto.file_path].source_code, function_to_tests={}
)
server.optimizer.current_function_optimizer = function_optimizer

View file

@ -245,8 +245,7 @@ class CodeString(BaseModel):
"""Validate code syntax for the specified language."""
if self.language == "python":
validate_python_code(self.code)
elif self.language in ("javascript", "typescript"):
# Validate JavaScript/TypeScript syntax using language support
else:
from codeflash.languages.registry import get_language_support
lang_support = get_language_support(self.language)

View file

@ -1,6 +1,5 @@
from __future__ import annotations
import ast
import concurrent.futures
import dataclasses
import logging
@ -59,32 +58,16 @@ from codeflash.code_utils.config_consts import (
EffortLevel,
get_effort_value,
)
from codeflash.code_utils.deduplicate_code import normalize_code
from codeflash.code_utils.env_utils import get_pr_number
from codeflash.code_utils.formatter import format_code, format_generated_code, sort_imports
from codeflash.code_utils.git_utils import git_root_dir
from codeflash.code_utils.instrument_existing_tests import inject_profiling_into_existing_test
from codeflash.code_utils.shell_utils import make_env_with_project_root
from codeflash.code_utils.time_utils import humanize_runtime
from codeflash.discovery.functions_to_optimize import was_function_previously_optimized
from codeflash.either import Failure, Success, is_successful
from codeflash.languages import is_python
from codeflash.languages.base import Language
from codeflash.languages.current import current_language_support
from codeflash.languages.javascript.test_runner import clear_created_config_files, get_created_config_files
from codeflash.languages.python.context import code_context_extractor
from codeflash.languages.python.context.unused_definition_remover import (
detect_unused_helper_functions,
revert_unused_helper_functions,
)
from codeflash.languages.python.static_analysis.code_extractor import get_opt_review_metrics, is_numerical_code
from codeflash.languages.python.static_analysis.code_replacer import (
add_custom_marker_to_all_tests,
modify_autouse_fixture,
replace_function_definitions_in_module,
)
from codeflash.languages.python.static_analysis.line_profile_utils import add_decorator_imports, contains_jit_decorator
from codeflash.languages.python.static_analysis.static_analysis import get_first_top_level_function_or_method_ast
from codeflash.lsp.helpers import is_LSP_enabled, is_subagent_mode, report_to_markdown_table, tree_to_markdown
from codeflash.lsp.lsp_message import LspCodeMessage, LspMarkdownMessage, LSPMessageId
from codeflash.models.ExperimentMetadata import ExperimentMetadata
@ -94,7 +77,6 @@ from codeflash.models.models import (
AIServiceCodeRepairRequest,
BestOptimization,
CandidateEvaluationContext,
CodeOptimizationContext,
GeneratedTests,
GeneratedTestsList,
OptimizationReviewResult,
@ -121,27 +103,23 @@ from codeflash.result.critic import (
)
from codeflash.result.explanation import Explanation
from codeflash.telemetry.posthog_cf import ph
from codeflash.verification.concolic_testing import generate_concolic_tests
from codeflash.verification.equivalence import compare_test_results
from codeflash.verification.instrument_codeflash_capture import instrument_codeflash_capture
from codeflash.verification.parse_line_profile_test_output import parse_line_profile_results
from codeflash.verification.parse_test_output import (
calculate_function_throughput_from_test_results,
parse_concurrency_metrics,
parse_test_results,
)
from codeflash.verification.test_runner import run_behavioral_tests, run_benchmarking_tests, run_line_profile_tests
from codeflash.verification.parse_test_output import parse_concurrency_metrics, parse_test_results
from codeflash.verification.verification_utils import get_test_file_path
from codeflash.verification.verifier import generate_tests
if TYPE_CHECKING:
import ast
from argparse import Namespace
from typing import Any
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
from codeflash.either import Result
from codeflash.languages.base import DependencyResolver
from codeflash.models.function_types import FunctionParent
from codeflash.models.models import (
BenchmarkKey,
CodeOptimizationContext,
CodeStringsMarkdown,
ConcurrencyMetrics,
CoverageData,
@ -459,14 +437,9 @@ class FunctionOptimizer:
)
self.language_support = current_language_support()
if not function_to_optimize_ast:
# Skip Python AST parsing for non-Python languages
if not is_python():
self.function_to_optimize_ast = None
else:
original_module_ast = ast.parse(function_to_optimize_source_code)
self.function_to_optimize_ast = get_first_top_level_function_or_method_ast(
function_to_optimize.function_name, function_to_optimize.parents, original_module_ast
)
self.function_to_optimize_ast = self._resolve_function_ast(
self.function_to_optimize_source_code, function_to_optimize.function_name, function_to_optimize.parents
)
else:
self.function_to_optimize_ast = function_to_optimize_ast
self.function_to_tests = function_to_tests if function_to_tests else {}
@ -503,6 +476,71 @@ class FunctionOptimizer:
self.is_numerical_code: bool | None = None
self.code_already_exists: bool = False
# --- Hooks for language-specific subclasses ---
def _resolve_function_ast(
self, source_code: str, function_name: str, parents: list[FunctionParent]
) -> ast.FunctionDef | ast.AsyncFunctionDef | None:
return None
def requires_function_ast(self) -> bool:
return False
def analyze_code_characteristics(self, code_context: CodeOptimizationContext) -> None:
pass
def get_optimization_review_metrics(
self,
source_code: str,
file_path: Path,
qualified_name: str,
project_root: Path,
tests_root: Path,
language: Language,
) -> str:
return ""
def instrument_test_fixtures(self, test_paths: list[Path]) -> dict[Path, list[str]] | None:
return None
def instrument_async_for_mode(self, mode: TestingMode) -> None:
pass
def instrument_capture(self, file_path_to_helper_classes: dict[Path, set[str]]) -> None:
pass
def should_check_coverage(self) -> bool:
return False
def collect_async_metrics(
self,
benchmarking_results: TestResults,
code_context: CodeOptimizationContext,
helper_code: dict[Path, str],
test_env: dict[str, str],
) -> tuple[int | None, ConcurrencyMetrics | None]:
return None, None
def compare_candidate_results(
self,
baseline_results: OriginalCodeBaseline,
candidate_behavior_results: TestResults,
optimization_candidate_index: int,
) -> tuple[bool, list[TestDiff]]:
return compare_test_results(
baseline_results.behavior_test_results, candidate_behavior_results, pass_fail_only=True
)
def should_skip_sqlite_cleanup(self, testing_type: TestingMode, optimization_iteration: int) -> bool:
return False
def parse_line_profile_test_results(
self, line_profiler_output_file: Path | None
) -> tuple[TestResults | dict, CoverageData | None]:
return TestResults(test_results=[]), None
# --- End hooks ---
def can_be_optimized(self) -> Result[tuple[bool, CodeOptimizationContext, dict[Path, str]], str]:
should_run_experiment = self.experiment_id is not None
logger.info(f"!lsp|Function Trace ID: {self.function_trace_id}")
@ -656,10 +694,7 @@ class FunctionOptimizer:
original_conftest_content = None
if self.args.override_fixtures:
logger.info("Disabling all autouse fixtures associated with the generated test files")
original_conftest_content = modify_autouse_fixture(generated_test_paths + generated_perf_test_paths)
logger.info("Add custom marker to generated test files")
add_custom_marker_to_all_tests(generated_test_paths + generated_perf_test_paths)
original_conftest_content = self.instrument_test_fixtures(generated_test_paths + generated_perf_test_paths)
return Success(
(
@ -679,7 +714,7 @@ class FunctionOptimizer:
if not is_successful(initialization_result):
return Failure(initialization_result.failure())
should_run_experiment, code_context, original_helper_code = initialization_result.unwrap()
self.is_numerical_code = is_numerical_code(code_string=code_context.read_writable_code.flat)
self.analyze_code_characteristics(code_context)
code_print(
code_context.read_writable_code.flat,
file_name=self.function_to_optimize.file_path,
@ -929,7 +964,9 @@ class FunctionOptimizer:
runtimes_list = []
for valid_opt in eval_ctx.valid_optimizations:
valid_opt_normalized_code = normalize_code(valid_opt.candidate.source_code.flat.strip())
valid_opt_normalized_code = self.language_support.normalize_code(
valid_opt.candidate.source_code.flat.strip()
)
new_candidate_with_shorter_code = OptimizedCandidate(
source_code=eval_ctx.ast_code_to_id[valid_opt_normalized_code]["shorter_source_code"],
optimization_id=valid_opt.candidate.optimization_id,
@ -1036,7 +1073,7 @@ class FunctionOptimizer:
candidate = candidate_node.candidate
normalized_code = normalize_code(candidate.source_code.flat.strip())
normalized_code = self.language_support.normalize_code(candidate.source_code.flat.strip())
if normalized_code == normalized_original:
logger.info(f"h3|Candidate {candidate_index}/{total_candidates}: Identical to original code, skipping.")
@ -1248,7 +1285,7 @@ class FunctionOptimizer:
self.future_adaptive_optimizations,
)
candidate_index = 0
normalized_original = normalize_code(code_context.read_writable_code.flat.strip())
normalized_original = self.language_support.normalize_code(code_context.read_writable_code.flat.strip())
# Process candidates using queue-based approach
while not processor.is_done():
@ -1493,57 +1530,24 @@ class FunctionOptimizer:
return new_code, new_helper_code
def group_functions_by_file(self, code_context: CodeOptimizationContext) -> dict[Path, set[str]]:
functions_by_file: dict[Path, set[str]] = defaultdict(set)
functions_by_file[self.function_to_optimize.file_path].add(self.function_to_optimize.qualified_name)
for helper in code_context.helper_functions:
if helper.definition_type != "class":
functions_by_file[helper.file_path].add(helper.qualified_name)
return functions_by_file
def replace_function_and_helpers_with_optimized_code(
self,
code_context: CodeOptimizationContext,
optimized_code: CodeStringsMarkdown,
original_helper_code: dict[Path, str],
) -> bool:
did_update = False
read_writable_functions_by_file_path = defaultdict(set)
read_writable_functions_by_file_path[self.function_to_optimize.file_path].add(
self.function_to_optimize.qualified_name
)
for helper_function in code_context.helper_functions:
# Skip class definitions (definition_type may be None for non-Python languages)
if helper_function.definition_type != "class":
read_writable_functions_by_file_path[helper_function.file_path].add(helper_function.qualified_name)
for module_abspath, qualified_names in read_writable_functions_by_file_path.items():
did_update |= replace_function_definitions_in_module(
function_names=list(qualified_names),
optimized_code=optimized_code,
module_abspath=module_abspath,
preexisting_objects=code_context.preexisting_objects,
project_root_path=self.project_root,
)
unused_helpers = detect_unused_helper_functions(self.function_to_optimize, code_context, optimized_code)
# Revert unused helper functions to their original definitions
if unused_helpers:
revert_unused_helper_functions(self.project_root, unused_helpers, original_helper_code)
return did_update
raise NotImplementedError
def get_code_optimization_context(self) -> Result[CodeOptimizationContext, str]:
try:
new_code_ctx = code_context_extractor.get_code_optimization_context(
self.function_to_optimize, self.project_root, call_graph=self.call_graph
)
except ValueError as e:
return Failure(str(e))
return Success(
CodeOptimizationContext(
testgen_context=new_code_ctx.testgen_context,
read_writable_code=new_code_ctx.read_writable_code,
read_only_context_code=new_code_ctx.read_only_context_code,
hashing_code_context=new_code_ctx.hashing_code_context,
hashing_code_context_hash=new_code_ctx.hashing_code_context_hash,
helper_functions=new_code_ctx.helper_functions,
testgen_helper_fqns=new_code_ctx.testgen_helper_fqns,
preexisting_objects=new_code_ctx.preexisting_objects,
)
)
raise NotImplementedError
@staticmethod
def cleanup_leftover_test_return_values() -> None:
@ -1560,178 +1564,27 @@ class FunctionOptimizer:
func_qualname = self.function_to_optimize.qualified_name_with_modules_from_root(self.project_root)
if func_qualname not in function_to_all_tests:
logger.info(f"Did not find any pre-existing tests for '{func_qualname}', will only use generated tests.")
# Handle non-Python existing test instrumentation
elif not is_python():
test_file_invocation_positions = defaultdict(list)
for tests_in_file in function_to_all_tests.get(func_qualname):
test_file_invocation_positions[
(tests_in_file.tests_in_file.test_file, tests_in_file.tests_in_file.test_type)
].append(tests_in_file)
return unique_instrumented_test_files
for (test_file, test_type), tests_in_file_list in test_file_invocation_positions.items():
path_obj_test_file = Path(test_file)
if test_type == TestType.EXISTING_UNIT_TEST:
existing_test_files_count += 1
elif test_type == TestType.REPLAY_TEST:
replay_test_files_count += 1
elif test_type == TestType.CONCOLIC_COVERAGE_TEST:
concolic_coverage_test_files_count += 1
else:
msg = f"Unexpected test type: {test_type}"
raise ValueError(msg)
test_file_invocation_positions = defaultdict(list)
for tests_in_file in function_to_all_tests.get(func_qualname):
test_file_invocation_positions[
(tests_in_file.tests_in_file.test_file, tests_in_file.tests_in_file.test_type)
].append(tests_in_file)
# Use language-specific instrumentation
success, injected_behavior_test = self.language_support.instrument_existing_test(
test_path=path_obj_test_file,
call_positions=[test.position for test in tests_in_file_list],
function_to_optimize=self.function_to_optimize,
tests_project_root=self.test_cfg.tests_project_rootdir,
mode="behavior",
)
if not success:
logger.debug(f"Failed to instrument test file {test_file} for behavior testing")
continue
success, injected_perf_test = self.language_support.instrument_existing_test(
test_path=path_obj_test_file,
call_positions=[test.position for test in tests_in_file_list],
function_to_optimize=self.function_to_optimize,
tests_project_root=self.test_cfg.tests_project_rootdir,
mode="performance",
)
if not success:
logger.debug(f"Failed to instrument test file {test_file} for performance testing")
continue
# Generate instrumented test file paths
# For JS/TS, preserve .test.ts or .spec.ts suffix for Jest pattern matching
def get_instrumented_path(original_path: str, suffix: str) -> Path:
"""Generate instrumented test file path preserving .test/.spec pattern."""
path_obj = Path(original_path)
stem = path_obj.stem # e.g., "fibonacci.test"
ext = path_obj.suffix # e.g., ".ts"
# Check for .test or .spec in stem (JS/TS pattern)
if ".test" in stem:
# fibonacci.test -> fibonacci__suffix.test
base, _ = stem.rsplit(".test", 1)
new_stem = f"{base}{suffix}.test"
elif ".spec" in stem:
base, _ = stem.rsplit(".spec", 1)
new_stem = f"{base}{suffix}.spec"
else:
# Default Python-style: add suffix before extension
new_stem = f"{stem}{suffix}"
return path_obj.parent / f"{new_stem}{ext}"
new_behavioral_test_path = get_instrumented_path(test_file, "__perfinstrumented")
new_perf_test_path = get_instrumented_path(test_file, "__perfonlyinstrumented")
if injected_behavior_test is not None:
with new_behavioral_test_path.open("w", encoding="utf8") as _f:
_f.write(injected_behavior_test)
logger.debug(f"[PIPELINE] Wrote instrumented behavior test to {new_behavioral_test_path}")
else:
msg = "injected_behavior_test is None"
raise ValueError(msg)
if injected_perf_test is not None:
with new_perf_test_path.open("w", encoding="utf8") as _f:
_f.write(injected_perf_test)
logger.debug(f"[PIPELINE] Wrote instrumented perf test to {new_perf_test_path}")
unique_instrumented_test_files.add(new_behavioral_test_path)
unique_instrumented_test_files.add(new_perf_test_path)
if not self.test_files.get_by_original_file_path(path_obj_test_file):
self.test_files.add(
TestFile(
instrumented_behavior_file_path=new_behavioral_test_path,
benchmarking_file_path=new_perf_test_path,
original_source=None,
original_file_path=Path(test_file),
test_type=test_type,
tests_in_file=[t.tests_in_file for t in tests_in_file_list],
)
)
if existing_test_files_count > 0 or replay_test_files_count > 0 or concolic_coverage_test_files_count > 0:
logger.info(
f"Instrumented {existing_test_files_count} existing unit test file"
f"{'s' if existing_test_files_count != 1 else ''}, {replay_test_files_count} replay test file"
f"{'s' if replay_test_files_count != 1 else ''}, and "
f"{concolic_coverage_test_files_count} concolic coverage test file"
f"{'s' if concolic_coverage_test_files_count != 1 else ''} for {func_qualname}"
)
console.rule()
else:
test_file_invocation_positions = defaultdict(list)
for tests_in_file in function_to_all_tests.get(func_qualname):
test_file_invocation_positions[
(tests_in_file.tests_in_file.test_file, tests_in_file.tests_in_file.test_type)
].append(tests_in_file)
for (test_file, test_type), tests_in_file_list in test_file_invocation_positions.items():
path_obj_test_file = Path(test_file)
if test_type == TestType.EXISTING_UNIT_TEST:
existing_test_files_count += 1
elif test_type == TestType.REPLAY_TEST:
replay_test_files_count += 1
elif test_type == TestType.CONCOLIC_COVERAGE_TEST:
concolic_coverage_test_files_count += 1
else:
msg = f"Unexpected test type: {test_type}"
raise ValueError(msg)
success, injected_behavior_test = inject_profiling_into_existing_test(
mode=TestingMode.BEHAVIOR,
test_path=path_obj_test_file,
call_positions=[test.position for test in tests_in_file_list],
function_to_optimize=self.function_to_optimize,
tests_project_root=self.test_cfg.tests_project_rootdir,
)
if not success:
continue
success, injected_perf_test = inject_profiling_into_existing_test(
mode=TestingMode.PERFORMANCE,
test_path=path_obj_test_file,
call_positions=[test.position for test in tests_in_file_list],
function_to_optimize=self.function_to_optimize,
tests_project_root=self.test_cfg.tests_project_rootdir,
)
if not success:
continue
# TODO: this naming logic should be moved to a function and made more standard
new_behavioral_test_path = Path(
f"{os.path.splitext(test_file)[0]}__perfinstrumented{os.path.splitext(test_file)[1]}" # noqa: PTH122
)
new_perf_test_path = Path(
f"{os.path.splitext(test_file)[0]}__perfonlyinstrumented{os.path.splitext(test_file)[1]}" # noqa: PTH122
)
if injected_behavior_test is not None:
with new_behavioral_test_path.open("w", encoding="utf8") as _f:
_f.write(injected_behavior_test)
else:
msg = "injected_behavior_test is None"
raise ValueError(msg)
if injected_perf_test is not None:
with new_perf_test_path.open("w", encoding="utf8") as _f:
_f.write(injected_perf_test)
unique_instrumented_test_files.add(new_behavioral_test_path)
unique_instrumented_test_files.add(new_perf_test_path)
if not self.test_files.get_by_original_file_path(path_obj_test_file):
self.test_files.add(
TestFile(
instrumented_behavior_file_path=new_behavioral_test_path,
benchmarking_file_path=new_perf_test_path,
original_source=None,
original_file_path=Path(test_file),
test_type=test_type,
tests_in_file=[t.tests_in_file for t in tests_in_file_list],
)
)
for test_file, test_type in test_file_invocation_positions:
path_obj_test_file = Path(test_file)
if test_type == TestType.EXISTING_UNIT_TEST:
existing_test_files_count += 1
elif test_type == TestType.REPLAY_TEST:
replay_test_files_count += 1
elif test_type == TestType.CONCOLIC_COVERAGE_TEST:
concolic_coverage_test_files_count += 1
else:
msg = f"Unexpected test type: {test_type}"
raise ValueError(msg)
if existing_test_files_count > 0 or replay_test_files_count > 0 or concolic_coverage_test_files_count > 0:
logger.info(
f"Discovered {existing_test_files_count} existing unit test file"
f"{'s' if existing_test_files_count != 1 else ''}, {replay_test_files_count} replay test file"
@ -1740,6 +1593,87 @@ class FunctionOptimizer:
f"{'s' if concolic_coverage_test_files_count != 1 else ''} for {func_qualname}"
)
console.rule()
for (test_file, test_type), tests_in_file_list in test_file_invocation_positions.items():
path_obj_test_file = Path(test_file)
# Use language-specific instrumentation
success, injected_behavior_test = self.language_support.instrument_existing_test(
test_path=path_obj_test_file,
call_positions=[test.position for test in tests_in_file_list],
function_to_optimize=self.function_to_optimize,
tests_project_root=self.test_cfg.tests_project_rootdir,
mode="behavior",
)
if not success:
logger.debug(f"Failed to instrument test file {test_file} for behavior testing")
continue
success, injected_perf_test = self.language_support.instrument_existing_test(
test_path=path_obj_test_file,
call_positions=[test.position for test in tests_in_file_list],
function_to_optimize=self.function_to_optimize,
tests_project_root=self.test_cfg.tests_project_rootdir,
mode="performance",
)
if not success:
logger.debug(f"Failed to instrument test file {test_file} for performance testing")
continue
# For JS/TS, preserve .test.ts or .spec.ts suffix for Jest pattern matching
def get_instrumented_path(original_path: str, suffix: str) -> Path:
path_obj = Path(original_path)
stem = path_obj.stem
ext = path_obj.suffix
if ".test" in stem:
base, _ = stem.rsplit(".test", 1)
new_stem = f"{base}{suffix}.test"
elif ".spec" in stem:
base, _ = stem.rsplit(".spec", 1)
new_stem = f"{base}{suffix}.spec"
else:
new_stem = f"{stem}{suffix}"
return path_obj.parent / f"{new_stem}{ext}"
new_behavioral_test_path = get_instrumented_path(test_file, "__perfinstrumented")
new_perf_test_path = get_instrumented_path(test_file, "__perfonlyinstrumented")
if injected_behavior_test is not None:
with new_behavioral_test_path.open("w", encoding="utf8") as _f:
_f.write(injected_behavior_test)
logger.debug(f"[PIPELINE] Wrote instrumented behavior test to {new_behavioral_test_path}")
else:
msg = "injected_behavior_test is None"
raise ValueError(msg)
if injected_perf_test is not None:
with new_perf_test_path.open("w", encoding="utf8") as _f:
_f.write(injected_perf_test)
logger.debug(f"[PIPELINE] Wrote instrumented perf test to {new_perf_test_path}")
unique_instrumented_test_files.add(new_behavioral_test_path)
unique_instrumented_test_files.add(new_perf_test_path)
if not self.test_files.get_by_original_file_path(path_obj_test_file):
self.test_files.add(
TestFile(
instrumented_behavior_file_path=new_behavioral_test_path,
benchmarking_file_path=new_perf_test_path,
original_source=None,
original_file_path=Path(test_file),
test_type=test_type,
tests_in_file=[t.tests_in_file for t in tests_in_file_list],
)
)
instrumented_count = len(unique_instrumented_test_files) // 2 # each test produces behavior + perf files
if instrumented_count > 0:
logger.info(
f"Instrumented {instrumented_count} existing unit test file"
f"{'s' if instrumented_count != 1 else ''} for {func_qualname}"
)
console.rule()
return unique_instrumented_test_files
def generate_tests(
@ -1764,9 +1698,9 @@ class FunctionOptimizer:
future_concolic_tests = None
else:
future_concolic_tests = self.executor.submit(
generate_concolic_tests,
self.language_support.generate_concolic_tests,
self.test_cfg,
self.args,
self.args.project_root,
self.function_to_optimize,
self.function_to_optimize_ast,
)
@ -1842,7 +1776,7 @@ class FunctionOptimizer:
)
future_references = self.executor.submit(
get_opt_review_metrics,
self.get_optimization_review_metrics,
self.function_to_optimize_source_code,
self.function_to_optimize.file_path,
self.function_to_optimize.qualified_name,
@ -1933,8 +1867,7 @@ class FunctionOptimizer:
original_code_baseline, test_functions_to_remove = baseline_result.unwrap()
# Check test quantity for all languages
quantity_ok = quantity_of_tests_critic(original_code_baseline)
# TODO: {Self} Only check coverage for Python - coverage infrastructure not yet reliable for JS/TS
coverage_ok = coverage_critic(original_code_baseline.coverage_results) if is_python() else True
coverage_ok = coverage_critic(original_code_baseline.coverage_results) if self.should_check_coverage() else True
if isinstance(original_code_baseline, OriginalCodeBaseline) and (not coverage_ok or not quantity_ok):
if self.args.override_fixtures:
restore_conftest(original_conftest_content)
@ -2140,7 +2073,6 @@ class FunctionOptimizer:
if (
self.function_to_optimize.is_async
and is_python()
and original_code_baseline.async_throughput is not None
and best_optimization.async_throughput is not None
):
@ -2344,25 +2276,13 @@ class FunctionOptimizer:
test_env = self.get_test_env(codeflash_loop_index=0, codeflash_test_iteration=0, codeflash_tracer_disable=1)
if self.function_to_optimize.is_async and is_python():
from codeflash.code_utils.instrument_existing_tests import add_async_decorator_to_function
success = add_async_decorator_to_function(
self.function_to_optimize.file_path,
self.function_to_optimize,
TestingMode.BEHAVIOR,
project_root=self.project_root,
)
if self.function_to_optimize.is_async:
self.instrument_async_for_mode(TestingMode.BEHAVIOR)
# Instrument codeflash capture
with progress_bar("Running tests to establish original code behavior..."):
try:
# Only instrument Python code here - non-Python languages use their own runtime helpers
# which are already included in the generated/instrumented tests
if is_python():
instrument_codeflash_capture(
self.function_to_optimize, file_path_to_helper_classes, self.test_cfg.tests_root
)
self.instrument_capture(file_path_to_helper_classes)
total_looping_time = TOTAL_LOOPING_TIME_EFFECTIVE
logger.debug(f"[PIPELINE] Establishing baseline with {len(self.test_files)} test files")
@ -2391,7 +2311,7 @@ class FunctionOptimizer:
console.rule()
return Failure("Failed to establish a baseline for the original code - bevhavioral tests failed.")
# Skip coverage check for non-Python languages (coverage not yet supported)
if is_python() and not coverage_critic(coverage_results):
if self.should_check_coverage() and not coverage_critic(coverage_results):
did_pass_all_tests = all(result.did_pass for result in behavioral_results)
if not did_pass_all_tests:
return Failure("Tests failed to pass for the original code.")
@ -2412,15 +2332,8 @@ class FunctionOptimizer:
for idx, tf in enumerate(self.test_files.test_files):
logger.debug(f"[BENCHMARK-FILES] Test file {idx}: perf_file={tf.benchmarking_file_path}")
if self.function_to_optimize.is_async and is_python():
from codeflash.code_utils.instrument_existing_tests import add_async_decorator_to_function
add_async_decorator_to_function(
self.function_to_optimize.file_path,
self.function_to_optimize,
TestingMode.PERFORMANCE,
project_root=self.project_root,
)
if self.function_to_optimize.is_async:
self.instrument_async_for_mode(TestingMode.PERFORMANCE)
try:
benchmarking_results, _ = self.run_and_parse_tests(
@ -2472,22 +2385,9 @@ class FunctionOptimizer:
console.rule()
logger.debug(f"Total original code runtime (ns): {total_timing}")
async_throughput = None
concurrency_metrics = None
if self.function_to_optimize.is_async and is_python():
async_throughput = calculate_function_throughput_from_test_results(
benchmarking_results, self.function_to_optimize.function_name
)
logger.debug(f"Original async function throughput: {async_throughput} calls/second")
concurrency_metrics = self.run_concurrency_benchmark(
code_context=code_context, original_helper_code=original_helper_code, test_env=test_env
)
if concurrency_metrics:
logger.debug(
f"Original concurrency metrics: ratio={concurrency_metrics.concurrency_ratio:.2f}, "
f"seq={concurrency_metrics.sequential_time_ns}ns, conc={concurrency_metrics.concurrent_time_ns}ns"
)
async_throughput, concurrency_metrics = self.collect_async_metrics(
benchmarking_results, code_context, original_helper_code, test_env
)
if self.args.benchmark:
replay_benchmarking_test_results = benchmarking_results.group_by_benchmarks(
@ -2589,22 +2489,11 @@ class FunctionOptimizer:
candidate_helper_code = {}
for module_abspath in original_helper_code:
candidate_helper_code[module_abspath] = Path(module_abspath).read_text("utf-8")
if self.function_to_optimize.is_async and is_python():
from codeflash.code_utils.instrument_existing_tests import add_async_decorator_to_function
add_async_decorator_to_function(
self.function_to_optimize.file_path,
self.function_to_optimize,
TestingMode.BEHAVIOR,
project_root=self.project_root,
)
if self.function_to_optimize.is_async:
self.instrument_async_for_mode(TestingMode.BEHAVIOR)
try:
# Only instrument Python code here - non-Python languages use their own runtime helpers
if is_python():
instrument_codeflash_capture(
self.function_to_optimize, file_path_to_helper_classes, self.test_cfg.tests_root
)
self.instrument_capture(file_path_to_helper_classes)
total_looping_time = TOTAL_LOOPING_TIME_EFFECTIVE
candidate_behavior_results, _ = self.run_and_parse_tests(
@ -2615,13 +2504,10 @@ class FunctionOptimizer:
testing_time=total_looping_time,
enable_coverage=False,
)
# Remove instrumentation
finally:
# Only restore code for Python - non-Python tests are self-contained
if is_python():
self.write_code_and_helpers(
candidate_fto_code, candidate_helper_code, self.function_to_optimize.file_path
)
self.write_code_and_helpers(
candidate_fto_code, candidate_helper_code, self.function_to_optimize.file_path
)
console.print(
TestResults.report_to_tree(
candidate_behavior_results.get_test_pass_fail_report_by_type(),
@ -2630,30 +2516,9 @@ class FunctionOptimizer:
)
console.rule()
# Use language-appropriate comparison
if not is_python():
# Non-Python: Compare using language support with SQLite results if available
original_sqlite = get_run_tmp_file(Path("test_return_values_0.sqlite"))
candidate_sqlite = get_run_tmp_file(Path(f"test_return_values_{optimization_candidate_index}.sqlite"))
if original_sqlite.exists() and candidate_sqlite.exists():
# Full comparison using captured return values via language support
# Use js_project_root where node_modules is located
js_root = self.test_cfg.js_project_root or self.args.project_root
match, diffs = self.language_support.compare_test_results(
original_sqlite, candidate_sqlite, project_root=js_root
)
# Cleanup SQLite files after comparison
candidate_sqlite.unlink(missing_ok=True)
else:
# Fallback: compare test pass/fail status (tests aren't instrumented yet)
# If all tests that passed for original also pass for candidate, consider it a match
match, diffs = compare_test_results(
baseline_results.behavior_test_results, candidate_behavior_results, pass_fail_only=True
)
else:
# Python: Compare using Python comparator
match, diffs = compare_test_results(baseline_results.behavior_test_results, candidate_behavior_results)
match, diffs = self.compare_candidate_results(
baseline_results, candidate_behavior_results, optimization_candidate_index
)
if match:
logger.info("h3|Test results matched ✅")
@ -2667,16 +2532,8 @@ class FunctionOptimizer:
logger.info(f"loading|Running performance tests for candidate {optimization_candidate_index}...")
console.rule()
# For async functions, instrument at definition site for performance benchmarking
if self.function_to_optimize.is_async and is_python():
from codeflash.code_utils.instrument_existing_tests import add_async_decorator_to_function
add_async_decorator_to_function(
self.function_to_optimize.file_path,
self.function_to_optimize,
TestingMode.PERFORMANCE,
project_root=self.project_root,
)
if self.function_to_optimize.is_async:
self.instrument_async_for_mode(TestingMode.PERFORMANCE)
try:
candidate_benchmarking_results, _ = self.run_and_parse_tests(
@ -2688,8 +2545,7 @@ class FunctionOptimizer:
enable_coverage=False,
)
finally:
# Restore original source if we instrumented it
if self.function_to_optimize.is_async and is_python():
if self.function_to_optimize.is_async:
self.write_code_and_helpers(
candidate_fto_code, candidate_helper_code, self.function_to_optimize.file_path
)
@ -2704,23 +2560,9 @@ class FunctionOptimizer:
logger.debug(f"Total optimized code {optimization_candidate_index} runtime (ns): {total_candidate_timing}")
candidate_async_throughput = None
candidate_concurrency_metrics = None
if self.function_to_optimize.is_async and is_python():
candidate_async_throughput = calculate_function_throughput_from_test_results(
candidate_benchmarking_results, self.function_to_optimize.function_name
)
logger.debug(f"Candidate async function throughput: {candidate_async_throughput} calls/second")
# Run concurrency benchmark for candidate
candidate_concurrency_metrics = self.run_concurrency_benchmark(
code_context=code_context, original_helper_code=candidate_helper_code, test_env=test_env
)
if candidate_concurrency_metrics:
logger.debug(
f"Candidate concurrency metrics: ratio={candidate_concurrency_metrics.concurrency_ratio:.2f}, "
f"seq={candidate_concurrency_metrics.sequential_time_ns}ns, conc={candidate_concurrency_metrics.concurrent_time_ns}ns"
)
candidate_async_throughput, candidate_concurrency_metrics = self.collect_async_metrics(
candidate_benchmarking_results, code_context, candidate_helper_code, test_env
)
if self.args.benchmark:
candidate_replay_benchmarking_results = candidate_benchmarking_results.group_by_benchmarks(
@ -2764,40 +2606,36 @@ class FunctionOptimizer:
coverage_config_file = None
try:
if testing_type == TestingMode.BEHAVIOR:
result_file_path, run_result, coverage_database_file, coverage_config_file = run_behavioral_tests(
test_files,
test_framework=self.test_cfg.test_framework,
cwd=self.project_root,
test_env=test_env,
pytest_timeout=INDIVIDUAL_TESTCASE_TIMEOUT,
enable_coverage=enable_coverage,
js_project_root=self.test_cfg.js_project_root,
candidate_index=optimization_iteration,
result_file_path, run_result, coverage_database_file, coverage_config_file = (
self.language_support.run_behavioral_tests(
test_paths=test_files,
test_env=test_env,
cwd=self.project_root,
timeout=INDIVIDUAL_TESTCASE_TIMEOUT,
project_root=self.test_cfg.js_project_root,
enable_coverage=enable_coverage,
candidate_index=optimization_iteration,
)
)
elif testing_type == TestingMode.LINE_PROFILE:
result_file_path, run_result = run_line_profile_tests(
test_files,
cwd=self.project_root,
result_file_path, run_result = self.language_support.run_line_profile_tests(
test_paths=test_files,
test_env=test_env,
pytest_cmd=self.test_cfg.pytest_cmd,
pytest_timeout=INDIVIDUAL_TESTCASE_TIMEOUT,
pytest_target_runtime_seconds=testing_time,
test_framework=self.test_cfg.test_framework,
js_project_root=self.test_cfg.js_project_root,
line_profiler_output_file=line_profiler_output_file,
cwd=self.project_root,
timeout=INDIVIDUAL_TESTCASE_TIMEOUT,
project_root=self.test_cfg.js_project_root,
line_profile_output_file=line_profiler_output_file,
)
elif testing_type == TestingMode.PERFORMANCE:
result_file_path, run_result = run_benchmarking_tests(
test_files,
cwd=self.project_root,
result_file_path, run_result = self.language_support.run_benchmarking_tests(
test_paths=test_files,
test_env=test_env,
pytest_cmd=self.test_cfg.pytest_cmd,
pytest_timeout=INDIVIDUAL_TESTCASE_TIMEOUT,
pytest_target_runtime_seconds=testing_time,
pytest_min_loops=pytest_min_loops,
pytest_max_loops=pytest_max_loops,
test_framework=self.test_cfg.test_framework,
js_project_root=self.test_cfg.js_project_root,
cwd=self.project_root,
timeout=INDIVIDUAL_TESTCASE_TIMEOUT,
project_root=self.test_cfg.js_project_root,
min_loops=pytest_min_loops,
max_loops=pytest_max_loops,
target_duration_seconds=testing_time,
)
else:
msg = f"Unexpected testing type: {testing_type}"
@ -2829,9 +2667,7 @@ class FunctionOptimizer:
console.print(panel)
if testing_type in {TestingMode.BEHAVIOR, TestingMode.PERFORMANCE}:
# For non-Python behavior tests, skip SQLite cleanup - files needed for language-native comparison
non_python_original_code = not is_python() and optimization_iteration == 0
skip_cleanup = (not is_python() and testing_type == TestingMode.BEHAVIOR) or non_python_original_code
skip_cleanup = self.should_skip_sqlite_cleanup(testing_type, optimization_iteration)
results, coverage_results = parse_test_results(
test_xml_path=result_file_path,
@ -2849,13 +2685,7 @@ class FunctionOptimizer:
if testing_type == TestingMode.PERFORMANCE:
results.perf_stdout = run_result.stdout
return results, coverage_results
# For LINE_PROFILE mode, Python uses .lprof files while JavaScript uses JSON
# Return TestResults for JavaScript so _line_profiler_step_javascript can parse the JSON
if not is_python():
# Return TestResults to indicate tests ran, actual parsing happens in _line_profiler_step_javascript
return TestResults(test_results=[]), None
results, coverage_results = parse_line_profile_results(line_profiler_output_file=line_profiler_output_file)
return results, coverage_results
return self.parse_line_profile_test_results(line_profiler_output_file)
def submit_test_generation_tasks(
self,
@ -2912,108 +2742,9 @@ class FunctionOptimizer:
def line_profiler_step(
self, code_context: CodeOptimizationContext, original_helper_code: dict[Path, str], candidate_index: int
) -> dict:
# Dispatch to language-specific implementation
if is_python():
return self._line_profiler_step_python(code_context, original_helper_code, candidate_index)
if self.language_support is not None and hasattr(self.language_support, "instrument_source_for_line_profiler"):
try:
line_profiler_output_path = get_run_tmp_file(Path("line_profiler_output.json"))
# NOTE: currently this handles single file only, add support to multi file instrumentation (or should it be kept for the main file only)
original_source = Path(self.function_to_optimize.file_path).read_text()
# Instrument source code
success = self.language_support.instrument_source_for_line_profiler(
func_info=self.function_to_optimize, line_profiler_output_file=line_profiler_output_path
)
if not success:
return {"timings": {}, "unit": 0, "str_out": ""}
test_env = self.get_test_env(
codeflash_loop_index=0, codeflash_test_iteration=candidate_index, codeflash_tracer_disable=1
)
_test_results, _ = self.run_and_parse_tests(
testing_type=TestingMode.LINE_PROFILE,
test_env=test_env,
test_files=self.test_files,
optimization_iteration=0,
testing_time=TOTAL_LOOPING_TIME_EFFECTIVE,
enable_coverage=False,
code_context=code_context,
line_profiler_output_file=line_profiler_output_path,
)
if not hasattr(self.language_support, "parse_line_profile_results"):
raise ValueError("Language support does not implement parse_line_profile_results") # noqa: TRY301
return self.language_support.parse_line_profile_results(line_profiler_output_path)
except Exception as e:
logger.warning(f"Failed to run line profiling: {e}")
return {"timings": {}, "unit": 0, "str_out": ""}
finally:
# restore original source
Path(self.function_to_optimize.file_path).write_text(original_source)
logger.warning(f"Language support for {self.language_support.language} doesn't support line profiling")
) -> dict[str, Any]:
return {"timings": {}, "unit": 0, "str_out": ""}
def _line_profiler_step_python(
self, code_context: CodeOptimizationContext, original_helper_code: dict[Path, str], candidate_index: int
) -> dict:
"""Python-specific line profiler using decorator imports."""
# Check if candidate code contains JIT decorators - line profiler doesn't work with JIT compiled code
candidate_fto_code = Path(self.function_to_optimize.file_path).read_text("utf-8")
if contains_jit_decorator(candidate_fto_code):
logger.info(
f"Skipping line profiler for {self.function_to_optimize.function_name} - code contains JIT decorator"
)
return {"timings": {}, "unit": 0, "str_out": ""}
# Check helper code for JIT decorators
for module_abspath in original_helper_code:
candidate_helper_code = Path(module_abspath).read_text("utf-8")
if contains_jit_decorator(candidate_helper_code):
logger.info(
f"Skipping line profiler for {self.function_to_optimize.function_name} - helper code contains JIT decorator"
)
return {"timings": {}, "unit": 0, "str_out": ""}
try:
console.rule()
test_env = self.get_test_env(
codeflash_loop_index=0, codeflash_test_iteration=candidate_index, codeflash_tracer_disable=1
)
line_profiler_output_file = add_decorator_imports(self.function_to_optimize, code_context)
line_profile_results, _ = self.run_and_parse_tests(
testing_type=TestingMode.LINE_PROFILE,
test_env=test_env,
test_files=self.test_files,
optimization_iteration=0,
testing_time=TOTAL_LOOPING_TIME_EFFECTIVE,
enable_coverage=False,
code_context=code_context,
line_profiler_output_file=line_profiler_output_file,
)
finally:
# Remove codeflash capture
self.write_code_and_helpers(
self.function_to_optimize_source_code, original_helper_code, self.function_to_optimize.file_path
)
# this will happen when a timeoutexpired exception happens
if isinstance(line_profile_results, TestResults) and not line_profile_results.test_results:
logger.warning(
f"Timeout occurred while running line profiler for original function {self.function_to_optimize.function_name}"
)
# set default value for line profiler results
return {"timings": {}, "unit": 0, "str_out": ""}
if line_profile_results["str_out"] == "":
logger.warning(
f"Couldn't run line profiler for original function {self.function_to_optimize.function_name}"
)
return line_profile_results
def run_concurrency_benchmark(
self, code_context: CodeOptimizationContext, original_helper_code: dict[Path, str], test_env: dict[str, str]
) -> ConcurrencyMetrics | None:

View file

@ -1,6 +1,5 @@
from __future__ import annotations
import ast
import copy
import os
import tempfile
@ -30,20 +29,20 @@ from codeflash.code_utils.git_worktree_utils import (
)
from codeflash.code_utils.time_utils import humanize_runtime
from codeflash.either import is_successful
from codeflash.languages import current_language_support, is_javascript, set_current_language
from codeflash.languages import current_language_support, set_current_language
from codeflash.lsp.helpers import is_subagent_mode
from codeflash.models.models import ValidCode
from codeflash.telemetry.posthog_cf import ph
from codeflash.verification.verification_utils import TestConfig
if TYPE_CHECKING:
import ast
from argparse import Namespace
from codeflash.benchmarking.function_ranker import FunctionRanker
from codeflash.code_utils.checkpoint import CodeflashRunCheckpoint
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
from codeflash.languages.base import DependencyResolver
from codeflash.models.models import BenchmarkKey, FunctionCalledInTest
from codeflash.models.models import BenchmarkKey, FunctionCalledInTest, ValidCode
from codeflash.optimization.function_optimizer import FunctionOptimizer
@ -72,59 +71,6 @@ class Optimizer:
self.original_args_and_test_cfg: tuple[Namespace, TestConfig] | None = None
self.patch_files: list[Path] = []
@staticmethod
def _find_js_project_root(file_path: Path) -> Path | None:
"""Find the JavaScript/TypeScript project root by looking for package.json.
Traverses up from the given file path to find the nearest directory
containing package.json or jest.config.js.
Args:
file_path: A file path within the JavaScript project.
Returns:
The project root directory, or None if not found.
"""
current = file_path.parent if file_path.is_file() else file_path
while current != current.parent: # Stop at filesystem root
if (
(current / "package.json").exists()
or (current / "jest.config.js").exists()
or (current / "jest.config.ts").exists()
or (current / "tsconfig.json").exists()
):
return current
current = current.parent
return None
def _verify_js_requirements(self) -> None:
"""Verify JavaScript/TypeScript requirements before optimization.
Checks that Node.js, npm, and the test framework are available.
Logs warnings if requirements are not met but does not abort.
"""
from codeflash.languages import get_language_support
from codeflash.languages.base import Language
from codeflash.languages.test_framework import get_js_test_framework_or_default
js_project_root = self.test_cfg.js_project_root
if not js_project_root:
return
try:
js_support = get_language_support(Language.JAVASCRIPT)
test_framework = get_js_test_framework_or_default()
success, errors = js_support.verify_requirements(js_project_root, test_framework)
if not success:
logger.warning("JavaScript requirements check found issues:")
for error in errors:
logger.warning(f" - {error}")
except Exception as e:
logger.debug(f"Failed to verify JS requirements: {e}")
def run_benchmarks(
self, file_to_funcs_to_optimize: dict[Path, list[FunctionToOptimize]], num_optimizable_functions: int
) -> tuple[dict[str, dict[BenchmarkKey, float]], dict[BenchmarkKey, float]]:
@ -247,26 +193,8 @@ class Optimizer:
function_to_optimize_source_code: str | None = "",
function_benchmark_timings: dict[str, dict[BenchmarkKey, float]] | None = None,
total_benchmark_timings: dict[BenchmarkKey, float] | None = None,
original_module_ast: ast.Module | None = None,
original_module_path: Path | None = None,
call_graph: DependencyResolver | None = None,
) -> FunctionOptimizer | None:
from codeflash.languages.python.static_analysis.static_analysis import (
get_first_top_level_function_or_method_ast,
)
from codeflash.optimization.function_optimizer import FunctionOptimizer
if function_to_optimize_ast is None and original_module_ast is not None:
function_to_optimize_ast = get_first_top_level_function_or_method_ast(
function_to_optimize.function_name, function_to_optimize.parents, original_module_ast
)
if function_to_optimize_ast is None:
logger.info(
f"Function {function_to_optimize.qualified_name} not found in {original_module_path}.\n"
f"Skipping optimization."
)
return None
qualified_name_w_module = function_to_optimize.qualified_name_with_modules_from_root(self.args.project_root)
function_specific_timings = None
@ -279,7 +207,11 @@ class Optimizer:
):
function_specific_timings = function_benchmark_timings[qualified_name_w_module]
return FunctionOptimizer(
cls = current_language_support().function_optimizer_class
# TODO: _resolve_function_ast re-parses source via ast.parse() per function, even when the caller already
# has a parsed module AST. Consider passing the pre-parsed AST through to avoid redundant parsing.
function_optimizer = cls(
function_to_optimize=function_to_optimize,
test_cfg=self.test_cfg,
function_to_optimize_source_code=function_to_optimize_source_code,
@ -292,62 +224,26 @@ class Optimizer:
replay_tests_dir=self.replay_tests_dir,
call_graph=call_graph,
)
if function_optimizer.function_to_optimize_ast is None and function_optimizer.requires_function_ast():
logger.info(
f"Function {function_to_optimize.qualified_name} not found in "
f"{function_to_optimize.file_path}.\nSkipping optimization."
)
return None
return function_optimizer
def prepare_module_for_optimization(
self, original_module_path: Path
) -> tuple[dict[Path, ValidCode], ast.Module | None] | None:
from codeflash.languages.python.static_analysis.code_replacer import normalize_code, normalize_node
from codeflash.languages.python.static_analysis.static_analysis import analyze_imported_modules
logger.info(f"loading|Examining file {original_module_path!s}")
console.rule()
original_module_code: str = original_module_path.read_text(encoding="utf8")
# For JavaScript/TypeScript, skip Python-specific AST parsing
if is_javascript():
validated_original_code: dict[Path, ValidCode] = {
original_module_path: ValidCode(source_code=original_module_code, normalized_code=original_module_code)
}
return validated_original_code, None
# Python-specific parsing
try:
original_module_ast = ast.parse(original_module_code)
except SyntaxError as e:
logger.warning(f"Syntax error parsing code in {original_module_path}: {e}")
logger.info("Skipping optimization due to file error.")
return None
normalized_original_module_code = ast.unparse(normalize_node(original_module_ast))
validated_original_code = {
original_module_path: ValidCode(
source_code=original_module_code, normalized_code=normalized_original_module_code
)
}
imported_module_analyses = analyze_imported_modules(
return current_language_support().prepare_module(
original_module_code, original_module_path, self.args.project_root
)
has_syntax_error = False
for analysis in imported_module_analyses:
callee_original_code = analysis.file_path.read_text(encoding="utf8")
try:
normalized_callee_original_code = normalize_code(callee_original_code)
except SyntaxError as e:
logger.warning(f"Syntax error parsing code in callee module {analysis.file_path}: {e}")
logger.info("Skipping optimization due to helper file error.")
has_syntax_error = True
break
validated_original_code[analysis.file_path] = ValidCode(
source_code=callee_original_code, normalized_code=normalized_callee_original_code
)
if has_syntax_error:
return None
return validated_original_code, original_module_ast
def discover_tests(
self, file_to_funcs_to_optimize: dict[Path, list[FunctionToOptimize]]
) -> tuple[dict[str, set[FunctionCalledInTest]], int]:
@ -556,11 +452,7 @@ class Optimizer:
if funcs and funcs[0].language:
set_current_language(funcs[0].language)
self.test_cfg.set_language(funcs[0].language)
# For JavaScript, also set js_project_root for test execution
if is_javascript():
self.test_cfg.js_project_root = self._find_js_project_root(file_path)
# Verify JS requirements before proceeding
self._verify_js_requirements()
current_language_support().setup_test_config(self.test_cfg, file_path)
break
if self.args.all:
@ -624,7 +516,7 @@ class Optimizer:
continue
prepared_modules[original_module_path] = module_prep_result
validated_original_code, original_module_ast = prepared_modules[original_module_path]
validated_original_code, _original_module_ast = prepared_modules[original_module_path]
function_iterator_count = i + 1
logger.info(
@ -640,8 +532,6 @@ class Optimizer:
function_to_optimize_source_code=validated_original_code[original_module_path].source_code,
function_benchmark_timings=function_benchmark_timings,
total_benchmark_timings=total_benchmark_timings,
original_module_ast=original_module_ast,
original_module_path=original_module_path,
call_graph=resolver,
)
if function_optimizer is None:
@ -762,6 +652,8 @@ class Optimizer:
if hasattr(get_run_tmp_file, "tmpdir"):
get_run_tmp_file.tmpdir.cleanup()
del get_run_tmp_file.tmpdir
if hasattr(get_run_tmp_file, "tmpdir_path"):
del get_run_tmp_file.tmpdir_path
# Always clean up concolic test directory
cleanup_paths([self.test_cfg.concolic_test_root_dir])

View file

@ -1,133 +0,0 @@
from __future__ import annotations
import ast
import importlib.util
import subprocess
import tempfile
import time
from pathlib import Path
from typing import TYPE_CHECKING
from codeflash.cli_cmds.console import console, logger
from codeflash.code_utils.compat import SAFE_SYS_EXECUTABLE
from codeflash.code_utils.shell_utils import make_env_with_project_root
from codeflash.discovery.discover_unit_tests import discover_unit_tests
from codeflash.languages import is_python
from codeflash.languages.python.static_analysis.concolic_utils import clean_concolic_tests, is_valid_concolic_test
from codeflash.languages.python.static_analysis.static_analysis import has_typed_parameters
from codeflash.lsp.helpers import is_LSP_enabled
from codeflash.telemetry.posthog_cf import ph
from codeflash.verification.verification_utils import TestConfig
CROSSHAIR_AVAILABLE = importlib.util.find_spec("crosshair") is not None
if TYPE_CHECKING:
from argparse import Namespace
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
from codeflash.models.models import FunctionCalledInTest
def generate_concolic_tests(
test_cfg: TestConfig, args: Namespace, function_to_optimize: FunctionToOptimize, function_to_optimize_ast: ast.AST
) -> tuple[dict[str, set[FunctionCalledInTest]], str]:
"""Generate concolic tests using CrossHair (Python only).
CrossHair is a Python-specific symbolic execution tool. For non-Python languages
(JavaScript, TypeScript, etc.), this function returns early with empty results.
Args:
test_cfg: Test configuration
args: Command line arguments
function_to_optimize: The function being optimized
function_to_optimize_ast: AST of the function (Python ast.FunctionDef)
Returns:
Tuple of (function_to_tests mapping, concolic test suite code)
"""
start_time = time.perf_counter()
function_to_concolic_tests = {}
concolic_test_suite_code = ""
# CrossHair is Python-only - skip for other languages
if not is_python():
logger.debug("Skipping concolic test generation for non-Python languages (CrossHair is Python-only)")
return function_to_concolic_tests, concolic_test_suite_code
if not CROSSHAIR_AVAILABLE:
logger.debug("Skipping concolic test generation (crosshair-tool is not installed)")
return function_to_concolic_tests, concolic_test_suite_code
if is_LSP_enabled():
logger.debug("Skipping concolic test generation in LSP mode")
return function_to_concolic_tests, concolic_test_suite_code
if (
test_cfg.concolic_test_root_dir
and isinstance(function_to_optimize_ast, ast.FunctionDef)
and has_typed_parameters(function_to_optimize_ast, function_to_optimize.parents)
):
logger.info("Generating concolic opcode coverage tests for the original code…")
console.rule()
try:
env = make_env_with_project_root(args.project_root)
cover_result = subprocess.run(
[
SAFE_SYS_EXECUTABLE,
"-m",
"crosshair",
"cover",
"--example_output_format=pytest",
"--per_condition_timeout=20",
".".join(
[
function_to_optimize.file_path.relative_to(args.project_root)
.with_suffix("")
.as_posix()
.replace("/", "."),
function_to_optimize.qualified_name,
]
),
],
capture_output=True,
text=True,
cwd=args.project_root,
check=False,
timeout=600,
env=env,
)
except subprocess.TimeoutExpired:
logger.debug("CrossHair Cover test generation timed out")
return function_to_concolic_tests, concolic_test_suite_code
if cover_result.returncode == 0:
generated_concolic_test: str = cover_result.stdout
if not is_valid_concolic_test(generated_concolic_test, project_root=str(args.project_root)):
logger.debug("CrossHair generated invalid test, skipping")
console.rule()
return function_to_concolic_tests, concolic_test_suite_code
concolic_test_suite_code: str = clean_concolic_tests(generated_concolic_test)
concolic_test_suite_dir = Path(tempfile.mkdtemp(dir=test_cfg.concolic_test_root_dir))
concolic_test_suite_path = concolic_test_suite_dir / "test_concolic_coverage.py"
concolic_test_suite_path.write_text(concolic_test_suite_code, encoding="utf8")
concolic_test_cfg = TestConfig(
tests_root=concolic_test_suite_dir,
tests_project_rootdir=test_cfg.concolic_test_root_dir,
project_root_path=args.project_root,
)
function_to_concolic_tests, num_discovered_concolic_tests, _ = discover_unit_tests(concolic_test_cfg)
logger.info(
f"Created {num_discovered_concolic_tests} "
f"concolic unit test case{'s' if num_discovered_concolic_tests != 1 else ''} "
)
console.rule()
ph("cli-optimize-concolic-tests", {"num_tests": num_discovered_concolic_tests})
else:
logger.debug(f"Error running CrossHair Cover {': ' + cover_result.stderr if cover_result.stderr else '.'}")
console.rule()
end_time = time.perf_counter()
logger.debug(f"Generated concolic tests in {end_time - start_time:.2f} seconds")
return function_to_concolic_tests, concolic_test_suite_code

View file

@ -20,7 +20,7 @@ from codeflash.code_utils.code_utils import (
module_name_from_file_path,
)
from codeflash.discovery.discover_unit_tests import discover_parameters_unittest
from codeflash.languages import is_javascript
from codeflash.languages import Language
# Import Jest-specific parsing from the JavaScript language module
from codeflash.languages.javascript.parse import parse_jest_test_xml as _parse_jest_test_xml
@ -32,7 +32,6 @@ from codeflash.models.models import (
TestType,
VerificationType,
)
from codeflash.verification.coverage_utils import CoverageUtils, JestCoverageUtils
if TYPE_CHECKING:
import subprocess
@ -438,8 +437,10 @@ def parse_sqlite_test_results(sqlite_file_path: Path, test_files: TestFiles, tes
finally:
db.close()
# Check if this is a JavaScript test (use JSON) or Python test (use pickle)
is_jest = is_javascript()
# Check serialization format: JavaScript uses JSON, Python uses pickle
from codeflash.languages.current import current_language_support
is_json_format = current_language_support().test_result_serialization_format == "json"
for val in data:
try:
@ -452,7 +453,7 @@ def parse_sqlite_test_results(sqlite_file_path: Path, test_files: TestFiles, tes
# - A module-style path: "tests.fibonacci.test.ts" (dots as separators)
# - A file path: "tests/fibonacci.test.ts" (slashes as separators)
# For Python, it's a module path (e.g., "tests.test_foo") that needs conversion
if is_jest:
if is_json_format:
# Jest test file extensions (including .test.ts, .spec.ts patterns)
jest_test_extensions = (
".test.ts",
@ -517,7 +518,7 @@ def parse_sqlite_test_results(sqlite_file_path: Path, test_files: TestFiles, tes
test_type = test_files.get_test_type_by_instrumented_file_path(test_file_path)
logger.debug(f"[PARSE-DEBUG] by_instrumented_file_path: {test_type}")
# Default to GENERATED_REGRESSION for Jest tests when test type can't be determined
if test_type is None and is_jest:
if test_type is None and is_json_format:
test_type = TestType.GENERATED_REGRESSION
logger.debug("[PARSE-DEBUG] defaulting to GENERATED_REGRESSION (Jest)")
elif test_type is None:
@ -532,7 +533,7 @@ def parse_sqlite_test_results(sqlite_file_path: Path, test_files: TestFiles, tes
ret_val = None
if loop_index == 1 and val[7]:
try:
if is_jest:
if is_json_format:
# Jest comparison happens via Node.js script (language_support.compare_test_results)
# Store a marker indicating data exists but is not deserialized in Python
ret_val = ("__serialized__", val[7])
@ -578,7 +579,9 @@ def parse_test_xml(
run_result: subprocess.CompletedProcess | None = None,
) -> TestResults:
# Route to Jest-specific parser for JavaScript/TypeScript tests
if is_javascript():
from codeflash.languages.current import current_language
if current_language() in (Language.JAVASCRIPT, Language.TYPESCRIPT):
return _parse_jest_test_xml(
test_xml_file_path,
test_files,
@ -1000,7 +1003,9 @@ def parse_test_results(
# Also try to read legacy binary format for Python tests
# Binary file may contain additional results (e.g., from codeflash_wrap) even if SQLite has data
# from @codeflash_capture. We need to merge both sources.
if not is_javascript():
from codeflash.languages.current import current_language_support as _cls
if _cls().test_result_serialization_format == "pickle":
try:
bin_results_file = get_run_tmp_file(Path(f"test_return_values_{optimization_iteration}.bin"))
if bin_results_file.exists():
@ -1036,23 +1041,13 @@ def parse_test_results(
coverage = None
if coverage_database_file and source_file and code_context and function_name:
all_args = True
if is_javascript():
# Jest uses coverage-final.json (coverage_database_file points to this)
coverage = JestCoverageUtils.load_from_jest_json(
coverage_json_path=coverage_database_file,
function_name=function_name,
code_context=code_context,
source_code_path=source_file,
)
else:
# Python uses coverage.py SQLite database
coverage = CoverageUtils.load_from_sqlite_database(
database_path=coverage_database_file,
config_path=coverage_config_file,
source_code_path=source_file,
code_context=code_context,
function_name=function_name,
)
coverage = _cls().load_coverage(
coverage_database_file=coverage_database_file,
function_name=function_name,
code_context=code_context,
source_file=source_file,
coverage_config_file=coverage_config_file,
)
coverage.log_coverage()
try:
failures = parse_test_failures_from_stdout(run_result.stdout)

View file

@ -1,29 +1,17 @@
from __future__ import annotations
import contextlib
import re
import shlex
import shutil
import subprocess
import sys
from pathlib import Path
from typing import TYPE_CHECKING
from codeflash.cli_cmds.console import logger
from codeflash.code_utils.code_utils import custom_addopts, get_run_tmp_file
from codeflash.code_utils.compat import IS_POSIX, SAFE_SYS_EXECUTABLE
from codeflash.code_utils.config_consts import TOTAL_LOOPING_TIME_EFFECTIVE
from codeflash.code_utils.shell_utils import get_cross_platform_subprocess_run_args
from codeflash.languages import is_python
from codeflash.languages.python.static_analysis.coverage_utils import prepare_coverage_files
from codeflash.languages.registry import get_language_support, get_language_support_by_framework
from codeflash.models.models import TestFiles, TestType
if TYPE_CHECKING:
from codeflash.models.models import TestFiles
from pathlib import Path
BEHAVIORAL_BLOCKLISTED_PLUGINS = ["benchmark", "codspeed", "xdist", "sugar"]
BENCHMARKING_BLOCKLISTED_PLUGINS = ["codspeed", "cov", "benchmark", "profiling", "xdist", "sugar"]
from codeflash.cli_cmds.console import logger
from codeflash.code_utils.code_utils import custom_addopts
from codeflash.code_utils.shell_utils import get_cross_platform_subprocess_run_args
from codeflash.languages.registry import get_language_support
# Pattern to extract timing from stdout markers: !######...:<duration_ns>######!
# Jest markers have multiple colons: !######module:test:func:loop:id:duration######!
@ -112,267 +100,3 @@ def execute_test_subprocess(
cwd=cwd, env=env, timeout=timeout, check=False, text=True, capture_output=True
)
return subprocess.run(cmd_list, **run_args) # noqa: PLW1510
def run_behavioral_tests(
test_paths: TestFiles,
test_framework: str,
test_env: dict[str, str],
cwd: Path,
*,
pytest_timeout: int | None = None,
pytest_cmd: str = "pytest",
pytest_target_runtime_seconds: float = TOTAL_LOOPING_TIME_EFFECTIVE,
enable_coverage: bool = False,
js_project_root: Path | None = None,
candidate_index: int = 0,
) -> tuple[Path, subprocess.CompletedProcess, Path | None, Path | None]:
"""Run behavioral tests with optional coverage."""
# Check if there's a language support for this test framework that implements run_behavioral_tests
language_support = get_language_support_by_framework(test_framework)
if language_support is not None and hasattr(language_support, "run_behavioral_tests"):
return language_support.run_behavioral_tests(
test_paths=test_paths,
test_env=test_env,
cwd=cwd,
timeout=pytest_timeout,
project_root=js_project_root,
enable_coverage=enable_coverage,
candidate_index=candidate_index,
)
if is_python():
test_files: list[str] = []
for file in test_paths.test_files:
if file.test_type == TestType.REPLAY_TEST:
# Replay tests need specific test targeting because one file contains tests for multiple functions
if file.tests_in_file:
test_files.extend(
[
str(file.instrumented_behavior_file_path) + "::" + test.test_function
for test in file.tests_in_file
]
)
else:
test_files.append(str(file.instrumented_behavior_file_path))
pytest_cmd_list = (
shlex.split(f"{SAFE_SYS_EXECUTABLE} -m pytest", posix=IS_POSIX)
if pytest_cmd == "pytest"
else [SAFE_SYS_EXECUTABLE, "-m", *shlex.split(pytest_cmd, posix=IS_POSIX)]
)
test_files = list(set(test_files)) # remove multiple calls in the same test function
common_pytest_args = [
"--capture=tee-sys",
"-q",
"--codeflash_loops_scope=session",
"--codeflash_min_loops=1",
"--codeflash_max_loops=1",
f"--codeflash_seconds={pytest_target_runtime_seconds}",
]
if pytest_timeout is not None:
common_pytest_args.append(f"--timeout={pytest_timeout}")
result_file_path = get_run_tmp_file(Path("pytest_results.xml"))
result_args = [f"--junitxml={result_file_path.as_posix()}", "-o", "junit_logging=all"]
pytest_test_env = test_env.copy()
pytest_test_env["PYTEST_PLUGINS"] = "codeflash.verification.pytest_plugin"
if enable_coverage:
coverage_database_file, coverage_config_file = prepare_coverage_files()
# disable jit for coverage
pytest_test_env["NUMBA_DISABLE_JIT"] = str(1)
pytest_test_env["TORCHDYNAMO_DISABLE"] = str(1)
pytest_test_env["PYTORCH_JIT"] = str(0)
pytest_test_env["TF_XLA_FLAGS"] = "--tf_xla_auto_jit=0"
pytest_test_env["TF_ENABLE_ONEDNN_OPTS"] = str(0)
pytest_test_env["JAX_DISABLE_JIT"] = str(0)
is_windows = sys.platform == "win32"
if is_windows:
# On Windows, delete coverage database file directly instead of using 'coverage erase', to avoid locking issues
if coverage_database_file.exists():
with contextlib.suppress(PermissionError, OSError):
coverage_database_file.unlink()
else:
cov_erase = execute_test_subprocess(
shlex.split(f"{SAFE_SYS_EXECUTABLE} -m coverage erase"), cwd=cwd, env=pytest_test_env, timeout=30
) # this cleanup is necessary to avoid coverage data from previous runs, if there are any, then the current run will be appended to the previous data, which skews the results
logger.debug(cov_erase)
coverage_cmd = [
SAFE_SYS_EXECUTABLE,
"-m",
"coverage",
"run",
f"--rcfile={coverage_config_file.as_posix()}",
"-m",
]
if pytest_cmd == "pytest":
coverage_cmd.extend(["pytest"])
else:
coverage_cmd.extend(shlex.split(pytest_cmd, posix=IS_POSIX)[1:])
blocklist_args = [f"-p no:{plugin}" for plugin in BEHAVIORAL_BLOCKLISTED_PLUGINS if plugin != "cov"]
results = execute_test_subprocess(
coverage_cmd + common_pytest_args + blocklist_args + result_args + test_files,
cwd=cwd,
env=pytest_test_env,
timeout=600,
)
logger.debug(
f"Result return code: {results.returncode}, "
f"{'Result stderr:' + str(results.stderr) if results.stderr else ''}"
)
else:
blocklist_args = [f"-p no:{plugin}" for plugin in BEHAVIORAL_BLOCKLISTED_PLUGINS]
results = execute_test_subprocess(
pytest_cmd_list + common_pytest_args + blocklist_args + result_args + test_files,
cwd=cwd,
env=pytest_test_env,
timeout=600, # TODO: Make this dynamic
)
logger.debug(
f"""Result return code: {results.returncode}, {"Result stderr:" + str(results.stderr) if results.stderr else ""}"""
)
else:
msg = f"Unsupported test framework: {test_framework}"
raise ValueError(msg)
return (
result_file_path,
results,
coverage_database_file if enable_coverage else None,
coverage_config_file if enable_coverage else None,
)
def run_line_profile_tests(
test_paths: TestFiles,
pytest_cmd: str,
test_env: dict[str, str],
cwd: Path,
test_framework: str,
*,
pytest_target_runtime_seconds: float = TOTAL_LOOPING_TIME_EFFECTIVE,
pytest_timeout: int | None = None,
pytest_min_loops: int = 5,
pytest_max_loops: int = 100_000,
js_project_root: Path | None = None,
line_profiler_output_file: Path | None = None,
) -> tuple[Path, subprocess.CompletedProcess]:
# Check if there's a language support for this test framework that implements run_line_profile_tests
language_support = get_language_support_by_framework(test_framework)
if language_support is not None and hasattr(language_support, "run_line_profile_tests"):
return language_support.run_line_profile_tests(
test_paths=test_paths,
test_env=test_env,
cwd=cwd,
timeout=pytest_timeout,
project_root=js_project_root,
line_profile_output_file=line_profiler_output_file,
)
if is_python(): # pytest runs both pytest and unittest tests
pytest_cmd_list = (
shlex.split(f"{SAFE_SYS_EXECUTABLE} -m pytest", posix=IS_POSIX)
if pytest_cmd == "pytest"
else shlex.split(pytest_cmd)
)
# Always use file path - pytest discovers all tests including parametrized ones
test_files: list[str] = list(
{str(file.benchmarking_file_path) for file in test_paths.test_files}
) # remove multiple calls in the same test function
pytest_args = [
"--capture=tee-sys",
"-q",
"--codeflash_loops_scope=session",
"--codeflash_min_loops=1",
"--codeflash_max_loops=1",
f"--codeflash_seconds={pytest_target_runtime_seconds}",
]
if pytest_timeout is not None:
pytest_args.append(f"--timeout={pytest_timeout}")
result_file_path = get_run_tmp_file(Path("pytest_results.xml"))
result_args = [f"--junitxml={result_file_path.as_posix()}", "-o", "junit_logging=all"]
pytest_test_env = test_env.copy()
pytest_test_env["PYTEST_PLUGINS"] = "codeflash.verification.pytest_plugin"
blocklist_args = [f"-p no:{plugin}" for plugin in BENCHMARKING_BLOCKLISTED_PLUGINS]
pytest_test_env["LINE_PROFILE"] = "1"
results = execute_test_subprocess(
pytest_cmd_list + pytest_args + blocklist_args + result_args + test_files,
cwd=cwd,
env=pytest_test_env,
timeout=600, # TODO: Make this dynamic
)
else:
msg = f"Unsupported test framework: {test_framework}"
raise ValueError(msg)
return result_file_path, results
def run_benchmarking_tests(
test_paths: TestFiles,
pytest_cmd: str,
test_env: dict[str, str],
cwd: Path,
test_framework: str,
*,
pytest_target_runtime_seconds: float = TOTAL_LOOPING_TIME_EFFECTIVE,
pytest_timeout: int | None = None,
pytest_min_loops: int = 5,
pytest_max_loops: int = 100_000,
js_project_root: Path | None = None,
) -> tuple[Path, subprocess.CompletedProcess]:
logger.debug(f"run_benchmarking_tests called: framework={test_framework}, num_files={len(test_paths.test_files)}")
# Check if there's a language support for this test framework that implements run_benchmarking_tests
language_support = get_language_support_by_framework(test_framework)
if language_support is not None and hasattr(language_support, "run_benchmarking_tests"):
return language_support.run_benchmarking_tests(
test_paths=test_paths,
test_env=test_env,
cwd=cwd,
timeout=pytest_timeout,
project_root=js_project_root,
min_loops=pytest_min_loops,
max_loops=pytest_max_loops,
target_duration_seconds=pytest_target_runtime_seconds,
)
if is_python(): # pytest runs both pytest and unittest tests
pytest_cmd_list = (
shlex.split(f"{SAFE_SYS_EXECUTABLE} -m pytest", posix=IS_POSIX)
if pytest_cmd == "pytest"
else shlex.split(pytest_cmd)
)
# Always use file path - pytest discovers all tests including parametrized ones
test_files: list[str] = list(
{str(file.benchmarking_file_path) for file in test_paths.test_files}
) # remove multiple calls in the same test function
pytest_args = [
"--capture=tee-sys",
"-q",
"--codeflash_loops_scope=session",
f"--codeflash_min_loops={pytest_min_loops}",
f"--codeflash_max_loops={pytest_max_loops}",
f"--codeflash_seconds={pytest_target_runtime_seconds}",
"--codeflash_stability_check=true",
]
if pytest_timeout is not None:
pytest_args.append(f"--timeout={pytest_timeout}")
result_file_path = get_run_tmp_file(Path("pytest_results.xml"))
result_args = [f"--junitxml={result_file_path.as_posix()}", "-o", "junit_logging=all"]
pytest_test_env = test_env.copy()
pytest_test_env["PYTEST_PLUGINS"] = "codeflash.verification.pytest_plugin"
blocklist_args = [f"-p no:{plugin}" for plugin in BENCHMARKING_BLOCKLISTED_PLUGINS]
results = execute_test_subprocess(
pytest_cmd_list + pytest_args + blocklist_args + result_args + test_files,
cwd=cwd,
env=pytest_test_env,
timeout=600, # TODO: Make this dynamic
)
else:
msg = f"Unsupported test framework: {test_framework}"
raise ValueError(msg)
return result_file_path, results

View file

@ -6,7 +6,7 @@ from typing import Optional
from pydantic.dataclasses import dataclass
from codeflash.languages import current_language_support, is_javascript
from codeflash.languages import current_language_support
def get_test_file_path(
@ -19,7 +19,7 @@ def get_test_file_path(
assert test_type in {"unit", "inspired", "replay", "perf"}
function_name = function_name.replace(".", "_")
# Use appropriate file extension based on language
extension = current_language_support().get_test_file_suffix() if is_javascript() else ".py"
extension = current_language_support().get_test_file_suffix()
# For JavaScript/TypeScript, place generated tests in a subdirectory that matches
# Vitest/Jest include patterns (e.g., test/**/*.test.ts)
@ -164,16 +164,8 @@ class TestConfig:
@property
def test_framework(self) -> str:
"""Returns the appropriate test framework based on language.
For JavaScript/TypeScript: uses the configured framework (vitest, jest, or mocha).
For Python: uses pytest as default.
"""
if is_javascript():
from codeflash.languages.test_framework import get_js_test_framework_or_default
return get_js_test_framework_or_default()
return "pytest"
"""Returns the appropriate test framework based on language."""
return current_language_support().test_framework
def set_language(self, language: str) -> None:
"""Set the language for this test config.

View file

@ -6,8 +6,8 @@ from pathlib import Path
from typing import TYPE_CHECKING
from codeflash.cli_cmds.console import logger
from codeflash.code_utils.code_utils import get_run_tmp_file, module_name_from_file_path
from codeflash.languages import is_javascript
from codeflash.code_utils.code_utils import module_name_from_file_path
from codeflash.languages.current import current_language_support
from codeflash.verification.verification_utils import ModifyInspiredTests, delete_multiple_if_name_main
if TYPE_CHECKING:
@ -35,15 +35,13 @@ def generate_tests(
start_time = time.perf_counter()
test_module_path = Path(module_name_from_file_path(test_path, test_cfg.tests_project_rootdir))
# Detect module system for JavaScript/TypeScript before calling aiservice
project_module_system = None
if is_javascript():
from codeflash.languages.javascript.module_system import detect_module_system
# Detect module system via language support (non-None for JS/TS, None for Python)
lang_support = current_language_support()
source_file = Path(function_to_optimize.file_path)
project_module_system = lang_support.detect_module_system(test_cfg.tests_project_rootdir, source_file)
source_file = Path(function_to_optimize.file_path)
project_module_system = detect_module_system(test_cfg.tests_project_rootdir, source_file)
# For JavaScript, calculate the correct import path from the actual test location
if project_module_system is not None:
# For JavaScript/TypeScript, calculate the correct import path from the actual test location
# (test_path) to the source file, not from tests_root
import os
@ -73,65 +71,18 @@ def generate_tests(
)
if response and isinstance(response, tuple) and len(response) == 3:
generated_test_source, instrumented_behavior_test_source, instrumented_perf_test_source = response
temp_run_dir = get_run_tmp_file(Path()).as_posix()
# For JavaScript/TypeScript, instrumentation is done locally (aiservice returns uninstrumented code)
if is_javascript():
from codeflash.languages.javascript.instrument import (
TestingMode,
fix_imports_inside_test_blocks,
fix_jest_mock_paths,
instrument_generated_js_test,
validate_and_fix_import_style,
)
from codeflash.languages.javascript.module_system import (
ensure_module_system_compatibility,
ensure_vitest_imports,
)
source_file = Path(function_to_optimize.file_path)
# Fix import statements that appear inside test blocks (invalid JS syntax)
generated_test_source = fix_imports_inside_test_blocks(generated_test_source)
# Fix relative paths in jest.mock() calls
generated_test_source = fix_jest_mock_paths(
generated_test_source, test_path, source_file, test_cfg.tests_project_rootdir
)
# Validate and fix import styles (default vs named exports)
generated_test_source = validate_and_fix_import_style(
generated_test_source, source_file, function_to_optimize.function_name
)
# Convert module system if needed (e.g., CommonJS -> ESM for ESM projects)
# Skip conversion if ts-jest is installed (handles interop natively)
generated_test_source = ensure_module_system_compatibility(
generated_test_source, project_module_system, test_cfg.tests_project_rootdir
)
# Ensure vitest imports are present when using vitest framework
generated_test_source = ensure_vitest_imports(generated_test_source, test_cfg.test_framework)
# Instrument for behavior verification (writes to SQLite)
instrumented_behavior_test_source = instrument_generated_js_test(
test_code=generated_test_source, function_to_optimize=function_to_optimize, mode=TestingMode.BEHAVIOR
)
# Instrument for performance measurement (prints to stdout)
instrumented_perf_test_source = instrument_generated_js_test(
test_code=generated_test_source, function_to_optimize=function_to_optimize, mode=TestingMode.PERFORMANCE
)
logger.debug(f"Instrumented JS/TS tests locally for {function_to_optimize.function_name}")
else:
# Python: instrumentation is done by aiservice, just replace temp dir placeholders
instrumented_behavior_test_source = instrumented_behavior_test_source.replace(
"{codeflash_run_tmp_dir_client_side}", temp_run_dir
)
instrumented_perf_test_source = instrumented_perf_test_source.replace(
"{codeflash_run_tmp_dir_client_side}", temp_run_dir
generated_test_source, instrumented_behavior_test_source, instrumented_perf_test_source = (
lang_support.process_generated_test_strings(
generated_test_source=generated_test_source,
instrumented_behavior_test_source=instrumented_behavior_test_source,
instrumented_perf_test_source=instrumented_perf_test_source,
function_to_optimize=function_to_optimize,
test_path=test_path,
test_cfg=test_cfg,
project_module_system=project_module_system,
)
)
else:
logger.warning(f"Failed to generate and instrument tests for {function_to_optimize.function_name}")
return None

View file

@ -0,0 +1 @@

View file

@ -0,0 +1 @@

View file

@ -3,25 +3,31 @@ title: "How Codeflash Works"
description: "Understand Codeflash's generate-and-verify approach to code optimization and correctness verification"
icon: "gear"
sidebarTitle: "How It Works"
keywords: ["architecture", "verification", "correctness", "testing", "optimization", "LLM", "benchmarking"]
keywords: ["architecture", "verification", "correctness", "testing", "optimization", "LLM", "benchmarking", "javascript", "typescript", "python"]
---
# How Codeflash Works
Codeflash follows a "generate and verify" approach to optimize code. It uses LLMs to generate optimizations, then it rigorously verifies if those optimizations are indeed
faster and if they have the same behavior. The basic unit of optimization is a function—Codeflash tries to speed up the function, and tries to ensure that it still behaves the same way. This way if you merge the optimized code, it simply runs faster without breaking any functionality.
Codeflash supports **Python**, **JavaScript**, and **TypeScript** projects.
## Analysis of your code
Codeflash scans your codebase to identify all available functions. It locates existing unit tests in your projects and maps which functions they test. When optimizing a function, Codeflash runs these discovered tests to verify nothing has broken.
For Python, code analysis uses `libcst` and `jedi`. For JavaScript/TypeScript, it uses `tree-sitter` for AST parsing.
#### What kind of functions can Codeflash optimize?
Codeflash works best with self-contained functions that have minimal side effects (like communicating with external systems or sending network requests). Codeflash optimizes a group of functions - consisting of an entry point function and any other functions it directly calls.
Codeflash supports optimizing async functions.
Codeflash supports optimizing async functions in all supported languages.
#### Test Discovery
Codeflash currently only runs tests that directly call the target function in their test body. To discover tests that indirectly call the function, you can use the Codeflash Tracer. The Tracer analyzes your test suite and identifies all tests that eventually call a function.
Codeflash discovers tests that directly call the target function in their test body. For Python, it finds pytest and unittest tests. For JavaScript/TypeScript, it finds Jest and Vitest test files.
To discover tests that indirectly call the function, you can use the Codeflash Tracer. The Tracer analyzes your test suite and identifies all tests that eventually call a function.
## Optimization Generation
@ -48,12 +54,12 @@ We recommend manually reviewing the optimized code since there might be importan
Codeflash generates two types of tests:
- LLM Generated tests - Codeflash uses LLMs to create several regression test cases that cover typical function usage, edge cases, and large-scale inputs to verify both correctness and performance.
- Concolic coverage tests - Codeflash uses state-of-the-art concolic testing with an SMT Solver (a theorem prover) to explore execution paths and generate function arguments. This aims to maximize code coverage for the function being optimized. Codeflash runs the resulting test file to verify correctness. Currently, this feature only supports pytest.
- **LLM Generated tests** - Codeflash uses LLMs to create several regression test cases that cover typical function usage, edge cases, and large-scale inputs to verify both correctness and performance. This works for Python, JavaScript, and TypeScript.
- **Concolic coverage tests** - Codeflash uses state-of-the-art concolic testing with an SMT Solver (a theorem prover) to explore execution paths and generate function arguments. This aims to maximize code coverage for the function being optimized. Currently, this feature only supports Python (pytest).
## Code Execution
Codeflash runs tests for the target function using either pytest or unittest frameworks. The tests execute on your machine, ensuring access to the Python environment and any other dependencies associated to let Codeflash run your code properly. Running on your machine also ensures accurate performance measurements since runtime varies by system.
Codeflash runs tests for the target function on your machine. For Python, it uses pytest or unittest. For JavaScript/TypeScript, it uses Jest or Vitest. Running on your machine ensures access to your environment and dependencies, and provides accurate performance measurements since runtime varies by system.
#### Performance benchmarking

View file

@ -1,83 +1,26 @@
---
title: "Manual Configuration"
description: "Configure Codeflash for your project with pyproject.toml settings and advanced options"
description: "Configure Codeflash for your project"
icon: "gear"
sidebarTitle: "Manual Configuration"
keywords:
[
"configuration",
"pyproject.toml",
"setup",
"settings",
"pytest",
"formatter",
]
---
# Manual Configuration
Codeflash is installed and configured on a per-project basis.
`codeflash init` should guide you through the configuration process, but if you need to manually configure Codeflash or set advanced settings, you can do so by editing the `pyproject.toml` file in the root directory of your project.
`codeflash init` should guide you through the configuration process, but if you need to manually configure Codeflash or set advanced settings, follow the guide for your language:
## Configuration Options
Codeflash config looks like the following
```toml
[tool.codeflash]
module-root = "my_module"
tests-root = "tests"
formatter-cmds = ["black $file"]
# optional configuration
benchmarks-root = "tests/benchmarks" # Required when running with --benchmark
ignore-paths = ["my_module/build/"]
pytest-cmd = "pytest"
disable-imports-sorting = false
disable-telemetry = false
git-remote = "origin"
override-fixtures = false
```
All file paths are relative to the directory of the `pyproject.toml` file.
Required Options:
- `module-root`: The Python module you want Codeflash to optimize going forward. Only code under this directory will be optimized. It should also have an `__init__.py` file to make the module importable.
- `tests-root`: The directory where your tests are located. Codeflash will use this directory to discover existing tests as well as generate new tests.
Optional Configuration:
- `benchmarks-root`: The directory where your benchmarks are located. Codeflash will use this directory to discover existing benchmarks. Note that this option is required when running with `--benchmark`.
- `ignore-paths`: A list of paths within the `module-root` to ignore when optimizing code. Codeflash will not optimize code in these paths. Useful for ignoring build directories or other generated code. You can also leave this empty if not needed.
- `pytest-cmd`: The command to run your tests. Defaults to `pytest`. You can specify extra commandline arguments here for pytest.
- `formatter-cmds`: The command line to run your code formatter or linter. Defaults to `["black $file"]`. In the command line `$file` refers to the current file being optimized. The assumption with using tools here is that they overwrite the same file and returns a zero exit code. You can also specify multiple tools here that run in a chain as a toml array. You can also disable code formatting by setting this to `["disabled"]`.
- `ruff` - A recommended way to run ruff linting and formatting is `["ruff check --exit-zero --fix $file", "ruff format $file"]`. To make `ruff check --fix` return a 0 exit code please add a `--exit-zero` argument.
- `disable-imports-sorting`: By default, codeflash uses isort to organize your imports before creating suggestions. You can disable this by setting this field to `true`. This could be useful if you don't sort your imports or while using linters like ruff that sort imports too.
- `disable-telemetry`: Disable telemetry data collection. Defaults to `false`. Set this to `true` to disable telemetry data collection. Codeflash collects anonymized telemetry data to understand how users are using Codeflash and to improve the product. Telemetry does not collect any code data.
- `git-remote`: The git remote to use for pull requests. Defaults to `"origin"`.
- `override-fixtures`: Override pytest fixtures during optimization. Defaults to `false`.
## Example Configuration
Here's an example project with the following structure:
```text
acme-project/
|- foo_module/
| |- __init__.py
| |- foo.py
| |- main.py
|- tests/
| |- __init__.py
| |- test_script.py
|- pyproject.toml
```
Here's a sample `pyproject.toml` file for the above project:
```toml
[tool.codeflash]
module-root = "foo_module"
tests-root = "tests"
ignore-paths = []
```
<CardGroup cols={2}>
<Card title="Python Configuration" icon="python" href="/configuration/python">
Configure via `pyproject.toml`
</Card>
<Card title="JavaScript / TypeScript Configuration" icon="js" href="/configuration/javascript">
Configure via `package.json`
</Card>
</CardGroup>

View file

@ -0,0 +1,220 @@
---
title: "JavaScript / TypeScript Configuration"
description: "Configure Codeflash for JavaScript and TypeScript projects using package.json"
icon: "js"
sidebarTitle: "JavaScript / TypeScript"
keywords:
[
"configuration",
"package.json",
"javascript",
"typescript",
"jest",
"vitest",
"prettier",
"eslint",
"monorepo",
]
---
# JavaScript / TypeScript Configuration
Codeflash stores its configuration in `package.json` under the `"codeflash"` key.
## Full Reference
```json
{
"name": "my-project",
"codeflash": {
"moduleRoot": "src",
"testsRoot": "tests",
"testRunner": "jest",
"formatterCmds": ["prettier --write $file"],
"ignorePaths": ["src/generated/"],
"disableTelemetry": false,
"gitRemote": "origin"
}
}
```
All file paths are relative to the directory containing `package.json`.
<Info>
Codeflash auto-detects most settings from your project structure. Running `codeflash init` will set up the correct config — manual configuration is usually not needed.
</Info>
## Auto-Detection
When you run `codeflash init`, Codeflash inspects your project and auto-detects:
| Setting | Detection logic |
|---------|----------------|
| `moduleRoot` | Looks for `src/`, `lib/`, or the main source directory |
| `testsRoot` | Looks for `tests/`, `test/`, `__tests__/`, or files matching `*.test.js` / `*.spec.js` |
| `testRunner` | Checks `devDependencies` for `jest` or `vitest` |
| `formatterCmds` | Checks for `prettier`, `eslint`, or `biome` in dependencies and config files |
| Module system | Reads `"type"` field in `package.json` (ESM vs CommonJS) |
| TypeScript | Detects `tsconfig.json` |
You can always override any auto-detected value in the `"codeflash"` section.
## Required Options
- `moduleRoot`: The source directory to optimize. Only code under this directory will be optimized.
- `testsRoot`: The directory where your tests are located. Codeflash discovers existing tests and generates new ones here.
## Optional Options
- `testRunner`: Test framework to use. Auto-detected from your dependencies. Supported values: `"jest"`, `"vitest"`.
- `formatterCmds`: Formatter commands. `$file` refers to the file being optimized. Disable with `["disabled"]`.
- **Prettier**: `["prettier --write $file"]`
- **ESLint + Prettier**: `["eslint --fix $file", "prettier --write $file"]`
- **Biome**: `["biome check --write $file"]`
- `ignorePaths`: Paths within `moduleRoot` to skip during optimization.
- `disableTelemetry`: Disable anonymized telemetry. Defaults to `false`.
- `gitRemote`: Git remote for pull requests. Defaults to `"origin"`.
## Module Systems
Codeflash handles both ES Modules and CommonJS automatically. It detects the module system from your `package.json`:
```json
{
"type": "module"
}
```
- `"type": "module"` — Files are treated as ESM (`import`/`export`)
- `"type": "commonjs"` or omitted — Files are treated as CommonJS (`require`/`module.exports`)
No additional configuration is needed. Codeflash respects `.mjs`/`.cjs` extensions as well.
## TypeScript
TypeScript projects work out of the box. Codeflash detects TypeScript from the presence of `tsconfig.json` and handles `.ts`/`.tsx` files automatically.
No separate configuration is needed for TypeScript vs JavaScript.
## Test Framework Support
| Framework | Auto-detected from | Notes |
|-----------|-------------------|-------|
| **Jest** | `jest` in dependencies | Default for most projects |
| **Vitest** | `vitest` in dependencies | ESM-native support |
<Info>
**Functions must be exported** to be optimizable. Codeflash uses tree-sitter AST analysis to discover functions and check export status. Supported export patterns:
- `export function foo() {}`
- `export const foo = () => {}`
- `export default function foo() {}`
- `const foo = () => {}; export { foo };`
- `module.exports = { foo }`
- `const utils = { foo() {} }; module.exports = utils;`
</Info>
## Monorepo Configuration
For monorepo projects (Yarn workspaces, pnpm workspaces, Lerna, Nx, Turborepo), configure each package individually:
```text
my-monorepo/
|- packages/
| |- core/
| | |- src/
| | |- tests/
| | |- package.json <-- "codeflash" config here
| |- utils/
| | |- src/
| | |- __tests__/
| | |- package.json <-- "codeflash" config here
|- package.json <-- workspace root (no codeflash config)
```
Run `codeflash init` from within each package:
```bash
cd packages/core
npx codeflash init
```
<Warning>
**Always run codeflash from the package directory**, not the monorepo root. Codeflash needs to find the `package.json` with the `"codeflash"` config in the current working directory.
</Warning>
### Hoisted dependencies
If your monorepo hoists `node_modules` to the root (Yarn Berry with `nodeLinker: node-modules`, pnpm with `shamefully-hoist`), Codeflash resolves modules using Node.js standard resolution. This works automatically.
For **pnpm strict mode** (non-hoisted), ensure `codeflash` is a direct dependency of the package:
```bash
pnpm add --filter @my-org/core --save-dev codeflash
```
## Example
### Standard project
```text
my-app/
|- src/
| |- utils.js
| |- index.js
|- tests/
| |- utils.test.js
|- package.json
```
```json
{
"name": "my-app",
"codeflash": {
"moduleRoot": "src",
"testsRoot": "tests"
}
}
```
### Project with co-located tests
```text
my-app/
|- src/
| |- utils.js
| |- utils.test.js
| |- index.js
|- package.json
```
```json
{
"name": "my-app",
"codeflash": {
"moduleRoot": "src",
"testsRoot": "src"
}
}
```
### CommonJS library with no separate test directory
```text
my-lib/
|- lib/
| |- helpers.js
|- test/
| |- helpers.spec.js
|- package.json
```
```json
{
"name": "my-lib",
"codeflash": {
"moduleRoot": "lib",
"testsRoot": "test"
}
}
```

View file

@ -0,0 +1,80 @@
---
title: "Python Configuration"
description: "Configure Codeflash for Python projects using pyproject.toml"
icon: "python"
sidebarTitle: "Python"
keywords:
[
"configuration",
"pyproject.toml",
"python",
"pytest",
"formatter",
"ruff",
"black",
]
---
# Python Configuration
Codeflash stores its configuration in `pyproject.toml` under the `[tool.codeflash]` section.
## Full Reference
```toml
[tool.codeflash]
# Required
module-root = "my_module"
tests-root = "tests"
# Optional
formatter-cmds = ["black $file"]
benchmarks-root = "tests/benchmarks"
ignore-paths = ["my_module/build/"]
pytest-cmd = "pytest"
disable-imports-sorting = false
disable-telemetry = false
git-remote = "origin"
override-fixtures = false
```
All file paths are relative to the directory of the `pyproject.toml` file.
## Required Options
- `module-root`: The Python module to optimize. Only code under this directory will be optimized. It should have an `__init__.py` file to make the module importable.
- `tests-root`: The directory where your tests are located. Codeflash discovers existing tests and generates new ones here.
## Optional Options
- `benchmarks-root`: Directory for benchmarks. Required when running with `--benchmark`.
- `ignore-paths`: Paths within `module-root` to skip. Useful for build directories or generated code.
- `pytest-cmd`: Command to run your tests. Defaults to `pytest`. You can add extra arguments here.
- `formatter-cmds`: Formatter/linter commands. `$file` refers to the file being optimized. Disable with `["disabled"]`.
- **ruff** (recommended): `["ruff check --exit-zero --fix $file", "ruff format $file"]`
- **black**: `["black $file"]`
- `disable-imports-sorting`: Disable isort import sorting. Defaults to `false`.
- `disable-telemetry`: Disable anonymized telemetry. Defaults to `false`.
- `git-remote`: Git remote for pull requests. Defaults to `"origin"`.
- `override-fixtures`: Override pytest fixtures during optimization. Defaults to `false`.
## Example
```text
acme-project/
|- foo_module/
| |- __init__.py
| |- foo.py
| |- main.py
|- tests/
| |- __init__.py
| |- test_script.py
|- pyproject.toml
```
```toml
[tool.codeflash]
module-root = "foo_module"
tests-root = "tests"
ignore-paths = []
```

View file

@ -1,38 +1,51 @@
---
title: "JavaScript Installation"
title: "JavaScript / TypeScript Installation"
description: "Install and configure Codeflash for your JavaScript/TypeScript project"
icon: "node-js"
keywords:
[
"installation",
"javascript",
"typescript",
"npm",
"yarn",
"pnpm",
"bun",
"jest",
"vitest",
"monorepo",
]
---
Codeflash now supports JavaScript and TypeScript projects with optimized test data serialization using V8 native serialization.
Codeflash supports JavaScript and TypeScript projects. It uses V8 native serialization for test data capture and works with Jest and Vitest test frameworks.
### Prerequisites
Before installing Codeflash for JavaScript, ensure you have:
Before installing Codeflash, ensure you have:
1. **Node.js 16 or above** installed
1. **Node.js 18 or above** installed
2. **A JavaScript/TypeScript project** with a package manager (npm, yarn, pnpm, or bun)
3. **Project dependencies installed**
Good to have (optional):
1. **Unit Tests** that Codeflash uses to ensure correctness of the optimizations
1. **Unit tests** (Jest or Vitest) — Codeflash uses them to verify correctness of optimizations
<Warning>
**Node.js Runtime Required**
**Node.js 18+ Required**
Codeflash JavaScript support uses V8 serialization API, which is available natively in Node.js. Make sure you're running on Node.js 16+ for optimal compatibility.
Codeflash requires Node.js 18 or above. Check your version:
```bash
node --version # Should show v16.0.0 or higher
node --version # Should show v18.0.0 or higher
```
</Warning>
<Steps>
<Step title="Install Codeflash CLI">
<Step title="Install the Codeflash npm package">
Install Codeflash globally or as a development dependency in your project:
Install Codeflash as a development dependency in your project:
<CodeGroup>
```bash npm
@ -50,321 +63,285 @@ pnpm add --save-dev codeflash
```bash bun
bun add --dev codeflash
```
```bash global
npm install -g codeflash
```
</CodeGroup>
<Tip>
**Development Dependency Recommended**
Codeflash is intended for development and CI workflows. Installing as a dev dependency keeps your production bundle clean.
**Dev dependency recommended** — Codeflash is for development and CI workflows. Installing as a dev dependency keeps your production bundle clean.
</Tip>
<Info>
**Codeflash also requires a Python installation** (3.9+) to run the CLI optimizer. Install the Python CLI globally:
```bash
pip install codeflash
# or
uv tool install codeflash
```
The Python CLI orchestrates the optimization pipeline, while the npm package provides the JavaScript runtime (test runners, serialization, reporters).
</Info>
</Step>
<Step title="Generate a Codeflash API Key">
Codeflash uses cloud-hosted AI models. You need an API key:
1. Visit the [Codeflash Web App](https://app.codeflash.ai/)
2. Sign up with your GitHub account (free tier available)
3. Navigate to the [API Key](https://app.codeflash.ai/app/apikeys) page to generate your key
Set it as an environment variable:
```bash
export CODEFLASH_API_KEY="your-api-key-here"
```
Or add it to your shell profile (`~/.bashrc`, `~/.zshrc`) for persistence.
</Step>
<Step title="Run Automatic Configuration">
Navigate to your project's root directory (where your `package.json` file is) and run:
Navigate to your project root (where `package.json` is) and run:
```bash
<CodeGroup>
```bash npm / yarn / pnpm
npx codeflash init
```
```bash bun
bunx codeflash init
```
```bash Global install
codeflash init
```
</CodeGroup>
When running `codeflash init`, you will see the following prompts:
### What `codeflash init` does
```text
1. Enter your Codeflash API key (or login with Codeflash)
2. Which JavaScript/TypeScript module do you want me to optimize? (e.g. src/)
3. Where are your tests located? (e.g. tests/, __tests__/, *.test.js)
4. Which test framework do you use? (jest/vitest/mocha/ava/other)
5. Which code formatter do you use? (prettier/eslint/biome/disabled)
6. Which git remote should Codeflash use for Pull Requests? (if multiple remotes exist)
7. Help us improve Codeflash by sharing anonymous usage data?
8. Install the GitHub app
9. Install GitHub actions for Continuous optimization?
Codeflash **auto-detects** most settings from your project:
| Setting | How it's detected |
|---------|------------------|
| **Module root** | Looks for `src/`, `lib/`, or the directory containing your source files |
| **Tests root** | Looks for `tests/`, `test/`, `__tests__/`, or files matching `*.test.js` / `*.spec.js` |
| **Test framework** | Checks `devDependencies` for `jest` or `vitest` |
| **Formatter** | Checks for `prettier`, `eslint`, or `biome` in dependencies and config files |
| **Module system** | Reads `"type"` field in `package.json` (ESM vs CommonJS) |
| **TypeScript** | Detects `tsconfig.json` presence |
You'll be prompted to confirm or override the detected values. The configuration is saved in your `package.json` under the `"codeflash"` key:
```json
{
"name": "my-project",
"codeflash": {
"moduleRoot": "src",
"testsRoot": "tests"
}
}
```
After you have answered these questions, the Codeflash configuration will be saved in a `codeflash.config.js` file.
<Info>
**Test Data Serialization Strategy**
Codeflash uses **V8 serialization** for JavaScript test data capture. This provides:
- ⚡ **Best performance**: 2-3x faster than alternatives
- 🎯 **Perfect type preservation**: Maintains Date, Map, Set, TypedArrays, and more
- 📦 **Compact binary storage**: Smallest file sizes
- 🔄 **Framework agnostic**: Works with React, Vue, Angular, Svelte, and vanilla JS
**No separate config file needed.** Codeflash stores all configuration inside your existing `package.json`, not in a separate config file.
</Info>
</Step>
<Step title="Generate a Codeflash API Key">
Codeflash uses cloud-hosted AI models and integrations with GitHub. If you haven't created one already, you'll need to create an API key to authorize your access.
<Step title="Install the Codeflash GitHub App (optional)">
1. Visit the [Codeflash Web App](https://app.codeflash.ai/)
2. Sign up with your GitHub account (free)
3. Navigate to the [API Key](https://app.codeflash.ai/app/apikeys) page to generate your API key
To receive optimization PRs automatically, install the Codeflash GitHub App:
<Note>
**Free Tier Available**
[Install Codeflash GitHub App](https://github.com/apps/codeflash-ai/installations/select_target)
Codeflash offers a **free tier** with a limited number of optimizations. Perfect for trying it out on small projects!
</Note>
</Step>
<Step title="Install the Codeflash GitHub App">
Finally, if you have not done so already, Codeflash will ask you to install the GitHub App in your repository.
The Codeflash GitHub App allows the codeflash-ai bot to open PRs, review code, and provide optimization suggestions.
Please [install the Codeflash GitHub
app](https://github.com/apps/codeflash-ai/installations/select_target) by choosing the repository you want to install
Codeflash on.
This enables the codeflash-ai bot to open PRs with optimization suggestions. If you skip this step, you can still optimize locally using `--no-pr`.
</Step>
</Steps>
## Framework Support
## Monorepo Setup
Codeflash JavaScript support works seamlessly with all major frameworks and testing libraries:
For monorepos (Yarn workspaces, pnpm workspaces, Lerna, Nx, Turborepo), run `codeflash init` from within each package you want to optimize:
<CardGroup cols={2}>
<Card title="Frontend Frameworks" icon="react">
- React
- Vue.js
- Angular
- Svelte
- Solid.js
</Card>
```bash
# Navigate to the specific package
cd packages/my-library
<Card title="Test Frameworks" icon="flask">
- Jest
- Vitest
- Mocha
- AVA
- Playwright
- Cypress
</Card>
# Run init from the package directory
npx codeflash init
```
<Card title="Backend" icon="server">
- Express
- NestJS
- Fastify
- Koa
- Hono
</Card>
Each package gets its own `"codeflash"` section in its `package.json`. The `moduleRoot` and `testsRoot` paths are relative to that package's `package.json`.
<Card title="Runtimes" icon="gears">
- Node.js ✅ (Recommended)
- Bun (Coming soon)
- Deno (Coming soon)
</Card>
</CardGroup>
### Example: Yarn workspaces monorepo
## Understanding V8 Serialization
```text
my-monorepo/
|- packages/
| |- core/
| | |- src/
| | |- tests/
| | |- package.json <-- codeflash config here
| |- utils/
| | |- src/
| | |- __tests__/
| | |- package.json <-- codeflash config here
|- package.json <-- root workspace (no codeflash config needed)
```
Codeflash uses Node.js's native V8 serialization API to capture and compare test data. Here's what makes it powerful:
### Type Preservation
Unlike JSON serialization, V8 serialization preserves JavaScript-specific types:
```javascript
// These types are preserved perfectly:
const testData = {
date: new Date(), // ✅ Date objects
map: new Map([['key', 'value']]), // ✅ Map instances
set: new Set([1, 2, 3]), // ✅ Set instances
buffer: Buffer.from('hello'), // ✅ Buffers
typed: new Uint8Array([1, 2, 3]), // ✅ TypedArrays
bigint: 9007199254740991n, // ✅ BigInt
regex: /pattern/gi, // ✅ RegExp
undef: undefined, // ✅ undefined (not null!)
circular: {} // ✅ Circular references
};
testData.circular.self = testData.circular;
```json
// packages/core/package.json
{
"name": "@my-org/core",
"codeflash": {
"moduleRoot": "src",
"testsRoot": "tests"
}
}
```
<Warning>
**Why Not JSON?**
JSON serialization would cause bugs to slip through:
- `Date` becomes string → date arithmetic fails silently
- `Map` becomes `{}` → `.get()` calls return undefined
- `undefined` becomes `null` → type checks break
- TypedArrays become plain objects → binary operations fail
V8 serialization catches these issues during optimization verification.
**Run codeflash from the package directory**, not the monorepo root. Codeflash needs to find the `package.json` with the `"codeflash"` config in the current working directory.
</Warning>
## Try It Out!
<Info>
**Hoisted dependencies work fine.** If your monorepo hoists `node_modules` to the root (common in Yarn Berry, pnpm with `shamefully-hoist`), Codeflash resolves modules using Node.js standard resolution and will find them correctly.
</Info>
<Tabs>
<Tab title="Quick Start">
Once configured, you can start optimizing your JavaScript/TypeScript code immediately:
## Test Framework Support
```bash
# Optimize a specific function
codeflash --file path/to/your/file.js --function functionName
| Framework | Status | Auto-detected from |
|-----------|--------|-------------------|
| **Jest** | Supported | `jest` in dependencies |
| **Vitest** | Supported | `vitest` in dependencies |
| **Mocha** | Coming soon | — |
# Or optimize all functions in your codebase
<Info>
**Functions must be exported** to be optimizable. Codeflash can only discover and optimize functions that are exported from their module (via `export`, `export default`, or `module.exports`).
</Info>
## Try It Out
Once configured, optimize your code:
<CodeGroup>
```bash Optimize a function
codeflash --file src/utils.js --function processData
```
```bash Optimize locally (no PR)
codeflash --file src/utils.ts --function processData --no-pr
```
```bash Optimize entire codebase
codeflash --all
```
</Tab>
<Tab title="TypeScript Support">
Codeflash fully supports TypeScript projects:
```bash
# Optimize TypeScript files directly
codeflash --file src/utils.ts --function processData
# Works with TSX for React components
codeflash --file src/components/DataTable.tsx --function DataTable
```bash Trace and optimize
codeflash optimize --jest
```
<Info>
Codeflash preserves TypeScript types during optimization. Your type annotations and interfaces remain intact.
</Info>
</Tab>
<Tab title="Test Framework Examples">
<Accordion title="Jest Example">
```javascript
// sum.test.js
test('adds 1 + 2 to equal 3', () => {
expect(sum(1, 2)).toBe(3);
});
// Optimize the sum function
codeflash --file sum.js --function sum
```
</Accordion>
<Accordion title="Vitest Example">
```javascript
// calculator.test.js
import { describe, it, expect } from 'vitest';
describe('calculator', () => {
it('should multiply correctly', () => {
expect(multiply(2, 3)).toBe(6);
});
});
// Optimize the multiply function
codeflash --file calculator.js --function multiply
```
</Accordion>
</Tab>
</Tabs>
</CodeGroup>
## Troubleshooting
<AccordionGroup>
<Accordion title="📦 Module not found errors">
Make sure:
- ✅ All project dependencies are installed
- ✅ Your `node_modules` directory exists
<Accordion title="Function not found or not exported">
Codeflash only optimizes **exported** functions. Make sure your function is exported:
```bash
# Reinstall dependencies
npm install
# or
yarn install
```
</Accordion>
<Accordion title="🔧 V8 serialization errors">
If you encounter serialization errors:
**Functions and classes** cannot be serialized:
```javascript
// ❌ Won't work - contains function
const data = { callback: () => {} };
// ES Modules
export function processData(data) { ... }
// or
const processData = (data) => { ... };
export { processData };
// ✅ Works - pure data
const data = { value: 42, items: [1, 2, 3] };
// CommonJS
function processData(data) { ... }
module.exports = { processData };
```
**Symbols** are not serializable:
```javascript
// ❌ Won't work
const data = { [Symbol('key')]: 'value' };
// ✅ Use string keys
const data = { key: 'value' };
```
If codeflash reports the function exists but is not exported, add an export statement.
</Accordion>
<Accordion title="🧪 No optimizations found">
Not all functions can be optimized - some code is already optimal. This is expected.
<Accordion title="codeflash npm package not found / module errors">
Ensure the codeflash npm package is installed in your project:
Use the `--verbose` flag for detailed output:
```bash
codeflash optimize --verbose
<CodeGroup>
```bash npm
npm install --save-dev codeflash
```
```bash yarn
yarn add --dev codeflash
```
```bash pnpm
pnpm add --save-dev codeflash
```
</CodeGroup>
This will show:
- 🔍 Which functions are being analyzed
- 🚫 Why certain functions were skipped
- ⚠️ Detailed error messages
- 📊 Performance analysis results
For **monorepos**, make sure it's installed in the package you're optimizing, or at the workspace root if dependencies are hoisted.
</Accordion>
<Accordion title="🔍 Test discovery issues">
Verify:
- 📁 Your test directory path is correct in `codeflash.config.js`
- 🔍 Tests are discoverable by your test framework
- 📝 Test files follow naming conventions (`*.test.js`, `*.spec.js`)
<Accordion title="Test framework not detected">
Codeflash auto-detects the test framework from your `devDependencies`. If detection fails:
1. Verify your test framework is in `devDependencies`:
```bash
npm ls jest # or: npm ls vitest
```
2. Or set it manually in `package.json`:
```json
{
"codeflash": {
"testRunner": "jest"
}
}
```
</Accordion>
<Accordion title="Jest tests timing out">
If Jest tests take too long, Codeflash has a default timeout. For large test suites:
- Use `--file` and `--function` to target specific functions instead of `--all`
- Ensure your tests don't have expensive setup/teardown that runs for every test file
- Check if `jest.config.js` has a `setupFiles` that takes a long time
</Accordion>
<Accordion title="TypeScript compilation errors">
Codeflash uses your project's TypeScript configuration. If you see TS errors:
1. Verify `npx tsc --noEmit` passes on its own
2. Check that `tsconfig.json` is in the project root or the module root
3. For projects using `moduleResolution: "bundler"`, Codeflash creates a temporary tsconfig overlay — this is expected behavior
</Accordion>
<Accordion title="Monorepo: wrong package.json detected">
Run codeflash from the correct package directory:
```bash
# Test if your test framework can discover tests
npm test -- --listTests # Jest
# or
npx vitest list # Vitest
cd packages/my-library
codeflash --file src/utils.ts --function myFunc
```
If your monorepo tool hoists dependencies, you may need to ensure the `codeflash` npm package is accessible from the package directory. For pnpm, add `.npmrc` with `shamefully-hoist=true` or use `pnpm add --filter my-library --save-dev codeflash`.
</Accordion>
<Accordion title="No optimizations found">
Not all functions can be optimized — some code is already efficient. This is normal.
For better results:
- Target functions with loops, string manipulation, or data transformations
- Ensure the function has existing tests for correctness verification
- Use `codeflash optimize --jest` to trace real execution and capture realistic inputs
</Accordion>
</AccordionGroup>
## Configuration
## Configuration Reference
Your `codeflash.config.js` file controls how Codeflash analyzes your JavaScript project:
```javascript
module.exports = {
// Source code to optimize
module: 'src',
// Test location
tests: 'tests',
// Test framework
testFramework: 'jest',
// Serialization strategy (automatically set to 'v8')
serialization: 'v8',
// Formatter
formatter: 'prettier',
// Additional options
exclude: ['node_modules', 'dist', 'build'],
verbose: false
};
```
See [JavaScript / TypeScript Configuration](/configuration/javascript) for the full list of options.
### Next Steps
- Learn about [Codeflash Concepts](/codeflash-concepts/how-codeflash-works)
- Explore [Optimization workflows](/optimizing-with-codeflash/one-function)
- Learn [how Codeflash works](/codeflash-concepts/how-codeflash-works)
- [Optimize a single function](/optimizing-with-codeflash/one-function)
- Set up [Pull Request Optimization](/optimizing-with-codeflash/codeflash-github-actions)
- Read [configuration options](/configuration) for advanced setups
- Explore [Trace and Optimize](/optimizing-with-codeflash/trace-and-optimize) for workflow optimization

View file

@ -1,27 +1,38 @@
---
title: "Codeflash is an AI performance optimizer for Python code"
title: "Codeflash is an AI performance optimizer for your code"
icon: "rocket"
sidebarTitle: "Overview"
keywords: ["python", "performance", "optimization", "AI", "code analysis", "benchmarking"]
keywords: ["python", "javascript", "typescript", "performance", "optimization", "AI", "code analysis", "benchmarking"]
---
Codeflash speeds up any Python code by figuring out the best way to rewrite it while verifying that the behavior of the code is unchanged, and verifying real speed
gains through performance benchmarking.
Codeflash speeds up your code by figuring out the best way to rewrite it while verifying that the behavior is unchanged, and verifying real speed
gains through performance benchmarking. It supports **Python**, **JavaScript**, and **TypeScript**.
The optimizations Codeflash finds are generally better algorithms, opportunities to remove wasteful compute, better logic, utilizing caching and utilization of more efficient library methods. Codeflash
does not modify the system architecture of your code, but it tries to find the most efficient implementation of your current architecture.
### Get Started
<CardGroup cols={2}>
<Card title="Python Setup" icon="python" href="/getting-started/local-installation">
Install via pip, uv, or poetry
</Card>
<Card title="JavaScript / TypeScript Setup" icon="js" href="/getting-started/javascript-installation">
Install via npm, yarn, pnpm, or bun
</Card>
</CardGroup>
### How to use Codeflash
<CardGroup cols={1}>
<Card title="Optimize a Single Function" icon="bullseye" href="/optimizing-with-codeflash/one-function">
Target and optimize individual Python functions for maximum performance gains.
Target and optimize individual functions for maximum performance gains.
```bash
codeflash --file path.py --function my_function
codeflash --file path/to/file --function my_function
```
</Card>
<Card title="Optimize Pull Requests" icon="code-pull-request" href="/optimizing-with-codeflash/codeflash-github-actions">
<Card title="Optimize Pull Requests" icon="code-pull-request" href="/optimizing-with-codeflash/codeflash-github-actions">
Automatically find optimizations for Pull Requests with GitHub Actions integration.
```bash
codeflash init-actions
@ -29,7 +40,7 @@ does not modify the system architecture of your code, but it tries to find the m
</Card>
<Card title="Optimize Workflows with Tracing" icon="route" href="/optimizing-with-codeflash/trace-and-optimize">
End-to-end optimization of entire Python workflows with execution tracing.
End-to-end optimization of entire workflows with execution tracing.
```bash
codeflash optimize myscript.py
```
@ -42,7 +53,6 @@ does not modify the system architecture of your code, but it tries to find the m
```
</Card>
</CardGroup>
### How does Codeflash verify correctness?

View file

@ -1,6 +1,6 @@
---
title: "Optimize Performance Benchmarks with every Pull Request"
description: "Configure and use pytest-benchmark integration for performance-critical code optimization"
description: "Configure and use benchmark integration for performance-critical code optimization"
icon: "chart-line"
sidebarTitle: Setup Benchmarks to Optimize
keywords:
@ -26,6 +26,10 @@ It will then try to optimize the new code for the benchmark and calculate the im
## Using Codeflash in Benchmark Mode
<Note>
Benchmark mode currently supports Python projects using pytest-benchmark. JavaScript/TypeScript benchmark support is coming soon.
</Note>
1. **Create a benchmarks root:**
Create a directory for benchmarks if it does not already exist.
@ -44,7 +48,7 @@ It will then try to optimize the new code for the benchmark and calculate the im
2. **Define your benchmarks:**
Currently, Codeflash only supports benchmarks written as pytest-benchmarks. Check out the [pytest-benchmark](https://pytest-benchmark.readthedocs.io/en/stable/index.html) documentation for more information on syntax.
Codeflash supports benchmarks written as pytest-benchmarks. Check out the [pytest-benchmark](https://pytest-benchmark.readthedocs.io/en/stable/index.html) documentation for more information on syntax.
For example:
@ -58,7 +62,7 @@ It will then try to optimize the new code for the benchmark and calculate the im
Note that these benchmarks should be defined in such a way that they don't take a long time to run.
The pytest-benchmark format is simply used as an interface. The plugin is actually not used - Codeflash will run these benchmarks with its own pytest plugin
The pytest-benchmark format is simply used as an interface. The plugin is actually not used - Codeflash will run these benchmarks with its own pytest plugin.
3. **Run and Test Codeflash:**
@ -74,7 +78,7 @@ It will then try to optimize the new code for the benchmark and calculate the im
codeflash --file test_file.py --benchmark --benchmarks-root path/to/benchmarks
```
4. **Run Codeflash :**
4. **Run Codeflash with GitHub Actions:**
Benchmark mode is best used together with Codeflash as a GitHub Action. This way,
Codeflash will trace through your benchmark and optimize the functions modified in your Pull Request to speed up the benchmark.

View file

@ -3,13 +3,13 @@ title: "Optimize Your Entire Codebase"
description: "Automatically optimize all codepaths in your project with Codeflash's comprehensive analysis"
icon: "database"
sidebarTitle: "Optimize Entire Codebase"
keywords: ["codebase optimization", "all functions", "batch optimization", "github app", "checkpoint", "recovery"]
keywords: ["codebase optimization", "all functions", "batch optimization", "github app", "checkpoint", "recovery", "javascript", "typescript", "python"]
---
# Optimize your entire codebase
Codeflash can optimize your entire codebase by analyzing all the functions in your project and generating optimized versions of them.
It iterates through all the functions in your codebase and optimizes them one by one.
It iterates through all the functions in your codebase and optimizes them one by one. This works for Python, JavaScript, and TypeScript projects.
To optimize your entire codebase, run the following command in your project directory:
@ -30,15 +30,27 @@ codeflash --all path/to/dir
```
<Tip>
If your project has a good number of unit tests, we can trace those to achieve higher quality results.
The following approach is recommended instead:
If your project has a good number of unit tests, tracing them achieves higher quality results.
<Tabs>
<Tab title="Python">
```bash
codeflash optimize --trace-only -m pytest tests/ ; codeflash --all
```
This will run your test suite, trace all the code covered by your tests, ensuring higher correctness guarantees
and better performance benchmarking, and help create optimizations for code where the LLMs struggle to generate and run tests.
</Tab>
<Tab title="JavaScript / TypeScript">
```bash
codeflash optimize --trace-only --jest ; codeflash --all
# or for Vitest projects
codeflash optimize --trace-only --vitest ; codeflash --all
```
</Tab>
</Tabs>
Even though `codeflash --all` discovers any existing unit tests. It currently can only discover any test that directly calls the
This runs your test suite, traces all the code covered by your tests, ensuring higher correctness guarantees
and better performance benchmarking, and helps create optimizations for code where the LLMs struggle to generate and run tests.
`codeflash --all` discovers any existing unit tests, but it currently can only discover tests that directly call the
function under optimization. Tracing all the tests helps ensure correctness for code that may be indirectly called by your tests.
</Tip>

View file

@ -26,9 +26,9 @@ We highly recommend setting this up, since once you set it up all your new code
✅ A Codeflash API key from the [Codeflash Web App](https://app.codeflash.ai/)
✅ Completed [local installation](/getting-started/local-installation) with `codeflash init`
✅ Completed local installation with `codeflash init` ([Python](/getting-started/local-installation) or [JavaScript/TypeScript](/getting-started/javascript-installation))
✅ A Python project with a configured `pyproject.toml` file
✅ A configured project (`pyproject.toml` for Python, `package.json` for JavaScript/TypeScript)
</Warning>
## Setup Options
@ -113,7 +113,7 @@ jobs:
</Step>
<Step title="Choose Your Package Manager">
Customize the dependency installation based on your Python package manager:
Customize the dependency installation based on your package manager:
The workflow will need to be set up in such a way the Codeflash can create and
run tests for functionality and speed, so the stock YAML may need to be altered to
@ -121,7 +121,7 @@ suit the specific codebase. Typically the setup steps for a unit test workflow c
be copied.
<CodeGroup>
```yaml Poetry
```yaml Poetry (Python)
- name: Install Project Dependencies
run: |
python -m pip install --upgrade pip
@ -129,11 +129,11 @@ be copied.
poetry install --with dev
- name: Run Codeflash to optimize code
run: |
poetry env use python
poetry env use python
poetry run codeflash
```
```yaml uv
```yaml uv (Python)
- uses: astral-sh/setup-uv@v6
with:
enable-cache: true
@ -142,7 +142,7 @@ be copied.
run: uv run codeflash
```
```yaml pip
```yaml pip (Python)
- name: Install Project Dependencies
run: |
python -m pip install --upgrade pip
@ -151,7 +151,50 @@ be copied.
- name: Run Codeflash to optimize code
run: codeflash
```
```yaml npm (JavaScript/TypeScript)
- uses: actions/setup-node@v4
with:
node-version: '18'
- name: Install Project Dependencies
run: npm ci
- name: Run Codeflash to optimize code
run: npx codeflash
```
```yaml yarn (JavaScript/TypeScript)
- uses: actions/setup-node@v4
with:
node-version: '18'
- name: Install Project Dependencies
run: yarn install --immutable
- name: Run Codeflash to optimize code
run: yarn codeflash
```
```yaml pnpm (JavaScript/TypeScript)
- uses: pnpm/action-setup@v4
with:
version: 9
- uses: actions/setup-node@v4
with:
node-version: '18'
cache: 'pnpm'
- name: Install Project Dependencies
run: pnpm install --frozen-lockfile
- name: Run Codeflash to optimize code
run: pnpm codeflash
```
</CodeGroup>
<Info>
**Monorepo?** If your codeflash config is in a subdirectory, add `working-directory` to the steps:
```yaml
- name: Run Codeflash to optimize code
run: npx codeflash
working-directory: packages/my-library
```
</Info>
</Step>
<Step title="Add Repository Secret">

View file

@ -1,6 +1,6 @@
---
title: "Optimize a Single Function"
description: "Target and optimize individual Python functions for maximum performance gains"
description: "Target and optimize individual functions for maximum performance gains"
icon: "bullseye"
sidebarTitle: "Optimize Single Function"
keywords:
@ -10,6 +10,9 @@ keywords:
"class methods",
"performance",
"targeted optimization",
"javascript",
"typescript",
"python",
]
---
@ -24,23 +27,55 @@ your mileage may vary.
## How to optimize a function
To optimize a function, you can run the following command in your project:
To optimize a function, run the following command in your project:
<Tabs>
<Tab title="Python">
```bash
codeflash --file path/to/your/file.py --function function_name
```
</Tab>
<Tab title="JavaScript">
```bash
codeflash --file path/to/your/file.js --function functionName
```
</Tab>
<Tab title="TypeScript">
```bash
codeflash --file path/to/your/file.ts --function functionName
```
</Tab>
</Tabs>
If you have installed the GitHub App to your repository, the above command will open a pull request with the optimized function.
If you want to optimize a function locally, you can add a `--no-pr` argument as follows:
If you want to optimize a function locally, add a `--no-pr` argument:
<Tabs>
<Tab title="Python">
```bash
codeflash --file path/to/your/file.py --function function_name --no-pr
```
</Tab>
<Tab title="JavaScript / TypeScript">
```bash
codeflash --file path/to/your/file.ts --function functionName --no-pr
```
</Tab>
</Tabs>
### Optimizing class methods
To optimize a method `method_name` in a class `ClassName`, you can run the following command:
To optimize a method `methodName` in a class `ClassName`:
<Tabs>
<Tab title="Python">
```bash
codeflash --file path/to/your/file.py --function ClassName.method_name
```
</Tab>
<Tab title="JavaScript / TypeScript">
```bash
codeflash --file path/to/your/file.ts --function ClassName.methodName
```
</Tab>
</Tabs>

View file

@ -1,6 +1,6 @@
---
title: "Trace & Optimize E2E Workflows"
description: "End-to-end optimization of entire Python workflows with execution tracing"
description: "End-to-end optimization of entire workflows with execution tracing"
icon: "route"
sidebarTitle: "Optimize E2E Workflows"
keywords:
@ -11,28 +11,50 @@ keywords:
"end-to-end",
"script optimization",
"context manager",
"javascript",
"typescript",
"jest",
"vitest",
]
---
Codeflash can optimize an entire Python script end-to-end by tracing the script's execution and generating Replay Tests.
Tracing follows the execution of a script, profiles it and captures inputs to all functions it called, allowing them to be replayed during optimization.
Codeflash uses these Replay Tests to optimize the most important functions called in the script, delivering the best performance for your workflow.
Codeflash can optimize an entire script or test suite end-to-end by tracing its execution and generating Replay Tests.
Tracing follows the execution of your code, profiles it and captures inputs to all functions it called, allowing them to be replayed during optimization.
Codeflash uses these Replay Tests to optimize the most important functions called in the workflow, delivering the best performance.
![Function Optimization](/images/priority-order.png)
To optimize a script, `python myscript.py`, simply replace `python` with `codeflash optimize` and run the following command:
<Tabs>
<Tab title="Python">
To optimize a script, `python myscript.py`, simply replace `python` with `codeflash optimize`:
```bash
codeflash optimize myscript.py
```
You can also optimize code called by pytest tests that you could normally run like `python -m pytest tests/`, this provides for a good workload to optimize. Run this command:
You can also optimize code called by pytest tests:
```bash
codeflash optimize -m pytest tests/
```
</Tab>
<Tab title="JavaScript / TypeScript">
To trace and optimize your Jest or Vitest tests:
The powerful `codeflash optimize` command creates high-quality optimizations, making it ideal when you need to optimize a workflow or script. The initial tracing process can be slow, so try to limit your script's runtime to under 1 minute for best results. If your workflow is longer, consider tracing it into smaller sections by using the Codeflash tracer as a context manager (point 3 below).
```bash
# Jest
codeflash optimize --jest
# Vitest
codeflash optimize --vitest
# Or trace a specific script
codeflash optimize --language javascript script.js
```
</Tab>
</Tabs>
The `codeflash optimize` command creates high-quality optimizations, making it ideal when you need to optimize a workflow or script. The initial tracing process can be slow, so try to limit your script's runtime to under 1 minute for best results.
The generated replay tests and the trace file are for the immediate optimization use, don't add them to git.
@ -61,6 +83,9 @@ This way you can be _sure_ that the optimized function causes no changes of beha
## Using codeflash optimize
<Tabs>
<Tab title="Python">
Codeflash script optimizer can be used in three ways:
1. **As an integrated command**
@ -100,10 +125,10 @@ Codeflash script optimizer can be used in three ways:
- `--timeout`: The maximum time in seconds to trace the entire workflow. Default is indefinite. This is useful while tracing really long workflows.
3. **As a Context Manager -**
3. **As a Context Manager**
To trace only specific sections of your code, You can also use the Codeflash Tracer as a context manager.
You can wrap the code you want to trace in a `with` statement as follows -
To trace only specific sections of your code, you can use the Codeflash Tracer as a context manager.
You can wrap the code you want to trace in a `with` statement as follows:
```python
from codeflash.tracer import Tracer
@ -128,3 +153,46 @@ Codeflash script optimizer can be used in three ways:
- `output`: The file to save the trace to. Default is `codeflash.trace`.
- `config_file_path`: The path to the `pyproject.toml` file which stores the Codeflash config. This is auto-discovered by default.
You can also disable the tracer in the code by setting the `disable=True` option in the `Tracer` constructor.
</Tab>
<Tab title="JavaScript / TypeScript">
The JavaScript tracer uses Babel instrumentation to capture function calls during your test suite execution.
1. **Trace your test suite**
```bash
# Jest projects
codeflash optimize --jest
# Vitest projects
codeflash optimize --vitest
# Trace a specific script
codeflash optimize --language javascript src/main.js
```
2. **Trace specific functions only**
```bash
codeflash optimize --jest --only-functions processData,transformInput
```
3. **Trace and optimize as two separate steps**
```bash
# Step 1: Create trace file
codeflash optimize --trace-only --jest --output trace_file.sqlite
# Step 2: Optimize with replay tests
codeflash --replay-test /path/to/test_replay_test_0.test.js
```
More Options:
- `--timeout`: Maximum tracing time in seconds.
- `--max-function-count`: Maximum traces per function (default: 256).
- `--only-functions`: Comma-separated list of function names to trace.
</Tab>
</Tabs>

View file

@ -1,12 +1,12 @@
{
"name": "codeflash",
"version": "0.8.0",
"version": "0.9.0",
"lockfileVersion": 3,
"requires": true,
"packages": {
"": {
"name": "codeflash",
"version": "0.8.0",
"version": "0.9.0",
"hasInstallScript": true,
"license": "MIT",
"dependencies": {

View file

@ -220,7 +220,7 @@ def validate_output(stdout: str, return_code: int, expected_improvement_pct: int
return False
if config.expected_unit_test_files is not None:
# Match the per-function test discovery message from function_optimizer.py
# Match the per-function discovery message from function_optimizer.py
# Format: "Discovered X existing unit test files, Y replay test files, and Z concolic..."
unit_test_files_match = re.search(r"Discovered (\d+) existing unit test files?", stdout)
if not unit_test_files_match:

View file

@ -1,4 +1,4 @@
from codeflash.code_utils.deduplicate_code import are_codes_duplicate, normalize_code
from codeflash.languages.python.normalizer import normalize_python_code as normalize_code
def test_deduplicate1():
@ -23,7 +23,7 @@ def compute_sum(numbers):
"""
assert normalize_code(code1) == normalize_code(code2)
assert are_codes_duplicate(code1, code2)
assert normalize_code(code1) == normalize_code(code2)
# Example 3: Same function and parameter names, different local variables (should match)
code3 = """
@ -43,7 +43,7 @@ def calculate_sum(numbers):
"""
assert normalize_code(code3) == normalize_code(code4)
assert are_codes_duplicate(code3, code4)
assert normalize_code(code3) == normalize_code(code4)
# Example 4: Nested functions and classes (preserving names)
code5 = """

View file

@ -18,7 +18,7 @@ from codeflash.languages.python.static_analysis.code_replacer import (
)
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
from codeflash.models.models import CodeOptimizationContext, CodeStringsMarkdown, FunctionParent, FunctionSource
from codeflash.optimization.function_optimizer import FunctionOptimizer
from codeflash.languages.python.function_optimizer import PythonFunctionOptimizer
from codeflash.verification.verification_utils import TestConfig
os.environ["CODEFLASH_API_KEY"] = "cf-test-key"
@ -54,7 +54,7 @@ def sorter(arr):
test_framework="pytest",
pytest_cmd="pytest",
)
func_optimizer = FunctionOptimizer(function_to_optimize=func, test_cfg=test_config)
func_optimizer = PythonFunctionOptimizer(function_to_optimize=func, test_cfg=test_config)
code_context: CodeOptimizationContext = func_optimizer.get_code_optimization_context().unwrap()
original_helper_code: dict[Path, str] = {}
helper_function_paths = {hf.file_path for hf in code_context.helper_functions}
@ -834,7 +834,7 @@ class MainClass:
test_framework="pytest",
pytest_cmd="pytest",
)
func_optimizer = FunctionOptimizer(function_to_optimize=func_top_optimize, test_cfg=test_config)
func_optimizer = PythonFunctionOptimizer(function_to_optimize=func_top_optimize, test_cfg=test_config)
code_context = func_optimizer.get_code_optimization_context().unwrap()
assert code_context.testgen_context.flat.rstrip() == get_code_output.rstrip()
@ -1745,7 +1745,7 @@ class NewClass:
test_framework="pytest",
pytest_cmd="pytest",
)
func_optimizer = FunctionOptimizer(function_to_optimize=func, test_cfg=test_config)
func_optimizer = PythonFunctionOptimizer(function_to_optimize=func, test_cfg=test_config)
code_context: CodeOptimizationContext = func_optimizer.get_code_optimization_context().unwrap()
original_helper_code: dict[Path, str] = {}
helper_function_paths = {hf.file_path for hf in code_context.helper_functions}
@ -1824,7 +1824,7 @@ a=2
test_framework="pytest",
pytest_cmd="pytest",
)
func_optimizer = FunctionOptimizer(function_to_optimize=func, test_cfg=test_config)
func_optimizer = PythonFunctionOptimizer(function_to_optimize=func, test_cfg=test_config)
code_context: CodeOptimizationContext = func_optimizer.get_code_optimization_context().unwrap()
original_helper_code: dict[Path, str] = {}
helper_function_paths = {hf.file_path for hf in code_context.helper_functions}
@ -1904,7 +1904,7 @@ class NewClass:
test_framework="pytest",
pytest_cmd="pytest",
)
func_optimizer = FunctionOptimizer(function_to_optimize=func, test_cfg=test_config)
func_optimizer = PythonFunctionOptimizer(function_to_optimize=func, test_cfg=test_config)
code_context: CodeOptimizationContext = func_optimizer.get_code_optimization_context().unwrap()
original_helper_code: dict[Path, str] = {}
helper_function_paths = {hf.file_path for hf in code_context.helper_functions}
@ -1983,7 +1983,7 @@ class NewClass:
test_framework="pytest",
pytest_cmd="pytest",
)
func_optimizer = FunctionOptimizer(function_to_optimize=func, test_cfg=test_config)
func_optimizer = PythonFunctionOptimizer(function_to_optimize=func, test_cfg=test_config)
code_context: CodeOptimizationContext = func_optimizer.get_code_optimization_context().unwrap()
original_helper_code: dict[Path, str] = {}
helper_function_paths = {hf.file_path for hf in code_context.helper_functions}
@ -2063,7 +2063,7 @@ class NewClass:
test_framework="pytest",
pytest_cmd="pytest",
)
func_optimizer = FunctionOptimizer(function_to_optimize=func, test_cfg=test_config)
func_optimizer = PythonFunctionOptimizer(function_to_optimize=func, test_cfg=test_config)
code_context: CodeOptimizationContext = func_optimizer.get_code_optimization_context().unwrap()
original_helper_code: dict[Path, str] = {}
helper_function_paths = {hf.file_path for hf in code_context.helper_functions}
@ -2153,7 +2153,7 @@ class NewClass:
test_framework="pytest",
pytest_cmd="pytest",
)
func_optimizer = FunctionOptimizer(function_to_optimize=func, test_cfg=test_config)
func_optimizer = PythonFunctionOptimizer(function_to_optimize=func, test_cfg=test_config)
code_context: CodeOptimizationContext = func_optimizer.get_code_optimization_context().unwrap()
original_helper_code: dict[Path, str] = {}
helper_function_paths = {hf.file_path for hf in code_context.helper_functions}
@ -3453,7 +3453,7 @@ def hydrate_input_text_actions_with_field_names(
test_framework="pytest",
pytest_cmd="pytest",
)
func_optimizer = FunctionOptimizer(function_to_optimize=func, test_cfg=test_config)
func_optimizer = PythonFunctionOptimizer(function_to_optimize=func, test_cfg=test_config)
code_context: CodeOptimizationContext = func_optimizer.get_code_optimization_context().unwrap()
original_helper_code: dict[Path, str] = {}

View file

@ -8,7 +8,7 @@ from codeflash.code_utils.code_utils import get_run_tmp_file
from codeflash.code_utils.compat import SAFE_SYS_EXECUTABLE
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
from codeflash.models.models import FunctionParent, TestFile, TestFiles, TestingMode, TestType, VerificationType
from codeflash.optimization.function_optimizer import FunctionOptimizer
from codeflash.languages.python.function_optimizer import PythonFunctionOptimizer
from codeflash.verification.equivalence import compare_test_results
from codeflash.verification.instrument_codeflash_capture import instrument_codeflash_capture
from codeflash.verification.test_runner import execute_test_subprocess
@ -459,7 +459,7 @@ class MyClass:
file_path=sample_code_path,
parents=[FunctionParent(name="MyClass", type="ClassDef")],
)
func_optimizer = FunctionOptimizer(function_to_optimize=fto, test_cfg=test_config)
func_optimizer = PythonFunctionOptimizer(function_to_optimize=fto, test_cfg=test_config)
func_optimizer.test_files = TestFiles(
test_files=[
TestFile(
@ -582,7 +582,7 @@ class MyClass(ParentClass):
file_path=sample_code_path,
parents=[FunctionParent(name="MyClass", type="ClassDef")],
)
func_optimizer = FunctionOptimizer(function_to_optimize=fto, test_cfg=test_config)
func_optimizer = PythonFunctionOptimizer(function_to_optimize=fto, test_cfg=test_config)
func_optimizer.test_files = TestFiles(
test_files=[
TestFile(
@ -709,7 +709,7 @@ class MyClass:
file_path=sample_code_path,
parents=[FunctionParent(name="MyClass", type="ClassDef")],
)
func_optimizer = FunctionOptimizer(function_to_optimize=fto, test_cfg=test_config)
func_optimizer = PythonFunctionOptimizer(function_to_optimize=fto, test_cfg=test_config)
func_optimizer.test_files = TestFiles(
test_files=[
TestFile(
@ -872,7 +872,7 @@ class AnotherHelperClass:
file_path=fto_file_path,
parents=[FunctionParent(name="MyClass", type="ClassDef")],
)
func_optimizer = FunctionOptimizer(function_to_optimize=fto, test_cfg=test_config)
func_optimizer = PythonFunctionOptimizer(function_to_optimize=fto, test_cfg=test_config)
func_optimizer.test_files = TestFiles(
test_files=[
TestFile(
@ -1021,7 +1021,7 @@ class AnotherHelperClass:
test_framework="pytest",
pytest_cmd="pytest",
)
func_optimizer = FunctionOptimizer(function_to_optimize=fto, test_cfg=test_config)
func_optimizer = PythonFunctionOptimizer(function_to_optimize=fto, test_cfg=test_config)
func_optimizer.test_files = TestFiles(
test_files=[
TestFile(
@ -1055,7 +1055,7 @@ class AnotherHelperClass:
)
# Remove instrumentation
FunctionOptimizer.write_code_and_helpers(candidate_fto_code, candidate_helper_code, fto.file_path)
PythonFunctionOptimizer.write_code_and_helpers(candidate_fto_code, candidate_helper_code, fto.file_path)
assert len(test_results.test_results) == 4
assert test_results[0].id.test_function_name == "test_helper_classes"
@ -1106,7 +1106,7 @@ class MyClass:
testing_time=0.1,
)
# Remove instrumentation
FunctionOptimizer.write_code_and_helpers(candidate_fto_code, candidate_helper_code, fto.file_path)
PythonFunctionOptimizer.write_code_and_helpers(candidate_fto_code, candidate_helper_code, fto.file_path)
# Now, this fto_code mutates the instance so it should fail
mutated_fto_code = """
@ -1145,7 +1145,7 @@ class MyClass:
testing_time=0.1,
)
# Remove instrumentation
FunctionOptimizer.write_code_and_helpers(candidate_fto_code, candidate_helper_code, fto.file_path)
PythonFunctionOptimizer.write_code_and_helpers(candidate_fto_code, candidate_helper_code, fto.file_path)
match, _ = compare_test_results(test_results, mutated_test_results)
assert not match
@ -1184,7 +1184,7 @@ class MyClass:
testing_time=0.1,
)
# Remove instrumentation
FunctionOptimizer.write_code_and_helpers(candidate_fto_code, candidate_helper_code, fto.file_path)
PythonFunctionOptimizer.write_code_and_helpers(candidate_fto_code, candidate_helper_code, fto.file_path)
match, _ = compare_test_results(test_results, no_helper1_test_results)
assert match
@ -1446,7 +1446,7 @@ def calculate_portfolio_metrics(
test_framework="pytest",
pytest_cmd="pytest",
)
func_optimizer = FunctionOptimizer(function_to_optimize=fto, test_cfg=test_config)
func_optimizer = PythonFunctionOptimizer(function_to_optimize=fto, test_cfg=test_config)
func_optimizer.test_files = TestFiles(
test_files=[
TestFile(
@ -1477,7 +1477,7 @@ def calculate_portfolio_metrics(
)
# Remove instrumentation
FunctionOptimizer.write_code_and_helpers(candidate_fto_code, candidate_helper_code, fto.file_path)
PythonFunctionOptimizer.write_code_and_helpers(candidate_fto_code, candidate_helper_code, fto.file_path)
# Now, let's say we optimize the code and make changes.
new_fto_code = """import math
@ -1543,7 +1543,7 @@ def calculate_portfolio_metrics(
testing_time=0.1,
)
# Remove instrumentation
FunctionOptimizer.write_code_and_helpers(candidate_fto_code, candidate_helper_code, fto.file_path)
PythonFunctionOptimizer.write_code_and_helpers(candidate_fto_code, candidate_helper_code, fto.file_path)
matched, diffs = compare_test_results(test_results, modified_test_results)
assert not matched
@ -1606,7 +1606,7 @@ def calculate_portfolio_metrics(
testing_time=0.1,
)
# Remove instrumentation
FunctionOptimizer.write_code_and_helpers(candidate_fto_code, candidate_helper_code, fto.file_path)
PythonFunctionOptimizer.write_code_and_helpers(candidate_fto_code, candidate_helper_code, fto.file_path)
matched, diffs = compare_test_results(test_results, modified_test_results_2)
# now the test should match and no diffs should be found
assert len(diffs) == 0
@ -1671,7 +1671,7 @@ class SlotsClass:
file_path=sample_code_path,
parents=[FunctionParent(name="SlotsClass", type="ClassDef")],
)
func_optimizer = FunctionOptimizer(function_to_optimize=fto, test_cfg=test_config)
func_optimizer = PythonFunctionOptimizer(function_to_optimize=fto, test_cfg=test_config)
func_optimizer.test_files = TestFiles(
test_files=[
TestFile(

View file

@ -5,7 +5,7 @@ import pytest
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
from codeflash.either import is_successful
from codeflash.models.models import FunctionParent
from codeflash.optimization.function_optimizer import FunctionOptimizer
from codeflash.languages.python.function_optimizer import PythonFunctionOptimizer
from codeflash.verification.verification_utils import TestConfig
@ -132,7 +132,7 @@ def test_class_method_dependencies() -> None:
starting_line=None,
ending_line=None,
)
func_optimizer = FunctionOptimizer(
func_optimizer = PythonFunctionOptimizer(
function_to_optimize=function_to_optimize,
test_cfg=TestConfig(
tests_root=file_path,
@ -202,7 +202,7 @@ def test_recursive_function_context() -> None:
starting_line=None,
ending_line=None,
)
func_optimizer = FunctionOptimizer(
func_optimizer = PythonFunctionOptimizer(
function_to_optimize=function_to_optimize,
test_cfg=TestConfig(
tests_root=file_path,

View file

@ -680,8 +680,14 @@ def test_in_dunder_tests():
# Combine all discovered functions
all_functions = {}
for discovered in [discovered_source, discovered_test, discovered_test_underscore,
discovered_spec, discovered_tests_dir, discovered_dunder_tests]:
for discovered in [
discovered_source,
discovered_test,
discovered_test_underscore,
discovered_spec,
discovered_tests_dir,
discovered_dunder_tests,
]:
all_functions.update(discovered)
# Test Case 1: tests_root == module_root (overlapping case)
@ -781,9 +787,7 @@ def test_filter_functions_strict_string_matching():
# Strict check: exactly these 3 files should remain (those with 'test' as substring only)
expected_files = {contest_file, latest_file, attestation_file}
assert set(filtered.keys()) == expected_files, (
f"Expected files {expected_files}, got {set(filtered.keys())}"
)
assert set(filtered.keys()) == expected_files, f"Expected files {expected_files}, got {set(filtered.keys())}"
# Strict check: each file should have exactly 1 function with the expected name
assert [fn.function_name for fn in filtered[contest_file]] == ["run_contest"], (
@ -871,9 +875,7 @@ def test_filter_functions_test_directory_patterns():
# Strict check: exactly these 2 files should remain (those in non-test directories)
expected_files = {contest_file, latest_file}
assert set(filtered.keys()) == expected_files, (
f"Expected files {expected_files}, got {set(filtered.keys())}"
)
assert set(filtered.keys()) == expected_files, f"Expected files {expected_files}, got {set(filtered.keys())}"
# Strict check: each file should have exactly 1 function with the expected name
assert [fn.function_name for fn in filtered[contest_file]] == ["get_scores"], (
@ -936,9 +938,7 @@ def test_filter_functions_non_overlapping_tests_root():
# Strict check: exactly these 2 files should remain (both in src/, not in tests/)
expected_files = {source_file, test_in_src}
assert set(filtered.keys()) == expected_files, (
f"Expected files {expected_files}, got {set(filtered.keys())}"
)
assert set(filtered.keys()) == expected_files, f"Expected files {expected_files}, got {set(filtered.keys())}"
# Strict check: each file should have exactly 1 function with the expected name
assert [fn.function_name for fn in filtered[source_file]] == ["process"], (
@ -1047,20 +1047,15 @@ def test_deep_copy():
)
root_functions = [fn.function_name for fn in filtered.get(root_source_file, [])]
assert root_functions == ["main"], (
f"Expected ['main'], got {root_functions}"
)
assert root_functions == ["main"], f"Expected ['main'], got {root_functions}"
# Strict check: exactly 3 functions (2 from utils.py + 1 from main.py)
assert count == 3, (
f"Expected exactly 3 functions, got {count}. "
f"Some source files may have been incorrectly filtered."
f"Expected exactly 3 functions, got {count}. Some source files may have been incorrectly filtered."
)
# Verify test file was properly filtered (should not be in results)
assert test_file not in filtered, (
f"Test file {test_file} should have been filtered but wasn't"
)
assert test_file not in filtered, f"Test file {test_file} should have been filtered but wasn't"
def test_filter_functions_typescript_project_in_tests_folder():
@ -1214,9 +1209,7 @@ def sample_data():
# source_file and file_in_test_dir should remain
# test_prefix_file, conftest_file, and test_in_subdir should be filtered
expected_files = {source_file, file_in_test_dir}
assert set(filtered.keys()) == expected_files, (
f"Expected {expected_files}, got {set(filtered.keys())}"
)
assert set(filtered.keys()) == expected_files, f"Expected {expected_files}, got {set(filtered.keys())}"
assert count == 2, f"Expected exactly 2 functions, got {count}"
@ -1266,7 +1259,8 @@ class TestHelpers:
""")
support = PythonSupport()
functions = support.discover_functions(fixture_file)
source = fixture_file.read_text(encoding="utf-8")
functions = support.discover_functions(source, fixture_file)
function_names = [fn.function_name for fn in functions]
assert "regular_function" in function_names

View file

@ -7,7 +7,7 @@ import pytest
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
from codeflash.either import is_successful
from codeflash.models.models import FunctionParent, get_code_block_splitter
from codeflash.optimization.function_optimizer import FunctionOptimizer
from codeflash.languages.python.function_optimizer import PythonFunctionOptimizer
from codeflash.optimization.optimizer import Optimizer
from codeflash.verification.verification_utils import TestConfig
@ -233,7 +233,7 @@ class _PersistentCache(Generic[_P, _R, _CacheBackendT]):
test_framework="pytest",
pytest_cmd="pytest",
)
func_optimizer = FunctionOptimizer(function_to_optimize=function_to_optimize, test_cfg=test_config)
func_optimizer = PythonFunctionOptimizer(function_to_optimize=function_to_optimize, test_cfg=test_config)
with open(file_path) as f:
original_code = f.read()
ctx_result = func_optimizer.get_code_optimization_context()
@ -404,7 +404,7 @@ def test_bubble_sort_deps() -> None:
test_framework="pytest",
pytest_cmd="pytest",
)
func_optimizer = FunctionOptimizer(function_to_optimize=function_to_optimize, test_cfg=test_config)
func_optimizer = PythonFunctionOptimizer(function_to_optimize=function_to_optimize, test_cfg=test_config)
with open(file_path) as f:
original_code = f.read()
ctx_result = func_optimizer.get_code_optimization_context()

View file

@ -5,7 +5,7 @@ from tempfile import TemporaryDirectory
from codeflash.languages.python.static_analysis.line_profile_utils import add_decorator_imports, contains_jit_decorator
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
from codeflash.models.models import CodeOptimizationContext
from codeflash.optimization.function_optimizer import FunctionOptimizer
from codeflash.languages.python.function_optimizer import PythonFunctionOptimizer
from codeflash.verification.verification_utils import TestConfig
@ -22,7 +22,7 @@ def test_add_decorator_imports_helper_in_class():
pytest_cmd="pytest",
)
func = FunctionToOptimize(function_name="sort_classmethod", parents=[], file_path=code_path)
func_optimizer = FunctionOptimizer(function_to_optimize=func, test_cfg=test_config)
func_optimizer = PythonFunctionOptimizer(function_to_optimize=func, test_cfg=test_config)
os.chdir(run_cwd)
# func_optimizer = pass
try:
@ -94,7 +94,7 @@ def test_add_decorator_imports_helper_in_nested_class():
pytest_cmd="pytest",
)
func = FunctionToOptimize(function_name="sort_classmethod", parents=[], file_path=code_path)
func_optimizer = FunctionOptimizer(function_to_optimize=func, test_cfg=test_config)
func_optimizer = PythonFunctionOptimizer(function_to_optimize=func, test_cfg=test_config)
os.chdir(run_cwd)
# func_optimizer = pass
try:
@ -143,7 +143,7 @@ def test_add_decorator_imports_nodeps():
pytest_cmd="pytest",
)
func = FunctionToOptimize(function_name="sorter", parents=[], file_path=code_path)
func_optimizer = FunctionOptimizer(function_to_optimize=func, test_cfg=test_config)
func_optimizer = PythonFunctionOptimizer(function_to_optimize=func, test_cfg=test_config)
os.chdir(run_cwd)
# func_optimizer = pass
try:
@ -194,7 +194,7 @@ def test_add_decorator_imports_helper_outside():
pytest_cmd="pytest",
)
func = FunctionToOptimize(function_name="sorter_deps", parents=[], file_path=code_path)
func_optimizer = FunctionOptimizer(function_to_optimize=func, test_cfg=test_config)
func_optimizer = PythonFunctionOptimizer(function_to_optimize=func, test_cfg=test_config)
os.chdir(run_cwd)
# func_optimizer = pass
try:
@ -271,7 +271,7 @@ class helper:
pytest_cmd="pytest",
)
func = FunctionToOptimize(function_name="sorter", parents=[], file_path=code_write_path)
func_optimizer = FunctionOptimizer(function_to_optimize=func, test_cfg=test_config)
func_optimizer = PythonFunctionOptimizer(function_to_optimize=func, test_cfg=test_config)
os.chdir(run_cwd)
# func_optimizer = pass
try:

View file

@ -27,7 +27,7 @@ from codeflash.models.models import (
TestsInFile,
TestType,
)
from codeflash.optimization.function_optimizer import FunctionOptimizer
from codeflash.languages.python.function_optimizer import PythonFunctionOptimizer
from codeflash.verification.verification_utils import TestConfig
codeflash_wrap_string = """def codeflash_wrap(codeflash_wrapped, codeflash_test_module_name, codeflash_test_class_name, codeflash_test_name, codeflash_function_name, codeflash_line_id, codeflash_loop_index, codeflash_cur, codeflash_con, *args, **kwargs):
@ -434,7 +434,7 @@ def test_sort():
test_framework="pytest",
pytest_cmd="pytest",
)
func_optimizer = FunctionOptimizer(function_to_optimize=func, test_cfg=test_config)
func_optimizer = PythonFunctionOptimizer(function_to_optimize=func, test_cfg=test_config)
test_env = os.environ.copy()
test_env["CODEFLASH_TEST_ITERATION"] = "0"
test_env["CODEFLASH_LOOP_INDEX"] = "1"
@ -695,7 +695,7 @@ def test_sort_parametrized(input, expected_output):
test_framework="pytest",
pytest_cmd="pytest",
)
func_optimizer = FunctionOptimizer(function_to_optimize=func, test_cfg=test_config)
func_optimizer = PythonFunctionOptimizer(function_to_optimize=func, test_cfg=test_config)
test_results, coverage_data = func_optimizer.run_and_parse_tests(
testing_type=TestingMode.BEHAVIOR,
test_env=test_env,
@ -984,7 +984,7 @@ def test_sort_parametrized_loop(input, expected_output):
test_framework="pytest",
pytest_cmd="pytest",
)
func_optimizer = FunctionOptimizer(function_to_optimize=func, test_cfg=test_config)
func_optimizer = PythonFunctionOptimizer(function_to_optimize=func, test_cfg=test_config)
test_results, coverage_data = func_optimizer.run_and_parse_tests(
testing_type=TestingMode.BEHAVIOR,
test_env=test_env,
@ -1341,7 +1341,7 @@ def test_sort():
test_framework="pytest",
pytest_cmd="pytest",
)
func_optimizer = FunctionOptimizer(function_to_optimize=func, test_cfg=test_config)
func_optimizer = PythonFunctionOptimizer(function_to_optimize=func, test_cfg=test_config)
test_results, coverage_data = func_optimizer.run_and_parse_tests(
testing_type=TestingMode.BEHAVIOR,
test_env=test_env,
@ -1723,7 +1723,7 @@ class TestPigLatin(unittest.TestCase):
test_framework="unittest",
pytest_cmd="pytest",
)
func_optimizer = FunctionOptimizer(function_to_optimize=func, test_cfg=test_config)
func_optimizer = PythonFunctionOptimizer(function_to_optimize=func, test_cfg=test_config)
test_results, coverage_data = func_optimizer.run_and_parse_tests(
testing_type=TestingMode.BEHAVIOR,
test_env=test_env,
@ -1973,7 +1973,7 @@ import unittest
test_framework="unittest",
pytest_cmd="pytest",
)
func_optimizer = FunctionOptimizer(function_to_optimize=func, test_cfg=test_config)
func_optimizer = PythonFunctionOptimizer(function_to_optimize=func, test_cfg=test_config)
test_results, coverage_data = func_optimizer.run_and_parse_tests(
testing_type=TestingMode.BEHAVIOR,
test_env=test_env,
@ -2229,7 +2229,7 @@ import unittest
test_framework="unittest",
pytest_cmd="pytest",
)
func_optimizer = FunctionOptimizer(function_to_optimize=func, test_cfg=test_config)
func_optimizer = PythonFunctionOptimizer(function_to_optimize=func, test_cfg=test_config)
test_results, coverage_data = func_optimizer.run_and_parse_tests(
test_env=test_env,
testing_type=TestingMode.BEHAVIOR,
@ -2481,7 +2481,7 @@ import unittest
test_framework="unittest",
pytest_cmd="pytest",
)
func_optimizer = FunctionOptimizer(function_to_optimize=f, test_cfg=test_config)
func_optimizer = PythonFunctionOptimizer(function_to_optimize=f, test_cfg=test_config)
test_results, coverage_data = func_optimizer.run_and_parse_tests(
testing_type=TestingMode.BEHAVIOR,
test_env=test_env,
@ -3144,7 +3144,7 @@ def test_sleepfunc_sequence_short(n, expected_total_sleep_time):
test_framework="pytest",
pytest_cmd="pytest",
)
func_optimizer = FunctionOptimizer(function_to_optimize=func, test_cfg=test_config)
func_optimizer = PythonFunctionOptimizer(function_to_optimize=func, test_cfg=test_config)
test_files = TestFiles(
test_files=[
TestFile(
@ -3279,7 +3279,7 @@ import unittest
test_framework="unittest",
pytest_cmd="pytest",
)
func_optimizer = FunctionOptimizer(function_to_optimize=func, test_cfg=test_config)
func_optimizer = PythonFunctionOptimizer(function_to_optimize=func, test_cfg=test_config)
test_results, coverage_data = func_optimizer.run_and_parse_tests(
testing_type=TestingMode.PERFORMANCE,
test_env=test_env,

View file

@ -20,12 +20,15 @@ All assertions use strict string equality to verify exact extraction output.
from __future__ import annotations
from unittest.mock import MagicMock
import pytest
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
from codeflash.languages.base import Language
from codeflash.languages.javascript.support import JavaScriptSupport, TypeScriptSupport
from codeflash.languages.python.context.code_context_extractor import get_code_optimization_context_for_language
from codeflash.languages.javascript.function_optimizer import JavaScriptFunctionOptimizer
from codeflash.verification.verification_utils import TestConfig
@pytest.fixture
@ -61,7 +64,8 @@ export function add(a, b) {
file_path = temp_project / "math.js"
file_path.write_text(code, encoding="utf-8")
functions = js_support.discover_functions(file_path)
source = file_path.read_text(encoding="utf-8")
functions = js_support.discover_functions(source, file_path)
assert len(functions) == 1
func = functions[0]
@ -87,7 +91,8 @@ export const multiply = (a, b) => a * b;
file_path = temp_project / "math.js"
file_path.write_text(code, encoding="utf-8")
functions = js_support.discover_functions(file_path)
source = file_path.read_text(encoding="utf-8")
functions = js_support.discover_functions(source, file_path)
assert len(functions) == 1
func = functions[0]
assert func.function_name == "multiply"
@ -121,7 +126,8 @@ export function add(a, b) {
file_path = temp_project / "math.js"
file_path.write_text(code, encoding="utf-8")
functions = js_support.discover_functions(file_path)
source = file_path.read_text(encoding="utf-8")
functions = js_support.discover_functions(source, file_path)
func = functions[0]
context = js_support.extract_code_context(func, temp_project, temp_project)
@ -173,7 +179,8 @@ export async function processItems(items, callback, options = {}) {
file_path = temp_project / "processor.js"
file_path.write_text(code, encoding="utf-8")
functions = js_support.discover_functions(file_path)
source = file_path.read_text(encoding="utf-8")
functions = js_support.discover_functions(source, file_path)
func = functions[0]
context = js_support.extract_code_context(func, temp_project, temp_project)
@ -243,7 +250,8 @@ export class CacheManager {
file_path = temp_project / "cache.js"
file_path.write_text(code, encoding="utf-8")
functions = js_support.discover_functions(file_path)
source = file_path.read_text(encoding="utf-8")
functions = js_support.discover_functions(source, file_path)
get_or_compute = next(f for f in functions if f.function_name == "getOrCompute")
context = js_support.extract_code_context(get_or_compute, temp_project, temp_project)
@ -339,7 +347,8 @@ export function validateUserData(data, validators) {
file_path = temp_project / "validator.js"
file_path.write_text(code, encoding="utf-8")
functions = js_support.discover_functions(file_path)
source = file_path.read_text(encoding="utf-8")
functions = js_support.discover_functions(source, file_path)
func = next(f for f in functions if f.function_name == "validateUserData")
context = js_support.extract_code_context(func, temp_project, temp_project)
@ -429,7 +438,8 @@ export async function fetchWithRetry(endpoint, options = {}) {
file_path = temp_project / "api.js"
file_path.write_text(code, encoding="utf-8")
functions = js_support.discover_functions(file_path)
source = file_path.read_text(encoding="utf-8")
functions = js_support.discover_functions(source, file_path)
func = next(f for f in functions if f.function_name == "fetchWithRetry")
context = js_support.extract_code_context(func, temp_project, temp_project)
@ -515,7 +525,8 @@ export function validateField(value, fieldType) {
file_path = temp_project / "validation.js"
file_path.write_text(code, encoding="utf-8")
functions = js_support.discover_functions(file_path)
source = file_path.read_text(encoding="utf-8")
functions = js_support.discover_functions(source, file_path)
func = functions[0]
context = js_support.extract_code_context(func, temp_project, temp_project)
@ -578,7 +589,8 @@ export function processUserInput(rawInput) {
file_path = temp_project / "processor.js"
file_path.write_text(code, encoding="utf-8")
functions = js_support.discover_functions(file_path)
source = file_path.read_text(encoding="utf-8")
functions = js_support.discover_functions(source, file_path)
process_func = next(f for f in functions if f.function_name == "processUserInput")
context = js_support.extract_code_context(process_func, temp_project, temp_project)
@ -633,7 +645,8 @@ export function generateReport(data) {
file_path = temp_project / "report.js"
file_path.write_text(code, encoding="utf-8")
functions = js_support.discover_functions(file_path)
source = file_path.read_text(encoding="utf-8")
functions = js_support.discover_functions(source, file_path)
report_func = next(f for f in functions if f.function_name == "generateReport")
context = js_support.extract_code_context(report_func, temp_project, temp_project)
@ -731,7 +744,8 @@ export class Graph {
file_path = temp_project / "graph.js"
file_path.write_text(code, encoding="utf-8")
functions = js_support.discover_functions(file_path)
source = file_path.read_text(encoding="utf-8")
functions = js_support.discover_functions(source, file_path)
topo_sort = next(f for f in functions if f.function_name == "topologicalSort")
context = js_support.extract_code_context(topo_sort, temp_project, temp_project)
@ -819,7 +833,8 @@ export class MainClass {
file_path = temp_project / "classes.js"
file_path.write_text(code, encoding="utf-8")
functions = js_support.discover_functions(file_path)
source = file_path.read_text(encoding="utf-8")
functions = js_support.discover_functions(source, file_path)
main_method = next(f for f in functions if f.function_name == "mainMethod" and f.class_name == "MainClass")
context = js_support.extract_code_context(main_method, temp_project, temp_project)
@ -875,7 +890,8 @@ module.exports = { sortFromAnotherFile };
main_path = temp_project / "bubble_sort_imported.js"
main_path.write_text(main_code, encoding="utf-8")
functions = js_support.discover_functions(main_path)
source = main_path.read_text(encoding="utf-8")
functions = js_support.discover_functions(source, main_path)
main_func = next(f for f in functions if f.function_name == "sortFromAnotherFile")
context = js_support.extract_code_context(main_func, temp_project, temp_project)
@ -926,7 +942,8 @@ export function processNumber(n) {
main_path = temp_project / "main.js"
main_path.write_text(main_code, encoding="utf-8")
functions = js_support.discover_functions(main_path)
source = main_path.read_text(encoding="utf-8")
functions = js_support.discover_functions(source, main_path)
process_func = next(f for f in functions if f.function_name == "processNumber")
context = js_support.extract_code_context(process_func, temp_project, temp_project)
@ -992,7 +1009,8 @@ export function handleUserInput(rawInput) {
main_path = temp_project / "main.js"
main_path.write_text(main_code, encoding="utf-8")
functions = js_support.discover_functions(main_path)
source = main_path.read_text(encoding="utf-8")
functions = js_support.discover_functions(source, main_path)
handle_func = next(f for f in functions if f.function_name == "handleUserInput")
context = js_support.extract_code_context(handle_func, temp_project, temp_project)
@ -1043,7 +1061,8 @@ export function createEntity<T extends object>(data: T): Entity<T> {
file_path = temp_project / "entity.ts"
file_path.write_text(code, encoding="utf-8")
functions = ts_support.discover_functions(file_path)
source = file_path.read_text(encoding="utf-8")
functions = ts_support.discover_functions(source, file_path)
func = functions[0]
context = ts_support.extract_code_context(func, temp_project, temp_project)
@ -1133,7 +1152,8 @@ export class TypedCache<T> {
file_path = temp_project / "cache.ts"
file_path.write_text(code, encoding="utf-8")
functions = ts_support.discover_functions(file_path)
source = file_path.read_text(encoding="utf-8")
functions = ts_support.discover_functions(source, file_path)
get_method = next(f for f in functions if f.function_name == "get")
context = ts_support.extract_code_context(get_method, temp_project, temp_project)
@ -1217,7 +1237,8 @@ export function createUser(input: CreateUserInput, role: UserRole = DEFAULT_ROLE
service_path = temp_project / "service.ts"
service_path.write_text(service_code, encoding="utf-8")
functions = ts_support.discover_functions(service_path)
source = service_path.read_text(encoding="utf-8")
functions = ts_support.discover_functions(source, service_path)
func = next(f for f in functions if f.function_name == "createUser")
context = ts_support.extract_code_context(func, temp_project, temp_project)
@ -1271,7 +1292,8 @@ export function factorial(n) {
file_path = temp_project / "math.js"
file_path.write_text(code, encoding="utf-8")
functions = js_support.discover_functions(file_path)
source = file_path.read_text(encoding="utf-8")
functions = js_support.discover_functions(source, file_path)
func = functions[0]
context = js_support.extract_code_context(func, temp_project, temp_project)
@ -1301,7 +1323,8 @@ export function isOdd(n) {
file_path = temp_project / "parity.js"
file_path.write_text(code, encoding="utf-8")
functions = js_support.discover_functions(file_path)
source = file_path.read_text(encoding="utf-8")
functions = js_support.discover_functions(source, file_path)
is_even = next(f for f in functions if f.function_name == "isEven")
context = js_support.extract_code_context(is_even, temp_project, temp_project)
@ -1319,12 +1342,15 @@ export function isEven(n) {
assert helper_names == ["isOdd"]
# Verify helper source
assert context.helper_functions[0].source_code == """\
assert (
context.helper_functions[0].source_code
== """\
export function isOdd(n) {
if (n === 0) return false;
return isEven(n - 1);
}
"""
)
def test_complex_recursive_tree_traversal(self, js_support, temp_project):
"""Test complex recursive tree traversal with multiple recursive calls."""
@ -1363,7 +1389,8 @@ export function collectAllValues(root) {
file_path = temp_project / "tree.js"
file_path.write_text(code, encoding="utf-8")
functions = js_support.discover_functions(file_path)
source = file_path.read_text(encoding="utf-8")
functions = js_support.discover_functions(source, file_path)
collect_func = next(f for f in functions if f.function_name == "collectAllValues")
context = js_support.extract_code_context(collect_func, temp_project, temp_project)
@ -1428,7 +1455,8 @@ export async function fetchUserProfile(userId) {
file_path = temp_project / "api.js"
file_path.write_text(code, encoding="utf-8")
functions = js_support.discover_functions(file_path)
source = file_path.read_text(encoding="utf-8")
functions = js_support.discover_functions(source, file_path)
profile_func = next(f for f in functions if f.function_name == "fetchUserProfile")
context = js_support.extract_code_context(profile_func, temp_project, temp_project)
@ -1483,7 +1511,8 @@ module.exports = { Counter };
file_path = temp_project / "counter.js"
file_path.write_text(original_source, encoding="utf-8")
functions = js_support.discover_functions(file_path)
source = file_path.read_text(encoding="utf-8")
functions = js_support.discover_functions(source, file_path)
increment_func = next(fn for fn in functions if fn.function_name == "increment")
# Step 1: Extract code context
@ -1563,7 +1592,8 @@ export function processApiResponse({
file_path = temp_project / "api.js"
file_path.write_text(code, encoding="utf-8")
functions = js_support.discover_functions(file_path)
source = file_path.read_text(encoding="utf-8")
functions = js_support.discover_functions(source, file_path)
func = functions[0]
context = js_support.extract_code_context(func, temp_project, temp_project)
@ -1605,7 +1635,8 @@ export function* fibonacci(limit) {
file_path = temp_project / "generators.js"
file_path.write_text(code, encoding="utf-8")
functions = js_support.discover_functions(file_path)
source = file_path.read_text(encoding="utf-8")
functions = js_support.discover_functions(source, file_path)
range_func = next(f for f in functions if f.function_name == "range")
context = js_support.extract_code_context(range_func, temp_project, temp_project)
@ -1640,7 +1671,8 @@ export function createUserObject(name, email, age) {
file_path = temp_project / "user.js"
file_path.write_text(code, encoding="utf-8")
functions = js_support.discover_functions(file_path)
source = file_path.read_text(encoding="utf-8")
functions = js_support.discover_functions(source, file_path)
func = functions[0]
context = js_support.extract_code_context(func, temp_project, temp_project)
@ -1790,7 +1822,8 @@ export const sendSlackMessage = async (
file_path.write_text(code, encoding="utf-8")
target_func = "sendSlackMessage"
functions = ts_support.discover_functions(file_path)
source = file_path.read_text(encoding="utf-8")
functions = ts_support.discover_functions(source, file_path)
func_info = next(f for f in functions if f.function_name == target_func)
fto = FunctionToOptimize(
function_name=target_func,
@ -1804,9 +1837,11 @@ export const sendSlackMessage = async (
language="typescript",
)
ctx = get_code_optimization_context_for_language(
fto, temp_project
test_config = TestConfig(
tests_root=temp_project, tests_project_rootdir=temp_project, project_root_path=temp_project
)
func_optimizer = JavaScriptFunctionOptimizer(function_to_optimize=fto, test_cfg=test_config, aiservice_client=MagicMock())
ctx = func_optimizer.get_code_optimization_context().unwrap()
# The read_writable_code should contain the target function AND helper functions
expected_read_writable = """```typescript:slack_util.ts
@ -1899,7 +1934,6 @@ let web: WebClient | null = null"""
assert ctx.read_only_context_code == expected_read_only
class TestContextProperties:
"""Tests for CodeContext object properties."""
@ -1913,7 +1947,8 @@ export function test() {
file_path = temp_project / "test.js"
file_path.write_text(code, encoding="utf-8")
functions = js_support.discover_functions(file_path)
source = file_path.read_text(encoding="utf-8")
functions = js_support.discover_functions(source, file_path)
context = js_support.extract_code_context(functions[0], temp_project, temp_project)
assert context.language == Language.JAVASCRIPT
@ -1932,7 +1967,8 @@ export function test(): number {
file_path = temp_project / "test.ts"
file_path.write_text(code, encoding="utf-8")
functions = ts_support.discover_functions(file_path)
source = file_path.read_text(encoding="utf-8")
functions = ts_support.discover_functions(source, file_path)
context = ts_support.extract_code_context(functions[0], temp_project, temp_project)
# TypeScript uses JavaScript language enum
@ -1974,7 +2010,8 @@ export class Calculator {
file_path = temp_project / "calculator.js"
file_path.write_text(code, encoding="utf-8")
functions = js_support.discover_functions(file_path)
source = file_path.read_text(encoding="utf-8")
functions = js_support.discover_functions(source, file_path)
for func in functions:
if func.function_name != "constructor":

View file

@ -107,10 +107,8 @@ class TestJavaScriptCodeContext:
"""Test extracting code context for a JavaScript function."""
skip_if_js_not_supported()
from codeflash.discovery.functions_to_optimize import find_all_functions_in_file
from codeflash.languages import current as lang_current
from codeflash.languages.python.context.code_context_extractor import get_code_optimization_context
lang_current._current_language = Language.JAVASCRIPT
from codeflash.languages import get_language_support
from codeflash.languages.javascript.function_optimizer import JavaScriptFunctionOptimizer
fib_file = js_project_dir / "fibonacci.js"
if not fib_file.exists():
@ -122,7 +120,11 @@ class TestJavaScriptCodeContext:
fib_func = next((f for f in func_list if f.function_name == "fibonacci"), None)
assert fib_func is not None
context = get_code_optimization_context(fib_func, js_project_dir)
js_support = get_language_support(Language.JAVASCRIPT)
code_context = js_support.extract_code_context(fib_func, js_project_dir, js_project_dir)
context = JavaScriptFunctionOptimizer._build_optimization_context(
code_context, fib_file, "javascript", js_project_dir
)
assert context.read_writable_code is not None
assert context.read_writable_code.language == "javascript"

View file

@ -71,10 +71,8 @@ module.exports = { add };
"""Verify language is preserved in code context extraction."""
skip_if_js_not_supported()
from codeflash.discovery.functions_to_optimize import find_all_functions_in_file
from codeflash.languages import current as lang_current
from codeflash.languages.python.context.code_context_extractor import get_code_optimization_context
lang_current._current_language = Language.TYPESCRIPT
from codeflash.languages import get_language_support
from codeflash.languages.javascript.function_optimizer import JavaScriptFunctionOptimizer
ts_file = tmp_path / "utils.ts"
ts_file.write_text("""
@ -86,7 +84,11 @@ export function add(a: number, b: number): number {
functions = find_all_functions_in_file(ts_file)
func = functions[ts_file][0]
context = get_code_optimization_context(func, tmp_path)
ts_support = get_language_support(Language.TYPESCRIPT)
code_context = ts_support.extract_code_context(func, tmp_path, tmp_path)
context = JavaScriptFunctionOptimizer._build_optimization_context(
code_context, ts_file, "typescript", tmp_path
)
assert context.read_writable_code is not None
assert context.read_writable_code.language == "typescript"
@ -373,10 +375,7 @@ describe('fibonacci', () => {
"""Test get_code_optimization_context for JavaScript."""
skip_if_js_not_supported()
from codeflash.discovery.functions_to_optimize import find_all_functions_in_file
from codeflash.languages import current as lang_current
from codeflash.optimization.function_optimizer import FunctionOptimizer
lang_current._current_language = Language.JAVASCRIPT
from codeflash.languages.javascript.function_optimizer import JavaScriptFunctionOptimizer
src_file = js_project / "utils.js"
functions = find_all_functions_in_file(src_file)
@ -398,7 +397,7 @@ describe('fibonacci', () => {
pytest_cmd="jest",
)
optimizer = FunctionOptimizer(
optimizer = JavaScriptFunctionOptimizer(
function_to_optimize=func_to_optimize,
test_cfg=test_config,
aiservice_client=MagicMock(),
@ -415,10 +414,7 @@ describe('fibonacci', () => {
"""Test get_code_optimization_context for TypeScript."""
skip_if_js_not_supported()
from codeflash.discovery.functions_to_optimize import find_all_functions_in_file
from codeflash.languages import current as lang_current
from codeflash.optimization.function_optimizer import FunctionOptimizer
lang_current._current_language = Language.TYPESCRIPT
from codeflash.languages.javascript.function_optimizer import JavaScriptFunctionOptimizer
src_file = ts_project / "utils.ts"
functions = find_all_functions_in_file(src_file)
@ -440,7 +436,7 @@ describe('fibonacci', () => {
pytest_cmd="vitest",
)
optimizer = FunctionOptimizer(
optimizer = JavaScriptFunctionOptimizer(
function_to_optimize=func_to_optimize,
test_cfg=test_config,
aiservice_client=MagicMock(),
@ -461,10 +457,7 @@ class TestHelperFunctionLanguageAttribute:
"""Verify helper functions have language='javascript' for .js files."""
skip_if_js_not_supported()
from codeflash.discovery.functions_to_optimize import find_all_functions_in_file
from codeflash.languages import current as lang_current
from codeflash.optimization.function_optimizer import FunctionOptimizer
lang_current._current_language = Language.JAVASCRIPT
from codeflash.languages.javascript.function_optimizer import JavaScriptFunctionOptimizer
# Create a file with helper functions
src_file = tmp_path / "main.js"
@ -499,7 +492,7 @@ module.exports = { main };
pytest_cmd="jest",
)
optimizer = FunctionOptimizer(
optimizer = JavaScriptFunctionOptimizer(
function_to_optimize=func_to_optimize,
test_cfg=test_config,
aiservice_client=MagicMock(),
@ -515,10 +508,7 @@ module.exports = { main };
"""Verify helper functions have language='typescript' for .ts files."""
skip_if_js_not_supported()
from codeflash.discovery.functions_to_optimize import find_all_functions_in_file
from codeflash.languages import current as lang_current
from codeflash.optimization.function_optimizer import FunctionOptimizer
lang_current._current_language = Language.TYPESCRIPT
from codeflash.languages.javascript.function_optimizer import JavaScriptFunctionOptimizer
# Create a file with helper functions
src_file = tmp_path / "main.ts"
@ -551,7 +541,7 @@ export function main(): number {
pytest_cmd="vitest",
)
optimizer = FunctionOptimizer(
optimizer = JavaScriptFunctionOptimizer(
function_to_optimize=func_to_optimize,
test_cfg=test_config,
aiservice_client=MagicMock(),

View file

@ -16,8 +16,6 @@ NOTE: These tests require:
Tests will be skipped if dependencies are not available.
"""
import os
import shutil
import subprocess
from pathlib import Path
from unittest.mock import MagicMock
@ -26,7 +24,7 @@ import pytest
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
from codeflash.languages.base import Language
from codeflash.models.models import FunctionParent, TestFile, TestFiles, TestType, TestingMode
from codeflash.models.models import FunctionParent
from codeflash.verification.verification_utils import TestConfig
@ -58,13 +56,7 @@ def install_dependencies(project_dir: Path) -> bool:
if has_node_modules(project_dir):
return True
try:
result = subprocess.run(
["npm", "install"],
cwd=project_dir,
capture_output=True,
text=True,
timeout=120
)
result = subprocess.run(["npm", "install"], cwd=project_dir, capture_output=True, text=True, timeout=120)
return result.returncode == 0
except Exception:
return False
@ -82,6 +74,7 @@ def skip_if_js_not_supported():
"""Skip test if JavaScript/TypeScript languages are not supported."""
try:
from codeflash.languages import get_language_support
get_language_support(Language.JAVASCRIPT)
except Exception as e:
pytest.skip(f"JavaScript/TypeScript language support not available: {e}")
@ -157,8 +150,8 @@ module.exports = {
"""Test that JavaScript test instrumentation module can be imported."""
skip_if_js_not_supported()
from codeflash.languages import get_language_support
# Verify the instrumentation module can be imported
from codeflash.languages.javascript.instrument import inject_profiling_into_existing_js_test
# Get JavaScript support
js_support = get_language_support(Language.JAVASCRIPT)
@ -272,8 +265,8 @@ export default defineConfig({
"""Test that TypeScript test instrumentation module can be imported."""
skip_if_js_not_supported()
from codeflash.languages import get_language_support
# Verify the instrumentation module can be imported
from codeflash.languages.javascript.instrument import inject_profiling_into_existing_js_test
test_file = ts_project_dir / "tests" / "math.test.ts"
@ -356,10 +349,7 @@ class TestRunAndParseJavaScriptTests:
"""
skip_if_js_not_supported()
from codeflash.discovery.functions_to_optimize import find_all_functions_in_file
from codeflash.languages import current as lang_current
from codeflash.optimization.function_optimizer import FunctionOptimizer
lang_current._current_language = Language.TYPESCRIPT
from codeflash.languages.javascript.function_optimizer import JavaScriptFunctionOptimizer
# Find the fibonacci function
fib_file = vitest_project / "fibonacci.ts"
@ -389,10 +379,8 @@ class TestRunAndParseJavaScriptTests:
)
# Create optimizer
func_optimizer = FunctionOptimizer(
function_to_optimize=func,
test_cfg=test_config,
aiservice_client=MagicMock(),
func_optimizer = JavaScriptFunctionOptimizer(
function_to_optimize=func, test_cfg=test_config, aiservice_client=MagicMock()
)
# Get code context - this should work
@ -419,8 +407,8 @@ class TestTimingMarkerParsing:
# The marker format used by codeflash for JavaScript
# Start marker: !$######{tag}######$!
# End marker: !######{tag}:{duration}######!
start_pattern = r'!\$######(.+?)######\$!'
end_pattern = r'!######(.+?):(\d+)######!'
start_pattern = r"!\$######(.+?)######\$!"
end_pattern = r"!######(.+?):(\d+)######!"
start_marker = "!$######test/math.test.ts:TestMath.test_add:add:1:0_0######$!"
end_marker = "!######test/math.test.ts:TestMath.test_add:add:1:0_0:12345######!"
@ -472,6 +460,7 @@ class TestJavaScriptTestResultParsing:
# Parse the XML
import xml.etree.ElementTree as ET
tree = ET.parse(junit_xml)
root = tree.getroot()
@ -504,6 +493,7 @@ class TestJavaScriptTestResultParsing:
# Parse the XML
import xml.etree.ElementTree as ET
tree = ET.parse(junit_xml)
root = tree.getroot()

View file

@ -52,7 +52,7 @@ export function add(a, b) {
""")
f.flush()
functions = js_support.discover_functions(Path(f.name))
functions = js_support.discover_functions(Path(f.name).read_text(encoding="utf-8"), Path(f.name))
assert len(functions) == 1
assert functions[0].function_name == "add"
@ -76,7 +76,7 @@ export function multiply(a, b) {
""")
f.flush()
functions = js_support.discover_functions(Path(f.name))
functions = js_support.discover_functions(Path(f.name).read_text(encoding="utf-8"), Path(f.name))
assert len(functions) == 3
names = {func.function_name for func in functions}
@ -94,7 +94,7 @@ export const multiply = (x, y) => x * y;
""")
f.flush()
functions = js_support.discover_functions(Path(f.name))
functions = js_support.discover_functions(Path(f.name).read_text(encoding="utf-8"), Path(f.name))
assert len(functions) == 2
names = {func.function_name for func in functions}
@ -114,7 +114,7 @@ export function withoutReturn() {
""")
f.flush()
functions = js_support.discover_functions(Path(f.name))
functions = js_support.discover_functions(Path(f.name).read_text(encoding="utf-8"), Path(f.name))
# Only the function with return should be discovered
assert len(functions) == 1
@ -136,7 +136,7 @@ export class Calculator {
""")
f.flush()
functions = js_support.discover_functions(Path(f.name))
functions = js_support.discover_functions(Path(f.name).read_text(encoding="utf-8"), Path(f.name))
assert len(functions) == 2
for func in functions:
@ -157,7 +157,7 @@ export function syncFunction() {
""")
f.flush()
functions = js_support.discover_functions(Path(f.name))
functions = js_support.discover_functions(Path(f.name).read_text(encoding="utf-8"), Path(f.name))
assert len(functions) == 2
@ -182,7 +182,7 @@ export function syncFunc() {
f.flush()
criteria = FunctionFilterCriteria(include_async=False)
functions = js_support.discover_functions(Path(f.name), criteria)
functions = js_support.discover_functions(Path(f.name).read_text(encoding="utf-8"), Path(f.name), criteria)
assert len(functions) == 1
assert functions[0].function_name == "syncFunc"
@ -204,7 +204,7 @@ export class MyClass {
f.flush()
criteria = FunctionFilterCriteria(include_methods=False)
functions = js_support.discover_functions(Path(f.name), criteria)
functions = js_support.discover_functions(Path(f.name).read_text(encoding="utf-8"), Path(f.name), criteria)
assert len(functions) == 1
assert functions[0].function_name == "standalone"
@ -224,7 +224,7 @@ export function func2() {
""")
f.flush()
functions = js_support.discover_functions(Path(f.name))
functions = js_support.discover_functions(Path(f.name).read_text(encoding="utf-8"), Path(f.name))
func1 = next(f for f in functions if f.function_name == "func1")
func2 = next(f for f in functions if f.function_name == "func2")
@ -246,7 +246,7 @@ export function* numberGenerator() {
""")
f.flush()
functions = js_support.discover_functions(Path(f.name))
functions = js_support.discover_functions(Path(f.name).read_text(encoding="utf-8"), Path(f.name))
assert len(functions) == 1
assert functions[0].function_name == "numberGenerator"
@ -257,14 +257,14 @@ export function* numberGenerator() {
f.write("this is not valid javascript {{{{")
f.flush()
functions = js_support.discover_functions(Path(f.name))
functions = js_support.discover_functions(Path(f.name).read_text(encoding="utf-8"), Path(f.name))
# Tree-sitter is lenient, so it may still parse partial code
# The important thing is it doesn't crash
assert isinstance(functions, list)
def test_discover_nonexistent_file_returns_empty(self, js_support):
"""Test that nonexistent file returns empty list."""
functions = js_support.discover_functions(Path("/nonexistent/file.js"))
functions = js_support.discover_functions("", Path("/nonexistent/file.js"))
assert functions == []
def test_discover_function_expression(self, js_support):
@ -277,7 +277,7 @@ export const add = function(a, b) {
""")
f.flush()
functions = js_support.discover_functions(Path(f.name))
functions = js_support.discover_functions(Path(f.name).read_text(encoding="utf-8"), Path(f.name))
assert len(functions) == 1
assert functions[0].function_name == "add"
@ -296,7 +296,7 @@ export function named() {
""")
f.flush()
functions = js_support.discover_functions(Path(f.name))
functions = js_support.discover_functions(Path(f.name).read_text(encoding="utf-8"), Path(f.name))
# Only the named function should be discovered
assert len(functions) == 1
@ -507,7 +507,7 @@ export function main(a) {
file_path = Path(f.name)
# First discover functions to get accurate line numbers
functions = js_support.discover_functions(file_path)
functions = js_support.discover_functions(file_path.read_text(encoding="utf-8"), file_path)
main_func = next(f for f in functions if f.function_name == "main")
context = js_support.extract_code_context(main_func, file_path.parent, file_path.parent)
@ -535,7 +535,7 @@ class TestIntegration:
file_path = Path(f.name)
# Discover
functions = js_support.discover_functions(file_path)
functions = js_support.discover_functions(file_path.read_text(encoding="utf-8"), file_path)
assert len(functions) == 1
func = functions[0]
assert func.function_name == "fibonacci"
@ -584,7 +584,7 @@ export function standalone() {
f.flush()
file_path = Path(f.name)
functions = js_support.discover_functions(file_path)
functions = js_support.discover_functions(file_path.read_text(encoding="utf-8"), file_path)
# Should find 4 functions
assert len(functions) == 4
@ -623,7 +623,7 @@ export default Button;
f.flush()
file_path = Path(f.name)
functions = js_support.discover_functions(file_path)
functions = js_support.discover_functions(file_path.read_text(encoding="utf-8"), file_path)
# Should find both components
names = {f.function_name for f in functions}
@ -653,7 +653,7 @@ describe('Math functions', () => {
f.flush()
file_path = Path(f.name)
source = file_path.read_text()
source = file_path.read_text(encoding="utf-8")
from codeflash.languages.javascript.treesitter import get_analyzer_for_file
analyzer = get_analyzer_for_file(file_path)
@ -687,7 +687,7 @@ class TestClassMethodExtraction:
file_path = Path(f.name)
# Discover the method
functions = js_support.discover_functions(file_path)
functions = js_support.discover_functions(file_path.read_text(encoding="utf-8"), file_path)
add_method = next(f for f in functions if f.function_name == "add")
# Extract code context
@ -725,7 +725,7 @@ export class Calculator {
f.flush()
file_path = Path(f.name)
functions = js_support.discover_functions(file_path)
functions = js_support.discover_functions(file_path.read_text(encoding="utf-8"), file_path)
add_method = next(f for f in functions if f.function_name == "add")
context = js_support.extract_code_context(add_method, file_path.parent, file_path.parent)
@ -763,7 +763,7 @@ export class Calculator {
f.flush()
file_path = Path(f.name)
functions = js_support.discover_functions(file_path)
functions = js_support.discover_functions(file_path.read_text(encoding="utf-8"), file_path)
fib_method = next(f for f in functions if f.function_name == "fibonacci")
context = js_support.extract_code_context(fib_method, file_path.parent, file_path.parent)
@ -802,7 +802,7 @@ export class Calculator {
f.flush()
file_path = Path(f.name)
functions = js_support.discover_functions(file_path)
functions = js_support.discover_functions(file_path.read_text(encoding="utf-8"), file_path)
add_method = next((f for f in functions if f.function_name == "add"), None)
if add_method:
@ -832,7 +832,7 @@ export class Calculator {
f.flush()
file_path = Path(f.name)
functions = js_support.discover_functions(file_path)
functions = js_support.discover_functions(file_path.read_text(encoding="utf-8"), file_path)
fetch_method = next(f for f in functions if f.function_name == "fetchData")
context = js_support.extract_code_context(fetch_method, file_path.parent, file_path.parent)
@ -865,7 +865,7 @@ export class Calculator {
f.flush()
file_path = Path(f.name)
functions = js_support.discover_functions(file_path)
functions = js_support.discover_functions(file_path.read_text(encoding="utf-8"), file_path)
add_method = next((f for f in functions if f.function_name == "add"), None)
if add_method:
@ -894,7 +894,7 @@ export class Calculator {
f.flush()
file_path = Path(f.name)
functions = js_support.discover_functions(file_path)
functions = js_support.discover_functions(file_path.read_text(encoding="utf-8"), file_path)
method = next(f for f in functions if f.function_name == "simpleMethod")
context = js_support.extract_code_context(method, file_path.parent, file_path.parent)
@ -1079,7 +1079,7 @@ class TestClassMethodEdgeCases:
f.flush()
file_path = Path(f.name)
functions = js_support.discover_functions(file_path)
functions = js_support.discover_functions(file_path.read_text(encoding="utf-8"), file_path)
# Should find constructor and increment
names = {f.function_name for f in functions}
@ -1109,7 +1109,7 @@ class TestClassMethodEdgeCases:
f.flush()
file_path = Path(f.name)
functions = js_support.discover_functions(file_path)
functions = js_support.discover_functions(file_path.read_text(encoding="utf-8"), file_path)
# Should find at least greet
names = {f.function_name for f in functions}
@ -1137,7 +1137,7 @@ export class Dog extends Animal {
f.flush()
file_path = Path(f.name)
functions = js_support.discover_functions(file_path)
functions = js_support.discover_functions(file_path.read_text(encoding="utf-8"), file_path)
# Find Dog's fetch method
fetch_method = next((f for f in functions if f.function_name == "fetch" and f.class_name == "Dog"), None)
@ -1172,7 +1172,7 @@ export class Dog extends Animal {
f.flush()
file_path = Path(f.name)
functions = js_support.discover_functions(file_path)
functions = js_support.discover_functions(file_path.read_text(encoding="utf-8"), file_path)
# Should at least find publicMethod
names = {f.function_name for f in functions}
@ -1192,7 +1192,7 @@ module.exports = { Calculator };
f.flush()
file_path = Path(f.name)
functions = js_support.discover_functions(file_path)
functions = js_support.discover_functions(file_path.read_text(encoding="utf-8"), file_path)
add_method = next(f for f in functions if f.function_name == "add")
context = js_support.extract_code_context(add_method, file_path.parent, file_path.parent)
@ -1212,7 +1212,7 @@ module.exports = { Calculator };
f.flush()
file_path = Path(f.name)
functions = js_support.discover_functions(file_path)
functions = js_support.discover_functions(file_path.read_text(encoding="utf-8"), file_path)
# Find the add method
add_method = next((f for f in functions if f.function_name == "add"), None)
@ -1265,7 +1265,7 @@ module.exports = { Counter };
f.flush()
file_path = Path(f.name)
functions = js_support.discover_functions(file_path)
functions = js_support.discover_functions(file_path.read_text(encoding="utf-8"), file_path)
increment_func = next(fn for fn in functions if fn.function_name == "increment")
# Step 1: Extract code context (includes constructor for AI context)
@ -1362,7 +1362,7 @@ export class User {
f.flush()
file_path = Path(f.name)
functions = ts_support.discover_functions(file_path)
functions = ts_support.discover_functions(file_path.read_text(encoding="utf-8"), file_path)
get_name_func = next(fn for fn in functions if fn.function_name == "getName")
# Step 1: Extract code context (includes fields and constructor)
@ -1462,7 +1462,7 @@ export class Calculator {
f.flush()
file_path = Path(f.name)
functions = js_support.discover_functions(file_path)
functions = js_support.discover_functions(file_path.read_text(encoding="utf-8"), file_path)
add_func = next(fn for fn in functions if fn.function_name == "add")
# Extract context for add
@ -1546,7 +1546,7 @@ export class MathUtils {
f.flush()
file_path = Path(f.name)
functions = js_support.discover_functions(file_path)
functions = js_support.discover_functions(file_path.read_text(encoding="utf-8"), file_path)
add_func = next(fn for fn in functions if fn.function_name == "add")
# Extract context

View file

@ -53,7 +53,7 @@ describe('add function', () => {
""")
# Discover functions first
functions = js_support.discover_functions(source_file)
functions = js_support.discover_functions(source_file.read_text(encoding="utf-8"), source_file)
assert len(functions) == 1
# Discover tests
@ -90,7 +90,7 @@ describe('multiply', () => {
});
""")
functions = js_support.discover_functions(source_file)
functions = js_support.discover_functions(source_file.read_text(encoding="utf-8"), source_file)
tests = js_support.discover_tests(tmpdir, functions)
assert len(tests) > 0
@ -124,7 +124,7 @@ test('formats date correctly', () => {
});
""")
functions = js_support.discover_functions(source_file)
functions = js_support.discover_functions(source_file.read_text(encoding="utf-8"), source_file)
tests = js_support.discover_tests(tmpdir, functions)
assert len(tests) > 0
@ -170,7 +170,7 @@ describe('String Utils', () => {
});
""")
functions = js_support.discover_functions(source_file)
functions = js_support.discover_functions(source_file.read_text(encoding="utf-8"), source_file)
tests = js_support.discover_tests(tmpdir, functions)
assert len(tests) > 0
@ -208,7 +208,7 @@ describe('sum function', () => {
});
""")
functions = js_support.discover_functions(source_file)
functions = js_support.discover_functions(source_file.read_text(encoding="utf-8"), source_file)
tests = js_support.discover_tests(tmpdir, functions)
assert len(tests) > 0
@ -242,7 +242,7 @@ test('subtract two numbers', () => {
});
""")
functions = js_support.discover_functions(source_file)
functions = js_support.discover_functions(source_file.read_text(encoding="utf-8"), source_file)
tests = js_support.discover_tests(tmpdir, functions)
assert len(tests) > 0
@ -270,7 +270,7 @@ test('greets by name', () => {
});
""")
functions = js_support.discover_functions(source_file)
functions = js_support.discover_functions(source_file.read_text(encoding="utf-8"), source_file)
tests = js_support.discover_tests(tmpdir, functions)
assert len(tests) > 0
@ -316,7 +316,7 @@ describe('Calculator class', () => {
});
""")
functions = js_support.discover_functions(source_file)
functions = js_support.discover_functions(source_file.read_text(encoding="utf-8"), source_file)
tests = js_support.discover_tests(tmpdir, functions)
# Should find tests for class methods
@ -363,7 +363,7 @@ describe('clamp', () => {
});
""")
functions = js_support.discover_functions(source_file)
functions = js_support.discover_functions(source_file.read_text(encoding="utf-8"), source_file)
tests = js_support.discover_tests(tmpdir, functions)
assert len(tests) > 0
@ -399,7 +399,7 @@ describe('async utilities', () => {
});
""")
functions = js_support.discover_functions(source_file)
functions = js_support.discover_functions(source_file.read_text(encoding="utf-8"), source_file)
tests = js_support.discover_tests(tmpdir, functions)
assert len(tests) > 0
@ -436,7 +436,7 @@ describe('Button component', () => {
});
""")
functions = js_support.discover_functions(source_file)
functions = js_support.discover_functions(source_file.read_text(encoding="utf-8"), source_file)
tests = js_support.discover_tests(tmpdir, functions)
# JSX tests should be discovered
@ -466,7 +466,7 @@ test('other test', () => {
});
""")
functions = js_support.discover_functions(source_file)
functions = js_support.discover_functions(source_file.read_text(encoding="utf-8"), source_file)
tests = js_support.discover_tests(tmpdir, functions)
# Should not find tests for our function
@ -502,7 +502,7 @@ describe('validators', () => {
});
""")
functions = js_support.discover_functions(source_file)
functions = js_support.discover_functions(source_file.read_text(encoding="utf-8"), source_file)
tests = js_support.discover_tests(tmpdir, functions)
# Should find tests for isEmail
@ -546,7 +546,7 @@ test('helper2 returns 2', () => {
});
""")
functions = js_support.discover_functions(source_file)
functions = js_support.discover_functions(source_file.read_text(encoding="utf-8"), source_file)
tests = js_support.discover_tests(tmpdir, functions)
assert len(tests) > 0
@ -574,7 +574,7 @@ test(`formatNumber with decimal`, () => {
});
""")
functions = js_support.discover_functions(source_file)
functions = js_support.discover_functions(source_file.read_text(encoding="utf-8"), source_file)
tests = js_support.discover_tests(tmpdir, functions)
# May or may not find depending on template literal handling
@ -605,7 +605,7 @@ describe('transform', () => {
});
""")
functions = js_support.discover_functions(source_file)
functions = js_support.discover_functions(source_file.read_text(encoding="utf-8"), source_file)
tests = js_support.discover_tests(tmpdir, functions)
# Should still find tests since original name is imported
@ -626,7 +626,7 @@ it('third test', () => {});
f.flush()
file_path = Path(f.name)
source = file_path.read_text()
source = file_path.read_text(encoding="utf-8")
from codeflash.languages.javascript.treesitter import get_analyzer_for_file
analyzer = get_analyzer_for_file(file_path)
@ -651,7 +651,7 @@ describe('Suite B', () => {
f.flush()
file_path = Path(f.name)
source = file_path.read_text()
source = file_path.read_text(encoding="utf-8")
from codeflash.languages.javascript.treesitter import get_analyzer_for_file
analyzer = get_analyzer_for_file(file_path)
@ -675,7 +675,7 @@ describe('Outer', () => {
f.flush()
file_path = Path(f.name)
source = file_path.read_text()
source = file_path.read_text(encoding="utf-8")
from codeflash.languages.javascript.treesitter import get_analyzer_for_file
analyzer = get_analyzer_for_file(file_path)
@ -699,7 +699,7 @@ describe.skip('skipped describe', () => {
f.flush()
file_path = Path(f.name)
source = file_path.read_text()
source = file_path.read_text(encoding="utf-8")
from codeflash.languages.javascript.treesitter import get_analyzer_for_file
analyzer = get_analyzer_for_file(file_path)
@ -720,7 +720,7 @@ describe.only('only describe', () => {
f.flush()
file_path = Path(f.name)
source = file_path.read_text()
source = file_path.read_text(encoding="utf-8")
from codeflash.languages.javascript.treesitter import get_analyzer_for_file
analyzer = get_analyzer_for_file(file_path)
@ -738,7 +738,7 @@ describe('describe single', () => {});
f.flush()
file_path = Path(f.name)
source = file_path.read_text()
source = file_path.read_text(encoding="utf-8")
from codeflash.languages.javascript.treesitter import get_analyzer_for_file
analyzer = get_analyzer_for_file(file_path)
@ -757,7 +757,7 @@ describe("describe double", () => {});
f.flush()
file_path = Path(f.name)
source = file_path.read_text()
source = file_path.read_text(encoding="utf-8")
from codeflash.languages.javascript.treesitter import get_analyzer_for_file
analyzer = get_analyzer_for_file(file_path)
@ -773,7 +773,7 @@ describe("describe double", () => {});
f.flush()
file_path = Path(f.name)
source = file_path.read_text()
source = file_path.read_text(encoding="utf-8")
from codeflash.languages.javascript.treesitter import get_analyzer_for_file
analyzer = get_analyzer_for_file(file_path)
@ -806,7 +806,7 @@ test('funcA works', () => {
});
""")
functions = js_support.discover_functions(source_file)
functions = js_support.discover_functions(source_file.read_text(encoding="utf-8"), source_file)
tests = js_support.discover_tests(tmpdir, functions)
# funcA should have tests
@ -833,7 +833,7 @@ test('funcX works', () => {
});
""")
functions = js_support.discover_functions(source_file)
functions = js_support.discover_functions(source_file.read_text(encoding="utf-8"), source_file)
tests = js_support.discover_tests(tmpdir, functions)
# funcX should have tests
@ -859,7 +859,7 @@ test('mainFunc works', () => {
});
""")
functions = js_support.discover_functions(source_file)
functions = js_support.discover_functions(source_file.read_text(encoding="utf-8"), source_file)
tests = js_support.discover_tests(tmpdir, functions)
assert len(tests) > 0
@ -896,7 +896,7 @@ test('block commented', () => {
*/
""")
functions = js_support.discover_functions(source_file)
functions = js_support.discover_functions(source_file.read_text(encoding="utf-8"), source_file)
tests = js_support.discover_tests(tmpdir, functions)
assert len(tests) > 0
@ -921,7 +921,7 @@ test('broken test' { // Missing arrow function
});
""")
functions = js_support.discover_functions(source_file)
functions = js_support.discover_functions(source_file.read_text(encoding="utf-8"), source_file)
# Should not crash
tests = js_support.discover_tests(tmpdir, functions)
assert isinstance(tests, dict)
@ -949,7 +949,7 @@ describe('conflict tests', () => {
});
""")
functions = js_support.discover_functions(source_file)
functions = js_support.discover_functions(source_file.read_text(encoding="utf-8"), source_file)
tests = js_support.discover_tests(tmpdir, functions)
# Should still work despite naming conflicts
@ -966,7 +966,7 @@ export function lonelyFunc() { return 'alone'; }
module.exports = { lonelyFunc };
""")
functions = js_support.discover_functions(source_file)
functions = js_support.discover_functions(source_file.read_text(encoding="utf-8"), source_file)
tests = js_support.discover_tests(tmpdir, functions)
# Should return empty dict, not crash
@ -1001,7 +1001,7 @@ test('funcA works', () => {
});
""")
functions_a = js_support.discover_functions(file_a)
functions_a = js_support.discover_functions(file_a.read_text(encoding="utf-8"), file_a)
tests = js_support.discover_tests(tmpdir, functions_a)
# Should handle circular imports gracefully
@ -1047,7 +1047,7 @@ test.each([
f.flush()
file_path = Path(f.name)
source = file_path.read_text()
source = file_path.read_text(encoding="utf-8")
from codeflash.languages.javascript.treesitter import get_analyzer_for_file
analyzer = get_analyzer_for_file(file_path)
@ -1073,7 +1073,7 @@ describe.each([
f.flush()
file_path = Path(f.name)
source = file_path.read_text()
source = file_path.read_text(encoding="utf-8")
from codeflash.languages.javascript.treesitter import get_analyzer_for_file
analyzer = get_analyzer_for_file(file_path)
@ -1098,7 +1098,7 @@ describe('Math operations', () => {
f.flush()
file_path = Path(f.name)
source = file_path.read_text()
source = file_path.read_text(encoding="utf-8")
from codeflash.languages.javascript.treesitter import get_analyzer_for_file
analyzer = get_analyzer_for_file(file_path)
@ -1174,7 +1174,7 @@ describe('formatName', () => {
""")
# Discover functions
functions = js_support.discover_functions(source_file)
functions = js_support.discover_functions(source_file.read_text(encoding="utf-8"), source_file)
assert len(functions) == 3
# Discover tests
@ -1242,7 +1242,7 @@ describe('Database', () => {
});
""")
functions = js_support.discover_functions(source_file)
functions = js_support.discover_functions(source_file.read_text(encoding="utf-8"), source_file)
tests = js_support.discover_tests(tmpdir, functions)
assert len(tests) > 0
@ -1280,7 +1280,7 @@ test('funcA works', () => {
""")
# Discover functions from moduleB
functions_b = js_support.discover_functions(source_b)
functions_b = js_support.discover_functions(source_b.read_text(encoding="utf-8"), source_b)
tests = js_support.discover_tests(tmpdir, functions_b)
# funcB should not have any tests since test file doesn't import it
@ -1312,7 +1312,7 @@ test('funcOne returns 1', () => {
});
""")
functions = js_support.discover_functions(source_file)
functions = js_support.discover_functions(source_file.read_text(encoding="utf-8"), source_file)
tests = js_support.discover_tests(tmpdir, functions)
# Check that tests were found
@ -1340,7 +1340,7 @@ test('mentions targetFunc in string', () => {
});
""")
functions = js_support.discover_functions(source_file)
functions = js_support.discover_functions(source_file.read_text(encoding="utf-8"), source_file)
tests = js_support.discover_tests(tmpdir, functions)
# Current implementation may still match on string occurrence
@ -1367,7 +1367,7 @@ test('calculate doubles', () => {
});
""")
functions = js_support.discover_functions(source_file)
functions = js_support.discover_functions(source_file.read_text(encoding="utf-8"), source_file)
tests = js_support.discover_tests(tmpdir, functions)
# Should find tests since 'calculate' appears in source
@ -1399,7 +1399,7 @@ describe('MyClass', () => {
});
""")
functions = js_support.discover_functions(source_file)
functions = js_support.discover_functions(source_file.read_text(encoding="utf-8"), source_file)
tests = js_support.discover_tests(tmpdir, functions)
# Should find tests for class methods
@ -1432,7 +1432,7 @@ test('deepHelper works', () => {
});
""")
functions = js_support.discover_functions(source_file)
functions = js_support.discover_functions(source_file.read_text(encoding="utf-8"), source_file)
tests = js_support.discover_tests(tmpdir, functions)
assert len(tests) > 0
@ -1456,7 +1456,7 @@ testCases.forEach(name => {
f.flush()
file_path = Path(f.name)
source = file_path.read_text()
source = file_path.read_text(encoding="utf-8")
from codeflash.languages.javascript.treesitter import get_analyzer_for_file
analyzer = get_analyzer_for_file(file_path)
@ -1484,7 +1484,7 @@ describe('conditional tests', () => {
f.flush()
file_path = Path(f.name)
source = file_path.read_text()
source = file_path.read_text(encoding="utf-8")
from codeflash.languages.javascript.treesitter import get_analyzer_for_file
analyzer = get_analyzer_for_file(file_path)
@ -1508,7 +1508,7 @@ test('slow test', () => {
f.flush()
file_path = Path(f.name)
source = file_path.read_text()
source = file_path.read_text(encoding="utf-8")
from codeflash.languages.javascript.treesitter import get_analyzer_for_file
analyzer = get_analyzer_for_file(file_path)
@ -1531,7 +1531,7 @@ test.todo('also needs implementation');
f.flush()
file_path = Path(f.name)
source = file_path.read_text()
source = file_path.read_text(encoding="utf-8")
from codeflash.languages.javascript.treesitter import get_analyzer_for_file
analyzer = get_analyzer_for_file(file_path)
@ -1554,7 +1554,7 @@ test.concurrent('concurrent test 2', async () => {
f.flush()
file_path = Path(f.name)
source = file_path.read_text()
source = file_path.read_text(encoding="utf-8")
from codeflash.languages.javascript.treesitter import get_analyzer_for_file
analyzer = get_analyzer_for_file(file_path)
@ -1597,7 +1597,7 @@ describe('subtractNumbers', () => {
});
""")
functions = js_support.discover_functions(source_file)
functions = js_support.discover_functions(source_file.read_text(encoding="utf-8"), source_file)
tests = js_support.discover_tests(tmpdir, functions)
# All three functions should be discovered
@ -1628,7 +1628,7 @@ describe('Unrelated name', () => {
});
""")
functions = js_support.discover_functions(source_file)
functions = js_support.discover_functions(source_file.read_text(encoding="utf-8"), source_file)
tests = js_support.discover_tests(tmpdir, functions)
# Should still find tests
@ -1653,7 +1653,7 @@ describe('Array', function() {
f.flush()
file_path = Path(f.name)
source = file_path.read_text()
source = file_path.read_text(encoding="utf-8")
from codeflash.languages.javascript.treesitter import get_analyzer_for_file
analyzer = get_analyzer_for_file(file_path)
@ -1684,7 +1684,7 @@ describe('User', () => {
f.flush()
file_path = Path(f.name)
source = file_path.read_text()
source = file_path.read_text(encoding="utf-8")
from codeflash.languages.javascript.treesitter import get_analyzer_for_file
analyzer = get_analyzer_for_file(file_path)
@ -1712,7 +1712,7 @@ export class Calculator {
module.exports = { Calculator };
""")
functions = js_support.discover_functions(source_file)
functions = js_support.discover_functions(source_file.read_text(encoding="utf-8"), source_file)
# Check qualified names include class
add_func = next((f for f in functions if f.function_name == "add"), None)
@ -1737,7 +1737,7 @@ export class Outer {
module.exports = { Outer };
""")
functions = js_support.discover_functions(source_file)
functions = js_support.discover_functions(source_file.read_text(encoding="utf-8"), source_file)
# Should find at least the Outer class method
assert any(f.class_name == "Outer" for f in functions)

View file

@ -13,7 +13,7 @@ from codeflash.languages.base import Language
from codeflash.languages.javascript.support import JavaScriptSupport, TypeScriptSupport
from codeflash.languages.registry import get_language_support
from codeflash.models.models import FunctionParent
from codeflash.optimization.function_optimizer import FunctionOptimizer
from codeflash.languages.javascript.function_optimizer import JavaScriptFunctionOptimizer
from codeflash.verification.verification_utils import TestConfig
FIXTURES_DIR = Path(__file__).parent / "fixtures"
@ -37,7 +37,7 @@ class TestCodeExtractorCJS:
def test_discover_class_methods(self, js_support, cjs_project):
"""Test that class methods are discovered correctly."""
calculator_file = cjs_project / "calculator.js"
functions = js_support.discover_functions(calculator_file)
functions = js_support.discover_functions(calculator_file.read_text(encoding="utf-8"), calculator_file)
method_names = {f.function_name for f in functions}
@ -47,17 +47,19 @@ class TestCodeExtractorCJS:
def test_class_method_has_correct_parent(self, js_support, cjs_project):
"""Test parent class information for methods."""
calculator_file = cjs_project / "calculator.js"
functions = js_support.discover_functions(calculator_file)
functions = js_support.discover_functions(calculator_file.read_text(encoding="utf-8"), calculator_file)
for func in functions:
# All methods should belong to Calculator class
assert func.is_method is True, f"{func.function_name} should be a method"
assert func.class_name == "Calculator", f"{func.function_name} should belong to Calculator, got {func.class_name}"
assert func.class_name == "Calculator", (
f"{func.function_name} should belong to Calculator, got {func.class_name}"
)
def test_extract_permutation_code(self, js_support, cjs_project):
"""Test permutation method code extraction."""
calculator_file = cjs_project / "calculator.js"
functions = js_support.discover_functions(calculator_file)
functions = js_support.discover_functions(calculator_file.read_text(encoding="utf-8"), calculator_file)
permutation_func = next(f for f in functions if f.function_name == "permutation")
@ -93,7 +95,7 @@ class Calculator {
def test_extract_context_includes_direct_helpers(self, js_support, cjs_project):
"""Test that direct helper functions are included in context."""
calculator_file = cjs_project / "calculator.js"
functions = js_support.discover_functions(calculator_file)
functions = js_support.discover_functions(calculator_file.read_text(encoding="utf-8"), calculator_file)
permutation_func = next(f for f in functions if f.function_name == "permutation")
@ -129,7 +131,7 @@ export function factorial(n) {
def test_extract_compound_interest_code(self, js_support, cjs_project):
"""Test calculateCompoundInterest code extraction."""
calculator_file = cjs_project / "calculator.js"
functions = js_support.discover_functions(calculator_file)
functions = js_support.discover_functions(calculator_file.read_text(encoding="utf-8"), calculator_file)
compound_func = next(f for f in functions if f.function_name == "calculateCompoundInterest")
@ -175,7 +177,7 @@ class Calculator {
def test_extract_compound_interest_helpers(self, js_support, cjs_project):
"""Test helper extraction for calculateCompoundInterest."""
calculator_file = cjs_project / "calculator.js"
functions = js_support.discover_functions(calculator_file)
functions = js_support.discover_functions(calculator_file.read_text(encoding="utf-8"), calculator_file)
compound_func = next(f for f in functions if f.function_name == "calculateCompoundInterest")
@ -235,7 +237,7 @@ export function validateInput(value, name) {
def test_extract_context_includes_imports(self, js_support, cjs_project):
"""Test import statement extraction."""
calculator_file = cjs_project / "calculator.js"
functions = js_support.discover_functions(calculator_file)
functions = js_support.discover_functions(calculator_file.read_text(encoding="utf-8"), calculator_file)
compound_func = next(f for f in functions if f.function_name == "calculateCompoundInterest")
@ -256,7 +258,7 @@ export function validateInput(value, name) {
def test_extract_static_method(self, js_support, cjs_project):
"""Test static method extraction (quickAdd)."""
calculator_file = cjs_project / "calculator.js"
functions = js_support.discover_functions(calculator_file)
functions = js_support.discover_functions(calculator_file.read_text(encoding="utf-8"), calculator_file)
quick_add_func = next(f for f in functions if f.function_name == "quickAdd")
@ -315,7 +317,7 @@ class TestCodeExtractorESM:
def test_discover_esm_methods(self, js_support, esm_project):
"""Test method discovery in ESM project."""
calculator_file = esm_project / "calculator.js"
functions = js_support.discover_functions(calculator_file)
functions = js_support.discover_functions(calculator_file.read_text(encoding="utf-8"), calculator_file)
method_names = {f.function_name for f in functions}
@ -326,7 +328,7 @@ class TestCodeExtractorESM:
def test_esm_permutation_extraction(self, js_support, esm_project):
"""Test permutation method extraction in ESM."""
calculator_file = esm_project / "calculator.js"
functions = js_support.discover_functions(calculator_file)
functions = js_support.discover_functions(calculator_file.read_text(encoding="utf-8"), calculator_file)
permutation_func = next(f for f in functions if f.function_name == "permutation")
@ -376,7 +378,7 @@ export function factorial(n) {
def test_esm_compound_interest_extraction(self, js_support, esm_project):
"""Test calculateCompoundInterest extraction in ESM with import syntax."""
calculator_file = esm_project / "calculator.js"
functions = js_support.discover_functions(calculator_file)
functions = js_support.discover_functions(calculator_file.read_text(encoding="utf-8"), calculator_file)
compound_func = next(f for f in functions if f.function_name == "calculateCompoundInterest")
@ -502,7 +504,7 @@ class TestCodeExtractorTypeScript:
def test_discover_ts_methods(self, ts_support, ts_project):
"""Test method discovery in TypeScript."""
calculator_file = ts_project / "calculator.ts"
functions = ts_support.discover_functions(calculator_file)
functions = ts_support.discover_functions(calculator_file.read_text(encoding="utf-8"), calculator_file)
method_names = {f.function_name for f in functions}
@ -513,7 +515,7 @@ class TestCodeExtractorTypeScript:
def test_ts_permutation_extraction(self, ts_support, ts_project):
"""Test permutation method extraction in TypeScript."""
calculator_file = ts_project / "calculator.ts"
functions = ts_support.discover_functions(calculator_file)
functions = ts_support.discover_functions(calculator_file.read_text(encoding="utf-8"), calculator_file)
permutation_func = next(f for f in functions if f.function_name == "permutation")
@ -566,7 +568,7 @@ export function factorial(n: number): number {
def test_ts_compound_interest_extraction(self, ts_support, ts_project):
"""Test calculateCompoundInterest extraction in TypeScript."""
calculator_file = ts_project / "calculator.ts"
functions = ts_support.discover_functions(calculator_file)
functions = ts_support.discover_functions(calculator_file.read_text(encoding="utf-8"), calculator_file)
compound_func = next(f for f in functions if f.function_name == "calculateCompoundInterest")
@ -676,7 +678,7 @@ module.exports = { standalone };
test_file = tmp_path / "standalone.js"
test_file.write_text(source)
functions = js_support.discover_functions(test_file)
functions = js_support.discover_functions(test_file.read_text(encoding="utf-8"), test_file)
func = next(f for f in functions if f.function_name == "standalone")
context = js_support.extract_code_context(function=func, project_root=tmp_path, module_root=tmp_path)
@ -709,7 +711,7 @@ module.exports = { processArray };
test_file = tmp_path / "processor.js"
test_file.write_text(source)
functions = js_support.discover_functions(test_file)
functions = js_support.discover_functions(test_file.read_text(encoding="utf-8"), test_file)
func = next(f for f in functions if f.function_name == "processArray")
context = js_support.extract_code_context(function=func, project_root=tmp_path, module_root=tmp_path)
@ -744,7 +746,7 @@ module.exports = { fibonacci };
test_file = tmp_path / "recursive.js"
test_file.write_text(source)
functions = js_support.discover_functions(test_file)
functions = js_support.discover_functions(test_file.read_text(encoding="utf-8"), test_file)
func = next(f for f in functions if f.function_name == "fibonacci")
context = js_support.extract_code_context(function=func, project_root=tmp_path, module_root=tmp_path)
@ -777,7 +779,7 @@ module.exports = { processValue };
test_file = tmp_path / "arrow.js"
test_file.write_text(source)
functions = js_support.discover_functions(test_file)
functions = js_support.discover_functions(test_file.read_text(encoding="utf-8"), test_file)
func = next(f for f in functions if f.function_name == "processValue")
context = js_support.extract_code_context(function=func, project_root=tmp_path, module_root=tmp_path)
@ -835,7 +837,7 @@ module.exports = { Counter };
test_file = tmp_path / "counter.js"
test_file.write_text(source)
functions = js_support.discover_functions(test_file)
functions = js_support.discover_functions(test_file.read_text(encoding="utf-8"), test_file)
increment_func = next(f for f in functions if f.function_name == "increment")
context = js_support.extract_code_context(function=increment_func, project_root=tmp_path, module_root=tmp_path)
@ -874,7 +876,7 @@ module.exports = { MathUtils };
test_file = tmp_path / "math_utils.js"
test_file.write_text(source)
functions = js_support.discover_functions(test_file)
functions = js_support.discover_functions(test_file.read_text(encoding="utf-8"), test_file)
add_func = next(f for f in functions if f.function_name == "add")
context = js_support.extract_code_context(function=add_func, project_root=tmp_path, module_root=tmp_path)
@ -910,7 +912,7 @@ export class User {
test_file = tmp_path / "user.ts"
test_file.write_text(source)
functions = ts_support.discover_functions(test_file)
functions = ts_support.discover_functions(test_file.read_text(encoding="utf-8"), test_file)
get_name_func = next(f for f in functions if f.function_name == "getName")
context = ts_support.extract_code_context(function=get_name_func, project_root=tmp_path, module_root=tmp_path)
@ -949,7 +951,7 @@ export class Config {
test_file = tmp_path / "config.ts"
test_file.write_text(source)
functions = ts_support.discover_functions(test_file)
functions = ts_support.discover_functions(test_file.read_text(encoding="utf-8"), test_file)
get_url_func = next(f for f in functions if f.function_name == "getUrl")
context = ts_support.extract_code_context(function=get_url_func, project_root=tmp_path, module_root=tmp_path)
@ -990,7 +992,7 @@ module.exports = { Logger };
test_file = tmp_path / "logger.js"
test_file.write_text(source)
functions = js_support.discover_functions(test_file)
functions = js_support.discover_functions(test_file.read_text(encoding="utf-8"), test_file)
get_prefix_func = next(f for f in functions if f.function_name == "getPrefix")
context = js_support.extract_code_context(function=get_prefix_func, project_root=tmp_path, module_root=tmp_path)
@ -1032,7 +1034,7 @@ module.exports = { Factory };
test_file = tmp_path / "factory.js"
test_file.write_text(source)
functions = js_support.discover_functions(test_file)
functions = js_support.discover_functions(test_file.read_text(encoding="utf-8"), test_file)
create_func = next(f for f in functions if f.function_name == "create")
context = js_support.extract_code_context(function=create_func, project_root=tmp_path, module_root=tmp_path)
@ -1074,7 +1076,7 @@ class TestCodeExtractorIntegration:
js_support = get_language_support("javascript")
calculator_file = cjs_project / "calculator.js"
functions = js_support.discover_functions(calculator_file)
functions = js_support.discover_functions(calculator_file.read_text(encoding="utf-8"), calculator_file)
target = next(f for f in functions if f.function_name == "permutation")
parents = [FunctionParent(name=p.name, type=p.type) for p in target.parents]
@ -1099,7 +1101,7 @@ class TestCodeExtractorIntegration:
pytest_cmd="jest",
)
func_optimizer = FunctionOptimizer(
func_optimizer = JavaScriptFunctionOptimizer(
function_to_optimize=func, test_cfg=test_config, aiservice_client=MagicMock()
)
result = func_optimizer.get_code_optimization_context()
@ -1182,7 +1184,7 @@ export function distance(p1: Point, p2: Point): number {
test_file = tmp_path / "geometry.ts"
test_file.write_text(source)
functions = ts_support.discover_functions(test_file)
functions = ts_support.discover_functions(test_file.read_text(encoding="utf-8"), test_file)
distance_func = next(f for f in functions if f.function_name == "distance")
context = ts_support.extract_code_context(function=distance_func, project_root=tmp_path, module_root=tmp_path)
@ -1224,7 +1226,7 @@ export function processStatus(status: Status): string {
test_file = tmp_path / "status.ts"
test_file.write_text(source)
functions = ts_support.discover_functions(test_file)
functions = ts_support.discover_functions(test_file.read_text(encoding="utf-8"), test_file)
process_func = next(f for f in functions if f.function_name == "processStatus")
context = ts_support.extract_code_context(function=process_func, project_root=tmp_path, module_root=tmp_path)
@ -1259,7 +1261,7 @@ export function compute(x: number): Result<number> {
test_file = tmp_path / "compute.ts"
test_file.write_text(source)
functions = ts_support.discover_functions(test_file)
functions = ts_support.discover_functions(test_file.read_text(encoding="utf-8"), test_file)
compute_func = next(f for f in functions if f.function_name == "compute")
context = ts_support.extract_code_context(function=compute_func, project_root=tmp_path, module_root=tmp_path)
@ -1301,7 +1303,7 @@ export class Service {
test_file = tmp_path / "service.ts"
test_file.write_text(source)
functions = ts_support.discover_functions(test_file)
functions = ts_support.discover_functions(test_file.read_text(encoding="utf-8"), test_file)
get_timeout_func = next(f for f in functions if f.function_name == "getTimeout")
context = ts_support.extract_code_context(
@ -1332,7 +1334,7 @@ export function add(a: number, b: number): number {
test_file = tmp_path / "add.ts"
test_file.write_text(source)
functions = ts_support.discover_functions(test_file)
functions = ts_support.discover_functions(test_file.read_text(encoding="utf-8"), test_file)
add_func = next(f for f in functions if f.function_name == "add")
context = ts_support.extract_code_context(function=add_func, project_root=tmp_path, module_root=tmp_path)
@ -1363,7 +1365,7 @@ export function createRect(origin: Point, size: Size): { origin: Point; size: Si
test_file = tmp_path / "rect.ts"
test_file.write_text(source)
functions = ts_support.discover_functions(test_file)
functions = ts_support.discover_functions(test_file.read_text(encoding="utf-8"), test_file)
create_rect_func = next(f for f in functions if f.function_name == "createRect")
context = ts_support.extract_code_context(
@ -1409,7 +1411,7 @@ export function calculateDistance(p1: Point, p2: Point, config: CalculationConfi
}
""")
functions = ts_support.discover_functions(geometry_file)
functions = ts_support.discover_functions(geometry_file.read_text(encoding="utf-8"), geometry_file)
calc_distance_func = next(f for f in functions if f.function_name == "calculateDistance")
context = ts_support.extract_code_context(
@ -1460,7 +1462,7 @@ export function greetUser(user: User): string {
test_file = tmp_path / "user.ts"
test_file.write_text(source)
functions = ts_support.discover_functions(test_file)
functions = ts_support.discover_functions(test_file.read_text(encoding="utf-8"), test_file)
greet_func = next(f for f in functions if f.function_name == "greetUser")
context = ts_support.extract_code_context(function=greet_func, project_root=tmp_path, module_root=tmp_path)

View file

@ -7,6 +7,7 @@ These tests verify that code replacement correctly handles:
- ES Modules (import/export) syntax
- TypeScript import handling
"""
from __future__ import annotations
import shutil
@ -14,8 +15,8 @@ from pathlib import Path
import pytest
from codeflash.languages.python.static_analysis.code_replacer import replace_function_definitions_for_language
from codeflash.languages.base import Language
from codeflash.languages.code_replacer import replace_function_definitions_for_language
from codeflash.languages.current import set_current_language
from codeflash.languages.javascript.module_system import (
ModuleSystem,
@ -25,7 +26,6 @@ from codeflash.languages.javascript.module_system import (
ensure_module_system_compatibility,
get_import_statement,
)
from codeflash.languages.javascript.support import JavaScriptSupport, TypeScriptSupport
from codeflash.models.models import CodeStringsMarkdown
@ -50,7 +50,6 @@ def temp_project(tmp_path):
return project_root
FIXTURES_DIR = Path(__file__).parent / "fixtures"
@ -308,7 +307,9 @@ class TestTsJestSkipsConversion:
When ts-jest is installed, it handles module interoperability internally,
so we skip conversion to avoid breaking valid imports.
"""
def __init__(self):
@pytest.fixture(autouse=True)
def _set_language(self):
set_current_language(Language.TYPESCRIPT)
def test_commonjs_not_converted_when_ts_jest_installed(self, tmp_path):
@ -751,6 +752,7 @@ class TestIntegrationWithFixtures:
f"import statements should be converted to require.\nFound import lines: {import_lines}"
)
class TestSimpleFunctionReplacement:
"""Tests for simple function body replacement with strict assertions."""
@ -764,7 +766,8 @@ export function add(a, b) {
file_path = temp_project / "math.js"
file_path.write_text(original_source, encoding="utf-8")
functions = js_support.discover_functions(file_path)
source = file_path.read_text(encoding="utf-8")
functions = js_support.discover_functions(source, file_path)
func = functions[0]
# Optimized version with different body
@ -800,7 +803,8 @@ export function processData(data) {
file_path = temp_project / "processor.js"
file_path.write_text(original_source, encoding="utf-8")
functions = js_support.discover_functions(file_path)
source = file_path.read_text(encoding="utf-8")
functions = js_support.discover_functions(source, file_path)
func = functions[0]
# Optimized version using map
@ -839,7 +843,8 @@ module.exports = { targetFunction, otherFunction };
file_path = temp_project / "module.js"
file_path.write_text(original_source, encoding="utf-8")
functions = js_support.discover_functions(file_path)
source = file_path.read_text(encoding="utf-8")
functions = js_support.discover_functions(source, file_path)
target_func = next(f for f in functions if f.function_name == "targetFunction")
optimized_code = """\
@ -891,7 +896,8 @@ export class Calculator {
file_path = temp_project / "calculator.js"
file_path.write_text(original_source, encoding="utf-8")
functions = js_support.discover_functions(file_path)
source = file_path.read_text(encoding="utf-8")
functions = js_support.discover_functions(source, file_path)
add_method = next(f for f in functions if f.function_name == "add")
# Optimized version provided in class context
@ -954,7 +960,8 @@ export class DataProcessor {
file_path = temp_project / "processor.js"
file_path.write_text(original_source, encoding="utf-8")
functions = js_support.discover_functions(file_path)
source = file_path.read_text(encoding="utf-8")
functions = js_support.discover_functions(source, file_path)
process_method = next(f for f in functions if f.function_name == "process")
optimized_code = """\
@ -1016,7 +1023,8 @@ export function add(a, b) {
file_path = temp_project / "math.js"
file_path.write_text(original_source, encoding="utf-8")
functions = js_support.discover_functions(file_path)
source = file_path.read_text(encoding="utf-8")
functions = js_support.discover_functions(source, file_path)
func = functions[0]
optimized_code = """\
@ -1070,7 +1078,8 @@ export class Cache {
file_path = temp_project / "cache.js"
file_path.write_text(original_source, encoding="utf-8")
functions = js_support.discover_functions(file_path)
source = file_path.read_text(encoding="utf-8")
functions = js_support.discover_functions(source, file_path)
get_method = next(f for f in functions if f.function_name == "get")
optimized_code = """\
@ -1131,7 +1140,8 @@ export async function fetchData(url) {
file_path = temp_project / "api.js"
file_path.write_text(original_source, encoding="utf-8")
functions = js_support.discover_functions(file_path)
source = file_path.read_text(encoding="utf-8")
functions = js_support.discover_functions(source, file_path)
func = functions[0]
optimized_code = """\
@ -1172,7 +1182,8 @@ export class ApiClient {
file_path = temp_project / "client.js"
file_path.write_text(original_source, encoding="utf-8")
functions = js_support.discover_functions(file_path)
source = file_path.read_text(encoding="utf-8")
functions = js_support.discover_functions(source, file_path)
get_method = next(f for f in functions if f.function_name == "get")
optimized_code = """\
@ -1223,7 +1234,8 @@ export function* range(start, end) {
file_path = temp_project / "generators.js"
file_path.write_text(original_source, encoding="utf-8")
functions = js_support.discover_functions(file_path)
source = file_path.read_text(encoding="utf-8")
functions = js_support.discover_functions(source, file_path)
func = functions[0]
optimized_code = """\
@ -1262,7 +1274,8 @@ export function processArray(items: number[]): number {
file_path = temp_project / "processor.ts"
file_path.write_text(original_source, encoding="utf-8")
functions = ts_support.discover_functions(file_path)
source = file_path.read_text(encoding="utf-8")
functions = ts_support.discover_functions(source, file_path)
func = functions[0]
optimized_code = """\
@ -1303,7 +1316,8 @@ export class Container<T> {
file_path = temp_project / "container.ts"
file_path.write_text(original_source, encoding="utf-8")
functions = ts_support.discover_functions(file_path)
source = file_path.read_text(encoding="utf-8")
functions = ts_support.discover_functions(source, file_path)
get_all_method = next(f for f in functions if f.function_name == "getAll")
optimized_code = """\
@ -1356,7 +1370,8 @@ export function createUser(name: string, email: string): User {
file_path = temp_project / "user.ts"
file_path.write_text(original_source, encoding="utf-8")
functions = ts_support.discover_functions(file_path)
source = file_path.read_text(encoding="utf-8")
functions = ts_support.discover_functions(source, file_path)
func = next(f for f in functions if f.function_name == "createUser")
optimized_code = """\
@ -1411,7 +1426,8 @@ export function processItems(items) {
file_path = temp_project / "processor.js"
file_path.write_text(original_source, encoding="utf-8")
functions = js_support.discover_functions(file_path)
source = file_path.read_text(encoding="utf-8")
functions = js_support.discover_functions(source, file_path)
process_func = next(f for f in functions if f.function_name == "processItems")
optimized_code = """\
@ -1458,7 +1474,8 @@ export class MathUtils {
file_path.write_text(original_source, encoding="utf-8")
# First replacement: sum method
functions = js_support.discover_functions(file_path)
source = file_path.read_text(encoding="utf-8")
functions = js_support.discover_functions(source, file_path)
sum_method = next(f for f in functions if f.function_name == "sum")
optimized_sum = """\
@ -1505,7 +1522,8 @@ export function processConfig({ server: { host, port }, database: { url, poolSiz
file_path = temp_project / "config.js"
file_path.write_text(original_source, encoding="utf-8")
functions = js_support.discover_functions(file_path)
source = file_path.read_text(encoding="utf-8")
functions = js_support.discover_functions(source, file_path)
func = functions[0]
optimized_code = """\
@ -1544,7 +1562,8 @@ export function minimal() {
file_path = temp_project / "minimal.js"
file_path.write_text(original_source, encoding="utf-8")
functions = js_support.discover_functions(file_path)
source = file_path.read_text(encoding="utf-8")
functions = js_support.discover_functions(source, file_path)
func = functions[0]
optimized_code = """\
@ -1571,7 +1590,8 @@ export function identity(x) { return x; }
file_path = temp_project / "utils.js"
file_path.write_text(original_source, encoding="utf-8")
functions = js_support.discover_functions(file_path)
source = file_path.read_text(encoding="utf-8")
functions = js_support.discover_functions(source, file_path)
func = functions[0]
optimized_code = """\
@ -1598,7 +1618,8 @@ export function formatMessage(name) {
file_path = temp_project / "formatter.js"
file_path.write_text(original_source, encoding="utf-8")
functions = js_support.discover_functions(file_path)
source = file_path.read_text(encoding="utf-8")
functions = js_support.discover_functions(source, file_path)
func = functions[0]
optimized_code = """\
@ -1633,7 +1654,8 @@ export function validateEmail(email) {
file_path = temp_project / "validator.js"
file_path.write_text(original_source, encoding="utf-8")
functions = js_support.discover_functions(file_path)
source = file_path.read_text(encoding="utf-8")
functions = js_support.discover_functions(source, file_path)
func = functions[0]
optimized_code = """\
@ -1676,7 +1698,8 @@ module.exports = { main, helper };
file_path = temp_project / "module.js"
file_path.write_text(original_source, encoding="utf-8")
functions = js_support.discover_functions(file_path)
source = file_path.read_text(encoding="utf-8")
functions = js_support.discover_functions(source, file_path)
main_func = next(f for f in functions if f.function_name == "main")
optimized_code = """\
@ -1719,7 +1742,8 @@ export function main(data) {
file_path = temp_project / "module.js"
file_path.write_text(original_source, encoding="utf-8")
functions = js_support.discover_functions(file_path)
source = file_path.read_text(encoding="utf-8")
functions = js_support.discover_functions(source, file_path)
main_func = next(f for f in functions if f.function_name == "main")
optimized_code = """\
@ -1750,20 +1774,16 @@ class TestSyntaxValidation:
"""Test that various replacements all produce valid JavaScript."""
test_cases = [
# (original, optimized, description)
(
"export function f(x) { return x + 1; }",
"export function f(x) { return ++x; }",
"increment replacement"
),
("export function f(x) { return x + 1; }", "export function f(x) { return ++x; }", "increment replacement"),
(
"export function f(arr) { return arr.length > 0; }",
"export function f(arr) { return !!arr.length; }",
"boolean conversion"
"boolean conversion",
),
(
"export function f(a, b) { if (a) { return a; } return b; }",
"export function f(a, b) { return a || b; }",
"logical OR replacement"
"logical OR replacement",
),
]
@ -1771,7 +1791,8 @@ class TestSyntaxValidation:
file_path = temp_project / f"test_{i}.js"
file_path.write_text(original, encoding="utf-8")
functions = js_support.discover_functions(file_path)
source = file_path.read_text(encoding="utf-8")
functions = js_support.discover_functions(source, file_path)
func = functions[0]
result = js_support.replace_function(original, func, optimized)
@ -1875,7 +1896,8 @@ export class DataProcessor<T> {
target_func = "findDuplicates"
parent_class = "DataProcessor"
functions = ts_support.discover_functions(file_path)
source = file_path.read_text(encoding="utf-8")
functions = ts_support.discover_functions(source, file_path)
# find function
target_func_info = None
for func in functions:
@ -1920,11 +1942,15 @@ class DataProcessor<T> {
```
"""
code_markdown = CodeStringsMarkdown.parse_markdown_code(new_code)
replaced = replace_function_definitions_for_language([f"{parent_class}.{target_func}"], code_markdown, file_path, temp_project)
replaced = replace_function_definitions_for_language(
[f"{parent_class}.{target_func}"], code_markdown, file_path, temp_project, lang_support=ts_support
)
assert replaced
new_code = file_path.read_text()
assert new_code == """/**
assert (
new_code
== """/**
* DataProcessor class - demonstrates class method optimization in TypeScript.
* Contains intentionally inefficient implementations for optimization testing.
*/
@ -2015,7 +2041,7 @@ export class DataProcessor<T> {
}
}
"""
)
class TestNewVariableFromOptimizedCode:
@ -2030,9 +2056,9 @@ class TestNewVariableFromOptimizedCode:
1. Add the new variable after the constant it references
2. Replace the function with the optimized version
"""
from codeflash.models.models import CodeStringsMarkdown, CodeString
from codeflash.models.models import CodeString, CodeStringsMarkdown
original_source = '''\
original_source = """\
const CODEFLASH_EMPLOYEE_GITHUB_IDS = new Set([
"1234",
]);
@ -2040,43 +2066,34 @@ const CODEFLASH_EMPLOYEE_GITHUB_IDS = new Set([
export function isCodeflashEmployee(userId: string): boolean {
return CODEFLASH_EMPLOYEE_GITHUB_IDS.has(userId);
}
'''
"""
file_path = temp_project / "auth.ts"
file_path.write_text(original_source, encoding="utf-8")
# Optimized code introduces a bound method variable for performance
optimized_code = '''const _has: (id: string) => boolean = CODEFLASH_EMPLOYEE_GITHUB_IDS.has.bind(
optimized_code = """const _has: (id: string) => boolean = CODEFLASH_EMPLOYEE_GITHUB_IDS.has.bind(
CODEFLASH_EMPLOYEE_GITHUB_IDS
);
export function isCodeflashEmployee(userId: string): boolean {
return _has(userId);
}
'''
"""
code_markdown = CodeStringsMarkdown(
code_strings=[
CodeString(
code=optimized_code,
file_path=Path("auth.ts"),
language="typescript"
)
],
language="typescript"
code_strings=[CodeString(code=optimized_code, file_path=Path("auth.ts"), language="typescript")],
language="typescript",
)
replaced = replace_function_definitions_for_language(
["isCodeflashEmployee"],
code_markdown,
file_path,
temp_project,
["isCodeflashEmployee"], code_markdown, file_path, temp_project, lang_support=ts_support
)
assert replaced
result = file_path.read_text()
# Expected result for strict equality check
expected_result = '''\
expected_result = """\
const CODEFLASH_EMPLOYEE_GITHUB_IDS = new Set([
"1234",
]);
@ -2088,11 +2105,9 @@ const _has: (id: string) => boolean = CODEFLASH_EMPLOYEE_GITHUB_IDS.has.bind(
export function isCodeflashEmployee(userId: string): boolean {
return _has(userId);
}
'''
"""
assert result == expected_result, (
f"Result does not match expected output.\n"
f"Expected:\n{expected_result}\n\n"
f"Got:\n{result}"
f"Result does not match expected output.\nExpected:\n{expected_result}\n\nGot:\n{result}"
)
@ -2113,7 +2128,7 @@ class TestImportedTypeNotDuplicated:
contains the TreeNode interface definition (from read-only context),
the replacement should NOT add the interface to the original file.
"""
from codeflash.models.models import CodeStringsMarkdown, CodeString
from codeflash.models.models import CodeString, CodeStringsMarkdown
# Original source imports TreeNode
original_source = """\
@ -2163,20 +2178,13 @@ export function getNearestAbove(
code_markdown = CodeStringsMarkdown(
code_strings=[
CodeString(
code=optimized_code_with_interface,
file_path=Path("helpers.ts"),
language="typescript"
)
CodeString(code=optimized_code_with_interface, file_path=Path("helpers.ts"), language="typescript")
],
language="typescript"
language="typescript",
)
replace_function_definitions_for_language(
["getNearestAbove"],
code_markdown,
file_path,
temp_project,
["getNearestAbove"], code_markdown, file_path, temp_project, lang_support=ts_support
)
result = file_path.read_text()
@ -2203,7 +2211,7 @@ export function getNearestAbove(
def test_multiple_imported_types_not_duplicated(self, ts_support, temp_project):
"""Test that multiple imported types are not duplicated."""
from codeflash.models.models import CodeStringsMarkdown, CodeString
from codeflash.models.models import CodeString, CodeStringsMarkdown
original_source = """\
import type { TreeNode, NodeSpace } from "./constants";
@ -2235,21 +2243,12 @@ export function processNode(node: TreeNode, space: NodeSpace): number {
"""
code_markdown = CodeStringsMarkdown(
code_strings=[
CodeString(
code=optimized_code,
file_path=Path("processor.ts"),
language="typescript"
)
],
language="typescript"
code_strings=[CodeString(code=optimized_code, file_path=Path("processor.ts"), language="typescript")],
language="typescript",
)
replace_function_definitions_for_language(
["processNode"],
code_markdown,
file_path,
temp_project,
["processNode"], code_markdown, file_path, temp_project, lang_support=ts_support
)
result = file_path.read_text()

View file

@ -345,8 +345,8 @@ class TestDiscoverFunctionsParity:
py_file = write_temp_file(SIMPLE_FUNCTION.python, ".py")
js_file = write_temp_file(SIMPLE_FUNCTION.javascript, ".js")
py_funcs = python_support.discover_functions(py_file)
js_funcs = js_support.discover_functions(js_file)
py_funcs = python_support.discover_functions(py_file.read_text(encoding="utf-8"), py_file)
js_funcs = js_support.discover_functions(js_file.read_text(encoding="utf-8"), js_file)
# Both should find exactly one function
assert len(py_funcs) == 1, f"Python found {len(py_funcs)}, expected 1"
@ -365,8 +365,8 @@ class TestDiscoverFunctionsParity:
py_file = write_temp_file(MULTIPLE_FUNCTIONS.python, ".py")
js_file = write_temp_file(MULTIPLE_FUNCTIONS.javascript, ".js")
py_funcs = python_support.discover_functions(py_file)
js_funcs = js_support.discover_functions(js_file)
py_funcs = python_support.discover_functions(py_file.read_text(encoding="utf-8"), py_file)
js_funcs = js_support.discover_functions(js_file.read_text(encoding="utf-8"), js_file)
# Both should find 3 functions
assert len(py_funcs) == 3, f"Python found {len(py_funcs)}, expected 3"
@ -384,8 +384,8 @@ class TestDiscoverFunctionsParity:
py_file = write_temp_file(WITH_AND_WITHOUT_RETURN.python, ".py")
js_file = write_temp_file(WITH_AND_WITHOUT_RETURN.javascript, ".js")
py_funcs = python_support.discover_functions(py_file)
js_funcs = js_support.discover_functions(js_file)
py_funcs = python_support.discover_functions(py_file.read_text(encoding="utf-8"), py_file)
js_funcs = js_support.discover_functions(js_file.read_text(encoding="utf-8"), js_file)
# Both should find only 1 function (the one with return)
assert len(py_funcs) == 1, f"Python found {len(py_funcs)}, expected 1"
@ -400,8 +400,8 @@ class TestDiscoverFunctionsParity:
py_file = write_temp_file(CLASS_METHODS.python, ".py")
js_file = write_temp_file(CLASS_METHODS.javascript, ".js")
py_funcs = python_support.discover_functions(py_file)
js_funcs = js_support.discover_functions(js_file)
py_funcs = python_support.discover_functions(py_file.read_text(encoding="utf-8"), py_file)
js_funcs = js_support.discover_functions(js_file.read_text(encoding="utf-8"), js_file)
# Both should find 2 methods
assert len(py_funcs) == 2, f"Python found {len(py_funcs)}, expected 2"
@ -421,8 +421,8 @@ class TestDiscoverFunctionsParity:
py_file = write_temp_file(ASYNC_FUNCTIONS.python, ".py")
js_file = write_temp_file(ASYNC_FUNCTIONS.javascript, ".js")
py_funcs = python_support.discover_functions(py_file)
js_funcs = js_support.discover_functions(js_file)
py_funcs = python_support.discover_functions(py_file.read_text(encoding="utf-8"), py_file)
js_funcs = js_support.discover_functions(js_file.read_text(encoding="utf-8"), js_file)
# Both should find 2 functions
assert len(py_funcs) == 2, f"Python found {len(py_funcs)}, expected 2"
@ -444,8 +444,8 @@ class TestDiscoverFunctionsParity:
py_file = write_temp_file(NESTED_FUNCTIONS.python, ".py")
js_file = write_temp_file(NESTED_FUNCTIONS.javascript, ".js")
py_funcs = python_support.discover_functions(py_file)
js_funcs = js_support.discover_functions(js_file)
py_funcs = python_support.discover_functions(py_file.read_text(encoding="utf-8"), py_file)
js_funcs = js_support.discover_functions(js_file.read_text(encoding="utf-8"), js_file)
# Python skips nested functions — only outer is discovered
assert len(py_funcs) == 1, f"Python found {len(py_funcs)}, expected 1"
@ -465,8 +465,8 @@ class TestDiscoverFunctionsParity:
py_file = write_temp_file(STATIC_METHODS.python, ".py")
js_file = write_temp_file(STATIC_METHODS.javascript, ".js")
py_funcs = python_support.discover_functions(py_file)
js_funcs = js_support.discover_functions(js_file)
py_funcs = python_support.discover_functions(py_file.read_text(encoding="utf-8"), py_file)
js_funcs = js_support.discover_functions(js_file.read_text(encoding="utf-8"), js_file)
# Both should find 1 function
assert len(py_funcs) == 1, f"Python found {len(py_funcs)}, expected 1"
@ -483,8 +483,8 @@ class TestDiscoverFunctionsParity:
py_file = write_temp_file(COMPLEX_FILE.python, ".py")
js_file = write_temp_file(COMPLEX_FILE.javascript, ".js")
py_funcs = python_support.discover_functions(py_file)
js_funcs = js_support.discover_functions(js_file)
py_funcs = python_support.discover_functions(py_file.read_text(encoding="utf-8"), py_file)
js_funcs = js_support.discover_functions(js_file.read_text(encoding="utf-8"), js_file)
# Both should find 4 functions
assert len(py_funcs) == 4, f"Python found {len(py_funcs)}, expected 4"
@ -515,8 +515,8 @@ class TestDiscoverFunctionsParity:
criteria = FunctionFilterCriteria(include_async=False)
py_funcs = python_support.discover_functions(py_file, criteria)
js_funcs = js_support.discover_functions(js_file, criteria)
py_funcs = python_support.discover_functions(py_file.read_text(encoding="utf-8"), py_file, criteria)
js_funcs = js_support.discover_functions(js_file.read_text(encoding="utf-8"), js_file, criteria)
# Both should find only 1 function (the sync one)
assert len(py_funcs) == 1, f"Python found {len(py_funcs)}, expected 1"
@ -533,8 +533,8 @@ class TestDiscoverFunctionsParity:
criteria = FunctionFilterCriteria(include_methods=False)
py_funcs = python_support.discover_functions(py_file, criteria)
js_funcs = js_support.discover_functions(js_file, criteria)
py_funcs = python_support.discover_functions(py_file.read_text(encoding="utf-8"), py_file, criteria)
js_funcs = js_support.discover_functions(js_file.read_text(encoding="utf-8"), js_file, criteria)
# Both should find only 1 function (standalone)
assert len(py_funcs) == 1, f"Python found {len(py_funcs)}, expected 1"
@ -545,11 +545,11 @@ class TestDiscoverFunctionsParity:
assert js_funcs[0].function_name == "standalone"
def test_nonexistent_file_returns_empty(self, python_support, js_support):
"""Python raises on nonexistent files; JavaScript returns empty list."""
with pytest.raises(FileNotFoundError):
python_support.discover_functions(Path("/nonexistent/file.py"))
"""Both languages return empty list for empty source."""
py_funcs = python_support.discover_functions("", Path("/nonexistent/file.py"))
assert py_funcs == []
js_funcs = js_support.discover_functions(Path("/nonexistent/file.js"))
js_funcs = js_support.discover_functions("", Path("/nonexistent/file.js"))
assert js_funcs == []
def test_line_numbers_captured(self, python_support, js_support):
@ -557,8 +557,8 @@ class TestDiscoverFunctionsParity:
py_file = write_temp_file(SIMPLE_FUNCTION.python, ".py")
js_file = write_temp_file(SIMPLE_FUNCTION.javascript, ".js")
py_funcs = python_support.discover_functions(py_file)
js_funcs = js_support.discover_functions(js_file)
py_funcs = python_support.discover_functions(py_file.read_text(encoding="utf-8"), py_file)
js_funcs = js_support.discover_functions(js_file.read_text(encoding="utf-8"), js_file)
# Both should have start_line and end_line
assert py_funcs[0].starting_line is not None
@ -908,8 +908,8 @@ class TestIntegrationParity:
js_file = write_temp_file(js_original, ".js")
# Discover
py_funcs = python_support.discover_functions(py_file)
js_funcs = js_support.discover_functions(js_file)
py_funcs = python_support.discover_functions(py_file.read_text(encoding="utf-8"), py_file)
js_funcs = js_support.discover_functions(js_file.read_text(encoding="utf-8"), js_file)
assert len(py_funcs) == 1
assert len(js_funcs) == 1
@ -960,8 +960,8 @@ class TestFeatureGaps:
py_file = write_temp_file(CLASS_METHODS.python, ".py")
js_file = write_temp_file(CLASS_METHODS.javascript, ".js")
py_funcs = python_support.discover_functions(py_file)
js_funcs = js_support.discover_functions(js_file)
py_funcs = python_support.discover_functions(py_file.read_text(encoding="utf-8"), py_file)
js_funcs = js_support.discover_functions(js_file.read_text(encoding="utf-8"), js_file)
for py_func in py_funcs:
# Check all expected fields are populated
@ -994,7 +994,7 @@ export const multiply = (x, y) => x * y;
export const identity = x => x;
"""
js_file = write_temp_file(js_code, ".js")
funcs = js_support.discover_functions(js_file)
funcs = js_support.discover_functions(js_file.read_text(encoding="utf-8"), js_file)
# Should find all arrow functions
names = {f.function_name for f in funcs}
@ -1021,8 +1021,8 @@ export function* numberGenerator() {
py_file = write_temp_file(py_code, ".py")
js_file = write_temp_file(js_code, ".js")
py_funcs = python_support.discover_functions(py_file)
js_funcs = js_support.discover_functions(js_file)
py_funcs = python_support.discover_functions(py_file.read_text(encoding="utf-8"), py_file)
js_funcs = js_support.discover_functions(js_file.read_text(encoding="utf-8"), js_file)
# Both should find the generator
assert len(py_funcs) == 1, f"Python found {len(py_funcs)} generators"
@ -1045,7 +1045,7 @@ def multi_decorated():
return 3
"""
py_file = write_temp_file(py_code, ".py")
funcs = python_support.discover_functions(py_file)
funcs = python_support.discover_functions(py_file.read_text(encoding="utf-8"), py_file)
# Should find all functions regardless of decorators
names = {f.function_name for f in funcs}
@ -1065,7 +1065,7 @@ export const namedExpr = function myFunc(x) {
};
"""
js_file = write_temp_file(js_code, ".js")
funcs = js_support.discover_functions(js_file)
funcs = js_support.discover_functions(js_file.read_text(encoding="utf-8"), js_file)
# Should find function expressions
names = {f.function_name for f in funcs}
@ -1085,8 +1085,8 @@ class TestEdgeCases:
py_file = write_temp_file("", ".py")
js_file = write_temp_file("", ".js")
py_funcs = python_support.discover_functions(py_file)
js_funcs = js_support.discover_functions(js_file)
py_funcs = python_support.discover_functions(py_file.read_text(encoding="utf-8"), py_file)
js_funcs = js_support.discover_functions(js_file.read_text(encoding="utf-8"), js_file)
assert py_funcs == []
assert js_funcs == []
@ -1110,8 +1110,8 @@ Multiline comment
py_file = write_temp_file(py_code, ".py")
js_file = write_temp_file(js_code, ".js")
py_funcs = python_support.discover_functions(py_file)
js_funcs = js_support.discover_functions(js_file)
py_funcs = python_support.discover_functions(py_file.read_text(encoding="utf-8"), py_file)
js_funcs = js_support.discover_functions(js_file.read_text(encoding="utf-8"), js_file)
assert py_funcs == []
assert js_funcs == []
@ -1130,8 +1130,8 @@ export function greeting() {
py_file = write_temp_file(py_code, ".py")
js_file = write_temp_file(js_code, ".js")
py_funcs = python_support.discover_functions(py_file)
js_funcs = js_support.discover_functions(js_file)
py_funcs = python_support.discover_functions(py_file.read_text(encoding="utf-8"), py_file)
js_funcs = js_support.discover_functions(js_file.read_text(encoding="utf-8"), js_file)
assert len(py_funcs) == 1
assert len(js_funcs) == 1

View file

@ -82,9 +82,9 @@ from pathlib import Path
from unittest.mock import MagicMock
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
from codeflash.languages.javascript.function_optimizer import JavaScriptFunctionOptimizer
from codeflash.languages.registry import get_language_support
from codeflash.models.models import CodeOptimizationContext, CodeStringsMarkdown
from codeflash.optimization.function_optimizer import FunctionOptimizer
from codeflash.verification.verification_utils import TestConfig
@ -110,7 +110,7 @@ def test_js_replcement() -> None:
original_helper = helper_file.read_text("utf-8")
js_support = get_language_support("javascript")
functions = js_support.discover_functions(main_file)
functions = js_support.discover_functions(main_file.read_text(encoding="utf-8"), main_file)
target = None
for func in functions:
if func.function_name == "calculateStats":
@ -135,7 +135,7 @@ def test_js_replcement() -> None:
project_root_path=root_dir,
pytest_cmd="jest",
)
func_optimizer = FunctionOptimizer(
func_optimizer = JavaScriptFunctionOptimizer(
function_to_optimize=func, test_cfg=test_config, aiservice_client=MagicMock()
)
result = func_optimizer.get_code_optimization_context()

View file

@ -49,7 +49,7 @@ def add(a, b):
""")
f.flush()
functions = python_support.discover_functions(Path(f.name))
functions = python_support.discover_functions(Path(f.name).read_text(encoding="utf-8"), Path(f.name))
assert len(functions) == 1
assert functions[0].function_name == "add"
@ -70,7 +70,7 @@ def multiply(a, b):
""")
f.flush()
functions = python_support.discover_functions(Path(f.name))
functions = python_support.discover_functions(Path(f.name).read_text(encoding="utf-8"), Path(f.name))
assert len(functions) == 3
names = {func.function_name for func in functions}
@ -88,7 +88,7 @@ def without_return():
""")
f.flush()
functions = python_support.discover_functions(Path(f.name))
functions = python_support.discover_functions(Path(f.name).read_text(encoding="utf-8"), Path(f.name))
# Only the function with return should be discovered
assert len(functions) == 1
@ -107,7 +107,7 @@ class Calculator:
""")
f.flush()
functions = python_support.discover_functions(Path(f.name))
functions = python_support.discover_functions(Path(f.name).read_text(encoding="utf-8"), Path(f.name))
assert len(functions) == 2
for func in functions:
@ -126,7 +126,7 @@ def sync_function():
""")
f.flush()
functions = python_support.discover_functions(Path(f.name))
functions = python_support.discover_functions(Path(f.name).read_text(encoding="utf-8"), Path(f.name))
assert len(functions) == 2
@ -147,7 +147,7 @@ def outer():
""")
f.flush()
functions = python_support.discover_functions(Path(f.name))
functions = python_support.discover_functions(Path(f.name).read_text(encoding="utf-8"), Path(f.name))
# Only outer should be discovered; inner is nested and skipped
assert len(functions) == 1
@ -164,7 +164,7 @@ class Utils:
""")
f.flush()
functions = python_support.discover_functions(Path(f.name))
functions = python_support.discover_functions(Path(f.name).read_text(encoding="utf-8"), Path(f.name))
assert len(functions) == 1
assert functions[0].function_name == "helper"
@ -183,7 +183,9 @@ def sync_func():
f.flush()
criteria = FunctionFilterCriteria(include_async=False)
functions = python_support.discover_functions(Path(f.name), criteria)
functions = python_support.discover_functions(
Path(f.name).read_text(encoding="utf-8"), Path(f.name), criteria
)
assert len(functions) == 1
assert functions[0].function_name == "sync_func"
@ -202,7 +204,9 @@ class MyClass:
f.flush()
criteria = FunctionFilterCriteria(include_methods=False)
functions = python_support.discover_functions(Path(f.name), criteria)
functions = python_support.discover_functions(
Path(f.name).read_text(encoding="utf-8"), Path(f.name), criteria
)
assert len(functions) == 1
assert functions[0].function_name == "standalone"
@ -220,7 +224,7 @@ def func2():
""")
f.flush()
functions = python_support.discover_functions(Path(f.name))
functions = python_support.discover_functions(Path(f.name).read_text(encoding="utf-8"), Path(f.name))
func1 = next(f for f in functions if f.function_name == "func1")
func2 = next(f for f in functions if f.function_name == "func2")
@ -239,12 +243,12 @@ def func2():
f.flush()
with pytest.raises(ParserSyntaxError):
python_support.discover_functions(Path(f.name))
python_support.discover_functions(Path(f.name).read_text(encoding="utf-8"), Path(f.name))
def test_discover_nonexistent_file_raises(self, python_support):
"""Test that nonexistent file raises FileNotFoundError."""
with pytest.raises(FileNotFoundError):
python_support.discover_functions(Path("/nonexistent/file.py"))
def test_discover_empty_source_returns_empty(self, python_support):
"""Test that empty source returns empty list."""
functions = python_support.discover_functions("", Path("/nonexistent/file.py"))
assert functions == []
class TestReplaceFunction:
@ -495,7 +499,7 @@ class TestIntegration:
file_path = Path(f.name)
# Discover
functions = python_support.discover_functions(file_path)
functions = python_support.discover_functions(file_path.read_text(encoding="utf-8"), file_path)
assert len(functions) == 1
func = functions[0]
assert func.function_name == "fibonacci"
@ -536,7 +540,7 @@ def standalone():
f.flush()
file_path = Path(f.name)
functions = python_support.discover_functions(file_path)
functions = python_support.discover_functions(file_path.read_text(encoding="utf-8"), file_path)
# Should find 4 functions
assert len(functions) == 4

View file

@ -13,7 +13,7 @@ from pathlib import Path
import pytest
from codeflash.languages.base import FunctionInfo, Language, ParentInfo
from codeflash.languages.base import Language
from codeflash.languages.javascript.support import TypeScriptSupport
@ -126,14 +126,13 @@ export function add(a: number, b: number): number {
f.flush()
file_path = Path(f.name)
functions = ts_support.discover_functions(file_path)
source = file_path.read_text(encoding="utf-8")
functions = ts_support.discover_functions(source, file_path)
assert len(functions) == 1
assert functions[0].function_name == "add"
# Extract code context
code_context = ts_support.extract_code_context(
functions[0], file_path.parent, file_path.parent
)
code_context = ts_support.extract_code_context(functions[0], file_path.parent, file_path.parent)
# Verify extracted code is valid
assert ts_support.validate_syntax(code_context.target_code) is True
@ -164,14 +163,13 @@ export async function execMongoEval(queryExpression, appsmithMongoURI) {
f.flush()
file_path = Path(f.name)
functions = ts_support.discover_functions(file_path)
source = file_path.read_text(encoding="utf-8")
functions = ts_support.discover_functions(source, file_path)
assert len(functions) == 1
assert functions[0].function_name == "execMongoEval"
# Extract code context
code_context = ts_support.extract_code_context(
functions[0], file_path.parent, file_path.parent
)
code_context = ts_support.extract_code_context(functions[0], file_path.parent, file_path.parent)
# Verify extracted code is valid
assert ts_support.validate_syntax(code_context.target_code) is True
@ -215,14 +213,13 @@ export async function figureOutContentsPath(root: string): Promise<string> {
f.flush()
file_path = Path(f.name)
functions = ts_support.discover_functions(file_path)
source = file_path.read_text(encoding="utf-8")
functions = ts_support.discover_functions(source, file_path)
assert len(functions) == 1
assert functions[0].function_name == "figureOutContentsPath"
# Extract code context
code_context = ts_support.extract_code_context(
functions[0], file_path.parent, file_path.parent
)
code_context = ts_support.extract_code_context(functions[0], file_path.parent, file_path.parent)
# Verify extracted code is valid
assert ts_support.validate_syntax(code_context.target_code) is True
@ -246,12 +243,11 @@ export function readConfig(filename: string): string {
f.flush()
file_path = Path(f.name)
functions = ts_support.discover_functions(file_path)
source = file_path.read_text(encoding="utf-8")
functions = ts_support.discover_functions(source, file_path)
assert len(functions) == 1
code_context = ts_support.extract_code_context(
functions[0], file_path.parent, file_path.parent
)
code_context = ts_support.extract_code_context(functions[0], file_path.parent, file_path.parent)
# Check that imports are captured
assert len(code_context.imports) > 0
@ -278,12 +274,11 @@ export async function fetchWithRetry(url: string): Promise<any> {
f.flush()
file_path = Path(f.name)
functions = ts_support.discover_functions(file_path)
source = file_path.read_text(encoding="utf-8")
functions = ts_support.discover_functions(source, file_path)
assert len(functions) == 1
code_context = ts_support.extract_code_context(
functions[0], file_path.parent, file_path.parent
)
code_context = ts_support.extract_code_context(functions[0], file_path.parent, file_path.parent)
# Verify extracted code is valid
assert ts_support.validate_syntax(code_context.target_code) is True
@ -324,7 +319,8 @@ export class EndpointGroup {
file_path = Path(f.name)
# Discover the 'post' method
functions = ts_support.discover_functions(file_path)
source = file_path.read_text(encoding="utf-8")
functions = ts_support.discover_functions(source, file_path)
post_method = None
for func in functions:
if func.function_name == "post":
@ -334,9 +330,7 @@ export class EndpointGroup {
assert post_method is not None, "post method should be discovered"
# Extract code context
code_context = ts_support.extract_code_context(
post_method, file_path.parent, file_path.parent
)
code_context = ts_support.extract_code_context(post_method, file_path.parent, file_path.parent)
# The extracted code should be syntactically valid
assert ts_support.validate_syntax(code_context.target_code) is True, (
@ -352,9 +346,7 @@ export class EndpointGroup {
# Check that addEndpoint appears BEFORE the closing brace of the class
class_end_index = code_context.target_code.rfind("}")
add_endpoint_index = code_context.target_code.find("addEndpoint")
assert add_endpoint_index < class_end_index, (
"addEndpoint should be inside the class wrapper"
)
assert add_endpoint_index < class_end_index, "addEndpoint should be inside the class wrapper"
def test_multiple_private_helpers_inside_class(self, ts_support):
"""Test that multiple private helpers are all included inside the class."""
@ -386,7 +378,8 @@ export class Router {
file_path = Path(f.name)
# Discover the 'addRoute' method
functions = ts_support.discover_functions(file_path)
source = file_path.read_text(encoding="utf-8")
functions = ts_support.discover_functions(source, file_path)
add_route_method = None
for func in functions:
if func.function_name == "addRoute":
@ -395,9 +388,7 @@ export class Router {
assert add_route_method is not None
code_context = ts_support.extract_code_context(
add_route_method, file_path.parent, file_path.parent
)
code_context = ts_support.extract_code_context(add_route_method, file_path.parent, file_path.parent)
# Should be valid TypeScript
assert ts_support.validate_syntax(code_context.target_code) is True
@ -424,7 +415,8 @@ export class Calculator {
f.flush()
file_path = Path(f.name)
functions = ts_support.discover_functions(file_path)
source = file_path.read_text(encoding="utf-8")
functions = ts_support.discover_functions(source, file_path)
add_method = None
for func in functions:
if func.function_name == "add":
@ -433,18 +425,14 @@ export class Calculator {
assert add_method is not None
code_context = ts_support.extract_code_context(
add_method, file_path.parent, file_path.parent
)
code_context = ts_support.extract_code_context(add_method, file_path.parent, file_path.parent)
# 'compute' should be in target_code (inside class)
assert "compute" in code_context.target_code
# 'compute' should NOT be in helper_functions (would be duplicate)
helper_names = [h.name for h in code_context.helper_functions]
assert "compute" not in helper_names, (
"Same-class helper 'compute' should not be in helper_functions list"
)
assert "compute" not in helper_names, "Same-class helper 'compute' should not be in helper_functions list"
class TestTypeScriptLanguageProperties:

View file

@ -124,10 +124,8 @@ class TestTypeScriptCodeContext:
"""Test extracting code context for a TypeScript function."""
skip_if_ts_not_supported()
from codeflash.discovery.functions_to_optimize import find_all_functions_in_file
from codeflash.languages import current as lang_current
from codeflash.languages.python.context.code_context_extractor import get_code_optimization_context
lang_current._current_language = Language.TYPESCRIPT
from codeflash.languages import get_language_support
from codeflash.languages.javascript.function_optimizer import JavaScriptFunctionOptimizer
fib_file = ts_project_dir / "fibonacci.ts"
if not fib_file.exists():
@ -139,7 +137,11 @@ class TestTypeScriptCodeContext:
fib_func = next((f for f in func_list if f.function_name == "fibonacci"), None)
assert fib_func is not None
context = get_code_optimization_context(fib_func, ts_project_dir)
ts_support = get_language_support(Language.TYPESCRIPT)
code_context = ts_support.extract_code_context(fib_func, ts_project_dir, ts_project_dir)
context = JavaScriptFunctionOptimizer._build_optimization_context(
code_context, fib_file, "typescript", ts_project_dir
)
assert context.read_writable_code is not None
# Critical: language should be "typescript", not "javascript"

View file

@ -118,11 +118,9 @@ class TestVitestCodeContext:
"""Test extracting code context for a TypeScript function."""
skip_if_js_not_supported()
from codeflash.discovery.functions_to_optimize import find_all_functions_in_file
from codeflash.languages import current as lang_current
from codeflash.languages import get_language_support
from codeflash.languages.base import Language
from codeflash.languages.python.context.code_context_extractor import get_code_optimization_context
lang_current._current_language = Language.TYPESCRIPT
from codeflash.languages.javascript.function_optimizer import JavaScriptFunctionOptimizer
fib_file = vitest_project_dir / "fibonacci.ts"
if not fib_file.exists():
@ -134,7 +132,11 @@ class TestVitestCodeContext:
fib_func = next((f for f in func_list if f.function_name == "fibonacci"), None)
assert fib_func is not None
context = get_code_optimization_context(fib_func, vitest_project_dir)
ts_support = get_language_support(Language.TYPESCRIPT)
code_context = ts_support.extract_code_context(fib_func, vitest_project_dir, vitest_project_dir)
context = JavaScriptFunctionOptimizer._build_optimization_context(
code_context, fib_file, "typescript", vitest_project_dir
)
assert context.read_writable_code is not None
assert context.read_writable_code.language == "typescript"

View file

@ -2,7 +2,7 @@ from pathlib import Path
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
from codeflash.models.models import CodeOptimizationContext, CodeStringsMarkdown
from codeflash.optimization.function_optimizer import FunctionOptimizer
from codeflash.languages.python.function_optimizer import PythonFunctionOptimizer
from codeflash.verification.verification_utils import TestConfig
@ -106,7 +106,7 @@ def _get_string_usage(text: str) -> Usage:
test_framework="pytest",
pytest_cmd="pytest",
)
func_optimizer = FunctionOptimizer(function_to_optimize=func, test_cfg=test_config)
func_optimizer = PythonFunctionOptimizer(function_to_optimize=func, test_cfg=test_config)
code_context: CodeOptimizationContext = func_optimizer.get_code_optimization_context().unwrap()
original_helper_code: dict[Path, str] = {}

View file

@ -3,9 +3,9 @@ import tempfile
from pathlib import Path
from codeflash.code_utils.code_utils import ImportErrorPattern
from codeflash.languages import current_language_support
from codeflash.models.models import TestFile, TestFiles, TestType
from codeflash.verification.parse_test_output import parse_test_xml
from codeflash.verification.test_runner import run_behavioral_tests
from codeflash.verification.verification_utils import TestConfig
@ -48,8 +48,8 @@ class TestUnittestRunnerSorter(unittest.TestCase):
test_files=[TestFile(instrumented_behavior_file_path=test_file_path, test_type=TestType.EXISTING_UNIT_TEST)]
)
test_file_path.write_text(code, encoding="utf-8")
result_file, process, _, _ = run_behavioral_tests(
test_files, test_framework=config.test_framework, cwd=Path(config.project_root_path), test_env=test_env
result_file, process, _, _ = current_language_support().run_behavioral_tests(
test_paths=test_files, test_env=test_env, cwd=Path(config.project_root_path)
)
results = parse_test_xml(result_file, test_files, config, process)
assert results[0].did_pass, "Test did not pass as expected"
@ -89,13 +89,8 @@ def test_sort():
test_files=[TestFile(instrumented_behavior_file_path=test_file_path, test_type=TestType.EXISTING_UNIT_TEST)]
)
test_file_path.write_text(code, encoding="utf-8")
result_file, process, _, _ = run_behavioral_tests(
test_files,
test_framework=config.test_framework,
cwd=Path(config.project_root_path),
test_env=test_env,
pytest_timeout=1,
pytest_target_runtime_seconds=1,
result_file, process, _, _ = current_language_support().run_behavioral_tests(
test_paths=test_files, test_env=test_env, cwd=Path(config.project_root_path), timeout=1
)
results = parse_test_xml(
test_xml_file_path=result_file, test_files=test_files, test_config=config, run_result=process
@ -136,13 +131,8 @@ def test_sort():
test_files=[TestFile(instrumented_behavior_file_path=test_file_path, test_type=TestType.EXISTING_UNIT_TEST)]
)
test_file_path.write_text(code, encoding="utf-8")
result_file, process, _, _ = run_behavioral_tests(
test_files,
test_framework=config.test_framework,
cwd=Path(config.project_root_path),
test_env=test_env,
pytest_timeout=1,
pytest_target_runtime_seconds=1,
result_file, process, _, _ = current_language_support().run_behavioral_tests(
test_paths=test_files, test_env=test_env, cwd=Path(config.project_root_path), timeout=1
)
results = parse_test_xml(
test_xml_file_path=result_file, test_files=test_files, test_config=config, run_result=process

View file

@ -10,8 +10,8 @@ from codeflash.languages.python.context.unused_definition_remover import (
detect_unused_helper_functions,
revert_unused_helper_functions,
)
from codeflash.languages.python.function_optimizer import PythonFunctionOptimizer
from codeflash.models.models import CodeStringsMarkdown
from codeflash.optimization.function_optimizer import FunctionOptimizer
from codeflash.verification.verification_utils import TestConfig
@ -83,7 +83,7 @@ def helper_function_2(x):
)
# Create function optimizer
optimizer = FunctionOptimizer(
optimizer = PythonFunctionOptimizer(
function_to_optimize=function_to_optimize,
test_cfg=test_cfg,
function_to_optimize_source_code=main_file.read_text(),
@ -194,7 +194,7 @@ def helper_function_2(x):
)
# Create function optimizer
optimizer = FunctionOptimizer(
optimizer = PythonFunctionOptimizer(
function_to_optimize=function_to_optimize,
test_cfg=test_cfg,
function_to_optimize_source_code=main_file.read_text(),
@ -269,7 +269,7 @@ def helper_function_2(x):
)
# Create function optimizer
optimizer = FunctionOptimizer(
optimizer = PythonFunctionOptimizer(
function_to_optimize=function_to_optimize,
test_cfg=test_cfg,
function_to_optimize_source_code=main_file.read_text(),
@ -365,7 +365,7 @@ def entrypoint_function(n):
)
# Create function optimizer
optimizer = FunctionOptimizer(
optimizer = PythonFunctionOptimizer(
function_to_optimize=function_to_optimize,
test_cfg=test_cfg,
function_to_optimize_source_code=main_file.read_text(),
@ -559,7 +559,7 @@ class Calculator:
)
# Create function optimizer
optimizer = FunctionOptimizer(
optimizer = PythonFunctionOptimizer(
function_to_optimize=function_to_optimize,
test_cfg=test_cfg,
function_to_optimize_source_code=main_file.read_text(),
@ -710,7 +710,7 @@ class Processor:
)
# Create function optimizer
optimizer = FunctionOptimizer(
optimizer = PythonFunctionOptimizer(
function_to_optimize=function_to_optimize,
test_cfg=test_cfg,
function_to_optimize_source_code=main_file.read_text(),
@ -895,7 +895,7 @@ class OuterClass:
)
# Create function optimizer
optimizer = FunctionOptimizer(
optimizer = PythonFunctionOptimizer(
function_to_optimize=function_to_optimize,
test_cfg=test_cfg,
function_to_optimize_source_code=main_file.read_text(),
@ -1051,7 +1051,7 @@ def entrypoint_function(n):
)
# Create function optimizer
optimizer = FunctionOptimizer(
optimizer = PythonFunctionOptimizer(
function_to_optimize=function_to_optimize,
test_cfg=test_cfg,
function_to_optimize_source_code=main_file.read_text(),
@ -1215,7 +1215,7 @@ def entrypoint_function(n):
)
# Create function optimizer
optimizer = FunctionOptimizer(
optimizer = PythonFunctionOptimizer(
function_to_optimize=function_to_optimize,
test_cfg=test_cfg,
function_to_optimize_source_code=main_file.read_text(),
@ -1442,7 +1442,7 @@ class MathUtils:
)
# Create function optimizer
optimizer = FunctionOptimizer(
optimizer = PythonFunctionOptimizer(
function_to_optimize=function_to_optimize,
test_cfg=test_cfg,
function_to_optimize_source_code=main_file.read_text(),
@ -1576,7 +1576,7 @@ async def async_entrypoint(n):
)
# Create function optimizer
optimizer = FunctionOptimizer(
optimizer = PythonFunctionOptimizer(
function_to_optimize=function_to_optimize,
test_cfg=test_cfg,
function_to_optimize_source_code=main_file.read_text(),
@ -1664,7 +1664,7 @@ def sync_entrypoint(n):
function_to_optimize = FunctionToOptimize(file_path=main_file, function_name="sync_entrypoint", parents=[])
# Create function optimizer
optimizer = FunctionOptimizer(
optimizer = PythonFunctionOptimizer(
function_to_optimize=function_to_optimize,
test_cfg=test_cfg,
function_to_optimize_source_code=main_file.read_text(),
@ -1773,7 +1773,7 @@ async def mixed_entrypoint(n):
)
# Create function optimizer
optimizer = FunctionOptimizer(
optimizer = PythonFunctionOptimizer(
function_to_optimize=function_to_optimize,
test_cfg=test_cfg,
function_to_optimize_source_code=main_file.read_text(),
@ -1874,7 +1874,7 @@ class AsyncProcessor:
)
# Create function optimizer
optimizer = FunctionOptimizer(
optimizer = PythonFunctionOptimizer(
function_to_optimize=function_to_optimize,
test_cfg=test_cfg,
function_to_optimize_source_code=main_file.read_text(),
@ -1960,7 +1960,7 @@ async def async_entrypoint(n):
)
# Create function optimizer
optimizer = FunctionOptimizer(
optimizer = PythonFunctionOptimizer(
function_to_optimize=function_to_optimize,
test_cfg=test_cfg,
function_to_optimize_source_code=main_file.read_text(),
@ -2039,7 +2039,7 @@ def gcd_recursive(a: int, b: int) -> int:
function_to_optimize = FunctionToOptimize(file_path=main_file, function_name="gcd_recursive", parents=[])
# Create function optimizer
optimizer = FunctionOptimizer(
optimizer = PythonFunctionOptimizer(
function_to_optimize=function_to_optimize,
test_cfg=test_cfg,
function_to_optimize_source_code=main_file.read_text(),
@ -2152,7 +2152,7 @@ async def async_entrypoint_with_generators(n):
)
# Create function optimizer
optimizer = FunctionOptimizer(
optimizer = PythonFunctionOptimizer(
function_to_optimize=function_to_optimize,
test_cfg=test_cfg,
function_to_optimize_source_code=main_file.read_text(),