mirror of
https://github.com/codeflash-ai/codeflash.git
synced 2026-05-04 18:25:17 +00:00
Merge branch 'main' into fix/jest-junit-and-misc
This commit is contained in:
commit
c53740df2e
82 changed files with 3335 additions and 3441 deletions
|
|
@ -1,5 +1,7 @@
|
|||
# Architecture
|
||||
|
||||
When adding, moving, or deleting source files, update this doc to match.
|
||||
|
||||
```
|
||||
codeflash/
|
||||
├── main.py # CLI entry point
|
||||
|
|
@ -14,7 +16,21 @@ codeflash/
|
|||
├── api/ # AI service communication
|
||||
├── code_utils/ # Code parsing, git utilities
|
||||
├── models/ # Pydantic models and types
|
||||
├── languages/ # Multi-language support (Python, JavaScript/TypeScript)
|
||||
├── languages/ # Multi-language support (Python, JavaScript/TypeScript, Java planned)
|
||||
│ ├── base.py # LanguageSupport protocol and shared data types
|
||||
│ ├── registry.py # Language registration and lookup by extension/enum
|
||||
│ ├── current.py # Current language singleton (set_current_language / current_language_support)
|
||||
│ ├── code_replacer.py # Language-agnostic code replacement
|
||||
│ ├── python/
|
||||
│ │ ├── support.py # PythonSupport (LanguageSupport implementation)
|
||||
│ │ ├── function_optimizer.py # PythonFunctionOptimizer subclass
|
||||
│ │ ├── optimizer.py # Python module preparation & AST resolution
|
||||
│ │ └── normalizer.py # Python code normalization for deduplication
|
||||
│ └── javascript/
|
||||
│ ├── support.py # JavaScriptSupport (LanguageSupport implementation)
|
||||
│ ├── function_optimizer.py # JavaScriptFunctionOptimizer subclass
|
||||
│ ├── optimizer.py # JS project root finding & module preparation
|
||||
│ └── normalizer.py # JS/TS code normalization for deduplication
|
||||
├── setup/ # Config schema, auto-detection, first-run experience
|
||||
├── picklepatch/ # Serialization/deserialization utilities
|
||||
├── tracing/ # Function call tracing
|
||||
|
|
@ -32,10 +48,36 @@ codeflash/
|
|||
|------|------------|
|
||||
| CLI arguments & commands | `cli_cmds/cli.py` |
|
||||
| Optimization orchestration | `optimization/optimizer.py` → `run()` |
|
||||
| Per-function optimization | `optimization/function_optimizer.py` |
|
||||
| Per-function optimization | `optimization/function_optimizer.py` (base), `languages/python/function_optimizer.py`, `languages/javascript/function_optimizer.py` |
|
||||
| Function discovery | `discovery/functions_to_optimize.py` |
|
||||
| Context extraction | `languages/<lang>/context/code_context_extractor.py` |
|
||||
| Test execution | `verification/test_runner.py`, `verification/pytest_plugin.py` |
|
||||
| Test execution | `languages/<lang>/support.py` (`run_behavioral_tests`, etc.), `verification/pytest_plugin.py` |
|
||||
| Performance ranking | `benchmarking/function_ranker.py` |
|
||||
| Domain types | `models/models.py`, `models/function_types.py` |
|
||||
| Result handling | `either.py` (`Result`, `Success`, `Failure`, `is_successful`) |
|
||||
|
||||
## LanguageSupport Protocol Methods
|
||||
|
||||
Core protocol in `languages/base.py`. Each language (`PythonSupport`, `JavaScriptSupport`) implements these.
|
||||
|
||||
| Category | Method/Property | Purpose |
|
||||
|----------|----------------|---------|
|
||||
| Identity | `language`, `file_extensions`, `default_file_extension` | Language identification |
|
||||
| Identity | `comment_prefix`, `dir_excludes` | Language conventions |
|
||||
| AI service | `default_language_version` | Language version for API payloads (`None` for Python, `"ES2022"` for JS) |
|
||||
| AI service | `valid_test_frameworks` | Allowed test frameworks for validation |
|
||||
| Discovery | `discover_functions`, `discover_tests` | Find optimizable functions and their tests |
|
||||
| Discovery | `adjust_test_config_for_discovery` | Pre-discovery config adjustment (no-op default) |
|
||||
| Context | `extract_code_context`, `find_helper_functions`, `find_references` | Code dependency extraction |
|
||||
| Transform | `replace_function`, `format_code`, `normalize_code` | Code modification |
|
||||
| Validation | `validate_syntax` | Syntax checking |
|
||||
| Test execution | `run_behavioral_tests`, `run_benchmarking_tests`, `run_line_profile_tests` | Test runners |
|
||||
| Test results | `test_result_serialization_format` | `"pickle"` (Python) or `"json"` (JS) |
|
||||
| Test results | `load_coverage` | Load coverage from language-specific format |
|
||||
| Test results | `compare_test_results` | Equivalence checking between original and candidate |
|
||||
| Test gen | `postprocess_generated_tests` | Post-process `GeneratedTestsList` objects |
|
||||
| Test gen | `process_generated_test_strings` | Instrument/transform raw generated test strings |
|
||||
| Module | `detect_module_system` | Detect project module system (`None` for Python, `"esm"`/`"commonjs"` for JS) |
|
||||
| Module | `prepare_module` | Parse/validate module before optimization |
|
||||
| Setup | `setup_test_config` | One-time project setup after language detection |
|
||||
| Optimizer | `function_optimizer_class` | Return `FunctionOptimizer` subclass for this language |
|
||||
|
|
|
|||
|
|
@ -7,4 +7,5 @@
|
|||
- **Comments**: Minimal - only explain "why", not "what"
|
||||
- **Docstrings**: Do not add unless explicitly requested
|
||||
- **Naming**: NEVER use leading underscores (`_function_name`) - Python has no true private functions, use public names
|
||||
- **Paths**: Always use absolute paths, handle encoding explicitly (UTF-8)
|
||||
- **Paths**: Always use absolute paths
|
||||
- **Encoding**: Always pass `encoding="utf-8"` to `open()`, `read_text()`, `write_text()`, etc. in new or changed code — Windows defaults to `cp1252` which breaks on non-ASCII content. Don't flag pre-existing code that lacks it unless you're already modifying that line.
|
||||
|
|
|
|||
|
|
@ -9,4 +9,5 @@ paths:
|
|||
- Use `get_language_support(identifier)` from `languages/registry.py` to get a `LanguageSupport` instance — never import language classes directly
|
||||
- New language support classes must use the `@register_language` decorator to register with the extension and language registries
|
||||
- `languages/__init__.py` uses `__getattr__` for lazy imports to avoid circular dependencies — follow this pattern when adding new exports
|
||||
- `is_javascript()` returns `True` for both JavaScript and TypeScript
|
||||
- Prefer `LanguageSupport` protocol dispatch over `is_python()`/`is_javascript()` guards — remaining guards are being migrated to protocol methods
|
||||
- `is_javascript()` returns `True` for both JavaScript and TypeScript (still used in ~15 call sites pending migration)
|
||||
|
|
|
|||
|
|
@ -3,7 +3,7 @@ paths:
|
|||
- "codeflash/optimization/**/*.py"
|
||||
- "codeflash/verification/**/*.py"
|
||||
- "codeflash/benchmarking/**/*.py"
|
||||
- "codeflash/context/**/*.py"
|
||||
- "codeflash/languages/*/context/**/*.py"
|
||||
---
|
||||
|
||||
# Optimization Pipeline Patterns
|
||||
|
|
|
|||
176
.github/workflows/claude.yml
vendored
176
.github/workflows/claude.yml
vendored
|
|
@ -56,132 +56,116 @@ jobs:
|
|||
use_sticky_comment: true
|
||||
allowed_bots: "claude[bot],codeflash-ai[bot]"
|
||||
prompt: |
|
||||
REPO: ${{ github.repository }}
|
||||
PR NUMBER: ${{ github.event.pull_request.number }}
|
||||
EVENT: ${{ github.event.action }}
|
||||
<context>
|
||||
repo: ${{ github.repository }}
|
||||
pr_number: ${{ github.event.pull_request.number }}
|
||||
event: ${{ github.event.action }}
|
||||
is_re_review: ${{ github.event.action == 'synchronize' }}
|
||||
</context>
|
||||
|
||||
## STEP 1: Run prek and mypy checks, fix issues
|
||||
<commitment>
|
||||
Execute these steps in order. If a step has no work, state that and continue to the next step.
|
||||
Post all review findings in a single summary comment only — never as inline PR review comments.
|
||||
</commitment>
|
||||
|
||||
First, run these checks on files changed in this PR:
|
||||
1. `uv run prek run --from-ref origin/main` - linting/formatting issues
|
||||
2. `uv run mypy <changed_files>` - type checking issues
|
||||
<step name="lint_and_typecheck">
|
||||
Run checks on files changed in this PR and auto-fix what you can.
|
||||
|
||||
If there are prek issues:
|
||||
- For SAFE auto-fixable issues (formatting, import sorting, trailing whitespace, etc.), run `uv run prek run --from-ref origin/main` again to auto-fix them
|
||||
- For issues that prek cannot auto-fix, do NOT attempt to fix them manually — report them as remaining issues in your summary
|
||||
1. Run `uv run prek run --from-ref origin/main` to check linting/formatting.
|
||||
If there are auto-fixable issues, run it again to fix them.
|
||||
Report any issues prek cannot auto-fix in your summary.
|
||||
|
||||
If there are mypy issues:
|
||||
- Fix type annotation issues (missing return types, Optional/None unions, import errors for type hints, incorrect types)
|
||||
- Do NOT add `type: ignore` comments - always fix the root cause
|
||||
2. Run `uv run mypy <changed_files>` to check types.
|
||||
Fix type annotation issues (missing return types, Optional unions, import errors).
|
||||
Always fix the root cause instead of adding `type: ignore` comments.
|
||||
Leave alone: type errors requiring logic changes, complex generics, anything changing runtime behavior.
|
||||
|
||||
After fixing issues:
|
||||
- Stage the fixed files with `git add`
|
||||
- Commit with message "style: auto-fix linting issues" or "fix: resolve mypy type errors" as appropriate
|
||||
- Push the changes with `git push`
|
||||
3. After fixes: stage with `git add`, commit ("style: auto-fix linting issues" or "fix: resolve mypy type errors"), push.
|
||||
|
||||
IMPORTANT - Verification after fixing:
|
||||
- After committing fixes, run `uv run prek run --from-ref origin/main` ONE MORE TIME to verify all issues are resolved
|
||||
- If errors remain, either fix them or report them honestly as unfixed in your summary
|
||||
- NEVER claim issues are fixed without verifying. If you cannot fix an issue, say so
|
||||
4. Verify by running `uv run prek run --from-ref origin/main` one more time. Report honestly if issues remain.
|
||||
</step>
|
||||
|
||||
Do NOT attempt to fix:
|
||||
- Type errors that require logic changes or refactoring
|
||||
- Complex generic type issues
|
||||
- Anything that could change runtime behavior
|
||||
<step name="resolve_stale_threads">
|
||||
Before reviewing, resolve any stale review threads from previous runs.
|
||||
|
||||
## STEP 2: Review the PR
|
||||
1. Fetch unresolved threads you created:
|
||||
`gh api graphql -f query='{ repository(owner: "${{ github.repository_owner }}", name: "${{ github.event.repository.name }}") { pullRequest(number: ${{ github.event.pull_request.number }}) { reviewThreads(first: 100) { nodes { id isResolved path comments(first: 1) { nodes { body author { login } } } } } } } }' --jq '.data.repository.pullRequest.reviewThreads.nodes[] | select(.isResolved == false) | select(.comments.nodes[0].author.login == "claude") | {id: .id, path: .path, body: .comments.nodes[0].body}'`
|
||||
|
||||
${{ github.event.action == 'synchronize' && 'This is a RE-REVIEW after new commits. First, get the list of changed files in this latest push using `gh pr diff`. Review ONLY the changed files. Check ALL existing review comments and resolve ones that are now fixed.' || 'This is the INITIAL REVIEW.' }}
|
||||
2. For each unresolved thread:
|
||||
a. Read the file at that path to check if the issue still exists
|
||||
b. If fixed → resolve it: `gh api graphql -f query='mutation { resolveReviewThread(input: {threadId: "<THREAD_ID>"}) { thread { isResolved } } }'`
|
||||
c. If still present → leave it
|
||||
|
||||
Review this PR focusing ONLY on:
|
||||
1. Critical bugs or logic errors
|
||||
Read the actual code before deciding. If there are no unresolved threads, skip to the next step.
|
||||
</step>
|
||||
|
||||
<step name="review">
|
||||
Review the diff (`gh pr diff ${{ github.event.pull_request.number }}`) for:
|
||||
1. Bugs that will crash at runtime
|
||||
2. Security vulnerabilities
|
||||
3. Breaking API changes
|
||||
4. Test failures (methods with typos that wont run)
|
||||
5. Stale documentation — if files or directories were moved, renamed, or deleted, check that `.claude/rules/`, `CLAUDE.md`, and `AGENTS.md` don't reference paths that no longer exist
|
||||
6. New language support — if new language modules are added under `languages/`, check that `.github/workflows/duplicate-code-detector.yml` includes the new language in its file filters, search patterns, and cross-module checks
|
||||
|
||||
IMPORTANT:
|
||||
- First check existing review comments using `gh api repos/${{ github.repository }}/pulls/${{ github.event.pull_request.number }}/comments`. For each existing comment, check if the issue still exists in the current code.
|
||||
- If an issue is fixed, use `gh api --method PATCH repos/${{ github.repository }}/pulls/comments/COMMENT_ID -f body="✅ Fixed in latest commit"` to resolve it.
|
||||
- Only create NEW inline comments for HIGH-PRIORITY issues found in changed files.
|
||||
- Limit to 5-7 NEW comments maximum per review.
|
||||
- Use CLAUDE.md for project-specific guidance.
|
||||
- Use `mcp__github_inline_comment__create_inline_comment` sparingly for critical code issues only.
|
||||
|
||||
## STEP 3: Coverage analysis
|
||||
Ignore style issues, type hints, and log message wording.
|
||||
Record findings for the summary comment. Refer to CLAUDE.md for project conventions.
|
||||
</step>
|
||||
|
||||
<step name="coverage">
|
||||
Analyze test coverage for changed files:
|
||||
|
||||
1. Get the list of Python files changed in this PR (excluding tests):
|
||||
`git diff --name-only origin/main...HEAD -- '*.py' | grep -v test`
|
||||
1. Get changed Python files (excluding tests): `git diff --name-only origin/main...HEAD -- '*.py' | grep -v test`
|
||||
2. Run coverage on PR branch: `uv run coverage run -m pytest tests/ -q --tb=no` then `uv run coverage json -o coverage-pr.json`
|
||||
3. Get per-file coverage: `uv run coverage report --include="<changed_files>"`
|
||||
4. Compare with main: checkout main, run coverage, checkout back
|
||||
5. Flag: new files below 75%, decreased coverage, untested changed lines
|
||||
</step>
|
||||
|
||||
2. Run tests with coverage on the PR branch:
|
||||
`uv run coverage run -m pytest tests/ -q --tb=no`
|
||||
`uv run coverage json -o coverage-pr.json`
|
||||
<step name="summary_comment">
|
||||
Post exactly one summary comment containing all results from previous steps.
|
||||
|
||||
3. Get coverage for changed files only:
|
||||
`uv run coverage report --include="<changed_files_comma_separated>"`
|
||||
|
||||
4. Compare with main branch coverage:
|
||||
- Checkout main: `git checkout origin/main`
|
||||
- Run coverage: `uv run coverage run -m pytest tests/ -q --tb=no && uv run coverage json -o coverage-main.json`
|
||||
- Checkout back: `git checkout -`
|
||||
|
||||
5. Analyze the diff to identify:
|
||||
- NEW FILES: Files that don't exist on main (require good test coverage)
|
||||
- MODIFIED FILES: Files with changes (changes must be covered by tests)
|
||||
|
||||
6. Report in PR comment with a markdown table:
|
||||
- Coverage % for each changed file (PR vs main)
|
||||
- Overall coverage change
|
||||
- For NEW files: Flag if coverage is below 75%
|
||||
- For MODIFIED files: Flag if the changed lines are not covered by tests
|
||||
- Flag if overall coverage decreased
|
||||
|
||||
Coverage requirements:
|
||||
- New implementations/files: Must have ≥75% test coverage
|
||||
- Modified code: Changed lines should be exercised by existing or new tests
|
||||
- No coverage regressions: Overall coverage should not decrease
|
||||
|
||||
## STEP 4: Post ONE consolidated summary comment
|
||||
|
||||
CRITICAL: You must post exactly ONE summary comment containing ALL results (pre-commit, review, coverage).
|
||||
DO NOT post multiple separate comments. Use this format:
|
||||
To ensure one comment: find an existing claude[bot] comment and update it, or create one if none exists.
|
||||
Delete any duplicate claude[bot] comments.
|
||||
|
||||
```
|
||||
gh api repos/${{ github.repository }}/issues/${{ github.event.pull_request.number }}/comments --jq '.[] | select(.user.login == "claude[bot]") | .id' | head -1
|
||||
```
|
||||
|
||||
Format:
|
||||
## PR Review Summary
|
||||
|
||||
### Prek Checks
|
||||
[status and any fixes made]
|
||||
|
||||
### Code Review
|
||||
[critical issues found, if any]
|
||||
|
||||
### Test Coverage
|
||||
[coverage table and analysis]
|
||||
|
||||
---
|
||||
*Last updated: <timestamp>*
|
||||
```
|
||||
</step>
|
||||
|
||||
To ensure only ONE comment exists:
|
||||
1. Find existing claude[bot] comment: `gh api repos/${{ github.repository }}/issues/${{ github.event.pull_request.number }}/comments --jq '.[] | select(.user.login == "claude[bot]") | .id' | head -1`
|
||||
2. If found, UPDATE it: `gh api --method PATCH repos/${{ github.repository }}/issues/comments/<ID> -f body="<content>"`
|
||||
3. If not found, CREATE: `gh pr comment ${{ github.event.pull_request.number }} --body "<content>"`
|
||||
4. Delete any OTHER claude[bot] comments to clean up duplicates: `gh api repos/${{ github.repository }}/issues/${{ github.event.pull_request.number }}/comments --jq '.[] | select(.user.login == "claude[bot]") | .id' | tail -n +2 | xargs -I {} gh api --method DELETE repos/${{ github.repository }}/issues/comments/{}`
|
||||
<step name="simplify">
|
||||
Run /simplify to review recently changed code for reuse, quality, and efficiency opportunities.
|
||||
If improvements are found, commit with "refactor: simplify <description>" and push.
|
||||
Only make behavior-preserving changes.
|
||||
</step>
|
||||
|
||||
## STEP 5: Merge pending codeflash optimization PRs
|
||||
<step name="merge_optimization_prs">
|
||||
Check for open PRs from codeflash-ai[bot]:
|
||||
`gh pr list --author "codeflash-ai[bot]" --state open --json number,title,headRefName,createdAt,mergeable`
|
||||
|
||||
Check for open optimization PRs from codeflash and merge if CI passes:
|
||||
For each PR:
|
||||
- If CI passes and the PR is mergeable → merge with `--squash --delete-branch`
|
||||
- Close the PR as stale if ANY of these apply:
|
||||
- Older than 7 days
|
||||
- Has merge conflicts (mergeable state is "CONFLICTING")
|
||||
- CI is failing
|
||||
- The optimized function no longer exists in the target file (check the diff)
|
||||
Close with: `gh pr close <number> --comment "Closing stale optimization PR." --delete-branch`
|
||||
</step>
|
||||
|
||||
1. List open PRs from codeflash bot:
|
||||
`gh pr list --author "codeflash-ai[bot]" --state open --json number,title,headRefName`
|
||||
|
||||
2. For each optimization PR:
|
||||
- Check if CI is passing: `gh pr checks <number>`
|
||||
- If all checks pass, merge it: `gh pr merge <number> --squash --delete-branch`
|
||||
claude_args: '--model us.anthropic.claude-opus-4-6-v1 --allowedTools "mcp__github_inline_comment__create_inline_comment,Bash(gh pr comment:*),Bash(gh pr diff:*),Bash(gh pr view:*),Bash(gh pr list:*),Bash(gh pr checks:*),Bash(gh pr merge:*),Bash(gh issue view:*),Bash(gh issue list:*),Bash(gh api:*),Bash(uv run prek *),Bash(uv run mypy *),Bash(uv run coverage *),Bash(uv run pytest *),Bash(git status*),Bash(git add *),Bash(git commit *),Bash(git push*),Bash(git diff *),Bash(git checkout *),Read,Glob,Grep,Edit"'
|
||||
<verification>
|
||||
Before finishing, confirm:
|
||||
- All steps were attempted (even if some had no work)
|
||||
- Stale review threads were checked and resolved where appropriate
|
||||
- All findings are in a single summary comment (no inline review comments were created)
|
||||
- If fixes were made, they were verified with prek
|
||||
</verification>
|
||||
claude_args: '--model us.anthropic.claude-opus-4-6-v1 --allowedTools "Bash(gh pr comment:*),Bash(gh pr diff:*),Bash(gh pr view:*),Bash(gh pr list:*),Bash(gh pr checks:*),Bash(gh pr merge:*),Bash(gh issue view:*),Bash(gh issue list:*),Bash(gh api:*),Bash(uv run prek *),Bash(uv run mypy *),Bash(uv run coverage *),Bash(uv run pytest *),Bash(git status*),Bash(git add *),Bash(git commit *),Bash(git push*),Bash(git diff *),Bash(git checkout *),Read,Glob,Grep,Edit,Skill"'
|
||||
additional_permissions: |
|
||||
actions: read
|
||||
|
||||
|
|
|
|||
|
|
@ -14,7 +14,8 @@ from codeflash.cli_cmds.console import console, logger
|
|||
from codeflash.code_utils.env_utils import get_codeflash_api_key
|
||||
from codeflash.code_utils.git_utils import get_last_commit_author_if_pr_exists, get_repo_owner_and_name
|
||||
from codeflash.code_utils.time_utils import humanize_runtime
|
||||
from codeflash.languages import is_javascript, is_python
|
||||
from codeflash.languages import Language, current_language
|
||||
from codeflash.languages.current import current_language_support
|
||||
from codeflash.models.ExperimentMetadata import ExperimentMetadata
|
||||
from codeflash.models.models import (
|
||||
AIServiceRefinerRequest,
|
||||
|
|
@ -51,6 +52,18 @@ class AiServiceClient:
|
|||
"""Get the next LLM call sequence number."""
|
||||
return next(self.llm_call_counter)
|
||||
|
||||
@staticmethod
|
||||
def add_language_metadata(
|
||||
payload: dict[str, Any], language_version: str | None = None, module_system: str | None = None
|
||||
) -> None:
|
||||
"""Add language version and module system metadata to an API payload."""
|
||||
payload["python_version"] = platform.python_version()
|
||||
default_lang_version = current_language_support().default_language_version
|
||||
if default_lang_version is not None:
|
||||
payload["language_version"] = language_version or default_lang_version
|
||||
if module_system:
|
||||
payload["module_system"] = module_system
|
||||
|
||||
def get_aiservice_base_url(self) -> str:
|
||||
if os.environ.get("CODEFLASH_AIS_SERVER", default="prod").lower() == "local":
|
||||
logger.info("Using local AI Service at http://localhost:8000")
|
||||
|
|
@ -177,16 +190,7 @@ class AiServiceClient:
|
|||
"is_numerical_code": is_numerical_code,
|
||||
}
|
||||
|
||||
# Add language-specific version fields
|
||||
# Always include python_version for backward compatibility with older backend
|
||||
payload["python_version"] = platform.python_version()
|
||||
if is_python():
|
||||
pass # python_version already set
|
||||
else:
|
||||
payload["language_version"] = language_version or "ES2022"
|
||||
# Add module system for JavaScript/TypeScript (esm or commonjs)
|
||||
if module_system:
|
||||
payload["module_system"] = module_system
|
||||
self.add_language_metadata(payload, language_version, module_system)
|
||||
|
||||
# DEBUG: Print payload language field
|
||||
logger.debug(
|
||||
|
|
@ -432,14 +436,7 @@ class AiServiceClient:
|
|||
"language": opt.language,
|
||||
}
|
||||
|
||||
# Add language version - always include python_version for backward compatibility
|
||||
item["python_version"] = platform.python_version()
|
||||
if is_python():
|
||||
pass # python_version already set
|
||||
elif opt.language_version:
|
||||
item["language_version"] = opt.language_version
|
||||
else:
|
||||
item["language_version"] = "ES2022" # Default for JS/TS
|
||||
self.add_language_metadata(item, opt.language_version)
|
||||
|
||||
# Add multi-file context if provided
|
||||
if opt.additional_context_files:
|
||||
|
|
@ -752,16 +749,11 @@ class AiServiceClient:
|
|||
|
||||
"""
|
||||
# Validate test framework based on language
|
||||
python_frameworks = ["pytest", "unittest"]
|
||||
javascript_frameworks = ["jest", "mocha", "vitest"]
|
||||
if is_python():
|
||||
assert test_framework in python_frameworks, (
|
||||
f"Invalid test framework for Python, got {test_framework} but expected one of {python_frameworks}"
|
||||
)
|
||||
elif is_javascript():
|
||||
assert test_framework in javascript_frameworks, (
|
||||
f"Invalid test framework for JavaScript, got {test_framework} but expected one of {javascript_frameworks}"
|
||||
)
|
||||
lang_support = current_language_support()
|
||||
valid_frameworks = lang_support.valid_test_frameworks
|
||||
assert test_framework in valid_frameworks, (
|
||||
f"Invalid test framework for {current_language()}, got {test_framework} but expected one of {list(valid_frameworks)}"
|
||||
)
|
||||
|
||||
payload: dict[str, Any] = {
|
||||
"source_code_being_tested": source_code_being_tested,
|
||||
|
|
@ -780,16 +772,7 @@ class AiServiceClient:
|
|||
"is_numerical_code": is_numerical_code,
|
||||
}
|
||||
|
||||
# Add language-specific version fields
|
||||
# Always include python_version for backward compatibility with older backend
|
||||
payload["python_version"] = platform.python_version()
|
||||
if is_python():
|
||||
pass # python_version already set
|
||||
else:
|
||||
payload["language_version"] = language_version or "ES2022"
|
||||
# Add module system for JavaScript/TypeScript (esm or commonjs)
|
||||
if module_system:
|
||||
payload["module_system"] = module_system
|
||||
self.add_language_metadata(payload, language_version, module_system)
|
||||
|
||||
# DEBUG: Print payload language field
|
||||
logger.debug(f"Sending testgen request with language='{payload['language']}', framework='{test_framework}'")
|
||||
|
|
@ -875,7 +858,7 @@ class AiServiceClient:
|
|||
"codeflash_version": codeflash_version,
|
||||
"calling_fn_details": calling_fn_details,
|
||||
"language": language,
|
||||
"python_version": platform.python_version() if is_python() else None,
|
||||
"python_version": platform.python_version() if current_language() == Language.PYTHON else None,
|
||||
"call_sequence": self.get_next_sequence(),
|
||||
}
|
||||
console.rule()
|
||||
|
|
|
|||
|
|
@ -408,9 +408,10 @@ def get_all_function_names(code: str) -> tuple[bool, list[str]]:
|
|||
def get_run_tmp_file(file_path: Path | str) -> Path:
|
||||
if isinstance(file_path, str):
|
||||
file_path = Path(file_path)
|
||||
if not hasattr(get_run_tmp_file, "tmpdir"):
|
||||
if not hasattr(get_run_tmp_file, "tmpdir_path"):
|
||||
get_run_tmp_file.tmpdir = TemporaryDirectory(prefix="codeflash_")
|
||||
return Path(get_run_tmp_file.tmpdir.name) / file_path
|
||||
get_run_tmp_file.tmpdir_path = Path(get_run_tmp_file.tmpdir.name)
|
||||
return get_run_tmp_file.tmpdir_path / file_path
|
||||
|
||||
|
||||
def path_belongs_to_site_packages(file_path: Path) -> bool:
|
||||
|
|
|
|||
|
|
@ -6,7 +6,10 @@ from typing import Any, Union
|
|||
MAX_TEST_RUN_ITERATIONS = 5
|
||||
OPTIMIZATION_CONTEXT_TOKEN_LIMIT = 64000
|
||||
TESTGEN_CONTEXT_TOKEN_LIMIT = 64000
|
||||
READ_WRITABLE_LIMIT_ERROR = "Read-writable code has exceeded token limit, cannot proceed"
|
||||
TESTGEN_LIMIT_ERROR = "Testgen code context has exceeded token limit, cannot proceed"
|
||||
INDIVIDUAL_TESTCASE_TIMEOUT = 15
|
||||
JAVA_TESTCASE_TIMEOUT = 120
|
||||
MAX_FUNCTION_TEST_SECONDS = 60
|
||||
MIN_IMPROVEMENT_THRESHOLD = 0.05
|
||||
MIN_THROUGHPUT_IMPROVEMENT_THRESHOLD = 0.10 # 10% minimum improvement for async throughput
|
||||
|
|
|
|||
|
|
@ -1,126 +0,0 @@
|
|||
"""Code deduplication utilities using language-specific normalizers.
|
||||
|
||||
This module provides functions to normalize code, generate fingerprints,
|
||||
and detect duplicate code segments across different programming languages.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import hashlib
|
||||
import re
|
||||
|
||||
from codeflash.code_utils.normalizers import get_normalizer
|
||||
from codeflash.languages import current_language, is_python
|
||||
|
||||
|
||||
def normalize_code(
|
||||
code: str, remove_docstrings: bool = True, return_ast_dump: bool = False, language: str | None = None
|
||||
) -> str:
|
||||
"""Normalize code by parsing, cleaning, and normalizing variable names.
|
||||
|
||||
Function names, class names, and parameters are preserved.
|
||||
|
||||
Args:
|
||||
code: Source code as string
|
||||
remove_docstrings: Whether to remove docstrings (Python only)
|
||||
return_ast_dump: Return AST dump instead of unparsed code (Python only)
|
||||
language: Language of the code. If None, uses the current session language.
|
||||
|
||||
Returns:
|
||||
Normalized code as string
|
||||
|
||||
"""
|
||||
if language is None:
|
||||
language = current_language().value
|
||||
|
||||
try:
|
||||
normalizer = get_normalizer(language)
|
||||
|
||||
# Python has additional options
|
||||
if is_python():
|
||||
if return_ast_dump:
|
||||
return normalizer.normalize_for_hash(code)
|
||||
return normalizer.normalize(code, remove_docstrings=remove_docstrings)
|
||||
|
||||
# For other languages, use standard normalization
|
||||
return normalizer.normalize(code)
|
||||
except ValueError:
|
||||
# Unknown language - fall back to basic normalization
|
||||
return _basic_normalize(code)
|
||||
except Exception:
|
||||
# Parsing error - try other languages or fall back
|
||||
if is_python():
|
||||
# Try JavaScript as fallback
|
||||
try:
|
||||
js_normalizer = get_normalizer("javascript")
|
||||
js_result = js_normalizer.normalize(code)
|
||||
if js_result != _basic_normalize(code):
|
||||
return js_result
|
||||
except Exception:
|
||||
pass
|
||||
return _basic_normalize(code)
|
||||
|
||||
|
||||
def _basic_normalize(code: str) -> str:
|
||||
"""Basic normalization: remove comments and normalize whitespace."""
|
||||
# Remove single-line comments (// and #)
|
||||
code = re.sub(r"//.*$", "", code, flags=re.MULTILINE)
|
||||
code = re.sub(r"#.*$", "", code, flags=re.MULTILINE)
|
||||
# Remove multi-line comments
|
||||
code = re.sub(r"/\*.*?\*/", "", code, flags=re.DOTALL)
|
||||
code = re.sub(r'""".*?"""', "", code, flags=re.DOTALL)
|
||||
code = re.sub(r"'''.*?'''", "", code, flags=re.DOTALL)
|
||||
# Normalize whitespace
|
||||
return " ".join(code.split())
|
||||
|
||||
|
||||
def get_code_fingerprint(code: str, language: str | None = None) -> str:
|
||||
"""Generate a fingerprint for normalized code.
|
||||
|
||||
Args:
|
||||
code: Source code
|
||||
language: Language of the code. If None, uses the current session language.
|
||||
|
||||
Returns:
|
||||
SHA-256 hash of normalized code
|
||||
|
||||
"""
|
||||
if language is None:
|
||||
language = current_language().value
|
||||
|
||||
try:
|
||||
normalizer = get_normalizer(language)
|
||||
return normalizer.get_fingerprint(code)
|
||||
except ValueError:
|
||||
# Unknown language - use basic normalization
|
||||
normalized = _basic_normalize(code)
|
||||
return hashlib.sha256(normalized.encode()).hexdigest()
|
||||
|
||||
|
||||
def are_codes_duplicate(code1: str, code2: str, language: str | None = None) -> bool:
|
||||
"""Check if two code segments are duplicates after normalization.
|
||||
|
||||
Args:
|
||||
code1: First code segment
|
||||
code2: Second code segment
|
||||
language: Language of the code. If None, uses the current session language.
|
||||
|
||||
Returns:
|
||||
True if codes are structurally identical (ignoring local variable names)
|
||||
|
||||
"""
|
||||
if language is None:
|
||||
language = current_language().value
|
||||
|
||||
try:
|
||||
normalizer = get_normalizer(language)
|
||||
return normalizer.are_duplicates(code1, code2)
|
||||
except ValueError:
|
||||
# Unknown language - use basic comparison
|
||||
return _basic_normalize(code1) == _basic_normalize(code2)
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
|
||||
# Re-export for backward compatibility
|
||||
__all__ = ["are_codes_duplicate", "get_code_fingerprint", "normalize_code"]
|
||||
|
|
@ -1,106 +0,0 @@
|
|||
"""Code normalizers for different programming languages.
|
||||
|
||||
This module provides language-specific code normalizers that transform source code
|
||||
into canonical forms for duplicate detection. The normalizers:
|
||||
- Replace local variable names with canonical forms (var_0, var_1, etc.)
|
||||
- Preserve function names, class names, parameters, and imports
|
||||
- Remove or normalize comments and docstrings
|
||||
- Produce consistent output for structurally identical code
|
||||
|
||||
Usage:
|
||||
>>> normalizer = get_normalizer("python")
|
||||
>>> normalized = normalizer.normalize(code)
|
||||
>>> fingerprint = normalizer.get_fingerprint(code)
|
||||
>>> are_same = normalizer.are_duplicates(code1, code2)
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from codeflash.code_utils.normalizers.base import CodeNormalizer
|
||||
from codeflash.code_utils.normalizers.javascript import JavaScriptNormalizer, TypeScriptNormalizer
|
||||
from codeflash.code_utils.normalizers.python import PythonNormalizer
|
||||
|
||||
__all__ = [
|
||||
"CodeNormalizer",
|
||||
"JavaScriptNormalizer",
|
||||
"PythonNormalizer",
|
||||
"TypeScriptNormalizer",
|
||||
"get_normalizer",
|
||||
"get_normalizer_for_extension",
|
||||
]
|
||||
|
||||
# Registry of normalizers by language
|
||||
_NORMALIZERS: dict[str, type[CodeNormalizer]] = {
|
||||
"python": PythonNormalizer,
|
||||
"javascript": JavaScriptNormalizer,
|
||||
"typescript": TypeScriptNormalizer,
|
||||
}
|
||||
|
||||
# Singleton cache for normalizer instances
|
||||
_normalizer_instances: dict[str, CodeNormalizer] = {}
|
||||
|
||||
|
||||
def get_normalizer(language: str) -> CodeNormalizer:
|
||||
"""Get a code normalizer for the specified language.
|
||||
|
||||
Args:
|
||||
language: Language name ('python', 'javascript', 'typescript')
|
||||
|
||||
Returns:
|
||||
CodeNormalizer instance for the language
|
||||
|
||||
Raises:
|
||||
ValueError: If no normalizer exists for the language
|
||||
|
||||
"""
|
||||
language = language.lower()
|
||||
|
||||
# Check cache first
|
||||
if language in _normalizer_instances:
|
||||
return _normalizer_instances[language]
|
||||
|
||||
# Get normalizer class
|
||||
if language not in _NORMALIZERS:
|
||||
supported = ", ".join(sorted(_NORMALIZERS.keys()))
|
||||
msg = f"No normalizer available for language '{language}'. Supported: {supported}"
|
||||
raise ValueError(msg)
|
||||
|
||||
# Create and cache instance
|
||||
normalizer = _NORMALIZERS[language]()
|
||||
_normalizer_instances[language] = normalizer
|
||||
return normalizer
|
||||
|
||||
|
||||
def get_normalizer_for_extension(extension: str) -> CodeNormalizer | None:
|
||||
"""Get a code normalizer based on file extension.
|
||||
|
||||
Args:
|
||||
extension: File extension including dot (e.g., '.py', '.js')
|
||||
|
||||
Returns:
|
||||
CodeNormalizer instance if found, None otherwise
|
||||
|
||||
"""
|
||||
extension = extension.lower()
|
||||
if not extension.startswith("."):
|
||||
extension = f".{extension}"
|
||||
|
||||
for language in _NORMALIZERS:
|
||||
normalizer = get_normalizer(language)
|
||||
if extension in normalizer.supported_extensions:
|
||||
return normalizer
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def register_normalizer(language: str, normalizer_class: type[CodeNormalizer]) -> None:
|
||||
"""Register a new normalizer for a language.
|
||||
|
||||
Args:
|
||||
language: Language name
|
||||
normalizer_class: CodeNormalizer subclass
|
||||
|
||||
"""
|
||||
_NORMALIZERS[language.lower()] = normalizer_class
|
||||
# Clear cached instance if it exists
|
||||
_normalizer_instances.pop(language.lower(), None)
|
||||
|
|
@ -1,104 +0,0 @@
|
|||
"""Abstract base class for code normalizers.
|
||||
|
||||
Code normalizers transform source code into a canonical form for duplicate detection.
|
||||
They normalize variable names, remove comments/docstrings, and produce consistent output
|
||||
that can be compared across different implementations of the same algorithm.
|
||||
"""
|
||||
|
||||
# TODO:{claude} move to base.py in language folder
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
|
||||
class CodeNormalizer(ABC):
|
||||
"""Abstract base class for language-specific code normalizers.
|
||||
|
||||
Subclasses must implement the normalize() method for their specific language.
|
||||
The normalization should:
|
||||
- Normalize local variable names to canonical forms (var_0, var_1, etc.)
|
||||
- Preserve function names, class names, parameters, and imports
|
||||
- Remove or normalize comments and docstrings
|
||||
- Produce consistent output for structurally identical code
|
||||
|
||||
Example:
|
||||
>>> normalizer = PythonNormalizer()
|
||||
>>> code1 = "def foo(x): y = x + 1; return y"
|
||||
>>> code2 = "def foo(x): z = x + 1; return z"
|
||||
>>> normalizer.normalize(code1) == normalizer.normalize(code2)
|
||||
True
|
||||
|
||||
"""
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def language(self) -> str:
|
||||
"""Return the language this normalizer handles."""
|
||||
...
|
||||
|
||||
@property
|
||||
def supported_extensions(self) -> tuple[str, ...]:
|
||||
"""Return file extensions this normalizer can handle."""
|
||||
return ()
|
||||
|
||||
@abstractmethod
|
||||
def normalize(self, code: str) -> str:
|
||||
"""Normalize code to a canonical form for comparison.
|
||||
|
||||
Args:
|
||||
code: Source code to normalize
|
||||
|
||||
Returns:
|
||||
Normalized representation of the code
|
||||
|
||||
"""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def normalize_for_hash(self, code: str) -> str:
|
||||
"""Normalize code optimized for hashing/fingerprinting.
|
||||
|
||||
This may return a more compact representation than normalize().
|
||||
|
||||
Args:
|
||||
code: Source code to normalize
|
||||
|
||||
Returns:
|
||||
Normalized representation suitable for hashing
|
||||
|
||||
"""
|
||||
...
|
||||
|
||||
def are_duplicates(self, code1: str, code2: str) -> bool:
|
||||
"""Check if two code segments are duplicates after normalization.
|
||||
|
||||
Args:
|
||||
code1: First code segment
|
||||
code2: Second code segment
|
||||
|
||||
Returns:
|
||||
True if codes are structurally identical
|
||||
|
||||
"""
|
||||
try:
|
||||
normalized1 = self.normalize_for_hash(code1)
|
||||
normalized2 = self.normalize_for_hash(code2)
|
||||
except Exception:
|
||||
return False
|
||||
else:
|
||||
return normalized1 == normalized2
|
||||
|
||||
def get_fingerprint(self, code: str) -> str:
|
||||
"""Generate a fingerprint hash for normalized code.
|
||||
|
||||
Args:
|
||||
code: Source code to fingerprint
|
||||
|
||||
Returns:
|
||||
SHA-256 hash of normalized code
|
||||
|
||||
"""
|
||||
import hashlib
|
||||
|
||||
normalized = self.normalize_for_hash(code)
|
||||
return hashlib.sha256(normalized.encode()).hexdigest()
|
||||
|
|
@ -641,17 +641,15 @@ def discover_unit_tests(
|
|||
discover_only_these_tests: list[Path] | None = None,
|
||||
file_to_funcs_to_optimize: dict[Path, list[FunctionToOptimize]] | None = None,
|
||||
) -> tuple[dict[str, set[FunctionCalledInTest]], int, int]:
|
||||
from codeflash.languages import is_javascript, is_python
|
||||
from codeflash.languages import is_python
|
||||
from codeflash.languages.current import current_language_support
|
||||
|
||||
# Detect language from functions being optimized
|
||||
language = _detect_language_from_functions(file_to_funcs_to_optimize)
|
||||
|
||||
# Route to language-specific test discovery for non-Python languages
|
||||
if not is_python():
|
||||
# For JavaScript/TypeScript, tests_project_rootdir should be tests_root itself
|
||||
# The Jest helper will be configured to NOT include "tests." prefix to match
|
||||
if is_javascript():
|
||||
cfg.tests_project_rootdir = cfg.tests_root
|
||||
current_language_support().adjust_test_config_for_discovery(cfg)
|
||||
return discover_tests_for_language(cfg, language, file_to_funcs_to_optimize)
|
||||
|
||||
# Existing Python logic
|
||||
|
|
|
|||
|
|
@ -1,6 +1,7 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import ast
|
||||
import contextlib
|
||||
import os
|
||||
import random
|
||||
import warnings
|
||||
|
|
@ -10,8 +11,8 @@ from pathlib import Path
|
|||
from typing import TYPE_CHECKING, Any, Optional
|
||||
|
||||
import git
|
||||
import libcst as cst
|
||||
from pydantic.dataclasses import dataclass
|
||||
from rich.text import Text
|
||||
from rich.tree import Tree
|
||||
|
||||
from codeflash.api.cfapi import get_blocklisted_functions, is_function_being_optimized_again
|
||||
|
|
@ -37,14 +38,8 @@ __all__ = ["FunctionParent", "FunctionToOptimize"]
|
|||
if TYPE_CHECKING:
|
||||
from argparse import Namespace
|
||||
|
||||
from libcst import CSTNode
|
||||
from libcst.metadata import CodeRange
|
||||
|
||||
from codeflash.models.models import CodeOptimizationContext
|
||||
from codeflash.verification.verification_utils import TestConfig
|
||||
import contextlib
|
||||
|
||||
from rich.text import Text
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
|
|
@ -56,70 +51,6 @@ class FunctionProperties:
|
|||
staticmethod_class_name: Optional[str]
|
||||
|
||||
|
||||
class ReturnStatementVisitor(cst.CSTVisitor):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.has_return_statement: bool = False
|
||||
|
||||
def visit_Return(self, node: cst.Return) -> None:
|
||||
self.has_return_statement = True
|
||||
|
||||
|
||||
class FunctionVisitor(cst.CSTVisitor):
|
||||
METADATA_DEPENDENCIES = (cst.metadata.PositionProvider, cst.metadata.ParentNodeProvider)
|
||||
|
||||
def __init__(self, file_path: Path) -> None:
|
||||
super().__init__()
|
||||
self.file_path: Path = file_path
|
||||
self.functions: list[FunctionToOptimize] = []
|
||||
|
||||
@staticmethod
|
||||
def is_pytest_fixture(node: cst.FunctionDef) -> bool:
|
||||
for decorator in node.decorators:
|
||||
dec = decorator.decorator
|
||||
if isinstance(dec, cst.Call):
|
||||
dec = dec.func
|
||||
if isinstance(dec, cst.Attribute) and dec.attr.value == "fixture":
|
||||
if isinstance(dec.value, cst.Name) and dec.value.value == "pytest":
|
||||
return True
|
||||
if isinstance(dec, cst.Name) and dec.value == "fixture":
|
||||
return True
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
def is_property(node: cst.FunctionDef) -> bool:
|
||||
for decorator in node.decorators:
|
||||
dec = decorator.decorator
|
||||
if isinstance(dec, cst.Name) and dec.value in ("property", "cached_property"):
|
||||
return True
|
||||
return False
|
||||
|
||||
def visit_FunctionDef(self, node: cst.FunctionDef) -> None:
|
||||
return_visitor: ReturnStatementVisitor = ReturnStatementVisitor()
|
||||
node.visit(return_visitor)
|
||||
if return_visitor.has_return_statement and not self.is_pytest_fixture(node) and not self.is_property(node):
|
||||
pos: CodeRange = self.get_metadata(cst.metadata.PositionProvider, node)
|
||||
parents: CSTNode | None = self.get_metadata(cst.metadata.ParentNodeProvider, node)
|
||||
ast_parents: list[FunctionParent] = []
|
||||
while parents is not None:
|
||||
if isinstance(parents, cst.FunctionDef):
|
||||
# Skip nested functions — only discover top-level and class-level functions
|
||||
return
|
||||
if isinstance(parents, cst.ClassDef):
|
||||
ast_parents.append(FunctionParent(parents.name.value, parents.__class__.__name__))
|
||||
parents = self.get_metadata(cst.metadata.ParentNodeProvider, parents, default=None)
|
||||
self.functions.append(
|
||||
FunctionToOptimize(
|
||||
function_name=node.name.value,
|
||||
file_path=self.file_path,
|
||||
parents=list(reversed(ast_parents)),
|
||||
starting_line=pos.start.line,
|
||||
ending_line=pos.end.line,
|
||||
is_async=bool(node.asynchronous),
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Multi-language support helpers
|
||||
# =============================================================================
|
||||
|
|
@ -480,22 +411,14 @@ def get_functions_within_lines(modified_lines: dict[str, list[int]]) -> dict[Pat
|
|||
path = Path(path_str)
|
||||
if not path.exists():
|
||||
continue
|
||||
with path.open(encoding="utf8") as f:
|
||||
file_content = f.read()
|
||||
try:
|
||||
wrapper = cst.metadata.MetadataWrapper(cst.parse_module(file_content))
|
||||
except Exception as e:
|
||||
logger.exception(e)
|
||||
continue
|
||||
function_lines = FunctionVisitor(file_path=path)
|
||||
wrapper.visit(function_lines)
|
||||
functions[path] = [
|
||||
function_to_optimize
|
||||
for function_to_optimize in function_lines.functions
|
||||
if (start_line := function_to_optimize.starting_line) is not None
|
||||
and (end_line := function_to_optimize.ending_line) is not None
|
||||
and any(start_line <= line <= end_line for line in lines_in_file)
|
||||
]
|
||||
all_functions = find_all_functions_in_file(path)
|
||||
functions[path] = [
|
||||
func
|
||||
for func in all_functions.get(path, [])
|
||||
if func.starting_line is not None
|
||||
and func.ending_line is not None
|
||||
and any(func.starting_line <= line <= func.ending_line for line in lines_in_file)
|
||||
]
|
||||
return functions
|
||||
|
||||
|
||||
|
|
@ -524,24 +447,19 @@ def get_all_files_and_functions(
|
|||
|
||||
|
||||
def find_all_functions_in_file(file_path: Path) -> dict[Path, list[FunctionToOptimize]]:
|
||||
"""Find all optimizable functions in a file, routing to the appropriate language handler.
|
||||
|
||||
This function checks if the file extension is supported and routes to either
|
||||
the Python-specific implementation (for backward compatibility) or the
|
||||
language support abstraction for other languages.
|
||||
|
||||
Args:
|
||||
file_path: Path to the source file.
|
||||
|
||||
Returns:
|
||||
Dictionary mapping file path to list of FunctionToOptimize.
|
||||
|
||||
"""
|
||||
# Check if the file extension is supported
|
||||
"""Find all optimizable functions in a file using the language support abstraction."""
|
||||
if not is_language_supported(file_path):
|
||||
return {}
|
||||
try:
|
||||
from codeflash.languages.base import FunctionFilterCriteria
|
||||
|
||||
return _find_all_functions_via_language_support(file_path)
|
||||
lang_support = get_language_support(file_path)
|
||||
criteria = FunctionFilterCriteria(require_return=True)
|
||||
source = file_path.read_text(encoding="utf-8")
|
||||
return {file_path: lang_support.discover_functions(source, file_path, criteria)}
|
||||
except Exception as e:
|
||||
logger.debug(f"Failed to discover functions in {file_path}: {e}")
|
||||
return {}
|
||||
|
||||
|
||||
def get_all_replay_test_functions(
|
||||
|
|
|
|||
|
|
@ -11,7 +11,7 @@ Usage:
|
|||
lang = get_language_support(Path("example.py"))
|
||||
|
||||
# Discover functions
|
||||
functions = lang.discover_functions(file_path)
|
||||
functions = lang.discover_functions(source, file_path)
|
||||
|
||||
# Replace a function
|
||||
new_source = lang.replace_function(file_path, function, new_code)
|
||||
|
|
@ -31,6 +31,7 @@ from codeflash.languages.base import (
|
|||
from codeflash.languages.current import (
|
||||
current_language,
|
||||
current_language_support,
|
||||
is_java,
|
||||
is_javascript,
|
||||
is_python,
|
||||
is_typescript,
|
||||
|
|
@ -78,6 +79,10 @@ def __getattr__(name: str):
|
|||
from codeflash.languages.python.support import PythonSupport
|
||||
|
||||
return PythonSupport
|
||||
if name == "JavaSupport":
|
||||
from codeflash.languages.java.support import JavaSupport
|
||||
|
||||
return JavaSupport
|
||||
msg = f"module {__name__!r} has no attribute {name!r}"
|
||||
raise AttributeError(msg)
|
||||
|
||||
|
|
@ -101,6 +106,7 @@ __all__ = [
|
|||
"get_language_support",
|
||||
"get_supported_extensions",
|
||||
"get_supported_languages",
|
||||
"is_java",
|
||||
"is_javascript",
|
||||
"is_jest",
|
||||
"is_mocha",
|
||||
|
|
|
|||
|
|
@ -11,11 +11,13 @@ from dataclasses import dataclass, field
|
|||
from typing import TYPE_CHECKING, Any, Protocol, runtime_checkable
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import ast
|
||||
from collections.abc import Callable, Iterable, Sequence
|
||||
from pathlib import Path
|
||||
|
||||
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
|
||||
from codeflash.models.models import FunctionSource, GeneratedTestsList, InvocationId
|
||||
from codeflash.models.models import FunctionSource, GeneratedTestsList, InvocationId, ValidCode
|
||||
from codeflash.verification.verification_utils import TestConfig
|
||||
|
||||
from codeflash.languages.language_enum import Language
|
||||
from codeflash.models.function_types import FunctionParent
|
||||
|
|
@ -91,6 +93,7 @@ class CodeContext:
|
|||
target_file: Path
|
||||
helper_functions: list[HelperFunction] = field(default_factory=list)
|
||||
read_only_context: str = ""
|
||||
imported_type_skeletons: str = ""
|
||||
imports: list[str] = field(default_factory=list)
|
||||
language: Language = Language.PYTHON
|
||||
|
||||
|
|
@ -166,6 +169,7 @@ class FunctionFilterCriteria:
|
|||
include_patterns: list[str] = field(default_factory=list)
|
||||
exclude_patterns: list[str] = field(default_factory=list)
|
||||
require_return: bool = True
|
||||
require_export: bool = True
|
||||
include_async: bool = True
|
||||
include_methods: bool = True
|
||||
min_lines: int | None = None
|
||||
|
|
@ -251,7 +255,7 @@ class LanguageSupport(Protocol):
|
|||
def language(self) -> Language:
|
||||
return Language.PYTHON
|
||||
|
||||
def discover_functions(self, file_path: Path, ...) -> list[FunctionInfo]:
|
||||
def discover_functions(self, source: str, file_path: Path, ...) -> list[FunctionInfo]:
|
||||
# Python-specific implementation using LibCST
|
||||
...
|
||||
|
||||
|
|
@ -302,15 +306,49 @@ class LanguageSupport(Protocol):
|
|||
"""
|
||||
...
|
||||
|
||||
@property
|
||||
def default_language_version(self) -> str | None:
|
||||
"""Default language version string sent to AI service.
|
||||
|
||||
Returns None for languages where the runtime version is auto-detected (e.g. Python).
|
||||
Returns a version string (e.g. "ES2022") for languages that need an explicit default.
|
||||
"""
|
||||
return None
|
||||
|
||||
@property
|
||||
def valid_test_frameworks(self) -> tuple[str, ...]:
|
||||
"""Valid test frameworks for this language."""
|
||||
...
|
||||
|
||||
@property
|
||||
def test_result_serialization_format(self) -> str:
|
||||
"""How test return values are serialized: "pickle" or "json"."""
|
||||
return "pickle"
|
||||
|
||||
def load_coverage(
|
||||
self,
|
||||
coverage_database_file: Path,
|
||||
function_name: str,
|
||||
code_context: Any,
|
||||
source_file: Path,
|
||||
coverage_config_file: Path | None = None,
|
||||
) -> Any:
|
||||
"""Load coverage data from language-specific format.
|
||||
|
||||
Returns a CoverageData instance.
|
||||
"""
|
||||
...
|
||||
|
||||
# === Discovery ===
|
||||
|
||||
def discover_functions(
|
||||
self, file_path: Path, filter_criteria: FunctionFilterCriteria | None = None
|
||||
self, source: str, file_path: Path, filter_criteria: FunctionFilterCriteria | None = None
|
||||
) -> list[FunctionToOptimize]:
|
||||
"""Find all optimizable functions in a file.
|
||||
"""Find all optimizable functions in source code.
|
||||
|
||||
Args:
|
||||
file_path: Path to the source file to analyze.
|
||||
source: Source code to analyze.
|
||||
file_path: Path to the source file (used for context and language detection).
|
||||
filter_criteria: Optional criteria to filter functions.
|
||||
|
||||
Returns:
|
||||
|
|
@ -637,6 +675,45 @@ class LanguageSupport(Protocol):
|
|||
"""
|
||||
...
|
||||
|
||||
@property
|
||||
def function_optimizer_class(self) -> type:
|
||||
"""Return the FunctionOptimizer subclass for this language."""
|
||||
from codeflash.optimization.function_optimizer import FunctionOptimizer
|
||||
|
||||
return FunctionOptimizer
|
||||
|
||||
def prepare_module(
|
||||
self, module_code: str, module_path: Path, project_root: Path
|
||||
) -> tuple[dict[Path, ValidCode], ast.Module | None] | None:
|
||||
"""Parse/validate a module before optimization."""
|
||||
...
|
||||
|
||||
def setup_test_config(self, test_cfg: TestConfig, file_path: Path) -> None:
|
||||
"""One-time project setup after language detection. Default: no-op."""
|
||||
|
||||
def adjust_test_config_for_discovery(self, test_cfg: TestConfig) -> None:
|
||||
"""Adjust test config before test discovery. Default: no-op."""
|
||||
|
||||
def detect_module_system(self, project_root: Path, source_file: Path) -> str | None:
|
||||
"""Detect the module system used by the project. Default: None (not applicable)."""
|
||||
return None
|
||||
|
||||
def process_generated_test_strings(
|
||||
self,
|
||||
generated_test_source: str,
|
||||
instrumented_behavior_test_source: str,
|
||||
instrumented_perf_test_source: str,
|
||||
function_to_optimize: FunctionToOptimize,
|
||||
test_path: Path,
|
||||
test_cfg: Any,
|
||||
project_module_system: str | None,
|
||||
) -> tuple[str, str, str]:
|
||||
"""Process raw generated test strings (instrumentation, placeholder replacement, etc.).
|
||||
|
||||
Returns (generated_test_source, instrumented_behavior_source, instrumented_perf_source).
|
||||
"""
|
||||
...
|
||||
|
||||
# === Configuration ===
|
||||
|
||||
def get_test_file_suffix(self) -> str:
|
||||
|
|
@ -732,6 +809,20 @@ class LanguageSupport(Protocol):
|
|||
|
||||
# === Test Execution ===
|
||||
|
||||
def generate_concolic_tests(
|
||||
self,
|
||||
test_cfg: TestConfig,
|
||||
project_root: Path,
|
||||
function_to_optimize: FunctionToOptimize,
|
||||
function_to_optimize_ast: Any,
|
||||
) -> tuple[dict, str]:
|
||||
"""Generate concolic tests for a function.
|
||||
|
||||
Default implementation returns empty results. Override for languages
|
||||
that support concolic testing (e.g. Python via CrossHair).
|
||||
"""
|
||||
return {}, ""
|
||||
|
||||
def run_behavioral_tests(
|
||||
self,
|
||||
test_paths: Any,
|
||||
|
|
@ -788,6 +879,31 @@ class LanguageSupport(Protocol):
|
|||
"""
|
||||
...
|
||||
|
||||
def run_line_profile_tests(
|
||||
self,
|
||||
test_paths: Any,
|
||||
test_env: dict[str, str],
|
||||
cwd: Path,
|
||||
timeout: int | None = None,
|
||||
project_root: Path | None = None,
|
||||
line_profile_output_file: Path | None = None,
|
||||
) -> tuple[Path, Any]:
|
||||
"""Run tests for line profiling.
|
||||
|
||||
Args:
|
||||
test_paths: TestFiles object containing test file information.
|
||||
test_env: Environment variables for the test run.
|
||||
cwd: Working directory for running tests.
|
||||
timeout: Optional timeout in seconds.
|
||||
project_root: Project root directory.
|
||||
line_profile_output_file: Path where line profile results will be written.
|
||||
|
||||
Returns:
|
||||
Tuple of (result_file_path, subprocess_result).
|
||||
|
||||
"""
|
||||
...
|
||||
|
||||
|
||||
def convert_parents_to_tuple(parents: list | tuple) -> tuple[FunctionParent, ...]:
|
||||
"""Convert a list of parent objects to a tuple of FunctionParent.
|
||||
|
|
|
|||
135
codeflash/languages/code_replacer.py
Normal file
135
codeflash/languages/code_replacer.py
Normal file
|
|
@ -0,0 +1,135 @@
|
|||
"""Language-agnostic code replacement utilities.
|
||||
|
||||
Used by non-Python language optimizers to replace function definitions
|
||||
via the LanguageSupport protocol.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from codeflash.cli_cmds.console import logger
|
||||
from codeflash.languages.base import FunctionFilterCriteria
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from pathlib import Path
|
||||
|
||||
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
|
||||
from codeflash.languages.base import LanguageSupport
|
||||
from codeflash.models.models import CodeStringsMarkdown
|
||||
|
||||
# Permissive criteria for discovering functions in code snippets (no export/return filtering)
|
||||
_SOURCE_CRITERIA = FunctionFilterCriteria(require_return=False, require_export=False)
|
||||
|
||||
|
||||
def get_optimized_code_for_module(relative_path: Path, optimized_code: CodeStringsMarkdown) -> str:
|
||||
file_to_code_context = optimized_code.file_to_path()
|
||||
module_optimized_code = file_to_code_context.get(str(relative_path))
|
||||
if module_optimized_code is None:
|
||||
# Fallback: if there's only one code block with None file path,
|
||||
# use it regardless of the expected path (the AI server doesn't always include file paths)
|
||||
if "None" in file_to_code_context and len(file_to_code_context) == 1:
|
||||
module_optimized_code = file_to_code_context["None"]
|
||||
logger.debug(f"Using code block with None file_path for {relative_path}")
|
||||
else:
|
||||
logger.warning(
|
||||
f"Optimized code not found for {relative_path} In the context\n-------\n{optimized_code}\n-------\n"
|
||||
"re-check your 'markdown code structure'"
|
||||
f"existing files are {file_to_code_context.keys()}"
|
||||
)
|
||||
module_optimized_code = ""
|
||||
return module_optimized_code
|
||||
|
||||
|
||||
def replace_function_definitions_for_language(
|
||||
function_names: list[str],
|
||||
optimized_code: CodeStringsMarkdown,
|
||||
module_abspath: Path,
|
||||
project_root_path: Path,
|
||||
lang_support: LanguageSupport,
|
||||
function_to_optimize: FunctionToOptimize | None = None,
|
||||
) -> bool:
|
||||
"""Replace function definitions using the LanguageSupport protocol.
|
||||
|
||||
Works for any language that implements LanguageSupport.replace_function
|
||||
and LanguageSupport.discover_functions.
|
||||
"""
|
||||
original_source_code: str = module_abspath.read_text(encoding="utf8")
|
||||
code_to_apply = get_optimized_code_for_module(module_abspath.relative_to(project_root_path), optimized_code)
|
||||
|
||||
if not code_to_apply.strip():
|
||||
return False
|
||||
|
||||
original_source_code = lang_support.add_global_declarations(
|
||||
optimized_code=code_to_apply, original_source=original_source_code, module_abspath=module_abspath
|
||||
)
|
||||
|
||||
if (
|
||||
function_to_optimize
|
||||
and function_to_optimize.starting_line
|
||||
and function_to_optimize.ending_line
|
||||
and function_to_optimize.file_path == module_abspath
|
||||
):
|
||||
optimized_func = _extract_function_from_code(
|
||||
lang_support, code_to_apply, function_to_optimize.function_name, module_abspath
|
||||
)
|
||||
if optimized_func:
|
||||
new_code = lang_support.replace_function(original_source_code, function_to_optimize, optimized_func)
|
||||
else:
|
||||
new_code = lang_support.replace_function(original_source_code, function_to_optimize, code_to_apply)
|
||||
else:
|
||||
new_code = original_source_code
|
||||
modified = False
|
||||
|
||||
functions_to_replace = list(function_names)
|
||||
|
||||
for func_name in functions_to_replace:
|
||||
current_functions = lang_support.discover_functions(new_code, module_abspath, _SOURCE_CRITERIA)
|
||||
|
||||
func = None
|
||||
for f in current_functions:
|
||||
if func_name in (f.qualified_name, f.function_name):
|
||||
func = f
|
||||
break
|
||||
|
||||
if func is None:
|
||||
continue
|
||||
|
||||
optimized_func = _extract_function_from_code(
|
||||
lang_support, code_to_apply, func.function_name, module_abspath
|
||||
)
|
||||
if optimized_func:
|
||||
new_code = lang_support.replace_function(new_code, func, optimized_func)
|
||||
modified = True
|
||||
|
||||
if not modified:
|
||||
logger.warning(f"Could not find function {function_names} in {module_abspath}")
|
||||
return False
|
||||
|
||||
if original_source_code.strip() == new_code.strip():
|
||||
return False
|
||||
|
||||
module_abspath.write_text(new_code, encoding="utf8")
|
||||
return True
|
||||
|
||||
|
||||
def _extract_function_from_code(
|
||||
lang_support: LanguageSupport, source_code: str, function_name: str, file_path: Path
|
||||
) -> str | None:
|
||||
"""Extract a specific function's source code from a code string.
|
||||
|
||||
Includes JSDoc/docstring comments if present.
|
||||
"""
|
||||
try:
|
||||
functions = lang_support.discover_functions(source_code, file_path, _SOURCE_CRITERIA)
|
||||
for func in functions:
|
||||
if func.function_name == function_name:
|
||||
lines = source_code.splitlines(keepends=True)
|
||||
effective_start = func.doc_start_line or func.starting_line
|
||||
if effective_start and func.ending_line and effective_start <= len(lines):
|
||||
func_lines = lines[effective_start - 1 : func.ending_line]
|
||||
return "".join(func_lines)
|
||||
except Exception as e:
|
||||
logger.debug(f"Error extracting function {function_name}: {e}")
|
||||
|
||||
return None
|
||||
|
|
@ -103,6 +103,16 @@ def is_typescript() -> bool:
|
|||
return _current_language == Language.TYPESCRIPT
|
||||
|
||||
|
||||
def is_java() -> bool:
|
||||
"""Check if the current language is Java.
|
||||
|
||||
Returns:
|
||||
True if the current language is Java.
|
||||
|
||||
"""
|
||||
return _current_language == Language.JAVA
|
||||
|
||||
|
||||
def current_language_support() -> LanguageSupport:
|
||||
"""Get the LanguageSupport instance for the current language.
|
||||
|
||||
|
|
|
|||
228
codeflash/languages/javascript/function_optimizer.py
Normal file
228
codeflash/languages/javascript/function_optimizer.py
Normal file
|
|
@ -0,0 +1,228 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import hashlib
|
||||
from collections import defaultdict
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from codeflash.cli_cmds.console import logger
|
||||
from codeflash.code_utils.code_utils import encoded_tokens_len, get_run_tmp_file
|
||||
from codeflash.code_utils.config_consts import (
|
||||
OPTIMIZATION_CONTEXT_TOKEN_LIMIT,
|
||||
READ_WRITABLE_LIMIT_ERROR,
|
||||
TESTGEN_CONTEXT_TOKEN_LIMIT,
|
||||
TESTGEN_LIMIT_ERROR,
|
||||
TOTAL_LOOPING_TIME_EFFECTIVE,
|
||||
)
|
||||
from codeflash.either import Failure, Success
|
||||
from codeflash.models.models import (
|
||||
CodeOptimizationContext,
|
||||
CodeString,
|
||||
CodeStringsMarkdown,
|
||||
FunctionSource,
|
||||
TestingMode,
|
||||
TestResults,
|
||||
)
|
||||
from codeflash.optimization.function_optimizer import FunctionOptimizer
|
||||
from codeflash.verification.equivalence import compare_test_results
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from codeflash.either import Result
|
||||
from codeflash.languages.base import CodeContext, HelperFunction
|
||||
from codeflash.models.models import CoverageData, OriginalCodeBaseline, TestDiff
|
||||
|
||||
|
||||
class JavaScriptFunctionOptimizer(FunctionOptimizer):
|
||||
def get_code_optimization_context(self) -> Result[CodeOptimizationContext, str]:
|
||||
from codeflash.languages import get_language_support
|
||||
from codeflash.languages.base import Language
|
||||
|
||||
language = Language(self.function_to_optimize.language)
|
||||
lang_support = get_language_support(language)
|
||||
|
||||
try:
|
||||
code_context = lang_support.extract_code_context(
|
||||
self.function_to_optimize, self.project_root, self.project_root
|
||||
)
|
||||
return Success(
|
||||
self._build_optimization_context(
|
||||
code_context,
|
||||
self.function_to_optimize.file_path,
|
||||
self.function_to_optimize.language,
|
||||
self.project_root,
|
||||
)
|
||||
)
|
||||
except ValueError as e:
|
||||
return Failure(str(e))
|
||||
|
||||
@staticmethod
|
||||
def _build_optimization_context(
|
||||
code_context: CodeContext,
|
||||
file_path: Path,
|
||||
language: str,
|
||||
project_root: Path,
|
||||
optim_token_limit: int = OPTIMIZATION_CONTEXT_TOKEN_LIMIT,
|
||||
testgen_token_limit: int = TESTGEN_CONTEXT_TOKEN_LIMIT,
|
||||
) -> CodeOptimizationContext:
|
||||
imports_code = "\n".join(code_context.imports) if code_context.imports else ""
|
||||
|
||||
try:
|
||||
target_relative_path = file_path.resolve().relative_to(project_root.resolve())
|
||||
except ValueError:
|
||||
target_relative_path = file_path
|
||||
|
||||
helpers_by_file: dict[Path, list[HelperFunction]] = defaultdict(list)
|
||||
helper_function_sources = []
|
||||
|
||||
for helper in code_context.helper_functions:
|
||||
helpers_by_file[helper.file_path].append(helper)
|
||||
helper_function_sources.append(
|
||||
FunctionSource(
|
||||
file_path=helper.file_path,
|
||||
qualified_name=helper.qualified_name,
|
||||
fully_qualified_name=helper.qualified_name,
|
||||
only_function_name=helper.name,
|
||||
source_code=helper.source_code,
|
||||
)
|
||||
)
|
||||
|
||||
target_file_code = code_context.target_code
|
||||
same_file_helpers = helpers_by_file.get(file_path, [])
|
||||
if same_file_helpers:
|
||||
helper_code = "\n\n".join(h.source_code for h in same_file_helpers)
|
||||
target_file_code = target_file_code + "\n\n" + helper_code
|
||||
|
||||
if imports_code:
|
||||
target_file_code = imports_code + "\n\n" + target_file_code
|
||||
|
||||
read_writable_code_strings = [
|
||||
CodeString(code=target_file_code, file_path=target_relative_path, language=language)
|
||||
]
|
||||
|
||||
for helper_file_path, file_helpers in helpers_by_file.items():
|
||||
if helper_file_path == file_path:
|
||||
continue
|
||||
try:
|
||||
helper_relative_path = helper_file_path.resolve().relative_to(project_root.resolve())
|
||||
except ValueError:
|
||||
helper_relative_path = helper_file_path
|
||||
combined_helper_code = "\n\n".join(h.source_code for h in file_helpers)
|
||||
read_writable_code_strings.append(
|
||||
CodeString(code=combined_helper_code, file_path=helper_relative_path, language=language)
|
||||
)
|
||||
|
||||
read_writable_code = CodeStringsMarkdown(code_strings=read_writable_code_strings, language=language)
|
||||
testgen_context = CodeStringsMarkdown(code_strings=read_writable_code_strings.copy(), language=language)
|
||||
|
||||
read_writable_tokens = encoded_tokens_len(read_writable_code.markdown)
|
||||
if read_writable_tokens > optim_token_limit:
|
||||
raise ValueError(READ_WRITABLE_LIMIT_ERROR)
|
||||
|
||||
testgen_tokens = encoded_tokens_len(testgen_context.markdown)
|
||||
if testgen_tokens > testgen_token_limit:
|
||||
raise ValueError(TESTGEN_LIMIT_ERROR)
|
||||
|
||||
code_hash = hashlib.sha256(read_writable_code.flat.encode("utf-8")).hexdigest()
|
||||
|
||||
return CodeOptimizationContext(
|
||||
testgen_context=testgen_context,
|
||||
read_writable_code=read_writable_code,
|
||||
read_only_context_code=code_context.read_only_context,
|
||||
hashing_code_context=read_writable_code.flat,
|
||||
hashing_code_context_hash=code_hash,
|
||||
helper_functions=helper_function_sources,
|
||||
testgen_helper_fqns=[fs.fully_qualified_name for fs in helper_function_sources],
|
||||
preexisting_objects=set(),
|
||||
)
|
||||
|
||||
def compare_candidate_results(
|
||||
self,
|
||||
baseline_results: OriginalCodeBaseline,
|
||||
candidate_behavior_results: TestResults,
|
||||
optimization_candidate_index: int,
|
||||
) -> tuple[bool, list[TestDiff]]:
|
||||
original_sqlite = get_run_tmp_file(Path("test_return_values_0.sqlite"))
|
||||
candidate_sqlite = get_run_tmp_file(Path(f"test_return_values_{optimization_candidate_index}.sqlite"))
|
||||
|
||||
if original_sqlite.exists() and candidate_sqlite.exists():
|
||||
js_root = self.test_cfg.js_project_root or self.project_root
|
||||
match, diffs = self.language_support.compare_test_results(
|
||||
original_sqlite, candidate_sqlite, project_root=js_root
|
||||
)
|
||||
candidate_sqlite.unlink(missing_ok=True)
|
||||
else:
|
||||
match, diffs = compare_test_results(
|
||||
baseline_results.behavior_test_results, candidate_behavior_results, pass_fail_only=True
|
||||
)
|
||||
return match, diffs
|
||||
|
||||
def should_skip_sqlite_cleanup(self, testing_type: TestingMode, optimization_iteration: int) -> bool:
|
||||
return testing_type == TestingMode.BEHAVIOR or optimization_iteration == 0
|
||||
|
||||
def parse_line_profile_test_results(
|
||||
self, line_profiler_output_file: Path | None
|
||||
) -> tuple[TestResults | dict[str, Any], CoverageData | None]:
|
||||
if line_profiler_output_file is None or not line_profiler_output_file.exists():
|
||||
return TestResults(test_results=[]), None
|
||||
if hasattr(self.language_support, "parse_line_profile_results"):
|
||||
return self.language_support.parse_line_profile_results(line_profiler_output_file), None
|
||||
return TestResults(test_results=[]), None
|
||||
|
||||
def line_profiler_step(
|
||||
self, code_context: CodeOptimizationContext, original_helper_code: dict[Path, str], candidate_index: int
|
||||
) -> dict[str, Any]:
|
||||
if not hasattr(self.language_support, "instrument_source_for_line_profiler"):
|
||||
logger.warning(f"Language support for {self.language_support.language} doesn't support line profiling")
|
||||
return {"timings": {}, "unit": 0, "str_out": ""}
|
||||
|
||||
original_source = self.function_to_optimize.file_path.read_text(encoding="utf-8")
|
||||
try:
|
||||
line_profiler_output_path = get_run_tmp_file(Path("line_profiler_output.json"))
|
||||
|
||||
success = self.language_support.instrument_source_for_line_profiler(
|
||||
func_info=self.function_to_optimize, line_profiler_output_file=line_profiler_output_path
|
||||
)
|
||||
if not success:
|
||||
return {"timings": {}, "unit": 0, "str_out": ""}
|
||||
|
||||
test_env = self.get_test_env(
|
||||
codeflash_loop_index=0, codeflash_test_iteration=candidate_index, codeflash_tracer_disable=1
|
||||
)
|
||||
|
||||
_test_results, _ = self.run_and_parse_tests(
|
||||
testing_type=TestingMode.LINE_PROFILE,
|
||||
test_env=test_env,
|
||||
test_files=self.test_files,
|
||||
optimization_iteration=0,
|
||||
testing_time=TOTAL_LOOPING_TIME_EFFECTIVE,
|
||||
enable_coverage=False,
|
||||
code_context=code_context,
|
||||
line_profiler_output_file=line_profiler_output_path,
|
||||
)
|
||||
|
||||
return self.language_support.parse_line_profile_results(line_profiler_output_path)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to run line profiling: {e}")
|
||||
return {"timings": {}, "unit": 0, "str_out": ""}
|
||||
finally:
|
||||
self.function_to_optimize.file_path.write_text(original_source, encoding="utf-8")
|
||||
|
||||
def replace_function_and_helpers_with_optimized_code(
|
||||
self,
|
||||
code_context: CodeOptimizationContext,
|
||||
optimized_code: CodeStringsMarkdown,
|
||||
original_helper_code: dict[Path, str],
|
||||
) -> bool:
|
||||
from codeflash.languages.code_replacer import replace_function_definitions_for_language
|
||||
|
||||
did_update = False
|
||||
for module_abspath, qualified_names in self.group_functions_by_file(code_context).items():
|
||||
did_update |= replace_function_definitions_for_language(
|
||||
function_names=list(qualified_names),
|
||||
optimized_code=optimized_code,
|
||||
module_abspath=module_abspath,
|
||||
project_root_path=self.project_root,
|
||||
lang_support=self.language_support,
|
||||
function_to_optimize=self.function_to_optimize,
|
||||
)
|
||||
return did_update
|
||||
|
|
@ -499,6 +499,18 @@ class MultiFileHelperFinder:
|
|||
# Split source into lines for JSDoc extraction
|
||||
lines = source.splitlines(keepends=True)
|
||||
|
||||
def helper_from_func(func):
|
||||
effective_start = func.doc_start_line or func.start_line
|
||||
helper_source = "".join(lines[effective_start - 1 : func.end_line])
|
||||
return HelperFunction(
|
||||
name=func.name,
|
||||
qualified_name=func.name,
|
||||
file_path=file_path,
|
||||
source_code=helper_source,
|
||||
start_line=effective_start,
|
||||
end_line=func.end_line,
|
||||
)
|
||||
|
||||
# Handle "default" export - look for default exported function
|
||||
if function_name == "default":
|
||||
# Find the default export
|
||||
|
|
@ -506,38 +518,14 @@ class MultiFileHelperFinder:
|
|||
# For now, return first function if looking for default
|
||||
# TODO: Implement proper default export detection
|
||||
for func in functions:
|
||||
# Extract source including JSDoc if present
|
||||
effective_start = func.doc_start_line or func.start_line
|
||||
helper_lines = lines[effective_start - 1 : func.end_line]
|
||||
helper_source = "".join(helper_lines)
|
||||
|
||||
return HelperFunction(
|
||||
name=func.name,
|
||||
qualified_name=func.name,
|
||||
file_path=file_path,
|
||||
source_code=helper_source,
|
||||
start_line=effective_start,
|
||||
end_line=func.end_line,
|
||||
)
|
||||
return helper_from_func(func)
|
||||
return None
|
||||
|
||||
# Find the function by name
|
||||
functions = file_analyzer.find_functions(source, include_methods=True)
|
||||
for func in functions:
|
||||
if func.name == function_name:
|
||||
# Extract source including JSDoc if present
|
||||
effective_start = func.doc_start_line or func.start_line
|
||||
helper_lines = lines[effective_start - 1 : func.end_line]
|
||||
helper_source = "".join(helper_lines)
|
||||
|
||||
return HelperFunction(
|
||||
name=func.name,
|
||||
qualified_name=func.name,
|
||||
file_path=file_path,
|
||||
source_code=helper_source,
|
||||
start_line=effective_start,
|
||||
end_line=func.end_line,
|
||||
)
|
||||
return helper_from_func(func)
|
||||
|
||||
logger.debug("Function %s not found in %s", function_name, file_path)
|
||||
return None
|
||||
|
|
|
|||
|
|
@ -1,17 +1,52 @@
|
|||
"""JavaScript/TypeScript code normalizer using tree-sitter."""
|
||||
"""JavaScript/TypeScript code normalizer using tree-sitter.
|
||||
|
||||
Not currently wired into JavaScriptSupport.normalize_code — kept as a
|
||||
ready-to-use upgrade path when AST-based JS deduplication is needed.
|
||||
|
||||
The old CodeNormalizer ABC (deleted from base.py) is preserved below for reference.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import re
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from codeflash.code_utils.normalizers.base import CodeNormalizer
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from tree_sitter import Node
|
||||
|
||||
|
||||
# TODO:{claude} move to language support directory to keep the directory structure clean
|
||||
# ---------------------------------------------------------------------------
|
||||
# Reference: the old CodeNormalizer ABC that was deleted from base.py.
|
||||
# Kept here so the interface contract is visible if we re-introduce a
|
||||
# normalizer hierarchy later.
|
||||
# ---------------------------------------------------------------------------
|
||||
class CodeNormalizer(ABC):
|
||||
@property
|
||||
@abstractmethod
|
||||
def language(self) -> str: ...
|
||||
|
||||
@abstractmethod
|
||||
def normalize(self, code: str) -> str: ...
|
||||
|
||||
@abstractmethod
|
||||
def normalize_for_hash(self, code: str) -> str: ...
|
||||
|
||||
def are_duplicates(self, code1: str, code2: str) -> bool:
|
||||
try:
|
||||
return self.normalize_for_hash(code1) == self.normalize_for_hash(code2)
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
def get_fingerprint(self, code: str) -> str:
|
||||
import hashlib
|
||||
|
||||
return hashlib.sha256(self.normalize_for_hash(code).encode()).hexdigest()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class JavaScriptVariableNormalizer:
|
||||
"""Normalizes JavaScript/TypeScript code for duplicate detection using tree-sitter.
|
||||
|
||||
|
|
@ -188,103 +223,35 @@ class JavaScriptVariableNormalizer:
|
|||
parts.append(")")
|
||||
|
||||
|
||||
def _basic_normalize(code: str) -> str:
|
||||
def _basic_normalize_js(code: str) -> str:
|
||||
"""Basic normalization: remove comments and normalize whitespace."""
|
||||
# Remove single-line comments
|
||||
code = re.sub(r"//.*$", "", code, flags=re.MULTILINE)
|
||||
# Remove multi-line comments
|
||||
code = re.sub(r"/\*.*?\*/", "", code, flags=re.DOTALL)
|
||||
# Normalize whitespace
|
||||
return " ".join(code.split())
|
||||
|
||||
|
||||
class JavaScriptNormalizer(CodeNormalizer):
|
||||
"""JavaScript code normalizer using tree-sitter.
|
||||
def normalize_js_code(code: str, typescript: bool = False) -> str:
|
||||
"""Normalize JavaScript/TypeScript code to a canonical form for comparison.
|
||||
|
||||
Normalizes JavaScript code by:
|
||||
- Replacing local variable names with canonical forms (var_0, var_1, etc.)
|
||||
- Preserving function names, class names, parameters, and imports
|
||||
- Removing comments
|
||||
- Normalizing string and number literals
|
||||
Uses tree-sitter to parse and normalize variable names. Falls back to
|
||||
basic comment/whitespace stripping if tree-sitter is unavailable or parsing fails.
|
||||
|
||||
Not currently wired into JavaScriptSupport.normalize_code — kept as a
|
||||
ready-to-use upgrade path when AST-based JS deduplication is needed.
|
||||
"""
|
||||
try:
|
||||
from codeflash.languages.javascript.treesitter import TreeSitterAnalyzer, TreeSitterLanguage
|
||||
|
||||
@property
|
||||
def language(self) -> str:
|
||||
"""Return the language this normalizer handles."""
|
||||
return "javascript"
|
||||
lang = TreeSitterLanguage.TYPESCRIPT if typescript else TreeSitterLanguage.JAVASCRIPT
|
||||
analyzer = TreeSitterAnalyzer(lang)
|
||||
tree = analyzer.parse(code)
|
||||
|
||||
@property
|
||||
def supported_extensions(self) -> tuple[str, ...]:
|
||||
"""Return file extensions this normalizer can handle."""
|
||||
return (".js", ".jsx", ".mjs", ".cjs")
|
||||
if tree.root_node.has_error:
|
||||
return _basic_normalize_js(code)
|
||||
|
||||
def _get_tree_sitter_language(self) -> str:
|
||||
"""Get the tree-sitter language identifier."""
|
||||
return "javascript"
|
||||
|
||||
def normalize(self, code: str) -> str:
|
||||
"""Normalize JavaScript code to a canonical form.
|
||||
|
||||
Args:
|
||||
code: JavaScript source code to normalize
|
||||
|
||||
Returns:
|
||||
Normalized representation of the code
|
||||
|
||||
"""
|
||||
try:
|
||||
from codeflash.languages.javascript.treesitter import TreeSitterAnalyzer, TreeSitterLanguage
|
||||
|
||||
lang_map = {"javascript": TreeSitterLanguage.JAVASCRIPT, "typescript": TreeSitterLanguage.TYPESCRIPT}
|
||||
lang = lang_map.get(self._get_tree_sitter_language(), TreeSitterLanguage.JAVASCRIPT)
|
||||
analyzer = TreeSitterAnalyzer(lang)
|
||||
tree = analyzer.parse(code)
|
||||
|
||||
if tree.root_node.has_error:
|
||||
return _basic_normalize(code)
|
||||
|
||||
normalizer = JavaScriptVariableNormalizer()
|
||||
source_bytes = code.encode("utf-8")
|
||||
|
||||
# First pass: collect preserved names
|
||||
normalizer.collect_preserved_names(tree.root_node, source_bytes)
|
||||
|
||||
# Second pass: normalize and build representation
|
||||
return normalizer.normalize_tree(tree.root_node, source_bytes)
|
||||
except Exception:
|
||||
return _basic_normalize(code)
|
||||
|
||||
def normalize_for_hash(self, code: str) -> str:
|
||||
"""Normalize JavaScript code optimized for hashing.
|
||||
|
||||
For JavaScript, this is the same as normalize().
|
||||
|
||||
Args:
|
||||
code: JavaScript source code to normalize
|
||||
|
||||
Returns:
|
||||
Normalized representation suitable for hashing
|
||||
|
||||
"""
|
||||
return self.normalize(code)
|
||||
|
||||
|
||||
class TypeScriptNormalizer(JavaScriptNormalizer):
|
||||
"""TypeScript code normalizer using tree-sitter.
|
||||
|
||||
Inherits from JavaScriptNormalizer and overrides language-specific settings.
|
||||
"""
|
||||
|
||||
@property
|
||||
def language(self) -> str:
|
||||
"""Return the language this normalizer handles."""
|
||||
return "typescript"
|
||||
|
||||
@property
|
||||
def supported_extensions(self) -> tuple[str, ...]:
|
||||
"""Return file extensions this normalizer can handle."""
|
||||
return (".ts", ".tsx", ".mts", ".cts")
|
||||
|
||||
def _get_tree_sitter_language(self) -> str:
|
||||
"""Get the tree-sitter language identifier."""
|
||||
return "typescript"
|
||||
normalizer = JavaScriptVariableNormalizer()
|
||||
source_bytes = code.encode("utf-8")
|
||||
normalizer.collect_preserved_names(tree.root_node, source_bytes)
|
||||
return normalizer.normalize_tree(tree.root_node, source_bytes)
|
||||
except Exception:
|
||||
return _basic_normalize_js(code)
|
||||
52
codeflash/languages/javascript/optimizer.py
Normal file
52
codeflash/languages/javascript/optimizer.py
Normal file
|
|
@ -0,0 +1,52 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from codeflash.cli_cmds.console import logger
|
||||
from codeflash.models.models import ValidCode
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from pathlib import Path
|
||||
|
||||
from codeflash.verification.verification_utils import TestConfig
|
||||
|
||||
|
||||
def prepare_javascript_module(
|
||||
original_module_code: str, original_module_path: Path
|
||||
) -> tuple[dict[Path, ValidCode], None]:
|
||||
"""Prepare a JavaScript/TypeScript module for optimization.
|
||||
|
||||
Unlike Python, JS/TS doesn't need AST parsing or import analysis at this stage.
|
||||
Returns a mapping of the file path to ValidCode with the source as-is.
|
||||
"""
|
||||
validated_original_code: dict[Path, ValidCode] = {
|
||||
original_module_path: ValidCode(source_code=original_module_code, normalized_code=original_module_code)
|
||||
}
|
||||
return validated_original_code, None
|
||||
|
||||
|
||||
def verify_js_requirements(test_cfg: TestConfig) -> None:
|
||||
"""Verify JavaScript/TypeScript requirements before optimization.
|
||||
|
||||
Checks that Node.js, npm, and the test framework are available.
|
||||
Logs warnings if requirements are not met but does not abort.
|
||||
"""
|
||||
from codeflash.languages import get_language_support
|
||||
from codeflash.languages.base import Language
|
||||
from codeflash.languages.test_framework import get_js_test_framework_or_default
|
||||
|
||||
js_project_root = test_cfg.js_project_root
|
||||
if not js_project_root:
|
||||
return
|
||||
|
||||
try:
|
||||
js_support = get_language_support(Language.JAVASCRIPT)
|
||||
test_framework = get_js_test_framework_or_default()
|
||||
success, errors = js_support.verify_requirements(js_project_root, test_framework)
|
||||
|
||||
if not success:
|
||||
logger.warning("JavaScript requirements check found issues:")
|
||||
for error in errors:
|
||||
logger.warning(f" - {error}")
|
||||
except Exception as e:
|
||||
logger.debug(f"Failed to verify JS requirements: {e}")
|
||||
|
|
@ -23,7 +23,8 @@ if TYPE_CHECKING:
|
|||
|
||||
from codeflash.languages.base import ReferenceInfo
|
||||
from codeflash.languages.javascript.treesitter import TypeDefinition
|
||||
from codeflash.models.models import GeneratedTestsList, InvocationId
|
||||
from codeflash.models.models import GeneratedTestsList, InvocationId, ValidCode
|
||||
from codeflash.verification.verification_utils import TestConfig
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
|
@ -50,8 +51,7 @@ class JavaScriptSupport:
|
|||
|
||||
@property
|
||||
def default_file_extension(self) -> str:
|
||||
"""Default file extension for JavaScript."""
|
||||
return ".js"
|
||||
return self.file_extensions[0]
|
||||
|
||||
@property
|
||||
def test_framework(self) -> str:
|
||||
|
|
@ -68,17 +68,47 @@ class JavaScriptSupport:
|
|||
def dir_excludes(self) -> frozenset[str]:
|
||||
return frozenset({"node_modules", "dist", "build", ".next", ".nuxt", "coverage", ".cache", ".turbo", ".vercel"})
|
||||
|
||||
@property
|
||||
def default_language_version(self) -> str | None:
|
||||
return "ES2022"
|
||||
|
||||
@property
|
||||
def valid_test_frameworks(self) -> tuple[str, ...]:
|
||||
return ("jest", "mocha", "vitest")
|
||||
|
||||
@property
|
||||
def test_result_serialization_format(self) -> str:
|
||||
return "json"
|
||||
|
||||
def load_coverage(
|
||||
self,
|
||||
coverage_database_file: Path,
|
||||
function_name: str,
|
||||
code_context: Any,
|
||||
source_file: Path,
|
||||
coverage_config_file: Path | None = None,
|
||||
) -> Any:
|
||||
from codeflash.verification.coverage_utils import JestCoverageUtils
|
||||
|
||||
return JestCoverageUtils.load_from_jest_json(
|
||||
coverage_json_path=coverage_database_file,
|
||||
function_name=function_name,
|
||||
code_context=code_context,
|
||||
source_code_path=source_file,
|
||||
)
|
||||
|
||||
# === Discovery ===
|
||||
|
||||
def discover_functions(
|
||||
self, file_path: Path, filter_criteria: FunctionFilterCriteria | None = None
|
||||
self, source: str, file_path: Path, filter_criteria: FunctionFilterCriteria | None = None
|
||||
) -> list[FunctionToOptimize]:
|
||||
"""Find all optimizable functions in a JavaScript file.
|
||||
"""Find all optimizable functions in JavaScript/TypeScript source code.
|
||||
|
||||
Uses tree-sitter to parse the file and find functions.
|
||||
Uses tree-sitter to parse the source and find functions.
|
||||
|
||||
Args:
|
||||
file_path: Path to the JavaScript file to analyze.
|
||||
source: Source code to analyze.
|
||||
file_path: Path to the source file (used for language detection).
|
||||
filter_criteria: Optional criteria to filter functions.
|
||||
|
||||
Returns:
|
||||
|
|
@ -87,12 +117,6 @@ class JavaScriptSupport:
|
|||
"""
|
||||
criteria = filter_criteria or FunctionFilterCriteria()
|
||||
|
||||
try:
|
||||
source = file_path.read_text(encoding="utf-8")
|
||||
except Exception as e:
|
||||
logger.warning("Failed to read %s: %s", file_path, e)
|
||||
return []
|
||||
|
||||
try:
|
||||
analyzer = get_analyzer_for_file(file_path)
|
||||
tree_functions = analyzer.find_functions(
|
||||
|
|
@ -111,7 +135,7 @@ class JavaScriptSupport:
|
|||
|
||||
# Skip non-exported functions (can't be imported in tests)
|
||||
# Exception: nested functions and methods are allowed if their parent is exported
|
||||
if not func.is_exported and not func.parent_function:
|
||||
if criteria.require_export and not func.is_exported and not func.parent_function:
|
||||
logger.debug(f"Skipping non-exported function: {func.name}") # noqa: G004
|
||||
continue
|
||||
|
||||
|
|
@ -144,61 +168,6 @@ class JavaScriptSupport:
|
|||
logger.warning("Failed to parse %s: %s", file_path, e)
|
||||
return []
|
||||
|
||||
def discover_functions_from_source(self, source: str, file_path: Path | None = None) -> list[FunctionToOptimize]:
|
||||
"""Find all functions in source code string.
|
||||
|
||||
Uses tree-sitter to parse the source and find functions.
|
||||
|
||||
Args:
|
||||
source: The source code to analyze.
|
||||
file_path: Optional file path for context (used for language detection).
|
||||
|
||||
Returns:
|
||||
List of FunctionToOptimize objects for discovered functions.
|
||||
|
||||
"""
|
||||
try:
|
||||
# Use JavaScript analyzer by default, or detect from file path
|
||||
if file_path:
|
||||
analyzer = get_analyzer_for_file(file_path)
|
||||
else:
|
||||
analyzer = TreeSitterAnalyzer(TreeSitterLanguage.JAVASCRIPT)
|
||||
|
||||
tree_functions = analyzer.find_functions(
|
||||
source, include_methods=True, include_arrow_functions=True, require_name=True
|
||||
)
|
||||
|
||||
functions: list[FunctionToOptimize] = []
|
||||
for func in tree_functions:
|
||||
# Build parents list
|
||||
parents: list[FunctionParent] = []
|
||||
if func.class_name:
|
||||
parents.append(FunctionParent(name=func.class_name, type="ClassDef"))
|
||||
if func.parent_function:
|
||||
parents.append(FunctionParent(name=func.parent_function, type="FunctionDef"))
|
||||
|
||||
functions.append(
|
||||
FunctionToOptimize(
|
||||
function_name=func.name,
|
||||
file_path=file_path or Path("unknown"),
|
||||
parents=parents,
|
||||
starting_line=func.start_line,
|
||||
ending_line=func.end_line,
|
||||
starting_col=func.start_col,
|
||||
ending_col=func.end_col,
|
||||
is_async=func.is_async,
|
||||
is_method=func.is_method,
|
||||
language=str(self.language),
|
||||
doc_start_line=func.doc_start_line,
|
||||
)
|
||||
)
|
||||
|
||||
return functions
|
||||
|
||||
except Exception as e:
|
||||
logger.warning("Failed to parse source: %s", e)
|
||||
return []
|
||||
|
||||
def _get_test_patterns(self) -> list[str]:
|
||||
"""Get test file patterns for this language.
|
||||
|
||||
|
|
@ -1508,7 +1477,7 @@ class JavaScriptSupport:
|
|||
return "".join(result_lines)
|
||||
|
||||
def format_code(self, source: str, file_path: Path | None = None) -> str:
|
||||
"""Format JavaScript code using prettier (if available).
|
||||
"""Format JavaScript/TypeScript code using prettier (if available).
|
||||
|
||||
Args:
|
||||
source: Source code to format.
|
||||
|
|
@ -1519,9 +1488,10 @@ class JavaScriptSupport:
|
|||
|
||||
"""
|
||||
try:
|
||||
# Try to use prettier via npx
|
||||
stdin_filepath = str(file_path.name) if file_path else f"file{self.default_file_extension}"
|
||||
|
||||
result = subprocess.run(
|
||||
["npx", "prettier", "--stdin-filepath", "file.js"],
|
||||
["npx", "prettier", "--stdin-filepath", stdin_filepath],
|
||||
check=False,
|
||||
input=source,
|
||||
capture_output=True,
|
||||
|
|
@ -1702,22 +1672,15 @@ class JavaScriptSupport:
|
|||
|
||||
# === Validation ===
|
||||
|
||||
@property
|
||||
def treesitter_language(self) -> TreeSitterLanguage:
|
||||
return TreeSitterLanguage.JAVASCRIPT
|
||||
|
||||
def validate_syntax(self, source: str) -> bool:
|
||||
"""Check if JavaScript source code is syntactically valid.
|
||||
|
||||
Uses tree-sitter to parse and check for errors.
|
||||
|
||||
Args:
|
||||
source: Source code to validate.
|
||||
|
||||
Returns:
|
||||
True if valid, False otherwise.
|
||||
|
||||
"""
|
||||
"""Check if source code is syntactically valid using tree-sitter."""
|
||||
try:
|
||||
analyzer = TreeSitterAnalyzer(TreeSitterLanguage.JAVASCRIPT)
|
||||
analyzer = TreeSitterAnalyzer(self.treesitter_language)
|
||||
tree = analyzer.parse(source)
|
||||
# Check if tree has errors
|
||||
return not tree.root_node.has_error
|
||||
except Exception:
|
||||
return False
|
||||
|
|
@ -1744,6 +1707,11 @@ class JavaScriptSupport:
|
|||
normalized_lines.append(stripped)
|
||||
return "\n".join(normalized_lines)
|
||||
|
||||
def generate_concolic_tests(
|
||||
self, test_cfg: Any, project_root: Any, function_to_optimize: Any, function_to_optimize_ast: Any
|
||||
) -> tuple[dict, str]:
|
||||
return {}, ""
|
||||
|
||||
# === Test Editing ===
|
||||
|
||||
def add_runtime_comments(
|
||||
|
|
@ -1909,6 +1877,92 @@ class JavaScriptSupport:
|
|||
|
||||
return compare_test_results(original_results_path, candidate_results_path, project_root=project_root)
|
||||
|
||||
@property
|
||||
def function_optimizer_class(self) -> type:
|
||||
from codeflash.languages.javascript.function_optimizer import JavaScriptFunctionOptimizer
|
||||
|
||||
return JavaScriptFunctionOptimizer
|
||||
|
||||
def prepare_module(
|
||||
self, module_code: str, module_path: Path, project_root: Path
|
||||
) -> tuple[dict[Path, ValidCode], None]:
|
||||
from codeflash.languages.javascript.optimizer import prepare_javascript_module
|
||||
|
||||
return prepare_javascript_module(module_code, module_path)
|
||||
|
||||
def setup_test_config(self, test_cfg: TestConfig, file_path: Path) -> None:
|
||||
from codeflash.languages.javascript.optimizer import verify_js_requirements
|
||||
from codeflash.languages.javascript.test_runner import find_node_project_root
|
||||
|
||||
test_cfg.js_project_root = find_node_project_root(file_path)
|
||||
verify_js_requirements(test_cfg)
|
||||
|
||||
def adjust_test_config_for_discovery(self, test_cfg: TestConfig) -> None:
|
||||
test_cfg.tests_project_rootdir = test_cfg.tests_root
|
||||
|
||||
def detect_module_system(self, project_root: Path, source_file: Path) -> str | None:
|
||||
from codeflash.languages.javascript.module_system import detect_module_system
|
||||
|
||||
return detect_module_system(project_root, source_file)
|
||||
|
||||
def process_generated_test_strings(
|
||||
self,
|
||||
generated_test_source: str,
|
||||
instrumented_behavior_test_source: str,
|
||||
instrumented_perf_test_source: str,
|
||||
function_to_optimize: Any,
|
||||
test_path: Path,
|
||||
test_cfg: Any,
|
||||
project_module_system: str | None,
|
||||
) -> tuple[str, str, str]:
|
||||
from codeflash.languages.javascript.instrument import (
|
||||
TestingMode,
|
||||
fix_imports_inside_test_blocks,
|
||||
fix_jest_mock_paths,
|
||||
instrument_generated_js_test,
|
||||
validate_and_fix_import_style,
|
||||
)
|
||||
from codeflash.languages.javascript.module_system import (
|
||||
ensure_module_system_compatibility,
|
||||
ensure_vitest_imports,
|
||||
)
|
||||
|
||||
source_file = Path(function_to_optimize.file_path)
|
||||
|
||||
# Fix import statements that appear inside test blocks (invalid JS syntax)
|
||||
generated_test_source = fix_imports_inside_test_blocks(generated_test_source)
|
||||
|
||||
# Fix relative paths in jest.mock() calls
|
||||
generated_test_source = fix_jest_mock_paths(
|
||||
generated_test_source, test_path, source_file, test_cfg.tests_project_rootdir
|
||||
)
|
||||
|
||||
# Validate and fix import styles (default vs named exports)
|
||||
generated_test_source = validate_and_fix_import_style(
|
||||
generated_test_source, source_file, function_to_optimize.function_name
|
||||
)
|
||||
|
||||
# Convert module system if needed (e.g., CommonJS -> ESM for ESM projects)
|
||||
generated_test_source = ensure_module_system_compatibility(
|
||||
generated_test_source, project_module_system, test_cfg.tests_project_rootdir
|
||||
)
|
||||
|
||||
# Ensure vitest imports are present when using vitest framework
|
||||
generated_test_source = ensure_vitest_imports(generated_test_source, test_cfg.test_framework)
|
||||
|
||||
# Instrument for behavior verification (writes to SQLite)
|
||||
instrumented_behavior_test_source = instrument_generated_js_test(
|
||||
test_code=generated_test_source, function_to_optimize=function_to_optimize, mode=TestingMode.BEHAVIOR
|
||||
)
|
||||
|
||||
# Instrument for performance measurement (prints to stdout)
|
||||
instrumented_perf_test_source = instrument_generated_js_test(
|
||||
test_code=generated_test_source, function_to_optimize=function_to_optimize, mode=TestingMode.PERFORMANCE
|
||||
)
|
||||
|
||||
logger.debug("Instrumented JS/TS tests locally for %s", function_to_optimize.function_name)
|
||||
return generated_test_source, instrumented_behavior_test_source, instrumented_perf_test_source
|
||||
|
||||
# === Configuration ===
|
||||
|
||||
def get_test_file_suffix(self) -> str:
|
||||
|
|
@ -2488,62 +2542,9 @@ class TypeScriptSupport(JavaScriptSupport):
|
|||
]
|
||||
|
||||
def get_test_file_suffix(self) -> str:
|
||||
"""Get the test file suffix for TypeScript.
|
||||
|
||||
Returns:
|
||||
Jest test file suffix for TypeScript.
|
||||
|
||||
"""
|
||||
"""Get the test file suffix for TypeScript."""
|
||||
return ".test.ts"
|
||||
|
||||
def validate_syntax(self, source: str) -> bool:
|
||||
"""Check if TypeScript source code is syntactically valid.
|
||||
|
||||
Uses tree-sitter TypeScript parser to parse and check for errors.
|
||||
|
||||
Args:
|
||||
source: Source code to validate.
|
||||
|
||||
Returns:
|
||||
True if valid, False otherwise.
|
||||
|
||||
"""
|
||||
try:
|
||||
analyzer = TreeSitterAnalyzer(TreeSitterLanguage.TYPESCRIPT)
|
||||
tree = analyzer.parse(source)
|
||||
return not tree.root_node.has_error
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
def format_code(self, source: str, file_path: Path | None = None) -> str:
|
||||
"""Format TypeScript code using prettier (if available).
|
||||
|
||||
Args:
|
||||
source: Source code to format.
|
||||
file_path: Optional file path for context.
|
||||
|
||||
Returns:
|
||||
Formatted source code.
|
||||
|
||||
"""
|
||||
try:
|
||||
# Determine file extension for prettier
|
||||
stdin_filepath = str(file_path.name) if file_path else "file.ts"
|
||||
|
||||
# Try to use prettier via npx
|
||||
result = subprocess.run(
|
||||
["npx", "prettier", "--stdin-filepath", stdin_filepath],
|
||||
check=False,
|
||||
input=source,
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=30,
|
||||
)
|
||||
if result.returncode == 0:
|
||||
return result.stdout
|
||||
except (subprocess.TimeoutExpired, FileNotFoundError):
|
||||
pass
|
||||
except Exception as e:
|
||||
logger.debug("Prettier formatting failed: %s", e)
|
||||
|
||||
return source
|
||||
@property
|
||||
def treesitter_language(self) -> TreeSitterLanguage:
|
||||
return TreeSitterLanguage.TYPESCRIPT
|
||||
|
|
|
|||
|
|
@ -369,7 +369,7 @@ def _get_jest_config_for_project(project_root: Path) -> Path | None:
|
|||
return original_jest_config
|
||||
|
||||
|
||||
def _find_node_project_root(file_path: Path) -> Path | None:
|
||||
def find_node_project_root(file_path: Path) -> Path | None:
|
||||
"""Find the Node.js project root by looking for package.json.
|
||||
|
||||
Traverses up from the given file path to find the nearest directory
|
||||
|
|
@ -686,7 +686,7 @@ def run_jest_behavioral_tests(
|
|||
# Use provided project_root, or detect it as fallback
|
||||
if project_root is None and test_files:
|
||||
first_test_file = Path(test_files[0])
|
||||
project_root = _find_node_project_root(first_test_file)
|
||||
project_root = find_node_project_root(first_test_file)
|
||||
|
||||
# Use the project root, or fall back to provided cwd
|
||||
effective_cwd = project_root if project_root else cwd
|
||||
|
|
@ -936,7 +936,7 @@ def run_jest_benchmarking_tests(
|
|||
# Use provided project_root, or detect it as fallback
|
||||
if project_root is None and test_files:
|
||||
first_test_file = Path(test_files[0])
|
||||
project_root = _find_node_project_root(first_test_file)
|
||||
project_root = find_node_project_root(first_test_file)
|
||||
|
||||
effective_cwd = project_root if project_root else cwd
|
||||
|
||||
|
|
@ -1106,7 +1106,7 @@ def run_jest_line_profile_tests(
|
|||
# Use provided project_root, or detect it as fallback
|
||||
if project_root is None and test_files:
|
||||
first_test_file = Path(test_files[0])
|
||||
project_root = _find_node_project_root(first_test_file)
|
||||
project_root = find_node_project_root(first_test_file)
|
||||
|
||||
effective_cwd = project_root if project_root else cwd
|
||||
logger.debug(f"Jest line profiling working directory: {effective_cwd}")
|
||||
|
|
|
|||
|
|
@ -12,6 +12,7 @@ class Language(str, Enum):
|
|||
PYTHON = "python"
|
||||
JAVASCRIPT = "javascript"
|
||||
TYPESCRIPT = "typescript"
|
||||
JAVA = "java"
|
||||
|
||||
def __str__(self) -> str:
|
||||
return self.value
|
||||
|
|
|
|||
|
|
@ -12,11 +12,13 @@ import libcst as cst
|
|||
|
||||
from codeflash.cli_cmds.console import logger
|
||||
from codeflash.code_utils.code_utils import encoded_tokens_len, get_qualified_name, path_belongs_to_site_packages
|
||||
from codeflash.code_utils.config_consts import OPTIMIZATION_CONTEXT_TOKEN_LIMIT, TESTGEN_CONTEXT_TOKEN_LIMIT
|
||||
from codeflash.code_utils.config_consts import (
|
||||
OPTIMIZATION_CONTEXT_TOKEN_LIMIT,
|
||||
READ_WRITABLE_LIMIT_ERROR,
|
||||
TESTGEN_CONTEXT_TOKEN_LIMIT,
|
||||
TESTGEN_LIMIT_ERROR,
|
||||
)
|
||||
from codeflash.discovery.functions_to_optimize import FunctionToOptimize # noqa: TC001
|
||||
|
||||
# Language support imports for multi-language code context extraction
|
||||
from codeflash.languages import Language, is_python
|
||||
from codeflash.languages.python.context.unused_definition_remover import (
|
||||
collect_top_level_defs_with_usages,
|
||||
get_section_names,
|
||||
|
|
@ -40,13 +42,9 @@ from codeflash.optimization.function_context import belongs_to_function_qualifie
|
|||
if TYPE_CHECKING:
|
||||
from jedi.api.classes import Name
|
||||
|
||||
from codeflash.languages.base import DependencyResolver, HelperFunction
|
||||
from codeflash.languages.base import DependencyResolver
|
||||
from codeflash.languages.python.context.unused_definition_remover import UsageInfo
|
||||
|
||||
# Error message constants
|
||||
READ_WRITABLE_LIMIT_ERROR = "Read-writable code has exceeded token limit, cannot proceed"
|
||||
TESTGEN_LIMIT_ERROR = "Testgen code context has exceeded token limit, cannot proceed"
|
||||
|
||||
|
||||
def build_testgen_context(
|
||||
helpers_of_fto_dict: dict[Path, set[FunctionSource]],
|
||||
|
|
@ -91,12 +89,6 @@ def get_code_optimization_context(
|
|||
testgen_token_limit: int = TESTGEN_CONTEXT_TOKEN_LIMIT,
|
||||
call_graph: DependencyResolver | None = None,
|
||||
) -> CodeOptimizationContext:
|
||||
# Route to language-specific implementation for non-Python languages
|
||||
if not is_python():
|
||||
return get_code_optimization_context_for_language(
|
||||
function_to_optimize, project_root_path, optim_token_limit, testgen_token_limit
|
||||
)
|
||||
|
||||
# Get FunctionSource representation of helpers of FTO
|
||||
fto_input = {function_to_optimize.file_path: {function_to_optimize.qualified_name}}
|
||||
if call_graph is not None:
|
||||
|
|
@ -216,140 +208,6 @@ def get_code_optimization_context(
|
|||
)
|
||||
|
||||
|
||||
def get_code_optimization_context_for_language(
|
||||
function_to_optimize: FunctionToOptimize,
|
||||
project_root_path: Path,
|
||||
optim_token_limit: int = OPTIMIZATION_CONTEXT_TOKEN_LIMIT,
|
||||
testgen_token_limit: int = TESTGEN_CONTEXT_TOKEN_LIMIT,
|
||||
) -> CodeOptimizationContext:
|
||||
"""Extract code optimization context for non-Python languages.
|
||||
|
||||
Uses the language support abstraction to extract code context and converts
|
||||
it to the CodeOptimizationContext format expected by the pipeline.
|
||||
|
||||
This function supports multi-file context extraction, grouping helpers by file
|
||||
and creating proper CodeStringsMarkdown with file paths for multi-file replacement.
|
||||
|
||||
Args:
|
||||
function_to_optimize: The function to extract context for.
|
||||
project_root_path: Root of the project.
|
||||
optim_token_limit: Token limit for optimization context.
|
||||
testgen_token_limit: Token limit for testgen context.
|
||||
|
||||
Returns:
|
||||
CodeOptimizationContext with target code and dependencies.
|
||||
|
||||
"""
|
||||
from codeflash.languages import get_language_support
|
||||
|
||||
# Get language support for this function
|
||||
language = Language(function_to_optimize.language)
|
||||
lang_support = get_language_support(language)
|
||||
|
||||
# Extract code context using language support
|
||||
code_context = lang_support.extract_code_context(function_to_optimize, project_root_path, project_root_path)
|
||||
|
||||
# Build imports string if available
|
||||
imports_code = "\n".join(code_context.imports) if code_context.imports else ""
|
||||
|
||||
# Get relative path for target file
|
||||
try:
|
||||
target_relative_path = function_to_optimize.file_path.resolve().relative_to(project_root_path.resolve())
|
||||
except ValueError:
|
||||
target_relative_path = function_to_optimize.file_path
|
||||
|
||||
# Group helpers by file path
|
||||
helpers_by_file: dict[Path, list[HelperFunction]] = defaultdict(list)
|
||||
helper_function_sources = []
|
||||
|
||||
for helper in code_context.helper_functions:
|
||||
helpers_by_file[helper.file_path].append(helper)
|
||||
|
||||
# Convert to FunctionSource for pipeline compatibility
|
||||
helper_function_sources.append(
|
||||
FunctionSource(
|
||||
file_path=helper.file_path,
|
||||
qualified_name=helper.qualified_name,
|
||||
fully_qualified_name=helper.qualified_name,
|
||||
only_function_name=helper.name,
|
||||
source_code=helper.source_code,
|
||||
)
|
||||
)
|
||||
|
||||
# Build read-writable code (target file + same-file helpers + global variables)
|
||||
read_writable_code_strings = []
|
||||
|
||||
# Combine target code with same-file helpers
|
||||
target_file_code = code_context.target_code
|
||||
same_file_helpers = helpers_by_file.get(function_to_optimize.file_path, [])
|
||||
if same_file_helpers:
|
||||
helper_code = "\n\n".join(h.source_code for h in same_file_helpers)
|
||||
target_file_code = target_file_code + "\n\n" + helper_code
|
||||
|
||||
# Note: code_context.read_only_context contains type definitions and global variables
|
||||
# These should be passed as read-only context to the AI, not prepended to the target code
|
||||
# If prepended to target code, the AI treats them as code to optimize and includes them in output
|
||||
|
||||
# Add imports to target file code
|
||||
if imports_code:
|
||||
target_file_code = imports_code + "\n\n" + target_file_code
|
||||
|
||||
read_writable_code_strings.append(
|
||||
CodeString(code=target_file_code, file_path=target_relative_path, language=function_to_optimize.language)
|
||||
)
|
||||
|
||||
# Add helper files (cross-file helpers)
|
||||
for file_path, file_helpers in helpers_by_file.items():
|
||||
if file_path == function_to_optimize.file_path:
|
||||
continue # Already included in target file
|
||||
|
||||
try:
|
||||
helper_relative_path = file_path.resolve().relative_to(project_root_path.resolve())
|
||||
except ValueError:
|
||||
helper_relative_path = file_path
|
||||
|
||||
# Combine all helpers from this file
|
||||
combined_helper_code = "\n\n".join(h.source_code for h in file_helpers)
|
||||
|
||||
read_writable_code_strings.append(
|
||||
CodeString(
|
||||
code=combined_helper_code, file_path=helper_relative_path, language=function_to_optimize.language
|
||||
)
|
||||
)
|
||||
|
||||
read_writable_code = CodeStringsMarkdown(
|
||||
code_strings=read_writable_code_strings, language=function_to_optimize.language
|
||||
)
|
||||
|
||||
# Build testgen context (same as read_writable for non-Python)
|
||||
testgen_context = CodeStringsMarkdown(
|
||||
code_strings=read_writable_code_strings.copy(), language=function_to_optimize.language
|
||||
)
|
||||
|
||||
# Check token limits
|
||||
read_writable_tokens = encoded_tokens_len(read_writable_code.markdown)
|
||||
if read_writable_tokens > optim_token_limit:
|
||||
raise ValueError(READ_WRITABLE_LIMIT_ERROR)
|
||||
|
||||
testgen_tokens = encoded_tokens_len(testgen_context.markdown)
|
||||
if testgen_tokens > testgen_token_limit:
|
||||
raise ValueError(TESTGEN_LIMIT_ERROR)
|
||||
|
||||
# Generate code hash from all read-writable code
|
||||
code_hash = hashlib.sha256(read_writable_code.flat.encode("utf-8")).hexdigest()
|
||||
|
||||
return CodeOptimizationContext(
|
||||
testgen_context=testgen_context,
|
||||
read_writable_code=read_writable_code,
|
||||
read_only_context_code=code_context.read_only_context,
|
||||
hashing_code_context=read_writable_code.flat,
|
||||
hashing_code_context_hash=code_hash,
|
||||
helper_functions=helper_function_sources,
|
||||
testgen_helper_fqns=[fs.fully_qualified_name for fs in helper_function_sources],
|
||||
preexisting_objects=set(),
|
||||
)
|
||||
|
||||
|
||||
def process_file_context(
|
||||
file_path: Path,
|
||||
primary_qualified_names: set[str],
|
||||
|
|
|
|||
|
|
@ -10,7 +10,8 @@ from typing import TYPE_CHECKING, Optional, Union
|
|||
import libcst as cst
|
||||
|
||||
from codeflash.cli_cmds.console import logger
|
||||
from codeflash.languages import is_python
|
||||
from codeflash.languages import current_language
|
||||
from codeflash.languages.base import Language
|
||||
from codeflash.languages.python.static_analysis.code_replacer import replace_function_definitions_in_module
|
||||
from codeflash.models.models import CodeString, CodeStringsMarkdown
|
||||
|
||||
|
|
@ -747,7 +748,7 @@ def detect_unused_helper_functions(
|
|||
|
||||
"""
|
||||
# Skip this analysis for non-Python languages since we use Python's ast module
|
||||
if not is_python():
|
||||
if current_language() != Language.PYTHON:
|
||||
logger.debug("Skipping unused helper function detection for non-Python languages")
|
||||
return []
|
||||
|
||||
|
|
|
|||
215
codeflash/languages/python/function_optimizer.py
Normal file
215
codeflash/languages/python/function_optimizer.py
Normal file
|
|
@ -0,0 +1,215 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import ast
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from codeflash.cli_cmds.console import console, logger
|
||||
from codeflash.code_utils.config_consts import TOTAL_LOOPING_TIME_EFFECTIVE
|
||||
from codeflash.either import Failure, Success
|
||||
from codeflash.languages.python.context.unused_definition_remover import (
|
||||
detect_unused_helper_functions,
|
||||
revert_unused_helper_functions,
|
||||
)
|
||||
from codeflash.languages.python.optimizer import resolve_python_function_ast
|
||||
from codeflash.languages.python.static_analysis.code_extractor import get_opt_review_metrics, is_numerical_code
|
||||
from codeflash.languages.python.static_analysis.code_replacer import (
|
||||
add_custom_marker_to_all_tests,
|
||||
modify_autouse_fixture,
|
||||
)
|
||||
from codeflash.languages.python.static_analysis.line_profile_utils import add_decorator_imports, contains_jit_decorator
|
||||
from codeflash.models.models import TestingMode, TestResults
|
||||
from codeflash.optimization.function_optimizer import FunctionOptimizer
|
||||
from codeflash.verification.parse_test_output import calculate_function_throughput_from_test_results
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from typing import Any
|
||||
|
||||
from codeflash.either import Result
|
||||
from codeflash.languages.base import Language
|
||||
from codeflash.models.function_types import FunctionParent
|
||||
from codeflash.models.models import (
|
||||
CodeOptimizationContext,
|
||||
CodeStringsMarkdown,
|
||||
ConcurrencyMetrics,
|
||||
CoverageData,
|
||||
OriginalCodeBaseline,
|
||||
TestDiff,
|
||||
)
|
||||
|
||||
|
||||
class PythonFunctionOptimizer(FunctionOptimizer):
|
||||
def get_code_optimization_context(self) -> Result[CodeOptimizationContext, str]:
|
||||
from codeflash.languages.python.context import code_context_extractor
|
||||
|
||||
try:
|
||||
return Success(
|
||||
code_context_extractor.get_code_optimization_context(
|
||||
self.function_to_optimize, self.project_root, call_graph=self.call_graph
|
||||
)
|
||||
)
|
||||
except ValueError as e:
|
||||
return Failure(str(e))
|
||||
|
||||
def _resolve_function_ast(
|
||||
self, source_code: str, function_name: str, parents: list[FunctionParent]
|
||||
) -> ast.FunctionDef | ast.AsyncFunctionDef | None:
|
||||
original_module_ast = ast.parse(source_code)
|
||||
return resolve_python_function_ast(function_name, parents, original_module_ast)
|
||||
|
||||
def requires_function_ast(self) -> bool:
|
||||
return True
|
||||
|
||||
def analyze_code_characteristics(self, code_context: CodeOptimizationContext) -> None:
|
||||
self.is_numerical_code = is_numerical_code(code_string=code_context.read_writable_code.flat)
|
||||
|
||||
def get_optimization_review_metrics(
|
||||
self,
|
||||
source_code: str,
|
||||
file_path: Path,
|
||||
qualified_name: str,
|
||||
project_root: Path,
|
||||
tests_root: Path,
|
||||
language: Language,
|
||||
) -> str:
|
||||
return get_opt_review_metrics(source_code, file_path, qualified_name, project_root, tests_root, language)
|
||||
|
||||
def instrument_test_fixtures(self, test_paths: list[Path]) -> dict[Path, list[str]] | None:
|
||||
logger.info("Disabling all autouse fixtures associated with the generated test files")
|
||||
original_conftest_content = modify_autouse_fixture(test_paths)
|
||||
logger.info("Add custom marker to generated test files")
|
||||
add_custom_marker_to_all_tests(test_paths)
|
||||
return original_conftest_content
|
||||
|
||||
def instrument_capture(self, file_path_to_helper_classes: dict[Path, set[str]]) -> None:
|
||||
from codeflash.verification.instrument_codeflash_capture import instrument_codeflash_capture
|
||||
|
||||
instrument_codeflash_capture(self.function_to_optimize, file_path_to_helper_classes, self.test_cfg.tests_root)
|
||||
|
||||
def should_check_coverage(self) -> bool:
|
||||
return True
|
||||
|
||||
def collect_async_metrics(
|
||||
self,
|
||||
benchmarking_results: TestResults,
|
||||
code_context: CodeOptimizationContext,
|
||||
helper_code: dict[Path, str],
|
||||
test_env: dict[str, str],
|
||||
) -> tuple[int | None, ConcurrencyMetrics | None]:
|
||||
if not self.function_to_optimize.is_async:
|
||||
return None, None
|
||||
|
||||
async_throughput = calculate_function_throughput_from_test_results(
|
||||
benchmarking_results, self.function_to_optimize.function_name
|
||||
)
|
||||
logger.debug(f"Async function throughput: {async_throughput} calls/second")
|
||||
|
||||
concurrency_metrics = self.run_concurrency_benchmark(
|
||||
code_context=code_context, original_helper_code=helper_code, test_env=test_env
|
||||
)
|
||||
if concurrency_metrics:
|
||||
logger.debug(
|
||||
f"Concurrency metrics: ratio={concurrency_metrics.concurrency_ratio:.2f}, "
|
||||
f"seq={concurrency_metrics.sequential_time_ns}ns, conc={concurrency_metrics.concurrent_time_ns}ns"
|
||||
)
|
||||
return async_throughput, concurrency_metrics
|
||||
|
||||
def instrument_async_for_mode(self, mode: TestingMode) -> None:
|
||||
from codeflash.code_utils.instrument_existing_tests import add_async_decorator_to_function
|
||||
|
||||
add_async_decorator_to_function(
|
||||
self.function_to_optimize.file_path, self.function_to_optimize, mode, project_root=self.project_root
|
||||
)
|
||||
|
||||
def should_skip_sqlite_cleanup(self, testing_type: TestingMode, optimization_iteration: int) -> bool:
|
||||
return False
|
||||
|
||||
def parse_line_profile_test_results(
|
||||
self, line_profiler_output_file: Path | None
|
||||
) -> tuple[TestResults | dict, CoverageData | None]:
|
||||
from codeflash.verification.parse_line_profile_test_output import parse_line_profile_results
|
||||
|
||||
return parse_line_profile_results(line_profiler_output_file=line_profiler_output_file)
|
||||
|
||||
def compare_candidate_results(
|
||||
self,
|
||||
baseline_results: OriginalCodeBaseline,
|
||||
candidate_behavior_results: TestResults,
|
||||
optimization_candidate_index: int,
|
||||
) -> tuple[bool, list[TestDiff]]:
|
||||
from codeflash.verification.equivalence import compare_test_results
|
||||
|
||||
return compare_test_results(baseline_results.behavior_test_results, candidate_behavior_results)
|
||||
|
||||
def replace_function_and_helpers_with_optimized_code(
|
||||
self,
|
||||
code_context: CodeOptimizationContext,
|
||||
optimized_code: CodeStringsMarkdown,
|
||||
original_helper_code: dict[Path, str],
|
||||
) -> bool:
|
||||
from codeflash.languages.python.static_analysis.code_replacer import replace_function_definitions_in_module
|
||||
|
||||
did_update = False
|
||||
for module_abspath, qualified_names in self.group_functions_by_file(code_context).items():
|
||||
did_update |= replace_function_definitions_in_module(
|
||||
function_names=list(qualified_names),
|
||||
optimized_code=optimized_code,
|
||||
module_abspath=module_abspath,
|
||||
preexisting_objects=code_context.preexisting_objects,
|
||||
project_root_path=self.project_root,
|
||||
)
|
||||
|
||||
unused_helpers = detect_unused_helper_functions(self.function_to_optimize, code_context, optimized_code)
|
||||
if unused_helpers:
|
||||
revert_unused_helper_functions(self.project_root, unused_helpers, original_helper_code)
|
||||
return did_update
|
||||
|
||||
def line_profiler_step(
|
||||
self, code_context: CodeOptimizationContext, original_helper_code: dict[Path, str], candidate_index: int
|
||||
) -> dict[str, Any]:
|
||||
candidate_fto_code = Path(self.function_to_optimize.file_path).read_text("utf-8")
|
||||
if contains_jit_decorator(candidate_fto_code):
|
||||
logger.info(
|
||||
f"Skipping line profiler for {self.function_to_optimize.function_name} - code contains JIT decorator"
|
||||
)
|
||||
return {"timings": {}, "unit": 0, "str_out": ""}
|
||||
|
||||
for module_abspath in original_helper_code:
|
||||
candidate_helper_code = Path(module_abspath).read_text("utf-8")
|
||||
if contains_jit_decorator(candidate_helper_code):
|
||||
logger.info(
|
||||
f"Skipping line profiler for {self.function_to_optimize.function_name} - helper code contains JIT decorator"
|
||||
)
|
||||
return {"timings": {}, "unit": 0, "str_out": ""}
|
||||
|
||||
try:
|
||||
console.rule()
|
||||
|
||||
test_env = self.get_test_env(
|
||||
codeflash_loop_index=0, codeflash_test_iteration=candidate_index, codeflash_tracer_disable=1
|
||||
)
|
||||
line_profiler_output_file = add_decorator_imports(self.function_to_optimize, code_context)
|
||||
line_profile_results, _ = self.run_and_parse_tests(
|
||||
testing_type=TestingMode.LINE_PROFILE,
|
||||
test_env=test_env,
|
||||
test_files=self.test_files,
|
||||
optimization_iteration=0,
|
||||
testing_time=TOTAL_LOOPING_TIME_EFFECTIVE,
|
||||
enable_coverage=False,
|
||||
code_context=code_context,
|
||||
line_profiler_output_file=line_profiler_output_file,
|
||||
)
|
||||
finally:
|
||||
self.write_code_and_helpers(
|
||||
self.function_to_optimize_source_code, original_helper_code, self.function_to_optimize.file_path
|
||||
)
|
||||
if isinstance(line_profile_results, TestResults) and not line_profile_results.test_results:
|
||||
logger.warning(
|
||||
f"Timeout occurred while running line profiler for original function {self.function_to_optimize.function_name}"
|
||||
)
|
||||
return {"timings": {}, "unit": 0, "str_out": ""}
|
||||
if line_profile_results["str_out"] == "":
|
||||
logger.warning(
|
||||
f"Couldn't run line profiler for original function {self.function_to_optimize.function_name}"
|
||||
)
|
||||
return line_profile_results
|
||||
|
|
@ -4,8 +4,6 @@ from __future__ import annotations
|
|||
|
||||
import ast
|
||||
|
||||
from codeflash.code_utils.normalizers.base import CodeNormalizer
|
||||
|
||||
|
||||
class VariableNormalizer(ast.NodeTransformer):
|
||||
"""Normalizes only local variable names in AST to canonical forms like var_0, var_1, etc.
|
||||
|
|
@ -164,63 +162,19 @@ def _remove_docstrings_from_ast(node: ast.AST) -> None:
|
|||
stack.extend([child for child in body if isinstance(child, node_types)])
|
||||
|
||||
|
||||
class PythonNormalizer(CodeNormalizer):
|
||||
"""Python code normalizer using AST transformation.
|
||||
def normalize_python_code(code: str, remove_docstrings: bool = True) -> str:
|
||||
"""Normalize Python code to a canonical form for comparison.
|
||||
|
||||
Normalizes Python code by:
|
||||
- Replacing local variable names with canonical forms (var_0, var_1, etc.)
|
||||
- Preserving function names, class names, parameters, and imports
|
||||
- Optionally removing docstrings
|
||||
Replaces local variable names with canonical forms (var_0, var_1, etc.)
|
||||
while preserving function names, class names, parameters, and imports.
|
||||
"""
|
||||
tree = ast.parse(code)
|
||||
|
||||
@property
|
||||
def language(self) -> str:
|
||||
"""Return the language this normalizer handles."""
|
||||
return "python"
|
||||
|
||||
@property
|
||||
def supported_extensions(self) -> tuple[str, ...]:
|
||||
"""Return file extensions this normalizer can handle."""
|
||||
return (".py", ".pyw", ".pyi")
|
||||
|
||||
def normalize(self, code: str, remove_docstrings: bool = True) -> str:
|
||||
"""Normalize Python code to a canonical form.
|
||||
|
||||
Args:
|
||||
code: Python source code to normalize
|
||||
remove_docstrings: Whether to remove docstrings
|
||||
|
||||
Returns:
|
||||
Normalized Python code as a string
|
||||
|
||||
"""
|
||||
tree = ast.parse(code)
|
||||
|
||||
if remove_docstrings:
|
||||
_remove_docstrings_from_ast(tree)
|
||||
|
||||
normalizer = VariableNormalizer()
|
||||
normalized_tree = normalizer.visit(tree)
|
||||
ast.fix_missing_locations(normalized_tree)
|
||||
|
||||
return ast.unparse(normalized_tree)
|
||||
|
||||
def normalize_for_hash(self, code: str) -> str:
|
||||
"""Normalize Python code optimized for hashing.
|
||||
|
||||
Returns AST dump which is faster than unparsing.
|
||||
|
||||
Args:
|
||||
code: Python source code to normalize
|
||||
|
||||
Returns:
|
||||
AST dump string suitable for hashing
|
||||
|
||||
"""
|
||||
tree = ast.parse(code)
|
||||
if remove_docstrings:
|
||||
_remove_docstrings_from_ast(tree)
|
||||
|
||||
normalizer = VariableNormalizer()
|
||||
normalized_tree = normalizer.visit(tree)
|
||||
normalizer = VariableNormalizer()
|
||||
normalized_tree = normalizer.visit(tree)
|
||||
ast.fix_missing_locations(normalized_tree)
|
||||
|
||||
return ast.dump(normalized_tree, annotate_fields=False, include_attributes=False)
|
||||
return ast.unparse(normalized_tree)
|
||||
63
codeflash/languages/python/optimizer.py
Normal file
63
codeflash/languages/python/optimizer.py
Normal file
|
|
@ -0,0 +1,63 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import ast
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from codeflash.cli_cmds.console import logger
|
||||
from codeflash.models.models import ValidCode
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from pathlib import Path
|
||||
|
||||
from codeflash.models.function_types import FunctionParent
|
||||
|
||||
|
||||
def prepare_python_module(
|
||||
original_module_code: str, original_module_path: Path, project_root: Path
|
||||
) -> tuple[dict[Path, ValidCode], ast.Module] | None:
|
||||
"""Parse a Python module, normalize its code, and validate imported callee modules.
|
||||
|
||||
Returns a mapping of file paths to ValidCode (for the module and its imported callees)
|
||||
plus the parsed AST, or None on syntax error.
|
||||
"""
|
||||
from codeflash.languages.python.static_analysis.code_replacer import normalize_code, normalize_node
|
||||
from codeflash.languages.python.static_analysis.static_analysis import analyze_imported_modules
|
||||
|
||||
try:
|
||||
original_module_ast = ast.parse(original_module_code)
|
||||
except SyntaxError as e:
|
||||
logger.warning(f"Syntax error parsing code in {original_module_path}: {e}")
|
||||
logger.info("Skipping optimization due to file error.")
|
||||
return None
|
||||
|
||||
normalized_original_module_code = ast.unparse(normalize_node(original_module_ast))
|
||||
validated_original_code: dict[Path, ValidCode] = {
|
||||
original_module_path: ValidCode(
|
||||
source_code=original_module_code, normalized_code=normalized_original_module_code
|
||||
)
|
||||
}
|
||||
|
||||
imported_module_analyses = analyze_imported_modules(original_module_code, original_module_path, project_root)
|
||||
|
||||
for analysis in imported_module_analyses:
|
||||
callee_original_code = analysis.file_path.read_text(encoding="utf8")
|
||||
try:
|
||||
normalized_callee_original_code = normalize_code(callee_original_code)
|
||||
except SyntaxError as e:
|
||||
logger.warning(f"Syntax error parsing code in callee module {analysis.file_path}: {e}")
|
||||
logger.info("Skipping optimization due to helper file error.")
|
||||
return None
|
||||
validated_original_code[analysis.file_path] = ValidCode(
|
||||
source_code=callee_original_code, normalized_code=normalized_callee_original_code
|
||||
)
|
||||
|
||||
return validated_original_code, original_module_ast
|
||||
|
||||
|
||||
def resolve_python_function_ast(
|
||||
function_name: str, parents: list[FunctionParent], module_ast: ast.Module
|
||||
) -> ast.FunctionDef | ast.AsyncFunctionDef | None:
|
||||
"""Look up a function/method AST node in a parsed Python module."""
|
||||
from codeflash.languages.python.static_analysis.static_analysis import get_first_top_level_function_or_method_ast
|
||||
|
||||
return get_first_top_level_function_or_method_ast(function_name, parents, module_ast)
|
||||
|
|
@ -4,7 +4,7 @@ import ast
|
|||
from collections import defaultdict
|
||||
from functools import lru_cache
|
||||
from itertools import chain
|
||||
from typing import TYPE_CHECKING, Optional, TypeVar
|
||||
from typing import TYPE_CHECKING, TypeVar
|
||||
|
||||
import libcst as cst
|
||||
from libcst.metadata import PositionProvider
|
||||
|
|
@ -12,7 +12,7 @@ from libcst.metadata import PositionProvider
|
|||
from codeflash.cli_cmds.console import logger
|
||||
from codeflash.code_utils.config_parser import find_conftest_files
|
||||
from codeflash.code_utils.formatter import sort_imports
|
||||
from codeflash.languages import is_python
|
||||
from codeflash.languages.code_replacer import get_optimized_code_for_module
|
||||
from codeflash.languages.python.static_analysis.code_extractor import (
|
||||
add_global_assignments,
|
||||
add_needed_imports_from_module,
|
||||
|
|
@ -24,9 +24,7 @@ from codeflash.models.models import FunctionParent
|
|||
if TYPE_CHECKING:
|
||||
from pathlib import Path
|
||||
|
||||
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
|
||||
from codeflash.languages.base import LanguageSupport
|
||||
from codeflash.models.models import CodeOptimizationContext, CodeStringsMarkdown, OptimizedCandidate, ValidCode
|
||||
from codeflash.models.models import CodeStringsMarkdown
|
||||
|
||||
ASTNodeT = TypeVar("ASTNodeT", bound=ast.AST)
|
||||
|
||||
|
|
@ -391,14 +389,7 @@ def replace_function_definitions_in_module(
|
|||
preexisting_objects: set[tuple[str, tuple[FunctionParent, ...]]],
|
||||
project_root_path: Path,
|
||||
should_add_global_assignments: bool = True,
|
||||
function_to_optimize: Optional[FunctionToOptimize] = None,
|
||||
) -> bool:
|
||||
# Route to language-specific implementation for non-Python languages
|
||||
if not is_python():
|
||||
return replace_function_definitions_for_language(
|
||||
function_names, optimized_code, module_abspath, project_root_path, function_to_optimize
|
||||
)
|
||||
|
||||
source_code: str = module_abspath.read_text(encoding="utf8")
|
||||
code_to_apply = get_optimized_code_for_module(module_abspath.relative_to(project_root_path), optimized_code)
|
||||
|
||||
|
|
@ -420,270 +411,5 @@ def replace_function_definitions_in_module(
|
|||
return True
|
||||
|
||||
|
||||
def replace_function_definitions_for_language(
|
||||
function_names: list[str],
|
||||
optimized_code: CodeStringsMarkdown,
|
||||
module_abspath: Path,
|
||||
project_root_path: Path,
|
||||
function_to_optimize: Optional[FunctionToOptimize] = None,
|
||||
) -> bool:
|
||||
"""Replace function definitions for non-Python languages.
|
||||
|
||||
Uses the language support abstraction to perform code replacement.
|
||||
|
||||
Args:
|
||||
function_names: List of qualified function names to replace.
|
||||
optimized_code: The optimized code to apply.
|
||||
module_abspath: Path to the module file.
|
||||
project_root_path: Root of the project.
|
||||
function_to_optimize: The function being optimized (needed for line info).
|
||||
|
||||
Returns:
|
||||
True if the code was modified, False if no changes.
|
||||
|
||||
"""
|
||||
from codeflash.languages import get_language_support
|
||||
from codeflash.languages.base import Language
|
||||
|
||||
original_source_code: str = module_abspath.read_text(encoding="utf8")
|
||||
code_to_apply = get_optimized_code_for_module(module_abspath.relative_to(project_root_path), optimized_code)
|
||||
|
||||
if not code_to_apply.strip():
|
||||
return False
|
||||
|
||||
# Get language support
|
||||
language = Language(optimized_code.language)
|
||||
lang_support = get_language_support(language)
|
||||
|
||||
# Add any new global declarations from the optimized code to the original source
|
||||
original_source_code = lang_support.add_global_declarations(
|
||||
optimized_code=code_to_apply, original_source=original_source_code, module_abspath=module_abspath
|
||||
)
|
||||
|
||||
# If we have function_to_optimize with line info and this is the main file, use it for precise replacement
|
||||
if (
|
||||
function_to_optimize
|
||||
and function_to_optimize.starting_line
|
||||
and function_to_optimize.ending_line
|
||||
and function_to_optimize.file_path == module_abspath
|
||||
):
|
||||
# Extract just the target function from the optimized code
|
||||
optimized_func = _extract_function_from_code(
|
||||
lang_support, code_to_apply, function_to_optimize.function_name, module_abspath
|
||||
)
|
||||
if optimized_func:
|
||||
new_code = lang_support.replace_function(original_source_code, function_to_optimize, optimized_func)
|
||||
else:
|
||||
# Fallback: use the entire optimized code (for simple single-function files)
|
||||
new_code = lang_support.replace_function(original_source_code, function_to_optimize, code_to_apply)
|
||||
else:
|
||||
# For helper files or when we don't have precise line info:
|
||||
# Find each function by name in both original and optimized code
|
||||
# Then replace with the corresponding optimized version
|
||||
new_code = original_source_code
|
||||
modified = False
|
||||
|
||||
# Get the list of function names to replace
|
||||
functions_to_replace = list(function_names)
|
||||
|
||||
for func_name in functions_to_replace:
|
||||
# Re-discover functions from current code state to get correct line numbers
|
||||
current_functions = lang_support.discover_functions_from_source(new_code, module_abspath)
|
||||
|
||||
# Find the function in current code
|
||||
func = None
|
||||
for f in current_functions:
|
||||
if func_name in (f.qualified_name, f.function_name):
|
||||
func = f
|
||||
break
|
||||
|
||||
if func is None:
|
||||
continue
|
||||
|
||||
# Extract just this function from the optimized code
|
||||
optimized_func = _extract_function_from_code(
|
||||
lang_support, code_to_apply, func.function_name, module_abspath
|
||||
)
|
||||
if optimized_func:
|
||||
new_code = lang_support.replace_function(new_code, func, optimized_func)
|
||||
modified = True
|
||||
|
||||
if not modified:
|
||||
logger.warning(f"Could not find function {function_names} in {module_abspath}")
|
||||
return False
|
||||
|
||||
# Check if there was actually a change
|
||||
if original_source_code.strip() == new_code.strip():
|
||||
return False
|
||||
|
||||
module_abspath.write_text(new_code, encoding="utf8")
|
||||
return True
|
||||
|
||||
|
||||
def _extract_function_from_code(
|
||||
lang_support: LanguageSupport, source_code: str, function_name: str, file_path: Path | None = None
|
||||
) -> str | None:
|
||||
"""Extract a specific function's source code from a code string.
|
||||
|
||||
Includes JSDoc/docstring comments if present.
|
||||
|
||||
Args:
|
||||
lang_support: Language support instance.
|
||||
source_code: The full source code containing the function.
|
||||
function_name: Name of the function to extract.
|
||||
file_path: Path to the file (used to determine correct analyzer for JS/TS).
|
||||
|
||||
Returns:
|
||||
The function's source code (including doc comments), or None if not found.
|
||||
|
||||
"""
|
||||
try:
|
||||
# Use the language support to find functions in the source
|
||||
# file_path is needed for JS/TS to determine correct analyzer (TypeScript vs JavaScript)
|
||||
functions = lang_support.discover_functions_from_source(source_code, file_path)
|
||||
for func in functions:
|
||||
if func.function_name == function_name:
|
||||
# Extract the function's source using line numbers
|
||||
# Use doc_start_line if available to include JSDoc/docstring
|
||||
lines = source_code.splitlines(keepends=True)
|
||||
effective_start = func.doc_start_line or func.starting_line
|
||||
if effective_start and func.ending_line and effective_start <= len(lines):
|
||||
func_lines = lines[effective_start - 1 : func.ending_line]
|
||||
return "".join(func_lines)
|
||||
except Exception as e:
|
||||
logger.debug(f"Error extracting function {function_name}: {e}")
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def get_optimized_code_for_module(relative_path: Path, optimized_code: CodeStringsMarkdown) -> str:
|
||||
file_to_code_context = optimized_code.file_to_path()
|
||||
module_optimized_code = file_to_code_context.get(str(relative_path))
|
||||
if module_optimized_code is None:
|
||||
# Fallback: if there's only one code block with None file path,
|
||||
# use it regardless of the expected path (the AI server doesn't always include file paths)
|
||||
if "None" in file_to_code_context and len(file_to_code_context) == 1:
|
||||
module_optimized_code = file_to_code_context["None"]
|
||||
logger.debug(f"Using code block with None file_path for {relative_path}")
|
||||
else:
|
||||
logger.warning(
|
||||
f"Optimized code not found for {relative_path} In the context\n-------\n{optimized_code}\n-------\n"
|
||||
"re-check your 'markdown code structure'"
|
||||
f"existing files are {file_to_code_context.keys()}"
|
||||
)
|
||||
module_optimized_code = ""
|
||||
return module_optimized_code
|
||||
|
||||
|
||||
def is_zero_diff(original_code: str, new_code: str) -> bool:
|
||||
return normalize_code(original_code) == normalize_code(new_code)
|
||||
|
||||
|
||||
def replace_optimized_code(
|
||||
callee_module_paths: set[Path],
|
||||
candidates: list[OptimizedCandidate],
|
||||
code_context: CodeOptimizationContext,
|
||||
function_to_optimize: FunctionToOptimize,
|
||||
validated_original_code: dict[Path, ValidCode],
|
||||
project_root: Path,
|
||||
) -> tuple[set[Path], dict[str, dict[Path, str]]]:
|
||||
initial_optimized_code = {
|
||||
candidate.optimization_id: replace_functions_and_add_imports(
|
||||
validated_original_code[function_to_optimize.file_path].source_code,
|
||||
[function_to_optimize.qualified_name],
|
||||
candidate.source_code,
|
||||
function_to_optimize.file_path,
|
||||
function_to_optimize.file_path,
|
||||
code_context.preexisting_objects,
|
||||
project_root,
|
||||
)
|
||||
for candidate in candidates
|
||||
}
|
||||
callee_original_code = {
|
||||
module_path: validated_original_code[module_path].source_code for module_path in callee_module_paths
|
||||
}
|
||||
intermediate_original_code: dict[str, dict[Path, str]] = {
|
||||
candidate.optimization_id: (
|
||||
callee_original_code | {function_to_optimize.file_path: initial_optimized_code[candidate.optimization_id]}
|
||||
)
|
||||
for candidate in candidates
|
||||
}
|
||||
module_paths = callee_module_paths | {function_to_optimize.file_path}
|
||||
optimized_code = {
|
||||
candidate.optimization_id: {
|
||||
module_path: replace_functions_and_add_imports(
|
||||
intermediate_original_code[candidate.optimization_id][module_path],
|
||||
(
|
||||
[
|
||||
callee.qualified_name
|
||||
for callee in code_context.helper_functions
|
||||
if callee.file_path == module_path and callee.definition_type != "class"
|
||||
]
|
||||
),
|
||||
candidate.source_code,
|
||||
function_to_optimize.file_path,
|
||||
module_path,
|
||||
[],
|
||||
project_root,
|
||||
)
|
||||
for module_path in module_paths
|
||||
}
|
||||
for candidate in candidates
|
||||
}
|
||||
return module_paths, optimized_code
|
||||
|
||||
|
||||
def is_optimized_module_code_zero_diff(
|
||||
candidates: list[OptimizedCandidate],
|
||||
validated_original_code: dict[Path, ValidCode],
|
||||
optimized_code: dict[str, dict[Path, str]],
|
||||
module_paths: set[Path],
|
||||
) -> dict[str, dict[Path, bool]]:
|
||||
return {
|
||||
candidate.optimization_id: {
|
||||
callee_module_path: normalize_code(optimized_code[candidate.optimization_id][callee_module_path])
|
||||
== validated_original_code[callee_module_path].normalized_code
|
||||
for callee_module_path in module_paths
|
||||
}
|
||||
for candidate in candidates
|
||||
}
|
||||
|
||||
|
||||
def candidates_with_diffs(
|
||||
candidates: list[OptimizedCandidate],
|
||||
validated_original_code: ValidCode,
|
||||
optimized_code: dict[str, dict[Path, str]],
|
||||
module_paths: set[Path],
|
||||
) -> list[OptimizedCandidate]:
|
||||
return [
|
||||
candidate
|
||||
for candidate in candidates
|
||||
if not all(
|
||||
is_optimized_module_code_zero_diff(candidates, validated_original_code, optimized_code, module_paths)[
|
||||
candidate.optimization_id
|
||||
].values()
|
||||
)
|
||||
]
|
||||
|
||||
|
||||
def replace_optimized_code_in_worktrees(
|
||||
optimized_code: dict[str, dict[Path, str]],
|
||||
candidates: list[OptimizedCandidate], # Should be candidates_with_diffs
|
||||
worktrees: list[Path],
|
||||
git_root: Path, # Handle None case
|
||||
) -> None:
|
||||
for candidate, worktree in zip(candidates, worktrees[1:]):
|
||||
for module_path in optimized_code[candidate.optimization_id]:
|
||||
(worktree / module_path.relative_to(git_root)).write_text(
|
||||
optimized_code[candidate.optimization_id][module_path], encoding="utf8"
|
||||
) # Check with is_optimized_module_code_zero_diff
|
||||
|
||||
|
||||
def function_to_optimize_original_worktree_fqn(
|
||||
function_to_optimize: FunctionToOptimize, worktrees: list[Path], git_root: Path
|
||||
) -> str:
|
||||
return (
|
||||
str(worktrees[0].name / function_to_optimize.file_path.relative_to(git_root).with_suffix("")).replace("/", ".")
|
||||
+ "."
|
||||
+ function_to_optimize.qualified_name
|
||||
)
|
||||
|
|
|
|||
|
|
@ -6,6 +6,8 @@ import logging
|
|||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
import libcst as cst
|
||||
|
||||
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
|
||||
from codeflash.languages.base import (
|
||||
CodeContext,
|
||||
|
|
@ -17,12 +19,18 @@ from codeflash.languages.base import (
|
|||
TestResult,
|
||||
)
|
||||
from codeflash.languages.registry import register_language
|
||||
from codeflash.models.function_types import FunctionParent
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import ast
|
||||
from collections.abc import Sequence
|
||||
|
||||
from libcst import CSTNode
|
||||
from libcst.metadata import CodeRange
|
||||
|
||||
from codeflash.languages.base import DependencyResolver
|
||||
from codeflash.models.models import FunctionSource, GeneratedTestsList, InvocationId
|
||||
from codeflash.models.models import FunctionSource, GeneratedTestsList, InvocationId, ValidCode
|
||||
from codeflash.verification.verification_utils import TestConfig
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
|
@ -41,6 +49,70 @@ def function_sources_to_helpers(sources: list[FunctionSource]) -> list[HelperFun
|
|||
]
|
||||
|
||||
|
||||
class ReturnStatementVisitor(cst.CSTVisitor):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.has_return_statement: bool = False
|
||||
|
||||
def visit_Return(self, node: cst.Return) -> None:
|
||||
self.has_return_statement = True
|
||||
|
||||
|
||||
class FunctionVisitor(cst.CSTVisitor):
|
||||
METADATA_DEPENDENCIES = (cst.metadata.PositionProvider, cst.metadata.ParentNodeProvider)
|
||||
|
||||
def __init__(self, file_path: Path) -> None:
|
||||
super().__init__()
|
||||
self.file_path: Path = file_path
|
||||
self.functions: list[FunctionToOptimize] = []
|
||||
|
||||
@staticmethod
|
||||
def is_pytest_fixture(node: cst.FunctionDef) -> bool:
|
||||
for decorator in node.decorators:
|
||||
dec = decorator.decorator
|
||||
if isinstance(dec, cst.Call):
|
||||
dec = dec.func
|
||||
if isinstance(dec, cst.Attribute) and dec.attr.value == "fixture":
|
||||
if isinstance(dec.value, cst.Name) and dec.value.value == "pytest":
|
||||
return True
|
||||
if isinstance(dec, cst.Name) and dec.value == "fixture":
|
||||
return True
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
def is_property(node: cst.FunctionDef) -> bool:
|
||||
for decorator in node.decorators:
|
||||
dec = decorator.decorator
|
||||
if isinstance(dec, cst.Name) and dec.value in ("property", "cached_property"):
|
||||
return True
|
||||
return False
|
||||
|
||||
def visit_FunctionDef(self, node: cst.FunctionDef) -> None:
|
||||
return_visitor: ReturnStatementVisitor = ReturnStatementVisitor()
|
||||
node.visit(return_visitor)
|
||||
if return_visitor.has_return_statement and not self.is_pytest_fixture(node) and not self.is_property(node):
|
||||
pos: CodeRange = self.get_metadata(cst.metadata.PositionProvider, node)
|
||||
parents: CSTNode | None = self.get_metadata(cst.metadata.ParentNodeProvider, node)
|
||||
ast_parents: list[FunctionParent] = []
|
||||
while parents is not None:
|
||||
if isinstance(parents, cst.FunctionDef):
|
||||
# Skip nested functions — only discover top-level and class-level functions
|
||||
return
|
||||
if isinstance(parents, cst.ClassDef):
|
||||
ast_parents.append(FunctionParent(parents.name.value, parents.__class__.__name__))
|
||||
parents = self.get_metadata(cst.metadata.ParentNodeProvider, parents, default=None)
|
||||
self.functions.append(
|
||||
FunctionToOptimize(
|
||||
function_name=node.name.value,
|
||||
file_path=self.file_path,
|
||||
parents=list(reversed(ast_parents)),
|
||||
starting_line=pos.start.line,
|
||||
ending_line=pos.end.line,
|
||||
is_async=bool(node.asynchronous),
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@register_language
|
||||
class PythonSupport:
|
||||
"""Python language support implementation.
|
||||
|
|
@ -107,30 +179,70 @@ class PythonSupport:
|
|||
}
|
||||
)
|
||||
|
||||
@property
|
||||
def default_language_version(self) -> str | None:
|
||||
return None
|
||||
|
||||
@property
|
||||
def valid_test_frameworks(self) -> tuple[str, ...]:
|
||||
return ("pytest", "unittest")
|
||||
|
||||
@property
|
||||
def test_result_serialization_format(self) -> str:
|
||||
return "pickle"
|
||||
|
||||
def load_coverage(
|
||||
self,
|
||||
coverage_database_file: Path,
|
||||
function_name: str,
|
||||
code_context: Any,
|
||||
source_file: Path,
|
||||
coverage_config_file: Path | None = None,
|
||||
) -> Any:
|
||||
from codeflash.verification.coverage_utils import CoverageUtils
|
||||
|
||||
return CoverageUtils.load_from_sqlite_database(
|
||||
database_path=coverage_database_file,
|
||||
config_path=coverage_config_file,
|
||||
source_code_path=source_file,
|
||||
code_context=code_context,
|
||||
function_name=function_name,
|
||||
)
|
||||
|
||||
def process_generated_test_strings(
|
||||
self,
|
||||
generated_test_source: str,
|
||||
instrumented_behavior_test_source: str,
|
||||
instrumented_perf_test_source: str,
|
||||
function_to_optimize: Any,
|
||||
test_path: Path,
|
||||
test_cfg: Any,
|
||||
project_module_system: str | None,
|
||||
) -> tuple[str, str, str]:
|
||||
from codeflash.code_utils.code_utils import get_run_tmp_file
|
||||
|
||||
temp_run_dir = get_run_tmp_file(Path()).as_posix()
|
||||
instrumented_behavior_test_source = instrumented_behavior_test_source.replace(
|
||||
"{codeflash_run_tmp_dir_client_side}", temp_run_dir
|
||||
)
|
||||
instrumented_perf_test_source = instrumented_perf_test_source.replace(
|
||||
"{codeflash_run_tmp_dir_client_side}", temp_run_dir
|
||||
)
|
||||
return generated_test_source, instrumented_behavior_test_source, instrumented_perf_test_source
|
||||
|
||||
def adjust_test_config_for_discovery(self, test_cfg: Any) -> None:
|
||||
pass
|
||||
|
||||
def detect_module_system(self, project_root: Path, source_file: Path) -> str | None:
|
||||
return None
|
||||
|
||||
# === Discovery ===
|
||||
|
||||
def discover_functions(
|
||||
self, file_path: Path, filter_criteria: FunctionFilterCriteria | None = None
|
||||
self, source: str, file_path: Path, filter_criteria: FunctionFilterCriteria | None = None
|
||||
) -> list[FunctionToOptimize]:
|
||||
"""Find all optimizable functions in a Python file.
|
||||
|
||||
Uses libcst to parse the file and find functions with return statements.
|
||||
|
||||
Args:
|
||||
file_path: Path to the Python file to analyze.
|
||||
filter_criteria: Optional criteria to filter functions.
|
||||
|
||||
Returns:
|
||||
List of FunctionToOptimize objects for discovered functions.
|
||||
|
||||
"""
|
||||
import libcst as cst
|
||||
|
||||
from codeflash.discovery.functions_to_optimize import FunctionVisitor
|
||||
|
||||
criteria = filter_criteria or FunctionFilterCriteria()
|
||||
|
||||
source = file_path.read_text(encoding="utf-8")
|
||||
tree = cst.parse_module(source)
|
||||
|
||||
wrapper = cst.metadata.MetadataWrapper(tree)
|
||||
|
|
@ -572,21 +684,10 @@ class PythonSupport:
|
|||
return False
|
||||
|
||||
def normalize_code(self, source: str) -> str:
|
||||
"""Normalize Python code for deduplication.
|
||||
|
||||
Removes comments, normalizes whitespace, and replaces variable names.
|
||||
|
||||
Args:
|
||||
source: Source code to normalize.
|
||||
|
||||
Returns:
|
||||
Normalized source code.
|
||||
|
||||
"""
|
||||
from codeflash.code_utils.deduplicate_code import normalize_code
|
||||
from codeflash.languages.python.normalizer import normalize_python_code
|
||||
|
||||
try:
|
||||
return normalize_code(source, remove_docstrings=True, language=Language.PYTHON)
|
||||
return normalize_python_code(source, remove_docstrings=True)
|
||||
except Exception:
|
||||
return source
|
||||
|
||||
|
|
@ -861,8 +962,335 @@ class PythonSupport:
|
|||
# Python uses line_profiler which has its own output format
|
||||
return {"timings": {}, "unit": 0, "str_out": ""}
|
||||
|
||||
@property
|
||||
def function_optimizer_class(self) -> type:
|
||||
from codeflash.languages.python.function_optimizer import PythonFunctionOptimizer
|
||||
|
||||
return PythonFunctionOptimizer
|
||||
|
||||
def prepare_module(
|
||||
self, module_code: str, module_path: Path, project_root: Path
|
||||
) -> tuple[dict[Path, ValidCode], ast.Module] | None:
|
||||
from codeflash.languages.python.optimizer import prepare_python_module
|
||||
|
||||
return prepare_python_module(module_code, module_path, project_root)
|
||||
|
||||
pytest_cmd: str = "pytest"
|
||||
|
||||
def setup_test_config(self, test_cfg: TestConfig, file_path: Path) -> None:
|
||||
self.pytest_cmd = test_cfg.pytest_cmd or "pytest"
|
||||
|
||||
def pytest_cmd_tokens(self, is_posix: bool) -> list[str]:
|
||||
import shlex
|
||||
|
||||
return shlex.split(self.pytest_cmd, posix=is_posix)
|
||||
|
||||
def build_pytest_cmd(self, safe_sys_executable: str, is_posix: bool) -> list[str]:
|
||||
return [safe_sys_executable, "-m", *self.pytest_cmd_tokens(is_posix)]
|
||||
|
||||
# === Test Execution (Full Protocol) ===
|
||||
# Note: For Python, test execution is handled by the main test_runner.py
|
||||
# which has special Python-specific logic. These methods are not called
|
||||
# for Python as the test_runner checks is_python() and uses the existing path.
|
||||
# They are defined here only for protocol compliance.
|
||||
|
||||
def run_behavioral_tests(
|
||||
self,
|
||||
test_paths: Any,
|
||||
test_env: dict[str, str],
|
||||
cwd: Path,
|
||||
timeout: int | None = None,
|
||||
project_root: Path | None = None,
|
||||
enable_coverage: bool = False,
|
||||
candidate_index: int = 0,
|
||||
) -> tuple[Path, Any, Path | None, Path | None]:
|
||||
import contextlib
|
||||
import shlex
|
||||
import sys
|
||||
|
||||
from codeflash.code_utils.code_utils import get_run_tmp_file
|
||||
from codeflash.code_utils.compat import IS_POSIX, SAFE_SYS_EXECUTABLE
|
||||
from codeflash.code_utils.config_consts import TOTAL_LOOPING_TIME_EFFECTIVE
|
||||
from codeflash.languages.python.static_analysis.coverage_utils import prepare_coverage_files
|
||||
from codeflash.models.models import TestType
|
||||
from codeflash.verification.test_runner import execute_test_subprocess
|
||||
|
||||
blocklisted_plugins = ["benchmark", "codspeed", "xdist", "sugar"]
|
||||
|
||||
test_files: list[str] = []
|
||||
for file in test_paths.test_files:
|
||||
if file.test_type == TestType.REPLAY_TEST:
|
||||
if file.tests_in_file:
|
||||
test_files.extend(
|
||||
[
|
||||
str(file.instrumented_behavior_file_path) + "::" + test.test_function
|
||||
for test in file.tests_in_file
|
||||
]
|
||||
)
|
||||
else:
|
||||
test_files.append(str(file.instrumented_behavior_file_path))
|
||||
|
||||
pytest_cmd_list = self.build_pytest_cmd(SAFE_SYS_EXECUTABLE, IS_POSIX)
|
||||
test_files = list(set(test_files))
|
||||
|
||||
common_pytest_args = [
|
||||
"--capture=tee-sys",
|
||||
"-q",
|
||||
"--codeflash_loops_scope=session",
|
||||
"--codeflash_min_loops=1",
|
||||
"--codeflash_max_loops=1",
|
||||
f"--codeflash_seconds={TOTAL_LOOPING_TIME_EFFECTIVE}",
|
||||
]
|
||||
if timeout is not None:
|
||||
common_pytest_args.append(f"--timeout={timeout}")
|
||||
|
||||
result_file_path = get_run_tmp_file(Path("pytest_results.xml"))
|
||||
result_args = [f"--junitxml={result_file_path.as_posix()}", "-o", "junit_logging=all"]
|
||||
|
||||
pytest_test_env = test_env.copy()
|
||||
pytest_test_env["PYTEST_PLUGINS"] = "codeflash.verification.pytest_plugin"
|
||||
|
||||
coverage_database_file: Path | None = None
|
||||
coverage_config_file: Path | None = None
|
||||
|
||||
if enable_coverage:
|
||||
coverage_database_file, coverage_config_file = prepare_coverage_files()
|
||||
pytest_test_env["NUMBA_DISABLE_JIT"] = str(1)
|
||||
pytest_test_env["TORCHDYNAMO_DISABLE"] = str(1)
|
||||
pytest_test_env["PYTORCH_JIT"] = str(0)
|
||||
pytest_test_env["TF_XLA_FLAGS"] = "--tf_xla_auto_jit=0"
|
||||
pytest_test_env["TF_ENABLE_ONEDNN_OPTS"] = str(0)
|
||||
pytest_test_env["JAX_DISABLE_JIT"] = str(0)
|
||||
|
||||
is_windows = sys.platform == "win32"
|
||||
if is_windows:
|
||||
if coverage_database_file.exists():
|
||||
with contextlib.suppress(PermissionError, OSError):
|
||||
coverage_database_file.unlink()
|
||||
else:
|
||||
cov_erase = execute_test_subprocess(
|
||||
shlex.split(f"{SAFE_SYS_EXECUTABLE} -m coverage erase"), cwd=cwd, env=pytest_test_env, timeout=30
|
||||
)
|
||||
logger.debug(cov_erase)
|
||||
coverage_cmd = [
|
||||
SAFE_SYS_EXECUTABLE,
|
||||
"-m",
|
||||
"coverage",
|
||||
"run",
|
||||
f"--rcfile={coverage_config_file.as_posix()}",
|
||||
"-m",
|
||||
]
|
||||
coverage_cmd.extend(self.pytest_cmd_tokens(IS_POSIX))
|
||||
|
||||
blocklist_args = [f"-p no:{plugin}" for plugin in blocklisted_plugins if plugin != "cov"]
|
||||
results = execute_test_subprocess(
|
||||
coverage_cmd + common_pytest_args + blocklist_args + result_args + test_files,
|
||||
cwd=cwd,
|
||||
env=pytest_test_env,
|
||||
timeout=600,
|
||||
)
|
||||
logger.debug("Result return code: %s, %s", results.returncode, results.stderr or "")
|
||||
else:
|
||||
blocklist_args = [f"-p no:{plugin}" for plugin in blocklisted_plugins]
|
||||
|
||||
results = execute_test_subprocess(
|
||||
pytest_cmd_list + common_pytest_args + blocklist_args + result_args + test_files,
|
||||
cwd=cwd,
|
||||
env=pytest_test_env,
|
||||
timeout=600,
|
||||
)
|
||||
logger.debug("Result return code: %s, %s", results.returncode, results.stderr or "")
|
||||
|
||||
return result_file_path, results, coverage_database_file, coverage_config_file
|
||||
|
||||
def run_benchmarking_tests(
|
||||
self,
|
||||
test_paths: Any,
|
||||
test_env: dict[str, str],
|
||||
cwd: Path,
|
||||
timeout: int | None = None,
|
||||
project_root: Path | None = None,
|
||||
min_loops: int = 5,
|
||||
max_loops: int = 100_000,
|
||||
target_duration_seconds: float = 10.0,
|
||||
) -> tuple[Path, Any]:
|
||||
|
||||
from codeflash.code_utils.code_utils import get_run_tmp_file
|
||||
from codeflash.code_utils.compat import IS_POSIX, SAFE_SYS_EXECUTABLE
|
||||
from codeflash.verification.test_runner import execute_test_subprocess
|
||||
|
||||
blocklisted_plugins = ["codspeed", "cov", "benchmark", "profiling", "xdist", "sugar"]
|
||||
|
||||
pytest_cmd_list = self.build_pytest_cmd(SAFE_SYS_EXECUTABLE, IS_POSIX)
|
||||
test_files: list[str] = list({str(file.benchmarking_file_path) for file in test_paths.test_files})
|
||||
pytest_args = [
|
||||
"--capture=tee-sys",
|
||||
"-q",
|
||||
"--codeflash_loops_scope=session",
|
||||
f"--codeflash_min_loops={min_loops}",
|
||||
f"--codeflash_max_loops={max_loops}",
|
||||
f"--codeflash_seconds={target_duration_seconds}",
|
||||
"--codeflash_stability_check=true",
|
||||
]
|
||||
if timeout is not None:
|
||||
pytest_args.append(f"--timeout={timeout}")
|
||||
|
||||
result_file_path = get_run_tmp_file(Path("pytest_results.xml"))
|
||||
result_args = [f"--junitxml={result_file_path.as_posix()}", "-o", "junit_logging=all"]
|
||||
pytest_test_env = test_env.copy()
|
||||
pytest_test_env["PYTEST_PLUGINS"] = "codeflash.verification.pytest_plugin"
|
||||
blocklist_args = [f"-p no:{plugin}" for plugin in blocklisted_plugins]
|
||||
results = execute_test_subprocess(
|
||||
pytest_cmd_list + pytest_args + blocklist_args + result_args + test_files,
|
||||
cwd=cwd,
|
||||
env=pytest_test_env,
|
||||
timeout=600,
|
||||
)
|
||||
return result_file_path, results
|
||||
|
||||
def run_line_profile_tests(
|
||||
self,
|
||||
test_paths: Any,
|
||||
test_env: dict[str, str],
|
||||
cwd: Path,
|
||||
timeout: int | None = None,
|
||||
project_root: Path | None = None,
|
||||
line_profile_output_file: Path | None = None,
|
||||
) -> tuple[Path, Any]:
|
||||
|
||||
from codeflash.code_utils.code_utils import get_run_tmp_file
|
||||
from codeflash.code_utils.compat import IS_POSIX, SAFE_SYS_EXECUTABLE
|
||||
from codeflash.code_utils.config_consts import TOTAL_LOOPING_TIME_EFFECTIVE
|
||||
from codeflash.verification.test_runner import execute_test_subprocess
|
||||
|
||||
blocklisted_plugins = ["codspeed", "cov", "benchmark", "profiling", "xdist", "sugar"]
|
||||
|
||||
pytest_cmd_list = self.build_pytest_cmd(SAFE_SYS_EXECUTABLE, IS_POSIX)
|
||||
test_files: list[str] = list({str(file.benchmarking_file_path) for file in test_paths.test_files})
|
||||
pytest_args = [
|
||||
"--capture=tee-sys",
|
||||
"-q",
|
||||
"--codeflash_loops_scope=session",
|
||||
"--codeflash_min_loops=1",
|
||||
"--codeflash_max_loops=1",
|
||||
f"--codeflash_seconds={TOTAL_LOOPING_TIME_EFFECTIVE}",
|
||||
]
|
||||
if timeout is not None:
|
||||
pytest_args.append(f"--timeout={timeout}")
|
||||
result_file_path = get_run_tmp_file(Path("pytest_results.xml"))
|
||||
result_args = [f"--junitxml={result_file_path.as_posix()}", "-o", "junit_logging=all"]
|
||||
pytest_test_env = test_env.copy()
|
||||
pytest_test_env["PYTEST_PLUGINS"] = "codeflash.verification.pytest_plugin"
|
||||
blocklist_args = [f"-p no:{plugin}" for plugin in blocklisted_plugins]
|
||||
pytest_test_env["LINE_PROFILE"] = "1"
|
||||
results = execute_test_subprocess(
|
||||
pytest_cmd_list + pytest_args + blocklist_args + result_args + test_files,
|
||||
cwd=cwd,
|
||||
env=pytest_test_env,
|
||||
timeout=600,
|
||||
)
|
||||
return result_file_path, results
|
||||
|
||||
def generate_concolic_tests(
|
||||
self, test_cfg: Any, project_root: Path, function_to_optimize: FunctionToOptimize, function_to_optimize_ast: Any
|
||||
) -> tuple[dict, str]:
|
||||
import ast
|
||||
import importlib.util
|
||||
import subprocess
|
||||
import tempfile
|
||||
import time
|
||||
|
||||
from codeflash.cli_cmds.console import console
|
||||
from codeflash.code_utils.compat import SAFE_SYS_EXECUTABLE
|
||||
from codeflash.code_utils.shell_utils import make_env_with_project_root
|
||||
from codeflash.discovery.discover_unit_tests import discover_unit_tests
|
||||
from codeflash.languages.python.static_analysis.concolic_utils import (
|
||||
clean_concolic_tests,
|
||||
is_valid_concolic_test,
|
||||
)
|
||||
from codeflash.languages.python.static_analysis.static_analysis import has_typed_parameters
|
||||
from codeflash.lsp.helpers import is_LSP_enabled
|
||||
from codeflash.telemetry.posthog_cf import ph
|
||||
from codeflash.verification.verification_utils import TestConfig
|
||||
|
||||
crosshair_available = importlib.util.find_spec("crosshair") is not None
|
||||
|
||||
start_time = time.perf_counter()
|
||||
function_to_concolic_tests: dict = {}
|
||||
concolic_test_suite_code = ""
|
||||
|
||||
if not crosshair_available:
|
||||
logger.debug("Skipping concolic test generation (crosshair-tool is not installed)")
|
||||
return function_to_concolic_tests, concolic_test_suite_code
|
||||
|
||||
if is_LSP_enabled():
|
||||
logger.debug("Skipping concolic test generation in LSP mode")
|
||||
return function_to_concolic_tests, concolic_test_suite_code
|
||||
|
||||
if (
|
||||
test_cfg.concolic_test_root_dir
|
||||
and isinstance(function_to_optimize_ast, ast.FunctionDef)
|
||||
and has_typed_parameters(function_to_optimize_ast, function_to_optimize.parents)
|
||||
):
|
||||
logger.info("Generating concolic opcode coverage tests for the original code…")
|
||||
console.rule()
|
||||
try:
|
||||
env = make_env_with_project_root(project_root)
|
||||
cover_result = subprocess.run(
|
||||
[
|
||||
SAFE_SYS_EXECUTABLE,
|
||||
"-m",
|
||||
"crosshair",
|
||||
"cover",
|
||||
"--example_output_format=pytest",
|
||||
"--per_condition_timeout=20",
|
||||
".".join(
|
||||
[
|
||||
function_to_optimize.file_path.relative_to(project_root)
|
||||
.with_suffix("")
|
||||
.as_posix()
|
||||
.replace("/", "."),
|
||||
function_to_optimize.qualified_name,
|
||||
]
|
||||
),
|
||||
],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
cwd=project_root,
|
||||
check=False,
|
||||
timeout=600,
|
||||
env=env,
|
||||
)
|
||||
except subprocess.TimeoutExpired:
|
||||
logger.debug("CrossHair Cover test generation timed out")
|
||||
return function_to_concolic_tests, concolic_test_suite_code
|
||||
|
||||
if cover_result.returncode == 0:
|
||||
generated_concolic_test: str = cover_result.stdout
|
||||
if not is_valid_concolic_test(generated_concolic_test, project_root=str(project_root)):
|
||||
logger.debug("CrossHair generated invalid test, skipping")
|
||||
console.rule()
|
||||
return function_to_concolic_tests, concolic_test_suite_code
|
||||
concolic_test_suite_code = clean_concolic_tests(generated_concolic_test)
|
||||
concolic_test_suite_dir = Path(tempfile.mkdtemp(dir=test_cfg.concolic_test_root_dir))
|
||||
concolic_test_suite_path = concolic_test_suite_dir / "test_concolic_coverage.py"
|
||||
concolic_test_suite_path.write_text(concolic_test_suite_code, encoding="utf8")
|
||||
|
||||
concolic_test_cfg = TestConfig(
|
||||
tests_root=concolic_test_suite_dir,
|
||||
tests_project_rootdir=test_cfg.concolic_test_root_dir,
|
||||
project_root_path=project_root,
|
||||
)
|
||||
function_to_concolic_tests, num_discovered_concolic_tests, _ = discover_unit_tests(concolic_test_cfg)
|
||||
logger.info(
|
||||
"Created %d concolic unit test case%s ",
|
||||
num_discovered_concolic_tests,
|
||||
"s" if num_discovered_concolic_tests != 1 else "",
|
||||
)
|
||||
console.rule()
|
||||
ph("cli-optimize-concolic-tests", {"num_tests": num_discovered_concolic_tests})
|
||||
|
||||
else:
|
||||
logger.debug(
|
||||
"Error running CrossHair Cover%s", ": " + cover_result.stderr if cover_result.stderr else "."
|
||||
)
|
||||
console.rule()
|
||||
end_time = time.perf_counter()
|
||||
logger.debug("Generated concolic tests in %.2f seconds", end_time - start_time)
|
||||
return function_to_concolic_tests, concolic_test_suite_code
|
||||
|
|
|
|||
|
|
@ -53,7 +53,10 @@ def _ensure_languages_registered() -> None:
|
|||
from codeflash.languages.python import support as _
|
||||
|
||||
with contextlib.suppress(ImportError):
|
||||
from codeflash.languages.javascript import support as _ # noqa: F401
|
||||
from codeflash.languages.javascript import support as _
|
||||
|
||||
with contextlib.suppress(ImportError):
|
||||
from codeflash.languages.java import support as _ # noqa: F401
|
||||
|
||||
_languages_registered = True
|
||||
|
||||
|
|
|
|||
|
|
@ -30,7 +30,7 @@ from __future__ import annotations
|
|||
|
||||
from typing import Literal
|
||||
|
||||
TestFramework = Literal["jest", "vitest", "mocha", "pytest", "unittest"]
|
||||
TestFramework = Literal["jest", "vitest", "mocha", "pytest", "unittest", "junit5", "junit4", "testng"]
|
||||
|
||||
# Module-level singleton for the current test framework
|
||||
_current_test_framework: TestFramework | None = None
|
||||
|
|
@ -63,11 +63,11 @@ def set_current_test_framework(framework: TestFramework | str | None) -> None:
|
|||
|
||||
if framework is not None:
|
||||
framework = framework.lower()
|
||||
if framework not in ("jest", "vitest", "mocha", "pytest", "unittest"):
|
||||
if framework not in ("jest", "vitest", "mocha", "pytest", "unittest", "junit5", "junit4", "testng"):
|
||||
# Default to jest for unknown JS frameworks, pytest for unknown Python
|
||||
from codeflash.languages.current import is_javascript
|
||||
from codeflash.languages.current import current_language_support
|
||||
|
||||
framework = "jest" if is_javascript() else "pytest"
|
||||
framework = current_language_support().test_framework
|
||||
|
||||
_current_test_framework = framework
|
||||
|
||||
|
|
|
|||
|
|
@ -463,14 +463,10 @@ def _initialize_current_function_optimizer() -> Union[dict[str, str], WrappedIni
|
|||
"message": "Failed to prepare module for optimization",
|
||||
}
|
||||
|
||||
validated_original_code, original_module_ast = module_prep_result
|
||||
validated_original_code, _original_module_ast = module_prep_result
|
||||
|
||||
function_optimizer = server.optimizer.create_function_optimizer(
|
||||
fto,
|
||||
function_to_optimize_source_code=validated_original_code[fto.file_path].source_code,
|
||||
original_module_ast=original_module_ast,
|
||||
original_module_path=fto.file_path,
|
||||
function_to_tests={},
|
||||
fto, function_to_optimize_source_code=validated_original_code[fto.file_path].source_code, function_to_tests={}
|
||||
)
|
||||
|
||||
server.optimizer.current_function_optimizer = function_optimizer
|
||||
|
|
|
|||
|
|
@ -245,8 +245,7 @@ class CodeString(BaseModel):
|
|||
"""Validate code syntax for the specified language."""
|
||||
if self.language == "python":
|
||||
validate_python_code(self.code)
|
||||
elif self.language in ("javascript", "typescript"):
|
||||
# Validate JavaScript/TypeScript syntax using language support
|
||||
else:
|
||||
from codeflash.languages.registry import get_language_support
|
||||
|
||||
lang_support = get_language_support(self.language)
|
||||
|
|
|
|||
|
|
@ -1,6 +1,5 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import ast
|
||||
import concurrent.futures
|
||||
import dataclasses
|
||||
import logging
|
||||
|
|
@ -59,32 +58,16 @@ from codeflash.code_utils.config_consts import (
|
|||
EffortLevel,
|
||||
get_effort_value,
|
||||
)
|
||||
from codeflash.code_utils.deduplicate_code import normalize_code
|
||||
from codeflash.code_utils.env_utils import get_pr_number
|
||||
from codeflash.code_utils.formatter import format_code, format_generated_code, sort_imports
|
||||
from codeflash.code_utils.git_utils import git_root_dir
|
||||
from codeflash.code_utils.instrument_existing_tests import inject_profiling_into_existing_test
|
||||
from codeflash.code_utils.shell_utils import make_env_with_project_root
|
||||
from codeflash.code_utils.time_utils import humanize_runtime
|
||||
from codeflash.discovery.functions_to_optimize import was_function_previously_optimized
|
||||
from codeflash.either import Failure, Success, is_successful
|
||||
from codeflash.languages import is_python
|
||||
from codeflash.languages.base import Language
|
||||
from codeflash.languages.current import current_language_support
|
||||
from codeflash.languages.javascript.test_runner import clear_created_config_files, get_created_config_files
|
||||
from codeflash.languages.python.context import code_context_extractor
|
||||
from codeflash.languages.python.context.unused_definition_remover import (
|
||||
detect_unused_helper_functions,
|
||||
revert_unused_helper_functions,
|
||||
)
|
||||
from codeflash.languages.python.static_analysis.code_extractor import get_opt_review_metrics, is_numerical_code
|
||||
from codeflash.languages.python.static_analysis.code_replacer import (
|
||||
add_custom_marker_to_all_tests,
|
||||
modify_autouse_fixture,
|
||||
replace_function_definitions_in_module,
|
||||
)
|
||||
from codeflash.languages.python.static_analysis.line_profile_utils import add_decorator_imports, contains_jit_decorator
|
||||
from codeflash.languages.python.static_analysis.static_analysis import get_first_top_level_function_or_method_ast
|
||||
from codeflash.lsp.helpers import is_LSP_enabled, is_subagent_mode, report_to_markdown_table, tree_to_markdown
|
||||
from codeflash.lsp.lsp_message import LspCodeMessage, LspMarkdownMessage, LSPMessageId
|
||||
from codeflash.models.ExperimentMetadata import ExperimentMetadata
|
||||
|
|
@ -94,7 +77,6 @@ from codeflash.models.models import (
|
|||
AIServiceCodeRepairRequest,
|
||||
BestOptimization,
|
||||
CandidateEvaluationContext,
|
||||
CodeOptimizationContext,
|
||||
GeneratedTests,
|
||||
GeneratedTestsList,
|
||||
OptimizationReviewResult,
|
||||
|
|
@ -121,27 +103,23 @@ from codeflash.result.critic import (
|
|||
)
|
||||
from codeflash.result.explanation import Explanation
|
||||
from codeflash.telemetry.posthog_cf import ph
|
||||
from codeflash.verification.concolic_testing import generate_concolic_tests
|
||||
from codeflash.verification.equivalence import compare_test_results
|
||||
from codeflash.verification.instrument_codeflash_capture import instrument_codeflash_capture
|
||||
from codeflash.verification.parse_line_profile_test_output import parse_line_profile_results
|
||||
from codeflash.verification.parse_test_output import (
|
||||
calculate_function_throughput_from_test_results,
|
||||
parse_concurrency_metrics,
|
||||
parse_test_results,
|
||||
)
|
||||
from codeflash.verification.test_runner import run_behavioral_tests, run_benchmarking_tests, run_line_profile_tests
|
||||
from codeflash.verification.parse_test_output import parse_concurrency_metrics, parse_test_results
|
||||
from codeflash.verification.verification_utils import get_test_file_path
|
||||
from codeflash.verification.verifier import generate_tests
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import ast
|
||||
from argparse import Namespace
|
||||
from typing import Any
|
||||
|
||||
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
|
||||
from codeflash.either import Result
|
||||
from codeflash.languages.base import DependencyResolver
|
||||
from codeflash.models.function_types import FunctionParent
|
||||
from codeflash.models.models import (
|
||||
BenchmarkKey,
|
||||
CodeOptimizationContext,
|
||||
CodeStringsMarkdown,
|
||||
ConcurrencyMetrics,
|
||||
CoverageData,
|
||||
|
|
@ -459,14 +437,9 @@ class FunctionOptimizer:
|
|||
)
|
||||
self.language_support = current_language_support()
|
||||
if not function_to_optimize_ast:
|
||||
# Skip Python AST parsing for non-Python languages
|
||||
if not is_python():
|
||||
self.function_to_optimize_ast = None
|
||||
else:
|
||||
original_module_ast = ast.parse(function_to_optimize_source_code)
|
||||
self.function_to_optimize_ast = get_first_top_level_function_or_method_ast(
|
||||
function_to_optimize.function_name, function_to_optimize.parents, original_module_ast
|
||||
)
|
||||
self.function_to_optimize_ast = self._resolve_function_ast(
|
||||
self.function_to_optimize_source_code, function_to_optimize.function_name, function_to_optimize.parents
|
||||
)
|
||||
else:
|
||||
self.function_to_optimize_ast = function_to_optimize_ast
|
||||
self.function_to_tests = function_to_tests if function_to_tests else {}
|
||||
|
|
@ -503,6 +476,71 @@ class FunctionOptimizer:
|
|||
self.is_numerical_code: bool | None = None
|
||||
self.code_already_exists: bool = False
|
||||
|
||||
# --- Hooks for language-specific subclasses ---
|
||||
|
||||
def _resolve_function_ast(
|
||||
self, source_code: str, function_name: str, parents: list[FunctionParent]
|
||||
) -> ast.FunctionDef | ast.AsyncFunctionDef | None:
|
||||
return None
|
||||
|
||||
def requires_function_ast(self) -> bool:
|
||||
return False
|
||||
|
||||
def analyze_code_characteristics(self, code_context: CodeOptimizationContext) -> None:
|
||||
pass
|
||||
|
||||
def get_optimization_review_metrics(
|
||||
self,
|
||||
source_code: str,
|
||||
file_path: Path,
|
||||
qualified_name: str,
|
||||
project_root: Path,
|
||||
tests_root: Path,
|
||||
language: Language,
|
||||
) -> str:
|
||||
return ""
|
||||
|
||||
def instrument_test_fixtures(self, test_paths: list[Path]) -> dict[Path, list[str]] | None:
|
||||
return None
|
||||
|
||||
def instrument_async_for_mode(self, mode: TestingMode) -> None:
|
||||
pass
|
||||
|
||||
def instrument_capture(self, file_path_to_helper_classes: dict[Path, set[str]]) -> None:
|
||||
pass
|
||||
|
||||
def should_check_coverage(self) -> bool:
|
||||
return False
|
||||
|
||||
def collect_async_metrics(
|
||||
self,
|
||||
benchmarking_results: TestResults,
|
||||
code_context: CodeOptimizationContext,
|
||||
helper_code: dict[Path, str],
|
||||
test_env: dict[str, str],
|
||||
) -> tuple[int | None, ConcurrencyMetrics | None]:
|
||||
return None, None
|
||||
|
||||
def compare_candidate_results(
|
||||
self,
|
||||
baseline_results: OriginalCodeBaseline,
|
||||
candidate_behavior_results: TestResults,
|
||||
optimization_candidate_index: int,
|
||||
) -> tuple[bool, list[TestDiff]]:
|
||||
return compare_test_results(
|
||||
baseline_results.behavior_test_results, candidate_behavior_results, pass_fail_only=True
|
||||
)
|
||||
|
||||
def should_skip_sqlite_cleanup(self, testing_type: TestingMode, optimization_iteration: int) -> bool:
|
||||
return False
|
||||
|
||||
def parse_line_profile_test_results(
|
||||
self, line_profiler_output_file: Path | None
|
||||
) -> tuple[TestResults | dict, CoverageData | None]:
|
||||
return TestResults(test_results=[]), None
|
||||
|
||||
# --- End hooks ---
|
||||
|
||||
def can_be_optimized(self) -> Result[tuple[bool, CodeOptimizationContext, dict[Path, str]], str]:
|
||||
should_run_experiment = self.experiment_id is not None
|
||||
logger.info(f"!lsp|Function Trace ID: {self.function_trace_id}")
|
||||
|
|
@ -656,10 +694,7 @@ class FunctionOptimizer:
|
|||
|
||||
original_conftest_content = None
|
||||
if self.args.override_fixtures:
|
||||
logger.info("Disabling all autouse fixtures associated with the generated test files")
|
||||
original_conftest_content = modify_autouse_fixture(generated_test_paths + generated_perf_test_paths)
|
||||
logger.info("Add custom marker to generated test files")
|
||||
add_custom_marker_to_all_tests(generated_test_paths + generated_perf_test_paths)
|
||||
original_conftest_content = self.instrument_test_fixtures(generated_test_paths + generated_perf_test_paths)
|
||||
|
||||
return Success(
|
||||
(
|
||||
|
|
@ -679,7 +714,7 @@ class FunctionOptimizer:
|
|||
if not is_successful(initialization_result):
|
||||
return Failure(initialization_result.failure())
|
||||
should_run_experiment, code_context, original_helper_code = initialization_result.unwrap()
|
||||
self.is_numerical_code = is_numerical_code(code_string=code_context.read_writable_code.flat)
|
||||
self.analyze_code_characteristics(code_context)
|
||||
code_print(
|
||||
code_context.read_writable_code.flat,
|
||||
file_name=self.function_to_optimize.file_path,
|
||||
|
|
@ -929,7 +964,9 @@ class FunctionOptimizer:
|
|||
runtimes_list = []
|
||||
|
||||
for valid_opt in eval_ctx.valid_optimizations:
|
||||
valid_opt_normalized_code = normalize_code(valid_opt.candidate.source_code.flat.strip())
|
||||
valid_opt_normalized_code = self.language_support.normalize_code(
|
||||
valid_opt.candidate.source_code.flat.strip()
|
||||
)
|
||||
new_candidate_with_shorter_code = OptimizedCandidate(
|
||||
source_code=eval_ctx.ast_code_to_id[valid_opt_normalized_code]["shorter_source_code"],
|
||||
optimization_id=valid_opt.candidate.optimization_id,
|
||||
|
|
@ -1036,7 +1073,7 @@ class FunctionOptimizer:
|
|||
|
||||
candidate = candidate_node.candidate
|
||||
|
||||
normalized_code = normalize_code(candidate.source_code.flat.strip())
|
||||
normalized_code = self.language_support.normalize_code(candidate.source_code.flat.strip())
|
||||
|
||||
if normalized_code == normalized_original:
|
||||
logger.info(f"h3|Candidate {candidate_index}/{total_candidates}: Identical to original code, skipping.")
|
||||
|
|
@ -1248,7 +1285,7 @@ class FunctionOptimizer:
|
|||
self.future_adaptive_optimizations,
|
||||
)
|
||||
candidate_index = 0
|
||||
normalized_original = normalize_code(code_context.read_writable_code.flat.strip())
|
||||
normalized_original = self.language_support.normalize_code(code_context.read_writable_code.flat.strip())
|
||||
|
||||
# Process candidates using queue-based approach
|
||||
while not processor.is_done():
|
||||
|
|
@ -1493,57 +1530,24 @@ class FunctionOptimizer:
|
|||
|
||||
return new_code, new_helper_code
|
||||
|
||||
def group_functions_by_file(self, code_context: CodeOptimizationContext) -> dict[Path, set[str]]:
|
||||
functions_by_file: dict[Path, set[str]] = defaultdict(set)
|
||||
functions_by_file[self.function_to_optimize.file_path].add(self.function_to_optimize.qualified_name)
|
||||
for helper in code_context.helper_functions:
|
||||
if helper.definition_type != "class":
|
||||
functions_by_file[helper.file_path].add(helper.qualified_name)
|
||||
return functions_by_file
|
||||
|
||||
def replace_function_and_helpers_with_optimized_code(
|
||||
self,
|
||||
code_context: CodeOptimizationContext,
|
||||
optimized_code: CodeStringsMarkdown,
|
||||
original_helper_code: dict[Path, str],
|
||||
) -> bool:
|
||||
did_update = False
|
||||
read_writable_functions_by_file_path = defaultdict(set)
|
||||
read_writable_functions_by_file_path[self.function_to_optimize.file_path].add(
|
||||
self.function_to_optimize.qualified_name
|
||||
)
|
||||
for helper_function in code_context.helper_functions:
|
||||
# Skip class definitions (definition_type may be None for non-Python languages)
|
||||
if helper_function.definition_type != "class":
|
||||
read_writable_functions_by_file_path[helper_function.file_path].add(helper_function.qualified_name)
|
||||
for module_abspath, qualified_names in read_writable_functions_by_file_path.items():
|
||||
did_update |= replace_function_definitions_in_module(
|
||||
function_names=list(qualified_names),
|
||||
optimized_code=optimized_code,
|
||||
module_abspath=module_abspath,
|
||||
preexisting_objects=code_context.preexisting_objects,
|
||||
project_root_path=self.project_root,
|
||||
)
|
||||
unused_helpers = detect_unused_helper_functions(self.function_to_optimize, code_context, optimized_code)
|
||||
|
||||
# Revert unused helper functions to their original definitions
|
||||
if unused_helpers:
|
||||
revert_unused_helper_functions(self.project_root, unused_helpers, original_helper_code)
|
||||
|
||||
return did_update
|
||||
raise NotImplementedError
|
||||
|
||||
def get_code_optimization_context(self) -> Result[CodeOptimizationContext, str]:
|
||||
try:
|
||||
new_code_ctx = code_context_extractor.get_code_optimization_context(
|
||||
self.function_to_optimize, self.project_root, call_graph=self.call_graph
|
||||
)
|
||||
except ValueError as e:
|
||||
return Failure(str(e))
|
||||
|
||||
return Success(
|
||||
CodeOptimizationContext(
|
||||
testgen_context=new_code_ctx.testgen_context,
|
||||
read_writable_code=new_code_ctx.read_writable_code,
|
||||
read_only_context_code=new_code_ctx.read_only_context_code,
|
||||
hashing_code_context=new_code_ctx.hashing_code_context,
|
||||
hashing_code_context_hash=new_code_ctx.hashing_code_context_hash,
|
||||
helper_functions=new_code_ctx.helper_functions,
|
||||
testgen_helper_fqns=new_code_ctx.testgen_helper_fqns,
|
||||
preexisting_objects=new_code_ctx.preexisting_objects,
|
||||
)
|
||||
)
|
||||
raise NotImplementedError
|
||||
|
||||
@staticmethod
|
||||
def cleanup_leftover_test_return_values() -> None:
|
||||
|
|
@ -1560,178 +1564,27 @@ class FunctionOptimizer:
|
|||
func_qualname = self.function_to_optimize.qualified_name_with_modules_from_root(self.project_root)
|
||||
if func_qualname not in function_to_all_tests:
|
||||
logger.info(f"Did not find any pre-existing tests for '{func_qualname}', will only use generated tests.")
|
||||
# Handle non-Python existing test instrumentation
|
||||
elif not is_python():
|
||||
test_file_invocation_positions = defaultdict(list)
|
||||
for tests_in_file in function_to_all_tests.get(func_qualname):
|
||||
test_file_invocation_positions[
|
||||
(tests_in_file.tests_in_file.test_file, tests_in_file.tests_in_file.test_type)
|
||||
].append(tests_in_file)
|
||||
return unique_instrumented_test_files
|
||||
|
||||
for (test_file, test_type), tests_in_file_list in test_file_invocation_positions.items():
|
||||
path_obj_test_file = Path(test_file)
|
||||
if test_type == TestType.EXISTING_UNIT_TEST:
|
||||
existing_test_files_count += 1
|
||||
elif test_type == TestType.REPLAY_TEST:
|
||||
replay_test_files_count += 1
|
||||
elif test_type == TestType.CONCOLIC_COVERAGE_TEST:
|
||||
concolic_coverage_test_files_count += 1
|
||||
else:
|
||||
msg = f"Unexpected test type: {test_type}"
|
||||
raise ValueError(msg)
|
||||
test_file_invocation_positions = defaultdict(list)
|
||||
for tests_in_file in function_to_all_tests.get(func_qualname):
|
||||
test_file_invocation_positions[
|
||||
(tests_in_file.tests_in_file.test_file, tests_in_file.tests_in_file.test_type)
|
||||
].append(tests_in_file)
|
||||
|
||||
# Use language-specific instrumentation
|
||||
success, injected_behavior_test = self.language_support.instrument_existing_test(
|
||||
test_path=path_obj_test_file,
|
||||
call_positions=[test.position for test in tests_in_file_list],
|
||||
function_to_optimize=self.function_to_optimize,
|
||||
tests_project_root=self.test_cfg.tests_project_rootdir,
|
||||
mode="behavior",
|
||||
)
|
||||
if not success:
|
||||
logger.debug(f"Failed to instrument test file {test_file} for behavior testing")
|
||||
continue
|
||||
|
||||
success, injected_perf_test = self.language_support.instrument_existing_test(
|
||||
test_path=path_obj_test_file,
|
||||
call_positions=[test.position for test in tests_in_file_list],
|
||||
function_to_optimize=self.function_to_optimize,
|
||||
tests_project_root=self.test_cfg.tests_project_rootdir,
|
||||
mode="performance",
|
||||
)
|
||||
if not success:
|
||||
logger.debug(f"Failed to instrument test file {test_file} for performance testing")
|
||||
continue
|
||||
|
||||
# Generate instrumented test file paths
|
||||
# For JS/TS, preserve .test.ts or .spec.ts suffix for Jest pattern matching
|
||||
def get_instrumented_path(original_path: str, suffix: str) -> Path:
|
||||
"""Generate instrumented test file path preserving .test/.spec pattern."""
|
||||
path_obj = Path(original_path)
|
||||
stem = path_obj.stem # e.g., "fibonacci.test"
|
||||
ext = path_obj.suffix # e.g., ".ts"
|
||||
|
||||
# Check for .test or .spec in stem (JS/TS pattern)
|
||||
if ".test" in stem:
|
||||
# fibonacci.test -> fibonacci__suffix.test
|
||||
base, _ = stem.rsplit(".test", 1)
|
||||
new_stem = f"{base}{suffix}.test"
|
||||
elif ".spec" in stem:
|
||||
base, _ = stem.rsplit(".spec", 1)
|
||||
new_stem = f"{base}{suffix}.spec"
|
||||
else:
|
||||
# Default Python-style: add suffix before extension
|
||||
new_stem = f"{stem}{suffix}"
|
||||
|
||||
return path_obj.parent / f"{new_stem}{ext}"
|
||||
|
||||
new_behavioral_test_path = get_instrumented_path(test_file, "__perfinstrumented")
|
||||
new_perf_test_path = get_instrumented_path(test_file, "__perfonlyinstrumented")
|
||||
|
||||
if injected_behavior_test is not None:
|
||||
with new_behavioral_test_path.open("w", encoding="utf8") as _f:
|
||||
_f.write(injected_behavior_test)
|
||||
logger.debug(f"[PIPELINE] Wrote instrumented behavior test to {new_behavioral_test_path}")
|
||||
else:
|
||||
msg = "injected_behavior_test is None"
|
||||
raise ValueError(msg)
|
||||
|
||||
if injected_perf_test is not None:
|
||||
with new_perf_test_path.open("w", encoding="utf8") as _f:
|
||||
_f.write(injected_perf_test)
|
||||
logger.debug(f"[PIPELINE] Wrote instrumented perf test to {new_perf_test_path}")
|
||||
|
||||
unique_instrumented_test_files.add(new_behavioral_test_path)
|
||||
unique_instrumented_test_files.add(new_perf_test_path)
|
||||
|
||||
if not self.test_files.get_by_original_file_path(path_obj_test_file):
|
||||
self.test_files.add(
|
||||
TestFile(
|
||||
instrumented_behavior_file_path=new_behavioral_test_path,
|
||||
benchmarking_file_path=new_perf_test_path,
|
||||
original_source=None,
|
||||
original_file_path=Path(test_file),
|
||||
test_type=test_type,
|
||||
tests_in_file=[t.tests_in_file for t in tests_in_file_list],
|
||||
)
|
||||
)
|
||||
|
||||
if existing_test_files_count > 0 or replay_test_files_count > 0 or concolic_coverage_test_files_count > 0:
|
||||
logger.info(
|
||||
f"Instrumented {existing_test_files_count} existing unit test file"
|
||||
f"{'s' if existing_test_files_count != 1 else ''}, {replay_test_files_count} replay test file"
|
||||
f"{'s' if replay_test_files_count != 1 else ''}, and "
|
||||
f"{concolic_coverage_test_files_count} concolic coverage test file"
|
||||
f"{'s' if concolic_coverage_test_files_count != 1 else ''} for {func_qualname}"
|
||||
)
|
||||
console.rule()
|
||||
else:
|
||||
test_file_invocation_positions = defaultdict(list)
|
||||
for tests_in_file in function_to_all_tests.get(func_qualname):
|
||||
test_file_invocation_positions[
|
||||
(tests_in_file.tests_in_file.test_file, tests_in_file.tests_in_file.test_type)
|
||||
].append(tests_in_file)
|
||||
for (test_file, test_type), tests_in_file_list in test_file_invocation_positions.items():
|
||||
path_obj_test_file = Path(test_file)
|
||||
if test_type == TestType.EXISTING_UNIT_TEST:
|
||||
existing_test_files_count += 1
|
||||
elif test_type == TestType.REPLAY_TEST:
|
||||
replay_test_files_count += 1
|
||||
elif test_type == TestType.CONCOLIC_COVERAGE_TEST:
|
||||
concolic_coverage_test_files_count += 1
|
||||
else:
|
||||
msg = f"Unexpected test type: {test_type}"
|
||||
raise ValueError(msg)
|
||||
success, injected_behavior_test = inject_profiling_into_existing_test(
|
||||
mode=TestingMode.BEHAVIOR,
|
||||
test_path=path_obj_test_file,
|
||||
call_positions=[test.position for test in tests_in_file_list],
|
||||
function_to_optimize=self.function_to_optimize,
|
||||
tests_project_root=self.test_cfg.tests_project_rootdir,
|
||||
)
|
||||
if not success:
|
||||
continue
|
||||
success, injected_perf_test = inject_profiling_into_existing_test(
|
||||
mode=TestingMode.PERFORMANCE,
|
||||
test_path=path_obj_test_file,
|
||||
call_positions=[test.position for test in tests_in_file_list],
|
||||
function_to_optimize=self.function_to_optimize,
|
||||
tests_project_root=self.test_cfg.tests_project_rootdir,
|
||||
)
|
||||
if not success:
|
||||
continue
|
||||
# TODO: this naming logic should be moved to a function and made more standard
|
||||
new_behavioral_test_path = Path(
|
||||
f"{os.path.splitext(test_file)[0]}__perfinstrumented{os.path.splitext(test_file)[1]}" # noqa: PTH122
|
||||
)
|
||||
new_perf_test_path = Path(
|
||||
f"{os.path.splitext(test_file)[0]}__perfonlyinstrumented{os.path.splitext(test_file)[1]}" # noqa: PTH122
|
||||
)
|
||||
if injected_behavior_test is not None:
|
||||
with new_behavioral_test_path.open("w", encoding="utf8") as _f:
|
||||
_f.write(injected_behavior_test)
|
||||
else:
|
||||
msg = "injected_behavior_test is None"
|
||||
raise ValueError(msg)
|
||||
if injected_perf_test is not None:
|
||||
with new_perf_test_path.open("w", encoding="utf8") as _f:
|
||||
_f.write(injected_perf_test)
|
||||
|
||||
unique_instrumented_test_files.add(new_behavioral_test_path)
|
||||
unique_instrumented_test_files.add(new_perf_test_path)
|
||||
|
||||
if not self.test_files.get_by_original_file_path(path_obj_test_file):
|
||||
self.test_files.add(
|
||||
TestFile(
|
||||
instrumented_behavior_file_path=new_behavioral_test_path,
|
||||
benchmarking_file_path=new_perf_test_path,
|
||||
original_source=None,
|
||||
original_file_path=Path(test_file),
|
||||
test_type=test_type,
|
||||
tests_in_file=[t.tests_in_file for t in tests_in_file_list],
|
||||
)
|
||||
)
|
||||
for test_file, test_type in test_file_invocation_positions:
|
||||
path_obj_test_file = Path(test_file)
|
||||
if test_type == TestType.EXISTING_UNIT_TEST:
|
||||
existing_test_files_count += 1
|
||||
elif test_type == TestType.REPLAY_TEST:
|
||||
replay_test_files_count += 1
|
||||
elif test_type == TestType.CONCOLIC_COVERAGE_TEST:
|
||||
concolic_coverage_test_files_count += 1
|
||||
else:
|
||||
msg = f"Unexpected test type: {test_type}"
|
||||
raise ValueError(msg)
|
||||
|
||||
if existing_test_files_count > 0 or replay_test_files_count > 0 or concolic_coverage_test_files_count > 0:
|
||||
logger.info(
|
||||
f"Discovered {existing_test_files_count} existing unit test file"
|
||||
f"{'s' if existing_test_files_count != 1 else ''}, {replay_test_files_count} replay test file"
|
||||
|
|
@ -1740,6 +1593,87 @@ class FunctionOptimizer:
|
|||
f"{'s' if concolic_coverage_test_files_count != 1 else ''} for {func_qualname}"
|
||||
)
|
||||
console.rule()
|
||||
|
||||
for (test_file, test_type), tests_in_file_list in test_file_invocation_positions.items():
|
||||
path_obj_test_file = Path(test_file)
|
||||
# Use language-specific instrumentation
|
||||
success, injected_behavior_test = self.language_support.instrument_existing_test(
|
||||
test_path=path_obj_test_file,
|
||||
call_positions=[test.position for test in tests_in_file_list],
|
||||
function_to_optimize=self.function_to_optimize,
|
||||
tests_project_root=self.test_cfg.tests_project_rootdir,
|
||||
mode="behavior",
|
||||
)
|
||||
if not success:
|
||||
logger.debug(f"Failed to instrument test file {test_file} for behavior testing")
|
||||
continue
|
||||
|
||||
success, injected_perf_test = self.language_support.instrument_existing_test(
|
||||
test_path=path_obj_test_file,
|
||||
call_positions=[test.position for test in tests_in_file_list],
|
||||
function_to_optimize=self.function_to_optimize,
|
||||
tests_project_root=self.test_cfg.tests_project_rootdir,
|
||||
mode="performance",
|
||||
)
|
||||
if not success:
|
||||
logger.debug(f"Failed to instrument test file {test_file} for performance testing")
|
||||
continue
|
||||
|
||||
# For JS/TS, preserve .test.ts or .spec.ts suffix for Jest pattern matching
|
||||
def get_instrumented_path(original_path: str, suffix: str) -> Path:
|
||||
path_obj = Path(original_path)
|
||||
stem = path_obj.stem
|
||||
ext = path_obj.suffix
|
||||
|
||||
if ".test" in stem:
|
||||
base, _ = stem.rsplit(".test", 1)
|
||||
new_stem = f"{base}{suffix}.test"
|
||||
elif ".spec" in stem:
|
||||
base, _ = stem.rsplit(".spec", 1)
|
||||
new_stem = f"{base}{suffix}.spec"
|
||||
else:
|
||||
new_stem = f"{stem}{suffix}"
|
||||
|
||||
return path_obj.parent / f"{new_stem}{ext}"
|
||||
|
||||
new_behavioral_test_path = get_instrumented_path(test_file, "__perfinstrumented")
|
||||
new_perf_test_path = get_instrumented_path(test_file, "__perfonlyinstrumented")
|
||||
|
||||
if injected_behavior_test is not None:
|
||||
with new_behavioral_test_path.open("w", encoding="utf8") as _f:
|
||||
_f.write(injected_behavior_test)
|
||||
logger.debug(f"[PIPELINE] Wrote instrumented behavior test to {new_behavioral_test_path}")
|
||||
else:
|
||||
msg = "injected_behavior_test is None"
|
||||
raise ValueError(msg)
|
||||
|
||||
if injected_perf_test is not None:
|
||||
with new_perf_test_path.open("w", encoding="utf8") as _f:
|
||||
_f.write(injected_perf_test)
|
||||
logger.debug(f"[PIPELINE] Wrote instrumented perf test to {new_perf_test_path}")
|
||||
|
||||
unique_instrumented_test_files.add(new_behavioral_test_path)
|
||||
unique_instrumented_test_files.add(new_perf_test_path)
|
||||
|
||||
if not self.test_files.get_by_original_file_path(path_obj_test_file):
|
||||
self.test_files.add(
|
||||
TestFile(
|
||||
instrumented_behavior_file_path=new_behavioral_test_path,
|
||||
benchmarking_file_path=new_perf_test_path,
|
||||
original_source=None,
|
||||
original_file_path=Path(test_file),
|
||||
test_type=test_type,
|
||||
tests_in_file=[t.tests_in_file for t in tests_in_file_list],
|
||||
)
|
||||
)
|
||||
|
||||
instrumented_count = len(unique_instrumented_test_files) // 2 # each test produces behavior + perf files
|
||||
if instrumented_count > 0:
|
||||
logger.info(
|
||||
f"Instrumented {instrumented_count} existing unit test file"
|
||||
f"{'s' if instrumented_count != 1 else ''} for {func_qualname}"
|
||||
)
|
||||
console.rule()
|
||||
return unique_instrumented_test_files
|
||||
|
||||
def generate_tests(
|
||||
|
|
@ -1764,9 +1698,9 @@ class FunctionOptimizer:
|
|||
future_concolic_tests = None
|
||||
else:
|
||||
future_concolic_tests = self.executor.submit(
|
||||
generate_concolic_tests,
|
||||
self.language_support.generate_concolic_tests,
|
||||
self.test_cfg,
|
||||
self.args,
|
||||
self.args.project_root,
|
||||
self.function_to_optimize,
|
||||
self.function_to_optimize_ast,
|
||||
)
|
||||
|
|
@ -1842,7 +1776,7 @@ class FunctionOptimizer:
|
|||
)
|
||||
|
||||
future_references = self.executor.submit(
|
||||
get_opt_review_metrics,
|
||||
self.get_optimization_review_metrics,
|
||||
self.function_to_optimize_source_code,
|
||||
self.function_to_optimize.file_path,
|
||||
self.function_to_optimize.qualified_name,
|
||||
|
|
@ -1933,8 +1867,7 @@ class FunctionOptimizer:
|
|||
original_code_baseline, test_functions_to_remove = baseline_result.unwrap()
|
||||
# Check test quantity for all languages
|
||||
quantity_ok = quantity_of_tests_critic(original_code_baseline)
|
||||
# TODO: {Self} Only check coverage for Python - coverage infrastructure not yet reliable for JS/TS
|
||||
coverage_ok = coverage_critic(original_code_baseline.coverage_results) if is_python() else True
|
||||
coverage_ok = coverage_critic(original_code_baseline.coverage_results) if self.should_check_coverage() else True
|
||||
if isinstance(original_code_baseline, OriginalCodeBaseline) and (not coverage_ok or not quantity_ok):
|
||||
if self.args.override_fixtures:
|
||||
restore_conftest(original_conftest_content)
|
||||
|
|
@ -2140,7 +2073,6 @@ class FunctionOptimizer:
|
|||
|
||||
if (
|
||||
self.function_to_optimize.is_async
|
||||
and is_python()
|
||||
and original_code_baseline.async_throughput is not None
|
||||
and best_optimization.async_throughput is not None
|
||||
):
|
||||
|
|
@ -2344,25 +2276,13 @@ class FunctionOptimizer:
|
|||
|
||||
test_env = self.get_test_env(codeflash_loop_index=0, codeflash_test_iteration=0, codeflash_tracer_disable=1)
|
||||
|
||||
if self.function_to_optimize.is_async and is_python():
|
||||
from codeflash.code_utils.instrument_existing_tests import add_async_decorator_to_function
|
||||
|
||||
success = add_async_decorator_to_function(
|
||||
self.function_to_optimize.file_path,
|
||||
self.function_to_optimize,
|
||||
TestingMode.BEHAVIOR,
|
||||
project_root=self.project_root,
|
||||
)
|
||||
if self.function_to_optimize.is_async:
|
||||
self.instrument_async_for_mode(TestingMode.BEHAVIOR)
|
||||
|
||||
# Instrument codeflash capture
|
||||
with progress_bar("Running tests to establish original code behavior..."):
|
||||
try:
|
||||
# Only instrument Python code here - non-Python languages use their own runtime helpers
|
||||
# which are already included in the generated/instrumented tests
|
||||
if is_python():
|
||||
instrument_codeflash_capture(
|
||||
self.function_to_optimize, file_path_to_helper_classes, self.test_cfg.tests_root
|
||||
)
|
||||
self.instrument_capture(file_path_to_helper_classes)
|
||||
|
||||
total_looping_time = TOTAL_LOOPING_TIME_EFFECTIVE
|
||||
logger.debug(f"[PIPELINE] Establishing baseline with {len(self.test_files)} test files")
|
||||
|
|
@ -2391,7 +2311,7 @@ class FunctionOptimizer:
|
|||
console.rule()
|
||||
return Failure("Failed to establish a baseline for the original code - bevhavioral tests failed.")
|
||||
# Skip coverage check for non-Python languages (coverage not yet supported)
|
||||
if is_python() and not coverage_critic(coverage_results):
|
||||
if self.should_check_coverage() and not coverage_critic(coverage_results):
|
||||
did_pass_all_tests = all(result.did_pass for result in behavioral_results)
|
||||
if not did_pass_all_tests:
|
||||
return Failure("Tests failed to pass for the original code.")
|
||||
|
|
@ -2412,15 +2332,8 @@ class FunctionOptimizer:
|
|||
for idx, tf in enumerate(self.test_files.test_files):
|
||||
logger.debug(f"[BENCHMARK-FILES] Test file {idx}: perf_file={tf.benchmarking_file_path}")
|
||||
|
||||
if self.function_to_optimize.is_async and is_python():
|
||||
from codeflash.code_utils.instrument_existing_tests import add_async_decorator_to_function
|
||||
|
||||
add_async_decorator_to_function(
|
||||
self.function_to_optimize.file_path,
|
||||
self.function_to_optimize,
|
||||
TestingMode.PERFORMANCE,
|
||||
project_root=self.project_root,
|
||||
)
|
||||
if self.function_to_optimize.is_async:
|
||||
self.instrument_async_for_mode(TestingMode.PERFORMANCE)
|
||||
|
||||
try:
|
||||
benchmarking_results, _ = self.run_and_parse_tests(
|
||||
|
|
@ -2472,22 +2385,9 @@ class FunctionOptimizer:
|
|||
console.rule()
|
||||
logger.debug(f"Total original code runtime (ns): {total_timing}")
|
||||
|
||||
async_throughput = None
|
||||
concurrency_metrics = None
|
||||
if self.function_to_optimize.is_async and is_python():
|
||||
async_throughput = calculate_function_throughput_from_test_results(
|
||||
benchmarking_results, self.function_to_optimize.function_name
|
||||
)
|
||||
logger.debug(f"Original async function throughput: {async_throughput} calls/second")
|
||||
|
||||
concurrency_metrics = self.run_concurrency_benchmark(
|
||||
code_context=code_context, original_helper_code=original_helper_code, test_env=test_env
|
||||
)
|
||||
if concurrency_metrics:
|
||||
logger.debug(
|
||||
f"Original concurrency metrics: ratio={concurrency_metrics.concurrency_ratio:.2f}, "
|
||||
f"seq={concurrency_metrics.sequential_time_ns}ns, conc={concurrency_metrics.concurrent_time_ns}ns"
|
||||
)
|
||||
async_throughput, concurrency_metrics = self.collect_async_metrics(
|
||||
benchmarking_results, code_context, original_helper_code, test_env
|
||||
)
|
||||
|
||||
if self.args.benchmark:
|
||||
replay_benchmarking_test_results = benchmarking_results.group_by_benchmarks(
|
||||
|
|
@ -2589,22 +2489,11 @@ class FunctionOptimizer:
|
|||
candidate_helper_code = {}
|
||||
for module_abspath in original_helper_code:
|
||||
candidate_helper_code[module_abspath] = Path(module_abspath).read_text("utf-8")
|
||||
if self.function_to_optimize.is_async and is_python():
|
||||
from codeflash.code_utils.instrument_existing_tests import add_async_decorator_to_function
|
||||
|
||||
add_async_decorator_to_function(
|
||||
self.function_to_optimize.file_path,
|
||||
self.function_to_optimize,
|
||||
TestingMode.BEHAVIOR,
|
||||
project_root=self.project_root,
|
||||
)
|
||||
if self.function_to_optimize.is_async:
|
||||
self.instrument_async_for_mode(TestingMode.BEHAVIOR)
|
||||
|
||||
try:
|
||||
# Only instrument Python code here - non-Python languages use their own runtime helpers
|
||||
if is_python():
|
||||
instrument_codeflash_capture(
|
||||
self.function_to_optimize, file_path_to_helper_classes, self.test_cfg.tests_root
|
||||
)
|
||||
self.instrument_capture(file_path_to_helper_classes)
|
||||
|
||||
total_looping_time = TOTAL_LOOPING_TIME_EFFECTIVE
|
||||
candidate_behavior_results, _ = self.run_and_parse_tests(
|
||||
|
|
@ -2615,13 +2504,10 @@ class FunctionOptimizer:
|
|||
testing_time=total_looping_time,
|
||||
enable_coverage=False,
|
||||
)
|
||||
# Remove instrumentation
|
||||
finally:
|
||||
# Only restore code for Python - non-Python tests are self-contained
|
||||
if is_python():
|
||||
self.write_code_and_helpers(
|
||||
candidate_fto_code, candidate_helper_code, self.function_to_optimize.file_path
|
||||
)
|
||||
self.write_code_and_helpers(
|
||||
candidate_fto_code, candidate_helper_code, self.function_to_optimize.file_path
|
||||
)
|
||||
console.print(
|
||||
TestResults.report_to_tree(
|
||||
candidate_behavior_results.get_test_pass_fail_report_by_type(),
|
||||
|
|
@ -2630,30 +2516,9 @@ class FunctionOptimizer:
|
|||
)
|
||||
console.rule()
|
||||
|
||||
# Use language-appropriate comparison
|
||||
if not is_python():
|
||||
# Non-Python: Compare using language support with SQLite results if available
|
||||
original_sqlite = get_run_tmp_file(Path("test_return_values_0.sqlite"))
|
||||
candidate_sqlite = get_run_tmp_file(Path(f"test_return_values_{optimization_candidate_index}.sqlite"))
|
||||
|
||||
if original_sqlite.exists() and candidate_sqlite.exists():
|
||||
# Full comparison using captured return values via language support
|
||||
# Use js_project_root where node_modules is located
|
||||
js_root = self.test_cfg.js_project_root or self.args.project_root
|
||||
match, diffs = self.language_support.compare_test_results(
|
||||
original_sqlite, candidate_sqlite, project_root=js_root
|
||||
)
|
||||
# Cleanup SQLite files after comparison
|
||||
candidate_sqlite.unlink(missing_ok=True)
|
||||
else:
|
||||
# Fallback: compare test pass/fail status (tests aren't instrumented yet)
|
||||
# If all tests that passed for original also pass for candidate, consider it a match
|
||||
match, diffs = compare_test_results(
|
||||
baseline_results.behavior_test_results, candidate_behavior_results, pass_fail_only=True
|
||||
)
|
||||
else:
|
||||
# Python: Compare using Python comparator
|
||||
match, diffs = compare_test_results(baseline_results.behavior_test_results, candidate_behavior_results)
|
||||
match, diffs = self.compare_candidate_results(
|
||||
baseline_results, candidate_behavior_results, optimization_candidate_index
|
||||
)
|
||||
|
||||
if match:
|
||||
logger.info("h3|Test results matched ✅")
|
||||
|
|
@ -2667,16 +2532,8 @@ class FunctionOptimizer:
|
|||
logger.info(f"loading|Running performance tests for candidate {optimization_candidate_index}...")
|
||||
console.rule()
|
||||
|
||||
# For async functions, instrument at definition site for performance benchmarking
|
||||
if self.function_to_optimize.is_async and is_python():
|
||||
from codeflash.code_utils.instrument_existing_tests import add_async_decorator_to_function
|
||||
|
||||
add_async_decorator_to_function(
|
||||
self.function_to_optimize.file_path,
|
||||
self.function_to_optimize,
|
||||
TestingMode.PERFORMANCE,
|
||||
project_root=self.project_root,
|
||||
)
|
||||
if self.function_to_optimize.is_async:
|
||||
self.instrument_async_for_mode(TestingMode.PERFORMANCE)
|
||||
|
||||
try:
|
||||
candidate_benchmarking_results, _ = self.run_and_parse_tests(
|
||||
|
|
@ -2688,8 +2545,7 @@ class FunctionOptimizer:
|
|||
enable_coverage=False,
|
||||
)
|
||||
finally:
|
||||
# Restore original source if we instrumented it
|
||||
if self.function_to_optimize.is_async and is_python():
|
||||
if self.function_to_optimize.is_async:
|
||||
self.write_code_and_helpers(
|
||||
candidate_fto_code, candidate_helper_code, self.function_to_optimize.file_path
|
||||
)
|
||||
|
|
@ -2704,23 +2560,9 @@ class FunctionOptimizer:
|
|||
|
||||
logger.debug(f"Total optimized code {optimization_candidate_index} runtime (ns): {total_candidate_timing}")
|
||||
|
||||
candidate_async_throughput = None
|
||||
candidate_concurrency_metrics = None
|
||||
if self.function_to_optimize.is_async and is_python():
|
||||
candidate_async_throughput = calculate_function_throughput_from_test_results(
|
||||
candidate_benchmarking_results, self.function_to_optimize.function_name
|
||||
)
|
||||
logger.debug(f"Candidate async function throughput: {candidate_async_throughput} calls/second")
|
||||
|
||||
# Run concurrency benchmark for candidate
|
||||
candidate_concurrency_metrics = self.run_concurrency_benchmark(
|
||||
code_context=code_context, original_helper_code=candidate_helper_code, test_env=test_env
|
||||
)
|
||||
if candidate_concurrency_metrics:
|
||||
logger.debug(
|
||||
f"Candidate concurrency metrics: ratio={candidate_concurrency_metrics.concurrency_ratio:.2f}, "
|
||||
f"seq={candidate_concurrency_metrics.sequential_time_ns}ns, conc={candidate_concurrency_metrics.concurrent_time_ns}ns"
|
||||
)
|
||||
candidate_async_throughput, candidate_concurrency_metrics = self.collect_async_metrics(
|
||||
candidate_benchmarking_results, code_context, candidate_helper_code, test_env
|
||||
)
|
||||
|
||||
if self.args.benchmark:
|
||||
candidate_replay_benchmarking_results = candidate_benchmarking_results.group_by_benchmarks(
|
||||
|
|
@ -2764,40 +2606,36 @@ class FunctionOptimizer:
|
|||
coverage_config_file = None
|
||||
try:
|
||||
if testing_type == TestingMode.BEHAVIOR:
|
||||
result_file_path, run_result, coverage_database_file, coverage_config_file = run_behavioral_tests(
|
||||
test_files,
|
||||
test_framework=self.test_cfg.test_framework,
|
||||
cwd=self.project_root,
|
||||
test_env=test_env,
|
||||
pytest_timeout=INDIVIDUAL_TESTCASE_TIMEOUT,
|
||||
enable_coverage=enable_coverage,
|
||||
js_project_root=self.test_cfg.js_project_root,
|
||||
candidate_index=optimization_iteration,
|
||||
result_file_path, run_result, coverage_database_file, coverage_config_file = (
|
||||
self.language_support.run_behavioral_tests(
|
||||
test_paths=test_files,
|
||||
test_env=test_env,
|
||||
cwd=self.project_root,
|
||||
timeout=INDIVIDUAL_TESTCASE_TIMEOUT,
|
||||
project_root=self.test_cfg.js_project_root,
|
||||
enable_coverage=enable_coverage,
|
||||
candidate_index=optimization_iteration,
|
||||
)
|
||||
)
|
||||
elif testing_type == TestingMode.LINE_PROFILE:
|
||||
result_file_path, run_result = run_line_profile_tests(
|
||||
test_files,
|
||||
cwd=self.project_root,
|
||||
result_file_path, run_result = self.language_support.run_line_profile_tests(
|
||||
test_paths=test_files,
|
||||
test_env=test_env,
|
||||
pytest_cmd=self.test_cfg.pytest_cmd,
|
||||
pytest_timeout=INDIVIDUAL_TESTCASE_TIMEOUT,
|
||||
pytest_target_runtime_seconds=testing_time,
|
||||
test_framework=self.test_cfg.test_framework,
|
||||
js_project_root=self.test_cfg.js_project_root,
|
||||
line_profiler_output_file=line_profiler_output_file,
|
||||
cwd=self.project_root,
|
||||
timeout=INDIVIDUAL_TESTCASE_TIMEOUT,
|
||||
project_root=self.test_cfg.js_project_root,
|
||||
line_profile_output_file=line_profiler_output_file,
|
||||
)
|
||||
elif testing_type == TestingMode.PERFORMANCE:
|
||||
result_file_path, run_result = run_benchmarking_tests(
|
||||
test_files,
|
||||
cwd=self.project_root,
|
||||
result_file_path, run_result = self.language_support.run_benchmarking_tests(
|
||||
test_paths=test_files,
|
||||
test_env=test_env,
|
||||
pytest_cmd=self.test_cfg.pytest_cmd,
|
||||
pytest_timeout=INDIVIDUAL_TESTCASE_TIMEOUT,
|
||||
pytest_target_runtime_seconds=testing_time,
|
||||
pytest_min_loops=pytest_min_loops,
|
||||
pytest_max_loops=pytest_max_loops,
|
||||
test_framework=self.test_cfg.test_framework,
|
||||
js_project_root=self.test_cfg.js_project_root,
|
||||
cwd=self.project_root,
|
||||
timeout=INDIVIDUAL_TESTCASE_TIMEOUT,
|
||||
project_root=self.test_cfg.js_project_root,
|
||||
min_loops=pytest_min_loops,
|
||||
max_loops=pytest_max_loops,
|
||||
target_duration_seconds=testing_time,
|
||||
)
|
||||
else:
|
||||
msg = f"Unexpected testing type: {testing_type}"
|
||||
|
|
@ -2829,9 +2667,7 @@ class FunctionOptimizer:
|
|||
console.print(panel)
|
||||
|
||||
if testing_type in {TestingMode.BEHAVIOR, TestingMode.PERFORMANCE}:
|
||||
# For non-Python behavior tests, skip SQLite cleanup - files needed for language-native comparison
|
||||
non_python_original_code = not is_python() and optimization_iteration == 0
|
||||
skip_cleanup = (not is_python() and testing_type == TestingMode.BEHAVIOR) or non_python_original_code
|
||||
skip_cleanup = self.should_skip_sqlite_cleanup(testing_type, optimization_iteration)
|
||||
|
||||
results, coverage_results = parse_test_results(
|
||||
test_xml_path=result_file_path,
|
||||
|
|
@ -2849,13 +2685,7 @@ class FunctionOptimizer:
|
|||
if testing_type == TestingMode.PERFORMANCE:
|
||||
results.perf_stdout = run_result.stdout
|
||||
return results, coverage_results
|
||||
# For LINE_PROFILE mode, Python uses .lprof files while JavaScript uses JSON
|
||||
# Return TestResults for JavaScript so _line_profiler_step_javascript can parse the JSON
|
||||
if not is_python():
|
||||
# Return TestResults to indicate tests ran, actual parsing happens in _line_profiler_step_javascript
|
||||
return TestResults(test_results=[]), None
|
||||
results, coverage_results = parse_line_profile_results(line_profiler_output_file=line_profiler_output_file)
|
||||
return results, coverage_results
|
||||
return self.parse_line_profile_test_results(line_profiler_output_file)
|
||||
|
||||
def submit_test_generation_tasks(
|
||||
self,
|
||||
|
|
@ -2912,108 +2742,9 @@ class FunctionOptimizer:
|
|||
|
||||
def line_profiler_step(
|
||||
self, code_context: CodeOptimizationContext, original_helper_code: dict[Path, str], candidate_index: int
|
||||
) -> dict:
|
||||
# Dispatch to language-specific implementation
|
||||
if is_python():
|
||||
return self._line_profiler_step_python(code_context, original_helper_code, candidate_index)
|
||||
|
||||
if self.language_support is not None and hasattr(self.language_support, "instrument_source_for_line_profiler"):
|
||||
try:
|
||||
line_profiler_output_path = get_run_tmp_file(Path("line_profiler_output.json"))
|
||||
# NOTE: currently this handles single file only, add support to multi file instrumentation (or should it be kept for the main file only)
|
||||
original_source = Path(self.function_to_optimize.file_path).read_text()
|
||||
# Instrument source code
|
||||
success = self.language_support.instrument_source_for_line_profiler(
|
||||
func_info=self.function_to_optimize, line_profiler_output_file=line_profiler_output_path
|
||||
)
|
||||
if not success:
|
||||
return {"timings": {}, "unit": 0, "str_out": ""}
|
||||
|
||||
test_env = self.get_test_env(
|
||||
codeflash_loop_index=0, codeflash_test_iteration=candidate_index, codeflash_tracer_disable=1
|
||||
)
|
||||
|
||||
_test_results, _ = self.run_and_parse_tests(
|
||||
testing_type=TestingMode.LINE_PROFILE,
|
||||
test_env=test_env,
|
||||
test_files=self.test_files,
|
||||
optimization_iteration=0,
|
||||
testing_time=TOTAL_LOOPING_TIME_EFFECTIVE,
|
||||
enable_coverage=False,
|
||||
code_context=code_context,
|
||||
line_profiler_output_file=line_profiler_output_path,
|
||||
)
|
||||
|
||||
if not hasattr(self.language_support, "parse_line_profile_results"):
|
||||
raise ValueError("Language support does not implement parse_line_profile_results") # noqa: TRY301
|
||||
|
||||
return self.language_support.parse_line_profile_results(line_profiler_output_path)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to run line profiling: {e}")
|
||||
return {"timings": {}, "unit": 0, "str_out": ""}
|
||||
finally:
|
||||
# restore original source
|
||||
Path(self.function_to_optimize.file_path).write_text(original_source)
|
||||
|
||||
logger.warning(f"Language support for {self.language_support.language} doesn't support line profiling")
|
||||
) -> dict[str, Any]:
|
||||
return {"timings": {}, "unit": 0, "str_out": ""}
|
||||
|
||||
def _line_profiler_step_python(
|
||||
self, code_context: CodeOptimizationContext, original_helper_code: dict[Path, str], candidate_index: int
|
||||
) -> dict:
|
||||
"""Python-specific line profiler using decorator imports."""
|
||||
# Check if candidate code contains JIT decorators - line profiler doesn't work with JIT compiled code
|
||||
candidate_fto_code = Path(self.function_to_optimize.file_path).read_text("utf-8")
|
||||
if contains_jit_decorator(candidate_fto_code):
|
||||
logger.info(
|
||||
f"Skipping line profiler for {self.function_to_optimize.function_name} - code contains JIT decorator"
|
||||
)
|
||||
return {"timings": {}, "unit": 0, "str_out": ""}
|
||||
|
||||
# Check helper code for JIT decorators
|
||||
for module_abspath in original_helper_code:
|
||||
candidate_helper_code = Path(module_abspath).read_text("utf-8")
|
||||
if contains_jit_decorator(candidate_helper_code):
|
||||
logger.info(
|
||||
f"Skipping line profiler for {self.function_to_optimize.function_name} - helper code contains JIT decorator"
|
||||
)
|
||||
return {"timings": {}, "unit": 0, "str_out": ""}
|
||||
|
||||
try:
|
||||
console.rule()
|
||||
|
||||
test_env = self.get_test_env(
|
||||
codeflash_loop_index=0, codeflash_test_iteration=candidate_index, codeflash_tracer_disable=1
|
||||
)
|
||||
line_profiler_output_file = add_decorator_imports(self.function_to_optimize, code_context)
|
||||
line_profile_results, _ = self.run_and_parse_tests(
|
||||
testing_type=TestingMode.LINE_PROFILE,
|
||||
test_env=test_env,
|
||||
test_files=self.test_files,
|
||||
optimization_iteration=0,
|
||||
testing_time=TOTAL_LOOPING_TIME_EFFECTIVE,
|
||||
enable_coverage=False,
|
||||
code_context=code_context,
|
||||
line_profiler_output_file=line_profiler_output_file,
|
||||
)
|
||||
finally:
|
||||
# Remove codeflash capture
|
||||
self.write_code_and_helpers(
|
||||
self.function_to_optimize_source_code, original_helper_code, self.function_to_optimize.file_path
|
||||
)
|
||||
# this will happen when a timeoutexpired exception happens
|
||||
if isinstance(line_profile_results, TestResults) and not line_profile_results.test_results:
|
||||
logger.warning(
|
||||
f"Timeout occurred while running line profiler for original function {self.function_to_optimize.function_name}"
|
||||
)
|
||||
# set default value for line profiler results
|
||||
return {"timings": {}, "unit": 0, "str_out": ""}
|
||||
if line_profile_results["str_out"] == "":
|
||||
logger.warning(
|
||||
f"Couldn't run line profiler for original function {self.function_to_optimize.function_name}"
|
||||
)
|
||||
return line_profile_results
|
||||
|
||||
def run_concurrency_benchmark(
|
||||
self, code_context: CodeOptimizationContext, original_helper_code: dict[Path, str], test_env: dict[str, str]
|
||||
) -> ConcurrencyMetrics | None:
|
||||
|
|
|
|||
|
|
@ -1,6 +1,5 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import ast
|
||||
import copy
|
||||
import os
|
||||
import tempfile
|
||||
|
|
@ -30,20 +29,20 @@ from codeflash.code_utils.git_worktree_utils import (
|
|||
)
|
||||
from codeflash.code_utils.time_utils import humanize_runtime
|
||||
from codeflash.either import is_successful
|
||||
from codeflash.languages import current_language_support, is_javascript, set_current_language
|
||||
from codeflash.languages import current_language_support, set_current_language
|
||||
from codeflash.lsp.helpers import is_subagent_mode
|
||||
from codeflash.models.models import ValidCode
|
||||
from codeflash.telemetry.posthog_cf import ph
|
||||
from codeflash.verification.verification_utils import TestConfig
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import ast
|
||||
from argparse import Namespace
|
||||
|
||||
from codeflash.benchmarking.function_ranker import FunctionRanker
|
||||
from codeflash.code_utils.checkpoint import CodeflashRunCheckpoint
|
||||
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
|
||||
from codeflash.languages.base import DependencyResolver
|
||||
from codeflash.models.models import BenchmarkKey, FunctionCalledInTest
|
||||
from codeflash.models.models import BenchmarkKey, FunctionCalledInTest, ValidCode
|
||||
from codeflash.optimization.function_optimizer import FunctionOptimizer
|
||||
|
||||
|
||||
|
|
@ -72,59 +71,6 @@ class Optimizer:
|
|||
self.original_args_and_test_cfg: tuple[Namespace, TestConfig] | None = None
|
||||
self.patch_files: list[Path] = []
|
||||
|
||||
@staticmethod
|
||||
def _find_js_project_root(file_path: Path) -> Path | None:
|
||||
"""Find the JavaScript/TypeScript project root by looking for package.json.
|
||||
|
||||
Traverses up from the given file path to find the nearest directory
|
||||
containing package.json or jest.config.js.
|
||||
|
||||
Args:
|
||||
file_path: A file path within the JavaScript project.
|
||||
|
||||
Returns:
|
||||
The project root directory, or None if not found.
|
||||
|
||||
"""
|
||||
current = file_path.parent if file_path.is_file() else file_path
|
||||
while current != current.parent: # Stop at filesystem root
|
||||
if (
|
||||
(current / "package.json").exists()
|
||||
or (current / "jest.config.js").exists()
|
||||
or (current / "jest.config.ts").exists()
|
||||
or (current / "tsconfig.json").exists()
|
||||
):
|
||||
return current
|
||||
current = current.parent
|
||||
return None
|
||||
|
||||
def _verify_js_requirements(self) -> None:
|
||||
"""Verify JavaScript/TypeScript requirements before optimization.
|
||||
|
||||
Checks that Node.js, npm, and the test framework are available.
|
||||
Logs warnings if requirements are not met but does not abort.
|
||||
|
||||
"""
|
||||
from codeflash.languages import get_language_support
|
||||
from codeflash.languages.base import Language
|
||||
from codeflash.languages.test_framework import get_js_test_framework_or_default
|
||||
|
||||
js_project_root = self.test_cfg.js_project_root
|
||||
if not js_project_root:
|
||||
return
|
||||
|
||||
try:
|
||||
js_support = get_language_support(Language.JAVASCRIPT)
|
||||
test_framework = get_js_test_framework_or_default()
|
||||
success, errors = js_support.verify_requirements(js_project_root, test_framework)
|
||||
|
||||
if not success:
|
||||
logger.warning("JavaScript requirements check found issues:")
|
||||
for error in errors:
|
||||
logger.warning(f" - {error}")
|
||||
except Exception as e:
|
||||
logger.debug(f"Failed to verify JS requirements: {e}")
|
||||
|
||||
def run_benchmarks(
|
||||
self, file_to_funcs_to_optimize: dict[Path, list[FunctionToOptimize]], num_optimizable_functions: int
|
||||
) -> tuple[dict[str, dict[BenchmarkKey, float]], dict[BenchmarkKey, float]]:
|
||||
|
|
@ -247,26 +193,8 @@ class Optimizer:
|
|||
function_to_optimize_source_code: str | None = "",
|
||||
function_benchmark_timings: dict[str, dict[BenchmarkKey, float]] | None = None,
|
||||
total_benchmark_timings: dict[BenchmarkKey, float] | None = None,
|
||||
original_module_ast: ast.Module | None = None,
|
||||
original_module_path: Path | None = None,
|
||||
call_graph: DependencyResolver | None = None,
|
||||
) -> FunctionOptimizer | None:
|
||||
from codeflash.languages.python.static_analysis.static_analysis import (
|
||||
get_first_top_level_function_or_method_ast,
|
||||
)
|
||||
from codeflash.optimization.function_optimizer import FunctionOptimizer
|
||||
|
||||
if function_to_optimize_ast is None and original_module_ast is not None:
|
||||
function_to_optimize_ast = get_first_top_level_function_or_method_ast(
|
||||
function_to_optimize.function_name, function_to_optimize.parents, original_module_ast
|
||||
)
|
||||
if function_to_optimize_ast is None:
|
||||
logger.info(
|
||||
f"Function {function_to_optimize.qualified_name} not found in {original_module_path}.\n"
|
||||
f"Skipping optimization."
|
||||
)
|
||||
return None
|
||||
|
||||
qualified_name_w_module = function_to_optimize.qualified_name_with_modules_from_root(self.args.project_root)
|
||||
|
||||
function_specific_timings = None
|
||||
|
|
@ -279,7 +207,11 @@ class Optimizer:
|
|||
):
|
||||
function_specific_timings = function_benchmark_timings[qualified_name_w_module]
|
||||
|
||||
return FunctionOptimizer(
|
||||
cls = current_language_support().function_optimizer_class
|
||||
|
||||
# TODO: _resolve_function_ast re-parses source via ast.parse() per function, even when the caller already
|
||||
# has a parsed module AST. Consider passing the pre-parsed AST through to avoid redundant parsing.
|
||||
function_optimizer = cls(
|
||||
function_to_optimize=function_to_optimize,
|
||||
test_cfg=self.test_cfg,
|
||||
function_to_optimize_source_code=function_to_optimize_source_code,
|
||||
|
|
@ -292,62 +224,26 @@ class Optimizer:
|
|||
replay_tests_dir=self.replay_tests_dir,
|
||||
call_graph=call_graph,
|
||||
)
|
||||
if function_optimizer.function_to_optimize_ast is None and function_optimizer.requires_function_ast():
|
||||
logger.info(
|
||||
f"Function {function_to_optimize.qualified_name} not found in "
|
||||
f"{function_to_optimize.file_path}.\nSkipping optimization."
|
||||
)
|
||||
return None
|
||||
return function_optimizer
|
||||
|
||||
def prepare_module_for_optimization(
|
||||
self, original_module_path: Path
|
||||
) -> tuple[dict[Path, ValidCode], ast.Module | None] | None:
|
||||
from codeflash.languages.python.static_analysis.code_replacer import normalize_code, normalize_node
|
||||
from codeflash.languages.python.static_analysis.static_analysis import analyze_imported_modules
|
||||
|
||||
logger.info(f"loading|Examining file {original_module_path!s}")
|
||||
console.rule()
|
||||
|
||||
original_module_code: str = original_module_path.read_text(encoding="utf8")
|
||||
|
||||
# For JavaScript/TypeScript, skip Python-specific AST parsing
|
||||
if is_javascript():
|
||||
validated_original_code: dict[Path, ValidCode] = {
|
||||
original_module_path: ValidCode(source_code=original_module_code, normalized_code=original_module_code)
|
||||
}
|
||||
return validated_original_code, None
|
||||
|
||||
# Python-specific parsing
|
||||
try:
|
||||
original_module_ast = ast.parse(original_module_code)
|
||||
except SyntaxError as e:
|
||||
logger.warning(f"Syntax error parsing code in {original_module_path}: {e}")
|
||||
logger.info("Skipping optimization due to file error.")
|
||||
return None
|
||||
normalized_original_module_code = ast.unparse(normalize_node(original_module_ast))
|
||||
validated_original_code = {
|
||||
original_module_path: ValidCode(
|
||||
source_code=original_module_code, normalized_code=normalized_original_module_code
|
||||
)
|
||||
}
|
||||
|
||||
imported_module_analyses = analyze_imported_modules(
|
||||
return current_language_support().prepare_module(
|
||||
original_module_code, original_module_path, self.args.project_root
|
||||
)
|
||||
|
||||
has_syntax_error = False
|
||||
for analysis in imported_module_analyses:
|
||||
callee_original_code = analysis.file_path.read_text(encoding="utf8")
|
||||
try:
|
||||
normalized_callee_original_code = normalize_code(callee_original_code)
|
||||
except SyntaxError as e:
|
||||
logger.warning(f"Syntax error parsing code in callee module {analysis.file_path}: {e}")
|
||||
logger.info("Skipping optimization due to helper file error.")
|
||||
has_syntax_error = True
|
||||
break
|
||||
validated_original_code[analysis.file_path] = ValidCode(
|
||||
source_code=callee_original_code, normalized_code=normalized_callee_original_code
|
||||
)
|
||||
|
||||
if has_syntax_error:
|
||||
return None
|
||||
|
||||
return validated_original_code, original_module_ast
|
||||
|
||||
def discover_tests(
|
||||
self, file_to_funcs_to_optimize: dict[Path, list[FunctionToOptimize]]
|
||||
) -> tuple[dict[str, set[FunctionCalledInTest]], int]:
|
||||
|
|
@ -556,11 +452,7 @@ class Optimizer:
|
|||
if funcs and funcs[0].language:
|
||||
set_current_language(funcs[0].language)
|
||||
self.test_cfg.set_language(funcs[0].language)
|
||||
# For JavaScript, also set js_project_root for test execution
|
||||
if is_javascript():
|
||||
self.test_cfg.js_project_root = self._find_js_project_root(file_path)
|
||||
# Verify JS requirements before proceeding
|
||||
self._verify_js_requirements()
|
||||
current_language_support().setup_test_config(self.test_cfg, file_path)
|
||||
break
|
||||
|
||||
if self.args.all:
|
||||
|
|
@ -624,7 +516,7 @@ class Optimizer:
|
|||
continue
|
||||
prepared_modules[original_module_path] = module_prep_result
|
||||
|
||||
validated_original_code, original_module_ast = prepared_modules[original_module_path]
|
||||
validated_original_code, _original_module_ast = prepared_modules[original_module_path]
|
||||
|
||||
function_iterator_count = i + 1
|
||||
logger.info(
|
||||
|
|
@ -640,8 +532,6 @@ class Optimizer:
|
|||
function_to_optimize_source_code=validated_original_code[original_module_path].source_code,
|
||||
function_benchmark_timings=function_benchmark_timings,
|
||||
total_benchmark_timings=total_benchmark_timings,
|
||||
original_module_ast=original_module_ast,
|
||||
original_module_path=original_module_path,
|
||||
call_graph=resolver,
|
||||
)
|
||||
if function_optimizer is None:
|
||||
|
|
@ -762,6 +652,8 @@ class Optimizer:
|
|||
if hasattr(get_run_tmp_file, "tmpdir"):
|
||||
get_run_tmp_file.tmpdir.cleanup()
|
||||
del get_run_tmp_file.tmpdir
|
||||
if hasattr(get_run_tmp_file, "tmpdir_path"):
|
||||
del get_run_tmp_file.tmpdir_path
|
||||
|
||||
# Always clean up concolic test directory
|
||||
cleanup_paths([self.test_cfg.concolic_test_root_dir])
|
||||
|
|
|
|||
|
|
@ -1,133 +0,0 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import ast
|
||||
import importlib.util
|
||||
import subprocess
|
||||
import tempfile
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from codeflash.cli_cmds.console import console, logger
|
||||
from codeflash.code_utils.compat import SAFE_SYS_EXECUTABLE
|
||||
from codeflash.code_utils.shell_utils import make_env_with_project_root
|
||||
from codeflash.discovery.discover_unit_tests import discover_unit_tests
|
||||
from codeflash.languages import is_python
|
||||
from codeflash.languages.python.static_analysis.concolic_utils import clean_concolic_tests, is_valid_concolic_test
|
||||
from codeflash.languages.python.static_analysis.static_analysis import has_typed_parameters
|
||||
from codeflash.lsp.helpers import is_LSP_enabled
|
||||
from codeflash.telemetry.posthog_cf import ph
|
||||
from codeflash.verification.verification_utils import TestConfig
|
||||
|
||||
CROSSHAIR_AVAILABLE = importlib.util.find_spec("crosshair") is not None
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from argparse import Namespace
|
||||
|
||||
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
|
||||
from codeflash.models.models import FunctionCalledInTest
|
||||
|
||||
|
||||
def generate_concolic_tests(
|
||||
test_cfg: TestConfig, args: Namespace, function_to_optimize: FunctionToOptimize, function_to_optimize_ast: ast.AST
|
||||
) -> tuple[dict[str, set[FunctionCalledInTest]], str]:
|
||||
"""Generate concolic tests using CrossHair (Python only).
|
||||
|
||||
CrossHair is a Python-specific symbolic execution tool. For non-Python languages
|
||||
(JavaScript, TypeScript, etc.), this function returns early with empty results.
|
||||
|
||||
Args:
|
||||
test_cfg: Test configuration
|
||||
args: Command line arguments
|
||||
function_to_optimize: The function being optimized
|
||||
function_to_optimize_ast: AST of the function (Python ast.FunctionDef)
|
||||
|
||||
Returns:
|
||||
Tuple of (function_to_tests mapping, concolic test suite code)
|
||||
|
||||
"""
|
||||
start_time = time.perf_counter()
|
||||
function_to_concolic_tests = {}
|
||||
concolic_test_suite_code = ""
|
||||
|
||||
# CrossHair is Python-only - skip for other languages
|
||||
if not is_python():
|
||||
logger.debug("Skipping concolic test generation for non-Python languages (CrossHair is Python-only)")
|
||||
return function_to_concolic_tests, concolic_test_suite_code
|
||||
|
||||
if not CROSSHAIR_AVAILABLE:
|
||||
logger.debug("Skipping concolic test generation (crosshair-tool is not installed)")
|
||||
return function_to_concolic_tests, concolic_test_suite_code
|
||||
|
||||
if is_LSP_enabled():
|
||||
logger.debug("Skipping concolic test generation in LSP mode")
|
||||
return function_to_concolic_tests, concolic_test_suite_code
|
||||
|
||||
if (
|
||||
test_cfg.concolic_test_root_dir
|
||||
and isinstance(function_to_optimize_ast, ast.FunctionDef)
|
||||
and has_typed_parameters(function_to_optimize_ast, function_to_optimize.parents)
|
||||
):
|
||||
logger.info("Generating concolic opcode coverage tests for the original code…")
|
||||
console.rule()
|
||||
try:
|
||||
env = make_env_with_project_root(args.project_root)
|
||||
cover_result = subprocess.run(
|
||||
[
|
||||
SAFE_SYS_EXECUTABLE,
|
||||
"-m",
|
||||
"crosshair",
|
||||
"cover",
|
||||
"--example_output_format=pytest",
|
||||
"--per_condition_timeout=20",
|
||||
".".join(
|
||||
[
|
||||
function_to_optimize.file_path.relative_to(args.project_root)
|
||||
.with_suffix("")
|
||||
.as_posix()
|
||||
.replace("/", "."),
|
||||
function_to_optimize.qualified_name,
|
||||
]
|
||||
),
|
||||
],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
cwd=args.project_root,
|
||||
check=False,
|
||||
timeout=600,
|
||||
env=env,
|
||||
)
|
||||
except subprocess.TimeoutExpired:
|
||||
logger.debug("CrossHair Cover test generation timed out")
|
||||
return function_to_concolic_tests, concolic_test_suite_code
|
||||
|
||||
if cover_result.returncode == 0:
|
||||
generated_concolic_test: str = cover_result.stdout
|
||||
if not is_valid_concolic_test(generated_concolic_test, project_root=str(args.project_root)):
|
||||
logger.debug("CrossHair generated invalid test, skipping")
|
||||
console.rule()
|
||||
return function_to_concolic_tests, concolic_test_suite_code
|
||||
concolic_test_suite_code: str = clean_concolic_tests(generated_concolic_test)
|
||||
concolic_test_suite_dir = Path(tempfile.mkdtemp(dir=test_cfg.concolic_test_root_dir))
|
||||
concolic_test_suite_path = concolic_test_suite_dir / "test_concolic_coverage.py"
|
||||
concolic_test_suite_path.write_text(concolic_test_suite_code, encoding="utf8")
|
||||
|
||||
concolic_test_cfg = TestConfig(
|
||||
tests_root=concolic_test_suite_dir,
|
||||
tests_project_rootdir=test_cfg.concolic_test_root_dir,
|
||||
project_root_path=args.project_root,
|
||||
)
|
||||
function_to_concolic_tests, num_discovered_concolic_tests, _ = discover_unit_tests(concolic_test_cfg)
|
||||
logger.info(
|
||||
f"Created {num_discovered_concolic_tests} "
|
||||
f"concolic unit test case{'s' if num_discovered_concolic_tests != 1 else ''} "
|
||||
)
|
||||
console.rule()
|
||||
ph("cli-optimize-concolic-tests", {"num_tests": num_discovered_concolic_tests})
|
||||
|
||||
else:
|
||||
logger.debug(f"Error running CrossHair Cover {': ' + cover_result.stderr if cover_result.stderr else '.'}")
|
||||
console.rule()
|
||||
end_time = time.perf_counter()
|
||||
logger.debug(f"Generated concolic tests in {end_time - start_time:.2f} seconds")
|
||||
return function_to_concolic_tests, concolic_test_suite_code
|
||||
|
|
@ -20,7 +20,7 @@ from codeflash.code_utils.code_utils import (
|
|||
module_name_from_file_path,
|
||||
)
|
||||
from codeflash.discovery.discover_unit_tests import discover_parameters_unittest
|
||||
from codeflash.languages import is_javascript
|
||||
from codeflash.languages import Language
|
||||
|
||||
# Import Jest-specific parsing from the JavaScript language module
|
||||
from codeflash.languages.javascript.parse import parse_jest_test_xml as _parse_jest_test_xml
|
||||
|
|
@ -32,7 +32,6 @@ from codeflash.models.models import (
|
|||
TestType,
|
||||
VerificationType,
|
||||
)
|
||||
from codeflash.verification.coverage_utils import CoverageUtils, JestCoverageUtils
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import subprocess
|
||||
|
|
@ -438,8 +437,10 @@ def parse_sqlite_test_results(sqlite_file_path: Path, test_files: TestFiles, tes
|
|||
finally:
|
||||
db.close()
|
||||
|
||||
# Check if this is a JavaScript test (use JSON) or Python test (use pickle)
|
||||
is_jest = is_javascript()
|
||||
# Check serialization format: JavaScript uses JSON, Python uses pickle
|
||||
from codeflash.languages.current import current_language_support
|
||||
|
||||
is_json_format = current_language_support().test_result_serialization_format == "json"
|
||||
|
||||
for val in data:
|
||||
try:
|
||||
|
|
@ -452,7 +453,7 @@ def parse_sqlite_test_results(sqlite_file_path: Path, test_files: TestFiles, tes
|
|||
# - A module-style path: "tests.fibonacci.test.ts" (dots as separators)
|
||||
# - A file path: "tests/fibonacci.test.ts" (slashes as separators)
|
||||
# For Python, it's a module path (e.g., "tests.test_foo") that needs conversion
|
||||
if is_jest:
|
||||
if is_json_format:
|
||||
# Jest test file extensions (including .test.ts, .spec.ts patterns)
|
||||
jest_test_extensions = (
|
||||
".test.ts",
|
||||
|
|
@ -517,7 +518,7 @@ def parse_sqlite_test_results(sqlite_file_path: Path, test_files: TestFiles, tes
|
|||
test_type = test_files.get_test_type_by_instrumented_file_path(test_file_path)
|
||||
logger.debug(f"[PARSE-DEBUG] by_instrumented_file_path: {test_type}")
|
||||
# Default to GENERATED_REGRESSION for Jest tests when test type can't be determined
|
||||
if test_type is None and is_jest:
|
||||
if test_type is None and is_json_format:
|
||||
test_type = TestType.GENERATED_REGRESSION
|
||||
logger.debug("[PARSE-DEBUG] defaulting to GENERATED_REGRESSION (Jest)")
|
||||
elif test_type is None:
|
||||
|
|
@ -532,7 +533,7 @@ def parse_sqlite_test_results(sqlite_file_path: Path, test_files: TestFiles, tes
|
|||
ret_val = None
|
||||
if loop_index == 1 and val[7]:
|
||||
try:
|
||||
if is_jest:
|
||||
if is_json_format:
|
||||
# Jest comparison happens via Node.js script (language_support.compare_test_results)
|
||||
# Store a marker indicating data exists but is not deserialized in Python
|
||||
ret_val = ("__serialized__", val[7])
|
||||
|
|
@ -578,7 +579,9 @@ def parse_test_xml(
|
|||
run_result: subprocess.CompletedProcess | None = None,
|
||||
) -> TestResults:
|
||||
# Route to Jest-specific parser for JavaScript/TypeScript tests
|
||||
if is_javascript():
|
||||
from codeflash.languages.current import current_language
|
||||
|
||||
if current_language() in (Language.JAVASCRIPT, Language.TYPESCRIPT):
|
||||
return _parse_jest_test_xml(
|
||||
test_xml_file_path,
|
||||
test_files,
|
||||
|
|
@ -1000,7 +1003,9 @@ def parse_test_results(
|
|||
# Also try to read legacy binary format for Python tests
|
||||
# Binary file may contain additional results (e.g., from codeflash_wrap) even if SQLite has data
|
||||
# from @codeflash_capture. We need to merge both sources.
|
||||
if not is_javascript():
|
||||
from codeflash.languages.current import current_language_support as _cls
|
||||
|
||||
if _cls().test_result_serialization_format == "pickle":
|
||||
try:
|
||||
bin_results_file = get_run_tmp_file(Path(f"test_return_values_{optimization_iteration}.bin"))
|
||||
if bin_results_file.exists():
|
||||
|
|
@ -1036,23 +1041,13 @@ def parse_test_results(
|
|||
coverage = None
|
||||
if coverage_database_file and source_file and code_context and function_name:
|
||||
all_args = True
|
||||
if is_javascript():
|
||||
# Jest uses coverage-final.json (coverage_database_file points to this)
|
||||
coverage = JestCoverageUtils.load_from_jest_json(
|
||||
coverage_json_path=coverage_database_file,
|
||||
function_name=function_name,
|
||||
code_context=code_context,
|
||||
source_code_path=source_file,
|
||||
)
|
||||
else:
|
||||
# Python uses coverage.py SQLite database
|
||||
coverage = CoverageUtils.load_from_sqlite_database(
|
||||
database_path=coverage_database_file,
|
||||
config_path=coverage_config_file,
|
||||
source_code_path=source_file,
|
||||
code_context=code_context,
|
||||
function_name=function_name,
|
||||
)
|
||||
coverage = _cls().load_coverage(
|
||||
coverage_database_file=coverage_database_file,
|
||||
function_name=function_name,
|
||||
code_context=code_context,
|
||||
source_file=source_file,
|
||||
coverage_config_file=coverage_config_file,
|
||||
)
|
||||
coverage.log_coverage()
|
||||
try:
|
||||
failures = parse_test_failures_from_stdout(run_result.stdout)
|
||||
|
|
|
|||
|
|
@ -1,29 +1,17 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import contextlib
|
||||
import re
|
||||
import shlex
|
||||
import shutil
|
||||
import subprocess
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from codeflash.cli_cmds.console import logger
|
||||
from codeflash.code_utils.code_utils import custom_addopts, get_run_tmp_file
|
||||
from codeflash.code_utils.compat import IS_POSIX, SAFE_SYS_EXECUTABLE
|
||||
from codeflash.code_utils.config_consts import TOTAL_LOOPING_TIME_EFFECTIVE
|
||||
from codeflash.code_utils.shell_utils import get_cross_platform_subprocess_run_args
|
||||
from codeflash.languages import is_python
|
||||
from codeflash.languages.python.static_analysis.coverage_utils import prepare_coverage_files
|
||||
from codeflash.languages.registry import get_language_support, get_language_support_by_framework
|
||||
from codeflash.models.models import TestFiles, TestType
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from codeflash.models.models import TestFiles
|
||||
from pathlib import Path
|
||||
|
||||
BEHAVIORAL_BLOCKLISTED_PLUGINS = ["benchmark", "codspeed", "xdist", "sugar"]
|
||||
BENCHMARKING_BLOCKLISTED_PLUGINS = ["codspeed", "cov", "benchmark", "profiling", "xdist", "sugar"]
|
||||
from codeflash.cli_cmds.console import logger
|
||||
from codeflash.code_utils.code_utils import custom_addopts
|
||||
from codeflash.code_utils.shell_utils import get_cross_platform_subprocess_run_args
|
||||
from codeflash.languages.registry import get_language_support
|
||||
|
||||
# Pattern to extract timing from stdout markers: !######...:<duration_ns>######!
|
||||
# Jest markers have multiple colons: !######module:test:func:loop:id:duration######!
|
||||
|
|
@ -112,267 +100,3 @@ def execute_test_subprocess(
|
|||
cwd=cwd, env=env, timeout=timeout, check=False, text=True, capture_output=True
|
||||
)
|
||||
return subprocess.run(cmd_list, **run_args) # noqa: PLW1510
|
||||
|
||||
|
||||
def run_behavioral_tests(
|
||||
test_paths: TestFiles,
|
||||
test_framework: str,
|
||||
test_env: dict[str, str],
|
||||
cwd: Path,
|
||||
*,
|
||||
pytest_timeout: int | None = None,
|
||||
pytest_cmd: str = "pytest",
|
||||
pytest_target_runtime_seconds: float = TOTAL_LOOPING_TIME_EFFECTIVE,
|
||||
enable_coverage: bool = False,
|
||||
js_project_root: Path | None = None,
|
||||
candidate_index: int = 0,
|
||||
) -> tuple[Path, subprocess.CompletedProcess, Path | None, Path | None]:
|
||||
"""Run behavioral tests with optional coverage."""
|
||||
# Check if there's a language support for this test framework that implements run_behavioral_tests
|
||||
language_support = get_language_support_by_framework(test_framework)
|
||||
if language_support is not None and hasattr(language_support, "run_behavioral_tests"):
|
||||
return language_support.run_behavioral_tests(
|
||||
test_paths=test_paths,
|
||||
test_env=test_env,
|
||||
cwd=cwd,
|
||||
timeout=pytest_timeout,
|
||||
project_root=js_project_root,
|
||||
enable_coverage=enable_coverage,
|
||||
candidate_index=candidate_index,
|
||||
)
|
||||
if is_python():
|
||||
test_files: list[str] = []
|
||||
for file in test_paths.test_files:
|
||||
if file.test_type == TestType.REPLAY_TEST:
|
||||
# Replay tests need specific test targeting because one file contains tests for multiple functions
|
||||
if file.tests_in_file:
|
||||
test_files.extend(
|
||||
[
|
||||
str(file.instrumented_behavior_file_path) + "::" + test.test_function
|
||||
for test in file.tests_in_file
|
||||
]
|
||||
)
|
||||
else:
|
||||
test_files.append(str(file.instrumented_behavior_file_path))
|
||||
|
||||
pytest_cmd_list = (
|
||||
shlex.split(f"{SAFE_SYS_EXECUTABLE} -m pytest", posix=IS_POSIX)
|
||||
if pytest_cmd == "pytest"
|
||||
else [SAFE_SYS_EXECUTABLE, "-m", *shlex.split(pytest_cmd, posix=IS_POSIX)]
|
||||
)
|
||||
test_files = list(set(test_files)) # remove multiple calls in the same test function
|
||||
|
||||
common_pytest_args = [
|
||||
"--capture=tee-sys",
|
||||
"-q",
|
||||
"--codeflash_loops_scope=session",
|
||||
"--codeflash_min_loops=1",
|
||||
"--codeflash_max_loops=1",
|
||||
f"--codeflash_seconds={pytest_target_runtime_seconds}",
|
||||
]
|
||||
if pytest_timeout is not None:
|
||||
common_pytest_args.append(f"--timeout={pytest_timeout}")
|
||||
|
||||
result_file_path = get_run_tmp_file(Path("pytest_results.xml"))
|
||||
result_args = [f"--junitxml={result_file_path.as_posix()}", "-o", "junit_logging=all"]
|
||||
|
||||
pytest_test_env = test_env.copy()
|
||||
pytest_test_env["PYTEST_PLUGINS"] = "codeflash.verification.pytest_plugin"
|
||||
|
||||
if enable_coverage:
|
||||
coverage_database_file, coverage_config_file = prepare_coverage_files()
|
||||
# disable jit for coverage
|
||||
pytest_test_env["NUMBA_DISABLE_JIT"] = str(1)
|
||||
pytest_test_env["TORCHDYNAMO_DISABLE"] = str(1)
|
||||
pytest_test_env["PYTORCH_JIT"] = str(0)
|
||||
pytest_test_env["TF_XLA_FLAGS"] = "--tf_xla_auto_jit=0"
|
||||
pytest_test_env["TF_ENABLE_ONEDNN_OPTS"] = str(0)
|
||||
pytest_test_env["JAX_DISABLE_JIT"] = str(0)
|
||||
|
||||
is_windows = sys.platform == "win32"
|
||||
if is_windows:
|
||||
# On Windows, delete coverage database file directly instead of using 'coverage erase', to avoid locking issues
|
||||
if coverage_database_file.exists():
|
||||
with contextlib.suppress(PermissionError, OSError):
|
||||
coverage_database_file.unlink()
|
||||
else:
|
||||
cov_erase = execute_test_subprocess(
|
||||
shlex.split(f"{SAFE_SYS_EXECUTABLE} -m coverage erase"), cwd=cwd, env=pytest_test_env, timeout=30
|
||||
) # this cleanup is necessary to avoid coverage data from previous runs, if there are any, then the current run will be appended to the previous data, which skews the results
|
||||
logger.debug(cov_erase)
|
||||
coverage_cmd = [
|
||||
SAFE_SYS_EXECUTABLE,
|
||||
"-m",
|
||||
"coverage",
|
||||
"run",
|
||||
f"--rcfile={coverage_config_file.as_posix()}",
|
||||
"-m",
|
||||
]
|
||||
|
||||
if pytest_cmd == "pytest":
|
||||
coverage_cmd.extend(["pytest"])
|
||||
else:
|
||||
coverage_cmd.extend(shlex.split(pytest_cmd, posix=IS_POSIX)[1:])
|
||||
|
||||
blocklist_args = [f"-p no:{plugin}" for plugin in BEHAVIORAL_BLOCKLISTED_PLUGINS if plugin != "cov"]
|
||||
results = execute_test_subprocess(
|
||||
coverage_cmd + common_pytest_args + blocklist_args + result_args + test_files,
|
||||
cwd=cwd,
|
||||
env=pytest_test_env,
|
||||
timeout=600,
|
||||
)
|
||||
logger.debug(
|
||||
f"Result return code: {results.returncode}, "
|
||||
f"{'Result stderr:' + str(results.stderr) if results.stderr else ''}"
|
||||
)
|
||||
else:
|
||||
blocklist_args = [f"-p no:{plugin}" for plugin in BEHAVIORAL_BLOCKLISTED_PLUGINS]
|
||||
|
||||
results = execute_test_subprocess(
|
||||
pytest_cmd_list + common_pytest_args + blocklist_args + result_args + test_files,
|
||||
cwd=cwd,
|
||||
env=pytest_test_env,
|
||||
timeout=600, # TODO: Make this dynamic
|
||||
)
|
||||
logger.debug(
|
||||
f"""Result return code: {results.returncode}, {"Result stderr:" + str(results.stderr) if results.stderr else ""}"""
|
||||
)
|
||||
else:
|
||||
msg = f"Unsupported test framework: {test_framework}"
|
||||
raise ValueError(msg)
|
||||
|
||||
return (
|
||||
result_file_path,
|
||||
results,
|
||||
coverage_database_file if enable_coverage else None,
|
||||
coverage_config_file if enable_coverage else None,
|
||||
)
|
||||
|
||||
|
||||
def run_line_profile_tests(
|
||||
test_paths: TestFiles,
|
||||
pytest_cmd: str,
|
||||
test_env: dict[str, str],
|
||||
cwd: Path,
|
||||
test_framework: str,
|
||||
*,
|
||||
pytest_target_runtime_seconds: float = TOTAL_LOOPING_TIME_EFFECTIVE,
|
||||
pytest_timeout: int | None = None,
|
||||
pytest_min_loops: int = 5,
|
||||
pytest_max_loops: int = 100_000,
|
||||
js_project_root: Path | None = None,
|
||||
line_profiler_output_file: Path | None = None,
|
||||
) -> tuple[Path, subprocess.CompletedProcess]:
|
||||
# Check if there's a language support for this test framework that implements run_line_profile_tests
|
||||
language_support = get_language_support_by_framework(test_framework)
|
||||
if language_support is not None and hasattr(language_support, "run_line_profile_tests"):
|
||||
return language_support.run_line_profile_tests(
|
||||
test_paths=test_paths,
|
||||
test_env=test_env,
|
||||
cwd=cwd,
|
||||
timeout=pytest_timeout,
|
||||
project_root=js_project_root,
|
||||
line_profile_output_file=line_profiler_output_file,
|
||||
)
|
||||
if is_python(): # pytest runs both pytest and unittest tests
|
||||
pytest_cmd_list = (
|
||||
shlex.split(f"{SAFE_SYS_EXECUTABLE} -m pytest", posix=IS_POSIX)
|
||||
if pytest_cmd == "pytest"
|
||||
else shlex.split(pytest_cmd)
|
||||
)
|
||||
# Always use file path - pytest discovers all tests including parametrized ones
|
||||
test_files: list[str] = list(
|
||||
{str(file.benchmarking_file_path) for file in test_paths.test_files}
|
||||
) # remove multiple calls in the same test function
|
||||
pytest_args = [
|
||||
"--capture=tee-sys",
|
||||
"-q",
|
||||
"--codeflash_loops_scope=session",
|
||||
"--codeflash_min_loops=1",
|
||||
"--codeflash_max_loops=1",
|
||||
f"--codeflash_seconds={pytest_target_runtime_seconds}",
|
||||
]
|
||||
if pytest_timeout is not None:
|
||||
pytest_args.append(f"--timeout={pytest_timeout}")
|
||||
result_file_path = get_run_tmp_file(Path("pytest_results.xml"))
|
||||
result_args = [f"--junitxml={result_file_path.as_posix()}", "-o", "junit_logging=all"]
|
||||
pytest_test_env = test_env.copy()
|
||||
pytest_test_env["PYTEST_PLUGINS"] = "codeflash.verification.pytest_plugin"
|
||||
blocklist_args = [f"-p no:{plugin}" for plugin in BENCHMARKING_BLOCKLISTED_PLUGINS]
|
||||
pytest_test_env["LINE_PROFILE"] = "1"
|
||||
results = execute_test_subprocess(
|
||||
pytest_cmd_list + pytest_args + blocklist_args + result_args + test_files,
|
||||
cwd=cwd,
|
||||
env=pytest_test_env,
|
||||
timeout=600, # TODO: Make this dynamic
|
||||
)
|
||||
else:
|
||||
msg = f"Unsupported test framework: {test_framework}"
|
||||
raise ValueError(msg)
|
||||
return result_file_path, results
|
||||
|
||||
|
||||
def run_benchmarking_tests(
|
||||
test_paths: TestFiles,
|
||||
pytest_cmd: str,
|
||||
test_env: dict[str, str],
|
||||
cwd: Path,
|
||||
test_framework: str,
|
||||
*,
|
||||
pytest_target_runtime_seconds: float = TOTAL_LOOPING_TIME_EFFECTIVE,
|
||||
pytest_timeout: int | None = None,
|
||||
pytest_min_loops: int = 5,
|
||||
pytest_max_loops: int = 100_000,
|
||||
js_project_root: Path | None = None,
|
||||
) -> tuple[Path, subprocess.CompletedProcess]:
|
||||
logger.debug(f"run_benchmarking_tests called: framework={test_framework}, num_files={len(test_paths.test_files)}")
|
||||
# Check if there's a language support for this test framework that implements run_benchmarking_tests
|
||||
language_support = get_language_support_by_framework(test_framework)
|
||||
if language_support is not None and hasattr(language_support, "run_benchmarking_tests"):
|
||||
return language_support.run_benchmarking_tests(
|
||||
test_paths=test_paths,
|
||||
test_env=test_env,
|
||||
cwd=cwd,
|
||||
timeout=pytest_timeout,
|
||||
project_root=js_project_root,
|
||||
min_loops=pytest_min_loops,
|
||||
max_loops=pytest_max_loops,
|
||||
target_duration_seconds=pytest_target_runtime_seconds,
|
||||
)
|
||||
if is_python(): # pytest runs both pytest and unittest tests
|
||||
pytest_cmd_list = (
|
||||
shlex.split(f"{SAFE_SYS_EXECUTABLE} -m pytest", posix=IS_POSIX)
|
||||
if pytest_cmd == "pytest"
|
||||
else shlex.split(pytest_cmd)
|
||||
)
|
||||
# Always use file path - pytest discovers all tests including parametrized ones
|
||||
test_files: list[str] = list(
|
||||
{str(file.benchmarking_file_path) for file in test_paths.test_files}
|
||||
) # remove multiple calls in the same test function
|
||||
pytest_args = [
|
||||
"--capture=tee-sys",
|
||||
"-q",
|
||||
"--codeflash_loops_scope=session",
|
||||
f"--codeflash_min_loops={pytest_min_loops}",
|
||||
f"--codeflash_max_loops={pytest_max_loops}",
|
||||
f"--codeflash_seconds={pytest_target_runtime_seconds}",
|
||||
"--codeflash_stability_check=true",
|
||||
]
|
||||
if pytest_timeout is not None:
|
||||
pytest_args.append(f"--timeout={pytest_timeout}")
|
||||
|
||||
result_file_path = get_run_tmp_file(Path("pytest_results.xml"))
|
||||
result_args = [f"--junitxml={result_file_path.as_posix()}", "-o", "junit_logging=all"]
|
||||
pytest_test_env = test_env.copy()
|
||||
pytest_test_env["PYTEST_PLUGINS"] = "codeflash.verification.pytest_plugin"
|
||||
blocklist_args = [f"-p no:{plugin}" for plugin in BENCHMARKING_BLOCKLISTED_PLUGINS]
|
||||
results = execute_test_subprocess(
|
||||
pytest_cmd_list + pytest_args + blocklist_args + result_args + test_files,
|
||||
cwd=cwd,
|
||||
env=pytest_test_env,
|
||||
timeout=600, # TODO: Make this dynamic
|
||||
)
|
||||
else:
|
||||
msg = f"Unsupported test framework: {test_framework}"
|
||||
raise ValueError(msg)
|
||||
return result_file_path, results
|
||||
|
|
|
|||
|
|
@ -6,7 +6,7 @@ from typing import Optional
|
|||
|
||||
from pydantic.dataclasses import dataclass
|
||||
|
||||
from codeflash.languages import current_language_support, is_javascript
|
||||
from codeflash.languages import current_language_support
|
||||
|
||||
|
||||
def get_test_file_path(
|
||||
|
|
@ -19,7 +19,7 @@ def get_test_file_path(
|
|||
assert test_type in {"unit", "inspired", "replay", "perf"}
|
||||
function_name = function_name.replace(".", "_")
|
||||
# Use appropriate file extension based on language
|
||||
extension = current_language_support().get_test_file_suffix() if is_javascript() else ".py"
|
||||
extension = current_language_support().get_test_file_suffix()
|
||||
|
||||
# For JavaScript/TypeScript, place generated tests in a subdirectory that matches
|
||||
# Vitest/Jest include patterns (e.g., test/**/*.test.ts)
|
||||
|
|
@ -164,16 +164,8 @@ class TestConfig:
|
|||
|
||||
@property
|
||||
def test_framework(self) -> str:
|
||||
"""Returns the appropriate test framework based on language.
|
||||
|
||||
For JavaScript/TypeScript: uses the configured framework (vitest, jest, or mocha).
|
||||
For Python: uses pytest as default.
|
||||
"""
|
||||
if is_javascript():
|
||||
from codeflash.languages.test_framework import get_js_test_framework_or_default
|
||||
|
||||
return get_js_test_framework_or_default()
|
||||
return "pytest"
|
||||
"""Returns the appropriate test framework based on language."""
|
||||
return current_language_support().test_framework
|
||||
|
||||
def set_language(self, language: str) -> None:
|
||||
"""Set the language for this test config.
|
||||
|
|
|
|||
|
|
@ -6,8 +6,8 @@ from pathlib import Path
|
|||
from typing import TYPE_CHECKING
|
||||
|
||||
from codeflash.cli_cmds.console import logger
|
||||
from codeflash.code_utils.code_utils import get_run_tmp_file, module_name_from_file_path
|
||||
from codeflash.languages import is_javascript
|
||||
from codeflash.code_utils.code_utils import module_name_from_file_path
|
||||
from codeflash.languages.current import current_language_support
|
||||
from codeflash.verification.verification_utils import ModifyInspiredTests, delete_multiple_if_name_main
|
||||
|
||||
if TYPE_CHECKING:
|
||||
|
|
@ -35,15 +35,13 @@ def generate_tests(
|
|||
start_time = time.perf_counter()
|
||||
test_module_path = Path(module_name_from_file_path(test_path, test_cfg.tests_project_rootdir))
|
||||
|
||||
# Detect module system for JavaScript/TypeScript before calling aiservice
|
||||
project_module_system = None
|
||||
if is_javascript():
|
||||
from codeflash.languages.javascript.module_system import detect_module_system
|
||||
# Detect module system via language support (non-None for JS/TS, None for Python)
|
||||
lang_support = current_language_support()
|
||||
source_file = Path(function_to_optimize.file_path)
|
||||
project_module_system = lang_support.detect_module_system(test_cfg.tests_project_rootdir, source_file)
|
||||
|
||||
source_file = Path(function_to_optimize.file_path)
|
||||
project_module_system = detect_module_system(test_cfg.tests_project_rootdir, source_file)
|
||||
|
||||
# For JavaScript, calculate the correct import path from the actual test location
|
||||
if project_module_system is not None:
|
||||
# For JavaScript/TypeScript, calculate the correct import path from the actual test location
|
||||
# (test_path) to the source file, not from tests_root
|
||||
import os
|
||||
|
||||
|
|
@ -73,65 +71,18 @@ def generate_tests(
|
|||
)
|
||||
if response and isinstance(response, tuple) and len(response) == 3:
|
||||
generated_test_source, instrumented_behavior_test_source, instrumented_perf_test_source = response
|
||||
temp_run_dir = get_run_tmp_file(Path()).as_posix()
|
||||
|
||||
# For JavaScript/TypeScript, instrumentation is done locally (aiservice returns uninstrumented code)
|
||||
if is_javascript():
|
||||
from codeflash.languages.javascript.instrument import (
|
||||
TestingMode,
|
||||
fix_imports_inside_test_blocks,
|
||||
fix_jest_mock_paths,
|
||||
instrument_generated_js_test,
|
||||
validate_and_fix_import_style,
|
||||
)
|
||||
from codeflash.languages.javascript.module_system import (
|
||||
ensure_module_system_compatibility,
|
||||
ensure_vitest_imports,
|
||||
)
|
||||
|
||||
source_file = Path(function_to_optimize.file_path)
|
||||
|
||||
# Fix import statements that appear inside test blocks (invalid JS syntax)
|
||||
generated_test_source = fix_imports_inside_test_blocks(generated_test_source)
|
||||
|
||||
# Fix relative paths in jest.mock() calls
|
||||
generated_test_source = fix_jest_mock_paths(
|
||||
generated_test_source, test_path, source_file, test_cfg.tests_project_rootdir
|
||||
)
|
||||
|
||||
# Validate and fix import styles (default vs named exports)
|
||||
generated_test_source = validate_and_fix_import_style(
|
||||
generated_test_source, source_file, function_to_optimize.function_name
|
||||
)
|
||||
|
||||
# Convert module system if needed (e.g., CommonJS -> ESM for ESM projects)
|
||||
# Skip conversion if ts-jest is installed (handles interop natively)
|
||||
generated_test_source = ensure_module_system_compatibility(
|
||||
generated_test_source, project_module_system, test_cfg.tests_project_rootdir
|
||||
)
|
||||
|
||||
# Ensure vitest imports are present when using vitest framework
|
||||
generated_test_source = ensure_vitest_imports(generated_test_source, test_cfg.test_framework)
|
||||
|
||||
# Instrument for behavior verification (writes to SQLite)
|
||||
instrumented_behavior_test_source = instrument_generated_js_test(
|
||||
test_code=generated_test_source, function_to_optimize=function_to_optimize, mode=TestingMode.BEHAVIOR
|
||||
)
|
||||
|
||||
# Instrument for performance measurement (prints to stdout)
|
||||
instrumented_perf_test_source = instrument_generated_js_test(
|
||||
test_code=generated_test_source, function_to_optimize=function_to_optimize, mode=TestingMode.PERFORMANCE
|
||||
)
|
||||
|
||||
logger.debug(f"Instrumented JS/TS tests locally for {function_to_optimize.function_name}")
|
||||
else:
|
||||
# Python: instrumentation is done by aiservice, just replace temp dir placeholders
|
||||
instrumented_behavior_test_source = instrumented_behavior_test_source.replace(
|
||||
"{codeflash_run_tmp_dir_client_side}", temp_run_dir
|
||||
)
|
||||
instrumented_perf_test_source = instrumented_perf_test_source.replace(
|
||||
"{codeflash_run_tmp_dir_client_side}", temp_run_dir
|
||||
generated_test_source, instrumented_behavior_test_source, instrumented_perf_test_source = (
|
||||
lang_support.process_generated_test_strings(
|
||||
generated_test_source=generated_test_source,
|
||||
instrumented_behavior_test_source=instrumented_behavior_test_source,
|
||||
instrumented_perf_test_source=instrumented_perf_test_source,
|
||||
function_to_optimize=function_to_optimize,
|
||||
test_path=test_path,
|
||||
test_cfg=test_cfg,
|
||||
project_module_system=project_module_system,
|
||||
)
|
||||
)
|
||||
else:
|
||||
logger.warning(f"Failed to generate and instrument tests for {function_to_optimize.function_name}")
|
||||
return None
|
||||
|
|
|
|||
1
docs/FRICTIONLESS_SETUP_PLAN.md
Normal file
1
docs/FRICTIONLESS_SETUP_PLAN.md
Normal file
|
|
@ -0,0 +1 @@
|
|||
|
||||
1
docs/JS_PROMPT_PARITY_RECOMMENDATIONS.md
Normal file
1
docs/JS_PROMPT_PARITY_RECOMMENDATIONS.md
Normal file
|
|
@ -0,0 +1 @@
|
|||
|
||||
|
|
@ -3,25 +3,31 @@ title: "How Codeflash Works"
|
|||
description: "Understand Codeflash's generate-and-verify approach to code optimization and correctness verification"
|
||||
icon: "gear"
|
||||
sidebarTitle: "How It Works"
|
||||
keywords: ["architecture", "verification", "correctness", "testing", "optimization", "LLM", "benchmarking"]
|
||||
keywords: ["architecture", "verification", "correctness", "testing", "optimization", "LLM", "benchmarking", "javascript", "typescript", "python"]
|
||||
---
|
||||
# How Codeflash Works
|
||||
|
||||
Codeflash follows a "generate and verify" approach to optimize code. It uses LLMs to generate optimizations, then it rigorously verifies if those optimizations are indeed
|
||||
faster and if they have the same behavior. The basic unit of optimization is a function—Codeflash tries to speed up the function, and tries to ensure that it still behaves the same way. This way if you merge the optimized code, it simply runs faster without breaking any functionality.
|
||||
|
||||
Codeflash supports **Python**, **JavaScript**, and **TypeScript** projects.
|
||||
|
||||
## Analysis of your code
|
||||
|
||||
Codeflash scans your codebase to identify all available functions. It locates existing unit tests in your projects and maps which functions they test. When optimizing a function, Codeflash runs these discovered tests to verify nothing has broken.
|
||||
|
||||
For Python, code analysis uses `libcst` and `jedi`. For JavaScript/TypeScript, it uses `tree-sitter` for AST parsing.
|
||||
|
||||
#### What kind of functions can Codeflash optimize?
|
||||
|
||||
Codeflash works best with self-contained functions that have minimal side effects (like communicating with external systems or sending network requests). Codeflash optimizes a group of functions - consisting of an entry point function and any other functions it directly calls.
|
||||
Codeflash supports optimizing async functions.
|
||||
Codeflash supports optimizing async functions in all supported languages.
|
||||
|
||||
#### Test Discovery
|
||||
|
||||
Codeflash currently only runs tests that directly call the target function in their test body. To discover tests that indirectly call the function, you can use the Codeflash Tracer. The Tracer analyzes your test suite and identifies all tests that eventually call a function.
|
||||
Codeflash discovers tests that directly call the target function in their test body. For Python, it finds pytest and unittest tests. For JavaScript/TypeScript, it finds Jest and Vitest test files.
|
||||
|
||||
To discover tests that indirectly call the function, you can use the Codeflash Tracer. The Tracer analyzes your test suite and identifies all tests that eventually call a function.
|
||||
|
||||
## Optimization Generation
|
||||
|
||||
|
|
@ -48,12 +54,12 @@ We recommend manually reviewing the optimized code since there might be importan
|
|||
|
||||
Codeflash generates two types of tests:
|
||||
|
||||
- LLM Generated tests - Codeflash uses LLMs to create several regression test cases that cover typical function usage, edge cases, and large-scale inputs to verify both correctness and performance.
|
||||
- Concolic coverage tests - Codeflash uses state-of-the-art concolic testing with an SMT Solver (a theorem prover) to explore execution paths and generate function arguments. This aims to maximize code coverage for the function being optimized. Codeflash runs the resulting test file to verify correctness. Currently, this feature only supports pytest.
|
||||
- **LLM Generated tests** - Codeflash uses LLMs to create several regression test cases that cover typical function usage, edge cases, and large-scale inputs to verify both correctness and performance. This works for Python, JavaScript, and TypeScript.
|
||||
- **Concolic coverage tests** - Codeflash uses state-of-the-art concolic testing with an SMT Solver (a theorem prover) to explore execution paths and generate function arguments. This aims to maximize code coverage for the function being optimized. Currently, this feature only supports Python (pytest).
|
||||
|
||||
## Code Execution
|
||||
|
||||
Codeflash runs tests for the target function using either pytest or unittest frameworks. The tests execute on your machine, ensuring access to the Python environment and any other dependencies associated to let Codeflash run your code properly. Running on your machine also ensures accurate performance measurements since runtime varies by system.
|
||||
Codeflash runs tests for the target function on your machine. For Python, it uses pytest or unittest. For JavaScript/TypeScript, it uses Jest or Vitest. Running on your machine ensures access to your environment and dependencies, and provides accurate performance measurements since runtime varies by system.
|
||||
|
||||
#### Performance benchmarking
|
||||
|
||||
|
|
|
|||
|
|
@ -1,83 +1,26 @@
|
|||
---
|
||||
title: "Manual Configuration"
|
||||
description: "Configure Codeflash for your project with pyproject.toml settings and advanced options"
|
||||
description: "Configure Codeflash for your project"
|
||||
icon: "gear"
|
||||
sidebarTitle: "Manual Configuration"
|
||||
keywords:
|
||||
[
|
||||
"configuration",
|
||||
"pyproject.toml",
|
||||
"setup",
|
||||
"settings",
|
||||
"pytest",
|
||||
"formatter",
|
||||
]
|
||||
---
|
||||
|
||||
# Manual Configuration
|
||||
|
||||
Codeflash is installed and configured on a per-project basis.
|
||||
`codeflash init` should guide you through the configuration process, but if you need to manually configure Codeflash or set advanced settings, you can do so by editing the `pyproject.toml` file in the root directory of your project.
|
||||
`codeflash init` should guide you through the configuration process, but if you need to manually configure Codeflash or set advanced settings, follow the guide for your language:
|
||||
|
||||
## Configuration Options
|
||||
|
||||
Codeflash config looks like the following
|
||||
|
||||
```toml
|
||||
[tool.codeflash]
|
||||
module-root = "my_module"
|
||||
tests-root = "tests"
|
||||
formatter-cmds = ["black $file"]
|
||||
# optional configuration
|
||||
benchmarks-root = "tests/benchmarks" # Required when running with --benchmark
|
||||
ignore-paths = ["my_module/build/"]
|
||||
pytest-cmd = "pytest"
|
||||
disable-imports-sorting = false
|
||||
disable-telemetry = false
|
||||
git-remote = "origin"
|
||||
override-fixtures = false
|
||||
```
|
||||
|
||||
All file paths are relative to the directory of the `pyproject.toml` file.
|
||||
|
||||
Required Options:
|
||||
|
||||
- `module-root`: The Python module you want Codeflash to optimize going forward. Only code under this directory will be optimized. It should also have an `__init__.py` file to make the module importable.
|
||||
- `tests-root`: The directory where your tests are located. Codeflash will use this directory to discover existing tests as well as generate new tests.
|
||||
|
||||
Optional Configuration:
|
||||
|
||||
- `benchmarks-root`: The directory where your benchmarks are located. Codeflash will use this directory to discover existing benchmarks. Note that this option is required when running with `--benchmark`.
|
||||
- `ignore-paths`: A list of paths within the `module-root` to ignore when optimizing code. Codeflash will not optimize code in these paths. Useful for ignoring build directories or other generated code. You can also leave this empty if not needed.
|
||||
- `pytest-cmd`: The command to run your tests. Defaults to `pytest`. You can specify extra commandline arguments here for pytest.
|
||||
- `formatter-cmds`: The command line to run your code formatter or linter. Defaults to `["black $file"]`. In the command line `$file` refers to the current file being optimized. The assumption with using tools here is that they overwrite the same file and returns a zero exit code. You can also specify multiple tools here that run in a chain as a toml array. You can also disable code formatting by setting this to `["disabled"]`.
|
||||
- `ruff` - A recommended way to run ruff linting and formatting is `["ruff check --exit-zero --fix $file", "ruff format $file"]`. To make `ruff check --fix` return a 0 exit code please add a `--exit-zero` argument.
|
||||
- `disable-imports-sorting`: By default, codeflash uses isort to organize your imports before creating suggestions. You can disable this by setting this field to `true`. This could be useful if you don't sort your imports or while using linters like ruff that sort imports too.
|
||||
- `disable-telemetry`: Disable telemetry data collection. Defaults to `false`. Set this to `true` to disable telemetry data collection. Codeflash collects anonymized telemetry data to understand how users are using Codeflash and to improve the product. Telemetry does not collect any code data.
|
||||
- `git-remote`: The git remote to use for pull requests. Defaults to `"origin"`.
|
||||
- `override-fixtures`: Override pytest fixtures during optimization. Defaults to `false`.
|
||||
|
||||
## Example Configuration
|
||||
|
||||
Here's an example project with the following structure:
|
||||
|
||||
```text
|
||||
acme-project/
|
||||
|- foo_module/
|
||||
| |- __init__.py
|
||||
| |- foo.py
|
||||
| |- main.py
|
||||
|- tests/
|
||||
| |- __init__.py
|
||||
| |- test_script.py
|
||||
|- pyproject.toml
|
||||
```
|
||||
|
||||
Here's a sample `pyproject.toml` file for the above project:
|
||||
|
||||
```toml
|
||||
[tool.codeflash]
|
||||
module-root = "foo_module"
|
||||
tests-root = "tests"
|
||||
ignore-paths = []
|
||||
```
|
||||
<CardGroup cols={2}>
|
||||
<Card title="Python Configuration" icon="python" href="/configuration/python">
|
||||
Configure via `pyproject.toml`
|
||||
</Card>
|
||||
<Card title="JavaScript / TypeScript Configuration" icon="js" href="/configuration/javascript">
|
||||
Configure via `package.json`
|
||||
</Card>
|
||||
</CardGroup>
|
||||
220
docs/configuration/javascript.mdx
Normal file
220
docs/configuration/javascript.mdx
Normal file
|
|
@ -0,0 +1,220 @@
|
|||
---
|
||||
title: "JavaScript / TypeScript Configuration"
|
||||
description: "Configure Codeflash for JavaScript and TypeScript projects using package.json"
|
||||
icon: "js"
|
||||
sidebarTitle: "JavaScript / TypeScript"
|
||||
keywords:
|
||||
[
|
||||
"configuration",
|
||||
"package.json",
|
||||
"javascript",
|
||||
"typescript",
|
||||
"jest",
|
||||
"vitest",
|
||||
"prettier",
|
||||
"eslint",
|
||||
"monorepo",
|
||||
]
|
||||
---
|
||||
|
||||
# JavaScript / TypeScript Configuration
|
||||
|
||||
Codeflash stores its configuration in `package.json` under the `"codeflash"` key.
|
||||
|
||||
## Full Reference
|
||||
|
||||
```json
|
||||
{
|
||||
"name": "my-project",
|
||||
"codeflash": {
|
||||
"moduleRoot": "src",
|
||||
"testsRoot": "tests",
|
||||
"testRunner": "jest",
|
||||
"formatterCmds": ["prettier --write $file"],
|
||||
"ignorePaths": ["src/generated/"],
|
||||
"disableTelemetry": false,
|
||||
"gitRemote": "origin"
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
All file paths are relative to the directory containing `package.json`.
|
||||
|
||||
<Info>
|
||||
Codeflash auto-detects most settings from your project structure. Running `codeflash init` will set up the correct config — manual configuration is usually not needed.
|
||||
</Info>
|
||||
|
||||
## Auto-Detection
|
||||
|
||||
When you run `codeflash init`, Codeflash inspects your project and auto-detects:
|
||||
|
||||
| Setting | Detection logic |
|
||||
|---------|----------------|
|
||||
| `moduleRoot` | Looks for `src/`, `lib/`, or the main source directory |
|
||||
| `testsRoot` | Looks for `tests/`, `test/`, `__tests__/`, or files matching `*.test.js` / `*.spec.js` |
|
||||
| `testRunner` | Checks `devDependencies` for `jest` or `vitest` |
|
||||
| `formatterCmds` | Checks for `prettier`, `eslint`, or `biome` in dependencies and config files |
|
||||
| Module system | Reads `"type"` field in `package.json` (ESM vs CommonJS) |
|
||||
| TypeScript | Detects `tsconfig.json` |
|
||||
|
||||
You can always override any auto-detected value in the `"codeflash"` section.
|
||||
|
||||
## Required Options
|
||||
|
||||
- `moduleRoot`: The source directory to optimize. Only code under this directory will be optimized.
|
||||
- `testsRoot`: The directory where your tests are located. Codeflash discovers existing tests and generates new ones here.
|
||||
|
||||
## Optional Options
|
||||
|
||||
- `testRunner`: Test framework to use. Auto-detected from your dependencies. Supported values: `"jest"`, `"vitest"`.
|
||||
- `formatterCmds`: Formatter commands. `$file` refers to the file being optimized. Disable with `["disabled"]`.
|
||||
- **Prettier**: `["prettier --write $file"]`
|
||||
- **ESLint + Prettier**: `["eslint --fix $file", "prettier --write $file"]`
|
||||
- **Biome**: `["biome check --write $file"]`
|
||||
- `ignorePaths`: Paths within `moduleRoot` to skip during optimization.
|
||||
- `disableTelemetry`: Disable anonymized telemetry. Defaults to `false`.
|
||||
- `gitRemote`: Git remote for pull requests. Defaults to `"origin"`.
|
||||
|
||||
## Module Systems
|
||||
|
||||
Codeflash handles both ES Modules and CommonJS automatically. It detects the module system from your `package.json`:
|
||||
|
||||
```json
|
||||
{
|
||||
"type": "module"
|
||||
}
|
||||
```
|
||||
|
||||
- `"type": "module"` — Files are treated as ESM (`import`/`export`)
|
||||
- `"type": "commonjs"` or omitted — Files are treated as CommonJS (`require`/`module.exports`)
|
||||
|
||||
No additional configuration is needed. Codeflash respects `.mjs`/`.cjs` extensions as well.
|
||||
|
||||
## TypeScript
|
||||
|
||||
TypeScript projects work out of the box. Codeflash detects TypeScript from the presence of `tsconfig.json` and handles `.ts`/`.tsx` files automatically.
|
||||
|
||||
No separate configuration is needed for TypeScript vs JavaScript.
|
||||
|
||||
## Test Framework Support
|
||||
|
||||
| Framework | Auto-detected from | Notes |
|
||||
|-----------|-------------------|-------|
|
||||
| **Jest** | `jest` in dependencies | Default for most projects |
|
||||
| **Vitest** | `vitest` in dependencies | ESM-native support |
|
||||
|
||||
<Info>
|
||||
**Functions must be exported** to be optimizable. Codeflash uses tree-sitter AST analysis to discover functions and check export status. Supported export patterns:
|
||||
|
||||
- `export function foo() {}`
|
||||
- `export const foo = () => {}`
|
||||
- `export default function foo() {}`
|
||||
- `const foo = () => {}; export { foo };`
|
||||
- `module.exports = { foo }`
|
||||
- `const utils = { foo() {} }; module.exports = utils;`
|
||||
</Info>
|
||||
|
||||
## Monorepo Configuration
|
||||
|
||||
For monorepo projects (Yarn workspaces, pnpm workspaces, Lerna, Nx, Turborepo), configure each package individually:
|
||||
|
||||
```text
|
||||
my-monorepo/
|
||||
|- packages/
|
||||
| |- core/
|
||||
| | |- src/
|
||||
| | |- tests/
|
||||
| | |- package.json <-- "codeflash" config here
|
||||
| |- utils/
|
||||
| | |- src/
|
||||
| | |- __tests__/
|
||||
| | |- package.json <-- "codeflash" config here
|
||||
|- package.json <-- workspace root (no codeflash config)
|
||||
```
|
||||
|
||||
Run `codeflash init` from within each package:
|
||||
|
||||
```bash
|
||||
cd packages/core
|
||||
npx codeflash init
|
||||
```
|
||||
|
||||
<Warning>
|
||||
**Always run codeflash from the package directory**, not the monorepo root. Codeflash needs to find the `package.json` with the `"codeflash"` config in the current working directory.
|
||||
</Warning>
|
||||
|
||||
### Hoisted dependencies
|
||||
|
||||
If your monorepo hoists `node_modules` to the root (Yarn Berry with `nodeLinker: node-modules`, pnpm with `shamefully-hoist`), Codeflash resolves modules using Node.js standard resolution. This works automatically.
|
||||
|
||||
For **pnpm strict mode** (non-hoisted), ensure `codeflash` is a direct dependency of the package:
|
||||
|
||||
```bash
|
||||
pnpm add --filter @my-org/core --save-dev codeflash
|
||||
```
|
||||
|
||||
## Example
|
||||
|
||||
### Standard project
|
||||
|
||||
```text
|
||||
my-app/
|
||||
|- src/
|
||||
| |- utils.js
|
||||
| |- index.js
|
||||
|- tests/
|
||||
| |- utils.test.js
|
||||
|- package.json
|
||||
```
|
||||
|
||||
```json
|
||||
{
|
||||
"name": "my-app",
|
||||
"codeflash": {
|
||||
"moduleRoot": "src",
|
||||
"testsRoot": "tests"
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
### Project with co-located tests
|
||||
|
||||
```text
|
||||
my-app/
|
||||
|- src/
|
||||
| |- utils.js
|
||||
| |- utils.test.js
|
||||
| |- index.js
|
||||
|- package.json
|
||||
```
|
||||
|
||||
```json
|
||||
{
|
||||
"name": "my-app",
|
||||
"codeflash": {
|
||||
"moduleRoot": "src",
|
||||
"testsRoot": "src"
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
### CommonJS library with no separate test directory
|
||||
|
||||
```text
|
||||
my-lib/
|
||||
|- lib/
|
||||
| |- helpers.js
|
||||
|- test/
|
||||
| |- helpers.spec.js
|
||||
|- package.json
|
||||
```
|
||||
|
||||
```json
|
||||
{
|
||||
"name": "my-lib",
|
||||
"codeflash": {
|
||||
"moduleRoot": "lib",
|
||||
"testsRoot": "test"
|
||||
}
|
||||
}
|
||||
```
|
||||
80
docs/configuration/python.mdx
Normal file
80
docs/configuration/python.mdx
Normal file
|
|
@ -0,0 +1,80 @@
|
|||
---
|
||||
title: "Python Configuration"
|
||||
description: "Configure Codeflash for Python projects using pyproject.toml"
|
||||
icon: "python"
|
||||
sidebarTitle: "Python"
|
||||
keywords:
|
||||
[
|
||||
"configuration",
|
||||
"pyproject.toml",
|
||||
"python",
|
||||
"pytest",
|
||||
"formatter",
|
||||
"ruff",
|
||||
"black",
|
||||
]
|
||||
---
|
||||
|
||||
# Python Configuration
|
||||
|
||||
Codeflash stores its configuration in `pyproject.toml` under the `[tool.codeflash]` section.
|
||||
|
||||
## Full Reference
|
||||
|
||||
```toml
|
||||
[tool.codeflash]
|
||||
# Required
|
||||
module-root = "my_module"
|
||||
tests-root = "tests"
|
||||
|
||||
# Optional
|
||||
formatter-cmds = ["black $file"]
|
||||
benchmarks-root = "tests/benchmarks"
|
||||
ignore-paths = ["my_module/build/"]
|
||||
pytest-cmd = "pytest"
|
||||
disable-imports-sorting = false
|
||||
disable-telemetry = false
|
||||
git-remote = "origin"
|
||||
override-fixtures = false
|
||||
```
|
||||
|
||||
All file paths are relative to the directory of the `pyproject.toml` file.
|
||||
|
||||
## Required Options
|
||||
|
||||
- `module-root`: The Python module to optimize. Only code under this directory will be optimized. It should have an `__init__.py` file to make the module importable.
|
||||
- `tests-root`: The directory where your tests are located. Codeflash discovers existing tests and generates new ones here.
|
||||
|
||||
## Optional Options
|
||||
|
||||
- `benchmarks-root`: Directory for benchmarks. Required when running with `--benchmark`.
|
||||
- `ignore-paths`: Paths within `module-root` to skip. Useful for build directories or generated code.
|
||||
- `pytest-cmd`: Command to run your tests. Defaults to `pytest`. You can add extra arguments here.
|
||||
- `formatter-cmds`: Formatter/linter commands. `$file` refers to the file being optimized. Disable with `["disabled"]`.
|
||||
- **ruff** (recommended): `["ruff check --exit-zero --fix $file", "ruff format $file"]`
|
||||
- **black**: `["black $file"]`
|
||||
- `disable-imports-sorting`: Disable isort import sorting. Defaults to `false`.
|
||||
- `disable-telemetry`: Disable anonymized telemetry. Defaults to `false`.
|
||||
- `git-remote`: Git remote for pull requests. Defaults to `"origin"`.
|
||||
- `override-fixtures`: Override pytest fixtures during optimization. Defaults to `false`.
|
||||
|
||||
## Example
|
||||
|
||||
```text
|
||||
acme-project/
|
||||
|- foo_module/
|
||||
| |- __init__.py
|
||||
| |- foo.py
|
||||
| |- main.py
|
||||
|- tests/
|
||||
| |- __init__.py
|
||||
| |- test_script.py
|
||||
|- pyproject.toml
|
||||
```
|
||||
|
||||
```toml
|
||||
[tool.codeflash]
|
||||
module-root = "foo_module"
|
||||
tests-root = "tests"
|
||||
ignore-paths = []
|
||||
```
|
||||
|
|
@ -1,38 +1,51 @@
|
|||
---
|
||||
title: "JavaScript Installation"
|
||||
title: "JavaScript / TypeScript Installation"
|
||||
description: "Install and configure Codeflash for your JavaScript/TypeScript project"
|
||||
icon: "node-js"
|
||||
keywords:
|
||||
[
|
||||
"installation",
|
||||
"javascript",
|
||||
"typescript",
|
||||
"npm",
|
||||
"yarn",
|
||||
"pnpm",
|
||||
"bun",
|
||||
"jest",
|
||||
"vitest",
|
||||
"monorepo",
|
||||
]
|
||||
---
|
||||
|
||||
Codeflash now supports JavaScript and TypeScript projects with optimized test data serialization using V8 native serialization.
|
||||
Codeflash supports JavaScript and TypeScript projects. It uses V8 native serialization for test data capture and works with Jest and Vitest test frameworks.
|
||||
|
||||
### Prerequisites
|
||||
|
||||
Before installing Codeflash for JavaScript, ensure you have:
|
||||
Before installing Codeflash, ensure you have:
|
||||
|
||||
1. **Node.js 16 or above** installed
|
||||
1. **Node.js 18 or above** installed
|
||||
2. **A JavaScript/TypeScript project** with a package manager (npm, yarn, pnpm, or bun)
|
||||
3. **Project dependencies installed**
|
||||
|
||||
Good to have (optional):
|
||||
|
||||
1. **Unit Tests** that Codeflash uses to ensure correctness of the optimizations
|
||||
1. **Unit tests** (Jest or Vitest) — Codeflash uses them to verify correctness of optimizations
|
||||
|
||||
<Warning>
|
||||
**Node.js Runtime Required**
|
||||
**Node.js 18+ Required**
|
||||
|
||||
Codeflash JavaScript support uses V8 serialization API, which is available natively in Node.js. Make sure you're running on Node.js 16+ for optimal compatibility.
|
||||
Codeflash requires Node.js 18 or above. Check your version:
|
||||
|
||||
```bash
|
||||
node --version # Should show v16.0.0 or higher
|
||||
node --version # Should show v18.0.0 or higher
|
||||
```
|
||||
|
||||
</Warning>
|
||||
|
||||
<Steps>
|
||||
<Step title="Install Codeflash CLI">
|
||||
<Step title="Install the Codeflash npm package">
|
||||
|
||||
Install Codeflash globally or as a development dependency in your project:
|
||||
Install Codeflash as a development dependency in your project:
|
||||
|
||||
<CodeGroup>
|
||||
```bash npm
|
||||
|
|
@ -50,321 +63,285 @@ pnpm add --save-dev codeflash
|
|||
```bash bun
|
||||
bun add --dev codeflash
|
||||
```
|
||||
|
||||
```bash global
|
||||
npm install -g codeflash
|
||||
```
|
||||
</CodeGroup>
|
||||
|
||||
<Tip>
|
||||
**Development Dependency Recommended**
|
||||
|
||||
Codeflash is intended for development and CI workflows. Installing as a dev dependency keeps your production bundle clean.
|
||||
|
||||
**Dev dependency recommended** — Codeflash is for development and CI workflows. Installing as a dev dependency keeps your production bundle clean.
|
||||
</Tip>
|
||||
|
||||
<Info>
|
||||
**Codeflash also requires a Python installation** (3.9+) to run the CLI optimizer. Install the Python CLI globally:
|
||||
|
||||
```bash
|
||||
pip install codeflash
|
||||
# or
|
||||
uv tool install codeflash
|
||||
```
|
||||
|
||||
The Python CLI orchestrates the optimization pipeline, while the npm package provides the JavaScript runtime (test runners, serialization, reporters).
|
||||
</Info>
|
||||
</Step>
|
||||
|
||||
<Step title="Generate a Codeflash API Key">
|
||||
Codeflash uses cloud-hosted AI models. You need an API key:
|
||||
|
||||
1. Visit the [Codeflash Web App](https://app.codeflash.ai/)
|
||||
2. Sign up with your GitHub account (free tier available)
|
||||
3. Navigate to the [API Key](https://app.codeflash.ai/app/apikeys) page to generate your key
|
||||
|
||||
Set it as an environment variable:
|
||||
|
||||
```bash
|
||||
export CODEFLASH_API_KEY="your-api-key-here"
|
||||
```
|
||||
|
||||
Or add it to your shell profile (`~/.bashrc`, `~/.zshrc`) for persistence.
|
||||
|
||||
</Step>
|
||||
|
||||
<Step title="Run Automatic Configuration">
|
||||
Navigate to your project's root directory (where your `package.json` file is) and run:
|
||||
Navigate to your project root (where `package.json` is) and run:
|
||||
|
||||
```bash
|
||||
<CodeGroup>
|
||||
```bash npm / yarn / pnpm
|
||||
npx codeflash init
|
||||
```
|
||||
|
||||
```bash bun
|
||||
bunx codeflash init
|
||||
```
|
||||
|
||||
```bash Global install
|
||||
codeflash init
|
||||
```
|
||||
</CodeGroup>
|
||||
|
||||
When running `codeflash init`, you will see the following prompts:
|
||||
### What `codeflash init` does
|
||||
|
||||
```text
|
||||
1. Enter your Codeflash API key (or login with Codeflash)
|
||||
2. Which JavaScript/TypeScript module do you want me to optimize? (e.g. src/)
|
||||
3. Where are your tests located? (e.g. tests/, __tests__/, *.test.js)
|
||||
4. Which test framework do you use? (jest/vitest/mocha/ava/other)
|
||||
5. Which code formatter do you use? (prettier/eslint/biome/disabled)
|
||||
6. Which git remote should Codeflash use for Pull Requests? (if multiple remotes exist)
|
||||
7. Help us improve Codeflash by sharing anonymous usage data?
|
||||
8. Install the GitHub app
|
||||
9. Install GitHub actions for Continuous optimization?
|
||||
Codeflash **auto-detects** most settings from your project:
|
||||
|
||||
| Setting | How it's detected |
|
||||
|---------|------------------|
|
||||
| **Module root** | Looks for `src/`, `lib/`, or the directory containing your source files |
|
||||
| **Tests root** | Looks for `tests/`, `test/`, `__tests__/`, or files matching `*.test.js` / `*.spec.js` |
|
||||
| **Test framework** | Checks `devDependencies` for `jest` or `vitest` |
|
||||
| **Formatter** | Checks for `prettier`, `eslint`, or `biome` in dependencies and config files |
|
||||
| **Module system** | Reads `"type"` field in `package.json` (ESM vs CommonJS) |
|
||||
| **TypeScript** | Detects `tsconfig.json` presence |
|
||||
|
||||
You'll be prompted to confirm or override the detected values. The configuration is saved in your `package.json` under the `"codeflash"` key:
|
||||
|
||||
```json
|
||||
{
|
||||
"name": "my-project",
|
||||
"codeflash": {
|
||||
"moduleRoot": "src",
|
||||
"testsRoot": "tests"
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
After you have answered these questions, the Codeflash configuration will be saved in a `codeflash.config.js` file.
|
||||
|
||||
<Info>
|
||||
**Test Data Serialization Strategy**
|
||||
|
||||
Codeflash uses **V8 serialization** for JavaScript test data capture. This provides:
|
||||
- ⚡ **Best performance**: 2-3x faster than alternatives
|
||||
- 🎯 **Perfect type preservation**: Maintains Date, Map, Set, TypedArrays, and more
|
||||
- 📦 **Compact binary storage**: Smallest file sizes
|
||||
- 🔄 **Framework agnostic**: Works with React, Vue, Angular, Svelte, and vanilla JS
|
||||
|
||||
**No separate config file needed.** Codeflash stores all configuration inside your existing `package.json`, not in a separate config file.
|
||||
</Info>
|
||||
|
||||
</Step>
|
||||
|
||||
<Step title="Generate a Codeflash API Key">
|
||||
Codeflash uses cloud-hosted AI models and integrations with GitHub. If you haven't created one already, you'll need to create an API key to authorize your access.
|
||||
<Step title="Install the Codeflash GitHub App (optional)">
|
||||
|
||||
1. Visit the [Codeflash Web App](https://app.codeflash.ai/)
|
||||
2. Sign up with your GitHub account (free)
|
||||
3. Navigate to the [API Key](https://app.codeflash.ai/app/apikeys) page to generate your API key
|
||||
To receive optimization PRs automatically, install the Codeflash GitHub App:
|
||||
|
||||
<Note>
|
||||
**Free Tier Available**
|
||||
[Install Codeflash GitHub App](https://github.com/apps/codeflash-ai/installations/select_target)
|
||||
|
||||
Codeflash offers a **free tier** with a limited number of optimizations. Perfect for trying it out on small projects!
|
||||
|
||||
</Note>
|
||||
</Step>
|
||||
|
||||
<Step title="Install the Codeflash GitHub App">
|
||||
|
||||
Finally, if you have not done so already, Codeflash will ask you to install the GitHub App in your repository.
|
||||
The Codeflash GitHub App allows the codeflash-ai bot to open PRs, review code, and provide optimization suggestions.
|
||||
|
||||
Please [install the Codeflash GitHub
|
||||
app](https://github.com/apps/codeflash-ai/installations/select_target) by choosing the repository you want to install
|
||||
Codeflash on.
|
||||
This enables the codeflash-ai bot to open PRs with optimization suggestions. If you skip this step, you can still optimize locally using `--no-pr`.
|
||||
|
||||
</Step>
|
||||
|
||||
</Steps>
|
||||
|
||||
## Framework Support
|
||||
## Monorepo Setup
|
||||
|
||||
Codeflash JavaScript support works seamlessly with all major frameworks and testing libraries:
|
||||
For monorepos (Yarn workspaces, pnpm workspaces, Lerna, Nx, Turborepo), run `codeflash init` from within each package you want to optimize:
|
||||
|
||||
<CardGroup cols={2}>
|
||||
<Card title="Frontend Frameworks" icon="react">
|
||||
- React
|
||||
- Vue.js
|
||||
- Angular
|
||||
- Svelte
|
||||
- Solid.js
|
||||
</Card>
|
||||
```bash
|
||||
# Navigate to the specific package
|
||||
cd packages/my-library
|
||||
|
||||
<Card title="Test Frameworks" icon="flask">
|
||||
- Jest
|
||||
- Vitest
|
||||
- Mocha
|
||||
- AVA
|
||||
- Playwright
|
||||
- Cypress
|
||||
</Card>
|
||||
# Run init from the package directory
|
||||
npx codeflash init
|
||||
```
|
||||
|
||||
<Card title="Backend" icon="server">
|
||||
- Express
|
||||
- NestJS
|
||||
- Fastify
|
||||
- Koa
|
||||
- Hono
|
||||
</Card>
|
||||
Each package gets its own `"codeflash"` section in its `package.json`. The `moduleRoot` and `testsRoot` paths are relative to that package's `package.json`.
|
||||
|
||||
<Card title="Runtimes" icon="gears">
|
||||
- Node.js ✅ (Recommended)
|
||||
- Bun (Coming soon)
|
||||
- Deno (Coming soon)
|
||||
</Card>
|
||||
</CardGroup>
|
||||
### Example: Yarn workspaces monorepo
|
||||
|
||||
## Understanding V8 Serialization
|
||||
```text
|
||||
my-monorepo/
|
||||
|- packages/
|
||||
| |- core/
|
||||
| | |- src/
|
||||
| | |- tests/
|
||||
| | |- package.json <-- codeflash config here
|
||||
| |- utils/
|
||||
| | |- src/
|
||||
| | |- __tests__/
|
||||
| | |- package.json <-- codeflash config here
|
||||
|- package.json <-- root workspace (no codeflash config needed)
|
||||
```
|
||||
|
||||
Codeflash uses Node.js's native V8 serialization API to capture and compare test data. Here's what makes it powerful:
|
||||
|
||||
### Type Preservation
|
||||
|
||||
Unlike JSON serialization, V8 serialization preserves JavaScript-specific types:
|
||||
|
||||
```javascript
|
||||
// These types are preserved perfectly:
|
||||
const testData = {
|
||||
date: new Date(), // ✅ Date objects
|
||||
map: new Map([['key', 'value']]), // ✅ Map instances
|
||||
set: new Set([1, 2, 3]), // ✅ Set instances
|
||||
buffer: Buffer.from('hello'), // ✅ Buffers
|
||||
typed: new Uint8Array([1, 2, 3]), // ✅ TypedArrays
|
||||
bigint: 9007199254740991n, // ✅ BigInt
|
||||
regex: /pattern/gi, // ✅ RegExp
|
||||
undef: undefined, // ✅ undefined (not null!)
|
||||
circular: {} // ✅ Circular references
|
||||
};
|
||||
testData.circular.self = testData.circular;
|
||||
```json
|
||||
// packages/core/package.json
|
||||
{
|
||||
"name": "@my-org/core",
|
||||
"codeflash": {
|
||||
"moduleRoot": "src",
|
||||
"testsRoot": "tests"
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
<Warning>
|
||||
**Why Not JSON?**
|
||||
|
||||
JSON serialization would cause bugs to slip through:
|
||||
- `Date` becomes string → date arithmetic fails silently
|
||||
- `Map` becomes `{}` → `.get()` calls return undefined
|
||||
- `undefined` becomes `null` → type checks break
|
||||
- TypedArrays become plain objects → binary operations fail
|
||||
|
||||
V8 serialization catches these issues during optimization verification.
|
||||
**Run codeflash from the package directory**, not the monorepo root. Codeflash needs to find the `package.json` with the `"codeflash"` config in the current working directory.
|
||||
</Warning>
|
||||
|
||||
## Try It Out!
|
||||
<Info>
|
||||
**Hoisted dependencies work fine.** If your monorepo hoists `node_modules` to the root (common in Yarn Berry, pnpm with `shamefully-hoist`), Codeflash resolves modules using Node.js standard resolution and will find them correctly.
|
||||
</Info>
|
||||
|
||||
<Tabs>
|
||||
<Tab title="Quick Start">
|
||||
Once configured, you can start optimizing your JavaScript/TypeScript code immediately:
|
||||
## Test Framework Support
|
||||
|
||||
```bash
|
||||
# Optimize a specific function
|
||||
codeflash --file path/to/your/file.js --function functionName
|
||||
| Framework | Status | Auto-detected from |
|
||||
|-----------|--------|-------------------|
|
||||
| **Jest** | Supported | `jest` in dependencies |
|
||||
| **Vitest** | Supported | `vitest` in dependencies |
|
||||
| **Mocha** | Coming soon | — |
|
||||
|
||||
# Or optimize all functions in your codebase
|
||||
<Info>
|
||||
**Functions must be exported** to be optimizable. Codeflash can only discover and optimize functions that are exported from their module (via `export`, `export default`, or `module.exports`).
|
||||
</Info>
|
||||
|
||||
## Try It Out
|
||||
|
||||
Once configured, optimize your code:
|
||||
|
||||
<CodeGroup>
|
||||
```bash Optimize a function
|
||||
codeflash --file src/utils.js --function processData
|
||||
```
|
||||
|
||||
```bash Optimize locally (no PR)
|
||||
codeflash --file src/utils.ts --function processData --no-pr
|
||||
```
|
||||
|
||||
```bash Optimize entire codebase
|
||||
codeflash --all
|
||||
```
|
||||
|
||||
</Tab>
|
||||
|
||||
<Tab title="TypeScript Support">
|
||||
Codeflash fully supports TypeScript projects:
|
||||
|
||||
```bash
|
||||
# Optimize TypeScript files directly
|
||||
codeflash --file src/utils.ts --function processData
|
||||
|
||||
# Works with TSX for React components
|
||||
codeflash --file src/components/DataTable.tsx --function DataTable
|
||||
```bash Trace and optimize
|
||||
codeflash optimize --jest
|
||||
```
|
||||
|
||||
<Info>
|
||||
Codeflash preserves TypeScript types during optimization. Your type annotations and interfaces remain intact.
|
||||
</Info>
|
||||
|
||||
</Tab>
|
||||
|
||||
<Tab title="Test Framework Examples">
|
||||
|
||||
<Accordion title="Jest Example">
|
||||
```javascript
|
||||
// sum.test.js
|
||||
test('adds 1 + 2 to equal 3', () => {
|
||||
expect(sum(1, 2)).toBe(3);
|
||||
});
|
||||
|
||||
// Optimize the sum function
|
||||
codeflash --file sum.js --function sum
|
||||
```
|
||||
</Accordion>
|
||||
|
||||
<Accordion title="Vitest Example">
|
||||
```javascript
|
||||
// calculator.test.js
|
||||
import { describe, it, expect } from 'vitest';
|
||||
|
||||
describe('calculator', () => {
|
||||
it('should multiply correctly', () => {
|
||||
expect(multiply(2, 3)).toBe(6);
|
||||
});
|
||||
});
|
||||
|
||||
// Optimize the multiply function
|
||||
codeflash --file calculator.js --function multiply
|
||||
```
|
||||
</Accordion>
|
||||
|
||||
</Tab>
|
||||
</Tabs>
|
||||
</CodeGroup>
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
<AccordionGroup>
|
||||
<Accordion title="📦 Module not found errors">
|
||||
Make sure:
|
||||
- ✅ All project dependencies are installed
|
||||
- ✅ Your `node_modules` directory exists
|
||||
<Accordion title="Function not found or not exported">
|
||||
Codeflash only optimizes **exported** functions. Make sure your function is exported:
|
||||
|
||||
```bash
|
||||
# Reinstall dependencies
|
||||
npm install
|
||||
# or
|
||||
yarn install
|
||||
```
|
||||
</Accordion>
|
||||
|
||||
<Accordion title="🔧 V8 serialization errors">
|
||||
If you encounter serialization errors:
|
||||
|
||||
**Functions and classes** cannot be serialized:
|
||||
```javascript
|
||||
// ❌ Won't work - contains function
|
||||
const data = { callback: () => {} };
|
||||
// ES Modules
|
||||
export function processData(data) { ... }
|
||||
// or
|
||||
const processData = (data) => { ... };
|
||||
export { processData };
|
||||
|
||||
// ✅ Works - pure data
|
||||
const data = { value: 42, items: [1, 2, 3] };
|
||||
// CommonJS
|
||||
function processData(data) { ... }
|
||||
module.exports = { processData };
|
||||
```
|
||||
|
||||
**Symbols** are not serializable:
|
||||
```javascript
|
||||
// ❌ Won't work
|
||||
const data = { [Symbol('key')]: 'value' };
|
||||
|
||||
// ✅ Use string keys
|
||||
const data = { key: 'value' };
|
||||
```
|
||||
If codeflash reports the function exists but is not exported, add an export statement.
|
||||
</Accordion>
|
||||
|
||||
<Accordion title="🧪 No optimizations found">
|
||||
Not all functions can be optimized - some code is already optimal. This is expected.
|
||||
<Accordion title="codeflash npm package not found / module errors">
|
||||
Ensure the codeflash npm package is installed in your project:
|
||||
|
||||
Use the `--verbose` flag for detailed output:
|
||||
```bash
|
||||
codeflash optimize --verbose
|
||||
<CodeGroup>
|
||||
```bash npm
|
||||
npm install --save-dev codeflash
|
||||
```
|
||||
```bash yarn
|
||||
yarn add --dev codeflash
|
||||
```
|
||||
```bash pnpm
|
||||
pnpm add --save-dev codeflash
|
||||
```
|
||||
</CodeGroup>
|
||||
|
||||
This will show:
|
||||
- 🔍 Which functions are being analyzed
|
||||
- 🚫 Why certain functions were skipped
|
||||
- ⚠️ Detailed error messages
|
||||
- 📊 Performance analysis results
|
||||
For **monorepos**, make sure it's installed in the package you're optimizing, or at the workspace root if dependencies are hoisted.
|
||||
</Accordion>
|
||||
|
||||
<Accordion title="🔍 Test discovery issues">
|
||||
Verify:
|
||||
- 📁 Your test directory path is correct in `codeflash.config.js`
|
||||
- 🔍 Tests are discoverable by your test framework
|
||||
- 📝 Test files follow naming conventions (`*.test.js`, `*.spec.js`)
|
||||
<Accordion title="Test framework not detected">
|
||||
Codeflash auto-detects the test framework from your `devDependencies`. If detection fails:
|
||||
|
||||
1. Verify your test framework is in `devDependencies`:
|
||||
```bash
|
||||
npm ls jest # or: npm ls vitest
|
||||
```
|
||||
2. Or set it manually in `package.json`:
|
||||
```json
|
||||
{
|
||||
"codeflash": {
|
||||
"testRunner": "jest"
|
||||
}
|
||||
}
|
||||
```
|
||||
</Accordion>
|
||||
|
||||
<Accordion title="Jest tests timing out">
|
||||
If Jest tests take too long, Codeflash has a default timeout. For large test suites:
|
||||
|
||||
- Use `--file` and `--function` to target specific functions instead of `--all`
|
||||
- Ensure your tests don't have expensive setup/teardown that runs for every test file
|
||||
- Check if `jest.config.js` has a `setupFiles` that takes a long time
|
||||
</Accordion>
|
||||
|
||||
<Accordion title="TypeScript compilation errors">
|
||||
Codeflash uses your project's TypeScript configuration. If you see TS errors:
|
||||
|
||||
1. Verify `npx tsc --noEmit` passes on its own
|
||||
2. Check that `tsconfig.json` is in the project root or the module root
|
||||
3. For projects using `moduleResolution: "bundler"`, Codeflash creates a temporary tsconfig overlay — this is expected behavior
|
||||
</Accordion>
|
||||
|
||||
<Accordion title="Monorepo: wrong package.json detected">
|
||||
Run codeflash from the correct package directory:
|
||||
|
||||
```bash
|
||||
# Test if your test framework can discover tests
|
||||
npm test -- --listTests # Jest
|
||||
# or
|
||||
npx vitest list # Vitest
|
||||
cd packages/my-library
|
||||
codeflash --file src/utils.ts --function myFunc
|
||||
```
|
||||
|
||||
If your monorepo tool hoists dependencies, you may need to ensure the `codeflash` npm package is accessible from the package directory. For pnpm, add `.npmrc` with `shamefully-hoist=true` or use `pnpm add --filter my-library --save-dev codeflash`.
|
||||
</Accordion>
|
||||
|
||||
<Accordion title="No optimizations found">
|
||||
Not all functions can be optimized — some code is already efficient. This is normal.
|
||||
|
||||
For better results:
|
||||
- Target functions with loops, string manipulation, or data transformations
|
||||
- Ensure the function has existing tests for correctness verification
|
||||
- Use `codeflash optimize --jest` to trace real execution and capture realistic inputs
|
||||
</Accordion>
|
||||
</AccordionGroup>
|
||||
|
||||
## Configuration
|
||||
## Configuration Reference
|
||||
|
||||
Your `codeflash.config.js` file controls how Codeflash analyzes your JavaScript project:
|
||||
|
||||
```javascript
|
||||
module.exports = {
|
||||
// Source code to optimize
|
||||
module: 'src',
|
||||
|
||||
// Test location
|
||||
tests: 'tests',
|
||||
|
||||
// Test framework
|
||||
testFramework: 'jest',
|
||||
|
||||
// Serialization strategy (automatically set to 'v8')
|
||||
serialization: 'v8',
|
||||
|
||||
// Formatter
|
||||
formatter: 'prettier',
|
||||
|
||||
// Additional options
|
||||
exclude: ['node_modules', 'dist', 'build'],
|
||||
verbose: false
|
||||
};
|
||||
```
|
||||
See [JavaScript / TypeScript Configuration](/configuration/javascript) for the full list of options.
|
||||
|
||||
### Next Steps
|
||||
|
||||
- Learn about [Codeflash Concepts](/codeflash-concepts/how-codeflash-works)
|
||||
- Explore [Optimization workflows](/optimizing-with-codeflash/one-function)
|
||||
- Learn [how Codeflash works](/codeflash-concepts/how-codeflash-works)
|
||||
- [Optimize a single function](/optimizing-with-codeflash/one-function)
|
||||
- Set up [Pull Request Optimization](/optimizing-with-codeflash/codeflash-github-actions)
|
||||
- Read [configuration options](/configuration) for advanced setups
|
||||
- Explore [Trace and Optimize](/optimizing-with-codeflash/trace-and-optimize) for workflow optimization
|
||||
|
|
|
|||
|
|
@ -1,27 +1,38 @@
|
|||
---
|
||||
title: "Codeflash is an AI performance optimizer for Python code"
|
||||
title: "Codeflash is an AI performance optimizer for your code"
|
||||
icon: "rocket"
|
||||
sidebarTitle: "Overview"
|
||||
keywords: ["python", "performance", "optimization", "AI", "code analysis", "benchmarking"]
|
||||
keywords: ["python", "javascript", "typescript", "performance", "optimization", "AI", "code analysis", "benchmarking"]
|
||||
---
|
||||
|
||||
Codeflash speeds up any Python code by figuring out the best way to rewrite it while verifying that the behavior of the code is unchanged, and verifying real speed
|
||||
gains through performance benchmarking.
|
||||
Codeflash speeds up your code by figuring out the best way to rewrite it while verifying that the behavior is unchanged, and verifying real speed
|
||||
gains through performance benchmarking. It supports **Python**, **JavaScript**, and **TypeScript**.
|
||||
|
||||
The optimizations Codeflash finds are generally better algorithms, opportunities to remove wasteful compute, better logic, utilizing caching and utilization of more efficient library methods. Codeflash
|
||||
does not modify the system architecture of your code, but it tries to find the most efficient implementation of your current architecture.
|
||||
|
||||
### Get Started
|
||||
|
||||
<CardGroup cols={2}>
|
||||
<Card title="Python Setup" icon="python" href="/getting-started/local-installation">
|
||||
Install via pip, uv, or poetry
|
||||
</Card>
|
||||
<Card title="JavaScript / TypeScript Setup" icon="js" href="/getting-started/javascript-installation">
|
||||
Install via npm, yarn, pnpm, or bun
|
||||
</Card>
|
||||
</CardGroup>
|
||||
|
||||
### How to use Codeflash
|
||||
|
||||
<CardGroup cols={1}>
|
||||
<Card title="Optimize a Single Function" icon="bullseye" href="/optimizing-with-codeflash/one-function">
|
||||
Target and optimize individual Python functions for maximum performance gains.
|
||||
Target and optimize individual functions for maximum performance gains.
|
||||
```bash
|
||||
codeflash --file path.py --function my_function
|
||||
codeflash --file path/to/file --function my_function
|
||||
```
|
||||
</Card>
|
||||
|
||||
<Card title="Optimize Pull Requests" icon="code-pull-request" href="/optimizing-with-codeflash/codeflash-github-actions">
|
||||
<Card title="Optimize Pull Requests" icon="code-pull-request" href="/optimizing-with-codeflash/codeflash-github-actions">
|
||||
Automatically find optimizations for Pull Requests with GitHub Actions integration.
|
||||
```bash
|
||||
codeflash init-actions
|
||||
|
|
@ -29,7 +40,7 @@ does not modify the system architecture of your code, but it tries to find the m
|
|||
</Card>
|
||||
|
||||
<Card title="Optimize Workflows with Tracing" icon="route" href="/optimizing-with-codeflash/trace-and-optimize">
|
||||
End-to-end optimization of entire Python workflows with execution tracing.
|
||||
End-to-end optimization of entire workflows with execution tracing.
|
||||
```bash
|
||||
codeflash optimize myscript.py
|
||||
```
|
||||
|
|
@ -42,7 +53,6 @@ does not modify the system architecture of your code, but it tries to find the m
|
|||
```
|
||||
</Card>
|
||||
|
||||
|
||||
</CardGroup>
|
||||
|
||||
### How does Codeflash verify correctness?
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
---
|
||||
title: "Optimize Performance Benchmarks with every Pull Request"
|
||||
description: "Configure and use pytest-benchmark integration for performance-critical code optimization"
|
||||
description: "Configure and use benchmark integration for performance-critical code optimization"
|
||||
icon: "chart-line"
|
||||
sidebarTitle: Setup Benchmarks to Optimize
|
||||
keywords:
|
||||
|
|
@ -26,6 +26,10 @@ It will then try to optimize the new code for the benchmark and calculate the im
|
|||
|
||||
## Using Codeflash in Benchmark Mode
|
||||
|
||||
<Note>
|
||||
Benchmark mode currently supports Python projects using pytest-benchmark. JavaScript/TypeScript benchmark support is coming soon.
|
||||
</Note>
|
||||
|
||||
1. **Create a benchmarks root:**
|
||||
|
||||
Create a directory for benchmarks if it does not already exist.
|
||||
|
|
@ -44,7 +48,7 @@ It will then try to optimize the new code for the benchmark and calculate the im
|
|||
|
||||
2. **Define your benchmarks:**
|
||||
|
||||
Currently, Codeflash only supports benchmarks written as pytest-benchmarks. Check out the [pytest-benchmark](https://pytest-benchmark.readthedocs.io/en/stable/index.html) documentation for more information on syntax.
|
||||
Codeflash supports benchmarks written as pytest-benchmarks. Check out the [pytest-benchmark](https://pytest-benchmark.readthedocs.io/en/stable/index.html) documentation for more information on syntax.
|
||||
|
||||
For example:
|
||||
|
||||
|
|
@ -58,7 +62,7 @@ It will then try to optimize the new code for the benchmark and calculate the im
|
|||
|
||||
Note that these benchmarks should be defined in such a way that they don't take a long time to run.
|
||||
|
||||
The pytest-benchmark format is simply used as an interface. The plugin is actually not used - Codeflash will run these benchmarks with its own pytest plugin
|
||||
The pytest-benchmark format is simply used as an interface. The plugin is actually not used - Codeflash will run these benchmarks with its own pytest plugin.
|
||||
|
||||
3. **Run and Test Codeflash:**
|
||||
|
||||
|
|
@ -74,7 +78,7 @@ It will then try to optimize the new code for the benchmark and calculate the im
|
|||
codeflash --file test_file.py --benchmark --benchmarks-root path/to/benchmarks
|
||||
```
|
||||
|
||||
4. **Run Codeflash :**
|
||||
4. **Run Codeflash with GitHub Actions:**
|
||||
|
||||
Benchmark mode is best used together with Codeflash as a GitHub Action. This way,
|
||||
Codeflash will trace through your benchmark and optimize the functions modified in your Pull Request to speed up the benchmark.
|
||||
|
|
|
|||
|
|
@ -3,13 +3,13 @@ title: "Optimize Your Entire Codebase"
|
|||
description: "Automatically optimize all codepaths in your project with Codeflash's comprehensive analysis"
|
||||
icon: "database"
|
||||
sidebarTitle: "Optimize Entire Codebase"
|
||||
keywords: ["codebase optimization", "all functions", "batch optimization", "github app", "checkpoint", "recovery"]
|
||||
keywords: ["codebase optimization", "all functions", "batch optimization", "github app", "checkpoint", "recovery", "javascript", "typescript", "python"]
|
||||
---
|
||||
|
||||
# Optimize your entire codebase
|
||||
|
||||
Codeflash can optimize your entire codebase by analyzing all the functions in your project and generating optimized versions of them.
|
||||
It iterates through all the functions in your codebase and optimizes them one by one.
|
||||
It iterates through all the functions in your codebase and optimizes them one by one. This works for Python, JavaScript, and TypeScript projects.
|
||||
|
||||
To optimize your entire codebase, run the following command in your project directory:
|
||||
|
||||
|
|
@ -30,15 +30,27 @@ codeflash --all path/to/dir
|
|||
```
|
||||
|
||||
<Tip>
|
||||
If your project has a good number of unit tests, we can trace those to achieve higher quality results.
|
||||
The following approach is recommended instead:
|
||||
If your project has a good number of unit tests, tracing them achieves higher quality results.
|
||||
|
||||
<Tabs>
|
||||
<Tab title="Python">
|
||||
```bash
|
||||
codeflash optimize --trace-only -m pytest tests/ ; codeflash --all
|
||||
```
|
||||
This will run your test suite, trace all the code covered by your tests, ensuring higher correctness guarantees
|
||||
and better performance benchmarking, and help create optimizations for code where the LLMs struggle to generate and run tests.
|
||||
</Tab>
|
||||
<Tab title="JavaScript / TypeScript">
|
||||
```bash
|
||||
codeflash optimize --trace-only --jest ; codeflash --all
|
||||
# or for Vitest projects
|
||||
codeflash optimize --trace-only --vitest ; codeflash --all
|
||||
```
|
||||
</Tab>
|
||||
</Tabs>
|
||||
|
||||
Even though `codeflash --all` discovers any existing unit tests. It currently can only discover any test that directly calls the
|
||||
This runs your test suite, traces all the code covered by your tests, ensuring higher correctness guarantees
|
||||
and better performance benchmarking, and helps create optimizations for code where the LLMs struggle to generate and run tests.
|
||||
|
||||
`codeflash --all` discovers any existing unit tests, but it currently can only discover tests that directly call the
|
||||
function under optimization. Tracing all the tests helps ensure correctness for code that may be indirectly called by your tests.
|
||||
|
||||
</Tip>
|
||||
|
|
|
|||
|
|
@ -26,9 +26,9 @@ We highly recommend setting this up, since once you set it up all your new code
|
|||
|
||||
✅ A Codeflash API key from the [Codeflash Web App](https://app.codeflash.ai/)
|
||||
|
||||
✅ Completed [local installation](/getting-started/local-installation) with `codeflash init`
|
||||
✅ Completed local installation with `codeflash init` ([Python](/getting-started/local-installation) or [JavaScript/TypeScript](/getting-started/javascript-installation))
|
||||
|
||||
✅ A Python project with a configured `pyproject.toml` file
|
||||
✅ A configured project (`pyproject.toml` for Python, `package.json` for JavaScript/TypeScript)
|
||||
</Warning>
|
||||
|
||||
## Setup Options
|
||||
|
|
@ -113,7 +113,7 @@ jobs:
|
|||
</Step>
|
||||
|
||||
<Step title="Choose Your Package Manager">
|
||||
Customize the dependency installation based on your Python package manager:
|
||||
Customize the dependency installation based on your package manager:
|
||||
|
||||
The workflow will need to be set up in such a way the Codeflash can create and
|
||||
run tests for functionality and speed, so the stock YAML may need to be altered to
|
||||
|
|
@ -121,7 +121,7 @@ suit the specific codebase. Typically the setup steps for a unit test workflow c
|
|||
be copied.
|
||||
|
||||
<CodeGroup>
|
||||
```yaml Poetry
|
||||
```yaml Poetry (Python)
|
||||
- name: Install Project Dependencies
|
||||
run: |
|
||||
python -m pip install --upgrade pip
|
||||
|
|
@ -129,11 +129,11 @@ be copied.
|
|||
poetry install --with dev
|
||||
- name: Run Codeflash to optimize code
|
||||
run: |
|
||||
poetry env use python
|
||||
poetry env use python
|
||||
poetry run codeflash
|
||||
```
|
||||
|
||||
```yaml uv
|
||||
```yaml uv (Python)
|
||||
- uses: astral-sh/setup-uv@v6
|
||||
with:
|
||||
enable-cache: true
|
||||
|
|
@ -142,7 +142,7 @@ be copied.
|
|||
run: uv run codeflash
|
||||
```
|
||||
|
||||
```yaml pip
|
||||
```yaml pip (Python)
|
||||
- name: Install Project Dependencies
|
||||
run: |
|
||||
python -m pip install --upgrade pip
|
||||
|
|
@ -151,7 +151,50 @@ be copied.
|
|||
- name: Run Codeflash to optimize code
|
||||
run: codeflash
|
||||
```
|
||||
|
||||
```yaml npm (JavaScript/TypeScript)
|
||||
- uses: actions/setup-node@v4
|
||||
with:
|
||||
node-version: '18'
|
||||
- name: Install Project Dependencies
|
||||
run: npm ci
|
||||
- name: Run Codeflash to optimize code
|
||||
run: npx codeflash
|
||||
```
|
||||
|
||||
```yaml yarn (JavaScript/TypeScript)
|
||||
- uses: actions/setup-node@v4
|
||||
with:
|
||||
node-version: '18'
|
||||
- name: Install Project Dependencies
|
||||
run: yarn install --immutable
|
||||
- name: Run Codeflash to optimize code
|
||||
run: yarn codeflash
|
||||
```
|
||||
|
||||
```yaml pnpm (JavaScript/TypeScript)
|
||||
- uses: pnpm/action-setup@v4
|
||||
with:
|
||||
version: 9
|
||||
- uses: actions/setup-node@v4
|
||||
with:
|
||||
node-version: '18'
|
||||
cache: 'pnpm'
|
||||
- name: Install Project Dependencies
|
||||
run: pnpm install --frozen-lockfile
|
||||
- name: Run Codeflash to optimize code
|
||||
run: pnpm codeflash
|
||||
```
|
||||
</CodeGroup>
|
||||
|
||||
<Info>
|
||||
**Monorepo?** If your codeflash config is in a subdirectory, add `working-directory` to the steps:
|
||||
```yaml
|
||||
- name: Run Codeflash to optimize code
|
||||
run: npx codeflash
|
||||
working-directory: packages/my-library
|
||||
```
|
||||
</Info>
|
||||
</Step>
|
||||
|
||||
<Step title="Add Repository Secret">
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
---
|
||||
title: "Optimize a Single Function"
|
||||
description: "Target and optimize individual Python functions for maximum performance gains"
|
||||
description: "Target and optimize individual functions for maximum performance gains"
|
||||
icon: "bullseye"
|
||||
sidebarTitle: "Optimize Single Function"
|
||||
keywords:
|
||||
|
|
@ -10,6 +10,9 @@ keywords:
|
|||
"class methods",
|
||||
"performance",
|
||||
"targeted optimization",
|
||||
"javascript",
|
||||
"typescript",
|
||||
"python",
|
||||
]
|
||||
---
|
||||
|
||||
|
|
@ -24,23 +27,55 @@ your mileage may vary.
|
|||
|
||||
## How to optimize a function
|
||||
|
||||
To optimize a function, you can run the following command in your project:
|
||||
To optimize a function, run the following command in your project:
|
||||
|
||||
<Tabs>
|
||||
<Tab title="Python">
|
||||
```bash
|
||||
codeflash --file path/to/your/file.py --function function_name
|
||||
```
|
||||
</Tab>
|
||||
<Tab title="JavaScript">
|
||||
```bash
|
||||
codeflash --file path/to/your/file.js --function functionName
|
||||
```
|
||||
</Tab>
|
||||
<Tab title="TypeScript">
|
||||
```bash
|
||||
codeflash --file path/to/your/file.ts --function functionName
|
||||
```
|
||||
</Tab>
|
||||
</Tabs>
|
||||
|
||||
If you have installed the GitHub App to your repository, the above command will open a pull request with the optimized function.
|
||||
If you want to optimize a function locally, you can add a `--no-pr` argument as follows:
|
||||
If you want to optimize a function locally, add a `--no-pr` argument:
|
||||
|
||||
<Tabs>
|
||||
<Tab title="Python">
|
||||
```bash
|
||||
codeflash --file path/to/your/file.py --function function_name --no-pr
|
||||
```
|
||||
</Tab>
|
||||
<Tab title="JavaScript / TypeScript">
|
||||
```bash
|
||||
codeflash --file path/to/your/file.ts --function functionName --no-pr
|
||||
```
|
||||
</Tab>
|
||||
</Tabs>
|
||||
|
||||
### Optimizing class methods
|
||||
|
||||
To optimize a method `method_name` in a class `ClassName`, you can run the following command:
|
||||
To optimize a method `methodName` in a class `ClassName`:
|
||||
|
||||
<Tabs>
|
||||
<Tab title="Python">
|
||||
```bash
|
||||
codeflash --file path/to/your/file.py --function ClassName.method_name
|
||||
```
|
||||
</Tab>
|
||||
<Tab title="JavaScript / TypeScript">
|
||||
```bash
|
||||
codeflash --file path/to/your/file.ts --function ClassName.methodName
|
||||
```
|
||||
</Tab>
|
||||
</Tabs>
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
---
|
||||
title: "Trace & Optimize E2E Workflows"
|
||||
description: "End-to-end optimization of entire Python workflows with execution tracing"
|
||||
description: "End-to-end optimization of entire workflows with execution tracing"
|
||||
icon: "route"
|
||||
sidebarTitle: "Optimize E2E Workflows"
|
||||
keywords:
|
||||
|
|
@ -11,28 +11,50 @@ keywords:
|
|||
"end-to-end",
|
||||
"script optimization",
|
||||
"context manager",
|
||||
"javascript",
|
||||
"typescript",
|
||||
"jest",
|
||||
"vitest",
|
||||
]
|
||||
---
|
||||
|
||||
Codeflash can optimize an entire Python script end-to-end by tracing the script's execution and generating Replay Tests.
|
||||
Tracing follows the execution of a script, profiles it and captures inputs to all functions it called, allowing them to be replayed during optimization.
|
||||
Codeflash uses these Replay Tests to optimize the most important functions called in the script, delivering the best performance for your workflow.
|
||||
Codeflash can optimize an entire script or test suite end-to-end by tracing its execution and generating Replay Tests.
|
||||
Tracing follows the execution of your code, profiles it and captures inputs to all functions it called, allowing them to be replayed during optimization.
|
||||
Codeflash uses these Replay Tests to optimize the most important functions called in the workflow, delivering the best performance.
|
||||
|
||||

|
||||
|
||||
To optimize a script, `python myscript.py`, simply replace `python` with `codeflash optimize` and run the following command:
|
||||
<Tabs>
|
||||
<Tab title="Python">
|
||||
To optimize a script, `python myscript.py`, simply replace `python` with `codeflash optimize`:
|
||||
|
||||
```bash
|
||||
codeflash optimize myscript.py
|
||||
```
|
||||
|
||||
You can also optimize code called by pytest tests that you could normally run like `python -m pytest tests/`, this provides for a good workload to optimize. Run this command:
|
||||
You can also optimize code called by pytest tests:
|
||||
|
||||
```bash
|
||||
codeflash optimize -m pytest tests/
|
||||
```
|
||||
</Tab>
|
||||
<Tab title="JavaScript / TypeScript">
|
||||
To trace and optimize your Jest or Vitest tests:
|
||||
|
||||
The powerful `codeflash optimize` command creates high-quality optimizations, making it ideal when you need to optimize a workflow or script. The initial tracing process can be slow, so try to limit your script's runtime to under 1 minute for best results. If your workflow is longer, consider tracing it into smaller sections by using the Codeflash tracer as a context manager (point 3 below).
|
||||
```bash
|
||||
# Jest
|
||||
codeflash optimize --jest
|
||||
|
||||
# Vitest
|
||||
codeflash optimize --vitest
|
||||
|
||||
# Or trace a specific script
|
||||
codeflash optimize --language javascript script.js
|
||||
```
|
||||
</Tab>
|
||||
</Tabs>
|
||||
|
||||
The `codeflash optimize` command creates high-quality optimizations, making it ideal when you need to optimize a workflow or script. The initial tracing process can be slow, so try to limit your script's runtime to under 1 minute for best results.
|
||||
|
||||
The generated replay tests and the trace file are for the immediate optimization use, don't add them to git.
|
||||
|
||||
|
|
@ -61,6 +83,9 @@ This way you can be _sure_ that the optimized function causes no changes of beha
|
|||
|
||||
## Using codeflash optimize
|
||||
|
||||
<Tabs>
|
||||
<Tab title="Python">
|
||||
|
||||
Codeflash script optimizer can be used in three ways:
|
||||
|
||||
1. **As an integrated command**
|
||||
|
|
@ -100,10 +125,10 @@ Codeflash script optimizer can be used in three ways:
|
|||
|
||||
- `--timeout`: The maximum time in seconds to trace the entire workflow. Default is indefinite. This is useful while tracing really long workflows.
|
||||
|
||||
3. **As a Context Manager -**
|
||||
3. **As a Context Manager**
|
||||
|
||||
To trace only specific sections of your code, You can also use the Codeflash Tracer as a context manager.
|
||||
You can wrap the code you want to trace in a `with` statement as follows -
|
||||
To trace only specific sections of your code, you can use the Codeflash Tracer as a context manager.
|
||||
You can wrap the code you want to trace in a `with` statement as follows:
|
||||
|
||||
```python
|
||||
from codeflash.tracer import Tracer
|
||||
|
|
@ -128,3 +153,46 @@ Codeflash script optimizer can be used in three ways:
|
|||
- `output`: The file to save the trace to. Default is `codeflash.trace`.
|
||||
- `config_file_path`: The path to the `pyproject.toml` file which stores the Codeflash config. This is auto-discovered by default.
|
||||
You can also disable the tracer in the code by setting the `disable=True` option in the `Tracer` constructor.
|
||||
|
||||
</Tab>
|
||||
<Tab title="JavaScript / TypeScript">
|
||||
|
||||
The JavaScript tracer uses Babel instrumentation to capture function calls during your test suite execution.
|
||||
|
||||
1. **Trace your test suite**
|
||||
|
||||
```bash
|
||||
# Jest projects
|
||||
codeflash optimize --jest
|
||||
|
||||
# Vitest projects
|
||||
codeflash optimize --vitest
|
||||
|
||||
# Trace a specific script
|
||||
codeflash optimize --language javascript src/main.js
|
||||
```
|
||||
|
||||
2. **Trace specific functions only**
|
||||
|
||||
```bash
|
||||
codeflash optimize --jest --only-functions processData,transformInput
|
||||
```
|
||||
|
||||
3. **Trace and optimize as two separate steps**
|
||||
|
||||
```bash
|
||||
# Step 1: Create trace file
|
||||
codeflash optimize --trace-only --jest --output trace_file.sqlite
|
||||
|
||||
# Step 2: Optimize with replay tests
|
||||
codeflash --replay-test /path/to/test_replay_test_0.test.js
|
||||
```
|
||||
|
||||
More Options:
|
||||
|
||||
- `--timeout`: Maximum tracing time in seconds.
|
||||
- `--max-function-count`: Maximum traces per function (default: 256).
|
||||
- `--only-functions`: Comma-separated list of function names to trace.
|
||||
|
||||
</Tab>
|
||||
</Tabs>
|
||||
|
|
|
|||
4
packages/codeflash/package-lock.json
generated
4
packages/codeflash/package-lock.json
generated
|
|
@ -1,12 +1,12 @@
|
|||
{
|
||||
"name": "codeflash",
|
||||
"version": "0.8.0",
|
||||
"version": "0.9.0",
|
||||
"lockfileVersion": 3,
|
||||
"requires": true,
|
||||
"packages": {
|
||||
"": {
|
||||
"name": "codeflash",
|
||||
"version": "0.8.0",
|
||||
"version": "0.9.0",
|
||||
"hasInstallScript": true,
|
||||
"license": "MIT",
|
||||
"dependencies": {
|
||||
|
|
|
|||
|
|
@ -220,7 +220,7 @@ def validate_output(stdout: str, return_code: int, expected_improvement_pct: int
|
|||
return False
|
||||
|
||||
if config.expected_unit_test_files is not None:
|
||||
# Match the per-function test discovery message from function_optimizer.py
|
||||
# Match the per-function discovery message from function_optimizer.py
|
||||
# Format: "Discovered X existing unit test files, Y replay test files, and Z concolic..."
|
||||
unit_test_files_match = re.search(r"Discovered (\d+) existing unit test files?", stdout)
|
||||
if not unit_test_files_match:
|
||||
|
|
|
|||
|
|
@ -1,4 +1,4 @@
|
|||
from codeflash.code_utils.deduplicate_code import are_codes_duplicate, normalize_code
|
||||
from codeflash.languages.python.normalizer import normalize_python_code as normalize_code
|
||||
|
||||
|
||||
def test_deduplicate1():
|
||||
|
|
@ -23,7 +23,7 @@ def compute_sum(numbers):
|
|||
"""
|
||||
|
||||
assert normalize_code(code1) == normalize_code(code2)
|
||||
assert are_codes_duplicate(code1, code2)
|
||||
assert normalize_code(code1) == normalize_code(code2)
|
||||
|
||||
# Example 3: Same function and parameter names, different local variables (should match)
|
||||
code3 = """
|
||||
|
|
@ -43,7 +43,7 @@ def calculate_sum(numbers):
|
|||
"""
|
||||
|
||||
assert normalize_code(code3) == normalize_code(code4)
|
||||
assert are_codes_duplicate(code3, code4)
|
||||
assert normalize_code(code3) == normalize_code(code4)
|
||||
|
||||
# Example 4: Nested functions and classes (preserving names)
|
||||
code5 = """
|
||||
|
|
|
|||
|
|
@ -18,7 +18,7 @@ from codeflash.languages.python.static_analysis.code_replacer import (
|
|||
)
|
||||
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
|
||||
from codeflash.models.models import CodeOptimizationContext, CodeStringsMarkdown, FunctionParent, FunctionSource
|
||||
from codeflash.optimization.function_optimizer import FunctionOptimizer
|
||||
from codeflash.languages.python.function_optimizer import PythonFunctionOptimizer
|
||||
from codeflash.verification.verification_utils import TestConfig
|
||||
|
||||
os.environ["CODEFLASH_API_KEY"] = "cf-test-key"
|
||||
|
|
@ -54,7 +54,7 @@ def sorter(arr):
|
|||
test_framework="pytest",
|
||||
pytest_cmd="pytest",
|
||||
)
|
||||
func_optimizer = FunctionOptimizer(function_to_optimize=func, test_cfg=test_config)
|
||||
func_optimizer = PythonFunctionOptimizer(function_to_optimize=func, test_cfg=test_config)
|
||||
code_context: CodeOptimizationContext = func_optimizer.get_code_optimization_context().unwrap()
|
||||
original_helper_code: dict[Path, str] = {}
|
||||
helper_function_paths = {hf.file_path for hf in code_context.helper_functions}
|
||||
|
|
@ -834,7 +834,7 @@ class MainClass:
|
|||
test_framework="pytest",
|
||||
pytest_cmd="pytest",
|
||||
)
|
||||
func_optimizer = FunctionOptimizer(function_to_optimize=func_top_optimize, test_cfg=test_config)
|
||||
func_optimizer = PythonFunctionOptimizer(function_to_optimize=func_top_optimize, test_cfg=test_config)
|
||||
code_context = func_optimizer.get_code_optimization_context().unwrap()
|
||||
assert code_context.testgen_context.flat.rstrip() == get_code_output.rstrip()
|
||||
|
||||
|
|
@ -1745,7 +1745,7 @@ class NewClass:
|
|||
test_framework="pytest",
|
||||
pytest_cmd="pytest",
|
||||
)
|
||||
func_optimizer = FunctionOptimizer(function_to_optimize=func, test_cfg=test_config)
|
||||
func_optimizer = PythonFunctionOptimizer(function_to_optimize=func, test_cfg=test_config)
|
||||
code_context: CodeOptimizationContext = func_optimizer.get_code_optimization_context().unwrap()
|
||||
original_helper_code: dict[Path, str] = {}
|
||||
helper_function_paths = {hf.file_path for hf in code_context.helper_functions}
|
||||
|
|
@ -1824,7 +1824,7 @@ a=2
|
|||
test_framework="pytest",
|
||||
pytest_cmd="pytest",
|
||||
)
|
||||
func_optimizer = FunctionOptimizer(function_to_optimize=func, test_cfg=test_config)
|
||||
func_optimizer = PythonFunctionOptimizer(function_to_optimize=func, test_cfg=test_config)
|
||||
code_context: CodeOptimizationContext = func_optimizer.get_code_optimization_context().unwrap()
|
||||
original_helper_code: dict[Path, str] = {}
|
||||
helper_function_paths = {hf.file_path for hf in code_context.helper_functions}
|
||||
|
|
@ -1904,7 +1904,7 @@ class NewClass:
|
|||
test_framework="pytest",
|
||||
pytest_cmd="pytest",
|
||||
)
|
||||
func_optimizer = FunctionOptimizer(function_to_optimize=func, test_cfg=test_config)
|
||||
func_optimizer = PythonFunctionOptimizer(function_to_optimize=func, test_cfg=test_config)
|
||||
code_context: CodeOptimizationContext = func_optimizer.get_code_optimization_context().unwrap()
|
||||
original_helper_code: dict[Path, str] = {}
|
||||
helper_function_paths = {hf.file_path for hf in code_context.helper_functions}
|
||||
|
|
@ -1983,7 +1983,7 @@ class NewClass:
|
|||
test_framework="pytest",
|
||||
pytest_cmd="pytest",
|
||||
)
|
||||
func_optimizer = FunctionOptimizer(function_to_optimize=func, test_cfg=test_config)
|
||||
func_optimizer = PythonFunctionOptimizer(function_to_optimize=func, test_cfg=test_config)
|
||||
code_context: CodeOptimizationContext = func_optimizer.get_code_optimization_context().unwrap()
|
||||
original_helper_code: dict[Path, str] = {}
|
||||
helper_function_paths = {hf.file_path for hf in code_context.helper_functions}
|
||||
|
|
@ -2063,7 +2063,7 @@ class NewClass:
|
|||
test_framework="pytest",
|
||||
pytest_cmd="pytest",
|
||||
)
|
||||
func_optimizer = FunctionOptimizer(function_to_optimize=func, test_cfg=test_config)
|
||||
func_optimizer = PythonFunctionOptimizer(function_to_optimize=func, test_cfg=test_config)
|
||||
code_context: CodeOptimizationContext = func_optimizer.get_code_optimization_context().unwrap()
|
||||
original_helper_code: dict[Path, str] = {}
|
||||
helper_function_paths = {hf.file_path for hf in code_context.helper_functions}
|
||||
|
|
@ -2153,7 +2153,7 @@ class NewClass:
|
|||
test_framework="pytest",
|
||||
pytest_cmd="pytest",
|
||||
)
|
||||
func_optimizer = FunctionOptimizer(function_to_optimize=func, test_cfg=test_config)
|
||||
func_optimizer = PythonFunctionOptimizer(function_to_optimize=func, test_cfg=test_config)
|
||||
code_context: CodeOptimizationContext = func_optimizer.get_code_optimization_context().unwrap()
|
||||
original_helper_code: dict[Path, str] = {}
|
||||
helper_function_paths = {hf.file_path for hf in code_context.helper_functions}
|
||||
|
|
@ -3453,7 +3453,7 @@ def hydrate_input_text_actions_with_field_names(
|
|||
test_framework="pytest",
|
||||
pytest_cmd="pytest",
|
||||
)
|
||||
func_optimizer = FunctionOptimizer(function_to_optimize=func, test_cfg=test_config)
|
||||
func_optimizer = PythonFunctionOptimizer(function_to_optimize=func, test_cfg=test_config)
|
||||
code_context: CodeOptimizationContext = func_optimizer.get_code_optimization_context().unwrap()
|
||||
|
||||
original_helper_code: dict[Path, str] = {}
|
||||
|
|
|
|||
|
|
@ -8,7 +8,7 @@ from codeflash.code_utils.code_utils import get_run_tmp_file
|
|||
from codeflash.code_utils.compat import SAFE_SYS_EXECUTABLE
|
||||
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
|
||||
from codeflash.models.models import FunctionParent, TestFile, TestFiles, TestingMode, TestType, VerificationType
|
||||
from codeflash.optimization.function_optimizer import FunctionOptimizer
|
||||
from codeflash.languages.python.function_optimizer import PythonFunctionOptimizer
|
||||
from codeflash.verification.equivalence import compare_test_results
|
||||
from codeflash.verification.instrument_codeflash_capture import instrument_codeflash_capture
|
||||
from codeflash.verification.test_runner import execute_test_subprocess
|
||||
|
|
@ -459,7 +459,7 @@ class MyClass:
|
|||
file_path=sample_code_path,
|
||||
parents=[FunctionParent(name="MyClass", type="ClassDef")],
|
||||
)
|
||||
func_optimizer = FunctionOptimizer(function_to_optimize=fto, test_cfg=test_config)
|
||||
func_optimizer = PythonFunctionOptimizer(function_to_optimize=fto, test_cfg=test_config)
|
||||
func_optimizer.test_files = TestFiles(
|
||||
test_files=[
|
||||
TestFile(
|
||||
|
|
@ -582,7 +582,7 @@ class MyClass(ParentClass):
|
|||
file_path=sample_code_path,
|
||||
parents=[FunctionParent(name="MyClass", type="ClassDef")],
|
||||
)
|
||||
func_optimizer = FunctionOptimizer(function_to_optimize=fto, test_cfg=test_config)
|
||||
func_optimizer = PythonFunctionOptimizer(function_to_optimize=fto, test_cfg=test_config)
|
||||
func_optimizer.test_files = TestFiles(
|
||||
test_files=[
|
||||
TestFile(
|
||||
|
|
@ -709,7 +709,7 @@ class MyClass:
|
|||
file_path=sample_code_path,
|
||||
parents=[FunctionParent(name="MyClass", type="ClassDef")],
|
||||
)
|
||||
func_optimizer = FunctionOptimizer(function_to_optimize=fto, test_cfg=test_config)
|
||||
func_optimizer = PythonFunctionOptimizer(function_to_optimize=fto, test_cfg=test_config)
|
||||
func_optimizer.test_files = TestFiles(
|
||||
test_files=[
|
||||
TestFile(
|
||||
|
|
@ -872,7 +872,7 @@ class AnotherHelperClass:
|
|||
file_path=fto_file_path,
|
||||
parents=[FunctionParent(name="MyClass", type="ClassDef")],
|
||||
)
|
||||
func_optimizer = FunctionOptimizer(function_to_optimize=fto, test_cfg=test_config)
|
||||
func_optimizer = PythonFunctionOptimizer(function_to_optimize=fto, test_cfg=test_config)
|
||||
func_optimizer.test_files = TestFiles(
|
||||
test_files=[
|
||||
TestFile(
|
||||
|
|
@ -1021,7 +1021,7 @@ class AnotherHelperClass:
|
|||
test_framework="pytest",
|
||||
pytest_cmd="pytest",
|
||||
)
|
||||
func_optimizer = FunctionOptimizer(function_to_optimize=fto, test_cfg=test_config)
|
||||
func_optimizer = PythonFunctionOptimizer(function_to_optimize=fto, test_cfg=test_config)
|
||||
func_optimizer.test_files = TestFiles(
|
||||
test_files=[
|
||||
TestFile(
|
||||
|
|
@ -1055,7 +1055,7 @@ class AnotherHelperClass:
|
|||
)
|
||||
|
||||
# Remove instrumentation
|
||||
FunctionOptimizer.write_code_and_helpers(candidate_fto_code, candidate_helper_code, fto.file_path)
|
||||
PythonFunctionOptimizer.write_code_and_helpers(candidate_fto_code, candidate_helper_code, fto.file_path)
|
||||
|
||||
assert len(test_results.test_results) == 4
|
||||
assert test_results[0].id.test_function_name == "test_helper_classes"
|
||||
|
|
@ -1106,7 +1106,7 @@ class MyClass:
|
|||
testing_time=0.1,
|
||||
)
|
||||
# Remove instrumentation
|
||||
FunctionOptimizer.write_code_and_helpers(candidate_fto_code, candidate_helper_code, fto.file_path)
|
||||
PythonFunctionOptimizer.write_code_and_helpers(candidate_fto_code, candidate_helper_code, fto.file_path)
|
||||
|
||||
# Now, this fto_code mutates the instance so it should fail
|
||||
mutated_fto_code = """
|
||||
|
|
@ -1145,7 +1145,7 @@ class MyClass:
|
|||
testing_time=0.1,
|
||||
)
|
||||
# Remove instrumentation
|
||||
FunctionOptimizer.write_code_and_helpers(candidate_fto_code, candidate_helper_code, fto.file_path)
|
||||
PythonFunctionOptimizer.write_code_and_helpers(candidate_fto_code, candidate_helper_code, fto.file_path)
|
||||
match, _ = compare_test_results(test_results, mutated_test_results)
|
||||
assert not match
|
||||
|
||||
|
|
@ -1184,7 +1184,7 @@ class MyClass:
|
|||
testing_time=0.1,
|
||||
)
|
||||
# Remove instrumentation
|
||||
FunctionOptimizer.write_code_and_helpers(candidate_fto_code, candidate_helper_code, fto.file_path)
|
||||
PythonFunctionOptimizer.write_code_and_helpers(candidate_fto_code, candidate_helper_code, fto.file_path)
|
||||
match, _ = compare_test_results(test_results, no_helper1_test_results)
|
||||
assert match
|
||||
|
||||
|
|
@ -1446,7 +1446,7 @@ def calculate_portfolio_metrics(
|
|||
test_framework="pytest",
|
||||
pytest_cmd="pytest",
|
||||
)
|
||||
func_optimizer = FunctionOptimizer(function_to_optimize=fto, test_cfg=test_config)
|
||||
func_optimizer = PythonFunctionOptimizer(function_to_optimize=fto, test_cfg=test_config)
|
||||
func_optimizer.test_files = TestFiles(
|
||||
test_files=[
|
||||
TestFile(
|
||||
|
|
@ -1477,7 +1477,7 @@ def calculate_portfolio_metrics(
|
|||
)
|
||||
|
||||
# Remove instrumentation
|
||||
FunctionOptimizer.write_code_and_helpers(candidate_fto_code, candidate_helper_code, fto.file_path)
|
||||
PythonFunctionOptimizer.write_code_and_helpers(candidate_fto_code, candidate_helper_code, fto.file_path)
|
||||
|
||||
# Now, let's say we optimize the code and make changes.
|
||||
new_fto_code = """import math
|
||||
|
|
@ -1543,7 +1543,7 @@ def calculate_portfolio_metrics(
|
|||
testing_time=0.1,
|
||||
)
|
||||
# Remove instrumentation
|
||||
FunctionOptimizer.write_code_and_helpers(candidate_fto_code, candidate_helper_code, fto.file_path)
|
||||
PythonFunctionOptimizer.write_code_and_helpers(candidate_fto_code, candidate_helper_code, fto.file_path)
|
||||
matched, diffs = compare_test_results(test_results, modified_test_results)
|
||||
|
||||
assert not matched
|
||||
|
|
@ -1606,7 +1606,7 @@ def calculate_portfolio_metrics(
|
|||
testing_time=0.1,
|
||||
)
|
||||
# Remove instrumentation
|
||||
FunctionOptimizer.write_code_and_helpers(candidate_fto_code, candidate_helper_code, fto.file_path)
|
||||
PythonFunctionOptimizer.write_code_and_helpers(candidate_fto_code, candidate_helper_code, fto.file_path)
|
||||
matched, diffs = compare_test_results(test_results, modified_test_results_2)
|
||||
# now the test should match and no diffs should be found
|
||||
assert len(diffs) == 0
|
||||
|
|
@ -1671,7 +1671,7 @@ class SlotsClass:
|
|||
file_path=sample_code_path,
|
||||
parents=[FunctionParent(name="SlotsClass", type="ClassDef")],
|
||||
)
|
||||
func_optimizer = FunctionOptimizer(function_to_optimize=fto, test_cfg=test_config)
|
||||
func_optimizer = PythonFunctionOptimizer(function_to_optimize=fto, test_cfg=test_config)
|
||||
func_optimizer.test_files = TestFiles(
|
||||
test_files=[
|
||||
TestFile(
|
||||
|
|
|
|||
|
|
@ -5,7 +5,7 @@ import pytest
|
|||
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
|
||||
from codeflash.either import is_successful
|
||||
from codeflash.models.models import FunctionParent
|
||||
from codeflash.optimization.function_optimizer import FunctionOptimizer
|
||||
from codeflash.languages.python.function_optimizer import PythonFunctionOptimizer
|
||||
from codeflash.verification.verification_utils import TestConfig
|
||||
|
||||
|
||||
|
|
@ -132,7 +132,7 @@ def test_class_method_dependencies() -> None:
|
|||
starting_line=None,
|
||||
ending_line=None,
|
||||
)
|
||||
func_optimizer = FunctionOptimizer(
|
||||
func_optimizer = PythonFunctionOptimizer(
|
||||
function_to_optimize=function_to_optimize,
|
||||
test_cfg=TestConfig(
|
||||
tests_root=file_path,
|
||||
|
|
@ -202,7 +202,7 @@ def test_recursive_function_context() -> None:
|
|||
starting_line=None,
|
||||
ending_line=None,
|
||||
)
|
||||
func_optimizer = FunctionOptimizer(
|
||||
func_optimizer = PythonFunctionOptimizer(
|
||||
function_to_optimize=function_to_optimize,
|
||||
test_cfg=TestConfig(
|
||||
tests_root=file_path,
|
||||
|
|
|
|||
|
|
@ -680,8 +680,14 @@ def test_in_dunder_tests():
|
|||
|
||||
# Combine all discovered functions
|
||||
all_functions = {}
|
||||
for discovered in [discovered_source, discovered_test, discovered_test_underscore,
|
||||
discovered_spec, discovered_tests_dir, discovered_dunder_tests]:
|
||||
for discovered in [
|
||||
discovered_source,
|
||||
discovered_test,
|
||||
discovered_test_underscore,
|
||||
discovered_spec,
|
||||
discovered_tests_dir,
|
||||
discovered_dunder_tests,
|
||||
]:
|
||||
all_functions.update(discovered)
|
||||
|
||||
# Test Case 1: tests_root == module_root (overlapping case)
|
||||
|
|
@ -781,9 +787,7 @@ def test_filter_functions_strict_string_matching():
|
|||
|
||||
# Strict check: exactly these 3 files should remain (those with 'test' as substring only)
|
||||
expected_files = {contest_file, latest_file, attestation_file}
|
||||
assert set(filtered.keys()) == expected_files, (
|
||||
f"Expected files {expected_files}, got {set(filtered.keys())}"
|
||||
)
|
||||
assert set(filtered.keys()) == expected_files, f"Expected files {expected_files}, got {set(filtered.keys())}"
|
||||
|
||||
# Strict check: each file should have exactly 1 function with the expected name
|
||||
assert [fn.function_name for fn in filtered[contest_file]] == ["run_contest"], (
|
||||
|
|
@ -871,9 +875,7 @@ def test_filter_functions_test_directory_patterns():
|
|||
|
||||
# Strict check: exactly these 2 files should remain (those in non-test directories)
|
||||
expected_files = {contest_file, latest_file}
|
||||
assert set(filtered.keys()) == expected_files, (
|
||||
f"Expected files {expected_files}, got {set(filtered.keys())}"
|
||||
)
|
||||
assert set(filtered.keys()) == expected_files, f"Expected files {expected_files}, got {set(filtered.keys())}"
|
||||
|
||||
# Strict check: each file should have exactly 1 function with the expected name
|
||||
assert [fn.function_name for fn in filtered[contest_file]] == ["get_scores"], (
|
||||
|
|
@ -936,9 +938,7 @@ def test_filter_functions_non_overlapping_tests_root():
|
|||
|
||||
# Strict check: exactly these 2 files should remain (both in src/, not in tests/)
|
||||
expected_files = {source_file, test_in_src}
|
||||
assert set(filtered.keys()) == expected_files, (
|
||||
f"Expected files {expected_files}, got {set(filtered.keys())}"
|
||||
)
|
||||
assert set(filtered.keys()) == expected_files, f"Expected files {expected_files}, got {set(filtered.keys())}"
|
||||
|
||||
# Strict check: each file should have exactly 1 function with the expected name
|
||||
assert [fn.function_name for fn in filtered[source_file]] == ["process"], (
|
||||
|
|
@ -1047,20 +1047,15 @@ def test_deep_copy():
|
|||
)
|
||||
|
||||
root_functions = [fn.function_name for fn in filtered.get(root_source_file, [])]
|
||||
assert root_functions == ["main"], (
|
||||
f"Expected ['main'], got {root_functions}"
|
||||
)
|
||||
assert root_functions == ["main"], f"Expected ['main'], got {root_functions}"
|
||||
|
||||
# Strict check: exactly 3 functions (2 from utils.py + 1 from main.py)
|
||||
assert count == 3, (
|
||||
f"Expected exactly 3 functions, got {count}. "
|
||||
f"Some source files may have been incorrectly filtered."
|
||||
f"Expected exactly 3 functions, got {count}. Some source files may have been incorrectly filtered."
|
||||
)
|
||||
|
||||
# Verify test file was properly filtered (should not be in results)
|
||||
assert test_file not in filtered, (
|
||||
f"Test file {test_file} should have been filtered but wasn't"
|
||||
)
|
||||
assert test_file not in filtered, f"Test file {test_file} should have been filtered but wasn't"
|
||||
|
||||
|
||||
def test_filter_functions_typescript_project_in_tests_folder():
|
||||
|
|
@ -1214,9 +1209,7 @@ def sample_data():
|
|||
# source_file and file_in_test_dir should remain
|
||||
# test_prefix_file, conftest_file, and test_in_subdir should be filtered
|
||||
expected_files = {source_file, file_in_test_dir}
|
||||
assert set(filtered.keys()) == expected_files, (
|
||||
f"Expected {expected_files}, got {set(filtered.keys())}"
|
||||
)
|
||||
assert set(filtered.keys()) == expected_files, f"Expected {expected_files}, got {set(filtered.keys())}"
|
||||
assert count == 2, f"Expected exactly 2 functions, got {count}"
|
||||
|
||||
|
||||
|
|
@ -1266,7 +1259,8 @@ class TestHelpers:
|
|||
""")
|
||||
|
||||
support = PythonSupport()
|
||||
functions = support.discover_functions(fixture_file)
|
||||
source = fixture_file.read_text(encoding="utf-8")
|
||||
functions = support.discover_functions(source, fixture_file)
|
||||
function_names = [fn.function_name for fn in functions]
|
||||
|
||||
assert "regular_function" in function_names
|
||||
|
|
|
|||
|
|
@ -7,7 +7,7 @@ import pytest
|
|||
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
|
||||
from codeflash.either import is_successful
|
||||
from codeflash.models.models import FunctionParent, get_code_block_splitter
|
||||
from codeflash.optimization.function_optimizer import FunctionOptimizer
|
||||
from codeflash.languages.python.function_optimizer import PythonFunctionOptimizer
|
||||
from codeflash.optimization.optimizer import Optimizer
|
||||
from codeflash.verification.verification_utils import TestConfig
|
||||
|
||||
|
|
@ -233,7 +233,7 @@ class _PersistentCache(Generic[_P, _R, _CacheBackendT]):
|
|||
test_framework="pytest",
|
||||
pytest_cmd="pytest",
|
||||
)
|
||||
func_optimizer = FunctionOptimizer(function_to_optimize=function_to_optimize, test_cfg=test_config)
|
||||
func_optimizer = PythonFunctionOptimizer(function_to_optimize=function_to_optimize, test_cfg=test_config)
|
||||
with open(file_path) as f:
|
||||
original_code = f.read()
|
||||
ctx_result = func_optimizer.get_code_optimization_context()
|
||||
|
|
@ -404,7 +404,7 @@ def test_bubble_sort_deps() -> None:
|
|||
test_framework="pytest",
|
||||
pytest_cmd="pytest",
|
||||
)
|
||||
func_optimizer = FunctionOptimizer(function_to_optimize=function_to_optimize, test_cfg=test_config)
|
||||
func_optimizer = PythonFunctionOptimizer(function_to_optimize=function_to_optimize, test_cfg=test_config)
|
||||
with open(file_path) as f:
|
||||
original_code = f.read()
|
||||
ctx_result = func_optimizer.get_code_optimization_context()
|
||||
|
|
|
|||
|
|
@ -5,7 +5,7 @@ from tempfile import TemporaryDirectory
|
|||
from codeflash.languages.python.static_analysis.line_profile_utils import add_decorator_imports, contains_jit_decorator
|
||||
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
|
||||
from codeflash.models.models import CodeOptimizationContext
|
||||
from codeflash.optimization.function_optimizer import FunctionOptimizer
|
||||
from codeflash.languages.python.function_optimizer import PythonFunctionOptimizer
|
||||
from codeflash.verification.verification_utils import TestConfig
|
||||
|
||||
|
||||
|
|
@ -22,7 +22,7 @@ def test_add_decorator_imports_helper_in_class():
|
|||
pytest_cmd="pytest",
|
||||
)
|
||||
func = FunctionToOptimize(function_name="sort_classmethod", parents=[], file_path=code_path)
|
||||
func_optimizer = FunctionOptimizer(function_to_optimize=func, test_cfg=test_config)
|
||||
func_optimizer = PythonFunctionOptimizer(function_to_optimize=func, test_cfg=test_config)
|
||||
os.chdir(run_cwd)
|
||||
# func_optimizer = pass
|
||||
try:
|
||||
|
|
@ -94,7 +94,7 @@ def test_add_decorator_imports_helper_in_nested_class():
|
|||
pytest_cmd="pytest",
|
||||
)
|
||||
func = FunctionToOptimize(function_name="sort_classmethod", parents=[], file_path=code_path)
|
||||
func_optimizer = FunctionOptimizer(function_to_optimize=func, test_cfg=test_config)
|
||||
func_optimizer = PythonFunctionOptimizer(function_to_optimize=func, test_cfg=test_config)
|
||||
os.chdir(run_cwd)
|
||||
# func_optimizer = pass
|
||||
try:
|
||||
|
|
@ -143,7 +143,7 @@ def test_add_decorator_imports_nodeps():
|
|||
pytest_cmd="pytest",
|
||||
)
|
||||
func = FunctionToOptimize(function_name="sorter", parents=[], file_path=code_path)
|
||||
func_optimizer = FunctionOptimizer(function_to_optimize=func, test_cfg=test_config)
|
||||
func_optimizer = PythonFunctionOptimizer(function_to_optimize=func, test_cfg=test_config)
|
||||
os.chdir(run_cwd)
|
||||
# func_optimizer = pass
|
||||
try:
|
||||
|
|
@ -194,7 +194,7 @@ def test_add_decorator_imports_helper_outside():
|
|||
pytest_cmd="pytest",
|
||||
)
|
||||
func = FunctionToOptimize(function_name="sorter_deps", parents=[], file_path=code_path)
|
||||
func_optimizer = FunctionOptimizer(function_to_optimize=func, test_cfg=test_config)
|
||||
func_optimizer = PythonFunctionOptimizer(function_to_optimize=func, test_cfg=test_config)
|
||||
os.chdir(run_cwd)
|
||||
# func_optimizer = pass
|
||||
try:
|
||||
|
|
@ -271,7 +271,7 @@ class helper:
|
|||
pytest_cmd="pytest",
|
||||
)
|
||||
func = FunctionToOptimize(function_name="sorter", parents=[], file_path=code_write_path)
|
||||
func_optimizer = FunctionOptimizer(function_to_optimize=func, test_cfg=test_config)
|
||||
func_optimizer = PythonFunctionOptimizer(function_to_optimize=func, test_cfg=test_config)
|
||||
os.chdir(run_cwd)
|
||||
# func_optimizer = pass
|
||||
try:
|
||||
|
|
|
|||
|
|
@ -27,7 +27,7 @@ from codeflash.models.models import (
|
|||
TestsInFile,
|
||||
TestType,
|
||||
)
|
||||
from codeflash.optimization.function_optimizer import FunctionOptimizer
|
||||
from codeflash.languages.python.function_optimizer import PythonFunctionOptimizer
|
||||
from codeflash.verification.verification_utils import TestConfig
|
||||
|
||||
codeflash_wrap_string = """def codeflash_wrap(codeflash_wrapped, codeflash_test_module_name, codeflash_test_class_name, codeflash_test_name, codeflash_function_name, codeflash_line_id, codeflash_loop_index, codeflash_cur, codeflash_con, *args, **kwargs):
|
||||
|
|
@ -434,7 +434,7 @@ def test_sort():
|
|||
test_framework="pytest",
|
||||
pytest_cmd="pytest",
|
||||
)
|
||||
func_optimizer = FunctionOptimizer(function_to_optimize=func, test_cfg=test_config)
|
||||
func_optimizer = PythonFunctionOptimizer(function_to_optimize=func, test_cfg=test_config)
|
||||
test_env = os.environ.copy()
|
||||
test_env["CODEFLASH_TEST_ITERATION"] = "0"
|
||||
test_env["CODEFLASH_LOOP_INDEX"] = "1"
|
||||
|
|
@ -695,7 +695,7 @@ def test_sort_parametrized(input, expected_output):
|
|||
test_framework="pytest",
|
||||
pytest_cmd="pytest",
|
||||
)
|
||||
func_optimizer = FunctionOptimizer(function_to_optimize=func, test_cfg=test_config)
|
||||
func_optimizer = PythonFunctionOptimizer(function_to_optimize=func, test_cfg=test_config)
|
||||
test_results, coverage_data = func_optimizer.run_and_parse_tests(
|
||||
testing_type=TestingMode.BEHAVIOR,
|
||||
test_env=test_env,
|
||||
|
|
@ -984,7 +984,7 @@ def test_sort_parametrized_loop(input, expected_output):
|
|||
test_framework="pytest",
|
||||
pytest_cmd="pytest",
|
||||
)
|
||||
func_optimizer = FunctionOptimizer(function_to_optimize=func, test_cfg=test_config)
|
||||
func_optimizer = PythonFunctionOptimizer(function_to_optimize=func, test_cfg=test_config)
|
||||
test_results, coverage_data = func_optimizer.run_and_parse_tests(
|
||||
testing_type=TestingMode.BEHAVIOR,
|
||||
test_env=test_env,
|
||||
|
|
@ -1341,7 +1341,7 @@ def test_sort():
|
|||
test_framework="pytest",
|
||||
pytest_cmd="pytest",
|
||||
)
|
||||
func_optimizer = FunctionOptimizer(function_to_optimize=func, test_cfg=test_config)
|
||||
func_optimizer = PythonFunctionOptimizer(function_to_optimize=func, test_cfg=test_config)
|
||||
test_results, coverage_data = func_optimizer.run_and_parse_tests(
|
||||
testing_type=TestingMode.BEHAVIOR,
|
||||
test_env=test_env,
|
||||
|
|
@ -1723,7 +1723,7 @@ class TestPigLatin(unittest.TestCase):
|
|||
test_framework="unittest",
|
||||
pytest_cmd="pytest",
|
||||
)
|
||||
func_optimizer = FunctionOptimizer(function_to_optimize=func, test_cfg=test_config)
|
||||
func_optimizer = PythonFunctionOptimizer(function_to_optimize=func, test_cfg=test_config)
|
||||
test_results, coverage_data = func_optimizer.run_and_parse_tests(
|
||||
testing_type=TestingMode.BEHAVIOR,
|
||||
test_env=test_env,
|
||||
|
|
@ -1973,7 +1973,7 @@ import unittest
|
|||
test_framework="unittest",
|
||||
pytest_cmd="pytest",
|
||||
)
|
||||
func_optimizer = FunctionOptimizer(function_to_optimize=func, test_cfg=test_config)
|
||||
func_optimizer = PythonFunctionOptimizer(function_to_optimize=func, test_cfg=test_config)
|
||||
test_results, coverage_data = func_optimizer.run_and_parse_tests(
|
||||
testing_type=TestingMode.BEHAVIOR,
|
||||
test_env=test_env,
|
||||
|
|
@ -2229,7 +2229,7 @@ import unittest
|
|||
test_framework="unittest",
|
||||
pytest_cmd="pytest",
|
||||
)
|
||||
func_optimizer = FunctionOptimizer(function_to_optimize=func, test_cfg=test_config)
|
||||
func_optimizer = PythonFunctionOptimizer(function_to_optimize=func, test_cfg=test_config)
|
||||
test_results, coverage_data = func_optimizer.run_and_parse_tests(
|
||||
test_env=test_env,
|
||||
testing_type=TestingMode.BEHAVIOR,
|
||||
|
|
@ -2481,7 +2481,7 @@ import unittest
|
|||
test_framework="unittest",
|
||||
pytest_cmd="pytest",
|
||||
)
|
||||
func_optimizer = FunctionOptimizer(function_to_optimize=f, test_cfg=test_config)
|
||||
func_optimizer = PythonFunctionOptimizer(function_to_optimize=f, test_cfg=test_config)
|
||||
test_results, coverage_data = func_optimizer.run_and_parse_tests(
|
||||
testing_type=TestingMode.BEHAVIOR,
|
||||
test_env=test_env,
|
||||
|
|
@ -3144,7 +3144,7 @@ def test_sleepfunc_sequence_short(n, expected_total_sleep_time):
|
|||
test_framework="pytest",
|
||||
pytest_cmd="pytest",
|
||||
)
|
||||
func_optimizer = FunctionOptimizer(function_to_optimize=func, test_cfg=test_config)
|
||||
func_optimizer = PythonFunctionOptimizer(function_to_optimize=func, test_cfg=test_config)
|
||||
test_files = TestFiles(
|
||||
test_files=[
|
||||
TestFile(
|
||||
|
|
@ -3279,7 +3279,7 @@ import unittest
|
|||
test_framework="unittest",
|
||||
pytest_cmd="pytest",
|
||||
)
|
||||
func_optimizer = FunctionOptimizer(function_to_optimize=func, test_cfg=test_config)
|
||||
func_optimizer = PythonFunctionOptimizer(function_to_optimize=func, test_cfg=test_config)
|
||||
test_results, coverage_data = func_optimizer.run_and_parse_tests(
|
||||
testing_type=TestingMode.PERFORMANCE,
|
||||
test_env=test_env,
|
||||
|
|
|
|||
|
|
@ -20,12 +20,15 @@ All assertions use strict string equality to verify exact extraction output.
|
|||
|
||||
from __future__ import annotations
|
||||
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
|
||||
from codeflash.languages.base import Language
|
||||
from codeflash.languages.javascript.support import JavaScriptSupport, TypeScriptSupport
|
||||
from codeflash.languages.python.context.code_context_extractor import get_code_optimization_context_for_language
|
||||
from codeflash.languages.javascript.function_optimizer import JavaScriptFunctionOptimizer
|
||||
from codeflash.verification.verification_utils import TestConfig
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
|
|
@ -61,7 +64,8 @@ export function add(a, b) {
|
|||
file_path = temp_project / "math.js"
|
||||
file_path.write_text(code, encoding="utf-8")
|
||||
|
||||
functions = js_support.discover_functions(file_path)
|
||||
source = file_path.read_text(encoding="utf-8")
|
||||
functions = js_support.discover_functions(source, file_path)
|
||||
assert len(functions) == 1
|
||||
func = functions[0]
|
||||
|
||||
|
|
@ -87,7 +91,8 @@ export const multiply = (a, b) => a * b;
|
|||
file_path = temp_project / "math.js"
|
||||
file_path.write_text(code, encoding="utf-8")
|
||||
|
||||
functions = js_support.discover_functions(file_path)
|
||||
source = file_path.read_text(encoding="utf-8")
|
||||
functions = js_support.discover_functions(source, file_path)
|
||||
assert len(functions) == 1
|
||||
func = functions[0]
|
||||
assert func.function_name == "multiply"
|
||||
|
|
@ -121,7 +126,8 @@ export function add(a, b) {
|
|||
file_path = temp_project / "math.js"
|
||||
file_path.write_text(code, encoding="utf-8")
|
||||
|
||||
functions = js_support.discover_functions(file_path)
|
||||
source = file_path.read_text(encoding="utf-8")
|
||||
functions = js_support.discover_functions(source, file_path)
|
||||
func = functions[0]
|
||||
|
||||
context = js_support.extract_code_context(func, temp_project, temp_project)
|
||||
|
|
@ -173,7 +179,8 @@ export async function processItems(items, callback, options = {}) {
|
|||
file_path = temp_project / "processor.js"
|
||||
file_path.write_text(code, encoding="utf-8")
|
||||
|
||||
functions = js_support.discover_functions(file_path)
|
||||
source = file_path.read_text(encoding="utf-8")
|
||||
functions = js_support.discover_functions(source, file_path)
|
||||
func = functions[0]
|
||||
|
||||
context = js_support.extract_code_context(func, temp_project, temp_project)
|
||||
|
|
@ -243,7 +250,8 @@ export class CacheManager {
|
|||
file_path = temp_project / "cache.js"
|
||||
file_path.write_text(code, encoding="utf-8")
|
||||
|
||||
functions = js_support.discover_functions(file_path)
|
||||
source = file_path.read_text(encoding="utf-8")
|
||||
functions = js_support.discover_functions(source, file_path)
|
||||
get_or_compute = next(f for f in functions if f.function_name == "getOrCompute")
|
||||
|
||||
context = js_support.extract_code_context(get_or_compute, temp_project, temp_project)
|
||||
|
|
@ -339,7 +347,8 @@ export function validateUserData(data, validators) {
|
|||
file_path = temp_project / "validator.js"
|
||||
file_path.write_text(code, encoding="utf-8")
|
||||
|
||||
functions = js_support.discover_functions(file_path)
|
||||
source = file_path.read_text(encoding="utf-8")
|
||||
functions = js_support.discover_functions(source, file_path)
|
||||
func = next(f for f in functions if f.function_name == "validateUserData")
|
||||
|
||||
context = js_support.extract_code_context(func, temp_project, temp_project)
|
||||
|
|
@ -429,7 +438,8 @@ export async function fetchWithRetry(endpoint, options = {}) {
|
|||
file_path = temp_project / "api.js"
|
||||
file_path.write_text(code, encoding="utf-8")
|
||||
|
||||
functions = js_support.discover_functions(file_path)
|
||||
source = file_path.read_text(encoding="utf-8")
|
||||
functions = js_support.discover_functions(source, file_path)
|
||||
func = next(f for f in functions if f.function_name == "fetchWithRetry")
|
||||
|
||||
context = js_support.extract_code_context(func, temp_project, temp_project)
|
||||
|
|
@ -515,7 +525,8 @@ export function validateField(value, fieldType) {
|
|||
file_path = temp_project / "validation.js"
|
||||
file_path.write_text(code, encoding="utf-8")
|
||||
|
||||
functions = js_support.discover_functions(file_path)
|
||||
source = file_path.read_text(encoding="utf-8")
|
||||
functions = js_support.discover_functions(source, file_path)
|
||||
func = functions[0]
|
||||
|
||||
context = js_support.extract_code_context(func, temp_project, temp_project)
|
||||
|
|
@ -578,7 +589,8 @@ export function processUserInput(rawInput) {
|
|||
file_path = temp_project / "processor.js"
|
||||
file_path.write_text(code, encoding="utf-8")
|
||||
|
||||
functions = js_support.discover_functions(file_path)
|
||||
source = file_path.read_text(encoding="utf-8")
|
||||
functions = js_support.discover_functions(source, file_path)
|
||||
process_func = next(f for f in functions if f.function_name == "processUserInput")
|
||||
|
||||
context = js_support.extract_code_context(process_func, temp_project, temp_project)
|
||||
|
|
@ -633,7 +645,8 @@ export function generateReport(data) {
|
|||
file_path = temp_project / "report.js"
|
||||
file_path.write_text(code, encoding="utf-8")
|
||||
|
||||
functions = js_support.discover_functions(file_path)
|
||||
source = file_path.read_text(encoding="utf-8")
|
||||
functions = js_support.discover_functions(source, file_path)
|
||||
report_func = next(f for f in functions if f.function_name == "generateReport")
|
||||
|
||||
context = js_support.extract_code_context(report_func, temp_project, temp_project)
|
||||
|
|
@ -731,7 +744,8 @@ export class Graph {
|
|||
file_path = temp_project / "graph.js"
|
||||
file_path.write_text(code, encoding="utf-8")
|
||||
|
||||
functions = js_support.discover_functions(file_path)
|
||||
source = file_path.read_text(encoding="utf-8")
|
||||
functions = js_support.discover_functions(source, file_path)
|
||||
topo_sort = next(f for f in functions if f.function_name == "topologicalSort")
|
||||
|
||||
context = js_support.extract_code_context(topo_sort, temp_project, temp_project)
|
||||
|
|
@ -819,7 +833,8 @@ export class MainClass {
|
|||
file_path = temp_project / "classes.js"
|
||||
file_path.write_text(code, encoding="utf-8")
|
||||
|
||||
functions = js_support.discover_functions(file_path)
|
||||
source = file_path.read_text(encoding="utf-8")
|
||||
functions = js_support.discover_functions(source, file_path)
|
||||
main_method = next(f for f in functions if f.function_name == "mainMethod" and f.class_name == "MainClass")
|
||||
|
||||
context = js_support.extract_code_context(main_method, temp_project, temp_project)
|
||||
|
|
@ -875,7 +890,8 @@ module.exports = { sortFromAnotherFile };
|
|||
main_path = temp_project / "bubble_sort_imported.js"
|
||||
main_path.write_text(main_code, encoding="utf-8")
|
||||
|
||||
functions = js_support.discover_functions(main_path)
|
||||
source = main_path.read_text(encoding="utf-8")
|
||||
functions = js_support.discover_functions(source, main_path)
|
||||
main_func = next(f for f in functions if f.function_name == "sortFromAnotherFile")
|
||||
|
||||
context = js_support.extract_code_context(main_func, temp_project, temp_project)
|
||||
|
|
@ -926,7 +942,8 @@ export function processNumber(n) {
|
|||
main_path = temp_project / "main.js"
|
||||
main_path.write_text(main_code, encoding="utf-8")
|
||||
|
||||
functions = js_support.discover_functions(main_path)
|
||||
source = main_path.read_text(encoding="utf-8")
|
||||
functions = js_support.discover_functions(source, main_path)
|
||||
process_func = next(f for f in functions if f.function_name == "processNumber")
|
||||
|
||||
context = js_support.extract_code_context(process_func, temp_project, temp_project)
|
||||
|
|
@ -992,7 +1009,8 @@ export function handleUserInput(rawInput) {
|
|||
main_path = temp_project / "main.js"
|
||||
main_path.write_text(main_code, encoding="utf-8")
|
||||
|
||||
functions = js_support.discover_functions(main_path)
|
||||
source = main_path.read_text(encoding="utf-8")
|
||||
functions = js_support.discover_functions(source, main_path)
|
||||
handle_func = next(f for f in functions if f.function_name == "handleUserInput")
|
||||
|
||||
context = js_support.extract_code_context(handle_func, temp_project, temp_project)
|
||||
|
|
@ -1043,7 +1061,8 @@ export function createEntity<T extends object>(data: T): Entity<T> {
|
|||
file_path = temp_project / "entity.ts"
|
||||
file_path.write_text(code, encoding="utf-8")
|
||||
|
||||
functions = ts_support.discover_functions(file_path)
|
||||
source = file_path.read_text(encoding="utf-8")
|
||||
functions = ts_support.discover_functions(source, file_path)
|
||||
func = functions[0]
|
||||
|
||||
context = ts_support.extract_code_context(func, temp_project, temp_project)
|
||||
|
|
@ -1133,7 +1152,8 @@ export class TypedCache<T> {
|
|||
file_path = temp_project / "cache.ts"
|
||||
file_path.write_text(code, encoding="utf-8")
|
||||
|
||||
functions = ts_support.discover_functions(file_path)
|
||||
source = file_path.read_text(encoding="utf-8")
|
||||
functions = ts_support.discover_functions(source, file_path)
|
||||
get_method = next(f for f in functions if f.function_name == "get")
|
||||
|
||||
context = ts_support.extract_code_context(get_method, temp_project, temp_project)
|
||||
|
|
@ -1217,7 +1237,8 @@ export function createUser(input: CreateUserInput, role: UserRole = DEFAULT_ROLE
|
|||
service_path = temp_project / "service.ts"
|
||||
service_path.write_text(service_code, encoding="utf-8")
|
||||
|
||||
functions = ts_support.discover_functions(service_path)
|
||||
source = service_path.read_text(encoding="utf-8")
|
||||
functions = ts_support.discover_functions(source, service_path)
|
||||
func = next(f for f in functions if f.function_name == "createUser")
|
||||
|
||||
context = ts_support.extract_code_context(func, temp_project, temp_project)
|
||||
|
|
@ -1271,7 +1292,8 @@ export function factorial(n) {
|
|||
file_path = temp_project / "math.js"
|
||||
file_path.write_text(code, encoding="utf-8")
|
||||
|
||||
functions = js_support.discover_functions(file_path)
|
||||
source = file_path.read_text(encoding="utf-8")
|
||||
functions = js_support.discover_functions(source, file_path)
|
||||
func = functions[0]
|
||||
|
||||
context = js_support.extract_code_context(func, temp_project, temp_project)
|
||||
|
|
@ -1301,7 +1323,8 @@ export function isOdd(n) {
|
|||
file_path = temp_project / "parity.js"
|
||||
file_path.write_text(code, encoding="utf-8")
|
||||
|
||||
functions = js_support.discover_functions(file_path)
|
||||
source = file_path.read_text(encoding="utf-8")
|
||||
functions = js_support.discover_functions(source, file_path)
|
||||
is_even = next(f for f in functions if f.function_name == "isEven")
|
||||
|
||||
context = js_support.extract_code_context(is_even, temp_project, temp_project)
|
||||
|
|
@ -1319,12 +1342,15 @@ export function isEven(n) {
|
|||
assert helper_names == ["isOdd"]
|
||||
|
||||
# Verify helper source
|
||||
assert context.helper_functions[0].source_code == """\
|
||||
assert (
|
||||
context.helper_functions[0].source_code
|
||||
== """\
|
||||
export function isOdd(n) {
|
||||
if (n === 0) return false;
|
||||
return isEven(n - 1);
|
||||
}
|
||||
"""
|
||||
)
|
||||
|
||||
def test_complex_recursive_tree_traversal(self, js_support, temp_project):
|
||||
"""Test complex recursive tree traversal with multiple recursive calls."""
|
||||
|
|
@ -1363,7 +1389,8 @@ export function collectAllValues(root) {
|
|||
file_path = temp_project / "tree.js"
|
||||
file_path.write_text(code, encoding="utf-8")
|
||||
|
||||
functions = js_support.discover_functions(file_path)
|
||||
source = file_path.read_text(encoding="utf-8")
|
||||
functions = js_support.discover_functions(source, file_path)
|
||||
collect_func = next(f for f in functions if f.function_name == "collectAllValues")
|
||||
|
||||
context = js_support.extract_code_context(collect_func, temp_project, temp_project)
|
||||
|
|
@ -1428,7 +1455,8 @@ export async function fetchUserProfile(userId) {
|
|||
file_path = temp_project / "api.js"
|
||||
file_path.write_text(code, encoding="utf-8")
|
||||
|
||||
functions = js_support.discover_functions(file_path)
|
||||
source = file_path.read_text(encoding="utf-8")
|
||||
functions = js_support.discover_functions(source, file_path)
|
||||
profile_func = next(f for f in functions if f.function_name == "fetchUserProfile")
|
||||
|
||||
context = js_support.extract_code_context(profile_func, temp_project, temp_project)
|
||||
|
|
@ -1483,7 +1511,8 @@ module.exports = { Counter };
|
|||
file_path = temp_project / "counter.js"
|
||||
file_path.write_text(original_source, encoding="utf-8")
|
||||
|
||||
functions = js_support.discover_functions(file_path)
|
||||
source = file_path.read_text(encoding="utf-8")
|
||||
functions = js_support.discover_functions(source, file_path)
|
||||
increment_func = next(fn for fn in functions if fn.function_name == "increment")
|
||||
|
||||
# Step 1: Extract code context
|
||||
|
|
@ -1563,7 +1592,8 @@ export function processApiResponse({
|
|||
file_path = temp_project / "api.js"
|
||||
file_path.write_text(code, encoding="utf-8")
|
||||
|
||||
functions = js_support.discover_functions(file_path)
|
||||
source = file_path.read_text(encoding="utf-8")
|
||||
functions = js_support.discover_functions(source, file_path)
|
||||
func = functions[0]
|
||||
|
||||
context = js_support.extract_code_context(func, temp_project, temp_project)
|
||||
|
|
@ -1605,7 +1635,8 @@ export function* fibonacci(limit) {
|
|||
file_path = temp_project / "generators.js"
|
||||
file_path.write_text(code, encoding="utf-8")
|
||||
|
||||
functions = js_support.discover_functions(file_path)
|
||||
source = file_path.read_text(encoding="utf-8")
|
||||
functions = js_support.discover_functions(source, file_path)
|
||||
range_func = next(f for f in functions if f.function_name == "range")
|
||||
|
||||
context = js_support.extract_code_context(range_func, temp_project, temp_project)
|
||||
|
|
@ -1640,7 +1671,8 @@ export function createUserObject(name, email, age) {
|
|||
file_path = temp_project / "user.js"
|
||||
file_path.write_text(code, encoding="utf-8")
|
||||
|
||||
functions = js_support.discover_functions(file_path)
|
||||
source = file_path.read_text(encoding="utf-8")
|
||||
functions = js_support.discover_functions(source, file_path)
|
||||
func = functions[0]
|
||||
|
||||
context = js_support.extract_code_context(func, temp_project, temp_project)
|
||||
|
|
@ -1790,7 +1822,8 @@ export const sendSlackMessage = async (
|
|||
file_path.write_text(code, encoding="utf-8")
|
||||
target_func = "sendSlackMessage"
|
||||
|
||||
functions = ts_support.discover_functions(file_path)
|
||||
source = file_path.read_text(encoding="utf-8")
|
||||
functions = ts_support.discover_functions(source, file_path)
|
||||
func_info = next(f for f in functions if f.function_name == target_func)
|
||||
fto = FunctionToOptimize(
|
||||
function_name=target_func,
|
||||
|
|
@ -1804,9 +1837,11 @@ export const sendSlackMessage = async (
|
|||
language="typescript",
|
||||
)
|
||||
|
||||
ctx = get_code_optimization_context_for_language(
|
||||
fto, temp_project
|
||||
test_config = TestConfig(
|
||||
tests_root=temp_project, tests_project_rootdir=temp_project, project_root_path=temp_project
|
||||
)
|
||||
func_optimizer = JavaScriptFunctionOptimizer(function_to_optimize=fto, test_cfg=test_config, aiservice_client=MagicMock())
|
||||
ctx = func_optimizer.get_code_optimization_context().unwrap()
|
||||
|
||||
# The read_writable_code should contain the target function AND helper functions
|
||||
expected_read_writable = """```typescript:slack_util.ts
|
||||
|
|
@ -1899,7 +1934,6 @@ let web: WebClient | null = null"""
|
|||
assert ctx.read_only_context_code == expected_read_only
|
||||
|
||||
|
||||
|
||||
class TestContextProperties:
|
||||
"""Tests for CodeContext object properties."""
|
||||
|
||||
|
|
@ -1913,7 +1947,8 @@ export function test() {
|
|||
file_path = temp_project / "test.js"
|
||||
file_path.write_text(code, encoding="utf-8")
|
||||
|
||||
functions = js_support.discover_functions(file_path)
|
||||
source = file_path.read_text(encoding="utf-8")
|
||||
functions = js_support.discover_functions(source, file_path)
|
||||
context = js_support.extract_code_context(functions[0], temp_project, temp_project)
|
||||
|
||||
assert context.language == Language.JAVASCRIPT
|
||||
|
|
@ -1932,7 +1967,8 @@ export function test(): number {
|
|||
file_path = temp_project / "test.ts"
|
||||
file_path.write_text(code, encoding="utf-8")
|
||||
|
||||
functions = ts_support.discover_functions(file_path)
|
||||
source = file_path.read_text(encoding="utf-8")
|
||||
functions = ts_support.discover_functions(source, file_path)
|
||||
context = ts_support.extract_code_context(functions[0], temp_project, temp_project)
|
||||
|
||||
# TypeScript uses JavaScript language enum
|
||||
|
|
@ -1974,7 +2010,8 @@ export class Calculator {
|
|||
file_path = temp_project / "calculator.js"
|
||||
file_path.write_text(code, encoding="utf-8")
|
||||
|
||||
functions = js_support.discover_functions(file_path)
|
||||
source = file_path.read_text(encoding="utf-8")
|
||||
functions = js_support.discover_functions(source, file_path)
|
||||
|
||||
for func in functions:
|
||||
if func.function_name != "constructor":
|
||||
|
|
|
|||
|
|
@ -107,10 +107,8 @@ class TestJavaScriptCodeContext:
|
|||
"""Test extracting code context for a JavaScript function."""
|
||||
skip_if_js_not_supported()
|
||||
from codeflash.discovery.functions_to_optimize import find_all_functions_in_file
|
||||
from codeflash.languages import current as lang_current
|
||||
from codeflash.languages.python.context.code_context_extractor import get_code_optimization_context
|
||||
|
||||
lang_current._current_language = Language.JAVASCRIPT
|
||||
from codeflash.languages import get_language_support
|
||||
from codeflash.languages.javascript.function_optimizer import JavaScriptFunctionOptimizer
|
||||
|
||||
fib_file = js_project_dir / "fibonacci.js"
|
||||
if not fib_file.exists():
|
||||
|
|
@ -122,7 +120,11 @@ class TestJavaScriptCodeContext:
|
|||
fib_func = next((f for f in func_list if f.function_name == "fibonacci"), None)
|
||||
assert fib_func is not None
|
||||
|
||||
context = get_code_optimization_context(fib_func, js_project_dir)
|
||||
js_support = get_language_support(Language.JAVASCRIPT)
|
||||
code_context = js_support.extract_code_context(fib_func, js_project_dir, js_project_dir)
|
||||
context = JavaScriptFunctionOptimizer._build_optimization_context(
|
||||
code_context, fib_file, "javascript", js_project_dir
|
||||
)
|
||||
|
||||
assert context.read_writable_code is not None
|
||||
assert context.read_writable_code.language == "javascript"
|
||||
|
|
|
|||
|
|
@ -71,10 +71,8 @@ module.exports = { add };
|
|||
"""Verify language is preserved in code context extraction."""
|
||||
skip_if_js_not_supported()
|
||||
from codeflash.discovery.functions_to_optimize import find_all_functions_in_file
|
||||
from codeflash.languages import current as lang_current
|
||||
from codeflash.languages.python.context.code_context_extractor import get_code_optimization_context
|
||||
|
||||
lang_current._current_language = Language.TYPESCRIPT
|
||||
from codeflash.languages import get_language_support
|
||||
from codeflash.languages.javascript.function_optimizer import JavaScriptFunctionOptimizer
|
||||
|
||||
ts_file = tmp_path / "utils.ts"
|
||||
ts_file.write_text("""
|
||||
|
|
@ -86,7 +84,11 @@ export function add(a: number, b: number): number {
|
|||
functions = find_all_functions_in_file(ts_file)
|
||||
func = functions[ts_file][0]
|
||||
|
||||
context = get_code_optimization_context(func, tmp_path)
|
||||
ts_support = get_language_support(Language.TYPESCRIPT)
|
||||
code_context = ts_support.extract_code_context(func, tmp_path, tmp_path)
|
||||
context = JavaScriptFunctionOptimizer._build_optimization_context(
|
||||
code_context, ts_file, "typescript", tmp_path
|
||||
)
|
||||
|
||||
assert context.read_writable_code is not None
|
||||
assert context.read_writable_code.language == "typescript"
|
||||
|
|
@ -373,10 +375,7 @@ describe('fibonacci', () => {
|
|||
"""Test get_code_optimization_context for JavaScript."""
|
||||
skip_if_js_not_supported()
|
||||
from codeflash.discovery.functions_to_optimize import find_all_functions_in_file
|
||||
from codeflash.languages import current as lang_current
|
||||
from codeflash.optimization.function_optimizer import FunctionOptimizer
|
||||
|
||||
lang_current._current_language = Language.JAVASCRIPT
|
||||
from codeflash.languages.javascript.function_optimizer import JavaScriptFunctionOptimizer
|
||||
|
||||
src_file = js_project / "utils.js"
|
||||
functions = find_all_functions_in_file(src_file)
|
||||
|
|
@ -398,7 +397,7 @@ describe('fibonacci', () => {
|
|||
pytest_cmd="jest",
|
||||
)
|
||||
|
||||
optimizer = FunctionOptimizer(
|
||||
optimizer = JavaScriptFunctionOptimizer(
|
||||
function_to_optimize=func_to_optimize,
|
||||
test_cfg=test_config,
|
||||
aiservice_client=MagicMock(),
|
||||
|
|
@ -415,10 +414,7 @@ describe('fibonacci', () => {
|
|||
"""Test get_code_optimization_context for TypeScript."""
|
||||
skip_if_js_not_supported()
|
||||
from codeflash.discovery.functions_to_optimize import find_all_functions_in_file
|
||||
from codeflash.languages import current as lang_current
|
||||
from codeflash.optimization.function_optimizer import FunctionOptimizer
|
||||
|
||||
lang_current._current_language = Language.TYPESCRIPT
|
||||
from codeflash.languages.javascript.function_optimizer import JavaScriptFunctionOptimizer
|
||||
|
||||
src_file = ts_project / "utils.ts"
|
||||
functions = find_all_functions_in_file(src_file)
|
||||
|
|
@ -440,7 +436,7 @@ describe('fibonacci', () => {
|
|||
pytest_cmd="vitest",
|
||||
)
|
||||
|
||||
optimizer = FunctionOptimizer(
|
||||
optimizer = JavaScriptFunctionOptimizer(
|
||||
function_to_optimize=func_to_optimize,
|
||||
test_cfg=test_config,
|
||||
aiservice_client=MagicMock(),
|
||||
|
|
@ -461,10 +457,7 @@ class TestHelperFunctionLanguageAttribute:
|
|||
"""Verify helper functions have language='javascript' for .js files."""
|
||||
skip_if_js_not_supported()
|
||||
from codeflash.discovery.functions_to_optimize import find_all_functions_in_file
|
||||
from codeflash.languages import current as lang_current
|
||||
from codeflash.optimization.function_optimizer import FunctionOptimizer
|
||||
|
||||
lang_current._current_language = Language.JAVASCRIPT
|
||||
from codeflash.languages.javascript.function_optimizer import JavaScriptFunctionOptimizer
|
||||
|
||||
# Create a file with helper functions
|
||||
src_file = tmp_path / "main.js"
|
||||
|
|
@ -499,7 +492,7 @@ module.exports = { main };
|
|||
pytest_cmd="jest",
|
||||
)
|
||||
|
||||
optimizer = FunctionOptimizer(
|
||||
optimizer = JavaScriptFunctionOptimizer(
|
||||
function_to_optimize=func_to_optimize,
|
||||
test_cfg=test_config,
|
||||
aiservice_client=MagicMock(),
|
||||
|
|
@ -515,10 +508,7 @@ module.exports = { main };
|
|||
"""Verify helper functions have language='typescript' for .ts files."""
|
||||
skip_if_js_not_supported()
|
||||
from codeflash.discovery.functions_to_optimize import find_all_functions_in_file
|
||||
from codeflash.languages import current as lang_current
|
||||
from codeflash.optimization.function_optimizer import FunctionOptimizer
|
||||
|
||||
lang_current._current_language = Language.TYPESCRIPT
|
||||
from codeflash.languages.javascript.function_optimizer import JavaScriptFunctionOptimizer
|
||||
|
||||
# Create a file with helper functions
|
||||
src_file = tmp_path / "main.ts"
|
||||
|
|
@ -551,7 +541,7 @@ export function main(): number {
|
|||
pytest_cmd="vitest",
|
||||
)
|
||||
|
||||
optimizer = FunctionOptimizer(
|
||||
optimizer = JavaScriptFunctionOptimizer(
|
||||
function_to_optimize=func_to_optimize,
|
||||
test_cfg=test_config,
|
||||
aiservice_client=MagicMock(),
|
||||
|
|
|
|||
|
|
@ -16,8 +16,6 @@ NOTE: These tests require:
|
|||
Tests will be skipped if dependencies are not available.
|
||||
"""
|
||||
|
||||
import os
|
||||
import shutil
|
||||
import subprocess
|
||||
from pathlib import Path
|
||||
from unittest.mock import MagicMock
|
||||
|
|
@ -26,7 +24,7 @@ import pytest
|
|||
|
||||
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
|
||||
from codeflash.languages.base import Language
|
||||
from codeflash.models.models import FunctionParent, TestFile, TestFiles, TestType, TestingMode
|
||||
from codeflash.models.models import FunctionParent
|
||||
from codeflash.verification.verification_utils import TestConfig
|
||||
|
||||
|
||||
|
|
@ -58,13 +56,7 @@ def install_dependencies(project_dir: Path) -> bool:
|
|||
if has_node_modules(project_dir):
|
||||
return True
|
||||
try:
|
||||
result = subprocess.run(
|
||||
["npm", "install"],
|
||||
cwd=project_dir,
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=120
|
||||
)
|
||||
result = subprocess.run(["npm", "install"], cwd=project_dir, capture_output=True, text=True, timeout=120)
|
||||
return result.returncode == 0
|
||||
except Exception:
|
||||
return False
|
||||
|
|
@ -82,6 +74,7 @@ def skip_if_js_not_supported():
|
|||
"""Skip test if JavaScript/TypeScript languages are not supported."""
|
||||
try:
|
||||
from codeflash.languages import get_language_support
|
||||
|
||||
get_language_support(Language.JAVASCRIPT)
|
||||
except Exception as e:
|
||||
pytest.skip(f"JavaScript/TypeScript language support not available: {e}")
|
||||
|
|
@ -157,8 +150,8 @@ module.exports = {
|
|||
"""Test that JavaScript test instrumentation module can be imported."""
|
||||
skip_if_js_not_supported()
|
||||
from codeflash.languages import get_language_support
|
||||
|
||||
# Verify the instrumentation module can be imported
|
||||
from codeflash.languages.javascript.instrument import inject_profiling_into_existing_js_test
|
||||
|
||||
# Get JavaScript support
|
||||
js_support = get_language_support(Language.JAVASCRIPT)
|
||||
|
|
@ -272,8 +265,8 @@ export default defineConfig({
|
|||
"""Test that TypeScript test instrumentation module can be imported."""
|
||||
skip_if_js_not_supported()
|
||||
from codeflash.languages import get_language_support
|
||||
|
||||
# Verify the instrumentation module can be imported
|
||||
from codeflash.languages.javascript.instrument import inject_profiling_into_existing_js_test
|
||||
|
||||
test_file = ts_project_dir / "tests" / "math.test.ts"
|
||||
|
||||
|
|
@ -356,10 +349,7 @@ class TestRunAndParseJavaScriptTests:
|
|||
"""
|
||||
skip_if_js_not_supported()
|
||||
from codeflash.discovery.functions_to_optimize import find_all_functions_in_file
|
||||
from codeflash.languages import current as lang_current
|
||||
from codeflash.optimization.function_optimizer import FunctionOptimizer
|
||||
|
||||
lang_current._current_language = Language.TYPESCRIPT
|
||||
from codeflash.languages.javascript.function_optimizer import JavaScriptFunctionOptimizer
|
||||
|
||||
# Find the fibonacci function
|
||||
fib_file = vitest_project / "fibonacci.ts"
|
||||
|
|
@ -389,10 +379,8 @@ class TestRunAndParseJavaScriptTests:
|
|||
)
|
||||
|
||||
# Create optimizer
|
||||
func_optimizer = FunctionOptimizer(
|
||||
function_to_optimize=func,
|
||||
test_cfg=test_config,
|
||||
aiservice_client=MagicMock(),
|
||||
func_optimizer = JavaScriptFunctionOptimizer(
|
||||
function_to_optimize=func, test_cfg=test_config, aiservice_client=MagicMock()
|
||||
)
|
||||
|
||||
# Get code context - this should work
|
||||
|
|
@ -419,8 +407,8 @@ class TestTimingMarkerParsing:
|
|||
# The marker format used by codeflash for JavaScript
|
||||
# Start marker: !$######{tag}######$!
|
||||
# End marker: !######{tag}:{duration}######!
|
||||
start_pattern = r'!\$######(.+?)######\$!'
|
||||
end_pattern = r'!######(.+?):(\d+)######!'
|
||||
start_pattern = r"!\$######(.+?)######\$!"
|
||||
end_pattern = r"!######(.+?):(\d+)######!"
|
||||
|
||||
start_marker = "!$######test/math.test.ts:TestMath.test_add:add:1:0_0######$!"
|
||||
end_marker = "!######test/math.test.ts:TestMath.test_add:add:1:0_0:12345######!"
|
||||
|
|
@ -472,6 +460,7 @@ class TestJavaScriptTestResultParsing:
|
|||
|
||||
# Parse the XML
|
||||
import xml.etree.ElementTree as ET
|
||||
|
||||
tree = ET.parse(junit_xml)
|
||||
root = tree.getroot()
|
||||
|
||||
|
|
@ -504,6 +493,7 @@ class TestJavaScriptTestResultParsing:
|
|||
|
||||
# Parse the XML
|
||||
import xml.etree.ElementTree as ET
|
||||
|
||||
tree = ET.parse(junit_xml)
|
||||
root = tree.getroot()
|
||||
|
||||
|
|
|
|||
|
|
@ -52,7 +52,7 @@ export function add(a, b) {
|
|||
""")
|
||||
f.flush()
|
||||
|
||||
functions = js_support.discover_functions(Path(f.name))
|
||||
functions = js_support.discover_functions(Path(f.name).read_text(encoding="utf-8"), Path(f.name))
|
||||
|
||||
assert len(functions) == 1
|
||||
assert functions[0].function_name == "add"
|
||||
|
|
@ -76,7 +76,7 @@ export function multiply(a, b) {
|
|||
""")
|
||||
f.flush()
|
||||
|
||||
functions = js_support.discover_functions(Path(f.name))
|
||||
functions = js_support.discover_functions(Path(f.name).read_text(encoding="utf-8"), Path(f.name))
|
||||
|
||||
assert len(functions) == 3
|
||||
names = {func.function_name for func in functions}
|
||||
|
|
@ -94,7 +94,7 @@ export const multiply = (x, y) => x * y;
|
|||
""")
|
||||
f.flush()
|
||||
|
||||
functions = js_support.discover_functions(Path(f.name))
|
||||
functions = js_support.discover_functions(Path(f.name).read_text(encoding="utf-8"), Path(f.name))
|
||||
|
||||
assert len(functions) == 2
|
||||
names = {func.function_name for func in functions}
|
||||
|
|
@ -114,7 +114,7 @@ export function withoutReturn() {
|
|||
""")
|
||||
f.flush()
|
||||
|
||||
functions = js_support.discover_functions(Path(f.name))
|
||||
functions = js_support.discover_functions(Path(f.name).read_text(encoding="utf-8"), Path(f.name))
|
||||
|
||||
# Only the function with return should be discovered
|
||||
assert len(functions) == 1
|
||||
|
|
@ -136,7 +136,7 @@ export class Calculator {
|
|||
""")
|
||||
f.flush()
|
||||
|
||||
functions = js_support.discover_functions(Path(f.name))
|
||||
functions = js_support.discover_functions(Path(f.name).read_text(encoding="utf-8"), Path(f.name))
|
||||
|
||||
assert len(functions) == 2
|
||||
for func in functions:
|
||||
|
|
@ -157,7 +157,7 @@ export function syncFunction() {
|
|||
""")
|
||||
f.flush()
|
||||
|
||||
functions = js_support.discover_functions(Path(f.name))
|
||||
functions = js_support.discover_functions(Path(f.name).read_text(encoding="utf-8"), Path(f.name))
|
||||
|
||||
assert len(functions) == 2
|
||||
|
||||
|
|
@ -182,7 +182,7 @@ export function syncFunc() {
|
|||
f.flush()
|
||||
|
||||
criteria = FunctionFilterCriteria(include_async=False)
|
||||
functions = js_support.discover_functions(Path(f.name), criteria)
|
||||
functions = js_support.discover_functions(Path(f.name).read_text(encoding="utf-8"), Path(f.name), criteria)
|
||||
|
||||
assert len(functions) == 1
|
||||
assert functions[0].function_name == "syncFunc"
|
||||
|
|
@ -204,7 +204,7 @@ export class MyClass {
|
|||
f.flush()
|
||||
|
||||
criteria = FunctionFilterCriteria(include_methods=False)
|
||||
functions = js_support.discover_functions(Path(f.name), criteria)
|
||||
functions = js_support.discover_functions(Path(f.name).read_text(encoding="utf-8"), Path(f.name), criteria)
|
||||
|
||||
assert len(functions) == 1
|
||||
assert functions[0].function_name == "standalone"
|
||||
|
|
@ -224,7 +224,7 @@ export function func2() {
|
|||
""")
|
||||
f.flush()
|
||||
|
||||
functions = js_support.discover_functions(Path(f.name))
|
||||
functions = js_support.discover_functions(Path(f.name).read_text(encoding="utf-8"), Path(f.name))
|
||||
|
||||
func1 = next(f for f in functions if f.function_name == "func1")
|
||||
func2 = next(f for f in functions if f.function_name == "func2")
|
||||
|
|
@ -246,7 +246,7 @@ export function* numberGenerator() {
|
|||
""")
|
||||
f.flush()
|
||||
|
||||
functions = js_support.discover_functions(Path(f.name))
|
||||
functions = js_support.discover_functions(Path(f.name).read_text(encoding="utf-8"), Path(f.name))
|
||||
|
||||
assert len(functions) == 1
|
||||
assert functions[0].function_name == "numberGenerator"
|
||||
|
|
@ -257,14 +257,14 @@ export function* numberGenerator() {
|
|||
f.write("this is not valid javascript {{{{")
|
||||
f.flush()
|
||||
|
||||
functions = js_support.discover_functions(Path(f.name))
|
||||
functions = js_support.discover_functions(Path(f.name).read_text(encoding="utf-8"), Path(f.name))
|
||||
# Tree-sitter is lenient, so it may still parse partial code
|
||||
# The important thing is it doesn't crash
|
||||
assert isinstance(functions, list)
|
||||
|
||||
def test_discover_nonexistent_file_returns_empty(self, js_support):
|
||||
"""Test that nonexistent file returns empty list."""
|
||||
functions = js_support.discover_functions(Path("/nonexistent/file.js"))
|
||||
functions = js_support.discover_functions("", Path("/nonexistent/file.js"))
|
||||
assert functions == []
|
||||
|
||||
def test_discover_function_expression(self, js_support):
|
||||
|
|
@ -277,7 +277,7 @@ export const add = function(a, b) {
|
|||
""")
|
||||
f.flush()
|
||||
|
||||
functions = js_support.discover_functions(Path(f.name))
|
||||
functions = js_support.discover_functions(Path(f.name).read_text(encoding="utf-8"), Path(f.name))
|
||||
|
||||
assert len(functions) == 1
|
||||
assert functions[0].function_name == "add"
|
||||
|
|
@ -296,7 +296,7 @@ export function named() {
|
|||
""")
|
||||
f.flush()
|
||||
|
||||
functions = js_support.discover_functions(Path(f.name))
|
||||
functions = js_support.discover_functions(Path(f.name).read_text(encoding="utf-8"), Path(f.name))
|
||||
|
||||
# Only the named function should be discovered
|
||||
assert len(functions) == 1
|
||||
|
|
@ -507,7 +507,7 @@ export function main(a) {
|
|||
file_path = Path(f.name)
|
||||
|
||||
# First discover functions to get accurate line numbers
|
||||
functions = js_support.discover_functions(file_path)
|
||||
functions = js_support.discover_functions(file_path.read_text(encoding="utf-8"), file_path)
|
||||
main_func = next(f for f in functions if f.function_name == "main")
|
||||
|
||||
context = js_support.extract_code_context(main_func, file_path.parent, file_path.parent)
|
||||
|
|
@ -535,7 +535,7 @@ class TestIntegration:
|
|||
file_path = Path(f.name)
|
||||
|
||||
# Discover
|
||||
functions = js_support.discover_functions(file_path)
|
||||
functions = js_support.discover_functions(file_path.read_text(encoding="utf-8"), file_path)
|
||||
assert len(functions) == 1
|
||||
func = functions[0]
|
||||
assert func.function_name == "fibonacci"
|
||||
|
|
@ -584,7 +584,7 @@ export function standalone() {
|
|||
f.flush()
|
||||
file_path = Path(f.name)
|
||||
|
||||
functions = js_support.discover_functions(file_path)
|
||||
functions = js_support.discover_functions(file_path.read_text(encoding="utf-8"), file_path)
|
||||
|
||||
# Should find 4 functions
|
||||
assert len(functions) == 4
|
||||
|
|
@ -623,7 +623,7 @@ export default Button;
|
|||
f.flush()
|
||||
file_path = Path(f.name)
|
||||
|
||||
functions = js_support.discover_functions(file_path)
|
||||
functions = js_support.discover_functions(file_path.read_text(encoding="utf-8"), file_path)
|
||||
|
||||
# Should find both components
|
||||
names = {f.function_name for f in functions}
|
||||
|
|
@ -653,7 +653,7 @@ describe('Math functions', () => {
|
|||
f.flush()
|
||||
file_path = Path(f.name)
|
||||
|
||||
source = file_path.read_text()
|
||||
source = file_path.read_text(encoding="utf-8")
|
||||
from codeflash.languages.javascript.treesitter import get_analyzer_for_file
|
||||
|
||||
analyzer = get_analyzer_for_file(file_path)
|
||||
|
|
@ -687,7 +687,7 @@ class TestClassMethodExtraction:
|
|||
file_path = Path(f.name)
|
||||
|
||||
# Discover the method
|
||||
functions = js_support.discover_functions(file_path)
|
||||
functions = js_support.discover_functions(file_path.read_text(encoding="utf-8"), file_path)
|
||||
add_method = next(f for f in functions if f.function_name == "add")
|
||||
|
||||
# Extract code context
|
||||
|
|
@ -725,7 +725,7 @@ export class Calculator {
|
|||
f.flush()
|
||||
file_path = Path(f.name)
|
||||
|
||||
functions = js_support.discover_functions(file_path)
|
||||
functions = js_support.discover_functions(file_path.read_text(encoding="utf-8"), file_path)
|
||||
add_method = next(f for f in functions if f.function_name == "add")
|
||||
|
||||
context = js_support.extract_code_context(add_method, file_path.parent, file_path.parent)
|
||||
|
|
@ -763,7 +763,7 @@ export class Calculator {
|
|||
f.flush()
|
||||
file_path = Path(f.name)
|
||||
|
||||
functions = js_support.discover_functions(file_path)
|
||||
functions = js_support.discover_functions(file_path.read_text(encoding="utf-8"), file_path)
|
||||
fib_method = next(f for f in functions if f.function_name == "fibonacci")
|
||||
|
||||
context = js_support.extract_code_context(fib_method, file_path.parent, file_path.parent)
|
||||
|
|
@ -802,7 +802,7 @@ export class Calculator {
|
|||
f.flush()
|
||||
file_path = Path(f.name)
|
||||
|
||||
functions = js_support.discover_functions(file_path)
|
||||
functions = js_support.discover_functions(file_path.read_text(encoding="utf-8"), file_path)
|
||||
add_method = next((f for f in functions if f.function_name == "add"), None)
|
||||
|
||||
if add_method:
|
||||
|
|
@ -832,7 +832,7 @@ export class Calculator {
|
|||
f.flush()
|
||||
file_path = Path(f.name)
|
||||
|
||||
functions = js_support.discover_functions(file_path)
|
||||
functions = js_support.discover_functions(file_path.read_text(encoding="utf-8"), file_path)
|
||||
fetch_method = next(f for f in functions if f.function_name == "fetchData")
|
||||
|
||||
context = js_support.extract_code_context(fetch_method, file_path.parent, file_path.parent)
|
||||
|
|
@ -865,7 +865,7 @@ export class Calculator {
|
|||
f.flush()
|
||||
file_path = Path(f.name)
|
||||
|
||||
functions = js_support.discover_functions(file_path)
|
||||
functions = js_support.discover_functions(file_path.read_text(encoding="utf-8"), file_path)
|
||||
add_method = next((f for f in functions if f.function_name == "add"), None)
|
||||
|
||||
if add_method:
|
||||
|
|
@ -894,7 +894,7 @@ export class Calculator {
|
|||
f.flush()
|
||||
file_path = Path(f.name)
|
||||
|
||||
functions = js_support.discover_functions(file_path)
|
||||
functions = js_support.discover_functions(file_path.read_text(encoding="utf-8"), file_path)
|
||||
method = next(f for f in functions if f.function_name == "simpleMethod")
|
||||
|
||||
context = js_support.extract_code_context(method, file_path.parent, file_path.parent)
|
||||
|
|
@ -1079,7 +1079,7 @@ class TestClassMethodEdgeCases:
|
|||
f.flush()
|
||||
file_path = Path(f.name)
|
||||
|
||||
functions = js_support.discover_functions(file_path)
|
||||
functions = js_support.discover_functions(file_path.read_text(encoding="utf-8"), file_path)
|
||||
|
||||
# Should find constructor and increment
|
||||
names = {f.function_name for f in functions}
|
||||
|
|
@ -1109,7 +1109,7 @@ class TestClassMethodEdgeCases:
|
|||
f.flush()
|
||||
file_path = Path(f.name)
|
||||
|
||||
functions = js_support.discover_functions(file_path)
|
||||
functions = js_support.discover_functions(file_path.read_text(encoding="utf-8"), file_path)
|
||||
|
||||
# Should find at least greet
|
||||
names = {f.function_name for f in functions}
|
||||
|
|
@ -1137,7 +1137,7 @@ export class Dog extends Animal {
|
|||
f.flush()
|
||||
file_path = Path(f.name)
|
||||
|
||||
functions = js_support.discover_functions(file_path)
|
||||
functions = js_support.discover_functions(file_path.read_text(encoding="utf-8"), file_path)
|
||||
|
||||
# Find Dog's fetch method
|
||||
fetch_method = next((f for f in functions if f.function_name == "fetch" and f.class_name == "Dog"), None)
|
||||
|
|
@ -1172,7 +1172,7 @@ export class Dog extends Animal {
|
|||
f.flush()
|
||||
file_path = Path(f.name)
|
||||
|
||||
functions = js_support.discover_functions(file_path)
|
||||
functions = js_support.discover_functions(file_path.read_text(encoding="utf-8"), file_path)
|
||||
|
||||
# Should at least find publicMethod
|
||||
names = {f.function_name for f in functions}
|
||||
|
|
@ -1192,7 +1192,7 @@ module.exports = { Calculator };
|
|||
f.flush()
|
||||
file_path = Path(f.name)
|
||||
|
||||
functions = js_support.discover_functions(file_path)
|
||||
functions = js_support.discover_functions(file_path.read_text(encoding="utf-8"), file_path)
|
||||
add_method = next(f for f in functions if f.function_name == "add")
|
||||
|
||||
context = js_support.extract_code_context(add_method, file_path.parent, file_path.parent)
|
||||
|
|
@ -1212,7 +1212,7 @@ module.exports = { Calculator };
|
|||
f.flush()
|
||||
file_path = Path(f.name)
|
||||
|
||||
functions = js_support.discover_functions(file_path)
|
||||
functions = js_support.discover_functions(file_path.read_text(encoding="utf-8"), file_path)
|
||||
|
||||
# Find the add method
|
||||
add_method = next((f for f in functions if f.function_name == "add"), None)
|
||||
|
|
@ -1265,7 +1265,7 @@ module.exports = { Counter };
|
|||
f.flush()
|
||||
file_path = Path(f.name)
|
||||
|
||||
functions = js_support.discover_functions(file_path)
|
||||
functions = js_support.discover_functions(file_path.read_text(encoding="utf-8"), file_path)
|
||||
increment_func = next(fn for fn in functions if fn.function_name == "increment")
|
||||
|
||||
# Step 1: Extract code context (includes constructor for AI context)
|
||||
|
|
@ -1362,7 +1362,7 @@ export class User {
|
|||
f.flush()
|
||||
file_path = Path(f.name)
|
||||
|
||||
functions = ts_support.discover_functions(file_path)
|
||||
functions = ts_support.discover_functions(file_path.read_text(encoding="utf-8"), file_path)
|
||||
get_name_func = next(fn for fn in functions if fn.function_name == "getName")
|
||||
|
||||
# Step 1: Extract code context (includes fields and constructor)
|
||||
|
|
@ -1462,7 +1462,7 @@ export class Calculator {
|
|||
f.flush()
|
||||
file_path = Path(f.name)
|
||||
|
||||
functions = js_support.discover_functions(file_path)
|
||||
functions = js_support.discover_functions(file_path.read_text(encoding="utf-8"), file_path)
|
||||
add_func = next(fn for fn in functions if fn.function_name == "add")
|
||||
|
||||
# Extract context for add
|
||||
|
|
@ -1546,7 +1546,7 @@ export class MathUtils {
|
|||
f.flush()
|
||||
file_path = Path(f.name)
|
||||
|
||||
functions = js_support.discover_functions(file_path)
|
||||
functions = js_support.discover_functions(file_path.read_text(encoding="utf-8"), file_path)
|
||||
add_func = next(fn for fn in functions if fn.function_name == "add")
|
||||
|
||||
# Extract context
|
||||
|
|
|
|||
|
|
@ -53,7 +53,7 @@ describe('add function', () => {
|
|||
""")
|
||||
|
||||
# Discover functions first
|
||||
functions = js_support.discover_functions(source_file)
|
||||
functions = js_support.discover_functions(source_file.read_text(encoding="utf-8"), source_file)
|
||||
assert len(functions) == 1
|
||||
|
||||
# Discover tests
|
||||
|
|
@ -90,7 +90,7 @@ describe('multiply', () => {
|
|||
});
|
||||
""")
|
||||
|
||||
functions = js_support.discover_functions(source_file)
|
||||
functions = js_support.discover_functions(source_file.read_text(encoding="utf-8"), source_file)
|
||||
tests = js_support.discover_tests(tmpdir, functions)
|
||||
|
||||
assert len(tests) > 0
|
||||
|
|
@ -124,7 +124,7 @@ test('formats date correctly', () => {
|
|||
});
|
||||
""")
|
||||
|
||||
functions = js_support.discover_functions(source_file)
|
||||
functions = js_support.discover_functions(source_file.read_text(encoding="utf-8"), source_file)
|
||||
tests = js_support.discover_tests(tmpdir, functions)
|
||||
|
||||
assert len(tests) > 0
|
||||
|
|
@ -170,7 +170,7 @@ describe('String Utils', () => {
|
|||
});
|
||||
""")
|
||||
|
||||
functions = js_support.discover_functions(source_file)
|
||||
functions = js_support.discover_functions(source_file.read_text(encoding="utf-8"), source_file)
|
||||
tests = js_support.discover_tests(tmpdir, functions)
|
||||
|
||||
assert len(tests) > 0
|
||||
|
|
@ -208,7 +208,7 @@ describe('sum function', () => {
|
|||
});
|
||||
""")
|
||||
|
||||
functions = js_support.discover_functions(source_file)
|
||||
functions = js_support.discover_functions(source_file.read_text(encoding="utf-8"), source_file)
|
||||
tests = js_support.discover_tests(tmpdir, functions)
|
||||
|
||||
assert len(tests) > 0
|
||||
|
|
@ -242,7 +242,7 @@ test('subtract two numbers', () => {
|
|||
});
|
||||
""")
|
||||
|
||||
functions = js_support.discover_functions(source_file)
|
||||
functions = js_support.discover_functions(source_file.read_text(encoding="utf-8"), source_file)
|
||||
tests = js_support.discover_tests(tmpdir, functions)
|
||||
|
||||
assert len(tests) > 0
|
||||
|
|
@ -270,7 +270,7 @@ test('greets by name', () => {
|
|||
});
|
||||
""")
|
||||
|
||||
functions = js_support.discover_functions(source_file)
|
||||
functions = js_support.discover_functions(source_file.read_text(encoding="utf-8"), source_file)
|
||||
tests = js_support.discover_tests(tmpdir, functions)
|
||||
|
||||
assert len(tests) > 0
|
||||
|
|
@ -316,7 +316,7 @@ describe('Calculator class', () => {
|
|||
});
|
||||
""")
|
||||
|
||||
functions = js_support.discover_functions(source_file)
|
||||
functions = js_support.discover_functions(source_file.read_text(encoding="utf-8"), source_file)
|
||||
tests = js_support.discover_tests(tmpdir, functions)
|
||||
|
||||
# Should find tests for class methods
|
||||
|
|
@ -363,7 +363,7 @@ describe('clamp', () => {
|
|||
});
|
||||
""")
|
||||
|
||||
functions = js_support.discover_functions(source_file)
|
||||
functions = js_support.discover_functions(source_file.read_text(encoding="utf-8"), source_file)
|
||||
tests = js_support.discover_tests(tmpdir, functions)
|
||||
|
||||
assert len(tests) > 0
|
||||
|
|
@ -399,7 +399,7 @@ describe('async utilities', () => {
|
|||
});
|
||||
""")
|
||||
|
||||
functions = js_support.discover_functions(source_file)
|
||||
functions = js_support.discover_functions(source_file.read_text(encoding="utf-8"), source_file)
|
||||
tests = js_support.discover_tests(tmpdir, functions)
|
||||
|
||||
assert len(tests) > 0
|
||||
|
|
@ -436,7 +436,7 @@ describe('Button component', () => {
|
|||
});
|
||||
""")
|
||||
|
||||
functions = js_support.discover_functions(source_file)
|
||||
functions = js_support.discover_functions(source_file.read_text(encoding="utf-8"), source_file)
|
||||
tests = js_support.discover_tests(tmpdir, functions)
|
||||
|
||||
# JSX tests should be discovered
|
||||
|
|
@ -466,7 +466,7 @@ test('other test', () => {
|
|||
});
|
||||
""")
|
||||
|
||||
functions = js_support.discover_functions(source_file)
|
||||
functions = js_support.discover_functions(source_file.read_text(encoding="utf-8"), source_file)
|
||||
tests = js_support.discover_tests(tmpdir, functions)
|
||||
|
||||
# Should not find tests for our function
|
||||
|
|
@ -502,7 +502,7 @@ describe('validators', () => {
|
|||
});
|
||||
""")
|
||||
|
||||
functions = js_support.discover_functions(source_file)
|
||||
functions = js_support.discover_functions(source_file.read_text(encoding="utf-8"), source_file)
|
||||
tests = js_support.discover_tests(tmpdir, functions)
|
||||
|
||||
# Should find tests for isEmail
|
||||
|
|
@ -546,7 +546,7 @@ test('helper2 returns 2', () => {
|
|||
});
|
||||
""")
|
||||
|
||||
functions = js_support.discover_functions(source_file)
|
||||
functions = js_support.discover_functions(source_file.read_text(encoding="utf-8"), source_file)
|
||||
tests = js_support.discover_tests(tmpdir, functions)
|
||||
|
||||
assert len(tests) > 0
|
||||
|
|
@ -574,7 +574,7 @@ test(`formatNumber with decimal`, () => {
|
|||
});
|
||||
""")
|
||||
|
||||
functions = js_support.discover_functions(source_file)
|
||||
functions = js_support.discover_functions(source_file.read_text(encoding="utf-8"), source_file)
|
||||
tests = js_support.discover_tests(tmpdir, functions)
|
||||
|
||||
# May or may not find depending on template literal handling
|
||||
|
|
@ -605,7 +605,7 @@ describe('transform', () => {
|
|||
});
|
||||
""")
|
||||
|
||||
functions = js_support.discover_functions(source_file)
|
||||
functions = js_support.discover_functions(source_file.read_text(encoding="utf-8"), source_file)
|
||||
tests = js_support.discover_tests(tmpdir, functions)
|
||||
|
||||
# Should still find tests since original name is imported
|
||||
|
|
@ -626,7 +626,7 @@ it('third test', () => {});
|
|||
f.flush()
|
||||
file_path = Path(f.name)
|
||||
|
||||
source = file_path.read_text()
|
||||
source = file_path.read_text(encoding="utf-8")
|
||||
from codeflash.languages.javascript.treesitter import get_analyzer_for_file
|
||||
|
||||
analyzer = get_analyzer_for_file(file_path)
|
||||
|
|
@ -651,7 +651,7 @@ describe('Suite B', () => {
|
|||
f.flush()
|
||||
file_path = Path(f.name)
|
||||
|
||||
source = file_path.read_text()
|
||||
source = file_path.read_text(encoding="utf-8")
|
||||
from codeflash.languages.javascript.treesitter import get_analyzer_for_file
|
||||
|
||||
analyzer = get_analyzer_for_file(file_path)
|
||||
|
|
@ -675,7 +675,7 @@ describe('Outer', () => {
|
|||
f.flush()
|
||||
file_path = Path(f.name)
|
||||
|
||||
source = file_path.read_text()
|
||||
source = file_path.read_text(encoding="utf-8")
|
||||
from codeflash.languages.javascript.treesitter import get_analyzer_for_file
|
||||
|
||||
analyzer = get_analyzer_for_file(file_path)
|
||||
|
|
@ -699,7 +699,7 @@ describe.skip('skipped describe', () => {
|
|||
f.flush()
|
||||
file_path = Path(f.name)
|
||||
|
||||
source = file_path.read_text()
|
||||
source = file_path.read_text(encoding="utf-8")
|
||||
from codeflash.languages.javascript.treesitter import get_analyzer_for_file
|
||||
|
||||
analyzer = get_analyzer_for_file(file_path)
|
||||
|
|
@ -720,7 +720,7 @@ describe.only('only describe', () => {
|
|||
f.flush()
|
||||
file_path = Path(f.name)
|
||||
|
||||
source = file_path.read_text()
|
||||
source = file_path.read_text(encoding="utf-8")
|
||||
from codeflash.languages.javascript.treesitter import get_analyzer_for_file
|
||||
|
||||
analyzer = get_analyzer_for_file(file_path)
|
||||
|
|
@ -738,7 +738,7 @@ describe('describe single', () => {});
|
|||
f.flush()
|
||||
file_path = Path(f.name)
|
||||
|
||||
source = file_path.read_text()
|
||||
source = file_path.read_text(encoding="utf-8")
|
||||
from codeflash.languages.javascript.treesitter import get_analyzer_for_file
|
||||
|
||||
analyzer = get_analyzer_for_file(file_path)
|
||||
|
|
@ -757,7 +757,7 @@ describe("describe double", () => {});
|
|||
f.flush()
|
||||
file_path = Path(f.name)
|
||||
|
||||
source = file_path.read_text()
|
||||
source = file_path.read_text(encoding="utf-8")
|
||||
from codeflash.languages.javascript.treesitter import get_analyzer_for_file
|
||||
|
||||
analyzer = get_analyzer_for_file(file_path)
|
||||
|
|
@ -773,7 +773,7 @@ describe("describe double", () => {});
|
|||
f.flush()
|
||||
file_path = Path(f.name)
|
||||
|
||||
source = file_path.read_text()
|
||||
source = file_path.read_text(encoding="utf-8")
|
||||
from codeflash.languages.javascript.treesitter import get_analyzer_for_file
|
||||
|
||||
analyzer = get_analyzer_for_file(file_path)
|
||||
|
|
@ -806,7 +806,7 @@ test('funcA works', () => {
|
|||
});
|
||||
""")
|
||||
|
||||
functions = js_support.discover_functions(source_file)
|
||||
functions = js_support.discover_functions(source_file.read_text(encoding="utf-8"), source_file)
|
||||
tests = js_support.discover_tests(tmpdir, functions)
|
||||
|
||||
# funcA should have tests
|
||||
|
|
@ -833,7 +833,7 @@ test('funcX works', () => {
|
|||
});
|
||||
""")
|
||||
|
||||
functions = js_support.discover_functions(source_file)
|
||||
functions = js_support.discover_functions(source_file.read_text(encoding="utf-8"), source_file)
|
||||
tests = js_support.discover_tests(tmpdir, functions)
|
||||
|
||||
# funcX should have tests
|
||||
|
|
@ -859,7 +859,7 @@ test('mainFunc works', () => {
|
|||
});
|
||||
""")
|
||||
|
||||
functions = js_support.discover_functions(source_file)
|
||||
functions = js_support.discover_functions(source_file.read_text(encoding="utf-8"), source_file)
|
||||
tests = js_support.discover_tests(tmpdir, functions)
|
||||
|
||||
assert len(tests) > 0
|
||||
|
|
@ -896,7 +896,7 @@ test('block commented', () => {
|
|||
*/
|
||||
""")
|
||||
|
||||
functions = js_support.discover_functions(source_file)
|
||||
functions = js_support.discover_functions(source_file.read_text(encoding="utf-8"), source_file)
|
||||
tests = js_support.discover_tests(tmpdir, functions)
|
||||
|
||||
assert len(tests) > 0
|
||||
|
|
@ -921,7 +921,7 @@ test('broken test' { // Missing arrow function
|
|||
});
|
||||
""")
|
||||
|
||||
functions = js_support.discover_functions(source_file)
|
||||
functions = js_support.discover_functions(source_file.read_text(encoding="utf-8"), source_file)
|
||||
# Should not crash
|
||||
tests = js_support.discover_tests(tmpdir, functions)
|
||||
assert isinstance(tests, dict)
|
||||
|
|
@ -949,7 +949,7 @@ describe('conflict tests', () => {
|
|||
});
|
||||
""")
|
||||
|
||||
functions = js_support.discover_functions(source_file)
|
||||
functions = js_support.discover_functions(source_file.read_text(encoding="utf-8"), source_file)
|
||||
tests = js_support.discover_tests(tmpdir, functions)
|
||||
|
||||
# Should still work despite naming conflicts
|
||||
|
|
@ -966,7 +966,7 @@ export function lonelyFunc() { return 'alone'; }
|
|||
module.exports = { lonelyFunc };
|
||||
""")
|
||||
|
||||
functions = js_support.discover_functions(source_file)
|
||||
functions = js_support.discover_functions(source_file.read_text(encoding="utf-8"), source_file)
|
||||
tests = js_support.discover_tests(tmpdir, functions)
|
||||
|
||||
# Should return empty dict, not crash
|
||||
|
|
@ -1001,7 +1001,7 @@ test('funcA works', () => {
|
|||
});
|
||||
""")
|
||||
|
||||
functions_a = js_support.discover_functions(file_a)
|
||||
functions_a = js_support.discover_functions(file_a.read_text(encoding="utf-8"), file_a)
|
||||
tests = js_support.discover_tests(tmpdir, functions_a)
|
||||
|
||||
# Should handle circular imports gracefully
|
||||
|
|
@ -1047,7 +1047,7 @@ test.each([
|
|||
f.flush()
|
||||
file_path = Path(f.name)
|
||||
|
||||
source = file_path.read_text()
|
||||
source = file_path.read_text(encoding="utf-8")
|
||||
from codeflash.languages.javascript.treesitter import get_analyzer_for_file
|
||||
|
||||
analyzer = get_analyzer_for_file(file_path)
|
||||
|
|
@ -1073,7 +1073,7 @@ describe.each([
|
|||
f.flush()
|
||||
file_path = Path(f.name)
|
||||
|
||||
source = file_path.read_text()
|
||||
source = file_path.read_text(encoding="utf-8")
|
||||
from codeflash.languages.javascript.treesitter import get_analyzer_for_file
|
||||
|
||||
analyzer = get_analyzer_for_file(file_path)
|
||||
|
|
@ -1098,7 +1098,7 @@ describe('Math operations', () => {
|
|||
f.flush()
|
||||
file_path = Path(f.name)
|
||||
|
||||
source = file_path.read_text()
|
||||
source = file_path.read_text(encoding="utf-8")
|
||||
from codeflash.languages.javascript.treesitter import get_analyzer_for_file
|
||||
|
||||
analyzer = get_analyzer_for_file(file_path)
|
||||
|
|
@ -1174,7 +1174,7 @@ describe('formatName', () => {
|
|||
""")
|
||||
|
||||
# Discover functions
|
||||
functions = js_support.discover_functions(source_file)
|
||||
functions = js_support.discover_functions(source_file.read_text(encoding="utf-8"), source_file)
|
||||
assert len(functions) == 3
|
||||
|
||||
# Discover tests
|
||||
|
|
@ -1242,7 +1242,7 @@ describe('Database', () => {
|
|||
});
|
||||
""")
|
||||
|
||||
functions = js_support.discover_functions(source_file)
|
||||
functions = js_support.discover_functions(source_file.read_text(encoding="utf-8"), source_file)
|
||||
tests = js_support.discover_tests(tmpdir, functions)
|
||||
|
||||
assert len(tests) > 0
|
||||
|
|
@ -1280,7 +1280,7 @@ test('funcA works', () => {
|
|||
""")
|
||||
|
||||
# Discover functions from moduleB
|
||||
functions_b = js_support.discover_functions(source_b)
|
||||
functions_b = js_support.discover_functions(source_b.read_text(encoding="utf-8"), source_b)
|
||||
tests = js_support.discover_tests(tmpdir, functions_b)
|
||||
|
||||
# funcB should not have any tests since test file doesn't import it
|
||||
|
|
@ -1312,7 +1312,7 @@ test('funcOne returns 1', () => {
|
|||
});
|
||||
""")
|
||||
|
||||
functions = js_support.discover_functions(source_file)
|
||||
functions = js_support.discover_functions(source_file.read_text(encoding="utf-8"), source_file)
|
||||
tests = js_support.discover_tests(tmpdir, functions)
|
||||
|
||||
# Check that tests were found
|
||||
|
|
@ -1340,7 +1340,7 @@ test('mentions targetFunc in string', () => {
|
|||
});
|
||||
""")
|
||||
|
||||
functions = js_support.discover_functions(source_file)
|
||||
functions = js_support.discover_functions(source_file.read_text(encoding="utf-8"), source_file)
|
||||
tests = js_support.discover_tests(tmpdir, functions)
|
||||
|
||||
# Current implementation may still match on string occurrence
|
||||
|
|
@ -1367,7 +1367,7 @@ test('calculate doubles', () => {
|
|||
});
|
||||
""")
|
||||
|
||||
functions = js_support.discover_functions(source_file)
|
||||
functions = js_support.discover_functions(source_file.read_text(encoding="utf-8"), source_file)
|
||||
tests = js_support.discover_tests(tmpdir, functions)
|
||||
|
||||
# Should find tests since 'calculate' appears in source
|
||||
|
|
@ -1399,7 +1399,7 @@ describe('MyClass', () => {
|
|||
});
|
||||
""")
|
||||
|
||||
functions = js_support.discover_functions(source_file)
|
||||
functions = js_support.discover_functions(source_file.read_text(encoding="utf-8"), source_file)
|
||||
tests = js_support.discover_tests(tmpdir, functions)
|
||||
|
||||
# Should find tests for class methods
|
||||
|
|
@ -1432,7 +1432,7 @@ test('deepHelper works', () => {
|
|||
});
|
||||
""")
|
||||
|
||||
functions = js_support.discover_functions(source_file)
|
||||
functions = js_support.discover_functions(source_file.read_text(encoding="utf-8"), source_file)
|
||||
tests = js_support.discover_tests(tmpdir, functions)
|
||||
|
||||
assert len(tests) > 0
|
||||
|
|
@ -1456,7 +1456,7 @@ testCases.forEach(name => {
|
|||
f.flush()
|
||||
file_path = Path(f.name)
|
||||
|
||||
source = file_path.read_text()
|
||||
source = file_path.read_text(encoding="utf-8")
|
||||
from codeflash.languages.javascript.treesitter import get_analyzer_for_file
|
||||
|
||||
analyzer = get_analyzer_for_file(file_path)
|
||||
|
|
@ -1484,7 +1484,7 @@ describe('conditional tests', () => {
|
|||
f.flush()
|
||||
file_path = Path(f.name)
|
||||
|
||||
source = file_path.read_text()
|
||||
source = file_path.read_text(encoding="utf-8")
|
||||
from codeflash.languages.javascript.treesitter import get_analyzer_for_file
|
||||
|
||||
analyzer = get_analyzer_for_file(file_path)
|
||||
|
|
@ -1508,7 +1508,7 @@ test('slow test', () => {
|
|||
f.flush()
|
||||
file_path = Path(f.name)
|
||||
|
||||
source = file_path.read_text()
|
||||
source = file_path.read_text(encoding="utf-8")
|
||||
from codeflash.languages.javascript.treesitter import get_analyzer_for_file
|
||||
|
||||
analyzer = get_analyzer_for_file(file_path)
|
||||
|
|
@ -1531,7 +1531,7 @@ test.todo('also needs implementation');
|
|||
f.flush()
|
||||
file_path = Path(f.name)
|
||||
|
||||
source = file_path.read_text()
|
||||
source = file_path.read_text(encoding="utf-8")
|
||||
from codeflash.languages.javascript.treesitter import get_analyzer_for_file
|
||||
|
||||
analyzer = get_analyzer_for_file(file_path)
|
||||
|
|
@ -1554,7 +1554,7 @@ test.concurrent('concurrent test 2', async () => {
|
|||
f.flush()
|
||||
file_path = Path(f.name)
|
||||
|
||||
source = file_path.read_text()
|
||||
source = file_path.read_text(encoding="utf-8")
|
||||
from codeflash.languages.javascript.treesitter import get_analyzer_for_file
|
||||
|
||||
analyzer = get_analyzer_for_file(file_path)
|
||||
|
|
@ -1597,7 +1597,7 @@ describe('subtractNumbers', () => {
|
|||
});
|
||||
""")
|
||||
|
||||
functions = js_support.discover_functions(source_file)
|
||||
functions = js_support.discover_functions(source_file.read_text(encoding="utf-8"), source_file)
|
||||
tests = js_support.discover_tests(tmpdir, functions)
|
||||
|
||||
# All three functions should be discovered
|
||||
|
|
@ -1628,7 +1628,7 @@ describe('Unrelated name', () => {
|
|||
});
|
||||
""")
|
||||
|
||||
functions = js_support.discover_functions(source_file)
|
||||
functions = js_support.discover_functions(source_file.read_text(encoding="utf-8"), source_file)
|
||||
tests = js_support.discover_tests(tmpdir, functions)
|
||||
|
||||
# Should still find tests
|
||||
|
|
@ -1653,7 +1653,7 @@ describe('Array', function() {
|
|||
f.flush()
|
||||
file_path = Path(f.name)
|
||||
|
||||
source = file_path.read_text()
|
||||
source = file_path.read_text(encoding="utf-8")
|
||||
from codeflash.languages.javascript.treesitter import get_analyzer_for_file
|
||||
|
||||
analyzer = get_analyzer_for_file(file_path)
|
||||
|
|
@ -1684,7 +1684,7 @@ describe('User', () => {
|
|||
f.flush()
|
||||
file_path = Path(f.name)
|
||||
|
||||
source = file_path.read_text()
|
||||
source = file_path.read_text(encoding="utf-8")
|
||||
from codeflash.languages.javascript.treesitter import get_analyzer_for_file
|
||||
|
||||
analyzer = get_analyzer_for_file(file_path)
|
||||
|
|
@ -1712,7 +1712,7 @@ export class Calculator {
|
|||
module.exports = { Calculator };
|
||||
""")
|
||||
|
||||
functions = js_support.discover_functions(source_file)
|
||||
functions = js_support.discover_functions(source_file.read_text(encoding="utf-8"), source_file)
|
||||
|
||||
# Check qualified names include class
|
||||
add_func = next((f for f in functions if f.function_name == "add"), None)
|
||||
|
|
@ -1737,7 +1737,7 @@ export class Outer {
|
|||
module.exports = { Outer };
|
||||
""")
|
||||
|
||||
functions = js_support.discover_functions(source_file)
|
||||
functions = js_support.discover_functions(source_file.read_text(encoding="utf-8"), source_file)
|
||||
|
||||
# Should find at least the Outer class method
|
||||
assert any(f.class_name == "Outer" for f in functions)
|
||||
|
|
|
|||
|
|
@ -13,7 +13,7 @@ from codeflash.languages.base import Language
|
|||
from codeflash.languages.javascript.support import JavaScriptSupport, TypeScriptSupport
|
||||
from codeflash.languages.registry import get_language_support
|
||||
from codeflash.models.models import FunctionParent
|
||||
from codeflash.optimization.function_optimizer import FunctionOptimizer
|
||||
from codeflash.languages.javascript.function_optimizer import JavaScriptFunctionOptimizer
|
||||
from codeflash.verification.verification_utils import TestConfig
|
||||
|
||||
FIXTURES_DIR = Path(__file__).parent / "fixtures"
|
||||
|
|
@ -37,7 +37,7 @@ class TestCodeExtractorCJS:
|
|||
def test_discover_class_methods(self, js_support, cjs_project):
|
||||
"""Test that class methods are discovered correctly."""
|
||||
calculator_file = cjs_project / "calculator.js"
|
||||
functions = js_support.discover_functions(calculator_file)
|
||||
functions = js_support.discover_functions(calculator_file.read_text(encoding="utf-8"), calculator_file)
|
||||
|
||||
method_names = {f.function_name for f in functions}
|
||||
|
||||
|
|
@ -47,17 +47,19 @@ class TestCodeExtractorCJS:
|
|||
def test_class_method_has_correct_parent(self, js_support, cjs_project):
|
||||
"""Test parent class information for methods."""
|
||||
calculator_file = cjs_project / "calculator.js"
|
||||
functions = js_support.discover_functions(calculator_file)
|
||||
functions = js_support.discover_functions(calculator_file.read_text(encoding="utf-8"), calculator_file)
|
||||
|
||||
for func in functions:
|
||||
# All methods should belong to Calculator class
|
||||
assert func.is_method is True, f"{func.function_name} should be a method"
|
||||
assert func.class_name == "Calculator", f"{func.function_name} should belong to Calculator, got {func.class_name}"
|
||||
assert func.class_name == "Calculator", (
|
||||
f"{func.function_name} should belong to Calculator, got {func.class_name}"
|
||||
)
|
||||
|
||||
def test_extract_permutation_code(self, js_support, cjs_project):
|
||||
"""Test permutation method code extraction."""
|
||||
calculator_file = cjs_project / "calculator.js"
|
||||
functions = js_support.discover_functions(calculator_file)
|
||||
functions = js_support.discover_functions(calculator_file.read_text(encoding="utf-8"), calculator_file)
|
||||
|
||||
permutation_func = next(f for f in functions if f.function_name == "permutation")
|
||||
|
||||
|
|
@ -93,7 +95,7 @@ class Calculator {
|
|||
def test_extract_context_includes_direct_helpers(self, js_support, cjs_project):
|
||||
"""Test that direct helper functions are included in context."""
|
||||
calculator_file = cjs_project / "calculator.js"
|
||||
functions = js_support.discover_functions(calculator_file)
|
||||
functions = js_support.discover_functions(calculator_file.read_text(encoding="utf-8"), calculator_file)
|
||||
|
||||
permutation_func = next(f for f in functions if f.function_name == "permutation")
|
||||
|
||||
|
|
@ -129,7 +131,7 @@ export function factorial(n) {
|
|||
def test_extract_compound_interest_code(self, js_support, cjs_project):
|
||||
"""Test calculateCompoundInterest code extraction."""
|
||||
calculator_file = cjs_project / "calculator.js"
|
||||
functions = js_support.discover_functions(calculator_file)
|
||||
functions = js_support.discover_functions(calculator_file.read_text(encoding="utf-8"), calculator_file)
|
||||
|
||||
compound_func = next(f for f in functions if f.function_name == "calculateCompoundInterest")
|
||||
|
||||
|
|
@ -175,7 +177,7 @@ class Calculator {
|
|||
def test_extract_compound_interest_helpers(self, js_support, cjs_project):
|
||||
"""Test helper extraction for calculateCompoundInterest."""
|
||||
calculator_file = cjs_project / "calculator.js"
|
||||
functions = js_support.discover_functions(calculator_file)
|
||||
functions = js_support.discover_functions(calculator_file.read_text(encoding="utf-8"), calculator_file)
|
||||
|
||||
compound_func = next(f for f in functions if f.function_name == "calculateCompoundInterest")
|
||||
|
||||
|
|
@ -235,7 +237,7 @@ export function validateInput(value, name) {
|
|||
def test_extract_context_includes_imports(self, js_support, cjs_project):
|
||||
"""Test import statement extraction."""
|
||||
calculator_file = cjs_project / "calculator.js"
|
||||
functions = js_support.discover_functions(calculator_file)
|
||||
functions = js_support.discover_functions(calculator_file.read_text(encoding="utf-8"), calculator_file)
|
||||
|
||||
compound_func = next(f for f in functions if f.function_name == "calculateCompoundInterest")
|
||||
|
||||
|
|
@ -256,7 +258,7 @@ export function validateInput(value, name) {
|
|||
def test_extract_static_method(self, js_support, cjs_project):
|
||||
"""Test static method extraction (quickAdd)."""
|
||||
calculator_file = cjs_project / "calculator.js"
|
||||
functions = js_support.discover_functions(calculator_file)
|
||||
functions = js_support.discover_functions(calculator_file.read_text(encoding="utf-8"), calculator_file)
|
||||
|
||||
quick_add_func = next(f for f in functions if f.function_name == "quickAdd")
|
||||
|
||||
|
|
@ -315,7 +317,7 @@ class TestCodeExtractorESM:
|
|||
def test_discover_esm_methods(self, js_support, esm_project):
|
||||
"""Test method discovery in ESM project."""
|
||||
calculator_file = esm_project / "calculator.js"
|
||||
functions = js_support.discover_functions(calculator_file)
|
||||
functions = js_support.discover_functions(calculator_file.read_text(encoding="utf-8"), calculator_file)
|
||||
|
||||
method_names = {f.function_name for f in functions}
|
||||
|
||||
|
|
@ -326,7 +328,7 @@ class TestCodeExtractorESM:
|
|||
def test_esm_permutation_extraction(self, js_support, esm_project):
|
||||
"""Test permutation method extraction in ESM."""
|
||||
calculator_file = esm_project / "calculator.js"
|
||||
functions = js_support.discover_functions(calculator_file)
|
||||
functions = js_support.discover_functions(calculator_file.read_text(encoding="utf-8"), calculator_file)
|
||||
|
||||
permutation_func = next(f for f in functions if f.function_name == "permutation")
|
||||
|
||||
|
|
@ -376,7 +378,7 @@ export function factorial(n) {
|
|||
def test_esm_compound_interest_extraction(self, js_support, esm_project):
|
||||
"""Test calculateCompoundInterest extraction in ESM with import syntax."""
|
||||
calculator_file = esm_project / "calculator.js"
|
||||
functions = js_support.discover_functions(calculator_file)
|
||||
functions = js_support.discover_functions(calculator_file.read_text(encoding="utf-8"), calculator_file)
|
||||
|
||||
compound_func = next(f for f in functions if f.function_name == "calculateCompoundInterest")
|
||||
|
||||
|
|
@ -502,7 +504,7 @@ class TestCodeExtractorTypeScript:
|
|||
def test_discover_ts_methods(self, ts_support, ts_project):
|
||||
"""Test method discovery in TypeScript."""
|
||||
calculator_file = ts_project / "calculator.ts"
|
||||
functions = ts_support.discover_functions(calculator_file)
|
||||
functions = ts_support.discover_functions(calculator_file.read_text(encoding="utf-8"), calculator_file)
|
||||
|
||||
method_names = {f.function_name for f in functions}
|
||||
|
||||
|
|
@ -513,7 +515,7 @@ class TestCodeExtractorTypeScript:
|
|||
def test_ts_permutation_extraction(self, ts_support, ts_project):
|
||||
"""Test permutation method extraction in TypeScript."""
|
||||
calculator_file = ts_project / "calculator.ts"
|
||||
functions = ts_support.discover_functions(calculator_file)
|
||||
functions = ts_support.discover_functions(calculator_file.read_text(encoding="utf-8"), calculator_file)
|
||||
|
||||
permutation_func = next(f for f in functions if f.function_name == "permutation")
|
||||
|
||||
|
|
@ -566,7 +568,7 @@ export function factorial(n: number): number {
|
|||
def test_ts_compound_interest_extraction(self, ts_support, ts_project):
|
||||
"""Test calculateCompoundInterest extraction in TypeScript."""
|
||||
calculator_file = ts_project / "calculator.ts"
|
||||
functions = ts_support.discover_functions(calculator_file)
|
||||
functions = ts_support.discover_functions(calculator_file.read_text(encoding="utf-8"), calculator_file)
|
||||
|
||||
compound_func = next(f for f in functions if f.function_name == "calculateCompoundInterest")
|
||||
|
||||
|
|
@ -676,7 +678,7 @@ module.exports = { standalone };
|
|||
test_file = tmp_path / "standalone.js"
|
||||
test_file.write_text(source)
|
||||
|
||||
functions = js_support.discover_functions(test_file)
|
||||
functions = js_support.discover_functions(test_file.read_text(encoding="utf-8"), test_file)
|
||||
func = next(f for f in functions if f.function_name == "standalone")
|
||||
|
||||
context = js_support.extract_code_context(function=func, project_root=tmp_path, module_root=tmp_path)
|
||||
|
|
@ -709,7 +711,7 @@ module.exports = { processArray };
|
|||
test_file = tmp_path / "processor.js"
|
||||
test_file.write_text(source)
|
||||
|
||||
functions = js_support.discover_functions(test_file)
|
||||
functions = js_support.discover_functions(test_file.read_text(encoding="utf-8"), test_file)
|
||||
func = next(f for f in functions if f.function_name == "processArray")
|
||||
|
||||
context = js_support.extract_code_context(function=func, project_root=tmp_path, module_root=tmp_path)
|
||||
|
|
@ -744,7 +746,7 @@ module.exports = { fibonacci };
|
|||
test_file = tmp_path / "recursive.js"
|
||||
test_file.write_text(source)
|
||||
|
||||
functions = js_support.discover_functions(test_file)
|
||||
functions = js_support.discover_functions(test_file.read_text(encoding="utf-8"), test_file)
|
||||
func = next(f for f in functions if f.function_name == "fibonacci")
|
||||
|
||||
context = js_support.extract_code_context(function=func, project_root=tmp_path, module_root=tmp_path)
|
||||
|
|
@ -777,7 +779,7 @@ module.exports = { processValue };
|
|||
test_file = tmp_path / "arrow.js"
|
||||
test_file.write_text(source)
|
||||
|
||||
functions = js_support.discover_functions(test_file)
|
||||
functions = js_support.discover_functions(test_file.read_text(encoding="utf-8"), test_file)
|
||||
func = next(f for f in functions if f.function_name == "processValue")
|
||||
|
||||
context = js_support.extract_code_context(function=func, project_root=tmp_path, module_root=tmp_path)
|
||||
|
|
@ -835,7 +837,7 @@ module.exports = { Counter };
|
|||
test_file = tmp_path / "counter.js"
|
||||
test_file.write_text(source)
|
||||
|
||||
functions = js_support.discover_functions(test_file)
|
||||
functions = js_support.discover_functions(test_file.read_text(encoding="utf-8"), test_file)
|
||||
increment_func = next(f for f in functions if f.function_name == "increment")
|
||||
|
||||
context = js_support.extract_code_context(function=increment_func, project_root=tmp_path, module_root=tmp_path)
|
||||
|
|
@ -874,7 +876,7 @@ module.exports = { MathUtils };
|
|||
test_file = tmp_path / "math_utils.js"
|
||||
test_file.write_text(source)
|
||||
|
||||
functions = js_support.discover_functions(test_file)
|
||||
functions = js_support.discover_functions(test_file.read_text(encoding="utf-8"), test_file)
|
||||
add_func = next(f for f in functions if f.function_name == "add")
|
||||
|
||||
context = js_support.extract_code_context(function=add_func, project_root=tmp_path, module_root=tmp_path)
|
||||
|
|
@ -910,7 +912,7 @@ export class User {
|
|||
test_file = tmp_path / "user.ts"
|
||||
test_file.write_text(source)
|
||||
|
||||
functions = ts_support.discover_functions(test_file)
|
||||
functions = ts_support.discover_functions(test_file.read_text(encoding="utf-8"), test_file)
|
||||
get_name_func = next(f for f in functions if f.function_name == "getName")
|
||||
|
||||
context = ts_support.extract_code_context(function=get_name_func, project_root=tmp_path, module_root=tmp_path)
|
||||
|
|
@ -949,7 +951,7 @@ export class Config {
|
|||
test_file = tmp_path / "config.ts"
|
||||
test_file.write_text(source)
|
||||
|
||||
functions = ts_support.discover_functions(test_file)
|
||||
functions = ts_support.discover_functions(test_file.read_text(encoding="utf-8"), test_file)
|
||||
get_url_func = next(f for f in functions if f.function_name == "getUrl")
|
||||
|
||||
context = ts_support.extract_code_context(function=get_url_func, project_root=tmp_path, module_root=tmp_path)
|
||||
|
|
@ -990,7 +992,7 @@ module.exports = { Logger };
|
|||
test_file = tmp_path / "logger.js"
|
||||
test_file.write_text(source)
|
||||
|
||||
functions = js_support.discover_functions(test_file)
|
||||
functions = js_support.discover_functions(test_file.read_text(encoding="utf-8"), test_file)
|
||||
get_prefix_func = next(f for f in functions if f.function_name == "getPrefix")
|
||||
|
||||
context = js_support.extract_code_context(function=get_prefix_func, project_root=tmp_path, module_root=tmp_path)
|
||||
|
|
@ -1032,7 +1034,7 @@ module.exports = { Factory };
|
|||
test_file = tmp_path / "factory.js"
|
||||
test_file.write_text(source)
|
||||
|
||||
functions = js_support.discover_functions(test_file)
|
||||
functions = js_support.discover_functions(test_file.read_text(encoding="utf-8"), test_file)
|
||||
create_func = next(f for f in functions if f.function_name == "create")
|
||||
|
||||
context = js_support.extract_code_context(function=create_func, project_root=tmp_path, module_root=tmp_path)
|
||||
|
|
@ -1074,7 +1076,7 @@ class TestCodeExtractorIntegration:
|
|||
js_support = get_language_support("javascript")
|
||||
calculator_file = cjs_project / "calculator.js"
|
||||
|
||||
functions = js_support.discover_functions(calculator_file)
|
||||
functions = js_support.discover_functions(calculator_file.read_text(encoding="utf-8"), calculator_file)
|
||||
target = next(f for f in functions if f.function_name == "permutation")
|
||||
|
||||
parents = [FunctionParent(name=p.name, type=p.type) for p in target.parents]
|
||||
|
|
@ -1099,7 +1101,7 @@ class TestCodeExtractorIntegration:
|
|||
pytest_cmd="jest",
|
||||
)
|
||||
|
||||
func_optimizer = FunctionOptimizer(
|
||||
func_optimizer = JavaScriptFunctionOptimizer(
|
||||
function_to_optimize=func, test_cfg=test_config, aiservice_client=MagicMock()
|
||||
)
|
||||
result = func_optimizer.get_code_optimization_context()
|
||||
|
|
@ -1182,7 +1184,7 @@ export function distance(p1: Point, p2: Point): number {
|
|||
test_file = tmp_path / "geometry.ts"
|
||||
test_file.write_text(source)
|
||||
|
||||
functions = ts_support.discover_functions(test_file)
|
||||
functions = ts_support.discover_functions(test_file.read_text(encoding="utf-8"), test_file)
|
||||
distance_func = next(f for f in functions if f.function_name == "distance")
|
||||
|
||||
context = ts_support.extract_code_context(function=distance_func, project_root=tmp_path, module_root=tmp_path)
|
||||
|
|
@ -1224,7 +1226,7 @@ export function processStatus(status: Status): string {
|
|||
test_file = tmp_path / "status.ts"
|
||||
test_file.write_text(source)
|
||||
|
||||
functions = ts_support.discover_functions(test_file)
|
||||
functions = ts_support.discover_functions(test_file.read_text(encoding="utf-8"), test_file)
|
||||
process_func = next(f for f in functions if f.function_name == "processStatus")
|
||||
|
||||
context = ts_support.extract_code_context(function=process_func, project_root=tmp_path, module_root=tmp_path)
|
||||
|
|
@ -1259,7 +1261,7 @@ export function compute(x: number): Result<number> {
|
|||
test_file = tmp_path / "compute.ts"
|
||||
test_file.write_text(source)
|
||||
|
||||
functions = ts_support.discover_functions(test_file)
|
||||
functions = ts_support.discover_functions(test_file.read_text(encoding="utf-8"), test_file)
|
||||
compute_func = next(f for f in functions if f.function_name == "compute")
|
||||
|
||||
context = ts_support.extract_code_context(function=compute_func, project_root=tmp_path, module_root=tmp_path)
|
||||
|
|
@ -1301,7 +1303,7 @@ export class Service {
|
|||
test_file = tmp_path / "service.ts"
|
||||
test_file.write_text(source)
|
||||
|
||||
functions = ts_support.discover_functions(test_file)
|
||||
functions = ts_support.discover_functions(test_file.read_text(encoding="utf-8"), test_file)
|
||||
get_timeout_func = next(f for f in functions if f.function_name == "getTimeout")
|
||||
|
||||
context = ts_support.extract_code_context(
|
||||
|
|
@ -1332,7 +1334,7 @@ export function add(a: number, b: number): number {
|
|||
test_file = tmp_path / "add.ts"
|
||||
test_file.write_text(source)
|
||||
|
||||
functions = ts_support.discover_functions(test_file)
|
||||
functions = ts_support.discover_functions(test_file.read_text(encoding="utf-8"), test_file)
|
||||
add_func = next(f for f in functions if f.function_name == "add")
|
||||
|
||||
context = ts_support.extract_code_context(function=add_func, project_root=tmp_path, module_root=tmp_path)
|
||||
|
|
@ -1363,7 +1365,7 @@ export function createRect(origin: Point, size: Size): { origin: Point; size: Si
|
|||
test_file = tmp_path / "rect.ts"
|
||||
test_file.write_text(source)
|
||||
|
||||
functions = ts_support.discover_functions(test_file)
|
||||
functions = ts_support.discover_functions(test_file.read_text(encoding="utf-8"), test_file)
|
||||
create_rect_func = next(f for f in functions if f.function_name == "createRect")
|
||||
|
||||
context = ts_support.extract_code_context(
|
||||
|
|
@ -1409,7 +1411,7 @@ export function calculateDistance(p1: Point, p2: Point, config: CalculationConfi
|
|||
}
|
||||
""")
|
||||
|
||||
functions = ts_support.discover_functions(geometry_file)
|
||||
functions = ts_support.discover_functions(geometry_file.read_text(encoding="utf-8"), geometry_file)
|
||||
calc_distance_func = next(f for f in functions if f.function_name == "calculateDistance")
|
||||
|
||||
context = ts_support.extract_code_context(
|
||||
|
|
@ -1460,7 +1462,7 @@ export function greetUser(user: User): string {
|
|||
test_file = tmp_path / "user.ts"
|
||||
test_file.write_text(source)
|
||||
|
||||
functions = ts_support.discover_functions(test_file)
|
||||
functions = ts_support.discover_functions(test_file.read_text(encoding="utf-8"), test_file)
|
||||
greet_func = next(f for f in functions if f.function_name == "greetUser")
|
||||
|
||||
context = ts_support.extract_code_context(function=greet_func, project_root=tmp_path, module_root=tmp_path)
|
||||
|
|
|
|||
|
|
@ -7,6 +7,7 @@ These tests verify that code replacement correctly handles:
|
|||
- ES Modules (import/export) syntax
|
||||
- TypeScript import handling
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import shutil
|
||||
|
|
@ -14,8 +15,8 @@ from pathlib import Path
|
|||
|
||||
import pytest
|
||||
|
||||
from codeflash.languages.python.static_analysis.code_replacer import replace_function_definitions_for_language
|
||||
from codeflash.languages.base import Language
|
||||
from codeflash.languages.code_replacer import replace_function_definitions_for_language
|
||||
from codeflash.languages.current import set_current_language
|
||||
from codeflash.languages.javascript.module_system import (
|
||||
ModuleSystem,
|
||||
|
|
@ -25,7 +26,6 @@ from codeflash.languages.javascript.module_system import (
|
|||
ensure_module_system_compatibility,
|
||||
get_import_statement,
|
||||
)
|
||||
|
||||
from codeflash.languages.javascript.support import JavaScriptSupport, TypeScriptSupport
|
||||
from codeflash.models.models import CodeStringsMarkdown
|
||||
|
||||
|
|
@ -50,7 +50,6 @@ def temp_project(tmp_path):
|
|||
return project_root
|
||||
|
||||
|
||||
|
||||
FIXTURES_DIR = Path(__file__).parent / "fixtures"
|
||||
|
||||
|
||||
|
|
@ -308,7 +307,9 @@ class TestTsJestSkipsConversion:
|
|||
When ts-jest is installed, it handles module interoperability internally,
|
||||
so we skip conversion to avoid breaking valid imports.
|
||||
"""
|
||||
def __init__(self):
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _set_language(self):
|
||||
set_current_language(Language.TYPESCRIPT)
|
||||
|
||||
def test_commonjs_not_converted_when_ts_jest_installed(self, tmp_path):
|
||||
|
|
@ -751,6 +752,7 @@ class TestIntegrationWithFixtures:
|
|||
f"import statements should be converted to require.\nFound import lines: {import_lines}"
|
||||
)
|
||||
|
||||
|
||||
class TestSimpleFunctionReplacement:
|
||||
"""Tests for simple function body replacement with strict assertions."""
|
||||
|
||||
|
|
@ -764,7 +766,8 @@ export function add(a, b) {
|
|||
file_path = temp_project / "math.js"
|
||||
file_path.write_text(original_source, encoding="utf-8")
|
||||
|
||||
functions = js_support.discover_functions(file_path)
|
||||
source = file_path.read_text(encoding="utf-8")
|
||||
functions = js_support.discover_functions(source, file_path)
|
||||
func = functions[0]
|
||||
|
||||
# Optimized version with different body
|
||||
|
|
@ -800,7 +803,8 @@ export function processData(data) {
|
|||
file_path = temp_project / "processor.js"
|
||||
file_path.write_text(original_source, encoding="utf-8")
|
||||
|
||||
functions = js_support.discover_functions(file_path)
|
||||
source = file_path.read_text(encoding="utf-8")
|
||||
functions = js_support.discover_functions(source, file_path)
|
||||
func = functions[0]
|
||||
|
||||
# Optimized version using map
|
||||
|
|
@ -839,7 +843,8 @@ module.exports = { targetFunction, otherFunction };
|
|||
file_path = temp_project / "module.js"
|
||||
file_path.write_text(original_source, encoding="utf-8")
|
||||
|
||||
functions = js_support.discover_functions(file_path)
|
||||
source = file_path.read_text(encoding="utf-8")
|
||||
functions = js_support.discover_functions(source, file_path)
|
||||
target_func = next(f for f in functions if f.function_name == "targetFunction")
|
||||
|
||||
optimized_code = """\
|
||||
|
|
@ -891,7 +896,8 @@ export class Calculator {
|
|||
file_path = temp_project / "calculator.js"
|
||||
file_path.write_text(original_source, encoding="utf-8")
|
||||
|
||||
functions = js_support.discover_functions(file_path)
|
||||
source = file_path.read_text(encoding="utf-8")
|
||||
functions = js_support.discover_functions(source, file_path)
|
||||
add_method = next(f for f in functions if f.function_name == "add")
|
||||
|
||||
# Optimized version provided in class context
|
||||
|
|
@ -954,7 +960,8 @@ export class DataProcessor {
|
|||
file_path = temp_project / "processor.js"
|
||||
file_path.write_text(original_source, encoding="utf-8")
|
||||
|
||||
functions = js_support.discover_functions(file_path)
|
||||
source = file_path.read_text(encoding="utf-8")
|
||||
functions = js_support.discover_functions(source, file_path)
|
||||
process_method = next(f for f in functions if f.function_name == "process")
|
||||
|
||||
optimized_code = """\
|
||||
|
|
@ -1016,7 +1023,8 @@ export function add(a, b) {
|
|||
file_path = temp_project / "math.js"
|
||||
file_path.write_text(original_source, encoding="utf-8")
|
||||
|
||||
functions = js_support.discover_functions(file_path)
|
||||
source = file_path.read_text(encoding="utf-8")
|
||||
functions = js_support.discover_functions(source, file_path)
|
||||
func = functions[0]
|
||||
|
||||
optimized_code = """\
|
||||
|
|
@ -1070,7 +1078,8 @@ export class Cache {
|
|||
file_path = temp_project / "cache.js"
|
||||
file_path.write_text(original_source, encoding="utf-8")
|
||||
|
||||
functions = js_support.discover_functions(file_path)
|
||||
source = file_path.read_text(encoding="utf-8")
|
||||
functions = js_support.discover_functions(source, file_path)
|
||||
get_method = next(f for f in functions if f.function_name == "get")
|
||||
|
||||
optimized_code = """\
|
||||
|
|
@ -1131,7 +1140,8 @@ export async function fetchData(url) {
|
|||
file_path = temp_project / "api.js"
|
||||
file_path.write_text(original_source, encoding="utf-8")
|
||||
|
||||
functions = js_support.discover_functions(file_path)
|
||||
source = file_path.read_text(encoding="utf-8")
|
||||
functions = js_support.discover_functions(source, file_path)
|
||||
func = functions[0]
|
||||
|
||||
optimized_code = """\
|
||||
|
|
@ -1172,7 +1182,8 @@ export class ApiClient {
|
|||
file_path = temp_project / "client.js"
|
||||
file_path.write_text(original_source, encoding="utf-8")
|
||||
|
||||
functions = js_support.discover_functions(file_path)
|
||||
source = file_path.read_text(encoding="utf-8")
|
||||
functions = js_support.discover_functions(source, file_path)
|
||||
get_method = next(f for f in functions if f.function_name == "get")
|
||||
|
||||
optimized_code = """\
|
||||
|
|
@ -1223,7 +1234,8 @@ export function* range(start, end) {
|
|||
file_path = temp_project / "generators.js"
|
||||
file_path.write_text(original_source, encoding="utf-8")
|
||||
|
||||
functions = js_support.discover_functions(file_path)
|
||||
source = file_path.read_text(encoding="utf-8")
|
||||
functions = js_support.discover_functions(source, file_path)
|
||||
func = functions[0]
|
||||
|
||||
optimized_code = """\
|
||||
|
|
@ -1262,7 +1274,8 @@ export function processArray(items: number[]): number {
|
|||
file_path = temp_project / "processor.ts"
|
||||
file_path.write_text(original_source, encoding="utf-8")
|
||||
|
||||
functions = ts_support.discover_functions(file_path)
|
||||
source = file_path.read_text(encoding="utf-8")
|
||||
functions = ts_support.discover_functions(source, file_path)
|
||||
func = functions[0]
|
||||
|
||||
optimized_code = """\
|
||||
|
|
@ -1303,7 +1316,8 @@ export class Container<T> {
|
|||
file_path = temp_project / "container.ts"
|
||||
file_path.write_text(original_source, encoding="utf-8")
|
||||
|
||||
functions = ts_support.discover_functions(file_path)
|
||||
source = file_path.read_text(encoding="utf-8")
|
||||
functions = ts_support.discover_functions(source, file_path)
|
||||
get_all_method = next(f for f in functions if f.function_name == "getAll")
|
||||
|
||||
optimized_code = """\
|
||||
|
|
@ -1356,7 +1370,8 @@ export function createUser(name: string, email: string): User {
|
|||
file_path = temp_project / "user.ts"
|
||||
file_path.write_text(original_source, encoding="utf-8")
|
||||
|
||||
functions = ts_support.discover_functions(file_path)
|
||||
source = file_path.read_text(encoding="utf-8")
|
||||
functions = ts_support.discover_functions(source, file_path)
|
||||
func = next(f for f in functions if f.function_name == "createUser")
|
||||
|
||||
optimized_code = """\
|
||||
|
|
@ -1411,7 +1426,8 @@ export function processItems(items) {
|
|||
file_path = temp_project / "processor.js"
|
||||
file_path.write_text(original_source, encoding="utf-8")
|
||||
|
||||
functions = js_support.discover_functions(file_path)
|
||||
source = file_path.read_text(encoding="utf-8")
|
||||
functions = js_support.discover_functions(source, file_path)
|
||||
process_func = next(f for f in functions if f.function_name == "processItems")
|
||||
|
||||
optimized_code = """\
|
||||
|
|
@ -1458,7 +1474,8 @@ export class MathUtils {
|
|||
file_path.write_text(original_source, encoding="utf-8")
|
||||
|
||||
# First replacement: sum method
|
||||
functions = js_support.discover_functions(file_path)
|
||||
source = file_path.read_text(encoding="utf-8")
|
||||
functions = js_support.discover_functions(source, file_path)
|
||||
sum_method = next(f for f in functions if f.function_name == "sum")
|
||||
|
||||
optimized_sum = """\
|
||||
|
|
@ -1505,7 +1522,8 @@ export function processConfig({ server: { host, port }, database: { url, poolSiz
|
|||
file_path = temp_project / "config.js"
|
||||
file_path.write_text(original_source, encoding="utf-8")
|
||||
|
||||
functions = js_support.discover_functions(file_path)
|
||||
source = file_path.read_text(encoding="utf-8")
|
||||
functions = js_support.discover_functions(source, file_path)
|
||||
func = functions[0]
|
||||
|
||||
optimized_code = """\
|
||||
|
|
@ -1544,7 +1562,8 @@ export function minimal() {
|
|||
file_path = temp_project / "minimal.js"
|
||||
file_path.write_text(original_source, encoding="utf-8")
|
||||
|
||||
functions = js_support.discover_functions(file_path)
|
||||
source = file_path.read_text(encoding="utf-8")
|
||||
functions = js_support.discover_functions(source, file_path)
|
||||
func = functions[0]
|
||||
|
||||
optimized_code = """\
|
||||
|
|
@ -1571,7 +1590,8 @@ export function identity(x) { return x; }
|
|||
file_path = temp_project / "utils.js"
|
||||
file_path.write_text(original_source, encoding="utf-8")
|
||||
|
||||
functions = js_support.discover_functions(file_path)
|
||||
source = file_path.read_text(encoding="utf-8")
|
||||
functions = js_support.discover_functions(source, file_path)
|
||||
func = functions[0]
|
||||
|
||||
optimized_code = """\
|
||||
|
|
@ -1598,7 +1618,8 @@ export function formatMessage(name) {
|
|||
file_path = temp_project / "formatter.js"
|
||||
file_path.write_text(original_source, encoding="utf-8")
|
||||
|
||||
functions = js_support.discover_functions(file_path)
|
||||
source = file_path.read_text(encoding="utf-8")
|
||||
functions = js_support.discover_functions(source, file_path)
|
||||
func = functions[0]
|
||||
|
||||
optimized_code = """\
|
||||
|
|
@ -1633,7 +1654,8 @@ export function validateEmail(email) {
|
|||
file_path = temp_project / "validator.js"
|
||||
file_path.write_text(original_source, encoding="utf-8")
|
||||
|
||||
functions = js_support.discover_functions(file_path)
|
||||
source = file_path.read_text(encoding="utf-8")
|
||||
functions = js_support.discover_functions(source, file_path)
|
||||
func = functions[0]
|
||||
|
||||
optimized_code = """\
|
||||
|
|
@ -1676,7 +1698,8 @@ module.exports = { main, helper };
|
|||
file_path = temp_project / "module.js"
|
||||
file_path.write_text(original_source, encoding="utf-8")
|
||||
|
||||
functions = js_support.discover_functions(file_path)
|
||||
source = file_path.read_text(encoding="utf-8")
|
||||
functions = js_support.discover_functions(source, file_path)
|
||||
main_func = next(f for f in functions if f.function_name == "main")
|
||||
|
||||
optimized_code = """\
|
||||
|
|
@ -1719,7 +1742,8 @@ export function main(data) {
|
|||
file_path = temp_project / "module.js"
|
||||
file_path.write_text(original_source, encoding="utf-8")
|
||||
|
||||
functions = js_support.discover_functions(file_path)
|
||||
source = file_path.read_text(encoding="utf-8")
|
||||
functions = js_support.discover_functions(source, file_path)
|
||||
main_func = next(f for f in functions if f.function_name == "main")
|
||||
|
||||
optimized_code = """\
|
||||
|
|
@ -1750,20 +1774,16 @@ class TestSyntaxValidation:
|
|||
"""Test that various replacements all produce valid JavaScript."""
|
||||
test_cases = [
|
||||
# (original, optimized, description)
|
||||
(
|
||||
"export function f(x) { return x + 1; }",
|
||||
"export function f(x) { return ++x; }",
|
||||
"increment replacement"
|
||||
),
|
||||
("export function f(x) { return x + 1; }", "export function f(x) { return ++x; }", "increment replacement"),
|
||||
(
|
||||
"export function f(arr) { return arr.length > 0; }",
|
||||
"export function f(arr) { return !!arr.length; }",
|
||||
"boolean conversion"
|
||||
"boolean conversion",
|
||||
),
|
||||
(
|
||||
"export function f(a, b) { if (a) { return a; } return b; }",
|
||||
"export function f(a, b) { return a || b; }",
|
||||
"logical OR replacement"
|
||||
"logical OR replacement",
|
||||
),
|
||||
]
|
||||
|
||||
|
|
@ -1771,7 +1791,8 @@ class TestSyntaxValidation:
|
|||
file_path = temp_project / f"test_{i}.js"
|
||||
file_path.write_text(original, encoding="utf-8")
|
||||
|
||||
functions = js_support.discover_functions(file_path)
|
||||
source = file_path.read_text(encoding="utf-8")
|
||||
functions = js_support.discover_functions(source, file_path)
|
||||
func = functions[0]
|
||||
|
||||
result = js_support.replace_function(original, func, optimized)
|
||||
|
|
@ -1875,7 +1896,8 @@ export class DataProcessor<T> {
|
|||
target_func = "findDuplicates"
|
||||
parent_class = "DataProcessor"
|
||||
|
||||
functions = ts_support.discover_functions(file_path)
|
||||
source = file_path.read_text(encoding="utf-8")
|
||||
functions = ts_support.discover_functions(source, file_path)
|
||||
# find function
|
||||
target_func_info = None
|
||||
for func in functions:
|
||||
|
|
@ -1920,11 +1942,15 @@ class DataProcessor<T> {
|
|||
```
|
||||
"""
|
||||
code_markdown = CodeStringsMarkdown.parse_markdown_code(new_code)
|
||||
replaced = replace_function_definitions_for_language([f"{parent_class}.{target_func}"], code_markdown, file_path, temp_project)
|
||||
replaced = replace_function_definitions_for_language(
|
||||
[f"{parent_class}.{target_func}"], code_markdown, file_path, temp_project, lang_support=ts_support
|
||||
)
|
||||
assert replaced
|
||||
|
||||
new_code = file_path.read_text()
|
||||
assert new_code == """/**
|
||||
assert (
|
||||
new_code
|
||||
== """/**
|
||||
* DataProcessor class - demonstrates class method optimization in TypeScript.
|
||||
* Contains intentionally inefficient implementations for optimization testing.
|
||||
*/
|
||||
|
|
@ -2015,7 +2041,7 @@ export class DataProcessor<T> {
|
|||
}
|
||||
}
|
||||
"""
|
||||
|
||||
)
|
||||
|
||||
|
||||
class TestNewVariableFromOptimizedCode:
|
||||
|
|
@ -2030,9 +2056,9 @@ class TestNewVariableFromOptimizedCode:
|
|||
1. Add the new variable after the constant it references
|
||||
2. Replace the function with the optimized version
|
||||
"""
|
||||
from codeflash.models.models import CodeStringsMarkdown, CodeString
|
||||
from codeflash.models.models import CodeString, CodeStringsMarkdown
|
||||
|
||||
original_source = '''\
|
||||
original_source = """\
|
||||
const CODEFLASH_EMPLOYEE_GITHUB_IDS = new Set([
|
||||
"1234",
|
||||
]);
|
||||
|
|
@ -2040,43 +2066,34 @@ const CODEFLASH_EMPLOYEE_GITHUB_IDS = new Set([
|
|||
export function isCodeflashEmployee(userId: string): boolean {
|
||||
return CODEFLASH_EMPLOYEE_GITHUB_IDS.has(userId);
|
||||
}
|
||||
'''
|
||||
"""
|
||||
file_path = temp_project / "auth.ts"
|
||||
file_path.write_text(original_source, encoding="utf-8")
|
||||
|
||||
# Optimized code introduces a bound method variable for performance
|
||||
optimized_code = '''const _has: (id: string) => boolean = CODEFLASH_EMPLOYEE_GITHUB_IDS.has.bind(
|
||||
optimized_code = """const _has: (id: string) => boolean = CODEFLASH_EMPLOYEE_GITHUB_IDS.has.bind(
|
||||
CODEFLASH_EMPLOYEE_GITHUB_IDS
|
||||
);
|
||||
|
||||
export function isCodeflashEmployee(userId: string): boolean {
|
||||
return _has(userId);
|
||||
}
|
||||
'''
|
||||
"""
|
||||
|
||||
code_markdown = CodeStringsMarkdown(
|
||||
code_strings=[
|
||||
CodeString(
|
||||
code=optimized_code,
|
||||
file_path=Path("auth.ts"),
|
||||
language="typescript"
|
||||
)
|
||||
],
|
||||
language="typescript"
|
||||
code_strings=[CodeString(code=optimized_code, file_path=Path("auth.ts"), language="typescript")],
|
||||
language="typescript",
|
||||
)
|
||||
|
||||
replaced = replace_function_definitions_for_language(
|
||||
["isCodeflashEmployee"],
|
||||
code_markdown,
|
||||
file_path,
|
||||
temp_project,
|
||||
["isCodeflashEmployee"], code_markdown, file_path, temp_project, lang_support=ts_support
|
||||
)
|
||||
|
||||
assert replaced
|
||||
result = file_path.read_text()
|
||||
|
||||
# Expected result for strict equality check
|
||||
expected_result = '''\
|
||||
expected_result = """\
|
||||
const CODEFLASH_EMPLOYEE_GITHUB_IDS = new Set([
|
||||
"1234",
|
||||
]);
|
||||
|
|
@ -2088,11 +2105,9 @@ const _has: (id: string) => boolean = CODEFLASH_EMPLOYEE_GITHUB_IDS.has.bind(
|
|||
export function isCodeflashEmployee(userId: string): boolean {
|
||||
return _has(userId);
|
||||
}
|
||||
'''
|
||||
"""
|
||||
assert result == expected_result, (
|
||||
f"Result does not match expected output.\n"
|
||||
f"Expected:\n{expected_result}\n\n"
|
||||
f"Got:\n{result}"
|
||||
f"Result does not match expected output.\nExpected:\n{expected_result}\n\nGot:\n{result}"
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -2113,7 +2128,7 @@ class TestImportedTypeNotDuplicated:
|
|||
contains the TreeNode interface definition (from read-only context),
|
||||
the replacement should NOT add the interface to the original file.
|
||||
"""
|
||||
from codeflash.models.models import CodeStringsMarkdown, CodeString
|
||||
from codeflash.models.models import CodeString, CodeStringsMarkdown
|
||||
|
||||
# Original source imports TreeNode
|
||||
original_source = """\
|
||||
|
|
@ -2163,20 +2178,13 @@ export function getNearestAbove(
|
|||
|
||||
code_markdown = CodeStringsMarkdown(
|
||||
code_strings=[
|
||||
CodeString(
|
||||
code=optimized_code_with_interface,
|
||||
file_path=Path("helpers.ts"),
|
||||
language="typescript"
|
||||
)
|
||||
CodeString(code=optimized_code_with_interface, file_path=Path("helpers.ts"), language="typescript")
|
||||
],
|
||||
language="typescript"
|
||||
language="typescript",
|
||||
)
|
||||
|
||||
replace_function_definitions_for_language(
|
||||
["getNearestAbove"],
|
||||
code_markdown,
|
||||
file_path,
|
||||
temp_project,
|
||||
["getNearestAbove"], code_markdown, file_path, temp_project, lang_support=ts_support
|
||||
)
|
||||
|
||||
result = file_path.read_text()
|
||||
|
|
@ -2203,7 +2211,7 @@ export function getNearestAbove(
|
|||
|
||||
def test_multiple_imported_types_not_duplicated(self, ts_support, temp_project):
|
||||
"""Test that multiple imported types are not duplicated."""
|
||||
from codeflash.models.models import CodeStringsMarkdown, CodeString
|
||||
from codeflash.models.models import CodeString, CodeStringsMarkdown
|
||||
|
||||
original_source = """\
|
||||
import type { TreeNode, NodeSpace } from "./constants";
|
||||
|
|
@ -2235,21 +2243,12 @@ export function processNode(node: TreeNode, space: NodeSpace): number {
|
|||
"""
|
||||
|
||||
code_markdown = CodeStringsMarkdown(
|
||||
code_strings=[
|
||||
CodeString(
|
||||
code=optimized_code,
|
||||
file_path=Path("processor.ts"),
|
||||
language="typescript"
|
||||
)
|
||||
],
|
||||
language="typescript"
|
||||
code_strings=[CodeString(code=optimized_code, file_path=Path("processor.ts"), language="typescript")],
|
||||
language="typescript",
|
||||
)
|
||||
|
||||
replace_function_definitions_for_language(
|
||||
["processNode"],
|
||||
code_markdown,
|
||||
file_path,
|
||||
temp_project,
|
||||
["processNode"], code_markdown, file_path, temp_project, lang_support=ts_support
|
||||
)
|
||||
|
||||
result = file_path.read_text()
|
||||
|
|
|
|||
|
|
@ -345,8 +345,8 @@ class TestDiscoverFunctionsParity:
|
|||
py_file = write_temp_file(SIMPLE_FUNCTION.python, ".py")
|
||||
js_file = write_temp_file(SIMPLE_FUNCTION.javascript, ".js")
|
||||
|
||||
py_funcs = python_support.discover_functions(py_file)
|
||||
js_funcs = js_support.discover_functions(js_file)
|
||||
py_funcs = python_support.discover_functions(py_file.read_text(encoding="utf-8"), py_file)
|
||||
js_funcs = js_support.discover_functions(js_file.read_text(encoding="utf-8"), js_file)
|
||||
|
||||
# Both should find exactly one function
|
||||
assert len(py_funcs) == 1, f"Python found {len(py_funcs)}, expected 1"
|
||||
|
|
@ -365,8 +365,8 @@ class TestDiscoverFunctionsParity:
|
|||
py_file = write_temp_file(MULTIPLE_FUNCTIONS.python, ".py")
|
||||
js_file = write_temp_file(MULTIPLE_FUNCTIONS.javascript, ".js")
|
||||
|
||||
py_funcs = python_support.discover_functions(py_file)
|
||||
js_funcs = js_support.discover_functions(js_file)
|
||||
py_funcs = python_support.discover_functions(py_file.read_text(encoding="utf-8"), py_file)
|
||||
js_funcs = js_support.discover_functions(js_file.read_text(encoding="utf-8"), js_file)
|
||||
|
||||
# Both should find 3 functions
|
||||
assert len(py_funcs) == 3, f"Python found {len(py_funcs)}, expected 3"
|
||||
|
|
@ -384,8 +384,8 @@ class TestDiscoverFunctionsParity:
|
|||
py_file = write_temp_file(WITH_AND_WITHOUT_RETURN.python, ".py")
|
||||
js_file = write_temp_file(WITH_AND_WITHOUT_RETURN.javascript, ".js")
|
||||
|
||||
py_funcs = python_support.discover_functions(py_file)
|
||||
js_funcs = js_support.discover_functions(js_file)
|
||||
py_funcs = python_support.discover_functions(py_file.read_text(encoding="utf-8"), py_file)
|
||||
js_funcs = js_support.discover_functions(js_file.read_text(encoding="utf-8"), js_file)
|
||||
|
||||
# Both should find only 1 function (the one with return)
|
||||
assert len(py_funcs) == 1, f"Python found {len(py_funcs)}, expected 1"
|
||||
|
|
@ -400,8 +400,8 @@ class TestDiscoverFunctionsParity:
|
|||
py_file = write_temp_file(CLASS_METHODS.python, ".py")
|
||||
js_file = write_temp_file(CLASS_METHODS.javascript, ".js")
|
||||
|
||||
py_funcs = python_support.discover_functions(py_file)
|
||||
js_funcs = js_support.discover_functions(js_file)
|
||||
py_funcs = python_support.discover_functions(py_file.read_text(encoding="utf-8"), py_file)
|
||||
js_funcs = js_support.discover_functions(js_file.read_text(encoding="utf-8"), js_file)
|
||||
|
||||
# Both should find 2 methods
|
||||
assert len(py_funcs) == 2, f"Python found {len(py_funcs)}, expected 2"
|
||||
|
|
@ -421,8 +421,8 @@ class TestDiscoverFunctionsParity:
|
|||
py_file = write_temp_file(ASYNC_FUNCTIONS.python, ".py")
|
||||
js_file = write_temp_file(ASYNC_FUNCTIONS.javascript, ".js")
|
||||
|
||||
py_funcs = python_support.discover_functions(py_file)
|
||||
js_funcs = js_support.discover_functions(js_file)
|
||||
py_funcs = python_support.discover_functions(py_file.read_text(encoding="utf-8"), py_file)
|
||||
js_funcs = js_support.discover_functions(js_file.read_text(encoding="utf-8"), js_file)
|
||||
|
||||
# Both should find 2 functions
|
||||
assert len(py_funcs) == 2, f"Python found {len(py_funcs)}, expected 2"
|
||||
|
|
@ -444,8 +444,8 @@ class TestDiscoverFunctionsParity:
|
|||
py_file = write_temp_file(NESTED_FUNCTIONS.python, ".py")
|
||||
js_file = write_temp_file(NESTED_FUNCTIONS.javascript, ".js")
|
||||
|
||||
py_funcs = python_support.discover_functions(py_file)
|
||||
js_funcs = js_support.discover_functions(js_file)
|
||||
py_funcs = python_support.discover_functions(py_file.read_text(encoding="utf-8"), py_file)
|
||||
js_funcs = js_support.discover_functions(js_file.read_text(encoding="utf-8"), js_file)
|
||||
|
||||
# Python skips nested functions — only outer is discovered
|
||||
assert len(py_funcs) == 1, f"Python found {len(py_funcs)}, expected 1"
|
||||
|
|
@ -465,8 +465,8 @@ class TestDiscoverFunctionsParity:
|
|||
py_file = write_temp_file(STATIC_METHODS.python, ".py")
|
||||
js_file = write_temp_file(STATIC_METHODS.javascript, ".js")
|
||||
|
||||
py_funcs = python_support.discover_functions(py_file)
|
||||
js_funcs = js_support.discover_functions(js_file)
|
||||
py_funcs = python_support.discover_functions(py_file.read_text(encoding="utf-8"), py_file)
|
||||
js_funcs = js_support.discover_functions(js_file.read_text(encoding="utf-8"), js_file)
|
||||
|
||||
# Both should find 1 function
|
||||
assert len(py_funcs) == 1, f"Python found {len(py_funcs)}, expected 1"
|
||||
|
|
@ -483,8 +483,8 @@ class TestDiscoverFunctionsParity:
|
|||
py_file = write_temp_file(COMPLEX_FILE.python, ".py")
|
||||
js_file = write_temp_file(COMPLEX_FILE.javascript, ".js")
|
||||
|
||||
py_funcs = python_support.discover_functions(py_file)
|
||||
js_funcs = js_support.discover_functions(js_file)
|
||||
py_funcs = python_support.discover_functions(py_file.read_text(encoding="utf-8"), py_file)
|
||||
js_funcs = js_support.discover_functions(js_file.read_text(encoding="utf-8"), js_file)
|
||||
|
||||
# Both should find 4 functions
|
||||
assert len(py_funcs) == 4, f"Python found {len(py_funcs)}, expected 4"
|
||||
|
|
@ -515,8 +515,8 @@ class TestDiscoverFunctionsParity:
|
|||
|
||||
criteria = FunctionFilterCriteria(include_async=False)
|
||||
|
||||
py_funcs = python_support.discover_functions(py_file, criteria)
|
||||
js_funcs = js_support.discover_functions(js_file, criteria)
|
||||
py_funcs = python_support.discover_functions(py_file.read_text(encoding="utf-8"), py_file, criteria)
|
||||
js_funcs = js_support.discover_functions(js_file.read_text(encoding="utf-8"), js_file, criteria)
|
||||
|
||||
# Both should find only 1 function (the sync one)
|
||||
assert len(py_funcs) == 1, f"Python found {len(py_funcs)}, expected 1"
|
||||
|
|
@ -533,8 +533,8 @@ class TestDiscoverFunctionsParity:
|
|||
|
||||
criteria = FunctionFilterCriteria(include_methods=False)
|
||||
|
||||
py_funcs = python_support.discover_functions(py_file, criteria)
|
||||
js_funcs = js_support.discover_functions(js_file, criteria)
|
||||
py_funcs = python_support.discover_functions(py_file.read_text(encoding="utf-8"), py_file, criteria)
|
||||
js_funcs = js_support.discover_functions(js_file.read_text(encoding="utf-8"), js_file, criteria)
|
||||
|
||||
# Both should find only 1 function (standalone)
|
||||
assert len(py_funcs) == 1, f"Python found {len(py_funcs)}, expected 1"
|
||||
|
|
@ -545,11 +545,11 @@ class TestDiscoverFunctionsParity:
|
|||
assert js_funcs[0].function_name == "standalone"
|
||||
|
||||
def test_nonexistent_file_returns_empty(self, python_support, js_support):
|
||||
"""Python raises on nonexistent files; JavaScript returns empty list."""
|
||||
with pytest.raises(FileNotFoundError):
|
||||
python_support.discover_functions(Path("/nonexistent/file.py"))
|
||||
"""Both languages return empty list for empty source."""
|
||||
py_funcs = python_support.discover_functions("", Path("/nonexistent/file.py"))
|
||||
assert py_funcs == []
|
||||
|
||||
js_funcs = js_support.discover_functions(Path("/nonexistent/file.js"))
|
||||
js_funcs = js_support.discover_functions("", Path("/nonexistent/file.js"))
|
||||
assert js_funcs == []
|
||||
|
||||
def test_line_numbers_captured(self, python_support, js_support):
|
||||
|
|
@ -557,8 +557,8 @@ class TestDiscoverFunctionsParity:
|
|||
py_file = write_temp_file(SIMPLE_FUNCTION.python, ".py")
|
||||
js_file = write_temp_file(SIMPLE_FUNCTION.javascript, ".js")
|
||||
|
||||
py_funcs = python_support.discover_functions(py_file)
|
||||
js_funcs = js_support.discover_functions(js_file)
|
||||
py_funcs = python_support.discover_functions(py_file.read_text(encoding="utf-8"), py_file)
|
||||
js_funcs = js_support.discover_functions(js_file.read_text(encoding="utf-8"), js_file)
|
||||
|
||||
# Both should have start_line and end_line
|
||||
assert py_funcs[0].starting_line is not None
|
||||
|
|
@ -908,8 +908,8 @@ class TestIntegrationParity:
|
|||
js_file = write_temp_file(js_original, ".js")
|
||||
|
||||
# Discover
|
||||
py_funcs = python_support.discover_functions(py_file)
|
||||
js_funcs = js_support.discover_functions(js_file)
|
||||
py_funcs = python_support.discover_functions(py_file.read_text(encoding="utf-8"), py_file)
|
||||
js_funcs = js_support.discover_functions(js_file.read_text(encoding="utf-8"), js_file)
|
||||
|
||||
assert len(py_funcs) == 1
|
||||
assert len(js_funcs) == 1
|
||||
|
|
@ -960,8 +960,8 @@ class TestFeatureGaps:
|
|||
py_file = write_temp_file(CLASS_METHODS.python, ".py")
|
||||
js_file = write_temp_file(CLASS_METHODS.javascript, ".js")
|
||||
|
||||
py_funcs = python_support.discover_functions(py_file)
|
||||
js_funcs = js_support.discover_functions(js_file)
|
||||
py_funcs = python_support.discover_functions(py_file.read_text(encoding="utf-8"), py_file)
|
||||
js_funcs = js_support.discover_functions(js_file.read_text(encoding="utf-8"), js_file)
|
||||
|
||||
for py_func in py_funcs:
|
||||
# Check all expected fields are populated
|
||||
|
|
@ -994,7 +994,7 @@ export const multiply = (x, y) => x * y;
|
|||
export const identity = x => x;
|
||||
"""
|
||||
js_file = write_temp_file(js_code, ".js")
|
||||
funcs = js_support.discover_functions(js_file)
|
||||
funcs = js_support.discover_functions(js_file.read_text(encoding="utf-8"), js_file)
|
||||
|
||||
# Should find all arrow functions
|
||||
names = {f.function_name for f in funcs}
|
||||
|
|
@ -1021,8 +1021,8 @@ export function* numberGenerator() {
|
|||
py_file = write_temp_file(py_code, ".py")
|
||||
js_file = write_temp_file(js_code, ".js")
|
||||
|
||||
py_funcs = python_support.discover_functions(py_file)
|
||||
js_funcs = js_support.discover_functions(js_file)
|
||||
py_funcs = python_support.discover_functions(py_file.read_text(encoding="utf-8"), py_file)
|
||||
js_funcs = js_support.discover_functions(js_file.read_text(encoding="utf-8"), js_file)
|
||||
|
||||
# Both should find the generator
|
||||
assert len(py_funcs) == 1, f"Python found {len(py_funcs)} generators"
|
||||
|
|
@ -1045,7 +1045,7 @@ def multi_decorated():
|
|||
return 3
|
||||
"""
|
||||
py_file = write_temp_file(py_code, ".py")
|
||||
funcs = python_support.discover_functions(py_file)
|
||||
funcs = python_support.discover_functions(py_file.read_text(encoding="utf-8"), py_file)
|
||||
|
||||
# Should find all functions regardless of decorators
|
||||
names = {f.function_name for f in funcs}
|
||||
|
|
@ -1065,7 +1065,7 @@ export const namedExpr = function myFunc(x) {
|
|||
};
|
||||
"""
|
||||
js_file = write_temp_file(js_code, ".js")
|
||||
funcs = js_support.discover_functions(js_file)
|
||||
funcs = js_support.discover_functions(js_file.read_text(encoding="utf-8"), js_file)
|
||||
|
||||
# Should find function expressions
|
||||
names = {f.function_name for f in funcs}
|
||||
|
|
@ -1085,8 +1085,8 @@ class TestEdgeCases:
|
|||
py_file = write_temp_file("", ".py")
|
||||
js_file = write_temp_file("", ".js")
|
||||
|
||||
py_funcs = python_support.discover_functions(py_file)
|
||||
js_funcs = js_support.discover_functions(js_file)
|
||||
py_funcs = python_support.discover_functions(py_file.read_text(encoding="utf-8"), py_file)
|
||||
js_funcs = js_support.discover_functions(js_file.read_text(encoding="utf-8"), js_file)
|
||||
|
||||
assert py_funcs == []
|
||||
assert js_funcs == []
|
||||
|
|
@ -1110,8 +1110,8 @@ Multiline comment
|
|||
py_file = write_temp_file(py_code, ".py")
|
||||
js_file = write_temp_file(js_code, ".js")
|
||||
|
||||
py_funcs = python_support.discover_functions(py_file)
|
||||
js_funcs = js_support.discover_functions(js_file)
|
||||
py_funcs = python_support.discover_functions(py_file.read_text(encoding="utf-8"), py_file)
|
||||
js_funcs = js_support.discover_functions(js_file.read_text(encoding="utf-8"), js_file)
|
||||
|
||||
assert py_funcs == []
|
||||
assert js_funcs == []
|
||||
|
|
@ -1130,8 +1130,8 @@ export function greeting() {
|
|||
py_file = write_temp_file(py_code, ".py")
|
||||
js_file = write_temp_file(js_code, ".js")
|
||||
|
||||
py_funcs = python_support.discover_functions(py_file)
|
||||
js_funcs = js_support.discover_functions(js_file)
|
||||
py_funcs = python_support.discover_functions(py_file.read_text(encoding="utf-8"), py_file)
|
||||
js_funcs = js_support.discover_functions(js_file.read_text(encoding="utf-8"), js_file)
|
||||
|
||||
assert len(py_funcs) == 1
|
||||
assert len(js_funcs) == 1
|
||||
|
|
|
|||
|
|
@ -82,9 +82,9 @@ from pathlib import Path
|
|||
from unittest.mock import MagicMock
|
||||
|
||||
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
|
||||
from codeflash.languages.javascript.function_optimizer import JavaScriptFunctionOptimizer
|
||||
from codeflash.languages.registry import get_language_support
|
||||
from codeflash.models.models import CodeOptimizationContext, CodeStringsMarkdown
|
||||
from codeflash.optimization.function_optimizer import FunctionOptimizer
|
||||
from codeflash.verification.verification_utils import TestConfig
|
||||
|
||||
|
||||
|
|
@ -110,7 +110,7 @@ def test_js_replcement() -> None:
|
|||
original_helper = helper_file.read_text("utf-8")
|
||||
|
||||
js_support = get_language_support("javascript")
|
||||
functions = js_support.discover_functions(main_file)
|
||||
functions = js_support.discover_functions(main_file.read_text(encoding="utf-8"), main_file)
|
||||
target = None
|
||||
for func in functions:
|
||||
if func.function_name == "calculateStats":
|
||||
|
|
@ -135,7 +135,7 @@ def test_js_replcement() -> None:
|
|||
project_root_path=root_dir,
|
||||
pytest_cmd="jest",
|
||||
)
|
||||
func_optimizer = FunctionOptimizer(
|
||||
func_optimizer = JavaScriptFunctionOptimizer(
|
||||
function_to_optimize=func, test_cfg=test_config, aiservice_client=MagicMock()
|
||||
)
|
||||
result = func_optimizer.get_code_optimization_context()
|
||||
|
|
|
|||
|
|
@ -49,7 +49,7 @@ def add(a, b):
|
|||
""")
|
||||
f.flush()
|
||||
|
||||
functions = python_support.discover_functions(Path(f.name))
|
||||
functions = python_support.discover_functions(Path(f.name).read_text(encoding="utf-8"), Path(f.name))
|
||||
|
||||
assert len(functions) == 1
|
||||
assert functions[0].function_name == "add"
|
||||
|
|
@ -70,7 +70,7 @@ def multiply(a, b):
|
|||
""")
|
||||
f.flush()
|
||||
|
||||
functions = python_support.discover_functions(Path(f.name))
|
||||
functions = python_support.discover_functions(Path(f.name).read_text(encoding="utf-8"), Path(f.name))
|
||||
|
||||
assert len(functions) == 3
|
||||
names = {func.function_name for func in functions}
|
||||
|
|
@ -88,7 +88,7 @@ def without_return():
|
|||
""")
|
||||
f.flush()
|
||||
|
||||
functions = python_support.discover_functions(Path(f.name))
|
||||
functions = python_support.discover_functions(Path(f.name).read_text(encoding="utf-8"), Path(f.name))
|
||||
|
||||
# Only the function with return should be discovered
|
||||
assert len(functions) == 1
|
||||
|
|
@ -107,7 +107,7 @@ class Calculator:
|
|||
""")
|
||||
f.flush()
|
||||
|
||||
functions = python_support.discover_functions(Path(f.name))
|
||||
functions = python_support.discover_functions(Path(f.name).read_text(encoding="utf-8"), Path(f.name))
|
||||
|
||||
assert len(functions) == 2
|
||||
for func in functions:
|
||||
|
|
@ -126,7 +126,7 @@ def sync_function():
|
|||
""")
|
||||
f.flush()
|
||||
|
||||
functions = python_support.discover_functions(Path(f.name))
|
||||
functions = python_support.discover_functions(Path(f.name).read_text(encoding="utf-8"), Path(f.name))
|
||||
|
||||
assert len(functions) == 2
|
||||
|
||||
|
|
@ -147,7 +147,7 @@ def outer():
|
|||
""")
|
||||
f.flush()
|
||||
|
||||
functions = python_support.discover_functions(Path(f.name))
|
||||
functions = python_support.discover_functions(Path(f.name).read_text(encoding="utf-8"), Path(f.name))
|
||||
|
||||
# Only outer should be discovered; inner is nested and skipped
|
||||
assert len(functions) == 1
|
||||
|
|
@ -164,7 +164,7 @@ class Utils:
|
|||
""")
|
||||
f.flush()
|
||||
|
||||
functions = python_support.discover_functions(Path(f.name))
|
||||
functions = python_support.discover_functions(Path(f.name).read_text(encoding="utf-8"), Path(f.name))
|
||||
|
||||
assert len(functions) == 1
|
||||
assert functions[0].function_name == "helper"
|
||||
|
|
@ -183,7 +183,9 @@ def sync_func():
|
|||
f.flush()
|
||||
|
||||
criteria = FunctionFilterCriteria(include_async=False)
|
||||
functions = python_support.discover_functions(Path(f.name), criteria)
|
||||
functions = python_support.discover_functions(
|
||||
Path(f.name).read_text(encoding="utf-8"), Path(f.name), criteria
|
||||
)
|
||||
|
||||
assert len(functions) == 1
|
||||
assert functions[0].function_name == "sync_func"
|
||||
|
|
@ -202,7 +204,9 @@ class MyClass:
|
|||
f.flush()
|
||||
|
||||
criteria = FunctionFilterCriteria(include_methods=False)
|
||||
functions = python_support.discover_functions(Path(f.name), criteria)
|
||||
functions = python_support.discover_functions(
|
||||
Path(f.name).read_text(encoding="utf-8"), Path(f.name), criteria
|
||||
)
|
||||
|
||||
assert len(functions) == 1
|
||||
assert functions[0].function_name == "standalone"
|
||||
|
|
@ -220,7 +224,7 @@ def func2():
|
|||
""")
|
||||
f.flush()
|
||||
|
||||
functions = python_support.discover_functions(Path(f.name))
|
||||
functions = python_support.discover_functions(Path(f.name).read_text(encoding="utf-8"), Path(f.name))
|
||||
|
||||
func1 = next(f for f in functions if f.function_name == "func1")
|
||||
func2 = next(f for f in functions if f.function_name == "func2")
|
||||
|
|
@ -239,12 +243,12 @@ def func2():
|
|||
f.flush()
|
||||
|
||||
with pytest.raises(ParserSyntaxError):
|
||||
python_support.discover_functions(Path(f.name))
|
||||
python_support.discover_functions(Path(f.name).read_text(encoding="utf-8"), Path(f.name))
|
||||
|
||||
def test_discover_nonexistent_file_raises(self, python_support):
|
||||
"""Test that nonexistent file raises FileNotFoundError."""
|
||||
with pytest.raises(FileNotFoundError):
|
||||
python_support.discover_functions(Path("/nonexistent/file.py"))
|
||||
def test_discover_empty_source_returns_empty(self, python_support):
|
||||
"""Test that empty source returns empty list."""
|
||||
functions = python_support.discover_functions("", Path("/nonexistent/file.py"))
|
||||
assert functions == []
|
||||
|
||||
|
||||
class TestReplaceFunction:
|
||||
|
|
@ -495,7 +499,7 @@ class TestIntegration:
|
|||
file_path = Path(f.name)
|
||||
|
||||
# Discover
|
||||
functions = python_support.discover_functions(file_path)
|
||||
functions = python_support.discover_functions(file_path.read_text(encoding="utf-8"), file_path)
|
||||
assert len(functions) == 1
|
||||
func = functions[0]
|
||||
assert func.function_name == "fibonacci"
|
||||
|
|
@ -536,7 +540,7 @@ def standalone():
|
|||
f.flush()
|
||||
file_path = Path(f.name)
|
||||
|
||||
functions = python_support.discover_functions(file_path)
|
||||
functions = python_support.discover_functions(file_path.read_text(encoding="utf-8"), file_path)
|
||||
|
||||
# Should find 4 functions
|
||||
assert len(functions) == 4
|
||||
|
|
|
|||
|
|
@ -13,7 +13,7 @@ from pathlib import Path
|
|||
|
||||
import pytest
|
||||
|
||||
from codeflash.languages.base import FunctionInfo, Language, ParentInfo
|
||||
from codeflash.languages.base import Language
|
||||
from codeflash.languages.javascript.support import TypeScriptSupport
|
||||
|
||||
|
||||
|
|
@ -126,14 +126,13 @@ export function add(a: number, b: number): number {
|
|||
f.flush()
|
||||
file_path = Path(f.name)
|
||||
|
||||
functions = ts_support.discover_functions(file_path)
|
||||
source = file_path.read_text(encoding="utf-8")
|
||||
functions = ts_support.discover_functions(source, file_path)
|
||||
assert len(functions) == 1
|
||||
assert functions[0].function_name == "add"
|
||||
|
||||
# Extract code context
|
||||
code_context = ts_support.extract_code_context(
|
||||
functions[0], file_path.parent, file_path.parent
|
||||
)
|
||||
code_context = ts_support.extract_code_context(functions[0], file_path.parent, file_path.parent)
|
||||
|
||||
# Verify extracted code is valid
|
||||
assert ts_support.validate_syntax(code_context.target_code) is True
|
||||
|
|
@ -164,14 +163,13 @@ export async function execMongoEval(queryExpression, appsmithMongoURI) {
|
|||
f.flush()
|
||||
file_path = Path(f.name)
|
||||
|
||||
functions = ts_support.discover_functions(file_path)
|
||||
source = file_path.read_text(encoding="utf-8")
|
||||
functions = ts_support.discover_functions(source, file_path)
|
||||
assert len(functions) == 1
|
||||
assert functions[0].function_name == "execMongoEval"
|
||||
|
||||
# Extract code context
|
||||
code_context = ts_support.extract_code_context(
|
||||
functions[0], file_path.parent, file_path.parent
|
||||
)
|
||||
code_context = ts_support.extract_code_context(functions[0], file_path.parent, file_path.parent)
|
||||
|
||||
# Verify extracted code is valid
|
||||
assert ts_support.validate_syntax(code_context.target_code) is True
|
||||
|
|
@ -215,14 +213,13 @@ export async function figureOutContentsPath(root: string): Promise<string> {
|
|||
f.flush()
|
||||
file_path = Path(f.name)
|
||||
|
||||
functions = ts_support.discover_functions(file_path)
|
||||
source = file_path.read_text(encoding="utf-8")
|
||||
functions = ts_support.discover_functions(source, file_path)
|
||||
assert len(functions) == 1
|
||||
assert functions[0].function_name == "figureOutContentsPath"
|
||||
|
||||
# Extract code context
|
||||
code_context = ts_support.extract_code_context(
|
||||
functions[0], file_path.parent, file_path.parent
|
||||
)
|
||||
code_context = ts_support.extract_code_context(functions[0], file_path.parent, file_path.parent)
|
||||
|
||||
# Verify extracted code is valid
|
||||
assert ts_support.validate_syntax(code_context.target_code) is True
|
||||
|
|
@ -246,12 +243,11 @@ export function readConfig(filename: string): string {
|
|||
f.flush()
|
||||
file_path = Path(f.name)
|
||||
|
||||
functions = ts_support.discover_functions(file_path)
|
||||
source = file_path.read_text(encoding="utf-8")
|
||||
functions = ts_support.discover_functions(source, file_path)
|
||||
assert len(functions) == 1
|
||||
|
||||
code_context = ts_support.extract_code_context(
|
||||
functions[0], file_path.parent, file_path.parent
|
||||
)
|
||||
code_context = ts_support.extract_code_context(functions[0], file_path.parent, file_path.parent)
|
||||
|
||||
# Check that imports are captured
|
||||
assert len(code_context.imports) > 0
|
||||
|
|
@ -278,12 +274,11 @@ export async function fetchWithRetry(url: string): Promise<any> {
|
|||
f.flush()
|
||||
file_path = Path(f.name)
|
||||
|
||||
functions = ts_support.discover_functions(file_path)
|
||||
source = file_path.read_text(encoding="utf-8")
|
||||
functions = ts_support.discover_functions(source, file_path)
|
||||
assert len(functions) == 1
|
||||
|
||||
code_context = ts_support.extract_code_context(
|
||||
functions[0], file_path.parent, file_path.parent
|
||||
)
|
||||
code_context = ts_support.extract_code_context(functions[0], file_path.parent, file_path.parent)
|
||||
|
||||
# Verify extracted code is valid
|
||||
assert ts_support.validate_syntax(code_context.target_code) is True
|
||||
|
|
@ -324,7 +319,8 @@ export class EndpointGroup {
|
|||
file_path = Path(f.name)
|
||||
|
||||
# Discover the 'post' method
|
||||
functions = ts_support.discover_functions(file_path)
|
||||
source = file_path.read_text(encoding="utf-8")
|
||||
functions = ts_support.discover_functions(source, file_path)
|
||||
post_method = None
|
||||
for func in functions:
|
||||
if func.function_name == "post":
|
||||
|
|
@ -334,9 +330,7 @@ export class EndpointGroup {
|
|||
assert post_method is not None, "post method should be discovered"
|
||||
|
||||
# Extract code context
|
||||
code_context = ts_support.extract_code_context(
|
||||
post_method, file_path.parent, file_path.parent
|
||||
)
|
||||
code_context = ts_support.extract_code_context(post_method, file_path.parent, file_path.parent)
|
||||
|
||||
# The extracted code should be syntactically valid
|
||||
assert ts_support.validate_syntax(code_context.target_code) is True, (
|
||||
|
|
@ -352,9 +346,7 @@ export class EndpointGroup {
|
|||
# Check that addEndpoint appears BEFORE the closing brace of the class
|
||||
class_end_index = code_context.target_code.rfind("}")
|
||||
add_endpoint_index = code_context.target_code.find("addEndpoint")
|
||||
assert add_endpoint_index < class_end_index, (
|
||||
"addEndpoint should be inside the class wrapper"
|
||||
)
|
||||
assert add_endpoint_index < class_end_index, "addEndpoint should be inside the class wrapper"
|
||||
|
||||
def test_multiple_private_helpers_inside_class(self, ts_support):
|
||||
"""Test that multiple private helpers are all included inside the class."""
|
||||
|
|
@ -386,7 +378,8 @@ export class Router {
|
|||
file_path = Path(f.name)
|
||||
|
||||
# Discover the 'addRoute' method
|
||||
functions = ts_support.discover_functions(file_path)
|
||||
source = file_path.read_text(encoding="utf-8")
|
||||
functions = ts_support.discover_functions(source, file_path)
|
||||
add_route_method = None
|
||||
for func in functions:
|
||||
if func.function_name == "addRoute":
|
||||
|
|
@ -395,9 +388,7 @@ export class Router {
|
|||
|
||||
assert add_route_method is not None
|
||||
|
||||
code_context = ts_support.extract_code_context(
|
||||
add_route_method, file_path.parent, file_path.parent
|
||||
)
|
||||
code_context = ts_support.extract_code_context(add_route_method, file_path.parent, file_path.parent)
|
||||
|
||||
# Should be valid TypeScript
|
||||
assert ts_support.validate_syntax(code_context.target_code) is True
|
||||
|
|
@ -424,7 +415,8 @@ export class Calculator {
|
|||
f.flush()
|
||||
file_path = Path(f.name)
|
||||
|
||||
functions = ts_support.discover_functions(file_path)
|
||||
source = file_path.read_text(encoding="utf-8")
|
||||
functions = ts_support.discover_functions(source, file_path)
|
||||
add_method = None
|
||||
for func in functions:
|
||||
if func.function_name == "add":
|
||||
|
|
@ -433,18 +425,14 @@ export class Calculator {
|
|||
|
||||
assert add_method is not None
|
||||
|
||||
code_context = ts_support.extract_code_context(
|
||||
add_method, file_path.parent, file_path.parent
|
||||
)
|
||||
code_context = ts_support.extract_code_context(add_method, file_path.parent, file_path.parent)
|
||||
|
||||
# 'compute' should be in target_code (inside class)
|
||||
assert "compute" in code_context.target_code
|
||||
|
||||
# 'compute' should NOT be in helper_functions (would be duplicate)
|
||||
helper_names = [h.name for h in code_context.helper_functions]
|
||||
assert "compute" not in helper_names, (
|
||||
"Same-class helper 'compute' should not be in helper_functions list"
|
||||
)
|
||||
assert "compute" not in helper_names, "Same-class helper 'compute' should not be in helper_functions list"
|
||||
|
||||
|
||||
class TestTypeScriptLanguageProperties:
|
||||
|
|
|
|||
|
|
@ -124,10 +124,8 @@ class TestTypeScriptCodeContext:
|
|||
"""Test extracting code context for a TypeScript function."""
|
||||
skip_if_ts_not_supported()
|
||||
from codeflash.discovery.functions_to_optimize import find_all_functions_in_file
|
||||
from codeflash.languages import current as lang_current
|
||||
from codeflash.languages.python.context.code_context_extractor import get_code_optimization_context
|
||||
|
||||
lang_current._current_language = Language.TYPESCRIPT
|
||||
from codeflash.languages import get_language_support
|
||||
from codeflash.languages.javascript.function_optimizer import JavaScriptFunctionOptimizer
|
||||
|
||||
fib_file = ts_project_dir / "fibonacci.ts"
|
||||
if not fib_file.exists():
|
||||
|
|
@ -139,7 +137,11 @@ class TestTypeScriptCodeContext:
|
|||
fib_func = next((f for f in func_list if f.function_name == "fibonacci"), None)
|
||||
assert fib_func is not None
|
||||
|
||||
context = get_code_optimization_context(fib_func, ts_project_dir)
|
||||
ts_support = get_language_support(Language.TYPESCRIPT)
|
||||
code_context = ts_support.extract_code_context(fib_func, ts_project_dir, ts_project_dir)
|
||||
context = JavaScriptFunctionOptimizer._build_optimization_context(
|
||||
code_context, fib_file, "typescript", ts_project_dir
|
||||
)
|
||||
|
||||
assert context.read_writable_code is not None
|
||||
# Critical: language should be "typescript", not "javascript"
|
||||
|
|
|
|||
|
|
@ -118,11 +118,9 @@ class TestVitestCodeContext:
|
|||
"""Test extracting code context for a TypeScript function."""
|
||||
skip_if_js_not_supported()
|
||||
from codeflash.discovery.functions_to_optimize import find_all_functions_in_file
|
||||
from codeflash.languages import current as lang_current
|
||||
from codeflash.languages import get_language_support
|
||||
from codeflash.languages.base import Language
|
||||
from codeflash.languages.python.context.code_context_extractor import get_code_optimization_context
|
||||
|
||||
lang_current._current_language = Language.TYPESCRIPT
|
||||
from codeflash.languages.javascript.function_optimizer import JavaScriptFunctionOptimizer
|
||||
|
||||
fib_file = vitest_project_dir / "fibonacci.ts"
|
||||
if not fib_file.exists():
|
||||
|
|
@ -134,7 +132,11 @@ class TestVitestCodeContext:
|
|||
fib_func = next((f for f in func_list if f.function_name == "fibonacci"), None)
|
||||
assert fib_func is not None
|
||||
|
||||
context = get_code_optimization_context(fib_func, vitest_project_dir)
|
||||
ts_support = get_language_support(Language.TYPESCRIPT)
|
||||
code_context = ts_support.extract_code_context(fib_func, vitest_project_dir, vitest_project_dir)
|
||||
context = JavaScriptFunctionOptimizer._build_optimization_context(
|
||||
code_context, fib_file, "typescript", vitest_project_dir
|
||||
)
|
||||
|
||||
assert context.read_writable_code is not None
|
||||
assert context.read_writable_code.language == "typescript"
|
||||
|
|
|
|||
|
|
@ -2,7 +2,7 @@ from pathlib import Path
|
|||
|
||||
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
|
||||
from codeflash.models.models import CodeOptimizationContext, CodeStringsMarkdown
|
||||
from codeflash.optimization.function_optimizer import FunctionOptimizer
|
||||
from codeflash.languages.python.function_optimizer import PythonFunctionOptimizer
|
||||
from codeflash.verification.verification_utils import TestConfig
|
||||
|
||||
|
||||
|
|
@ -106,7 +106,7 @@ def _get_string_usage(text: str) -> Usage:
|
|||
test_framework="pytest",
|
||||
pytest_cmd="pytest",
|
||||
)
|
||||
func_optimizer = FunctionOptimizer(function_to_optimize=func, test_cfg=test_config)
|
||||
func_optimizer = PythonFunctionOptimizer(function_to_optimize=func, test_cfg=test_config)
|
||||
code_context: CodeOptimizationContext = func_optimizer.get_code_optimization_context().unwrap()
|
||||
|
||||
original_helper_code: dict[Path, str] = {}
|
||||
|
|
|
|||
|
|
@ -3,9 +3,9 @@ import tempfile
|
|||
from pathlib import Path
|
||||
|
||||
from codeflash.code_utils.code_utils import ImportErrorPattern
|
||||
from codeflash.languages import current_language_support
|
||||
from codeflash.models.models import TestFile, TestFiles, TestType
|
||||
from codeflash.verification.parse_test_output import parse_test_xml
|
||||
from codeflash.verification.test_runner import run_behavioral_tests
|
||||
from codeflash.verification.verification_utils import TestConfig
|
||||
|
||||
|
||||
|
|
@ -48,8 +48,8 @@ class TestUnittestRunnerSorter(unittest.TestCase):
|
|||
test_files=[TestFile(instrumented_behavior_file_path=test_file_path, test_type=TestType.EXISTING_UNIT_TEST)]
|
||||
)
|
||||
test_file_path.write_text(code, encoding="utf-8")
|
||||
result_file, process, _, _ = run_behavioral_tests(
|
||||
test_files, test_framework=config.test_framework, cwd=Path(config.project_root_path), test_env=test_env
|
||||
result_file, process, _, _ = current_language_support().run_behavioral_tests(
|
||||
test_paths=test_files, test_env=test_env, cwd=Path(config.project_root_path)
|
||||
)
|
||||
results = parse_test_xml(result_file, test_files, config, process)
|
||||
assert results[0].did_pass, "Test did not pass as expected"
|
||||
|
|
@ -89,13 +89,8 @@ def test_sort():
|
|||
test_files=[TestFile(instrumented_behavior_file_path=test_file_path, test_type=TestType.EXISTING_UNIT_TEST)]
|
||||
)
|
||||
test_file_path.write_text(code, encoding="utf-8")
|
||||
result_file, process, _, _ = run_behavioral_tests(
|
||||
test_files,
|
||||
test_framework=config.test_framework,
|
||||
cwd=Path(config.project_root_path),
|
||||
test_env=test_env,
|
||||
pytest_timeout=1,
|
||||
pytest_target_runtime_seconds=1,
|
||||
result_file, process, _, _ = current_language_support().run_behavioral_tests(
|
||||
test_paths=test_files, test_env=test_env, cwd=Path(config.project_root_path), timeout=1
|
||||
)
|
||||
results = parse_test_xml(
|
||||
test_xml_file_path=result_file, test_files=test_files, test_config=config, run_result=process
|
||||
|
|
@ -136,13 +131,8 @@ def test_sort():
|
|||
test_files=[TestFile(instrumented_behavior_file_path=test_file_path, test_type=TestType.EXISTING_UNIT_TEST)]
|
||||
)
|
||||
test_file_path.write_text(code, encoding="utf-8")
|
||||
result_file, process, _, _ = run_behavioral_tests(
|
||||
test_files,
|
||||
test_framework=config.test_framework,
|
||||
cwd=Path(config.project_root_path),
|
||||
test_env=test_env,
|
||||
pytest_timeout=1,
|
||||
pytest_target_runtime_seconds=1,
|
||||
result_file, process, _, _ = current_language_support().run_behavioral_tests(
|
||||
test_paths=test_files, test_env=test_env, cwd=Path(config.project_root_path), timeout=1
|
||||
)
|
||||
results = parse_test_xml(
|
||||
test_xml_file_path=result_file, test_files=test_files, test_config=config, run_result=process
|
||||
|
|
|
|||
|
|
@ -10,8 +10,8 @@ from codeflash.languages.python.context.unused_definition_remover import (
|
|||
detect_unused_helper_functions,
|
||||
revert_unused_helper_functions,
|
||||
)
|
||||
from codeflash.languages.python.function_optimizer import PythonFunctionOptimizer
|
||||
from codeflash.models.models import CodeStringsMarkdown
|
||||
from codeflash.optimization.function_optimizer import FunctionOptimizer
|
||||
from codeflash.verification.verification_utils import TestConfig
|
||||
|
||||
|
||||
|
|
@ -83,7 +83,7 @@ def helper_function_2(x):
|
|||
)
|
||||
|
||||
# Create function optimizer
|
||||
optimizer = FunctionOptimizer(
|
||||
optimizer = PythonFunctionOptimizer(
|
||||
function_to_optimize=function_to_optimize,
|
||||
test_cfg=test_cfg,
|
||||
function_to_optimize_source_code=main_file.read_text(),
|
||||
|
|
@ -194,7 +194,7 @@ def helper_function_2(x):
|
|||
)
|
||||
|
||||
# Create function optimizer
|
||||
optimizer = FunctionOptimizer(
|
||||
optimizer = PythonFunctionOptimizer(
|
||||
function_to_optimize=function_to_optimize,
|
||||
test_cfg=test_cfg,
|
||||
function_to_optimize_source_code=main_file.read_text(),
|
||||
|
|
@ -269,7 +269,7 @@ def helper_function_2(x):
|
|||
)
|
||||
|
||||
# Create function optimizer
|
||||
optimizer = FunctionOptimizer(
|
||||
optimizer = PythonFunctionOptimizer(
|
||||
function_to_optimize=function_to_optimize,
|
||||
test_cfg=test_cfg,
|
||||
function_to_optimize_source_code=main_file.read_text(),
|
||||
|
|
@ -365,7 +365,7 @@ def entrypoint_function(n):
|
|||
)
|
||||
|
||||
# Create function optimizer
|
||||
optimizer = FunctionOptimizer(
|
||||
optimizer = PythonFunctionOptimizer(
|
||||
function_to_optimize=function_to_optimize,
|
||||
test_cfg=test_cfg,
|
||||
function_to_optimize_source_code=main_file.read_text(),
|
||||
|
|
@ -559,7 +559,7 @@ class Calculator:
|
|||
)
|
||||
|
||||
# Create function optimizer
|
||||
optimizer = FunctionOptimizer(
|
||||
optimizer = PythonFunctionOptimizer(
|
||||
function_to_optimize=function_to_optimize,
|
||||
test_cfg=test_cfg,
|
||||
function_to_optimize_source_code=main_file.read_text(),
|
||||
|
|
@ -710,7 +710,7 @@ class Processor:
|
|||
)
|
||||
|
||||
# Create function optimizer
|
||||
optimizer = FunctionOptimizer(
|
||||
optimizer = PythonFunctionOptimizer(
|
||||
function_to_optimize=function_to_optimize,
|
||||
test_cfg=test_cfg,
|
||||
function_to_optimize_source_code=main_file.read_text(),
|
||||
|
|
@ -895,7 +895,7 @@ class OuterClass:
|
|||
)
|
||||
|
||||
# Create function optimizer
|
||||
optimizer = FunctionOptimizer(
|
||||
optimizer = PythonFunctionOptimizer(
|
||||
function_to_optimize=function_to_optimize,
|
||||
test_cfg=test_cfg,
|
||||
function_to_optimize_source_code=main_file.read_text(),
|
||||
|
|
@ -1051,7 +1051,7 @@ def entrypoint_function(n):
|
|||
)
|
||||
|
||||
# Create function optimizer
|
||||
optimizer = FunctionOptimizer(
|
||||
optimizer = PythonFunctionOptimizer(
|
||||
function_to_optimize=function_to_optimize,
|
||||
test_cfg=test_cfg,
|
||||
function_to_optimize_source_code=main_file.read_text(),
|
||||
|
|
@ -1215,7 +1215,7 @@ def entrypoint_function(n):
|
|||
)
|
||||
|
||||
# Create function optimizer
|
||||
optimizer = FunctionOptimizer(
|
||||
optimizer = PythonFunctionOptimizer(
|
||||
function_to_optimize=function_to_optimize,
|
||||
test_cfg=test_cfg,
|
||||
function_to_optimize_source_code=main_file.read_text(),
|
||||
|
|
@ -1442,7 +1442,7 @@ class MathUtils:
|
|||
)
|
||||
|
||||
# Create function optimizer
|
||||
optimizer = FunctionOptimizer(
|
||||
optimizer = PythonFunctionOptimizer(
|
||||
function_to_optimize=function_to_optimize,
|
||||
test_cfg=test_cfg,
|
||||
function_to_optimize_source_code=main_file.read_text(),
|
||||
|
|
@ -1576,7 +1576,7 @@ async def async_entrypoint(n):
|
|||
)
|
||||
|
||||
# Create function optimizer
|
||||
optimizer = FunctionOptimizer(
|
||||
optimizer = PythonFunctionOptimizer(
|
||||
function_to_optimize=function_to_optimize,
|
||||
test_cfg=test_cfg,
|
||||
function_to_optimize_source_code=main_file.read_text(),
|
||||
|
|
@ -1664,7 +1664,7 @@ def sync_entrypoint(n):
|
|||
function_to_optimize = FunctionToOptimize(file_path=main_file, function_name="sync_entrypoint", parents=[])
|
||||
|
||||
# Create function optimizer
|
||||
optimizer = FunctionOptimizer(
|
||||
optimizer = PythonFunctionOptimizer(
|
||||
function_to_optimize=function_to_optimize,
|
||||
test_cfg=test_cfg,
|
||||
function_to_optimize_source_code=main_file.read_text(),
|
||||
|
|
@ -1773,7 +1773,7 @@ async def mixed_entrypoint(n):
|
|||
)
|
||||
|
||||
# Create function optimizer
|
||||
optimizer = FunctionOptimizer(
|
||||
optimizer = PythonFunctionOptimizer(
|
||||
function_to_optimize=function_to_optimize,
|
||||
test_cfg=test_cfg,
|
||||
function_to_optimize_source_code=main_file.read_text(),
|
||||
|
|
@ -1874,7 +1874,7 @@ class AsyncProcessor:
|
|||
)
|
||||
|
||||
# Create function optimizer
|
||||
optimizer = FunctionOptimizer(
|
||||
optimizer = PythonFunctionOptimizer(
|
||||
function_to_optimize=function_to_optimize,
|
||||
test_cfg=test_cfg,
|
||||
function_to_optimize_source_code=main_file.read_text(),
|
||||
|
|
@ -1960,7 +1960,7 @@ async def async_entrypoint(n):
|
|||
)
|
||||
|
||||
# Create function optimizer
|
||||
optimizer = FunctionOptimizer(
|
||||
optimizer = PythonFunctionOptimizer(
|
||||
function_to_optimize=function_to_optimize,
|
||||
test_cfg=test_cfg,
|
||||
function_to_optimize_source_code=main_file.read_text(),
|
||||
|
|
@ -2039,7 +2039,7 @@ def gcd_recursive(a: int, b: int) -> int:
|
|||
function_to_optimize = FunctionToOptimize(file_path=main_file, function_name="gcd_recursive", parents=[])
|
||||
|
||||
# Create function optimizer
|
||||
optimizer = FunctionOptimizer(
|
||||
optimizer = PythonFunctionOptimizer(
|
||||
function_to_optimize=function_to_optimize,
|
||||
test_cfg=test_cfg,
|
||||
function_to_optimize_source_code=main_file.read_text(),
|
||||
|
|
@ -2152,7 +2152,7 @@ async def async_entrypoint_with_generators(n):
|
|||
)
|
||||
|
||||
# Create function optimizer
|
||||
optimizer = FunctionOptimizer(
|
||||
optimizer = PythonFunctionOptimizer(
|
||||
function_to_optimize=function_to_optimize,
|
||||
test_cfg=test_cfg,
|
||||
function_to_optimize_source_code=main_file.read_text(),
|
||||
|
|
|
|||
Loading…
Reference in a new issue